import openeo
from openeo.udf import XarrayDataCube
import xarray as xr
from typing import Dict
from pathlib import Path
import os
import sys
import numpy as np
from typing import List, Tuple

CHUNK_X_SIZE, CHUNK_Y_SIZE = 128, 128
CHUNK_X_OVERLAP, CHUNK_Y_OVERLAP = 16, 16

DEFAULT_TARGET_EXCLUDED_VALUE = 0
DEFAULT_EXCLUDED_VALUES = [0, 65534, 65535]
DEFAULT_UNDECIDED_CLASS_VALUE = 3000


def remove_small_objects(input_data: np.ndarray,
                         target_excluded_value: int = DEFAULT_TARGET_EXCLUDED_VALUE,
                         minimum_object_size: int = 25,
                         excluded_values: List[int] = DEFAULT_EXCLUDED_VALUES) -> np.ndarray:
    """
    Excludes objects (blobs of pixels) that have an area below 25 pixels. Here, an object is considered to be a
    classified crop field, or an island of crop fields that is surrounded by no-crop land. The excluded pixels
    are set to the specified target_exclude_value, and the probabilities are set to 0.

    :param input_data: A numpy array containing in the first dimensions two elements: the class codes and the associated
                       predicted probability/confidence.
    :param target_excluded_value: Value on which to convert isolated pixels.
    :param minimum_object_size: The minimum size of the crop blobs to be kept
    :param excluded_values: Values that are considered as non-crop land cover or invalid data
    :return: An array containing the processed results as well as an array with the new prediction probabilities.
    """

    from skimage.morphology import remove_small_objects

    prediction = input_data[0]
    probability = input_data[1]

    # Creates a binary mask containing the excluded values
    exclusion_mask = np.zeros(shape=prediction.shape, dtype=np.bool)
    for excluded_val in excluded_values:
        exclusion_mask[prediction == excluded_val] = True
    valid_mask = ~exclusion_mask

    valid_mask_cleaned = remove_small_objects(valid_mask, min_size=minimum_object_size)

    prediction_cleaned = prediction.copy()
    probability_cleaned = probability.copy()

    prediction_cleaned[~valid_mask_cleaned] = target_excluded_value
    probability_cleaned[~valid_mask_cleaned] = 0

    for excluded_val in excluded_values:
        prediction_cleaned[prediction == excluded_val] = excluded_val
        probability_cleaned[prediction == excluded_val] = 0

    return np.array([prediction_cleaned, probability_cleaned])


def remove_isolated_pixels(input_data: np.ndarray,
                           target_excluded_value: int = DEFAULT_TARGET_EXCLUDED_VALUE,
                           number_of_neighbors: int = 5,
                           excluded_values: List[int] = DEFAULT_EXCLUDED_VALUES) -> np.ndarray:
    """
    Excluding the pixels classified as crops that have less neighbors than the specified parameter. Their probability
    are set to 0.

    :param input_data: A numpy array containing in the first dimensions two elements: the class codes and the associated
                       predicted probability/confidence.
    :param target_excluded_value: Value on which to convert isolated pixels.
    :param number_of_neighbors: Number of minimum neighbors to have to remain valid.
    :param excluded_values: Values that are considered as non-crop land cover or invalid data
    :return: An array containing the processed results as well as an array with the new prediction probabilities.
    """
    from scipy.signal import convolve2d

    prediction = input_data[0]
    probability = input_data[1]

    # Creates mask containing all the excluded values
    exclusion_mask = np.zeros(shape=prediction.shape, dtype=np.bool)
    for excluded_val in excluded_values:
        exclusion_mask[prediction == excluded_val] = True

    # Inverts the mask, creating a mask that includes valid pixels
    valid_mask = (~exclusion_mask).astype(np.uint8)

    # 3x3 kernel, used for counting neighbor valid pixels (8 neighbors around each pixel)
    kernel = np.ones(shape=(3, 3), dtype=np.uint8)
    # Setting the central value of the kernel to 0 (we don't count the pixel itself as a neighbor)
    kernel[1, 1] = 0

    # Performs convolution, essentially counting the neighbors for each pixel
    neighbor_counts = convolve2d(valid_mask, kernel, mode='same', fillvalue=0)  # TODO: evaluate impact of fillvalue

    # Create the final isolated mask and return a copy of the input data but with isolated pixels marked as excluded
    isolated_pixels = neighbor_counts < number_of_neighbors
    output = prediction.copy()
    output_proba = probability.copy()

    output[isolated_pixels] = target_excluded_value
    output_proba[isolated_pixels] = 0

    # The pixels that were originally considered as one of the excluded classes are reverted back to this original class
    # (For example, we don't want to change nodata pixels into nocrop pixels)
    for excluded_val in excluded_values:
        output[prediction == excluded_val] = excluded_val
        output_proba[prediction == excluded_val] = 0

    return np.array([output, output_proba])


