# -*- coding: utf-8 -*-

from pathlib import Path
import os
import geopandas as gpd
import pandas as pd
import numpy as np
from loguru import logger

from sklearn.model_selection import train_test_split


# NOTE: run this script on MathWin07 !!


CROPTYPES = {
    '0': {
        'label': 'Unknown'
    },
    '1': {
        'label': 'Maize',
        'CT': ['1200']},
    '2': {
        'label': 'Common wheat',
        'CT': ['1100', '1110', '1120']},
    '3': {
        'label': 'Durum wheat'},
    '4': {
        'label': 'Barley',
        'CT': ['1500', '1510', '1520']},
    '5': {
        'label': 'Rye',
        'CT': ['1600', '1610', '1620']},
    '6': {
        'label': 'Triticale'},
    '7': {
        'label': 'Spelt'},
    '8': {
        'label': 'Oats',
        'CT': ['1700']},
    '9': {
        'label': 'Rice',
        'CT': ['1300']},
    '10': {
        'label': 'Other cereals',
        'CT': ['1900', '1910', '1920', '1400', '1800', '8300']},
    '11': {
        'label': 'Vegetables and herbs',
        'CT': ['2100', '2110', '2120', '2130', '2140',
               '2150', '2160', '2170', '2190',
               '2200', '2210', '2220', '2230', '2240',
               '2250', '2260', '2290',
               '2300', '2310', '2320', '2330', '2340', '2350', '2390',
               '2400', '2900',
               '6211', '6212', '6219']},
    '12': {
        'label': 'Peas and beans',
        'CT': ['7100', '7200', '7300', '7400', '7500',
               '7700', '7800']},
    '13': {
        'label': 'Other pulses',
        'CT': ['7500', '7600', '7910']},
    '21': {
        'label': 'Potatoes',
        'CT': ['5100']},
    '22': {
        'label': 'Beet',
        'CT': ['8100']},
    '23': {
        'label': 'Other root crops',
        'CT': ['5200', '5300', '5400', '5900']},
    '31': {
        'label': 'Sunflower',
        'CT': ['4380']},
    '32': {
        'label': 'Soybean',
        'CT': ['4100']},
    '33': {
        'label': 'Rapeseed',
        'CT': ['4350', '4351', '4352']},
    '34': {
        'label': 'Flax, cotton and hemp',
        'CT': ['9210', '9211', '9212', '9213', '9219']},
    '35': {
        'label': 'Other industrial crops',
        'CT': ['4200', '4300', '4310', '4320', '4330',
               '4340', '4360', '4370', '4390',
               '8200', '8900']},
    '41': {
        'label': 'Grapes',
        'CT': ['3300']},
    '42': {
        'label': 'Olives',
        'CT': ['4420']},
    '43': {
        'label': 'Pome and stone fruit',
        'CT': ['3500', '3510', '3520', '3530', '3540',
               '3550', '3560', '3590']},
    '44': {
        'label': 'Citrus fruit',
        'CT': ['3200', '3210', '3220', '3230', '3240', '3290']},
    '45': {
        'label': 'Berries',
        'CT': ['3400', '3410', '3420', '3430', '3440',
               '3450', '3460', '3490']},
    '46': {
        'label': 'Nuts',
        'CT': ['3600', '3610', '3620', '3630', '3640',
               '3650', '3660', '3690']},
    '47': {
        'label': 'Other fruit',
        'CT': ['3100', '3110', '3120', '3130', '3140',
               '3150', '3160', '3170',
               '3190', '3900']},
    '48': {
        'label': 'Other permanent crops',
        'CT': ['4410', '4430', '4490',
               '6110', '6120', '6130', '6140', '6190',
               '6221', '6222', '6223', '6224', '6225', '6226', '6229',
               '9220', '9320', '9400', '7920', '9520', '9920']},
    '51': {
        'label': 'Grass',
        'LC': ['13'],
        'CT': ['9110', '9120']},
    '52': {
        'label': 'Alfalfa'},
    '53': {
        'label': 'Other fodder crops',
        'CT': ['7900',
               '9100']},
    '61': {
        'label': 'Forest',
        'LC': ['40', '41', '42']},
    '71': {
        'label': 'Other',
        'CT': ['9000', '9300', '9310', '9500', '9510',
               '9600', '9900', '9910', '9998']},
    '255': {
        'label': 'Unclassified'}
}


