from openeo.processes import array_create, if_, is_nodata, power
from openeo.rest.datacube import DataCube

from cropclass.openeo.masking import scl_mask_erode_dilate
from cropclass.utils import laea20km_id_to_extent
from cropclass.utils.catalogue_check import (catalogue_check_S1,
                                             catalogue_check_S2)


def add_S1_bands(connection, S1_collection,
                 other_bands, bbox, start, end,
                 preprocess=True, **processing_options):
    """Method to add S1 bands to datacube

    Args:
        S1_collection (str): name of the S1 collection
        other_bands (DataCube): OpenEO datacube to add bands to

    Available processing_options:
        s1_orbitdirection
        provider
        target_crs
    """
    isCreo = "creo" in processing_options.get("provider", "").lower()
    orbit_direction = processing_options.get('s1_orbitdirection', None)

    if isCreo:
        orbit_direction = catalogue_check_S1(orbit_direction, start, end, bbox)

    if orbit_direction is not None:
        properties = {"sat:orbit_state": lambda orbdir: orbdir == orbit_direction}  # NOQA
    else:
        properties = {}

    #extra filter on VV&VH because wrong products with HH&HV where found. Only on creo
    if isCreo:
        properties.update({"polarisation" : lambda polar: polar == "VV&VH"})

    # Load collection
    S1bands = connection.load_collection(
        S1_collection,
        bands=['VH', 'VV'],
        spatial_extent=bbox,
        temporal_extent=[start, end],
        properties=properties
    )

    if S1_collection == "SENTINEL1_GRD":
        # compute backscatter if starting from raw GRD,
        # otherwise assume preprocessed backscatter
        S1bands = S1bands.sar_backscatter(
            coefficient='sigma0-ellipsoid',
            local_incidence_angle=False,
            # DO NOT USE MAPZEN
            elevation_model='COPERNICUS_30' if isCreo else None,
            options={"implementation_version": "2",
                     "tile_size": 256, "otb_memory": 1024, "debug": False,
                     "elev_geoid": "/opt/openeo-vito-aux-data/egm96.tif"}
        )
    else:
        pass
        # temporal_partition_options = {
        #     "indexreduction": 2,
        #     "temporalresolution": "ByDay",
        #     "tilesize": 512
        # }
        # S1bands.result_node().update_arguments(
        #     featureflags=temporal_partition_options)

    # S1bands._pg.arguments['featureflags'] = {'experimental': True}

    # Resample to the S2 spatial resolution
    target_crs = processing_options.get("target_crs", None)
    if target_crs is not None:
        S1bands = S1bands.resample_spatial(
            projection=target_crs, resolution=10.0)

    if preprocess:

        # Composite to dekads
        S1bands = S1bands.aggregate_temporal_period(period="dekad",
                                                    reducer="mean")

        # Linearly interpolate missing values
        S1bands = S1bands.apply_dimension(dimension="t",
                                          process="array_interpolate_linear")

    # Scale to int16
    if isCreo:
        # for CREO, rescaling also replaces nodata introduced by orfeo
        # with a low value
        # https://github.com/Open-EO/openeo-geopyspark-driver/issues/293
        # TODO: check if nodata is correctly handled in Orfeo
        S1bands = S1bands.apply_dimension(
            dimension="bands",
            process=lambda x: array_create(
                [if_(is_nodata(x[0]), 1, power(
                    base=10, p=(10.0 * x[0].log(base=10) + 83.) / 20.)),
                    if_(is_nodata(x[1]), 1, power(
                        base=10, p=(10.0 * x[1].log(base=10) + 83.) / 20.))]))
    else:
        S1bands = S1bands.apply_dimension(
            dimension="bands",
            process=lambda x: array_create(
                [power(base=10, p=(10.0 * x[0].log(base=10) + 83.) / 20.),
                 power(base=10, p=(10.0 * x[1].log(base=10) + 83.) / 20.)]))

    S1bands = S1bands.linear_scale_range(1, 65534, 1, 65534)

    # --------------------------------------------------------------------
    # Merge cubes
    # --------------------------------------------------------------------

    merged_inputs = other_bands.resample_cube_spatial(
        S1bands).merge_cubes(S1bands)

    return merged_inputs


