from typing import Dict, List, Tuple

import rasterio.enums
import xarray as xr
from pathlib import Path
import numpy as np

DEFAULT_OUT_OF_SCOPE = 65535

DEFAULT_NO_DATA = 65534

DEFAULT_NO_FIELD = 65533

DEFAULT_TRANSLATION_TABLE = {
    "1-1-1-0": 1110,  # Wheat
    "1-1-1-1": 1111,  # Winter Wheat
    "1-1-1-2": 1112,  # Spring wheat
    "1-1-2-0": 1120,  # Barley
    "1-1-2-1": 1121,  # Winter Barley
    "1-1-2-2": 1122,  # Spring Barley
    "1-1-3-0": 1130,  # Maize
    "1-1-4-0": 1140,  # Rice
    "1-1-5-0": 1150,  # Other cereals
    "1-1-5-1": 1151,  # Other winter cereals
    "1-1-5-2": 1152,  # Other spring cereals
    "1-2-1-0": 1210,  # Fresh vegetables
    "1-2-2-0": 1220,  # Dry pulses
    "1-3-1-0": 1310,  # Potatoes
    "1-3-2-0": 1320,  # Sugar beet
    "1-4-1-0": 1410,  # Sunflower
    "1-4-2-0": 1420,  # Soybeens
    "1-4-3-0": 1430,  # Rapeseed
    "1-4-4-0": 1440,  # Flax, cotton and hemp
    "1-5-0-0": 1500,  # Grass and fodded crops
    "2-1-0-0": 2100,  # Grapes
    "2-2-0-0": 2200,  # Olives
    "2-3-1-0": 2310,  # Fruits
    "2-3-2-0": 2320,  # Nuts
    "NoCrop": DEFAULT_NO_FIELD,
    "NaN": DEFAULT_NO_DATA,
    "nan": DEFAULT_NO_DATA,
    "NAN": DEFAULT_NO_DATA,
    None: DEFAULT_NO_DATA,
}

DEFAULT_HRL_NAMES_TABLE: Dict[int, str] = {
    1110: "wheat",
    1111: "winter_wheat",
    1112: "spring_wheat",
    1120: "barley",
    1121: "winter_barley",
    1122: "spring_barley",
    1130: "maize",
    1140: "rice",
    1150: "other_cereals",
    1151: "other_winter_cereals",
    1152: "other_spring_cereals",
    1210: "fresh_vegetables",
    1220: "dry_pulses",
    1310: "potatoes",
    1320: "sugar_beet",
    1410: "sunflower",
    1420: "soybeans",
    1430: "rapeseed",
    1440: "flax_cotton_and_hemp",
    1500: "grass/fodder",
    2100: "grapes",
    2200: "olives",
    2310: "fruits",
    2320: "nuts",
    # Undecided and other special classes
    1000: "undecided_arable",
    2000: "undecided_perrenial",
    3000: "undecided",
}

DEFAULT_HRL_NAMES_FULL_TABLE: Dict[int, str] = {
    1110: "Wheat",
    1111: "Winter wheat",
    1112: "Spring wheat",
    1120: "Barley",
    1121: "Winter barley",
    1122: "Spring barley",
    1130: "Maize",
    1140: "Rice",
    1150: "Other cereals (rye, oats, triticale)",
    1151: "Other winter cereals",
    1152: "Other spring cereals",
    1210: "Fresh Vegetables",
    1220: "Dry Pulses",
    1310: "Potatoes",
    1320: "Sugar Beet",
    1410: "Sunflower",
    1420: "Soybeans",
    1430: "Rapeseed",
    1440: "Flax, cotton and hemp",
    1500: "Grass/fodder",
    2100: "Grapes",
    2200: "Olives",
    2310: "Fruits",
    2320: "Nuts",
    # Undecided and other special classes
    3000: "Undecided",
    3100: "Undecided arable",
    3200: "Undecided perennial",
}

DEFAULT_REGROUPMENT_TABLE: Dict[int, int] = {
    111: 1110,
    112: 1120,
    113: 1130,
    114: 1140,
    115: 1150,
    121: 1210,
    122: 1220,
    131: 1310,
    132: 1320,
    141: 1410,
    142: 1420,
    143: 1430,
    144: 1440,
    150: 1500,
    210: 2100,
    220: 2200,
    231: 2310,
    232: 2320,
}

