# -*- coding: utf-8 -*-

import os
from pathlib import Path
import json
import openeo
import geopandas as gpd
import pandas as pd
from loguru import logger
import utm
import numpy as np
import xarray as xr
import glob

from openeo.rest.conversions import timeseries_json_to_pandas
from openeo.processes import eq


TIMERANGE_LUT = {
    '2016': {'start': '2015-09-01',
             'end': '2016-08-30'},
    '2017': {'start': '2016-09-01',
             'end': '2017-08-30'},
    '2018': {'start': '2017-09-01',
             'end': '2018-08-30'},
    '2019': {'start': '2018-09-01',
             'end': '2019-08-30'},
    '2020': {'start': '2019-09-01',
             'end': '2020-08-30'},
    '2021': {'start': '2020-09-01',
             'end': '2021-08-30'}
}

COLUMNS_ORDER = ['VH', 'VV', 'angle',
                 "B01", "B02", "B03", "B04",
                 "B05", "B06", "B07", "B08",
                 "B09", "B11", "B12", "B8A",
                 'SCL']


def _get_epsg(lat, zone_nr):
    if lat >= 0:
        epsg_code = '326' + str(zone_nr)
    else:
        epsg_code = '327' + str(zone_nr)
    return int(epsg_code)


def buffer_geometries(gdf, inw_buffer_size=-10):

    buffer_args = {'cap_style': 1, 'join_style': 3, 'resolution': 4}
    utm_zone_nr = utm.from_latlon(gdf.iloc[0, :].geometry.bounds[1],
                                  gdf.iloc[0, :].geometry.bounds[0])[2]
    epsg_UTM_field = _get_epsg(gdf.iloc[0, :].geometry.bounds[1],
                               utm_zone_nr)
    parcels_UTM = gdf.to_crs({'init': 'epsg:{}'.format(str(epsg_UTM_field))})
    parcels_buffered = parcels_UTM.buffer(inw_buffer_size, **buffer_args)
    parcels_buffered = parcels_buffered.simplify(2)
    # in case the buffering resulted in empty geometries
    # -> use the original ones
    parcels_buffered[parcels_buffered.is_empty] = parcels_UTM.loc[
        parcels_buffered.is_empty].geometry
    # convert back to lat/lon
    parcels_buffered_WGS = parcels_buffered.to_crs({'init': 'epsg:4326'})
    gdf.geometry = parcels_buffered_WGS.geometry.to_list()
    gdf.crs = parcels_buffered_WGS.crs
    geo = gdf.geometry.__geo_interface__

    return geo


def get_input_TS(eoconn, time_range, geo):
    '''Function that will build the 
    input timeseries for the model'''

    # don't consider images with more than 85% cloud coverage
    s2_properties = {"eo:cloud_cover": lambda v: v <= 85}
    S2_L2A = eoconn.load_collection('SENTINEL2_L2A_SENTINELHUB',
                                    bands=["B01", "B02", "B03",
                                           "B04", "B05", "B06",
                                           "B07", "B08", "B09",
                                           "B11", "B12", "B8A",
                                           "SCL"],
                                    properties=s2_properties)
    S2_L2A_masked = S2_L2A.process("mask_scl_dilation", data=S2_L2A,
                                   scl_band_name="SCL")
    s1properties = {"polarization": lambda p: eq(p, "DV")}
    S1_GRD = eoconn.load_collection('SENTINEL1_GRD',
                                    bands=['VH', 'VV'],
                                    properties=s1properties)
    S1_GRD = S1_GRD.sar_backscatter(
        coefficient="gamma0-ellipsoid",
        local_incidence_angle=True)
    S1_GRD = S1_GRD.apply(lambda x: 10 * x.log(base=10))
    S2_L2A_masked = S2_L2A_masked.resample_cube_spatial(S1_GRD)
    merged_cube = S1_GRD.merge_cubes(S2_L2A_masked)

    return merged_cube.filter_temporal(
        time_range).aggregate_spatial(geo, reducer='mean')


def add_epsg(fields):

    fields['epsg'] = [''] * len(fields)
    for idx, row in fields.iterrows():
        utm_zone_nr = utm.from_latlon(row.geometry.bounds[1],
                                      row.geometry.bounds[0])[2]
        fields.loc[idx, 'epsg'] = _get_epsg(row.geometry.bounds[1],
                                            utm_zone_nr)

    return fields


