"""Plotting functions."""

from __future__ import annotations

from collections import Counter

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from vito_crop_classification.vito_logger import Logger


def plot_median_vs_others(
    df: pd.DataFrame,
    target: str,
    sensor: str,
    logger: Logger,
    vs_original_median: bool = False,
) -> plt.Figure:
    """
    Compare median of samples correctly predicted as class target with misclassified predictions.

    This comparison can be with either:
        - the median of samples being targets but predicted as y.
        - the original median of classes y predicted instead of target.

    Parameters
    ----------
    df : pd.DataFrame
        Dataframe containing test data.
    target : str
        Target class to investigate.
    sensor : str
        Sensor over which compute investigation.
    logger : Logger
        Logger used to write out messages.
    vs_original_median : bool, optional
        Print original median of classes predicted instead of class target

    Returns
    -------
    plt.Figure:
        Computed figure
    """
    logger(f" - plotting median vs wrong preds for class {target} on sensor {sensor}..")
    label_desc = {k: v for k, v in dict(zip(df.target_id, df.target_name))}

    # extract target data
    df_target = df[df.target == target]
    data = np.stack(df_target[sensor].values)

    # count frequency of wrong predicted classes
    counts = Counter(df_target.predicted)
    counts = Counter(dict(sorted(counts.items(), key=lambda x: -x[1])))

    # compute target statistics
    median = np.median(data, axis=0)
    lower = np.quantile(data[df_target.predicted == target], q=0.20, axis=0)
    upper = np.quantile(data[df_target.predicted == target], q=0.80, axis=0)

    # plot
    plt.figure(figsize=(10, 7))
    plt.title(
        f"{label_desc[target]} vs {'miscls original' if vs_original_median else 'misclassified'}"
    )

    plt.plot(median, linewidth=3, label="Median preds")
    plt.fill_between(np.arange(len(median)), lower, upper, color="b", alpha=0.2)

    for k, v in counts.items():
        ratio = v / sum(list(counts.values())[1:])
        if ratio < 0.05:
            break
        if vs_original_median:
            label = (
                f"median {label_desc[k]} - ({ratio * 100:.1f}%)"
                if k != target
                else f"median {label_desc[k]}"
            )
            misscls_median = np.median(np.stack(df[df.target == k][sensor].to_numpy()), axis=0)
        else:
            misscls_median = np.median(data[df_target.predicted == k], axis=0)
            label = (
                f"miscls {label_desc[k]}  - ({ratio * 100:.1f}%)"
                if k != target
                else f"median {label_desc[k]}"
            )
        plt.plot(misscls_median, label=label, linestyle="--")

    plt.ylim(-1, 1)
    plt.legend()
    plt.tight_layout()
    plt.grid(axis="y")
    return plt.gcf()


def plot_learned_trends(
    df: pd.DataFrame,
    targets: list[str],
    sensor: str,
    logger: Logger,
) -> plt.Figure:
    """Plot learned behavior for the specified crops.

    Parameters
    ----------
    df : pd.DataFrame
        Dataframe containing test data.
    targets : list[str]
        List of target classes to visualize.
    sensor : str
        Sensor over which compute investigation.
    logger : Logger
        Logger used to write out messages.

    Returns
    -------
    plt.Figure
        Computed figure
    """
    logger(f" - plotting learned trends for {targets} on sensor {sensor}..")
    label_desc = {k: v for k, v in dict(zip(df.target_id, df.target_name))}

    plt.figure(figsize=(10, 7))
    plt.title("Learned trends")
    for target in targets:
        df_target = df[df.target == target]
        vectors = np.stack(df_target[sensor].to_numpy())
        upper = np.quantile(vectors, q=0.80, axis=0)
        lower = np.quantile(vectors, q=0.20, axis=0)
        median = np.mean(vectors, axis=0)
        plt.plot(median, label=label_desc[target])
        plt.fill_between(np.arange(len(upper)), lower, upper, alpha=0.3)

    plt.ylim(-1, 1)
    plt.ylabel(sensor)
    plt.legend()
    plt.tight_layout()
    plt.grid(axis="y")
    return plt.gcf()


def plot_learned_noise(
    df: pd.DataFrame,
    target: str,
    sensor: str,
    logger: Logger,
) -> plt.Figure:
    """Plot learned noise for the specified target class.

    Parameters
    ----------
    df : pd.DataFrame
        Dataframe containing test data.
    target : str
        Target class to visualize.
    sensor : str
        Sensor over which compute investigation.
    logger : Logger
        Logger used to write out messages.

    Returns
    -------
    plt.Figure
        Computed figure
    """
    logger(f" - plotting learned noise for {target} on sensor {sensor}..")
    label_desc = {k: v for k, v in dict(zip(df.target_id, df.target_name))}

    plt.figure(figsize=(10, 7))
    plt.title(f"Learned noise - {label_desc[target]}")

    df_target = df[df.target == target]
    target_samples = np.stack(df_target[sensor].values)
    upper = np.quantile(target_samples, q=0.97, axis=0)
    lower = np.quantile(target_samples, q=0.03, axis=0)
    plt.fill_between(np.arange(len(upper)), lower, upper, color="r", alpha=0.3, label="all data")

    df_predicted = df[df.predicted == target]
    vectors = np.stack(df_predicted[sensor].to_numpy())
    median = np.median(vectors, axis=0)
    upper = np.quantile(vectors, q=0.85, axis=0)
    lower = np.quantile(vectors, q=0.15, axis=0)
    plt.plot(median, linewidth=3, label="learned behavior")
    plt.fill_between(np.arange(len(upper)), lower, upper, color="b", alpha=0.3)

    plt.ylim(-1, 1)
    plt.legend()
    plt.tight_layout()
    plt.grid(axis="y")
    return plt.gcf()