DEFAULT_PERENNIAL_CLASSES: List[int] = [
    2100, 2200, 2310, 2320, 2000
]


def translate_layer(inarr: xr.DataArray, translation_table: Dict[str, int] = None) -> xr.DataArray:
    """
    Translates the crop type values predicted by the model (string type) into HRL code using the given
    translation table. Probabilities are not changed.

    :param inarr: The input DataArray to process, needs at least one band named `croptype`
    :param translation_table: A mapping string -> int to perform changes in the croptype labels.
    """
    translation_table = DEFAULT_TRANSLATION_TABLE if translation_table is None else translation_table

    def map_func(input_data: np.ndarray, translation_table: Dict[str, int]):
        output_data = np.zeros_like(input_data, dtype=np.uint16)
        for orig, dst in translation_table.items():
            output_data[input_data == orig] = dst
        return output_data

    # Two modes, either the input is a single band array, in that case no selection is required, or the input array is
    # multiband. In that case, selection of the first band is required.

    if 'bands' not in inarr.dims:
        translated_data = xr.apply_ufunc(map_func, inarr, kwargs={
                                         'translation_table': translation_table})
    else:
        translated_layer = xr.apply_ufunc(
            map_func,
            inarr.sel({'bands': 'croptype'}),
            kwargs={'translation_table': translation_table},
            dask='parallelized'
        )
        translated_data = inarr.copy()
        translated_data.loc[{'bands': 'croptype'}] = translated_layer

    return translated_data


def _setup_aggregation_template(
    inarr: xr.DataArray,
    regroupment_table: Dict[int, int] = DEFAULT_REGROUPMENT_TABLE,
    hrl_names: Dict[int, str] = DEFAULT_HRL_NAMES_TABLE
) -> xr.DataArray:
    """
    Returns a Xarray with empty elements but with shape/coordinates that match
    the output of an aggregation. This is necessary for Dask to execute the 
    aggregation method, as dimensions and coordinates of the output must be
    known before execution to be able to do lazy computation.
    """

    n_groups = len(regroupment_table.values())
    empty_array = np.empty((2 + n_groups, *inarr.shape[1:]), dtype=np.uint16)

    band_names = ['croptype', 'probability']
    for group_nb in regroupment_table.values():
        croptype_name = hrl_names[group_nb]
        band_names.append(f'probability_{croptype_name}')

    return xr.DataArray(
        data=empty_array,
        dims=('bands', 'y', 'x'),
        coords={
            'bands': band_names,
            'y': inarr.coords['y'],
            'x': inarr.coords['x']
        }
    )


