
from loguru import logger
import geopandas as gpd
import pandas as pd
from pathlib import Path
import json
import glob
import numpy as np


INTERNAL_CT_LEGEND = ('/vitodata/EEA_HRL_VLCC/data/ref/crop_type2/'
                      'INTERNAL_crop_type_legend_20230524.csv')


def get_internal_ct_label(cl, internal_legend):

    internal_legend = internal_legend.fillna('')
    entry = internal_legend.loc[internal_legend.Code == cl]
    label = entry.Variety_management_use.values[0]
    if label == '':
        label = entry.Crop_type.values[0]
    if label == '':
        label = entry.Crop_group.values[0]
    if label == '':
        label == entry.LC_sub.values[0]
    if label == '':
        label = entry.LC.values[0]

    return label


def crop_type_breakdown(cl):
    ''' This function breaks down a crop type class
    containing a wildcard "*" into a list of corresponding
    crop type classes '''

    internal_legend = pd.read_csv(INTERNAL_CT_LEGEND,
                                  sep=';', header=0)

    ct_codes = list(internal_legend.Code.values)
    cl_parts = cl.split('*')
    cl_parts = [x for x in cl_parts if x != '']
    if len(cl_parts) == 1:
        # we deal with a wildcard at the end of the code
        lst_ct = [x for x in ct_codes if x.startswith(cl_parts[0])]
    elif len(cl_parts) == 2:
        # we deal with a wildcard in the middle of the code
        lst_ct = [x for x in ct_codes if
                  (x.startswith(cl_parts[0]) and
                   x.endswith(cl_parts[1]))]
    else:
        raise ValueError('Encountered problem when trying to break'
                         ' down crop type code!')

    return lst_ct


def check_ct_abundance(cl, samples):

    cl = cl.split(';')
    cl_lst = []
    for c in cl:
        if '*' in c:
            # this is a class comprised of many classes,
            # make sure to sample from all of them...
            c = crop_type_breakdown(c)
            cl_lst.extend(c)
        else:
            cl_lst.append(c)
    n_sampled = len(samples.loc[samples['CT_fin'].isin(cl_lst)])

    return n_sampled, cl_lst


