import os
import sys
from typing import Dict
from pathlib import Path
from typing import List, Tuple, Optional
import openeo
import xarray as xr
import numpy as np
from openeo.udf import XarrayDataCube

CHUNK_X_SIZE, CHUNK_Y_SIZE = 128, 128
CHUNK_X_OVERLAP, CHUNK_Y_OVERLAP = 16, 16

DEFAULT_TARGET_EXCLUDED_VALUE = 65533  # To NoCrop value
DEFAULT_EXCLUDED_VALUES = [0, 65533, 65534, 65535]

DEFAULT_UNDECIDED_ARABLE_VALUE = 3100
DEFAULT_UNDECIDED_PERENNIAL_VALUE = 3200

DEFAULT_CLASS_NAMES = [
    'winter_wheat', 'spring_wheat', 'winter_barley', 'spring_barley', 'maize',
    'rice', 'other_winter_cereals', 'other_spring_cereals', 'fresh_vegetables',
    'dry_pulses', 'potatoes', 'sugar_beet', 'sunflower', 'soybeans', 'rapeseed',
    'flax_cotton_and_hemp', 'grass/fodder', 'grapes', 'olives', 'fruits', 'nuts'
]

class InvalidProbabilityError(Exception):
    pass

def _check_probability(
    input_data: np.ndarray,
    phase_name: str,
    excluded_values: List[int] = DEFAULT_EXCLUDED_VALUES
) -> np.ndarray:
    """ Checks the coherence of prediction data in three ways.
        1. Checks if no max probability is above 100
        2. If there are per-class probabilities, checks that the sum of those
           per class probability is 100.
        3. If there are per-class probabilities, checks if the highest of those
           probabilities is reflected by max probability.

        It is asumed at the parameter input_data is a 3D array where the first
        dimension/axis has minimum two elements (croptype and max probability),
        if it has more than two elements, then the per-class probabilties are
        supposed to be the rest of the elements
    """
    if input_data.shape[0] < 2:
        return input_data
    # Invalid mask (mask for pixels that shouldn't be checked)
    invalid_mask = np.isin(input_data[0], excluded_values)

    # Checks if the max_probability is not over 100
    probabilties_over_100 = input_data[1] > 100
    if (probabilties_over_100 & ~invalid_mask * 1).sum() > 0:
        raise InvalidProbabilityError(
            f'Phase: {phase_name} - Some pixels have max_probability over 100 ' 
            f'(max: {np.max(input_data[1])})'
        )
    if input_data.shape[0] > 2:
        # Assumes that the arrays starting from the third one are per-class 
        # probability. Checks if the sum of those probabilities is not over 100
        sum_class_probabilities = np.sum(input_data[2:], axis=0)
        sum_class_probabilities_over_100 = sum_class_probabilities > 105  # Rounding error tolerance
        if (sum_class_probabilities_over_100 & ~invalid_mask * 1).sum() > 0:
            raise InvalidProbabilityError(
                f'Phase {phase_name} - Sum of per class probabilities lead to'
                f' higher probabilities than 100 '
                f'({np.max(sum_class_probabilities)})'
            )

        # Checks if max probability reflects the higher value in the per class
        # probabilities
        highest_class_probability = np.max(input_data[2:], axis=0)

        non_matching_pixels = abs(
            input_data[1].astype(int) - highest_class_probability.astype(int)
        ) > 1
        non_match_pixels_count = (non_matching_pixels & ~invalid_mask * 1).sum()
        if non_match_pixels_count > 0:
            raise InvalidProbabilityError(
                f'Phase: {phase_name} - Max probability doesn\'t match the '
                f'highest probability in the classes probability for some '
                f'pixels ({non_match_pixels_count})'
            )
    return input_data


