from openeo.rest.datacube import DataCube
from cropclass.openeo.masking import scl_mask_erode_dilate


def delineation_preprocessed_inputs(
        connection, bbox, start, end,
        S2_collection='TERRASCOPE_S2_TOC_V2',
        preprocess=True,
        **processing_options
) -> DataCube:

    # TODO: ADD SAR (e.g. S1_collection='S1_GRD_SIGMA0_DESCENDING')
    # TODO: ADD METEO

    # --------------------------------------------------------------------
    # Optical data
    # --------------------------------------------------------------------

    S2_bands = ["B02", "B03", "B04", "B08"]
    if processing_options.get("masking", None) == "mask_scl_dilation":
        S2_bands.append("SCL")
    bands = connection.load_collection(
        S2_collection,
        bands=S2_bands,
        spatial_extent=bbox,
        temporal_extent=[start, end],
        max_cloud_cover=95
    )

    temporal_partition_options = {
        "indexreduction": 2,
        "temporalresolution": "ByDay",
        "tilesize": 512
    }
    bands.result_node().update_arguments(
        featureflags=temporal_partition_options)

    target_crs = processing_options.get("target_crs", None)
    if (target_crs is not None):
        bands = bands.resample_spatial(projection=target_crs, resolution=10.0)

    if processing_options.get("masking", None) == "mask_scl_dilation":

        # TODO: add cloud masking parameters
        bands = bands.process("mask_scl_dilation",
                              data=bands,
                              scl_band_name="SCL", kernel1_size=17, kernel2_size=77,
                              mask1_values=[2, 4, 5, 6, 7], mask2_values=[3, 8, 9, 10, 11],
                              erosion_kernel_size=3, ).filter_bands(
            bands.metadata.band_names[:-1])
    else:
        # Apply conservative mask and select bands
        mask = scl_mask_erode_dilate(
            connection,
            bbox,
            scl_layer_band=S2_collection + ':SCL').resample_cube_spatial(bands)
        bands = bands.mask(mask)

    if preprocess:
        # Composite 10-daily
        bands = bands.aggregate_temporal_period(period="dekad",
                                                reducer="median")

        # Linearly interpolate missing values
        bands = bands.apply_dimension(dimension="t",
                                      process="array_interpolate_linear")
    
    # Force UINT16 to avoid overflow issue with S2 data
    bands = bands.linear_scale_range(0, 65534, 0, 65534)

    bands = bands.filter_temporal(start, end)

    # forcing 16 bit
    bands = bands.linear_scale_range(0, 65534, 0, 65534)

    return bands
