from typing import Dict

import rasterio.enums
import xarray as xr
from pathlib import Path
import numpy as np

DEFAULT_NO_DATA = 65535

DEFAULT_NO_FIELD = 0

DEFAULT_TRANSLATION_TABLE = {
    "1-1-1-0": 1110,  # Wheat
    "1-1-2-0": 1120,  # Barley
    "1-1-3-0": 1130,  # Maize
    "1-1-4-0": 1140,  # Rice
    "1-1-5-1": 1150,  # Other cereals
    "1-2-1-0": 1210,  # Dry pulses
    "1-2-2-0": 1220,  # Fresh vegetables
    "1-3-1-0": 1310,  # Potatoes
    "1-3-2-0": 1320,  # Sugar beet
    "1-4-1-0": 1410,  # Sunflower
    "1-4-2-0": 1420,  # Soybeens
    "1-4-3-0": 1430,  # Rapeseed
    "1-4-4-0": 1440,  # Flax, cotton and hemp
    "1-5-0-0": 1500,  # Grass and fodded crops
    "2-0-1-0": 2010,  # Grapes
    "2-0-2-0": 2020,  # Olives
    "2-0-3-1": 2031,  # Fruits
    "2-0-3-2": 2032,  # Nuts
    "NoCrop": DEFAULT_NO_FIELD,
    "NaN": DEFAULT_NO_DATA,
    "nan": DEFAULT_NO_DATA,
    "NAN": DEFAULT_NO_DATA,
    None: DEFAULT_NO_DATA,
}


def translate_layer(input: xr.DataArray, translation_table: Dict[str, int] = None) -> xr.DataArray:
    translation_table = DEFAULT_TRANSLATION_TABLE if translation_table is None else translation_table

    def map_func(input_data: np.ndarray, translation_table: Dict[str, int]):
        output_data = np.zeros_like(input_data, dtype=np.uint16)
        for orig, dst in translation_table.items():
            output_data[input_data == orig] = dst
        return output_data

    # Two modes, either the input is a single band array, in that case no selection is required, or the input array is
    # multiband. In that case, selection of the first band is required.

    if 'bands' not in input.dims:
        translated_data = xr.apply_ufunc(map_func, input, kwargs={'translation_table': translation_table})
    else:
        translated_layer = xr.apply_ufunc(
            map_func,
            input.sel({'bands': 'croptype'}),
            kwargs={'translation_table': translation_table},
            dask='parallelized'
        )
        translated_data = input.copy()
        translated_data.loc[{'bands': 'croptype'}] = translated_layer

    return translated_data


def reproject_layer(input: xr.DataArray,
                    target_crs: str = None,
                    clip_minx: float = None,
                    clip_maxx: float = None,
                    clip_miny: float = None,
                    clip_maxy: float = None) -> xr.DataArray:
    import rasterio
    from shapely.geometry import box
    import geojson

    if None in [clip_minx, clip_maxy, clip_miny, clip_maxx]:
        clipped_output = input
    else:
        bounding_box = geojson.loads(geojson.dumps(box(
            minx=clip_minx,
            maxx=clip_maxx,
            miny=clip_miny,
            maxy=clip_maxy
        )))

        clipped_output = input.rio.clip(geometries=[bounding_box], drop=True)

    if target_crs is None:
        reprojected_output = clipped_output
    else:
        reprojected_output = clipped_output.rio.reproject(
            target_crs,
            nodata=DEFAULT_NO_DATA,
            resolution=10,
            resampling=rasterio.enums.Resampling.nearest
        )

    return reprojected_output


def apply_bvl_mask(input: xr.DataArray, bvl_mask_path: Path) -> xr.DataArray:
    import rioxarray

    # Loading and aligning the bvl mask to the
    bvl_mask = rioxarray.open_rasterio(bvl_mask_path)

    bvl_mask = bvl_mask.rio.reproject_match(
        match_data_array=input,
        resampling=rasterio.enums.Resampling.nearest
    )

    # Renaming dimension band to bands to match new repo name standards
    if 'band' in bvl_mask.dims:
        bvl_mask = bvl_mask.rename({'band': 'bands'})
    if 'band' in input.dims:
        input = input.rename({'band': 'bands'})

    if 'bands' not in input.dims:
        input_prediction = input
    else:
        input_prediction = input.isel(bands=0)

    # Applies masking on the cropclass label layer
    masked_input_prediction = xr.where(bvl_mask < 4, DEFAULT_NO_FIELD, input_prediction)
    masked_input_prediction = xr.where(bvl_mask == 6, DEFAULT_NO_FIELD, masked_input_prediction)
    masked_input_prediction = xr.where(bvl_mask == 8, DEFAULT_NO_FIELD, masked_input_prediction)
    output_prediction = xr.where(bvl_mask == 255, DEFAULT_NO_DATA, masked_input_prediction)

    if 'bands' in input.dims:
        # Applies masking on the probability layer
        input_probability = input.isel(bands=1)

        masked_input_probability = xr.where(bvl_mask < 4, DEFAULT_NO_FIELD, input_probability)
        masked_input_probability = xr.where(bvl_mask == 6, DEFAULT_NO_FIELD, masked_input_probability)
        masked_input_probability = xr.where(bvl_mask == 8, DEFAULT_NO_FIELD, masked_input_probability)
        output_probability = xr.where(bvl_mask == 255, DEFAULT_NO_FIELD, masked_input_probability)

        # Combines the two layers together
        output = xr.concat(
            [output_prediction, output_probability],
            dim='bands',
            coords='all',
            combine_attrs='no_conflicts',
        ).assign_coords(bands=['croptype', 'probability'])
    else:
        output = output_prediction

    output.spatial_ref.attrs.update(input.spatial_ref.attrs)
    output = output.rio.write_crs(input.rio.crs)

    return output


def apply_layer_format(input: xr.DataArray,
                       reprojection_params: Dict,
                       bvl_mask_path: Path,
                       translation_table: Dict[str, int] = DEFAULT_TRANSLATION_TABLE
                       ) -> xr.DataArray:

    output = translate_layer(input, translation_table)
    output = reproject_layer(output, **reprojection_params)
    output = apply_bvl_mask(output, bvl_mask_path)

    return output
