from loguru import logger
from pathlib import Path
import pandas as pd
from typing import List
import re
import ast
import numpy as np

from satio.rsindices import (anir, norm_diff, vh_vv)

TRAINING_CLASSES_DIR = Path(
    '/vitodata/EEA_HRL_VLCC/data/training/crop_type/CT_training_classes')


def identity(arr):
    return arr


INDICES = {
    'S2-ANIR-xxx-10m': {
        'bands': ['S2-B04', 'S2-B08', 'S2-B11'],
        'formula': anir},
    'S2-NDGI-xxx-10m': {
        'bands': ['S2-B03', 'S2-B04'],
        'formula': norm_diff},
    'S2-NDWI-xxx-10m': {
        'bands': ['S2-B08', 'S2-B11'],
        'formula': norm_diff},
    'S2-NDRE1-xxx-10m': {
        'bands': ['S2-B08', 'S2-B05'],
        'formula': norm_diff},
    'S2-NDRE5-xxx-10m': {
        'bands': ['S2-B07', 'S2-B05'],
        'formula': norm_diff},
    'S1-VH_VV-xxx-10m': {
        'bands': ['S1-VH', 'S1-VV'],
        'formula': vh_vv},
    'S2-NDVI-xxx-10m': {
        'bands': ['S2-B08', 'S2-B04'],
        'formula': norm_diff},
    'S2-B12-xxx-20m': {
        'bands': ['S2-B12'],
        'formula': identity},
    'S1-VV-xxx-20m': {
        'bands': ['S1-VV'],
        'formula': identity},
    'S1-VH-xxx-20m': {
        'bands': ['S1-VH'],
        'formula': identity}
}


def labels_to_str(labels):
    # convert to strings
    labels = [str(int(label)) for label in labels]
    # insert '-' characters on the right spots
    labels = [f'{l[0:2]}-{l[2:4]}-{l[4:6]}-{l[6:9]}-{l[-1]}'
              for l in labels]
    return labels


def translate_timestamp(column: str, shift: int, excluded_list: List[str]):

    if column in excluded_list:
        return column

    interest_timestamp_regex = r'^.*-ts([0-9]+)-10m$'
    timestamp = re.findall(interest_timestamp_regex, column)[0]
    shifted_timestamp = int(timestamp) + shift

    return column.replace(f'ts{timestamp}-', f'ts{shifted_timestamp}-')


def select_tsteps(df, month_start, month_end):

    logger.info('Selecting tsteps...')
    # get the correct start timestep (= first dekad of start month)
    begin_timestep = (month_start - 1) * 3
    # get the correct end timestep (= last dekad of end month)
    end_timestep = ((month_end - 1) * 3) + 3

    interest_columns_suffix = [f'ts{timestep}-10m' for timestep
                               in range(begin_timestep, end_timestep)]
    additional_interest_columns = [c for c in df.columns if '-ts' not in c]

    interest_columns = list(filter(
        lambda col: col.endswith(tuple(interest_columns_suffix))
        or col in additional_interest_columns,
        df.columns
    ))

    column_map = {col: translate_timestamp(col, -begin_timestep,
                                           additional_interest_columns)
                  for col in interest_columns}

    df = df[df.columns.intersection(column_map.keys())]

    df = df.rename(column_map, axis=1)

    logger.info('Tsteps selected!')

    return df


def code_to_uint16(string_code):
    """Transform string_code in a uint16 code"""
    return np.uint16(''.join(string_code.split('-')))


def translate_labels(df, translation_file, attribute='LABEL'):

    logger.info('Start translation of CT labels...')
    logger.info(f'Reading translation key from: {translation_file}')
    legend_df = pd.read_csv(translation_file, sep=';')

    # get rid of all entries not having Internal_codes
    legend_df = legend_df.dropna(subset=['Internal_codes'])
    # get rid of redundant columns
    legend_df = legend_df.loc[:, ~legend_df.columns.str.startswith('Unnamed:')]

    logger.info('Construct mapping dict')
    lut = pd.Series(legend_df['Internal_codes'].values,
                    index=legend_df['HRL_code']).to_dict()
    lut = {key: ast.literal_eval(value) for key, value in lut.items()}
    hrl_names = pd.Series(legend_df['Name'].values,
                          index=legend_df['HRL_code']).to_dict()
    mapping_dict = {'LUT': lut,
                    'HRL_names': hrl_names}

    # adding two new columns to df
    df['HRL_code'] = [0] * len(df)
    df['HRL_name'] = [''] * len(df)

    logger.info('Execute mapping')
    for k in mapping_dict['LUT'].keys():
        for v in mapping_dict['LUT'][k]:
            r = re.compile(v)
            # get list of matched labels,
            codes = np.unique(list(filter(r.match, df[attribute].to_list())))
            # find row indexes of matches
            indexes = df[df[attribute].isin(codes)].index.values
            # assign mapped HRL code and name to the selected rows
            df.loc[indexes, "HRL_code"] = code_to_uint16(k)
            df.loc[indexes, "HRL_name"] = mapping_dict['HRL_names'][k]

    logger.info('Mapping done!')

    unmapped = sum(df['HRL_code'] == 0)
    if unmapped > 0:
        logger.info(f'{unmapped} rows containing unmapped code')

    return df