def sample_dataset(basedir, dataset, version,
                   shapefile):

    # Get settings
    samplesdir = basedir / dataset / 'samples'
    settingsfile = str(samplesdir / f'settings_{version}.json')
    settings = json.load(open(settingsfile, 'r'))

    logger.info('Reading original dataset...')
    gdf = gpd.read_file(shapefile)
    logger.info(f'Dataset has {len(gdf)} samples.')

    # remove unwanted samples
    ignore_labels = settings.get('ignore_labels_fully', None)
    if ignore_labels is not None:
        logger.info(f'Removing samples with labels {ignore_labels}')
        gdf = gdf.loc[~gdf['CT_fin'].isin(ignore_labels)]
        logger.info(f'{len(gdf)} samples remaining.')

    sample_with_repeat = settings.get('sample_with_repeat', False)
    if not sample_with_repeat:
        # check whether there are already samples taken from this dataset
        # and if yes, don't consider these anymore
        sampledir = basedir / dataset / 'samples'
        samplefiles = glob.glob(str(sampledir / '*.gpkg'))
        if len(samplefiles) > 0:
            logger.info('Removing previously selected samples...')
            for f in samplefiles:
                selected = gpd.read_file(f)
                selected_ids = list(selected['sampleID'].values)
                gdf = gdf.loc[~gdf['sampleID'].isin(selected_ids)]
            logger.info(f'{len(gdf)} samples remaining.')

    bufferregion = settings.get('bufferregion', None)
    if bufferregion is not None:
        logger.info('Applying buffer on full geometry...')
        gdfbuffer = gpd.read_file(bufferregion)
        inRegion = gdf.intersects(gdfbuffer.unary_union)
        gdf = gdf.loc[inRegion]
        logger.info(f'{len(gdf)} samples remaining.')

    if len(gdf) == 0:
        logger.info('No samples remaining, exiting...')
        return

    # convert to WGS84
    logger.info('Converting sampled dataset to WSG84 projection...')
    gdf = gdf.to_crs(epsg=4326)

    samples = None
    nsamples = settings.get('nsamples', None)
    focus_labels_file = settings.get('focus_labels', None)

    if nsamples is not None:
        ignore_labels_round1 = settings.get('ignore_labels_part', None)
        nsamples = int(float(nsamples))

        if nsamples < len(gdf):
            # Start with random sampling, but first
            # check whether we need to ignore some labels
            # during this stage
            if focus_labels_file is not None:
                focus_labels = pd.read_csv(focus_labels_file,
                                           sep=',', header=0)
                if ignore_labels_round1 is not None:
                    logger.warning(
                        f'Ignoring {ignore_labels_round1} in first round!')
                    ignore = []
                    for c in ignore_labels_round1:
                        c_code = focus_labels.loc[focus_labels['name']
                                                  == c, 'code'].values[0]
                        _, cl_lst = check_ct_abundance(c_code, gdf)
                        ignore.append(gdf.loc[gdf['CT_fin'].isin(cl_lst)])
                    ignore = pd.concat(ignore, axis=0, ignore_index=True)
                    gdf = gdf.loc[~gdf['sampleID'].isin(
                        ignore['sampleID'].values)]

            logger.info('First making a random selection...')
            samples = gdf.sample(n=nsamples)

            # remove the selected ones from the database
            ids_first_round = list(samples['sampleID'].values)
            gdf = gdf.loc[~gdf['sampleID'].isin(ids_first_round)]

            if ignore_labels_round1 is not None:
                # now merge the ignored ones back into the database
                gdf = pd.concat([gdf, ignore], axis=0, ignore_index=True)

            if focus_labels_file is not None:
                logger.info(
                    'Second round of sampling on dedicated crop types...')
                # get minimum number of samples per crop type
                min_samples = int(float(settings.get('min_samples_ct')))
                # loop through the focus labels and check how many samples
                # we already have and how many we need to add
                extra_samples = []
                for cl in focus_labels.code.values:
                    n_sampled, cl_lst = check_ct_abundance(cl, samples)
                    if n_sampled < min_samples:
                        n_extra = min_samples - n_sampled
                        subset = gdf.loc[gdf['CT_fin'].isin(cl_lst)]
                        n_extra = np.min([n_extra, len(subset)])
                        if n_extra > 0:
                            sampled = subset.sample(n=n_extra)
                            extra_samples.append(sampled)
                            ids = list(sampled['sampleID'].values)
                            gdf = gdf.loc[~gdf['sampleID'].isin(ids)]

                # add samples from first round and combine all
                extra_samples.append(samples)
                samples = pd.concat(
                    extra_samples, axis=0, ignore_index=True)

    if samples is None:
        logger.info('Retaining all samples of this dataset...')
        samples = gdf.copy()

    logger.info(f'Saving {len(samples)} samples to file...')
    if 'fid' in list(samples.columns):
        samples = samples.drop(columns=['fid'])
    outfile = str(samplesdir / f'{dataset}_{version}.gpkg')
    samples.to_file(outfile, driver='GPKG')

    logger.info('Checking crop type distribution in samples...')
    cl_count = {}
    # get focus labels
    if focus_labels_file is not None:
        focus_labels = pd.read_csv(focus_labels_file,
                                   sep=',', header=0)
        for j, row in focus_labels.iterrows():
            label = row['name']
            cl = row.code
            n_sampled, _ = check_ct_abundance(cl, samples)
            cl_count[label] = n_sampled
    else:
        focus_labels = list(samples['CT_fin'].unique())
        internal_legend = pd.read_csv(INTERNAL_CT_LEGEND,
                                      sep=';', header=0)
        for cl in focus_labels:
            label = get_internal_ct_label(cl, internal_legend)
            n_sampled = len(samples.loc[samples['CT_fin'] == cl])
            cl_count[label] = n_sampled

    cl_count = pd.DataFrame.from_dict(
        cl_count, orient='index', columns=['count'])
    outfile_count = outfile.replace('.gpkg', '_count.csv')
    cl_count.to_csv(outfile_count)

    logger.success(f'{outfile} created successfully!')