def majority_vote(input_data: np.ndarray, kernel_size: int = 7, conf_threshold: int = 30,
                  target_excluded_value: int = DEFAULT_TARGET_EXCLUDED_VALUE,
                  excluded_values: List[int] = DEFAULT_EXCLUDED_VALUES) -> np.ndarray:
    """
    Majority vote is performed using a sliding local kernel. For each pixel, the voting of a final class is done from
    neighbours values weighted with the confidence threshold. Pixels that have one of the specified excluded values are
    excluded in the voting process and are unchanged.

    The prediction probabilities are reevaluated by taking, for each pixel, the average of probabilities of the
    neighbors that belong to the winning class. (For example, if a pixel was voted to class 2 and there are three
    neighbors of that class, then the new probability is the sum of the old probabilities of each pixels divided by 3)

    :param input_data: A numpy array containing in the first dimensions two elements: the class codes and the associated
                       predicted probability/confidence.
    :param kernel_size: The size of the kernel used for the neighbour around the pixel.
    :param conf_threshold: Pixels under this confidence threshold do not count into the voting process.
    :param target_excluded_value: Pixels that have a null score for every class are turned into this exclusion value
    :param excluded_values: Pixels that have on of the excluded values do not count into the voting process and are
                            unchanged.
    :return: An array containing the processed results as well as an array with the new prediction probabilities.
    """
    from scipy.signal import convolve2d

    # As the probabilities are in integers between 0 and 100, we use uint16 matrices to store the vote scores
    assert kernel_size <= 25, f'Kernel value cannot be larger than 25 (currently: {kernel_size}) because it might lead to scenarios where the 16-bit count matrix is overflown'

    prediction = input_data[0]
    probability = input_data[1]

    # Build a class mapping, so classes are converted to indexes and vice-versa
    unique_values = set(np.unique(prediction))
    unique_values = sorted(unique_values - set(excluded_values))
    index_value_lut = [(k, v) for k, v in enumerate(unique_values)]

    counts = np.zeros(shape=(*prediction.shape, len(unique_values)), dtype=np.uint16)
    probabilities = np.zeros(shape=(*probability.shape, len(unique_values)), dtype=np.uint16)

    # Iterates for each classes
    for cls_idx, cls_value in index_value_lut:

        # Take the binary mask of the interest class, and multiplies by the probabilities
        class_mask = ((prediction == cls_value) * probability).astype(np.uint16)

        # Sets to 0 the class scores where the threshold is lower
        class_mask[probability <= conf_threshold] = 0

        # Set to 0 the class scores where the label is excluded
        for excluded_value in excluded_values:
            class_mask[prediction == excluded_value] = 0

        # Binary class mask, used to count HOW MANY neighbours pixels are used for this class
        binary_class_mask = (class_mask > 0).astype(np.uint16)

        # Creates the kernel
        kernel = np.ones(shape=(kernel_size, kernel_size), dtype=np.uint16)

        # Counts around the window the sum of probabilities for that given class
        counts[:, :, cls_idx] = convolve2d(class_mask, kernel, mode='same')

        # Counts the number of neighbors pixels that voted for that given class
        class_voters = convolve2d(binary_class_mask, kernel, mode='same')
        class_voters[class_voters == 0] = 1  # Remove the 0 values because might create divide by 0 issues

        probabilities[:, :, cls_idx] = np.divide(counts[:, :, cls_idx], class_voters)

    # Initializes output array
    aggregated_predictions = np.zeros(shape=(counts.shape[0], counts.shape[1]), dtype=np.uint16)
    # Initializes prediction output array
    aggregated_probabilities = np.zeros(shape=(counts.shape[0], counts.shape[1]), dtype=np.uint16)

    if len(unique_values) > 0:
        # Takes the indices that have the biggest scores
        aggregated_predictions_indices = np.argmax(counts, axis=2)

        # Get the new confidence score for the indices
        aggregated_probabilities = np.take_along_axis(
            probabilities,
            aggregated_predictions_indices.reshape(*aggregated_predictions_indices.shape, 1),
            axis=2
        ).squeeze()

        # Check which pixels have a counts value to 0
        no_score_mask = np.sum(counts, axis=2) == 0

        # convert back to values from indices
        for (cls_idx, cls_value) in index_value_lut:
            aggregated_predictions[aggregated_predictions_indices == cls_idx] = cls_value
            aggregated_predictions = aggregated_predictions.astype(np.uint16)

        aggregated_predictions[no_score_mask] = target_excluded_value
        aggregated_probabilities[no_score_mask] = 0

    # Setting excluded values back to their original values
    for excluded_value in excluded_values:
        aggregated_predictions[prediction == excluded_value] = excluded_value
        aggregated_probabilities[prediction == excluded_value] = 0

    # Check if aggregated_probabilities has no value over 100 (it shouldn't by the nature of the input proba)
    aggregated_probabilities_over_100 = aggregated_probabilities > 100
    assert np.count_nonzero(aggregated_probabilities_over_100) == 0, "Some new probabilities have a value over 100"

    return np.array([aggregated_predictions, aggregated_probabilities])