def select_training_classes(df, classes_file):

    logger.info('Start selection of training classes...')
    logger.info(f'Reading classes file: {classes_file}')
    classes_df = pd.read_csv(classes_file, sep=';')

    # get rid of all entries not having Codes
    classes_df = classes_df.dropna(subset=['Codes'])
    # get rid of redundant columns
    classes_df = classes_df.loc[:, ~
                                classes_df.columns.str.startswith('Unnamed:')]
    logger.info(f'File contains {len(classes_df)} unique training classes')

    logger.info('Construct mapping dict')
    lut = pd.Series(classes_df['Codes'].values,
                    index=classes_df['HRL_code']).to_dict()
    lut = {key: ast.literal_eval(value) for key, value in lut.items()}
    names = pd.Series(classes_df['Name'].values,
                      index=classes_df['HRL_code']).to_dict()
    mapping_dict = {'LUT': lut,
                    'Names': names}

    # adding two new columns to df
    df['train_code'] = [0] * len(df)
    df['train_name'] = [''] * len(df)

    logger.info('Execute mapping')
    for k in mapping_dict['LUT'].keys():
        values = mapping_dict['LUT'][k]
        # find row indexes of matches
        indexes = df[df.HRL_code.isin(values)].index.values
        # assign mapped training code and name to the selected rows
        df.loc[indexes, "train_code"] = code_to_uint16(k)
        df.loc[indexes, "train_name"] = mapping_dict['Names'][k]
    logger.info('Mapping done!')

    # removing irrelevant samples
    unused = sum(df['train_code'] == 0)
    logger.info(f'Deleting {unused} non-matching samples...')
    df = df.loc[df['train_code'] != 0]

    # Report on content of remaining dataframe
    labels_present = df['train_name'].unique()
    logger.info(f'{len(labels_present)} unique classes remaining:')
    logger.info(f'{labels_present}')
    logger.info(f'A total of {len(df)} samples remaining!')

    logger.info('Training classes selected')

    return df


def compute_index(df, bands, formula, outname):

    # first collect all relevant columns
    cols = []
    for b in bands:
        cols.append(sorted([col for col in df.columns if b in col]))

    # now compute the index for each timestep
    for i in range(len(cols[0])):
        ts = cols[0][i].split('-')[-2]
        outname_split = outname.split('-')
        colname = '-'.join(outname_split[:2])
        colname += f'-{ts}-'
        colname += outname_split[-1]
        df[colname] = formula(*[df[c[i]].values for c in cols])

    return df


def comp_percentiles_std(df, bands):

    metrics = ['p90', 'p10', 'p50', 'iqr', 'std']

    cols_added = []

    for b in bands:
        # collect all relevant columns
        cols = [col for col in df.columns if b in col]
        # create target columns
        outname_split = cols[0].split('-')
        target_cols = []
        for m in metrics:
            target_col = '-'.join(outname_split[:2])
            target_col += f'-{m}-'
            target_col += outname_split[-1]
            target_cols.append(target_col)
        # get all the data
        data = df[cols].values
        # compute the metrics as mentioned earlier
        p90, p10, p50 = np.percentile(data, [90, 10, 50], axis=1)
        df[target_cols[0]] = p90
        df[target_cols[1]] = p10
        df[target_cols[2]] = p50
        df[target_cols[3]] = p90 - p10
        df[target_cols[4]] = np.std(data, axis=1)
        # keep track of all columns added to dataframe
        cols_added.extend(target_cols)

    return df, cols_added


def prepare_filtering(df, filters):
    ''' This function computes some indexes and derived features 
    to be used during training data filtering'''

    # keep track of additional columns added to dataframe
    cols_added = []

    # bands for which percentiles need to be computed
    bands = []

    # get names of filter criteria
    variables = []
    for group, groupfilters in filters.items():
        variables.extend(list(groupfilters.keys()))
    variables = list(set(variables))

    # check if NDVI needs to be computed
    ndvi = False
    for var in variables:
        if 'NDVI' in var:
            ndvi = True
    if ndvi:
        logger.info('Computing NDVI...')
        bands_ndvi = ['S2-B08', 'S2-B04']
        formula = norm_diff
        outname = 'S2-NDVI-xxx-10m'
        df = compute_index(df, bands_ndvi, formula, outname)
        bands.append('S2-NDVI-')
        cols_added = [c for c in df.columns if 'S2-NDVI-ts' in c]

    # check if S1-VV is included in filters
    for var in variables:
        if 'VV' in var:
            bands.append('S1-VV-')

    logger.info('Computing percentiles...')
    df, perc_cols_added = comp_percentiles_std(df, bands)
    cols_added.extend(perc_cols_added)

    return df, cols_added