def process_dataset(infile, dataset, settings, outdir,
                    overwrite=False):

    outfile = str(outdir / f'{dataset}.gpkg')
    if not os.path.exists(outfile) or overwrite:

        logger.info('Reading file')
        gdf = gpd.read_file(infile)

        logger.info('Translating legend')
        gdf['CT'] = gdf['CT'].astype(int).astype(str)
        gdf['LC'] = gdf['LC'].astype(int).astype(str)
        gdf['CTfin'] = [''] * len(gdf)
        for ct, ctsettings in CROPTYPES.items():
            if ctsettings.get('CT', None) is not None:
                ctlist = ctsettings['CT']
                gdf.loc[gdf['CT'].isin(ctlist), 'CTfin'] = ctsettings['label']
            if ctsettings.get('LC', None) is not None:
                lclist = ctsettings['LC']
                gdf.loc[gdf['LC'].isin(lclist), 'CTfin'] = ctsettings['label']

        logger.info('Removing irrelevant samples')
        gdf.dropna(axis=0, subset=['CTfin'], inplace=True)
        gdf = gdf.loc[gdf['CTfin'] != '']

        # # Try to fix geometry by invoking buffer?
        # logger.info('Attempt to fix geometry ...')
        # gdf['geometry'] = gdf.buffer(0)

        # # Buffer
        # bufferregion = settings.get('bufferregion', None)
        # if bufferregion is not None:
        #     logger.info('Applying buffer')
        #     gdfbuffer = gpd.read_file(bufferregion)
        #     inRegion = gdf.intersects(gdfbuffer.unary_union)
        #     gdf = gdf.loc[inRegion]

        # # remove samples with only one occurrence for a CT label
        # counts = gdf.groupby(['CT']).count()
        # if np.sum(np.sum(counts == 1)) > 0:
        #     idx = np.where(counts.values[:, 0] == 1)[0]
        #     ctlabel = counts.index[idx]
        #
        #     for label in ctlabel:
        #         gdf = gdf[gdf['CT'] != label]

        logger.info('Making selection ...')
        nsamples = settings.get('nsamples', None)
        if nsamples is not None:
            classcounts = gdf.groupby(['CTfin']).count()
            samples_per_class = int(nsamples / (2*len(classcounts)))
            classcounts['nsamples'] = np.zeros(len(classcounts))
            totalsamples = 0
            while totalsamples < nsamples:
                for cl in classcounts.index:
                    if classcounts.loc[cl, 'geometry'] > 0:
                        add_samples = np.amin(
                            np.array([classcounts.loc[cl, 'geometry'],
                                      samples_per_class]))
                        classcounts.loc[cl, 'nsamples'] += add_samples
                        classcounts.loc[cl, 'geometry'] -= add_samples

                totalsamples = np.sum(classcounts['nsamples'].values)

            # now extract the samples randomly per class
            dfs = []
            for cl in classcounts.index:
                subset = gdf.loc[gdf['CTfin'] == cl]
                subsetsamples = int(classcounts.loc[cl, 'nsamples'])
                dfs.append(subset.sample(n=subsetsamples))

            samples = pd.concat(dfs, axis=0, ignore_index=True)

        # if nsamples is not None:
        #     discard, keep = train_test_split(gdf['sampleID'].values,
        #                                      test_size=nsamples,
        #                                      stratify=gdf['CTfin'])
        #     keep = list(keep)
        #     samples = gdf.loc[gdf['sampleID'].isin(keep)]
        else:
            # select all samples
            samples = gdf.copy()

        # extend sampleID with dataset name
        if settings.get('add_ref_id', None) is not None:
            samples['sampleID'] = [
                f'{dataset}-{id}' for id in samples['sampleID'].values]
        # save samples
        samples.to_file(outfile, driver='GPKG')

        logger.success(f'{outfile} generated!')