def smooth_and_reclassify(
    input_data: np.ndarray,
    aggregated_class_codes: Optional[List[int]] = None,
    excluded_values: List[int] = DEFAULT_EXCLUDED_VALUES
) -> np.ndarray:
    """
    This function smooths all the prediction scores of a croptype tile to produce
    a new croptype tile with new predicted values.
    
    It is assumed that the probabilities of all classes are included.

    :param input_data: 3D array with the 'band' dimension in the first axis. 
                       The first axis must contain the prediciton, max_probability
                       and a probability for every crop type class
    :param aggreagated_class_names: The names of the hrl classes, must be one
                                    per class probability in the input array and
                                    be in the right order.
    :param excluded_values: A list of values that are either NoCrop, NoData or OutOfScope
    """
    from scipy.signal import convolve2d
    from cropclass.postprocessing.layer_format import DEFAULT_REGROUPMENT_TABLE


    if aggregated_class_codes is None:
        aggregated_class_codes = DEFAULT_REGROUPMENT_TABLE.values()

    index_to_croptype = dict(enumerate(aggregated_class_codes))

    prediction = input_data[0]
    max_probability = input_data[1]

    # All probabilities are assumed to exist here and be in the same number
    # as the given translation table
    all_probabilities = input_data[2:]

    assert all_probabilities.shape[0] == len(aggregated_class_codes),\
         f'The number of class probabilities {all_probabilities.shape[0]} is ' \
         f'not equals to the number of crop types in the translation table ' \
         f'{len(index_to_croptype)}'

    # computes the excluded mask of values that remain unchanged
    excluded_mask = np.in1d(
        prediction.reshape(-1),
        excluded_values
    ).reshape(*prediction.shape)

    # Sets up the elements of all_probabilities array to 0 with the excluded_mask
    all_probabilities[:, excluded_mask] = 0

    conv_kernel = np.array([
        [1, 2, 1],
        [2, 3, 2],
        [1, 2, 1]
    ], dtype=np.int16)

    # Performs smoothing for all classes
    for cls_idx, _ in index_to_croptype.items():
        all_probabilities[cls_idx] = convolve2d(
            all_probabilities[cls_idx],
            conv_kernel,
            mode='same',
            fillvalue=0
        ) / conv_kernel.sum()
        all_probabilities[cls_idx] = (all_probabilities[cls_idx] + 0.5).astype(np.uint16)
        all_probabilities[cls_idx][excluded_mask] = 0

    # Compute new max probability
    new_max_probability = np.amax(all_probabilities, axis=0)
    new_max_probability[excluded_mask] = max_probability[excluded_mask]

    # Compute the new crop type
    new_prediction_idx = np.argmax(all_probabilities, axis=0)
    new_prediction = prediction.copy()
    for cls_idx, hrl_code in index_to_croptype.items():
        current_index_mask = new_prediction_idx == cls_idx
        new_prediction[current_index_mask] = hrl_code

    new_prediction[excluded_mask] = prediction[excluded_mask]

    return np.array([new_prediction, new_max_probability, *all_probabilities])


def remove_small_objects(input_data: np.ndarray,
                         target_excluded_value: int = DEFAULT_TARGET_EXCLUDED_VALUE,
                         minimum_object_size: int = 25,
                         excluded_values: List[int] = DEFAULT_EXCLUDED_VALUES) -> np.ndarray:
    """
    Excludes objects (blobs of pixels) that have an area below 25 pixels. Here, an object is considered to be a
    classified crop field, or an island of crop fields that is surrounded by no-crop land. The excluded pixels
    are set to the specified target_exclude_value, and the probabilities are set to 0.

    :param input_data: A numpy array containing in the first dimensions two elements: the class codes and the associated
                       predicted probability/confidence.
    :param target_excluded_value: Value on which to convert isolated pixels.
    :param minimum_object_size: The minimum size of the crop blobs to be kept
    :param excluded_values: Values that are considered as non-crop land cover or invalid data
    :return: An array containing the processed results as well as an array with the new prediction probabilities.
    """

    from skimage.morphology import remove_small_objects

    prediction = input_data[0]
    probability = input_data[1]

    # Creates a binary mask containing the excluded values
    exclusion_mask = np.zeros(shape=prediction.shape, dtype=np.bool)
    for excluded_val in excluded_values:
        exclusion_mask[prediction == excluded_val] = True
    valid_mask = ~exclusion_mask

    valid_mask_cleaned = remove_small_objects(valid_mask, min_size=minimum_object_size)

    prediction_cleaned = prediction.copy()
    probability_cleaned = probability.copy()

    prediction_cleaned[~valid_mask_cleaned] = target_excluded_value
    probability_cleaned[~valid_mask_cleaned] = 0

    for excluded_val in excluded_values:
        prediction_cleaned[prediction == excluded_val] = excluded_val
        probability_cleaned[prediction == excluded_val] = 0

    return np.array([prediction_cleaned, probability_cleaned, *input_data[2:]])


