


import openeo
import json
from openeo.processes import eq
from openeo.internal.graph_building import PGNode
from nextland_services.helper_functions import preprocessing_datacube_for_udf, fwrite
from nextland_services.constants import *
from openeo.rest.udp import build_process_dict
import os
from nextland_services.helpers import load_markdown


#Generate results paths
OUTPUT_DIR = Path(__file__).with_suffix('.output')
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
OUTPUT_DIR_GRAPH = RESOURCES
BASEPATH = Path(__file__).parent


# General approach:
#
# first merge all required inputs into a single multiband raster datacube
# compute timeseries for one or more fields, for all bands (one step)
# do postprocessing of the timeseries:
#   compute cropsar based on cleaned fapar + sigma0
#   combine cropsar + S1 to determine cropcalendar
#   return cropcalendar output in your own json format

class Cropcalendars():
    def __init__(self, connection=None):

        # crop calendar independant variables
        self.fAPAR_rescale_Openeo = 0.005
        self.path_harvest_model = r'/data/users/Public/bontek/e_shape/model/model_update1.0_iteration24.h5'
        self.VH_VV_range_normalization = [-13,-3.5]
        self.fAPAR_range_normalization = [0,1]
        self.metrics_order = ['sigma_ascending_VH', 'sigma_ascending_VV','sigma_ascending_angle','sigma_descending_VH', 'sigma_descending_VV','sigma_descending_angle', 'fAPAR']
        self.window_values = 5
        self.thr_detection = 0.75
        self.crop_calendar_event = 'Harvest'
        self.metrics_crop_event = ['cropSAR', 'VH_VV_{}']
        self.max_gap_prediction = 24
        self.shub = True
        self.index_window_above_thr = 2

        # assuming UDF-s packaged in the same folder as driver
        basedir = Path(__file__).parent
        self.fapar_udf_path = Path(__file__).parent / 'biopar_udf.py'
        self.crop_calendar_udf_path = Path(__file__).parent / 'crop_calendar_udf.py'

        # openeo connection
        if (connection == None):

            self._eoconn = openeo \
                .connect('https://openeo.vito.be/openeo/1.0.0') \
                .authenticate_basic('bontek', 'bontek123')
        else:
            self._eoconn = connection


    #####################################################
    ################# FUNCTIONS #########################
    #####################################################

    def get_resource(self, relative_path):
        return str(Path(relative_path))

    def load_udf(self, relative_path):
        with open(self.get_resource(relative_path), 'r', encoding="utf8") as f:
            return f.read()

    def get_bands(self, time_range, geo, shub=True):

        sigma_ascending = self._eoconn.load_collection('SENTINEL1_GRD', bands=['VH', 'VV'], properties={
            "orbitDirection": lambda od: eq(od, "ASCENDING")})
        sigma_ascending = sigma_ascending.sar_backscatter(coefficient="sigma0-ellipsoid",
                                                          local_incidence_angle=True)
        sigma_ascending = sigma_ascending.rename_labels("bands", target=["VH", "VV", "angle"])
        sigma_descending = self._eoconn.load_collection('SENTINEL1_GRD', bands=['VH', 'VV'], properties={
            "orbitDirection": lambda od: eq(od, "DESCENDING")})
        sigma_descending = sigma_descending.sar_backscatter(coefficient="sigma0-ellipsoid",
                                                            local_incidence_angle=True).resample_cube_spatial(
            sigma_ascending)
        sigma_descending = sigma_descending.rename_labels("bands", target=["VH_2", "VV_2", "angle_2"])

        # S2mask = create_mask(self._eoconn, scl_layer_band='SENTINEL2_L2A_SENTINELHUB:SCL')
        S2_bands = self._eoconn.load_collection('SENTINEL2_L2A_SENTINELHUB',
                                                bands=["B03", "B04", "B08", "sunAzimuthAngles", "sunZenithAngles",
                                                       "viewAzimuthMean", "viewZenithMean", "SCL"])
        S2_bands_mask = S2_bands.process("mask_scl_dilation", data=S2_bands,
                                         scl_band_name="SCL")

        # S2_bands_mask = S2_bands.mask(S2mask)
        S2_bands_mask = S2_bands_mask.resample_cube_spatial(sigma_ascending)
        udf = self.load_udf(self.fapar_udf_path)
        fapar_masked = S2_bands_mask.reduce_dimension(dimension="bands", reducer=PGNode(
            process_id="run_udf",
            data={"from_parameter": "data"},
            udf=udf,
            runtime="Python",
            context={'biopar': 'FAPAR'}
        ))
        fapar_masked = fapar_masked.add_dimension('bands', label='band_0', type='bands')
        all_bands = sigma_ascending.merge_cubes(sigma_descending).merge_cubes(fapar_masked)

        ### aply inwards buffering and date shift on the data
        all_bands = preprocessing_datacube_for_udf(all_bands, time_range, geo, bbox = False)

        return all_bands



    ##### FUNCTION TO BUILD PROCESS GRAPH NEEDED FOR HARVEST PREDICTIONS
    def generate_cropcalendars_workflow(self, time_range, gjson_path, run_local=False, create_pg = False):
        # get the datacube containing the time series data
        timeseries = self.get_bands(time_range, gjson_path,self.shub)

        ##### POST PROCESSING TIMESERIES USING A UDF
        if run_local:
            return timeseries

        udf = self.load_udf(self.crop_calendar_udf_path)

        # Default parameters are ingested in the UDF
        if not create_pg:
            gjson_path = gjson_path
            time_range = time_range
        else:
            gjson_path = {"from_parameter": "polygon"}
            time_range = {"from_parameter": "date"}

        context_to_udf = dict({'window_values': self.window_values, 'thr_detection': self.thr_detection,
                               'crop_calendar_event': self.crop_calendar_event,
                               'metrics_crop_event': self.metrics_crop_event,
                               'VH_VV_range_normalization': self.VH_VV_range_normalization,
                               'fAPAR_range_normalization': self.fAPAR_range_normalization,
                               'fAPAR_rescale_Openeo': self.fAPAR_rescale_Openeo,
                               'index_window_above_thr': self.index_window_above_thr,
                               'metrics_order': self.metrics_order, 'path_harvest_model': self.path_harvest_model,
                               'shub': self.shub, 'max_gap_prediction': self.max_gap_prediction, 'gjson': gjson_path,
                               'date': time_range})

        crop_calendars_graph = timeseries.process("run_udf", data=timeseries._pg, udf=udf, runtime='Python',
                                                  context=context_to_udf)
        # crop_calendars_df = pd.DataFrame.from_dict(crop_calendars)
        return crop_calendars_graph

    def generate_cropcalendars(self, time_range, gjson_path, create_pg = False):
        workflow = self.generate_cropcalendars_workflow(time_range, gjson_path, create_pg = create_pg)
        if create_pg:
            return workflow

        crop_calendars = workflow.send_job().start_and_wait().get_result().load_json()


        return crop_calendars

    def generate_cropcalendars_local(self, time_range, gjson_path):
        if not (OUTPUT_DIR / "Harvest_inputs_poly_ref_fields_inwbuf_dateshift.json").exists():
            timeseries = self.generate_cropcalendars_workflow(time_range, gjson_path, run_local=True)
            timeseries = timeseries.execute()
            with open(OUTPUT_DIR / "Harvest_inputs_poly_ref_fields_inwbuf_dateshift.json", 'w') as json_file:
                 json.dump(timeseries, json_file)
        with open(OUTPUT_DIR / "Harvest_inputs_poly_ref_fields_inwbuf_dateshift.json", 'r') as f:
            ts = json.load(f)

        from openeo_udf.api.udf_data import UdfData
        from openeo_udf.api.structured_data import StructuredData
        from crop_calendar_udf import udf_cropcalendars

        context_to_udf = dict({'window_values': self.window_values, 'thr_detection': self.thr_detection,
                               'crop_calendar_event': self.crop_calendar_event,
                               'metrics_crop_event': self.metrics_crop_event,
                               'VH_VV_range_normalization': self.VH_VV_range_normalization,
                               'fAPAR_range_normalization': self.fAPAR_range_normalization,
                               'fAPAR_rescale_Openeo': self.fAPAR_rescale_Openeo,
                               'index_window_above_thr': self.index_window_above_thr,
                               'metrics_order': self.metrics_order, 'path_harvest_model': self.path_harvest_model,
                               'shub': self.shub, 'max_gap_prediction': self.max_gap_prediction, 'gjson': gjson_path
                                  ,'date': time_range})

        udfdata = UdfData({"EPSG:4326"},
                          structured_data_list=[StructuredData(description='timeseries input', data=ts, type='dict')])
        udfdata.user_context = context_to_udf
        crop_calendars_df = udf_cropcalendars(udfdata)
        return crop_calendars_df


