import glob
from pathlib import Path
from loguru import logger
import tensorflow as tf
import numpy as np
import pandas as pd
import geopandas as gpd
import matplotlib.pyplot as plt
from shapely.geometry import Point
from sklearn.utils import class_weight


from worldcereal.utils.scalers import minmaxscaler, scale_df
from worldcereal.classification.weights import (get_refid_weight,
                                                load_refidweights)
from worldcereal.utils import aez
from cropclass import get_lut_ct
from worldcereal.utils.training import (select_within_aez,
                                        do_smote,
                                        NotEnoughSamplesError)

REFIDWEIGHTS = load_refidweights()


def get_label(row, output, lut):
    label = lut.get(row[output], 0)
    return label


def get_sample_weight(sample, options):

    # Default weight
    if sample['OUTPUT'] == 0:
        weight = 0
    else:
        weight = 1

    # Multiply the weights by the ref_id weighting factor
    ref_id = sample['ref_id']
    ref_id_weight = get_refid_weight(ref_id, 'LC', refidweights=REFIDWEIGHTS)
    weight *= ref_id_weight / 100.

    return float(weight)


def get_trainingdata(df, inputfeatures, season, options, aez_id=None,
                     aez_group=None, minsamples=500,
                     logdir=None, worldcover=False, outliers=False,
                     outlierinputs=None, buffer=500000,
                     scale_features=True, impute_missing=True,
                     return_pandas=False, outlierfraction=0.10):
    """Function that returns inputs/outputs from training DataFrame

    Args:
        df (pd.DataFrame): dataframe containing input/output features
        inputfeatures (list[str]): list of inputfeatures to be used
        season (str): season identifier
        options (dict): dictionary containing options
        aez_id (int, optional): ID of the AEZ to subset on. Defaults to None.
        aez_group (int, optional): AEZ group ID to subset on. Defaults to None.
        minsamples (int, optional): minimum number of samples. Defaults to 500.
        buffer (int, optional): the buffer (in m) to take around AEZs before
            selecting matching location_ids. Defaults to 500000 (500km)
        logdir (str, optional): output path for logs/figures. Defaults to None.
        worldcover (bool, optional): whether or not to remove outliers based on
            worldcover data. Defaults to False.
        outliers (bool, optional): whether or not to remove
            outliers from the class
            of interest based on KNN (pyod implementation). Defaults to False.
        outlierinputs (list, optional): in case outliers need to be removed,
            this list specifies which input variables need to be used here.
            Defaults to None.
        scale_features (bool, optional): If True, input features are scaled.
        impute_missing (bool, optional): If True, impute NaN by 0

    Raises:
        NotEnoughSamplesError: obviously when not enough samples were found

    Returns:
        inputs, outputs, weights: arrays to use in training
    """

    # Check if we have enough samples at all to start with
    # that's at least 2X minsamples (binary classification)
    if df.shape[0] < 2 * minsamples:
        errormessage = (f'Got less than {2 * minsamples} '
                        f'in total for this dataset. '
                        'Cannot continue!')
        logger.error(errormessage)
        raise NotEnoughSamplesError(errormessage)

    # Input clamping
    minclamp = -0.1  # Do not allow inputs below this value (None to disable)
    maxclamp = 1.1  # Do not allow inputs above this value (None to disable)

    if aez_id is not None and aez_group is not None:
        raise ValueError('Cannot set both `aez_id` and `aez_group`')

    # Select samples within (buffered) AEZ
    if aez_id is not None or aez_group is not None:
        df = select_within_aez(df, aez_id, aez_group, buffer)

    # Remove corrupt rows
    remove_idx = ((df.isnull()) | (df == 0)).sum(axis=1) > 30
    df = df[~remove_idx].copy()

    # Translate LC/CT to output values
    df['LABEL'] = df.apply(lambda row: get_label(
        row, options['output'], options['lut']), axis=1)

    # Remove the unknown labels
    remove_idx = (df['LABEL'] == 0)
    logger.info(f'Removed {remove_idx.sum()} unknown/ignore samples.')
    df = df[~remove_idx].copy()

    # Get rid of urban/bare samples according to worldcover
    if worldcover:
        agri_labels = [10, 11, 12]
        ewoc_ignore = [50, 60]
        remove_idx = ((df['OUTPUT'].isin(agri_labels)) &
                      (df['WORLDCOVER-LABEL-10m'].isin(ewoc_ignore)))

        df = df[~remove_idx].copy()
        logger.info(f'Removed {remove_idx.sum()} bare/urban samples'
                    f' according to worldcover information.')

    # get rid of outliers (only for class of interest)
    if outliers:
        if outlierinputs is None:
            raise ValueError('No outlier inputs provided!')
        from pyod.models.cblof import CBLOF

        # split the dataset and run outlier removal on only one part
        nremoved = 0
        dfs = []
        logger.info(f'Obs before OD: {df.shape[0]}')
        if df.shape[0] > 0:
            unique_labels = df[options['output']].unique()
            ref_ids = df['ref_id'].unique()
            for label in unique_labels:
                logger.info(f'Removing outliers for label: {label}')
                for ref_id in ref_ids:
                    dftoclean = df[(df[options['output']] == label) &
                                   (df['ref_id'] == ref_id)]
                    if dftoclean.shape[0] < 30:
                        # Not enough samples for this label: skip
                        # OD routine
                        dfs.append(dftoclean)
                        continue
                    # get the variables used for outlier removal
                    outlier_x = dftoclean[outlierinputs].values
                    # Get rid of any existing NaN values
                    outlier_x[np.isnan(outlier_x)] = 0
                    # fit the model
                    clf = CBLOF(contamination=outlierfraction, alpha=0.75)
                    clf.fit(outlier_x)
                    # get the prediction labels
                    y_train_pred = clf.labels_  # binary labels (0: inliers, 1: outliers)  # NOQA
                    nremoved += y_train_pred.sum()
                    retain_idx = np.where(y_train_pred == 0)
                    dfclean = dftoclean.iloc[retain_idx]
                    dfs.append(dfclean)
        # merge all cleaned dataframes
        if len(dfs) == 0:
            raise NotEnoughSamplesError('Not enough samples to perform OD!')
        df = pd.concat(dfs)
        logger.info(f'Removed {nremoved} outliers')
        logger.info(f'Obs after OD: {df.shape[0]}')

    # Scale the input data
    if scale_features:
        df[inputfeatures] = scale_df(df[inputfeatures],
                                     clamp=(minclamp, maxclamp),
                                     nodata=0,
                                     )

    # Get the inputs
    inputs = df[inputfeatures].values

    # Get the weights
    weights = df.apply(lambda row: get_sample_weight(
        row, options), axis=1).values

    # Log the number of samples still present
    # per ref_id
    ref_id_counts = (df.groupby('ref_id')[
        'LABEL'].value_counts().unstack().fillna(0).astype(int))
    if logdir is not None:
        ref_id_counts.to_csv(Path(logdir) / 'sample_counts.csv')
    ref_id_counts = ref_id_counts.sum(axis=1).to_dict()

    # Get the outputs AFTER the weights
    outputs = df['LABEL'].copy()

    if logdir is not None:
        # Plot histogram of original outputs
        counts = outputs.value_counts()
        labels = counts.index.astype(int)
        plt.bar(range(len(labels)), counts.values)
        plt.xticks(range(len(labels)), labels, rotation=90)
        plt.ylabel('Amounts')
        plt.title('Output label distribution')
        outfile = Path(logdir) / 'output_distribution.png'
        outfile.parent.mkdir(exist_ok=True)
        plt.savefig(outfile)
        plt.close()

    idx = np.where(np.sum(inputs, axis=1) == 0)
    logger.info(f'#Rows with all-zero inputs: {len(idx[0])}')

    # Compute class weights
    if options['classweights']:
        unique_classes = list(np.unique(outputs))
        class_weights = class_weight.compute_class_weight(
            class_weight='balanced', classes=unique_classes,
            y=outputs)
        class_weights = dict(zip(unique_classes, class_weights))
        logger.info(f'Inferred class weights: {class_weights}')
        if logdir is not None:
            pd.DataFrame.from_dict(class_weights, orient='index',
                                   columns=['class_weight'],
                                   dtype='float').to_csv(
                Path(logdir) / 'class_weights.csv'
            )
        for label in unique_classes:
            # Adjust sample weight by class weight
            weights[outputs == label] *= class_weights.get(label, 1)

    # Make sure all NaNs are gone!
    if impute_missing:
        if np.sum(np.isnan(inputs)) > 0:
            logger.warning(f'Removing {np.sum(np.isnan(inputs))} NaN values!')
        inputs[np.isnan(inputs)] = 0

    if return_pandas:
        logger.info('Transforming inputs to pandas.DataFrame ...')
        inputs = pd.DataFrame(data=inputs, columns=inputfeatures)

    return [inputs, outputs.values, weights, ref_id_counts]