def remove_isolated_pixels(input_data: np.ndarray,
                           target_excluded_value: int = DEFAULT_TARGET_EXCLUDED_VALUE,
                           number_of_neighbors: int = 5,
                           excluded_values: List[int] = DEFAULT_EXCLUDED_VALUES) -> np.ndarray:
    """
    Excluding the pixels classified as crops that have less neighbors than the specified parameter. Their probability
    are set to 0.

    :param input_data: A numpy array containing in the first dimensions two elements: the class codes and the associated
                       predicted probability/confidence.
    :param target_excluded_value: Value on which to convert isolated pixels.
    :param number_of_neighbors: Number of minimum neighbors to have to remain valid.
    :param excluded_values: Values that are considered as non-crop land cover or invalid data
    :return: An array containing the processed results as well as an array with the new prediction probabilities.
    """
    from scipy.signal import convolve2d

    prediction = input_data[0]
    probability = input_data[1]

    # Creates mask containing all the excluded values
    exclusion_mask = np.zeros(shape=prediction.shape, dtype=np.bool)
    for excluded_val in excluded_values:
        exclusion_mask[prediction == excluded_val] = True

    # Inverts the mask, creating a mask that includes valid pixels
    valid_mask = (~exclusion_mask).astype(np.uint8)

    # 3x3 kernel, used for counting neighbor valid pixels (8 neighbors around each pixel)
    kernel = np.ones(shape=(3, 3), dtype=np.uint8)
    # Setting the central value of the kernel to 0 (we don't count the pixel itself as a neighbor)
    kernel[1, 1] = 0

    # Performs convolution, essentially counting the neighbors for each pixel
    neighbor_counts = convolve2d(valid_mask, kernel, mode='same', fillvalue=0)  # TODO: evaluate impact of fillvalue

    # Create the final isolated mask and return a copy of the input data but with isolated pixels marked as excluded
    isolated_pixels = neighbor_counts < number_of_neighbors
    output = prediction.copy()
    output_proba = probability.copy()

    output[isolated_pixels] = target_excluded_value
    output_proba[isolated_pixels] = 0

    # The pixels that were originally considered as one of the excluded classes are reverted back to this original class
    # (For example, we don't want to change nodata pixels into nocrop pixels)
    for excluded_val in excluded_values:
        output[prediction == excluded_val] = excluded_val
        output_proba[prediction == excluded_val] = 0

    return np.array([output, output_proba, *input_data[2:]])