def filter_df(df, filters):

    logger.info('Start rule-based outlier filtering...')

    logger.info('Computing required indices...')
    df, cols_added = prepare_filtering(df, filters)

    logger.info(f'Dataframe contains {len(df)} samples')

    logger.info('Splitting crops into annuals and perennials...')
    df['crop_cat'] = [''] * len(df)
    df.loc[df['train_code'].astype(str).str.startswith(
        '1'), 'crop_cat'] = 'annual'
    df.loc[df['train_code'].astype(str).str.startswith(
        '2'), 'crop_cat'] = 'perennial'
    df_annual = df.loc[df['crop_cat'] == 'annual']
    logger.info(f'{len(df_annual)} samples for annual cropland')
    df_perennial = df.loc[df['crop_cat'] == 'perennial']
    logger.info(f'{len(df_perennial)} samples for perennial cropland')

    filters_annual = filters.get('annuals', None)
    if filters_annual is not None:
        logger.info('Applying NDVI filters for ANNUAL crops...')
        filters_annual = filters['annuals']
        if 'NDVI-IQR' in filters_annual.keys():
            logger.info('Filter: NDVI IQR should be reasonably large')
            df_annual = df_annual.loc[df_annual['S2-NDVI-iqr-10m'] >=
                                      filters_annual['NDVI-IQR']]
            logger.info(f'{len(df_annual)} annual samples remaining')
        if 'NDVI-P90' in filters_annual.keys():
            logger.info('Filter: NDVI P90 should not be too low')
            df_annual = df_annual.loc[df_annual['S2-NDVI-p90-10m'] >=
                                      filters_annual['NDVI-P90']]
            logger.info(f'{len(df_annual)} annual samples remaining')
        if 'NDVI-P10-1' in filters_annual.keys():
            logger.info('Filter: NDVI PP10 should be low')
            df_annual = df_annual.loc[df_annual['S2-NDVI-p10-10m'] <=
                                      filters_annual['NDVI-P10-1']]
            logger.info(f'{len(df_annual)} annual samples remaining')
        if 'NDVI-P10-2' in filters_annual.keys():
            logger.info('Filter: NDVI PP10 should not be extreme negative')
            df_annual = df_annual.loc[df_annual['S2-NDVI-p10-10m'] >=
                                      filters_annual['NDVI-P10-2']]
            logger.info(f'{len(df_annual)} annual samples remaining')
        if 'S1-VV' in filters_annual.keys():
            logger.info('Filter: S1 VV should not be positive')
            df_annual = df_annual.loc[df_annual['S1-VV-p90-10m'] <=
                                      filters_annual['S1-VV']]
            logger.info(f'{len(df_annual)} annual samples remaining')

    filters_perennial = filters.get('perennials', None)
    if filters_perennial is not None:
        logger.info('Applying filters for PERENNIAL crops...')
        if 'NDVI-P90' in filters_perennial.keys():
            logger.info('Filter: NDVI PP90 should be not too low')
            df_perennial = df_perennial.loc[df_perennial['S2-NDVI-p90-10m'] >=
                                            filters_perennial['NDVI-P90']]
            logger.info(f'{len(df_perennial)} perennial samples remaining')

    logger.info('Merging annuals and perennials...')
    df = pd.concat([df_annual, df_perennial], axis=0)
    df.index = list(range(len(df)))
    logger.info(f'A total of {len(df)} samples remaining after filtering')

    # remove all additional columns that were added during this process
    df.drop(columns=cols_added, inplace=True)

    return df


def remove_outliers(df, inputs):

    logger.info('Computing needed features for outlier detection...')
    cols_added = []
    for index in inputs:
        # first compute indices
        params = INDICES[index]
        if params['formula'] is not None:
            df = compute_index(df, params['bands'], params['formula'],
                               index)
            index_ts = '-'.join(index.split('-')[:2]) + '-ts'
            cols_added.extend([c for c in df.columns if index_ts in c])
    # computing percentiles of all indices
    bands_percentiles = [
        '-'.join(i.split('-')[0:2]) + '-' for i in inputs]
    df, perc_cols = comp_percentiles_std(df, bands_percentiles)
    cols_added.extend(perc_cols)

    logger.info('Start outlier detection...')
    logger.warning('Outlier detection not implemented yet!')

    logger.info('Removing added columns from dataframe...')
    df.drop(columns=cols_added, inplace=True)

    return df