def add_DEM(connection, DEM_collection, other_bands, bbox,
            **processing_options):

    dem = connection.load_collection(
        DEM_collection,
        spatial_extent=bbox,
    )

    # if "creo" in processing_options.get("provider", ""):

    #     temporal_partition_options = {
    #         "indexreduction": 0,
    #         "temporalresolution": "ByDay",
    #         "tilesize": 1024
    #     }
    #     dem.result_node().update_arguments(
    #         featureflags=temporal_partition_options)

    # For now, add this experimental option to avoid gaps near bbox edges
    # as a result, only nearest neighbor resampling is possible right now

    # dem.result_node().update_arguments(featureflags={"experimental": True})

    # Resample to the S2 spatial resolution
    # TODO: check interpolation method
    # TODO: check no-data near edges of cube
    target_crs = processing_options.get("target_crs", None)
    if (target_crs is not None):
        dem = dem.resample_spatial(projection=target_crs, resolution=10.0,
                                   method='cubic')

    # collection has timestamps which we need to get rid of
    dem = dem.max_time()

    # --------------------------------------------------------------------
    # Merge cubes
    # --------------------------------------------------------------------

    merged_inputs = other_bands.merge_cubes(dem)

    return merged_inputs

def add_mask_stac(connection, mask_collection, other_bands, bbox,
                   **processing_options):

    year = str(processing_options.get("year","2021"))
    mask = connection.load_stac(
        mask_collection,
        spatial_extent=bbox,
        temporal_extent=[f"{year}-01-01", f"{year}-12-31"]
    )

    target_crs = processing_options.get("target_crs", None)
    if (target_crs is not None):
        mask = mask.resample_spatial(projection=target_crs, resolution=10.0)


    # --------------------------------------------------------------------
    # Merge cubes
    # --------------------------------------------------------------------
    if processing_options.get('apply_bvl_mask', False):
        from openeo.processes import and_
        mask = mask.apply(lambda mask_band:and_(and_((mask_band != 4), (mask_band != 5)),(mask_band != 7)))
        other_bands = other_bands.mask(mask.max_time())

    return other_bands

def add_worldcover(connection, WORLDCOVER_collection, other_bands, bbox,
                   **processing_options):
    # ESA_WORLDCOVER_10M_2020_V1

    start = "2020-12-01" if "2020" in WORLDCOVER_collection else "2021-12-01"
    end = "2021-01-30" if "2020" in WORLDCOVER_collection else "2022-01-30"

    worldcover = connection.load_collection(
        WORLDCOVER_collection,
        spatial_extent=bbox,
        temporal_extent=[start, end],
        bands=["MAP"]
    )

    # if "creo" in processing_options.get("provider", ""):

    #     temporal_partition_options = {
    #         "indexreduction": 0,
    #         "temporalresolution": "ByDay",
    #         "tilesize": 1024
    #     }
    #     worldcover.result_node().update_arguments(
    #         featureflags=temporal_partition_options)

    # Resample to the S2 spatial resolution
    worldcover = worldcover.resample_cube_spatial(other_bands, method='near')

    # collection has timestamps which we need to get rid of
    worldcover = worldcover.max_time()

    # --------------------------------------------------------------------
    # Merge cubes
    # --------------------------------------------------------------------
    if processing_options.get('apply_worldcovermask', False):
        other_bands = other_bands.mask(worldcover.band("MAP") != 40)

    merged_inputs = other_bands.merge_cubes(worldcover)

    return merged_inputs


def add_meteo(connection, METEO_collection, other_bands, bbox,
              start, end, target_crs=None):
    # AGERA5

    meteo = connection.load_collection(
        METEO_collection,
        spatial_extent=bbox,
        bands=['temperature-mean'],
        temporal_extent=[start, end]
    )

    if (target_crs is not None):
        meteo = meteo.resample_spatial(projection=target_crs, resolution=10.0)

    # Composite to dekads
    meteo = meteo.aggregate_temporal_period(period="dekad",
                                            reducer="mean")

    # Linearly interpolate missing values.
    # Shouldn't exist in this dataset but is good practice to do so
    meteo = meteo.apply_dimension(dimension="t",
                                  process="array_interpolate_linear")

    # Rename band to match Radix model requirements
    meteo = meteo.rename_labels('bands', ['temperature_mean'])

    # --------------------------------------------------------------------
    # Merge cubes
    # or return just meteo
    # --------------------------------------------------------------------
    if other_bands is None:
        return meteo

    merged_inputs = other_bands.merge_cubes(meteo)

    return merged_inputs


