import torch
import os
import matplotlib.pyplot as plt
import rasterio
import numpy as np
import pandas as pd
import glob
from loguru import logger as log
import matplotlib
from fielddelineation.utils.prediction import window_generation
from vito_lot_delineation.evaluation.main import _compute_single, get_pr


# Pre-computed field sizes
FIELD_SIZES = {
    "small": (0, 40),
    "medium": (40, 100),
    "large": (100, 300),
    "extra large": (300, 1e9), }


def _derive_eval_metrics(target, prediction, iou_thr=0.6):
    # Create a precision recall matrix
    target_ids = [x for x in torch.unique(target) if x > 0]
    prediction_ids = [x for x in torch.unique(prediction) if x > 0]

    # Define the precision recall metrics
    pr = get_pr(gt=target, pr=prediction)

    target_found = pr["recall"] >= iou_thr
    prediction_found = pr["precision"] >= iou_thr

    # overlap_m = torch.zeros(len(prediction_ids), len(target_ids))
    # coverage_m = torch.zeros(len(target_ids), len(prediction_ids))
    # for p, p_id in enumerate(prediction_ids):
    #     for t, t_id in enumerate(target_ids):
    #         t_arr = target == t_id
    #         p_arr = prediction == p_id
    #         # how much is the prediction overlapping the target in respect to its area
    #         overlap_m[p, t] = (t_arr & p_arr).sum() / (p_arr).sum()

    #         # how much is the target covered by the prediction
    #         coverage_m[t, p] = (t_arr & p_arr).sum() / (t_arr).sum()

    # overflowing_predictions_ids = ((coverage_m > .4).sum(0) > 1)
    # try:

    #     overflowing_predictions = np.array(prediction_ids)[
    #         overflowing_predictions_ids]
    # except:
    #     overflowing_predictions = np.ndarray([])

    # # avoid problem when not an array but just a float value
    # if type(overflowing_predictions) != np.ndarray:
    #     overflowing_predictions = np.array(overflowing_predictions)

    # pr_m = torch.zeros(len(target_ids), len(prediction_ids))
    # for t, t_id in enumerate(target_ids):
    #     for p, p_id in enumerate(prediction_ids):
    #         if int(p_id) in overflowing_predictions:
    #             pr_m[t, p] = 0.0
    #         else:
    #             t_arr = target == t_id
    #             p_arr = prediction == p_id
    #             pr_m[t, p] = (t_arr & p_arr).sum() / (t_arr | p_arr).sum()

    # prediction_iou = pr_m.max(axis=0).values
    # target_iou = pr_m.max(axis=1).values

    # precision = np.round(
    #     float((prediction_iou >= iou_thr).sum() / max(len(prediction_iou), 1)), 2)
    # recall = np.round(
    #     float((target_iou >= iou_thr).sum() / max(len(target_iou), 1)), 2)
    # log.info(f"Precision: {precision}")
    # log.info(f"Recall: {recall}")

    # prediction_found = pr_m.max(axis=0).values >= iou_thr
    # target_found = pr_m.max(axis=1).values >= iou_thr

    return prediction_found, target_found, pr