def majority_vote(input_data: np.ndarray, kernel_size: int = 7, conf_threshold: int = 30,
                  target_excluded_value: int = DEFAULT_TARGET_EXCLUDED_VALUE,
                  excluded_values: List[int] = DEFAULT_EXCLUDED_VALUES) -> np.ndarray:
    """
    Majority vote is performed using a sliding local kernel. For each pixel, the voting of a final class is done from
    neighbours values weighted with the confidence threshold. Pixels that have one of the specified excluded values are
    excluded in the voting process and are unchanged.

    The prediction probabilities are reevaluated by taking, for each pixel, the average of probabilities of the
    neighbors that belong to the winning class. (For example, if a pixel was voted to class 2 and there are three
    neighbors of that class, then the new probability is the sum of the old probabilities of each pixels divided by 3)

    :param input_data: A numpy array containing in the first dimensions two elements: the class codes and the associated
                       predicted probability/confidence.
    :param kernel_size: The size of the kernel used for the neighbour around the pixel.
    :param conf_threshold: Pixels under this confidence threshold do not count into the voting process.
    :param target_excluded_value: Pixels that have a null score for every class are turned into this exclusion value
    :param excluded_values: Pixels that have on of the excluded values do not count into the voting process and are
                            unchanged.
    :return: An array containing the processed results as well as an array with the new prediction probabilities.
    """
    from scipy.signal import convolve2d

    # As the probabilities are in integers between 0 and 100, we use uint16 matrices to store the vote scores
    assert kernel_size <= 25, f'Kernel value cannot be larger than 25 (currently: {kernel_size}) because it might lead to scenarios where the 16-bit count matrix is overflown'

    prediction = input_data[0]
    probability = input_data[1]

    # Build a class mapping, so classes are converted to indexes and vice-versa
    unique_values = set(np.unique(prediction))
    unique_values = sorted(unique_values - set(excluded_values))
    index_value_lut = [(k, v) for k, v in enumerate(unique_values)]

    counts = np.zeros(shape=(*prediction.shape, len(unique_values)), dtype=np.uint16)
    probabilities = np.zeros(shape=(*probability.shape, len(unique_values)), dtype=np.uint16)

    # Iterates for each classes
    for cls_idx, cls_value in index_value_lut:

        # Take the binary mask of the interest class, and multiplies by the probabilities
        class_mask = ((prediction == cls_value) * probability).astype(np.uint16)

        # Sets to 0 the class scores where the threshold is lower
        class_mask[probability <= conf_threshold] = 0

        # Set to 0 the class scores where the label is excluded
        for excluded_value in excluded_values:
            class_mask[prediction == excluded_value] = 0

        # Binary class mask, used to count HOW MANY neighbours pixels are used for this class
        binary_class_mask = (class_mask > 0).astype(np.uint16)

        # Creates the kernel
        kernel = np.ones(shape=(kernel_size, kernel_size), dtype=np.uint16)

        # Counts around the window the sum of probabilities for that given class
        counts[:, :, cls_idx] = convolve2d(class_mask, kernel, mode='same')

        # Counts the number of neighbors pixels that voted for that given class
        class_voters = convolve2d(binary_class_mask, kernel, mode='same')
        class_voters[class_voters == 0] = 1  # Remove the 0 values because might create divide by 0 issues

        probabilities[:, :, cls_idx] = np.divide(counts[:, :, cls_idx], class_voters)

    # Initializes output array
    aggregated_predictions = np.zeros(shape=(counts.shape[0], counts.shape[1]), dtype=np.uint16)
    # Initializes prediction output array
    aggregated_probabilities = np.zeros(shape=(counts.shape[0], counts.shape[1]), dtype=np.uint16)

    if len(unique_values) > 0:
        # Takes the indices that have the biggest scores
        aggregated_predictions_indices = np.argmax(counts, axis=2)

        # Get the new confidence score for the indices
        aggregated_probabilities = np.take_along_axis(
            probabilities,
            aggregated_predictions_indices.reshape(*aggregated_predictions_indices.shape, 1),
            axis=2
        ).squeeze()

        # Check which pixels have a counts value to 0
        no_score_mask = np.sum(counts, axis=2) == 0

        # convert back to values from indices
        for (cls_idx, cls_value) in index_value_lut:
            aggregated_predictions[aggregated_predictions_indices == cls_idx] = cls_value
            aggregated_predictions = aggregated_predictions.astype(np.uint16)

        aggregated_predictions[no_score_mask] = target_excluded_value
        aggregated_probabilities[no_score_mask] = 0

    # Setting excluded values back to their original values
    for excluded_value in excluded_values:
        aggregated_predictions[prediction == excluded_value] = excluded_value
        aggregated_probabilities[prediction == excluded_value] = 0

    return np.array(
        [aggregated_predictions, aggregated_probabilities, *input_data[2:]]
    )


