from openeo.rest.datacube import DataCube
from openeo.processes import date_shift, if_, eq, max
from openeo.internal.graph_building import PGNode
from openeo.rest.connection import Connection
from cropsar_px.udf_gan import WINDOW_SIZE


def cropsar_pixel_inputs(
        vito_connection: Connection,
        geo,
        start,
        end,
        gan_window_half: int = 80,
        s1_collection="SENTINEL1_GRD",
        s2_collection="SENTINEL2_L2A") -> DataCube:
    """
    Get the input data cube for the CropSAR_px model.
    The data cube consists of a cloud masked Sentinel-2 NDVI band and Sentinel-1 VV and VH bands.
    The requested date range will be expanded with the GAN window half time.

    :param vito_connection: openEO connection object
    :param geo: geometry for spatial filter
    :param start: start date
    :param end: end date
    :param gan_window_half: GAN window half time
    :param s1_collection: Sentinel-1 collection (SENTINEL1_GRD, S1_GRD_SIGMA0 or TERRASCOPE_S1_GAMMA0_V1)
    :param s2_collection: Sentinel-2 collection (SENTINEL2_L2A or TERRASCOPE_S2_TOC_V2)
    """

    start_shifted = date_shift(start, -gan_window_half, "day")
    end_shifted = date_shift(end, gan_window_half, "day")

    S2bands = vito_connection.load_collection(
        s2_collection,
        bands=['B04', 'B08', 'SCL'],
        temporal_extent=[start_shifted, end_shifted]
    )
    S2bands = S2bands.process(
        "mask_scl_dilation", data=S2bands, scl_band_name="SCL")

    B4band = S2bands.band('B04')
    B8band = S2bands.band('B08')
    S2ndvi = (B8band - B4band) / (B8band + B4band)
    S2ndvi = S2ndvi.add_dimension("bands", "S2ndvi", type="bands")

    if s1_collection == "S1_GRD_SIGMA0":
        S1bands_asc = vito_connection.load_collection(
            "S1_GRD_SIGMA0_ASCENDING",
            bands=['VH', 'VV'],
            temporal_extent=[start_shifted, end_shifted]
        )
        S1bands_desc = vito_connection.load_collection(
            "S1_GRD_SIGMA0_DESCENDING",
            bands=['VH', 'VV'],
            temporal_extent=[start_shifted, end_shifted]
        )
        S1bands = S1bands_asc.merge_cubes(S1bands_desc, overlap_resolver=max)
    else:
        S1bands = vito_connection.load_collection(
            s1_collection,
            bands=['VH', 'VV'],
            temporal_extent=[start_shifted, end_shifted]
        )

    # TODO: remove duplicate code when supported in openEO
    # https://github.com/Open-EO/openeo-python-driver/issues/109
    if isinstance(s1_collection, str):
        if s1_collection == "SENTINEL1_GRD":
            S1bands = S1bands.sar_backscatter().apply(lambda x: 10 * x.log(base=10))
        if "SIGMA0" in s1_collection or "TERRASCOPE_S1_GAMMA0_V1" == s1_collection:
            S1bands = S1bands.apply(lambda x: 10 * x.log(base=10))
    else:
        S1bands = if_(
            eq(s1_collection, "SENTINEL1_GRD"),
            S1bands.sar_backscatter().apply(lambda x: 10 * x.log(base=10)),
            S1bands.apply(lambda x: 10 * x.log(base=10))
        )
    S1bands = S1bands.resample_cube_spatial(S2ndvi)
    merged_cube = S2ndvi
    merged_cube = merged_cube.merge_cubes(S1bands)

    merged_cube = filter_spatial_with_buffer(
        merged_cube, geo, WINDOW_SIZE * 10)

    return merged_cube


def filter_spatial_with_buffer(datacube: DataCube, geo, distance: int, unit: str = "meter") -> DataCube:
    """
    Apply a buffer to a geometry and use it to filter spatially.
    Custom method because `vector_buffer` is not properly supported in the openEO Python client.
    :param datacube: openEO datacube
    :param geo: geometry to be buffered and used for spatial filter
    :param distance: size of the buffer
    :param unit: unit in which the distance is measured
    """
    geo = datacube._get_geometry_argument(geo, valid_geojson_types=[
        "Point", "MultiPoint", "LineString", "MultiLineString",
        "Polygon", "MultiPolygon", "GeometryCollection", "FeatureCollection"
    ])

    buffer = PGNode(process_id="vector_buffer", arguments={
        "geometry": geo,
        "distance": distance,
        "unit": unit
    })
    datacube = datacube.filter_spatial(geo)
    datacube._pg.arguments['geometries'] = {"from_node": buffer}
    return datacube