def get_pixel_data(season, options, inputfeatures,
                   aez_id=None, aez_group=None, buffer=500000,
                   logdir=None, outlierinputs=None, minsamples=500,
                   **kwargs):

    results = {}
    alltypes = ['cal', 'val', 'test']

    for currenttype in alltypes:

        df = (options[f'trainingfiles_{currenttype}'] if
              type(options[f'trainingfiles_{currenttype}'])
              is list else [options[f'trainingfiles_{currenttype}']])

        df_data = pd.DataFrame()

        for current_df in df:
            df_data = pd.concat(
                [df_data,
                 pd.read_parquet(
                     (Path(current_df) /
                      f'training_df_LC.parquet'))])

        if options.get('target_ref_ids') is not None:
            logger.info('Filtering for ref_ids ...')
            df_data = df_data.loc[df_data.ref_id.isin(options.get('target_ref_ids'))]  # NOQA

        df_data = df_data.loc[['Andalucia' not in x for x in df_data.ref_id]]  # NOQA

        results[currenttype] = get_trainingdata(
            df_data, inputfeatures, season,
            options,
            aez_id=aez_id, aez_group=aez_group,
            logdir=logdir,
            worldcover=options.get('worldcover', False),
            outliers=options.get('outliers', False),
            outlierinputs=outlierinputs,
            buffer=buffer,
            minsamples=minsamples,
            **kwargs)

        logger.info((f'Nr of {currenttype} samples: '
                     f'{results[currenttype][0].shape[0]}'))

    return results['cal'], results['val'], results['test']
