######################################################################################################
# croptypeclassification.py
# ---------------------------
# Purpose
#       Build OpenEO croptypeclassification service
#
######################################################################################################

#imports
import json
import openeo
import geopandas as gpd
import os

import shapely.ops
from openeo.api.process import Parameter
from openeo.rest.conversions import timeseries_json_to_pandas
from openeo.rest.udp import build_process_dict
import utm
import datetime
import time
from nextland_services.constants import *
import  pandas as pd
import shapely
from shapely.geometry import Polygon
import pyproj
from shapely.ops import transform



OUTPUT_DIR = Path(__file__).with_suffix('.output')
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
OUTPUT_DIR_GRAPH = RESOURCES

# Functions



def _get_epsg(lat, zone_nr):
    if lat >= 0:
        epsg_code = '326' + str(zone_nr)
    else:
        epsg_code = '327' + str(zone_nr)
    return int(epsg_code)

def load_udf(udf):
    with open(udf, 'r+', encoding= "utf8") as f:
        return f.read()

#Write content to named file
def fwrite(fname, fcontent):
    f = open(fname, 'w')
    f.write(str(fcontent))
    f.close()

#Function that will built the input timeseries for the model
def get_input_udf(eoconn, time_range, geo, TS = True):
        S2_L2A = eoconn.load_collection('SENTINEL2_L2A_SENTINELHUB',
                                        bands=["B01","B02", "B03", "B04", "B05", "B06", "B07", "B08", "B09", "B11", "B12","B8A", "SCL"])

        S1_GRD = eoconn.load_collection('SENTINEL1_GRD', bands=['VH', 'VV'])
        S1_GRD = S1_GRD.sar_backscatter(coefficient="gamma0-ellipsoid", local_incidence_angle= True)
        S1_GRD = S1_GRD.apply(lambda x: 10 * x.log(base=10))
        #S1_GRD = S1_GRD.linear_scale_range(-40,0,0,40000)

        if TS:
            S2_L2A_masked = S2_L2A.process("mask_scl_dilation", data=S2_L2A,
                                             scl_band_name="SCL")
            S2_L2A_masked = S2_L2A_masked.resample_cube_spatial(S1_GRD)
            merged_cube = S1_GRD.merge(S2_L2A_masked)
            return merged_cube.filter_temporal(time_range).polygonal_mean_timeseries(geo)
        else:
            S2_L2A = S2_L2A.resample_cube_spatial(S1_GRD)

            # meteo = eoconn.load_collection("AGERA5", bands = ["temperature-mean"])
            # meteo._pg.arguments['featureflags'] = {'tilesize': 1}
            # meteo = meteo.resample_cube_spatial(S1_GRD)
            merged_cube = S1_GRD.merge(S2_L2A) #.merge(meteo)



            return merged_cube.filter_temporal(time_range).filter_spatial(geo)#.mask_polygon(geo)








# Test the croptypeclassification with an UDF
def test_run_udf_crop_type_classification(time_range, shp):
    eoconn = openeo.connect("https://openeo.vito.be/openeo/1.0").authenticate_basic('bontek', 'bontek123')


    TS_cube = get_input_udf(eoconn, time_range, shp)
    TS = TS_cube.send_job(out_format='json')#TS_cube.execute()
    _download_openeo_job_result(TS,r'/data/users/Public/bontek/Nextland/Croptype_classification/tmp/TS_validation_2.json')

    # with open(r'/data/users/Public/bontek/Nextland/Croptype_classification/tmp/TS_validation_2.json', 'w') as file:
    #     json.dump(TS, file)
    #
    #
    #
    # croptype_classification_code = load_udf("croptypeclassification_udf.py")
    #
    # udf = TS_cube.process("run_udf", data = TS_cube, udf = croptype_classification_code, runtime = "Python")
    #
    # pred = udf.send_job().start_and_wait().get_result().load_json()
    # return pred

def test_debug_croptypeclassification_udf():
    from openeo_udf.api.udf_data import UdfData
    from openeo_udf.api.structured_data import StructuredData
    from croptypeclassification_udf import udf_croptypeclassification
    with open(r'/data/users/Public/bontek/Nextland/Croptype_classification/tmp/TS_validation_2.json',
              'r') as json_file:
        ts = json.load(json_file)
        udfdata = UdfData({"EPSG":4326}, structured_data_list=[StructuredData(description= "timeseries input", data = ts, type = "dict")])
        res = udf_croptypeclassification(udfdata)
        fwrite(os.path.join(OUTPUT_DIR, 'croptypeclassication_testudf_result.res'), res.to_dict())




