import json
import pkg_resources
import openeo
from typing import Union, Optional
from openeo.api.process import Parameter
from openeo.processes import if_, ProcessBuilder
from openeo.rest.datacube import DataCube
from openeo.rest.connection import Connection

from cropsar_px.udf_gan import load_cropsar_px_udf
from cropsar_px.preprocessing_openeo import cropsar_pixel_inputs
from cropsar_px.udf_gan import WINDOW_SIZE


def get_connection() -> Connection:
    return openeo.connect("https://openeo.vito.be").authenticate_basic()


def get_process_graph():
    stream = pkg_resources.resource_stream("cropsar_px", "resources/udp.json")
    return json.load(stream)


def post_process(
        cropsar_pixel_cube: DataCube,
        geo,
        output_mask: Union[bool, Parameter]) -> Union[DataCube, ProcessBuilder]:
    """
    Post-processing on the CropSAR_px data cube. Sets the correct labels and scales the output to byte.
    :param cropsar_pixel_cube: data cube
    :param geo: geometry
    :param output_mask: indicates whether to export ground truth mask
    """
    cropsar_pixel_cube = cropsar_pixel_cube.filter_spatial(geo)

    # used if `output_mask` evaluates to false
    cube_ndvi_only: DataCube = cropsar_pixel_cube.rename_labels("bands", ["NDVI"])
    cube_ndvi_only = cube_ndvi_only.linear_scale_range(-0.08, 1, 0, 250)

    # used if `output_mask` evaluates to true
    cube_ndvi_mask: DataCube = cropsar_pixel_cube.rename_labels("bands", ["NDVI", "mask"])
    ndvi_band = cube_ndvi_mask.band("NDVI").linear_scale_range(-0.08, 1, 0, 250)
    mask_band = cube_ndvi_mask.band("mask").linear_scale_range(0, 250, 0, 250)  # rescale to byte but keep value
    cube_ndvi_mask = ndvi_band.merge_cubes(mask_band)

    # do actual if check on `output_mask`
    cropsar_pixel_cube = cropsar_pixel_cube.process(
        "if",
        value=output_mask,
        accept=cube_ndvi_mask,
        reject=cube_ndvi_only
    )

    return cropsar_pixel_cube


def apply_udf(input_cube: DataCube, context: dict) -> DataCube:
    """
    Apply the UDF on the given data cube with the parameters in the context.

    :param input_cube: input data cube
    :param context: UDF parameters
    """
    gan_udf_code = load_cropsar_px_udf()
    cropsar_pixel_cube = input_cube.apply_neighborhood(
        lambda data: data.run_udf(udf=gan_udf_code, runtime='Python', context=context),
        size=[
            {'dimension': 'x', 'value': WINDOW_SIZE - WINDOW_SIZE / 8, 'unit': 'px'},
            {'dimension': 'y', 'value': WINDOW_SIZE - WINDOW_SIZE / 8, 'unit': 'px'}
        ],
        overlap=[
            {'dimension': 'x', 'value': WINDOW_SIZE / 16, 'unit': 'px'},
            {'dimension': 'y', 'value': WINDOW_SIZE / 16, 'unit': 'px'}
        ]
    )
    return cropsar_pixel_cube


def get_context(
        geo: Union[dict, Parameter],
        startdate: Union[str, Parameter],
        enddate: Union[str, Parameter],
        inpaint_only: Union[bool, Parameter],
        output_mask: Union[bool, Parameter],
        nrt_mode: Union[bool, Parameter],
        drop_dates: Union[Optional[list], Parameter],
        **extra_context) -> dict:
    """
    Get the UDF context. Takes care of the different behaviour for concrete values or OpenEO parameters.

    :param geo: geometry
    :param startdate: start date
    :param enddate: end date
    :param output_mask: indicates whether to output ground truth mask
    """
    parameterized = any(isinstance(param, Parameter) for param in
                        [geo, startdate, enddate, inpaint_only, output_mask, nrt_mode, drop_dates])

    if parameterized:
        context = {
            "startdate": {"from_parameter": startdate.name},
            "enddate": {"from_parameter": enddate.name},
            "inpaint_only": {"from_parameter": inpaint_only.name},
            "output_mask": {"from_parameter": output_mask.name},
            "nrt_mode": {"from_parameter": nrt_mode.name},
            "drop_dates": {"from_parameter": drop_dates.name},
        }
    else:
        context = {
            "startdate": startdate,
            "enddate": enddate,
            "inpaint_only": inpaint_only,
            "output_mask": output_mask,
            "nrt_mode": nrt_mode,
            "drop_dates": drop_dates,
        }
    # concrete context variables (not from parameters)
    context['gan_window_size'] = WINDOW_SIZE

    if extra_context is not None:
        context.update(extra_context)

    return context