def set_undecided_by_threshold(input_data: np.ndarray, threshold: int = 35,
                               undecided_value: int = DEFAULT_UNDECIDED_CLASS_VALUE) -> np.ndarray:
    """
    Sets the pixels with prediction probabilities under the defined threshold value to the target undecided value.

    :param input_data: A 3D numpy array containing in the first axis at least the class mask and the probability mask
    :param threshold: The minimum prediction confidence value under which the pixels are sets to the undecided class.
    :param undecided_value: The target value to convert low probability pixels to. By default is 3000.

    :return: Two numpy arrays. The first containing the new class mask with some undecided pixels and the second
    containing the original prediction probability/confidence values.
    """
    prediction = input_data[0]
    probability = input_data[1]

    output_prediction = prediction.copy()
    output_prediction[probability < threshold] = undecided_value

    return np.array([output_prediction, probability])


def apply_datacube(cube: XarrayDataCube, context: Dict) -> XarrayDataCube:
    # if context.get('custom_dependency_path', None) is not None:
    #     sys.path.insert(0, context['custom_dependency_path'])

    sys.path.insert(0, 'tmp/venv_static')
    sys.path.insert(0, 'tmp/venv')

    context_postproc = context.get('postprocessing', {})

    # Selects the first and only timestamp
    inarr = cube.get_array().isel(t=0)

    smoothed_output = xr.apply_ufunc(
        set_undecided_by_threshold, inarr,
        output_dtypes='uint16',
        dask='forbidden',
        kwargs=context_postproc
    )

    smoothed_output = xr.apply_ufunc(
        majority_vote, smoothed_output,
        output_dtypes='uint16',
        dask='forbidden',
        kwargs=context_postproc
    )

    smoothed_output = xr.apply_ufunc(
        remove_isolated_pixels, smoothed_output,
        output_dtypes='uint16',
        dask='forbidden',
        kwargs=context_postproc
    )

    smoothed_output = xr.apply_ufunc(
        remove_small_objects, smoothed_output,
        output_dtypes='uint16',
        dask='forbidden',
        kwargs=context_postproc
    )

    smoothed_output.attrs.update(inarr.attrs)

    return XarrayDataCube(smoothed_output)


def load_cropclass_postprocess_udf() -> str:
    import os
    with open(os.path.realpath(__file__), 'r+') as f:
        return f.read()


def apply_majority_voting_openeo(input: openeo.DataCube,
                                 context: Dict) -> openeo.DataCube:
    postproc_udf = load_cropclass_postprocess_udf()

    output = input.apply_neighborhood(
        process=lambda data: data.run_udf(postproc_udf, context=context, runtime='Python'),
        size=[{'dimension': 'bands', 'value': 2, 'unit': 'px'},
              {'dimension': 'x', 'value': CHUNK_X_SIZE, 'unit': 'px'},
              {'dimension': 'y', 'value': CHUNK_Y_SIZE, 'unit': 'px'}],
        overlap=[{'dimension': 'x', 'value': CHUNK_X_OVERLAP, 'unit': 'px'},
                 {'dimension': 'y', 'value': CHUNK_Y_OVERLAP, 'unit': 'px'}]
    )

    return output


def apply_postprocessing_local(input: xr.DataArray,
                                projection_bbox: Dict = None,
                                bvl_masking_path: str = None) -> xr.DataArray:
    """ Applies postprocessing operations, depending on the specified parameters.

    :param input: A two band DataArray (prediction, probability)
    :param projection_bbox: A dictionary containing the clipping bounding box and the target crs to reproject to. The
                            clipping bbox coordinates must be minx, maxx, miny, maxy expressed in the input data CRS.
                            The reprojection is performed AFTER the clipping. Default: None (operation not performed)
    :param bvl_masking_path: Path to the BVL mask raster. Default: None (operation not performed).

    :returns: A datarray with the modified postprocessed data.
    """
    from cropclass.postprocessing import translate_layer, reproject_layer, apply_bvl_mask

    postprocessed_input = input

    # Performs projection on both bands at the same time
    if projection_bbox is not None:
        required_parameters = ['target_crs', 'clip_minx', 'clip_maxx', 'clip_miny', 'clip_maxy']
        assert all(key in projection_bbox.keys() for key in required_parameters),\
            f'The following parameters must be specified in the projection_bbox parameters: {required_parameters}'
        postprocessed_input = reproject_layer(postprocessed_input, **projection_bbox)

    # Performs masking with different resulting values depending on the bands
    if bvl_masking_path is not None:
        assert os.path.exists(bvl_masking_path), f'Specified path for BVL masking path is invalid: {bvl_masking_path}'
        postprocessed_input = apply_bvl_mask(postprocessed_input, Path(bvl_masking_path))

    return postprocessed_input