# Build graph croptypeclassification and write it away as a JSON file
def croptypeclassification_build_graph_poly_and_store_udp(eoconn):

    #Define the service parameters
    time_range = date_parameter("Left-closed temporal interval, i.e. an array with exactly two elements:\n\n1. The first element is the start of the temporal interval. The specified instance in time is **included** in the interval.\n2. The second element is the end of the temporal interval. The specified instance in time is **excluded** from the interval.\n\nThe specified temporal strings follow [RFC 3339](https://www.rfc-editor.org/rfc/rfc3339.html). Also supports open intervals by setting one of the boundaries to `null`, but never both.")
    polygon = polygon_param(description=" A polygon object for which the croptype classification will be done.")
    timeseries_input = get_input_udf(eoconn, time_range, polygon)

    #load UDF
    croptype_classification_code = load_udf("croptypeclassification_udf.py")
    udf = timeseries_input.process("run_udf", data = timeseries_input, udf = croptype_classification_code, runtime = "Python")

    #Build service dict
    croptypeclassification_dict = build_process_dict(
        process_id= "CROPTYPECLASSIFICATION",
        process_graph= udf,
        parameters= [time_range, polygon]
    )

    #Write service graph to json file

    fwrite(os.path.join(OUTPUT_DIR_GRAPH, 'croptypclassification_graph_poly.json'), json.dumps(croptypeclassification_dict, indent = 4))


#Save the json graph as an UDP
#Getgeometries function was manually added to the json to allow dealing witn loaded and  not loaded geometries (filename)
def save_udp(eoconn):
    # Define the service parameters
    time_range = date_parameter("Left-closed temporal interval, i.e. an array with exactly two elements:\n\n1. The first element is the start of the temporal interval. The specified instance in time is **included** in the interval.\n2. The second element is the end of the temporal interval. The specified instance in time is **excluded** from the interval.\n\nThe specified temporal strings follow [RFC 3339](https://www.rfc-editor.org/rfc/rfc3339.html). Also supports open intervals by setting one of the boundaries to `null`, but never both.")
    polygon = polygon_param(description=" A polygon object for which the croptype classification will be done.")
    with open(os.path.join(OUTPUT_DIR_GRAPH,'croptypclassification_graph_poly.json'),'r') as file:
        graph = json.load(file)
    udp = eoconn.save_user_defined_process(
        'CROPTYPECLASSIFICATION',
        graph["process_graph"],
        description= "Predicts crop type for all the input geometries. The input geometries can be given by the input parameter 'file_polygons' or 'polygon'. The outcome is a dictionary with some keys indicating the croptype prediction per year for each requested field (e.g. 2017_CT) and their correspoding confidence(e.g.2017_CONF). The amount if predictions within each key is determined by the size of the input geometries. The order of the predictions is the same as the order of the input geometries.",
        parameters=[time_range, polygon], public=True)


def test_build_croptypeclassification_graph():
    eoconn = openeo.connect("https://openeo.vito.be").authenticate_basic('bontek', 'bontek123')
    croptypeclassification_build_graph_poly_and_store_udp(eoconn)

def window_around_centroid(gj, window_size):
    utm_zone_nr = utm.from_latlon(gj['coordinates'][0][0][0][1], gj['coordinates'][0][0][0][0])[2]
    epsg_UTM_field = _get_epsg(gj['coordinates'][0][0][0][1], utm_zone_nr)
    x_coord = [float(item[0]) for item in gj['coordinates'][0][0]]
    y_coord = [float(item[1]) for item in gj['coordinates'][0][0]]
    field = Polygon(zip(x_coord, y_coord))
    field_centroid = field.centroid
    Proj_WGS =4326
    Proj_UTM = epsg_UTM_field
    def geometry_to_crs(geometry, crs_from, crs_to):
        if crs_from == crs_to:
            return geometry
        proj_from =pyproj.Proj(crs_from)
        proj_to = pyproj.Proj(crs_to)
        def project(x,y, z = 0):
            return pyproj.transform(proj_from, proj_to, x, y, always_xy= True)
        return shapely.ops.transform(project, geometry)

    field_centroid_UTM = geometry_to_crs(field_centroid, Proj_WGS,Proj_UTM)
    field_buffer_UTM = field_centroid_UTM.buffer((window_size/2)*10, cap_style = 3)
    field_buffer_WGS = geometry_to_crs(field_buffer_UTM, Proj_UTM,Proj_WGS)
    return field_buffer_WGS