def create_cube(
        connection: Connection,
        geometry: Union[dict, Parameter],
        startdate: Union[str, Parameter],
        enddate: Union[str, Parameter],
        s1_collection: Union[str, Parameter] = "SENTINEL1_GRD",
        s2_collection: Union[str, Parameter] = "SENTINEL2_L2A",
        inpaint_only: Union[bool, Parameter] = True,
        output_mask: Union[bool, Parameter] = False,
        nrt_mode: Union[bool, Parameter] = False,
        drop_dates: Union[Optional[list], Parameter] = None,
        **extra_context
) -> Union[DataCube, ProcessBuilder]:
    """
    Get the processing steps of the CropSAR_px process. The `geo`, `startdate`, `enddate` arguments can be OpenEO parameters.

    :param connection: OpenEO connection object
    :param geometry: geometry
    :param startdate: requested start date
    :param enddate: requested end date
    :param s1_collection: Sentinel-1 collection
    :param s2_collection: Sentinel-2 collection
    :param inpaint_only: indicates whether only to use inpainting or predict the full output
    :param output_mask: indicates whether to output a ground truth mask
    :param nrt_mode: only use prior information (NRT mode)
    :param drop_dates: drop Sentinel-2 acquisitions in input for dates in list
    """
    input_cube = cropsar_pixel_inputs(
        connection,
        geometry,
        startdate,
        enddate,
        s1_collection=s1_collection,
        s2_collection=s2_collection
    )
    context = get_context(
        geo=geometry,
        startdate=startdate,
        enddate=enddate,
        inpaint_only=inpaint_only,
        output_mask=output_mask,
        nrt_mode=nrt_mode,
        drop_dates=drop_dates,
        **extra_context
    )
    cropsar_pixel_cube = apply_udf(input_cube, context)
    cropsar_pixel_cube = post_process(cropsar_pixel_cube, geometry, output_mask)
    return cropsar_pixel_cube


if __name__ == "__main__":
    # mandatory parameters
    param_geo = Parameter(
        name="geometry",
        description="Geometry as GeoJSON feature(s).",
        schema={
            "type": "object",
            "subtype": "geojson"
        }
    )

    param_startdate = Parameter.string(name="startdate", description="start of the temporal interval")
    param_enddate = Parameter.string(name="enddate", description="end of the temporal interval")

    # optional parameters
    param_nrt = Parameter.boolean(name="nrt", default=False, description="only use prior information (NRT mode)")
    param_inpaint_only = Parameter.boolean(
        name="inpaint_only", default=True,
        description="indicates whether to only use inpainting or predict the whole output")
    param_output_mask = Parameter.boolean(
        name="output_mask", default=False,
        description="indicates whether to output a ground truth mask")
    param_drop_dates = Parameter.array(
        name="drop_dates", default=None,
        description="list of dates for which Sentinel-2 acquisitions will be dropped in the input")

    connection = get_connection()
    cropsar_pixel_cube = create_cube(
        connection,
        param_geo,
        param_startdate,
        param_enddate,
        inpaint_only=param_inpaint_only,
        output_mask=param_output_mask,
        nrt_mode=param_nrt,
        drop_dates=param_drop_dates,
    )

    spec = {
        "id": "CropSAR_px",
        "summary": "CropSAR_px",
        "description": pkg_resources.resource_string("cropsar_px", "resources/udp_description.md").decode("utf-8"),
        "parameters": [
            param_geo.to_dict(),
            param_startdate.to_dict(),
            param_enddate.to_dict(),
            param_nrt.to_dict(),
            param_inpaint_only.to_dict(),
            param_output_mask.to_dict(),
            param_drop_dates.to_dict(),
        ],
        "process_graph": cropsar_pixel_cube.flat_graph()
    }

    # write the UDP to a file
    with open("resources/udp.json", "w") as f:
        json.dump(spec, f, indent=4)

    # save the UDP in the back-end for the current user
    # connection.save_user_defined_process(spec['id'], cropsar_pixel_cube, [param_extent, param_startdate, param_enddate])