def main(indir, outdir, datasets, overwrite=False):

    for dataset, settings in datasets.items():
        year = dataset.split('_')[0]
        labeltype = dataset.split('_')[-2]
        contenttype = dataset.split('_')[-1]

        infile = str(indir / labeltype / contenttype /
                     year / dataset / f'{dataset}.shp')
        if not os.path.exists(infile):
            infile = infile.replace('.shp', '.gpkg')
            if not os.path.exists(infile):
                raise ValueError(f'{infile} not found, cannot continue!')

        logger.info(f'START PROCESSING {dataset}')
        process_dataset(infile, dataset, settings, outdir,
                        overwrite=overwrite)


if __name__ == "__main__":

    # indir = Path(r'W:\data\ref\VITO_processed')
    indir = Path('/data/worldcereal/data/ref/VITO_processed')
    # outdir = Path(r'W:\tmp\jeroen\demeter_shapefiles')  # NOQA
    outdir = Path('/data/worldcereal/tmp/jeroen/demeter_shapefiles')
    outdir.mkdir(exist_ok=True, parents=True)

    overwrite = False

    datasets = {
        # '2018_BE_LPIS-Flanders_POLY_110': {
        #     'bufferregion':
        #         r'W:\data\ref\VITO\buffering\flanders_buffered1000m.shp',
        #     'nsamples': 2000,
        #     'add_ref_id': 1},
        # '2019_BE_LPIS-Flanders_POLY_110': {
        #     'bufferregion':
        #         r'W:\data\ref\VITO\buffering\flanders_buffered1000m.shp',
        #     'nsamples': 2000,
        #     'add_ref_id': 1},
        # '2021_BE_LPIS-Flanders_POLY_110': {
        #     'bufferregion':
        #         r'W:\data\ref\VITO\buffering\flanders_buffered1000m.shp',
        #     'nsamples': 2000},

        # '2018_FR_LPIS_POLY_110': {
        #     'bufferregion':
        #         r'W:\data\ref\VITO\buffering\france_buffered1000m.shp',
        #     'nsamples': 4000},
        # '2019_FR_LPIS_POLY_110': {
        #     'bufferregion':
        #         r'W:\data\ref\VITO\buffering\france_buffered1000m.shp',
        #     'nsamples': 4000},
        # '2020_FR_LPIS_POLY_110': {
        #     'bufferregion':
        #         r'W:\data\ref\VITO\buffering\france_buffered1000m.shp',
        #     'nsamples': 4000},

        # '2019_LV_LPIS_POLY_110': {
        #     'bufferregion':
        #         r'W:\data\ref\VITO\buffering\latvia_buffered1000m.shp',
        #     'nsamples': 4000,
        #     'add_ref_id': 1},
        # '2021_LV_LPIS_POLY_110': {
        #     'bufferregion':
        #         r'W:\data\ref\VITO\buffering\latvia_buffered1000m.shp',
        #     'nsamples': 4000,
        #     'add_ref_id': 1},

        # '2018_AT_LPIS_POLY_110': {
        #     'bufferregion':
        #         r'W:\data\ref\VITO\buffering\austria_buffered1000m.shp',
        #     'nsamples': 2000},
        # '2019_AT_LPIS_POLY_110': {
        #     'bufferregion':
        #         r'W:\data\ref\VITO\buffering\austria_buffered1000m.shp',
        #     'nsamples': 2000},
        # '2020_AT_LPIS_POLY_110': {
        #     'bufferregion':
        #         r'W:\data\ref\VITO\buffering\austria_buffered1000m.shp',
        #     'nsamples': 2000},
        # '2021_AT_LPIS_POLY_110': {
        #     'bufferregion':
        #         r'W:\data\ref\VITO\buffering\austria_buffered1000m.shp',
        #     'nsamples': 2000},

        # '2019_ESP_ESYRCE_POLY_111': {
        #     'nsamples': 3000},
        # '2020_ESP_ESYRCE_POLY_111': {
        #     'nsamples': 3000},
        # '2021_ESP_ESYRCE_POLY_111': {
        #     'nsamples': 3000},

        # '2019_ES_SIGPAC-Catalunya_POLY_111': {
        #     'nsamples': 3000,
        #     'bufferregion':
        #         r'W:\data\ref\VITO\buffering\catalunya_buffered1000m.shp',
        #     'add_ref_id': 1},

        # '2019_UA_ARABLE-LAND_POLY_110': {
        # },
        '2018_EU_LUCAS_POINT_110': {
        }
    }

    main(indir, outdir, datasets, overwrite=overwrite)