def test_load_input_datacube():
    per_field = False
    ID_identifier = 'FID'
    time_range = '2021-04-01','2021-06-30'
    outdir_raster = r'/data/sigma/Nextland/Precifield/Products/Cropmap/Input_data/'
    outname = 'Test_input_datacube_UTM31'
    overwrite = False
    window_extraction = True
    window_size = 1024
    #gpd_shp = gpd.read_file(r'/data/sigma/Nextland/BDB/In_situ/Irrigation_advice_WGS84.shp')
    with open(r'/data/users/Public/bontek/Nextland/Croptype_classification/Precifield_fields_Nextland_UTM30.geojson', 'r') as file:
        js_shp = json.load(file)

        #r'/data/sigma/Nextland/Precifield/In_situ/Shapes_Precifield/Precifield_fields_Nextland_UTM31.geojson'

    eoconn = openeo.connect("https://openeo-dev.vito.be/openeo/1.0").authenticate_basic('bontek', 'bontek123')

    if per_field:

        outdir_raster_per_field = r'/data/sigma/Nextland/Precifield/Products/Cropmap/Input_data/per_field/tmp/'
        for poly in js_shp['features']:
            polyid = poly['properties'][ID_identifier]
            polyg = poly['geometry']

            if window_extraction:
                polyg = window_around_centroid(polyg, window_size)
                # file_gpd = gpd.GeoDataFrame(geometry=gpd.GeoSeries(polyg))
                # file_gpd.crs = 4326
                # file_gpd.to_file(r'/data/sigma/Nextland/Precifield/Products/Cropmap/Input_data/per_field/tmp/test.shp')
                polyg = shapely.geometry.GeometryCollection([shapely.geometry.shape(polyg)]) #TODO temporariry needed to avoid error




            input_cube = get_input_udf(eoconn, time_range, polyg, TS=False)
            if os.path.exists(os.path.join(outdir_raster_per_field, '%s.nc' % (polyid))) and not overwrite:
                continue

            # try:
            #     ### extract a certain windowsize around the field center
            #input_cube.download(os.path.join(outdir_raster_per_field, '%s.nc' % (polyid)),
                                        #format="NetCDF", options={'stitch': True}) #options={'stitch': True}
            #
            #
            # except:

            job = input_cube.send_job(out_format='netCDF',
                                            sample_by_feature=True, job_options={"executor-memory":"10g"}, options={'stitch': True})  # , options={'stitch': True}
            job.start_and_wait().download_result(os.path.join(outdir_raster_per_field, '%s.nc' % (polyid)))
            #results.download_files(os.path.join(outdir_raster_per_field, '%s.nc' % (polyid)))


    else:
        if window_extraction:
            lst_geom = []
            for poly in js_shp['features']:
                polyg = poly['geometry']
                polyg = window_around_centroid(polyg, window_size)
                lst_geom.append(polyg)
            js_shp = shapely.geometry.GeometryCollection(
                [shapely.geometry.shape(feature) for feature in lst_geom])  # TODO temporariry needed to avoid error

        input_cube = get_input_udf(eoconn, time_range, js_shp, TS= False)


        job = input_cube.send_job(out_format='netCDF', options={'stitch': True}, sample_by_feature = True)
        job.start_and_wait().get_results().download_files(os.path.join(outdir_raster, '%s_202104_202106.nc' % (outname)))
        # os.chdir(outdir_raster)
        #
        # job = input_cube.execute_batch(out_format= 'netCDF', sample_by_feature = True, job_options={"executor-memory":"5g"}, options={'stitch': True}) #, options={'stitch': True}
        # results = job.get_results()
        # results.download_files('%s.nc' % (outname))







#### TEST THE CROPTYPECLASSIFICATION
shp_file = "/data/users/Public/bontek/Nextland/Croptype_classification/SHP_validation.shp"
time_range = ['2018-10-01','2019-08-15']
eoconn = openeo.connect("https://openeo.vito.be/openeo/1.0").authenticate_basic('bontek', 'bontek123')


#
# import shapely
# import utm
# shp = gpd.read_file(shp_file)
# inw_buffer_size = -10
# buffer_args = {'cap_style': 1, 'join_style': 3, 'resolution': 4}
# utm_zone_nr = utm.from_latlon(shp.iloc[0, :].geometry.bounds[1], shp.iloc[0, :].geometry.bounds[0])[2]
# epsg_UTM_field = _get_epsg(shp.iloc[0, :].geometry.bounds[1], utm_zone_nr)
# parcels_UTM = shp.to_crs({'init': 'epsg:{}'.format(str(epsg_UTM_field))})
# parcels_buffered = parcels_UTM.buffer(inw_buffer_size, **buffer_args)
# parcels_buffered = parcels_buffered.simplify(2)
# parcels_buffered_WGS = parcels_buffered.to_crs({'init': 'epsg:4326'})
# shp.geometry = parcels_buffered_WGS.geometry.to_list()
# shp.crs = parcels_buffered_WGS.crs
# crs = int(shp.crs.get('init').split('epsg:')[1])
# geo = shapely.geometry.GeometryCollection([shapely.geometry.shape(feature.geometry) for index, feature in shp.iterrows()])
#pred = test_run_udf_crop_type_classification(time_range, shp_file)
#test_debug_croptypeclassification_udf()
#test_build_croptypeclassification_graph()
#save_udp(eoconn)
test_load_input_datacube()



# gpd_shp = gpd.read_file(shp_file)
# js_shp = json.loads(gpd_shp.to_json())
# for poly in js_shp['features']:
#     polyid = poly['id']
#     polyg = poly['geometry']
# croptype_pred = eoconn.datacube_from_process('CROPTYPECLASSIFICATION', date= time_range, polygon = polyg).send_job().start_and_wait().get_result().load_json()


















