import json
import pkg_resources
import openeo
from typing import Union
from openeo.api.process import Parameter
from openeo.processes import if_, ProcessBuilder
from openeo.rest.datacube import DataCube
from openeo.rest.connection import Connection

from cropsar_px.udf_gan import load_cropsar_px_udf
from cropsar_px.preprocessing_openeo import cropsar_pixel_inputs
from cropsar_px.udf_gan import WINDOW_SIZE


def get_connection() -> Connection:
    return openeo.connect("https://openeo-dev.vito.be").authenticate_basic()


def get_process_graph():
    stream = pkg_resources.resource_stream("cropsar_px", "resources/udp.json")
    return json.load(stream)


def post_process(cropsar_pixel_cube: DataCube, percentiles: Union[bool, Parameter]) -> Union[DataCube, ProcessBuilder]:
    """
    Post-processing on the CropSAR_px data cube. Sets the correct labels and scales the output to byte.
    :param cropsar_pixel_cube: data cube
    :param percentiles: indicates whether data cube has Q10, Q50, Q90 percentiles
    """
    if isinstance(percentiles, bool):
        if percentiles:
            bands = ["Q10", "Q50", "Q90"]
        else:
            bands = ["NDVI"]
        cropsar_pixel_cube = cropsar_pixel_cube.rename_labels("bands", bands)
    else:
        cropsar_pixel_cube = if_(
            percentiles,
            cropsar_pixel_cube.rename_labels("bands", ["Q10", "Q50", "Q90"]),
            cropsar_pixel_cube.rename_labels("bands", ["NDVI"]))

    cropsar_pixel_cube = cropsar_pixel_cube.linear_scale_range(-0.08, 1, 0, 250)
    return cropsar_pixel_cube


def apply_udf(input_cube: DataCube, context: dict) -> DataCube:
    """
    Apply the UDF on the given data cube with the parameters in the context.

    :param input_cube: input data cube
    :param context: UDF parameters
    """
    gan_udf_code = load_cropsar_px_udf()
    cropsar_pixel_cube = input_cube.apply_neighborhood(
        lambda data: data.run_udf(udf=gan_udf_code, runtime='Python', context=context),
        size=[
            {'dimension': 'x', 'value': WINDOW_SIZE - WINDOW_SIZE / 8, 'unit': 'px'},
            {'dimension': 'y', 'value': WINDOW_SIZE - WINDOW_SIZE / 8, 'unit': 'px'}
        ],
        overlap=[
            {'dimension': 'x', 'value': WINDOW_SIZE / 16, 'unit': 'px'},
            {'dimension': 'y', 'value': WINDOW_SIZE / 16, 'unit': 'px'}
        ]
    )
    return cropsar_pixel_cube


def get_context(
        geo: Union[dict, Parameter],
        startdate: Union[str, Parameter],
        enddate: Union[str, Parameter],
        percentiles: Union[bool, Parameter]) -> dict:
    """
    Get the UDF context. Takes care of the different behaviour for concrete values or OpenEO parameters.

    :param geo: geometry
    :param startdate: start date
    :param enddate: end date
    :param percentiles: get Q10, Q50, Q90 percentiles
    """
    parameterized = any(isinstance(param, Parameter) for param in [geo, startdate, enddate, percentiles])

    if parameterized:
        context = {
            "startdate": {"from_parameter": startdate.name},
            "enddate": {"from_parameter": enddate.name},
            "percentiles": {"from_parameter": percentiles.name}
        }
    else:
        context = {
            "startdate": startdate,
            "enddate": enddate,
            "percentiles": percentiles
        }
    # concrete context variables (not from parameters)
    context['gan_window_size'] = WINDOW_SIZE

    return context


def create_cube(
        connection: Connection,
        geo: Union[dict, Parameter],
        startdate: Union[str, Parameter],
        enddate: Union[str, Parameter],
        percentiles: Union[bool, Parameter]) -> Union[DataCube, ProcessBuilder]:
    """
    Get the processing steps of the CropSAR_px process. The `geo`, `startdate`, `enddate` arguments can be OpenEO parameters.

    :param connection: OpenEO connection object
    :param geo: geometry
    :param startdate: requested start date
    :param enddate: requested end date
    :param percentiles: get Q10, Q50, Q90 percentiles
    """
    input_cube = cropsar_pixel_inputs(connection, geo, startdate, enddate)
    context = get_context(geo, startdate, enddate, percentiles)
    cropsar_pixel_cube = apply_udf(input_cube, context)
    cropsar_pixel_cube = post_process(cropsar_pixel_cube, percentiles)
    return cropsar_pixel_cube


if __name__ == "__main__":
    param_geo = Parameter(
        name="geometry",
        description="Geometry as GeoJSON feature(s).",
        schema={
            "type": "object",
            "subtype": "geojson"
        }
    )

    param_startdate = Parameter.string(name="startdate")
    param_enddate = Parameter.string(name="enddate")

    param_percentiles = Parameter.boolean(name="percentiles", default=False)

    connection = get_connection()
    cropsar_pixel_cube = create_cube(connection, param_geo, param_startdate, param_enddate, param_percentiles)

    spec = {
        "id": "CropSAR_px",
        "summary": "CropSAR_px",
        "description": pkg_resources.resource_string("cropsar_px", "resources/udp_description.md").decode("utf-8"),
        "parameters": [
            param_geo.to_dict(),
            param_startdate.to_dict(),
            param_enddate.to_dict()
        ],
        "process_graph": cropsar_pixel_cube.flat_graph()
    }

    # write the UDP to a file
    with open("resources/udp.json", "w") as f:
        json.dump(spec, f, indent=4)

    # save the UDP in the back-end for the current user
    # connection.save_user_defined_process(spec['id'], cropsar_pixel_cube, [param_extent, param_startdate, param_enddate])