def preprocess_dataset(basedir, dataset):

    indir = basedir / dataset / 'original'

    # get settings
    settingsfile = indir / 'Settings.json'
    settings = json.load(open(settingsfile, 'r'))

    # get filename
    shapefile = glob.glob(str(indir / '*.gpkg'))
    if len(shapefile) == 0:
        shapefile = glob.glob(str(indir / '*.shp'))
    if len(shapefile) != 0:
        shapefile = shapefile[0]
    else:
        raise FileNotFoundError('No valid dataset found!')

    logger.info('Reading shapefile...')
    gdf = gpd.read_file(shapefile)

    logger.info(f'Dataset has {len(gdf)} samples.')

    logger.info('Dropping unwanted columns')
    keep_attrs = settings.get('keep_attrs', [])
    ct_attr = settings.get('ct_attr')
    keep_attrs.append(ct_attr)
    valtime_attr = settings.get('valtime_attr', None)
    if valtime_attr is not None:
        keep_attrs.append(valtime_attr)
    keep_attrs.append('geometry')
    drop_cols = [c for c in gdf.columns if c not in keep_attrs]
    gdf.drop(columns=drop_cols, inplace=True)

    logger.info('Creating unique sample ID...')
    gdf['sampleID'] = [f'{dataset}-{x}'
                       for x in range(len(gdf))]

    logger.info('Translating crop type legend...')
    if ct_attr == 'CT_fin':
        gdf.rename(columns={'CT_fin': 'CT'}, inplace=True)
        ct_attr = 'CT'
    # Deleting samples without crop type information
    gdf = gdf.loc[~gdf[ct_attr].isna()]
    logger.info(f'{len(gdf)} samples with valid CT information')
    # make sure they are strings
    gdf[ct_attr] = gdf[ct_attr].astype(str)
    # read translation key
    translation_csv = settings.get('translation_key')
    translation_df = pd.read_csv(translation_csv,
                                 header=0, sep=';')
    translation_df = translation_df.astype(str)
    # convert to dict
    translation_key = pd.Series(translation_df.INTERNAL_CODE.values,
                                index=translation_df.OLD_CODE.values
                                ).to_dict()
    # do translation
    gdf['CT_fin'] = [translation_key.get(x, '') for x in gdf[ct_attr].values]
    # drop samples which weren't translated correctly
    gdf = gdf.loc[gdf['CT_fin'] != '']
    logger.info(f'{len(gdf)} samples correctly translated')
    # [x for x in np.unique(gdf['CT'].values) if x not in translation_key.keys()]

    logger.info('Checking validity time...')
    if valtime_attr is not None:
        gdf['validityTi'] = pd.to_datetime(gdf[valtime_attr])
        # check range of validity time
        mindate = np.min(gdf['validityTi'].values)
        maxdate = np.max(gdf['validityTi'].values)
        timedif = ((maxdate - mindate).astype('timedelta64[M]') /
                   np.timedelta64(1, 'M'))
        if timedif > 4:
            raise ValueError('More than 4 months between first and last '
                             'validity time in the dataset --> split dataset!')
        gdf['validityTi'] = [x.astype('datetime64[s]').item().strftime(format='%Y-%m-%d')
                             for x in gdf['validityTi'].values]
    else:
        valtime = settings.get('valtime', None)
        if valtime is None:
            raise ValueError('Either valtime_attr or valtime '
                             'should be included in the settings!')
        else:
            gdf['validityTi'] = [valtime] * len(gdf)

    # save translated full file
    outfile = indir / (Path(shapefile).stem + '_preproc.gpkg')
    logger.info(f'Writing translated dataset to {outfile}')
    gdf.to_file(outfile, driver='GPKG')

    return outfile


def get_start_end_from_valtime(shapefile):

    df = gpd.read_file(shapefile)
    valtime = pd.to_datetime(df['validityTi']).values
    mindate = np.min(valtime)
    maxdate = np.max(valtime)
    timedif = maxdate - mindate
    timedif = timedif.astype('timedelta64[D]') / (2 * np.timedelta64(1, 'D'))
    centerdate = mindate + timedif.astype('timedelta64[D]')
    start = centerdate - pd.Timedelta(weeks=10*4)
    end = centerdate + pd.Timedelta(weeks=8*4)
    start = start.strftime(format='%Y-%m-%d')
    end = end.strftime(format='%Y-%m-%d')

    return start, end