def cropclass_features_sar_optical_dem(year, laea_id, connection,
                                       provider="terrascope"):
    """
    Convenience method to retrieve features from openEO based on
    laea 20km grid id

    When saving directly to netCDF, sufficient driver memory is
    needed to prevent out of memory errors:
    "driver-memory": "5G",
    "driver-memoryOverhead": "2G",
    Args:
        year:
        laea_id:
        connection:
        provider:

    Returns:

    """

    STARTDATE = 'YYYY-03-01'.replace('YYYY', str(year))
    ENDDATE = 'YYYY-10-31'.replace('YYYY', str(year))
    extent = laea20km_id_to_extent(laea_id)

    isTerra = "terra" in provider.lower()

    collections = dict(
        S2_collection="SENTINEL2_L2A" if not isTerra else "TERRASCOPE_S2_TOC_V2",  # NOQA
        WORLDCOVER_collection="ESA_WORLDCOVER_10M_2021_V2", METEO_collection=None,  # NOQA
        S1_collection="SENTINEL1_GRD_SIGMA0" if isTerra else "SENTINEL1_GRD",
        DEM_collection="COPERNICUS_30")

    processing = dict(target_crs=3035, masking="mask_scl_dilation",
                      provider=provider)

    return cropclass_preprocessed_inputs(
        connection,
        extent,
        STARTDATE,
        ENDDATE, **collections, **processing)


