from openeo.processes import array_create, is_nodata, if_
from openeo.rest.datacube import DataCube
from cropclass.openeo.masking import scl_mask_erode_dilate

from cropclass.openeo.classification import SENTINEL1_AS_SHORTS


def add_S1_bands(connection, S1_collection,
                 other_bands, bbox, start, end,
                 preprocess=True, **processing_options):
    properties = {}

    S1bands = connection.load_collection(
        S1_collection,
        bands=['VH', 'VV'],
        spatial_extent=bbox,
        temporal_extent=[start, end],
        properties=properties
    )

    if S1_collection == "SENTINEL1_GRD":
        #compute backscatter if starting from raw GRD, otherwise assume preprocessed backscatter
        S1bands = S1bands.sar_backscatter(
            coefficient='sigma0-ellipsoid',
            local_incidence_angle=False,
            # DO NOT USE MAPZEN
            elevation_model='srtmgl1',
            options={"implementation_version": "2",
                     "tile_size": 256, "otb_memory": 1024, "debug": False,"elev_geoid":"/opt/openeo-vito-aux-data/egm96.tif"}
        )

    # S1bands._pg.arguments['featureflags'] = {'experimental': True}

    # Resample to the S2 spatial resolution
    target_crs = processing_options.get("target_crs", None)
    if(target_crs is not None):
        S1bands = S1bands.resample_spatial(projection=target_crs, resolution=10.0)

    # scale to int16
    if(processing_options.get(SENTINEL1_AS_SHORTS, False)):
        #rescaling also replaces nodata introduced by orfeo with a low value
        #https://github.com/Open-EO/openeo-geopyspark-driver/issues/293
        S1bands = S1bands.apply_dimension(dimension="bands",
                                          process=lambda x: array_create(
                                              [if_(is_nodata(x[0]),1,30.0 + 10.0 * x[0].log(base=10)), if_(is_nodata(x[1]),1,30.0 + 10.0 * x[1].log(base=10))]))
        S1bands = S1bands.linear_scale_range(0, 30, 0, 30000)

    if preprocess:

        # Composite 10-daily
        S1bands = S1bands.aggregate_temporal_period(period="dekad",
                                                    reducer="mean")

        # Linearly interpolate missing values
        S1bands = S1bands.apply_dimension(dimension="t",
                                          process="array_interpolate_linear")

    # TODO: CONVERT HERE TO DB!

    # --------------------------------------------------------------------
    # Merge cubes
    # --------------------------------------------------------------------

    merged_inputs = other_bands.resample_cube_spatial(S1bands).merge_cubes(S1bands)

    return merged_inputs


def add_DEM(connection, DEM_collection, other_bands, bbox,**processing_options):

    dem = connection.load_collection(
        DEM_collection,
        spatial_extent=bbox,
    )

    # For now, add this experimental option to avoid gaps near bbox edges
    # as a result, only nearest neighbor resampling is possible right now

    # dem.result_node().update_arguments(featureflags={"experimental": True})

    # Resample to the S2 spatial resolution
    # TODO: check interpolation method
    # TODO: check no-data near edges of cube
    target_crs = processing_options.get("target_crs", None)
    if (target_crs is not None):
        dem = dem.resample_spatial(projection=target_crs, resolution=10.0)

    # collection has timestamps which we need to get rid of
    dem = dem.max_time()

    # # Cast to int
    # TODO: casting doesn't have effect for the moment
    # dem = dem.linear_scale_range(-1000., 10000., -1000, 10000)

    # --------------------------------------------------------------------
    # Merge cubes
    # --------------------------------------------------------------------

    merged_inputs = other_bands.merge_cubes(dem)

    return merged_inputs


def add_worldcover(connection, WORLDCOVER_collection, other_bands, bbox,**processing_options):
    # ESA_WORLDCOVER_10M_2020_V1

    start = "2020-12-01" if "2020" in WORLDCOVER_collection else "2021-12-01"
    end = "2021-01-30" if "2020" in WORLDCOVER_collection else "2022-01-30"

    worldcover = connection.load_collection(
        WORLDCOVER_collection,
        spatial_extent=bbox,
        temporal_extent=[start, end],
        bands=["MAP"]
    )

    if "creo" in processing_options.get("provider","") :

        temporal_partition_options = {
            "indexreduction": 0,
            "temporalresolution": "ByDay",
            "tilesize": 1024
        }
        worldcover.result_node().update_arguments(
            featureflags=temporal_partition_options)

    # Resample to the S2 spatial resolution
    worldcover = worldcover.resample_cube_spatial(other_bands, method='near')

    # collection has timestamps which we need to get rid of
    worldcover = worldcover.max_time()

    # --------------------------------------------------------------------
    # Merge cubes
    # --------------------------------------------------------------------
    other_bands = other_bands.mask(worldcover.band("MAP")!=40)

    merged_inputs = other_bands.merge_cubes(worldcover)

    return merged_inputs