def aggregate_croptypes(
    inarr: xr.DataArray,
    regroupment_table: Dict[int, int] = DEFAULT_REGROUPMENT_TABLE,
    hrl_names: Dict[int, str] = DEFAULT_HRL_NAMES_TABLE,
) -> xr.DataArray:
    """
    Using the regroupment table, croptypes will be aggregated in groups.

    :param inarr: The input DataArray to process, needs at least one band named
                  `croptype` and one band named `max_probability`. If there is
                  a probability band per crop type, then the probability bands
                  must be named `probability_<class_name>`, where class names 
                  are lowercased, spaces are remplaced by underscores (_), and
                  commas (,) are removed.
    :param regroupment_table: A table regroupping crop types. Keys are tuples of
                              crop types in and values are target group codes.
                              Both keys/values contains HRL codes in integer.
    :param hrl_names: A table mapping HRL codes in integer with their class 
                      names.
    """
    # Getting the unique values of crop types present in the raster
    present_croptypes = np.unique(
        inarr.sel(bands='croptype').to_numpy()
    )

    # Creation of new probability array
    newprob_shape = (len(regroupment_table.values()), *inarr.shape[1:])
    new_probabilities = np.zeros(newprob_shape, dtype=np.uint8)
    new_croptype = np.zeros(inarr.shape[1:], dtype=np.uint16)

    # Track the pixels that where selected for modification
    modified_pixels = np.zeros(inarr.shape[1:], dtype=np.bool_)

    # xarray DataArray containing the crop types
    croptype = inarr.sel(bands='croptype').to_numpy()

    for group_idx, (croptype_suffix, croptype_group) in enumerate(regroupment_table.items()):
        # Finding the croptypes associated from the suffix
        croptype_classes = [
            cropclass for cropclass in present_croptypes
            if str(cropclass).startswith(str(croptype_suffix))
        ]

        # Masking selecting pixels with croptype of the current group of classes
        croptype_mask = np.isin(croptype, croptype_classes)
        modified_pixels = np.bitwise_or(modified_pixels, croptype_mask)
        new_croptype[croptype_mask] = croptype_group

        # Selecting the probabilities of the group
        probability_band_names = [
            f'probability_{hrl_names[cropclass]}'
            for cropclass in croptype_classes
        ]

        # Checking if those probabilities are available in the raster
        probability_band_names = list(
            set(probability_band_names).intersection(set(
                inarr.coords['bands'].to_numpy()))
        )

        probabilities_for_group = inarr.sel(
            bands=probability_band_names
        ).sum(dim='bands')

        new_probabilities[group_idx] = probabilities_for_group

    # Resets untouched pixels to the new croptype array
    new_croptype[~modified_pixels] = croptype[~modified_pixels]

    # Computation of max probablities
    new_probability = np.max(new_probabilities, axis=0)

    # Get group names
    all_group_names = [
        f'probability_{hrl_names[group]}'
        for group in regroupment_table.values()
    ]

    # Creation of an array with new probabilities
    outarr = xr.DataArray(
        data=np.concatenate([
            new_croptype[np.newaxis],
            new_probability[np.newaxis],
            new_probabilities
        ]),
        dims=inarr.dims,
        coords={
            'x': inarr.coords['x'],
            'y': inarr.coords['y'],
            'bands': ['croptype', 'probability', *all_group_names]
        }
    )

    return outarr


def apply_bvl_mask(inarr: xr.DataArray, bvl_mask_path: Path) -> xr.DataArray:
    import rioxarray

    # Loading and aligning the bvl mask to the
    bvl_mask = rioxarray.open_rasterio(bvl_mask_path)

    bvl_mask = bvl_mask.rio.reproject_match(
        match_data_array=inarr,
        resampling=rasterio.enums.Resampling.nearest
    )

    # Renaming dimension band to bands to match new repo name standards
    if 'band' in bvl_mask.dims:
        bvl_mask = bvl_mask.rename({'band': 'bands'})
    if 'band' in inarr.dims:
        inarr = inarr.rename({'band': 'bands'})

    if 'bands' not in inarr.dims:
        input_prediction = inarr
    else:
        input_prediction = inarr.isel(bands=0)

    # Applies masking on the cropclass label layer
    masked_input_prediction = xr.where(
        bvl_mask < 4, DEFAULT_NO_FIELD, input_prediction)
    masked_input_prediction = xr.where(
        bvl_mask == 6, DEFAULT_NO_FIELD, masked_input_prediction)
    masked_input_prediction = xr.where(
        bvl_mask == 8, DEFAULT_NO_FIELD, masked_input_prediction)
    output_prediction = xr.where(
        bvl_mask == 255, DEFAULT_NO_DATA, masked_input_prediction)

    if 'bands' in inarr.dims:
        # Applies masking on the probability layer
        input_probability = inarr.isel(bands=1)

        masked_input_probability = xr.where(
            bvl_mask < 4, DEFAULT_NO_FIELD, input_probability)
        masked_input_probability = xr.where(
            bvl_mask == 6, DEFAULT_NO_FIELD, masked_input_probability)
        masked_input_probability = xr.where(
            bvl_mask == 8, DEFAULT_NO_FIELD, masked_input_probability)
        output_probability = xr.where(
            bvl_mask == 255, DEFAULT_NO_FIELD, masked_input_probability)

        # Combines the two layers together
        outarr = xr.concat(
            [output_prediction, output_probability],
            dim='bands',
            coords='all',
            combine_attrs='no_conflicts',
        ).assign_coords(bands=['croptype', 'probability'])
    else:
        outarr = output_prediction

    outarr.spatial_ref.attrs.update(inarr.spatial_ref.attrs)
    outarr = outarr.rio.write_crs(inarr.rio.crs)

    return outarr
