from openeo.rest.datacube import DataCube


def add_S1_bands(connection, S1_collection,
                 other_bands, bbox, start, end,
                 preprocess=True):
    S1bands = connection.load_collection(
        S1_collection,
        bands=['VH', 'VV'],
        spatial_extent=bbox,
        temporal_extent=[start, end]
    )

    S1bands = S1bands.sar_backscatter(
        coefficient='sigma0-ellipsoid',
        local_incidence_angle=False,
        # DO NOT USE MAPZEN
        elevation_model='strmgl1'
    )

    # S1bands._pg.arguments['featureflags'] = {'experimental': True}

    if preprocess:

        # Composite 10-daily
        S1bands = S1bands.aggregate_temporal_period(period="dekad",
                                                    reducer="mean")

        # Linearly interpolate missing values
        S1bands = S1bands.apply_dimension(dimension="t",
                                          process="array_interpolate_linear")

    # TODO: CONVERT HERE TO DB!

    # Resample to the S2 spatial resolution
    S1bands = S1bands.resample_cube_spatial(other_bands)

    # --------------------------------------------------------------------
    # Merge cubes
    # --------------------------------------------------------------------

    merged_inputs = other_bands.merge_cubes(S1bands)

    return merged_inputs


def add_DEM(connection, DEM_collection, other_bands, bbox):

    dem = connection.load_collection(
        DEM_collection,
        spatial_extent=bbox,
    )

    # For now, add this experimental option to avoid gaps near bbox edges
    # as a result, only nearest neighbor resampling is possible right now
    # dem.result_node().update_arguments(featureflags={"experimental": True})

    # Resample to the S2 spatial resolution
    # TODO: check interpolation method
    # TODO: check no-data near edges of cube
    dem = dem.resample_cube_spatial(other_bands, method='near')

    # collection has timestamps which we need to get rid of
    dem = dem.max_time()

    # # Cast to int
    # TODO: casting doesn't have effect for the moment
    # dem = dem.linear_scale_range(-1000., 10000., -1000, 10000)

    # --------------------------------------------------------------------
    # Merge cubes
    # --------------------------------------------------------------------

    merged_inputs = other_bands.merge_cubes(dem)

    return merged_inputs


def add_worldcover(connection, WORLDCOVER_collection, other_bands, bbox):
    # ESA_WORLDCOVER_10M_2020_V1

    worldcover = connection.load_collection(
        WORLDCOVER_collection,
        spatial_extent=bbox,
        temporal_extent=["2020-12-01","2021-01-30"],
        bands=["MAP"]
    )

    # Resample to the S2 spatial resolution
    worldcover = worldcover.resample_cube_spatial(other_bands, method='near')

    # collection has timestamps which we need to get rid of
    worldcover = worldcover.max_time()

    # --------------------------------------------------------------------
    # Merge cubes
    # --------------------------------------------------------------------

    merged_inputs = other_bands.merge_cubes(worldcover)

    return merged_inputs


def add_meteo(connection, METEO_collection, other_bands, bbox, start, end):
    # AGERA5

    meteo = connection.load_collection(
        METEO_collection,
        spatial_extent=bbox,
        bands=['temperature-mean'],
        temporal_extent=[start, end]
    )
    meteo.result_node().update_arguments(featureflags={'experimental': True})

    # Resolution nearest neighbor upsamping to 10m/px
    meteo = meteo.resample_cube_spatial(other_bands)

    # Composite 10-daily
    meteo = meteo.aggregate_temporal_period(period="dekad",
                                            reducer="mean")

    # Linearly interpolate missing values.
    # Shouldn't exist in this dataset but is good practice to do so
    meteo = meteo.apply_dimension(dimension="t",
                                  process="array_interpolate_linear")

    # Cast to float
    # meteo = meteo.linear_scale_range(0., 1000., 0., 1000.)

    # --------------------------------------------------------------------
    # Merge cubes
    # --------------------------------------------------------------------

    merged_inputs = other_bands.merge_cubes(meteo)

    return merged_inputs


def cropclass_preprocessed_inputs(
        connection, bbox, start, end,
        S2_collection='TERRASCOPE_S2_TOC_V2',
        S1_collection='SENTINEL1_GRD',
        DEM_collection='COPERNICUS_30',
        METEO_collection='AGERA5',
        WORLDCOVER_collection=None,
        preprocess=True) -> DataCube:

    # TODO: ADD SAR (e.g. S1_collection='S1_GRD_SIGMA0_DESCENDING')
    # TODO: ADD METEO

    # --------------------------------------------------------------------
    # Optical data
    # --------------------------------------------------------------------

    bands = connection.load_collection(
        S2_collection,
        bands=["B02", "B03", "B04", "B05",
               "B06", "B07", "B08", "B11",
               "B12", "SCL"],
        spatial_extent=bbox,
        temporal_extent=[start, end]
    )

    """
    These settings force openeo to read a full timeseries at a time, which allows for some performance optimizations. 
    """
    temporal_partition_options = {
        "indexreduction": 0,
        "temporalresolution": "None",
        "tilesize": 256
    }
    bands.result_node().update_arguments(featureflags=temporal_partition_options)

    # Apply conservative mask and select bands
    # TODO: add cloud masking parameters
    bands = bands.process("mask_scl_dilation",
                          data=bands,
                          scl_band_name="SCL").filter_bands(
        bands.metadata.band_names[:-1])

    if preprocess:
        # Composite 10-daily
        bands = bands.aggregate_temporal_period(period="dekad",
                                                reducer="median")

        # Linearly interpolate missing values
        bands = bands.apply_dimension(dimension="t",
                                      process="array_interpolate_linear")

    # --------------------------------------------------------------------
    # AGERA5 Meteo data
    # --------------------------------------------------------------------
    if METEO_collection is not None:
        bands = add_meteo(connection, METEO_collection, bands, bbox, start, end)

    # --------------------------------------------------------------------
    # SAR data
    # --------------------------------------------------------------------
    if S1_collection is not None:
        bands = add_S1_bands(connection, S1_collection,
                             bands, bbox, start, end, preprocess=preprocess)

    # --------------------------------------------------------------------
    # DEM data
    # --------------------------------------------------------------------
    if DEM_collection is not None:
        bands = add_DEM(connection, DEM_collection, bands, bbox)

    # --------------------------------------------------------------------
    # WorldCover data
    # --------------------------------------------------------------------
    if WORLDCOVER_collection is not None:
        bands = add_worldcover(connection, WORLDCOVER_collection, bands, bbox)

    return bands


def cropclass_raw_inputs(*args, **kwargs):
    return cropclass_preprocessed_inputs(
        *args, **kwargs, preprocess=False)