def _evaluate_full_extent(pr_dir, gt_dir, range_filter=None,  dtype='int16'):
    """"""""""""""""""""""""""""""""
    # Model evaluation part
    """"""""""""""""""""""""""""""""
    log.info('Start model evaluation')

    STATISTICS = {
        'n_gt': 'sum',
        'n_pr': 'sum',
        'precision': 'mean',
        'recall': 'mean'
    }

    if dtype == 'int16':
        dtype_format = np.int16
    elif dtype == 'int32':
        dtype_format = np.int32

    else:
        raise ValueError('Datatype not supported for evaluation!!')

    if range_filter is not None:
        gt_torch = torch.Tensor(rasterio.open(gt_dir).read(
            1)[0:1000, 0:1000].astype(dtype_format))
        pred_torch = torch.Tensor(rasterio.open(pr_dir).read(1)[
                                  0:1000, 0:1000].astype(dtype_format))
    else:
        gt_torch = torch.Tensor(rasterio.open(gt_dir).read(
            1)[0:range_filter, 0:range_filter].astype(dtype_format))
        pred_torch = torch.Tensor(rasterio.open(pr_dir).read(
            1)[0:range_filter, 0:range_filter].astype(dtype_format))

    # Apply the evaluation function on these arrays
    result = _compute_single(
        gt=gt_torch.cpu(), pr=pred_torch.cpu(), field_sizes=FIELD_SIZES)

    # Put the outcome of the evaluation in a dataframe
    df_result = pd.DataFrame.from_dict(result)

    lst_df_stats = []
    for i, eval_param in df_result.iterrows():
        dict_stats_field_size = {}

        # Obtain the statistical operation that should be applied on the specific evaluation parameter
        stat_operation = STATISTICS.get(i)

        for field_size, stats in eval_param.iteritems():
            if stat_operation == 'sum':
                dict_stats_field_size.update({field_size: [len(stats.numpy())]})
            elif stat_operation == 'mean':
                dict_stats_field_size.update(
                    {field_size: [np.mean(stats.numpy())]})
            else:
                raise ValueError(f'Statistical operation not defined for {i}')
        # merge the dictionary into a dataframe
        df_stats_eval_param = pd.DataFrame.from_dict(dict_stats_field_size)
        df_stats_eval_param.index = [i]
        lst_df_stats.append(df_stats_eval_param)

    # Merge now all the dataframes to get an summary dataframe with the evaluation metrics per field size
    df_merged = pd.concat(lst_df_stats)

    return df_merged


def _apply_window_evaluation(window_id, target_dir, pred_dir,
                             outfolder, window, pred_context):
    with rasterio.open(target_dir, 'r') as src_target:
        target_values = src_target.read(1, window=window)

    with rasterio.open(pred_dir, 'r') as src_pred:
        pred_values = src_pred.read(1, window=window)

    # check if there are fields in the patch otherwise skip evaluation
    # also only plot if there is at least one prediction at all
    if len(np.unique(target_values.data)) < 2 or len(np.unique(pred_values.data)) < 2:
        log.info(f'No fields covering patch {window_id} --> skip evaluation')

    else:
        _evaluate(torch.Tensor(pred_values.astype(np.int32)), torch.Tensor(
            target_values.astype(np.int32)), outfolder, window_id, pred_context)