def set_undecided_by_threshold(
    input_data: np.ndarray,
    threshold: int = 35,
    undecided_arable_value: int = DEFAULT_UNDECIDED_ARABLE_VALUE,
    undecided_perennial_value: int = DEFAULT_UNDECIDED_PERENNIAL_VALUE
) -> np.ndarray:
    """
    Sets the pixels with prediction probabilities under the defined threshold
    value to the target undecided value.

    :param input_data: A 3D numpy array containing in the first axis at least
                       the class mask and the probability mask.
    :param threshold: The minimum prediction confidence value under which the
                      pixels are sets to the undecided class.
    :param undecided_value_arable: Value to set all undecided arable
                      fields to. HRL codes between this variable and the
                      parameter `target_excluded_value_perennial` will be
                      considered as arable.
    :param undecided_value_perennial: Value to set all undecided perennial
                                      fields to. All HRL codes above are
                                      considered as perennial.

    :return: Two numpy arrays. The first containing the new class mask with some
             undecided pixels and the second containing the original prediction
             probability/confidence values.
    """
    prediction = input_data[0]
    probability = input_data[1]

    output_prediction = prediction.copy()

    undecided_mask = probability < threshold

    arable_undecided_mask = undecided_mask & \
        (prediction >= undecided_arable_value) & \
        (prediction < undecided_perennial_value)
    perennial_undecided_mask = undecided_mask & \
        (prediction >= undecided_perennial_value)

    output_prediction[arable_undecided_mask] = undecided_arable_value
    output_prediction[perennial_undecided_mask] = undecided_perennial_value

    return np.array([output_prediction, probability, *input_data[2:]])


def apply_datacube(cube: XarrayDataCube, context: Dict) -> XarrayDataCube:
    sys.path.insert(0, 'tmp/venv_static')
    sys.path.insert(0, 'tmp/venv')

    # load the context for the different postproc functions
    context_postpro_un = context.get('postprocessing_un', {})
    context_postpro_maj = context.get('postprocessing_maj', {})
    context_postpro_isol = context.get('postprocessing_isol', {})
    context_postpro_smal = context.get('postprocessing_smal', {})

    

    # Selects the first and only timestamp
    inarr = cube.get_array().isel(t=0)

    smoothed_output = xr.apply_ufunc(
        set_undecided_by_threshold, inarr,
        output_dtypes='uint16',
        dask='forbidden',
        kwargs=context_postpro_un
    )
    smoothed_output = xr.apply_ufunc(
        majority_vote, smoothed_output,
        output_dtypes='uint16',
        dask='forbidden',
        kwargs=context_postpro_maj
    )

    smoothed_output = xr.apply_ufunc(
        remove_isolated_pixels, smoothed_output,
        output_dtypes='uint16',
        dask='forbidden',
        kwargs=context_postpro_isol
    )

    smoothed_output = xr.apply_ufunc(
        remove_small_objects, smoothed_output,
        output_dtypes='uint16',
        dask='forbidden',
        kwargs=context_postpro_smal
    )

    smoothed_output.attrs.update(inarr.attrs)

    return XarrayDataCube(smoothed_output)


def load_cropclass_postprocess_udf() -> str:
    import os
    with open(os.path.realpath(__file__), 'r+') as f:
        return f.read()


def apply_majority_voting_openeo(input: openeo.DataCube,
                                 context: Dict) -> openeo.DataCube:
    postproc_udf = load_cropclass_postprocess_udf()

    output = input.apply_neighborhood(
        process=lambda data: data.run_udf(postproc_udf, context=context, runtime='Python'),
        size=[{'dimension': 'bands', 'value': 2, 'unit': 'px'},
              {'dimension': 'x', 'value': CHUNK_X_SIZE, 'unit': 'px'},
              {'dimension': 'y', 'value': CHUNK_Y_SIZE, 'unit': 'px'}],
        overlap=[{'dimension': 'x', 'value': CHUNK_X_OVERLAP, 'unit': 'px'},
                 {'dimension': 'y', 'value': CHUNK_Y_OVERLAP, 'unit': 'px'}]
    )

    return output