def timeseries_from_openeo(fields, geomtype, eoconn, time_range,
                           outdir, batch_size=500,
                           overwrite=False):
    logger.info('Getting timeseries through openEO...')

    # make sure we are processing one utm zone at a time
    # add epsg info to fields
    fields = add_epsg(fields)
    epsg_uni = list(fields['epsg'].unique())
    logger.info(f'Need to process {len(epsg_uni)} utm zones')
    for i, e in enumerate(epsg_uni):

        logger.info(f'Processing zone {i+1}/{len(epsg_uni)}')

        outdir_e = outdir / str(e)
        outdir_e.mkdir(exist_ok=True, parents=True)

        fields_e = fields.loc[fields['epsg'] == e]
        field_ids = list(fields_e['sampleID'].values)

        if geomtype == 'POLY':
            # buffer the field geometries inwards
            geo = buffer_geometries(fields_e)
        elif geomtype == 'POINT':
            geo = fields_e.geometry.__geo_interface__
        else:
            raise ValueError(
                f'{geomtype} not recognized as valid geometry type')
        # remove invalid geometries
        tokeep = []
        for i in range(len(field_ids)):
            if geo['features'][i]['geometry'] is not None:
                tokeep.append(i)
        field_ids = [field_ids[i] for i in tokeep]
        geo['features'] = [geo['features'][i] for i in tokeep]
        # check number of fields and split up in multiple requests if necessary
        nfields = len(field_ids)
        if nfields > batch_size:
            nbatches = int(np.ceil(nfields / batch_size))
            logger.info(f'Splitting request in {nbatches} requests')
        else:
            nbatches = 1

        count = 0
        for b in range(nbatches):
            logger.info(f'Working on batch {b+1}/{nbatches}')
            geo_batch = geo.copy()
            if count + batch_size > nfields - 1:
                geo_batch['features'] = geo_batch['features'][count:]
                field_ids_batch = field_ids[count:]
            else:
                geo_batch['features'] = geo_batch['features'][count:
                                                              count+batch_size]
                field_ids_batch = field_ids[count:count+batch_size]
            count += batch_size

            # get the actual time series
            outjson = str(outdir_e / f'Timeseries_{b}.json')
            if not os.path.exists(outjson) or overwrite:
                try:
                    TS_cube = get_input_TS(eoconn, time_range, geo_batch)
                    job_options = {'soft-errors': 'true',
                                   'max-executors': '50'}
                    TS_cube = TS_cube.execute_batch(job_options=job_options)
                    results = TS_cube.get_results()
                    results.download_file(outjson)
                    logger.info('Timeseries downloaded!')
                except:
                    logger.error('Job failed, skipping!')
                    continue
            # save field id's
            outfile_ids = str(outdir_e / f'FieldIDS_{b}.txt')
            if not os.path.exists(outfile_ids) or overwrite:
                with open(outfile_ids, 'w') as tf:
                    for i in field_ids_batch:
                        tf.write(i + "\n")
                logger.info('Field IDs saved!')

            logger.info(f'Batch {b+1}/{nbatches} processed!')

        logger.info(f'Zone {i+1} processed!')


def ts_json_to_netcdf(outdir, s1_file, s2_file):

    # prepare output
    merged_s1 = None
    merged_s2 = None

    # get input jsons
    jsons = sorted(glob.glob(str(outdir / '*.json')))

    for i, injson in enumerate(jsons):
        logger.info(f'Reading results from batch {i+1}/{len(jsons)}')
        # get time series
        with open(injson, 'r') as json_file:
            ts = json.load(json_file)
        ts_df = timeseries_json_to_pandas(ts)
        ts_df.index = pd.to_datetime(ts_df.index).date

        # get field id's
        batchnr = injson.split('.')[0].split('_')[-1]
        infile_ids = str(outdir / f'FieldIDS_{batchnr}.txt')
        with open(infile_ids, 'r') as f:
            field_ids = f.readlines()
        n_fields = len(field_ids)

        logger.info('Processing each field...')

        for i in range(n_fields):
            if n_fields > 1:
                ts_df_field = ts_df.loc[
                    :, ts_df.columns.get_level_values(0) == i]
            else:
                ts_df_field = ts_df.copy()
            ts_df_field.columns = COLUMNS_ORDER
            s1_data = ts_df_field.iloc[:, 0:3]
            s2_data = ts_df_field.iloc[:, 3:]

            # prepare S1 data
            s1_data.index = pd.to_datetime(s1_data.index).tz_localize(None)
            s1_data.index = pd.MultiIndex.from_product(
                [s1_data.index.tolist(), [field_ids[i]]],
                names=['time', 'CODE_OBJ'])
            s1_data = np.power(10, s1_data / 10.)
            s1_ds = xr.Dataset.from_dataframe(s1_data)
            if merged_s1 is not None:
                merged_s1 = merged_s1.merge(s1_ds)
            else:
                merged_s1 = s1_ds

            # prepare S2 data
            s2_data.index = pd.to_datetime(s2_data.index).tz_localize(None)
            s2_data.index = pd.MultiIndex.from_product(
                [s2_data.index.tolist(), [field_ids[i]]],
                names=['time', 'CODE_OBJ'])
            s2_data = s2_data * 0.0001
            s2_ds = xr.Dataset.from_dataframe(s2_data)
            if merged_s2 is not None:
                merged_s2 = merged_s2.merge(s2_ds)
            else:
                merged_s2 = s2_ds

    # save as netCDF
    merged_s1.to_netcdf(s1_file)
    merged_s2.to_netcdf(s2_file)
    logger.info('Time series saved to netCDF!')