def cropclass_preprocessed_inputs(
        connection, bbox, start: str, end: str,
        S2_collection='TERRASCOPE_S2_TOC_V2',
        S1_collection='SENTINEL1_GRD_SIGMA0',
        DEM_collection='COPERNICUS_30',
        METEO_collection='AGERA5',
        WORLDCOVER_collection=None,
        BVL_collection = None,
        preprocess=True,
        masking='mask_scl_dilation',
        **processing_options) -> DataCube:
    """Main method to get preprocessed inputs from OpenEO for
    downstream crop type mapping.

    Args:
        connection: OpenEO connection instance
        bbox (_type_): _description_
        start (str): Start date for requested input data (yyyy-mm-dd)
        end (str): Start date for requested input data (yyyy-mm-dd)
        S2_collection (str, optional): Collection name for S2 data.
                        Defaults to
                        'TERRASCOPE_S2_TOC_V2'.
        S1_collection (str, optional): Collection name for S1 data.
                        Defaults to
                        'SENTINEL1_GRD'.
        DEM_collection (str, optional): Collection name for DEM data.
                        Defaults to
                        'COPERNICUS_30'.
        METEO_collection (str, optional): Collection name for
                        meteo data. Defaults to 'AGERA5'.
        WORLDCOVER_collection (str, optional): _description_.
                        Defaults to None.
        preprocess (bool, optional): Apply compositing and interpolation.
                        Defaults to True.
        masking (str, optional): Masking method to be applied.
                                One of ['satio', 'mask_scl_dilation', None]
                                Defaults to 'mask_scl_dilation'.

    Returns:
        DataCube: OpenEO DataCube wich the requested inputs
    """

    # --------------------------------------------------------------------
    # Optical data
    # --------------------------------------------------------------------

    S2_bands = ["B02", "B03", "B04", "B05",
                "B06", "B07", "B08", "B11",
                "B12"]
    if masking not in ['satio', 'mask_scl_dilation', 'to_scl_dilation_mask', None]:
        raise ValueError(f'Unknown masking option `{masking}`')
    if masking in ['mask_scl_dilation']:
        # Need SCL band to mask
        S2_bands.append("SCL")
    bands = connection.load_collection(
        S2_collection,
        bands=S2_bands,
        spatial_extent=bbox,
        temporal_extent=[start, end],
        max_cloud_cover=95
    )

    # S2URL creo only accepts request in EPSG:4326
    isCreo = "creo" in processing_options.get("provider", "").lower()
    if isCreo:
        catalogue_check_S2(start, end, bbox)

    # NOTE: currently the tunings are disabled.
    #
    temporal_partition_options = {
        "indexreduction": 2,
        "temporalresolution": "ByDay",
        "tilesize": 1024
    }
    if masking == "to_scl_dilation_mask":
        bands.result_node().update_arguments(featureflags=temporal_partition_options)

    target_crs = processing_options.get("target_crs", None)
    if (target_crs is not None):
        bands = bands.resample_spatial(projection=target_crs, resolution=10.0)

    # NOTE: For now we mask again snow/ice because clouds
    # are sometimes marked as SCL value 11!
    if masking == 'mask_scl_dilation':
        # TODO: double check cloud masking parameters
        # https://github.com/Open-EO/openeo-geotrellis-extensions/blob/develop/geotrellis-common/src/main/scala/org/openeo/geotrelliscommon/CloudFilterStrategy.scala#L54  # NOQA
        bands = bands.process(
            "mask_scl_dilation",
            data=bands,
            scl_band_name="SCL",
            kernel1_size=17, kernel2_size=77,
            mask1_values=[2, 4, 5, 6, 7],
            mask2_values=[3, 8, 9, 10, 11],
            erosion_kernel_size=3).filter_bands(bands.metadata.band_names[:-1])
    elif masking == 'to_scl_dilation_mask':
        # TODO: double check cloud masking parameters
        # https://github.com/Open-EO/openeo-geotrellis-extensions/blob/develop/geotrellis-common/src/main/scala/org/openeo/geotrelliscommon/CloudFilterStrategy.scala#L54  # NOQA
        scl = connection.load_collection(
            S2_collection,
            bands=["SCL"],
            spatial_extent=bbox,
            temporal_extent=[start, end],
            max_cloud_cover=95
        )
        scl.result_node().update_arguments(featureflags=temporal_partition_options)
        if target_crs is not None:
            scl = scl.resample_spatial(projection=target_crs, resolution=10.0)
        mask = scl.process(
            "to_scl_dilation_mask",
            data=scl,
            kernel1_size=17, kernel2_size=77,
            mask1_values=[2, 4, 5, 6, 7],
            mask2_values=[3, 8, 9, 10, 11],
            erosion_kernel_size=3)
        bands = bands.mask(mask)
    elif masking == 'satio':
        # Apply satio-based mask
        mask = scl_mask_erode_dilate(
            connection,
            bbox,
            scl_layer_band=S2_collection + ':SCL',
            target_crs=target_crs).resample_cube_spatial(bands)
        bands = bands.mask(mask)

    if preprocess:
        # Composite to dekads
        bands = bands.aggregate_temporal_period(period="dekad",
                                                reducer="median")

        # TODO: if we would disable it here, nodata values
        # will be 65535 and we need to cope with that later
        # Linearly interpolate missing values
        bands = bands.apply_dimension(dimension="t",
                                      process="array_interpolate_linear")

    # Force UINT16 to avoid overflow issue with S2 data
    bands = bands.linear_scale_range(0, 65534, 0, 65534)

    # --------------------------------------------------------------------
    # AGERA5 Meteo data
    # --------------------------------------------------------------------
    if METEO_collection is not None:
        bands = add_meteo(connection, METEO_collection,
                          bands, bbox, start, end,
                          target_crs=target_crs)

    # --------------------------------------------------------------------
    # SAR data
    # --------------------------------------------------------------------
    if S1_collection is not None:
        bands = add_S1_bands(connection, S1_collection,
                             bands, bbox, start, end,
                             **processing_options)

    bands = bands.filter_temporal(start, end)


    # --------------------------------------------------------------------
    # DEM data
    # --------------------------------------------------------------------
    if DEM_collection is not None:
        bands = add_DEM(connection, DEM_collection,
                        bands, bbox, **processing_options)

    if BVL_collection is not None:
        bands = add_mask_stac(connection, BVL_collection,
                              bands, bbox, **processing_options)

    # --------------------------------------------------------------------
    # WorldCover data
    # --------------------------------------------------------------------
    if WORLDCOVER_collection is not None:
        bands = add_worldcover(
            connection, WORLDCOVER_collection, bands,
            bbox, **processing_options)
    elif processing_options.get('apply_worldcovermask', False):
        raise ValueError('Cannot mask without WorldCoverCollection')

    # forcing 16bit
    bands = bands.linear_scale_range(0, 65534, 0, 65534)

    return bands


def cropclass_raw_inputs(*args, **kwargs):
    return cropclass_preprocessed_inputs(
        *args, **kwargs, preprocess=False)