def load_timeseries_data():
    path_ref_fields = Path(__file__).parent.parent.parent.joinpath('tests', 'resources', 'ref_fields_harvest_detector.geojson')

    with open(path_ref_fields, 'r', encoding='utf-8') as f:
        polygons = json.load(f)

    time_range = '2019-01-01','2019-12-31'

    # intantiate the cropcalendar class
    cp = Cropcalendars()
    # TODO check if default udf path nowworks within unittests as well
    TS_cube = cp.generate_cropcalendars_workflow(time_range, polygons, run_local=True)
    timeseries_input_poly = TS_cube.execute()  # .send_job().start_and_wait().get_result().load_json()
    with open(OUTPUT_DIR / "Harvest_inputs_poly_ref_fields_inwbuf_dateshift.json", 'w') as f:
        json.dump(timeseries_input_poly, f)


def debug_cropcalendar_udf():
    # intantiate the cropcalendar class
    cp = Cropcalendars()
    time_range = '2019-01-01','2019-12-31'
    path_ref_fields = Path(__file__).parent.parent.parent.joinpath('tests', 'resources',
                                                                   'ref_fields_harvest_detector.geojson')

    with open(path_ref_fields, 'r', encoding='utf-8') as f:
        polygons = json.load(f)

    harvest_pred = cp.generate_cropcalendars_local(time_range,polygons)
    return harvest_pred