def _evaluate_window(patch_id, target_dir, pred_dir, outfolder,
                     windowsize, pred_context, valid_windows=None,
                     sql=None):
    xdim, ydim = rasterio.open(pred_dir).shape

    if valid_windows is None:
        windowlist = window_generation(xdim, ydim,
                                       windowsize, stride=0,
                                       force_match_grid=True)
    else:
        windowlist = valid_windows

    if pred_context.get('postproc_method') == 'radix':
        outfolder_patches = os.path.join(outfolder, 'png', 'radix')
        outfolder_acc = os.path.join(outfolder, 'csv', 'radix',
                                     'per_window')

    else:
        outfolder_patches = os.path.join(outfolder, 'png', 'felzenswalb')
        outfolder_acc = os.path.join(outfolder, 'csv', 'felzenswalb',
                                     'per_window')

    ex_files_csv = glob.glob(os.path.join(outfolder_acc, f'{patch_id}*.csv'))
    ex_files_png = glob.glob(os.path.join(
        outfolder_patches, f'{patch_id}*.png'))
    [os.unlink(item) for item in ex_files_csv]
    [os.unlink(item) for item in ex_files_png]

    log.info((f'A total of {len(windowlist)} windows have'
              ' been defined in this patch for evaluation ...'))

    # We create unique window IDs, based on the window position itself,
    # and the S2 tile id
    window_ids = ['_'.join([patch_id, str(window[0][0]),
                            str(window[0][1]), str(window[1][0]),
                            str(window[1][1])]) for window in windowlist]

    # Put it together in a dataframe
    df = pd.DataFrame(window_ids).rename(columns={0: 'window_id'})
    df['x0'] = [window[0][0] for window in windowlist]
    df['x1'] = [window[0][1] for window in windowlist]
    df['y0'] = [window[1][0] for window in windowlist]
    df['y1'] = [window[1][1] for window in windowlist]
    df['id'] = patch_id

    if sql is None:
        for i, row in df.iterrows():
            log.info(f'Processing window {str(i)} out of {str(df.shape[0])}')

            _apply_window_evaluation(row.window_id, target_dir, pred_dir,
                                     outfolder, ((row.x0, row.x1),
                                                 (row.y0, row.y1)),
                                     pred_context)

    else:
        log.info('Creating spark dataframe ...')
        df_spark = sql.createDataFrame(df).persist()

        log.info('Start processing patch prediction on the executors ...')
        sc_output = df_spark.repartition(len(df)) \
            .rdd.map(lambda row: (row.window_id,
                                  _apply_window_evaluation(
                                      row.window_id,
                                      target_dir,
                                      pred_dir,
                                      outfolder,
                                      ((row.x0, row.x1),
                                       (row.y0, row.y1)),
                                      pred_context,
                                  ))).collectAsMap()

        df_spark.unpersist()

    # Concatenate the accuracy per window to one file

    if pred_context.get('postproc_method') == 'radix':
        outfolder_summary_acc = os.path.join(outfolder,
                                             'csv', 'radix')
    else:
        outfolder_summary_acc = os.path.join(outfolder,
                                             'csv', 'felzenswalb')
    outname_summary_acc = f'{patch_id}_summary_acc.csv'
    files_acc = glob.glob(os.path.join(outfolder_summary_acc,
                                       'per_window', f'{patch_id}_*.csv'))

    # for safety check if every file can be opened,
    # otherwise remove the files that could not be opened
    # implemented because of problem in Spark --> to be investigate
    lst_files_acc = []
    for file in files_acc:
        try:
            pd.read_csv(file)
            lst_files_acc.append(file)
        except Exception as e:
            log.warning(
                f'Evaluation file: {file} could not be opened correctly due to {e}')
            continue

    if lst_files_acc:
        df_summary = pd.concat((pd.read_csv(f)) for f in lst_files_acc)
        df_summary = df_summary.reset_index(drop=True)
        df_summary = df_summary.mean(axis=0)
        df_summary.to_csv(os.path.join(outfolder_summary_acc,
                                       outname_summary_acc),
                          index=True)


