# -*- coding: utf-8 -*-

from cProfile import label
import xarray as xr
from pathlib import Path
from loguru import logger
import geopandas as gpd
import pandas as pd
import os
import glob
import copy

from worldcereal.utils.spark import get_spark_context

from cropclass.features.hrl_settings import get_feat_parameters


def convert_doy_to_date(doy, year):
    return pd.to_datetime(
        f'{year}-01-01') + pd.Timedelta(days=doy-1)


def get_s2_features(data, settings, features_meta,
                    rsi_meta, ignore_def_feat, gddnormalization):

    # select the S2 bands from the data

    # build a satio.Timeseries object based on the data

    # pre-process the timeseries: compositing and interpolation
    # based on satio functionality :
    # https://github.com/WorldCereal/wc-classification/blob/develop/src/worldcereal/fp.py#L202

    # compute vegetation indices based on satio compute_rsis functionality
    # https://github.com/WorldCereal/wc-classification/blob/f995677c8e22775f64340bfbe1d64e17a1db9401/src/worldcereal/fp.py#L651

    # return features


def get_s1_features(data, settings, features_meta,
                    rsi_meta, ignore_def_feat, gddnormalization):

    # same logic as for get_s2_features function

    # use worldcereal SAR processor as source of inspiration:
    # pre-processing:
    # https://github.com/WorldCereal/wc-classification/blob/f995677c8e22775f64340bfbe1d64e17a1db9401/src/worldcereal/fp.py#L1543

    # feature computation:
    # https://github.com/WorldCereal/wc-classification/blob/f995677c8e22775f64340bfbe1d64e17a1db9401/src/worldcereal/fp.py#L1729

    # return features


def get_label(data):

    return data['CTfin'].values[0]


def process_sample(sample, id, data, settings, index=None):

    sdoy = settings.get('start_doy')
    edoy = settings.get('end_doy')
    year = settings.get('year')
    # assume start date is always in previous year
    start_date = convert_doy_to_date(sdoy, year-1)
    end_date = convert_doy_to_date(edoy, year)

    # filter data to keep only relevant sample + time period
    filt_data = data.sel(CODE_OBJ=id)
    filt_data = filt_data.sel(time=slice(start_date, end_date))

    s2feat = get_s2_features(filt_data, settings['settings']['OPTICAL'],
                             settings['features_meta']['OPTICAL'],
                             settings['rsi_meta']['OPTICAL'],
                             settings['ignore_def_feat']['OPTICAL'],
                             settings['gddnormalization'])

    s1feat = get_s1_features(filt_data, settings['settings']['SAR'],
                             settings['features_meta']['SAR'],
                             settings['rsi_meta']['SAR'],
                             settings['ignore_def_feat']['SAR'],
                             settings['gddnormalization'])

    label = get_label(filt_data)

    # get bounds and epsg (from preprocess_trainingpoints)

    # get DEM feature
    if 'DEM' in settings['settings'].keys():
        ...

    # merge all features (using merge function of satio Features)

    # convert to dataframe
    features_df = pd.DataFrame()
    # ...

    if index is None:
        return features_df
    else:
        # dataframe to dict
        resultdict = features_df.to_dict(orient='index')
        key = list(resultdict.keys())[0]
        return key, resultdict[key]


def process_dataset(basedir, dataset, settings, sc=None):

    # get shapefile with samples
    shapefile = str(basedir / f'{dataset}.gpkg')
    if not os.path.exists(shapefile):
        logger.error('Matching shapefile not found, skipping!')
        return

    # get year and add to settings
    data_settings = copy.deepcopy(settings)
    year = int(dataset.split("_")[0])
    data_settings['year'] = year

    # read shapefile
    samples = gpd.read_file(shapefile)
    samples.set_index('sampleID', inplace=True)

    # get netcdf containing all data
    datafile = glob.glob(str(basedir / dataset /
                             '*InputsOutputs_crops.nc'))[0]
    if not os.path.exists(datafile):
        logger.error('Matching data not found, skipping!')
        return

    # read netcdf
    data = xr.open_dataset(datafile, engine='h5netcdf')

    # only retain samples for which we have data
    samples = samples.loc[list(data.coords['CODE_OBJ'].data)]

    # start processing samples
    if sc is not None:
        rdd = sc.parallelize(samples.iterrows(), len(samples)).map(
            lambda row: process_sample(row[1], row[0], data, data_settings,
                                       index=row[0])).filter(
            lambda r: r[0] is not None)
        result = rdd.collectAsMap()
        result = pd.DataFrame.from_dict(
            result, orient='index')
        # # converting to geodataframe
        # result = gpd.GeoDataFrame(resultdf,
        #                           geometry='geometry',
        #                           crs='EPSG:4326')
    else:
        results = []
        for index, row in samples.iterrows():
            results.append(process_sample(row, index, data, data_settings))
        result = [s for s in results if s is not None]
        result = pd.concat(result, axis=0)

    # save result
    outfile = str(basedir / dataset / 'features.parquet')
    result.to_parquet(outfile)

    logger.success(f'{dataset} successfully processed!')


def main(basedir, datasets, settings, spark=False):

    if spark:
        logger.info('Setting up spark ...')
        sc = get_spark_context()
    else:
        sc = None

    for dataset in datasets:
        logger.info(f'Working on dataset {dataset}')
        process_dataset(basedir, dataset, settings, sc=sc)

    logger.success('All done!')


if __name__ == '__main__':

    basedir = Path("/data/EEA_HRL_VLCC/data/ref/crop_type")
    datasets = [
        '2018_BE_LPIS-Flanders_POLY_110',
        '2019_BE_LPIS-Flanders_POLY_110',
        '2021_BE_LPIS-Flanders_POLY_110',
        '2018_FR_LPIS_POLY_110',
        '2019_FR_LPIS_POLY_110',
        '2020_FR_LPIS_POLY_110',
        '2019_LV_LPIS_POLY_110',
        '2021_LV_LPIS_POLY_110',
        '2018_AT_LPIS_POLY_110',
        '2019_AT_LPIS_POLY_110',
        '2020_AT_LPIS_POLY_110',
        '2021_AT_LPIS_POLY_110',
        '2019_ESP_ESYRCE_POLY_111',
        '2020_ESP_ESYRCE_POLY_111',
        '2021_ESP_ESYRCE_POLY_111',
        '2019_ES_SIGPAC-Catalunya_POLY_111',
        '2019_UA_ARABLE-LAND_POLY_110'
    ]

    spark = False

    settings = get_feat_parameters()

    main(basedir, datasets, settings, spark=spark)