def run_cropcalendar_udf():
    # intantiate the cropcalendar class
    cp = Cropcalendars()
    time_range = '2019-01-01', '2019-12-31'
    path_ref_fields = Path(__file__).parent.parent.parent.joinpath('tests', 'resources',
                                                                   'ref_fields_harvest_detector.geojson')

    with open(path_ref_fields, 'r', encoding='utf-8') as f:
        polygons = json.load(f)

    harvest_pred = cp.generate_cropcalendars(time_range, polygons)
    return harvest_pred

def crop_calendar_build_graph_poly_and_store_udp():

    eoconn = openeo.connect("https://openeo-dev.vito.be/openeo/1.0").authenticate_basic('bontek', 'bontek123')
    #Define the service parameters
    time_range = date_parameter("Left-closed temporal interval, i.e. an array with exactly two elements:\n\n1. The first element is the start of the temporal interval. The specified instance in time is **included** in the interval.\n2. The second element is the end of the temporal interval. The specified instance in time is **excluded** from the interval.\n\nThe specified temporal strings follow [RFC 3339](https://www.rfc-editor.org/rfc/rfc3339.html). Also supports open intervals by setting one of the boundaries to `null`, but never both.")

    polygon = polygon_param(description="A geojson polygon on which to compute the metric timeseries.")

    cp = Cropcalendars()
    udf = cp.generate_cropcalendars(time_range, polygon, create_pg= True)
    process_id = "CropCalendar"

    #Build service dict
    cropcalendar_dict = build_process_dict(
        process_id= process_id,
        description=load_markdown("crop_calendar.md"),
        process_graph= udf,
        parameters= [time_range, polygon]
    )

    #Write service graph to json file

    fwrite(os.path.join(OUTPUT_DIR_GRAPH, '{}.json'.format(process_id)), json.dumps(cropcalendar_dict, indent = 4))

def save_udp():
    eoconn = openeo.connect("https://openeo-dev.vito.be/openeo/1.0").authenticate_basic('bontek', 'bontek123')
    # Define the service parameters
    time_range = date_parameter(
        "Left-closed temporal interval, i.e. an array with exactly two elements:\n\n1. The first element is the start of the temporal interval. The specified instance in time is **included** in the interval.\n2. The second element is the end of the temporal interval. The specified instance in time is **excluded** from the interval.\n\nThe specified temporal strings follow [RFC 3339](https://www.rfc-editor.org/rfc/rfc3339.html). Also supports open intervals by setting one of the boundaries to `null`, but never both.")

    polygon = polygon_param(description="A geojson polygon on which to compute the metric timeseries.")

    process_id = "CropCalendar"


    with open(os.path.join(OUTPUT_DIR_GRAPH,'{}.json'.format(process_id)),'r') as file:
        graph = json.load(file)
    udp = eoconn.save_user_defined_process(
        process_id,
        graph["process_graph"],
        description= load_markdown("crop_calendar.md"),
        parameters=[time_range, polygon], public=True)


def run_udp():
    eoconn = openeo.connect("https://openeo-dev.vito.be/openeo/1.0").authenticate_basic('bontek', 'bontek123')
    process_id = "CropCalendar"

    time_range = '2019-01-01', '2019-12-31'
    path_ref_fields = Path(__file__).parent.parent.parent.joinpath('tests', 'resources',
                                                                   'ref_fields_harvest_detector.geojson')

    with open(path_ref_fields, 'r', encoding='utf-8') as f:
        polygons = json.load(f)

    harvest = eoconn.datacube_from_process(process_id, date=time_range,
                                        polygon=polygons).execute()#.send_job().start_and_wait().get_result().load_json()
    return harvest



#load_timeseries_data()
#res = debug_cropcalendar_udf()
#res = run_cropcalendar_udf()
#crop_calendar_build_graph_poly_and_store_udp()
#save_udp()
#harvest_out = run_udp()