def _evaluate(pred: torch.Tensor, target: torch.tensor,
              outfolder: str, window_id: str, pred_context: dict):
    """
    Function that will perform an evaluation of the model output based on the target (ground truth) data
    :param pred: array that contains the prediction of the model
    :param target: array that contains the actual contours of the fields that should be predicted
    :param outfolder: folder in which the evaluation results will be written
    :param window_id: the unique id of the window for which the prediction is done
    :param pred_context: provide some information on the configuration selected for post processing the result
    :return: evaluation metrics stored in a specific folder
    """

    # Threshold for defining when prediction is done properly
    iou_thr = 0.5

    # Plot now also the output on an image
    if pred_context.get('postproc_method') == 'radix':
        outfolder_patches = os.path.join(outfolder, 'png', 'radix')
    else:
        outfolder_patches = os.path.join(outfolder, 'png', 'felzenswalb')
    os.makedirs(outfolder_patches, exist_ok=True)

    # Create also a json file with an
    # overview on the precision and recall
    if pred_context.get('postproc_method') == 'radix':
        outfolder_acc = os.path.join(outfolder, 'csv', 'radix',
                                     'per_window')
    else:
        outfolder_acc = os.path.join(outfolder, 'csv', 'felzenswalb',
                                     'per_window')
    os.makedirs(outfolder_acc, exist_ok=True)
    outname_acc = f'{window_id}_accuracy.csv'

    # obtain the found predictions and targets
    try:
        prediction_found, target_found, pr = _derive_eval_metrics(
            target, pred, iou_thr=iou_thr)
    except:
        log.warning(f'SKIPPED EVALUATION FOR WINDOW: {window_id}')
        # SKip for the moment everything
        # TODO needs to be solved
        return
    fig, axs = plt.subplots(2, 2)
    plt.ioff()
    # define color scaling for hihllighting false and good predictions
    cmap_3classes = matplotlib.colors.LinearSegmentedColormap.from_list([0, 1, 2],
                                                                        ['purple', 'green', 'red'])
    cmap_2classes = matplotlib.colors.LinearSegmentedColormap.from_list([0, 1],
                                                                        ['purple', 'red'])
    precision = np.mean(pr['precision'].numpy().flatten())
    recall = np.mean(pr['recall'].numpy().flatten())
    # Re-order the unique values such that they start from 1
    pred_reorder = pred.numpy().copy()
    unique_values_pred = np.unique(pred.numpy())
    for i in range(len(np.unique(pred.numpy()))):
        i += 1
        pred_reorder[pred_reorder == unique_values_pred[i-1]] = i

    target_reorder = target.numpy().copy()
    unique_values_target = np.unique(target.numpy())
    for i in range(len(np.unique(target.numpy()))):
        i += 1
        target_reorder[target_reorder == unique_values_target[i - 1]] = i

    axs[0, 0].imshow(target_reorder, cmap="plasma")
    axs[0, 0].set_title("Target")
    target_ = target.detach().clone()
    target_[target_ > 0] = 2
    for i, t_id in enumerate(pr["gt_ids"]):
        v = 1 if target_found[i] else 2
        target_[target == t_id] = v

    # Assign proper color map dependent on if there are good predictions or not
    if torch.unique(target_).size()[0] == 2:
        axs[0, 1].imshow(target_, cmap=cmap_2classes)
    elif torch.unique(target_).size()[0] == 3:
        axs[0, 1].imshow(target_, cmap=cmap_3classes)
    else:
        axs[0, 1].imshow(target_, cmap=('plasma'))

    v_c, v_u = (target_ == 1).sum(), (target_ == 2).sum()
    Target_correct = (100*(v_c)/(v_c+v_u+1e-5)).numpy()
    axs[0, 1].set_title(f"Target (correct: {str(Target_correct)}%)")

    # Prediction
    axs[1, 0].imshow(pred_reorder, cmap="plasma")
    axs[1, 0].set_title("Prediction")
    prediction_ = pred.detach().clone()
    prediction_[prediction_ > 0] = 2
    for i, p_id in enumerate(pr['pr_ids']):
        v = 1 if prediction_found[i] else 2
        prediction_[pred == p_id] = v

    if torch.unique(prediction_).size()[0] == 2:
        axs[1, 1].imshow(prediction_, cmap=cmap_2classes)
    elif torch.unique(prediction_).size()[0] == 3:
        axs[1, 1].imshow(prediction_, cmap=cmap_3classes)
    else:
        axs[1, 1].imshow(prediction_, cmap=('plasma'))

    v_c, v_u = (prediction_ == 1).sum(), (prediction_ == 2).sum()
    Prediction_correct = (100*(v_c)/(v_c+v_u+1e-5)).numpy()
    axs[1, 1].set_title(
        f"Prediction (correct: {str(Prediction_correct)}%)")

    # Formatting
    _ = [axi.set_axis_off() for axi in axs.ravel()]

    # place a text box in upper left in axes coords
    # add also the text to the fig
    lst_text = []
    dict_printing_stats = {'iou_thr': iou_thr}
    # 'precision': precision,
    # 'recall': recall,

    for stat_type, stat_out in dict_printing_stats.items():
        lst_text.append(f'{stat_type}: {stat_out}')
    textstr = '\n'.join(lst_text)
    fig.suptitle(textstr)
    plt.tight_layout()
    plt.show(block=False)
    plt.savefig(os.path.join(outfolder_patches,
                f'VISUAL_CHECKS_PREDICTION_{window_id}.png'))
    plt.close()

    # Save now the metrics to a CSV file
    df_acc = pd.DataFrame(
        [precision, recall, Prediction_correct, Target_correct]).T
    df_acc.columns = ['precision', 'recall', 'Pred_correct', 'Targ_correct']
    if not df_acc.empty:
        df_acc.to_csv(os.path.join(outfolder_acc, outname_acc), index=False)