def process_dataset(fields, time_range, outdir_d, eoconn,
                    name, geomtype, overwrite=False):

    logger.info('Inferring output files')
    outdir_ts = outdir_d / 'TS'
    outdir_ts.mkdir(parents=True, exist_ok=True)
    s1_file = str(outdir_ts / f'S1_{time_range[0]}_'
                  f'{time_range[1]}_{name}.nc')
    s2_file = str(outdir_ts / f'S2_{time_range[0]}_'
                  f'{time_range[1]}_{name}.nc')

    if (not os.path.exists(s1_file) or
            not os.path.exists(s2_file) or overwrite):
        timeseries_from_openeo(fields, geomtype, eoconn, time_range,
                               outdir_ts, overwrite=overwrite)
        # now translate to xarrays and save as NetCDFs
        ts_json_to_netcdf(outdir_ts, s1_file, s2_file)
    else:
        logger.info('Time series were already there, skipped!')

    logger.success(f'{name} processed!')


def main(indir, datasets, overwrite=False):

    eoconn = openeo.connect(
        "openeo-dev.vito.be").authenticate_basic(
            'demeter', 'demeter123')

    for dataset in datasets:
        shp_file = str(indir / f'{dataset}.gpkg')
        year = dataset.split('_')[0]
        geomtype = dataset.split('_')[-2]
        time_start = TIMERANGE_LUT[year]['start']
        time_end = TIMERANGE_LUT[year]['end']
        time_range = [time_start, time_end]
        fields = gpd.read_file(shp_file)

        outdir_d = Path(indir) / dataset
        outdir_d.mkdir(exist_ok=True, parents=True)

        logger.info('*' * 50)
        logger.info(f'START PROCESSING {dataset}')
        process_dataset(fields, time_range, outdir_d, eoconn, dataset,
                        geomtype, overwrite=overwrite)

    logger.success('All done!')


if __name__ == "__main__":

    overwrite = False
    indir = Path('/data/EEA_HRL_VLCC/data/ref/crop_type')  # NOQA
    datasets = [
        # '2018_BE_LPIS-Flanders_POLY_110',
        # '2019_BE_LPIS-Flanders_POLY_110',
        # '2021_BE_LPIS-Flanders_POLY_110',
        # '2018_FR_LPIS_POLY_110',
        # '2019_FR_LPIS_POLY_110',
        # '2020_FR_LPIS_POLY_110',
        # '2019_LV_LPIS_POLY_110',
        # '2021_LV_LPIS_POLY_110',
        # '2018_AT_LPIS_POLY_110',
        # '2019_AT_LPIS_POLY_110',
        # '2020_AT_LPIS_POLY_110',
        # '2021_AT_LPIS_POLY_110',
        # '2019_ESP_ESYRCE_POLY_111',
        # '2020_ESP_ESYRCE_POLY_111',
        # '2021_ESP_ESYRCE_POLY_111',
        # '2019_ES_SIGPAC-Catalunya_POLY_111',
        # # '2019_UA_ARABLE-LAND_POLY_110',
        '2018_EU_LUCAS_POINT_110'
    ]

    # # remove invalid geometries from FR 2020 data
    # # (invalid geometries were found by running QGIS tool:
    # # Vector > Geometry Tools > Check Validity... -> method 1 (QGIS))
    # shp_file = str(indir / '2020_FR_LPIS_POLY_110.gpkg')
    # fields = gpd.read_file(shp_file)
    # invalid_file = str(indir / '2020_FR_invalid.gpkg')
    # invalid = gpd.read_file(invalid_file)
    # invalid = list(invalid['sampleID'].values)
    # fields = fields.loc[~fields['sampleID'].isin(invalid)]
    # fields.to_file(shp_file, driver='GPKG')

    main(indir, datasets, overwrite=overwrite)