def add_meteo(connection, METEO_collection, other_bands, bbox, start, end):
    # AGERA5

    meteo = connection.load_collection(
        METEO_collection,
        spatial_extent=bbox,
        bands=['temperature-mean'],
        temporal_extent=[start, end]
    )
    meteo.result_node().update_arguments(featureflags={'experimental': True})

    # Resolution nearest neighbor upsamping to 10m/px
    meteo = meteo.resample_cube_spatial(other_bands)

    # Composite 10-daily
    meteo = meteo.aggregate_temporal_period(period="dekad",
                                            reducer="mean")

    # Linearly interpolate missing values.
    # Shouldn't exist in this dataset but is good practice to do so
    meteo = meteo.apply_dimension(dimension="t",
                                  process="array_interpolate_linear")

    # Cast to float
    # meteo = meteo.linear_scale_range(0., 1000., 0., 1000.)

    # Rename band to match Radix model requirements
    meteo = meteo.rename_labels('bands', ['temperature_mean'])

    # --------------------------------------------------------------------
    # Merge cubes
    # --------------------------------------------------------------------

    merged_inputs = other_bands.merge_cubes(meteo)

    return merged_inputs


def cropclass_preprocessed_inputs(
        connection, bbox, start, end,
        S2_collection='TERRASCOPE_S2_TOC_V2',
        S1_collection='SENTINEL1_GRD',
        DEM_collection='COPERNICUS_30',
        METEO_collection='AGERA5',
        WORLDCOVER_collection=None,
        preprocess=True,
        **processing_options
) -> DataCube:

    # TODO: ADD SAR (e.g. S1_collection='S1_GRD_SIGMA0_DESCENDING')
    # TODO: ADD METEO

    # --------------------------------------------------------------------
    # Optical data
    # --------------------------------------------------------------------

    S2_bands = ["B02", "B03", "B04", "B05",
            "B06", "B07", "B08", "B11",
            "B12"]
    if processing_options.get("masking",None) == "mask_scl_dilation":
        S2_bands.append("SCL")
    bands = connection.load_collection(
        S2_collection,
        bands=S2_bands,
        spatial_extent=bbox,
        temporal_extent=[start, end]
    )

    """
    These settings force openeo to read a full timeseries at a time,
    which allows for some performance optimizations. 
    """
    temporal_partition_options = {
        "indexreduction": 0,
        "temporalresolution": "None",
        "tilesize": 256
    }

    if "creo" in processing_options.get("provider","") :

        temporal_partition_options = {
            "indexreduction": 2,
            "temporalresolution": "ByDay",
            "tilesize": 1024
        }
    bands.result_node().update_arguments(
        featureflags=temporal_partition_options)


    if processing_options.get("masking",None) == "mask_scl_dilation":

        # TODO: add cloud masking parameters
        bands = bands.process("mask_scl_dilation",
                              data=bands,
                              scl_band_name="SCL", kernel1_size=0, kernel2_size=91,
                              erosion_kernel_size=3, ).filter_bands(
            bands.metadata.band_names[:-1])
        bands = bands.resample_spatial(projection=3035, resolution=10.0)
    else:
        # Apply conservative mask and select bands
        mask = scl_mask_erode_dilate(
            connection,
            bbox,
            scl_layer_band=S2_collection + ':SCL').resample_cube_spatial(bands)
        bands = bands.mask(mask)

    if preprocess:
        # Composite 10-daily
        bands = bands.aggregate_temporal_period(period="dekad",
                                                reducer="median")

        # Linearly interpolate missing values
        bands = bands.apply_dimension(dimension="t",
                                      process="array_interpolate_linear")

    # --------------------------------------------------------------------
    # AGERA5 Meteo data
    # --------------------------------------------------------------------
    if METEO_collection is not None:
        bands = add_meteo(connection, METEO_collection,
                          bands, bbox, start, end)

    # --------------------------------------------------------------------
    # SAR data
    # --------------------------------------------------------------------
    if S1_collection is not None:
        bands = add_S1_bands(connection, S1_collection,
                             bands, bbox, start, end, preprocess=preprocess, **processing_options)

    bands = bands.filter_temporal(start, end)
    # --------------------------------------------------------------------
    # DEM data
    # --------------------------------------------------------------------
    if DEM_collection is not None:
        bands = add_DEM(connection, DEM_collection, bands, bbox, **processing_options)

    # --------------------------------------------------------------------
    # WorldCover data
    # --------------------------------------------------------------------
    if WORLDCOVER_collection is not None:
        bands = add_worldcover(connection, WORLDCOVER_collection, bands, bbox, **processing_options)

    if S1_collection == None or processing_options.get(SENTINEL1_AS_SHORTS,False):
        #forcing 16bit
        bands = bands.linear_scale_range(0,30000,0,30000)

    return bands


def cropclass_raw_inputs(*args, **kwargs):
    return cropclass_preprocessed_inputs(
        *args, **kwargs, preprocess=False)
