"""Plots used for the evaluation."""

from __future__ import annotations

from math import ceil
from pathlib import Path
from typing import Any

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from numpy.typing import NDArray
from scipy.stats import gaussian_kde
from seaborn import heatmap
from sklearn.metrics import classification_report, confusion_matrix


def plot_confusion(
    y_true: NDArray[np.int_],
    y_pred: NDArray[np.int_],
    class_ids: NDArray[np.str_],
    class_names: NDArray[np.str_],
    save_dir: Path | None = None,
    title: str | None = None,
) -> None:
    """Plot the confusion matrix."""
    # Create the normalised confusion matrix
    create_confusion_matrix(
        y_true=y_true,
        y_pred=y_pred,
        class_ids=class_ids,
        class_names=class_names,
        normalise=True,
        title=title,
    )
    if save_dir is not None:
        plt.savefig(save_dir / "confusion_ratio.png")
    else:
        plt.show()
    plt.close()

    # Create the raw-count confusion matrix
    create_confusion_matrix(
        y_true=y_true,
        y_pred=y_pred,
        class_ids=class_ids,
        class_names=class_names,
        normalise=False,
        title=title,
    )
    if save_dir is not None:
        plt.savefig(save_dir / "confusion_count.png")
    else:
        plt.show()
    plt.close()


def plot_metric(
    y_true: NDArray[np.int_],
    y_pred: NDArray[np.int_],
    save_dir: Path | None = None,
    metric: str = "f1-score",
    title: str | None = None,
) -> None:
    """Plot the evaluation graph of the specified metric."""
    target_names = sorted(set(y_true) | set(y_pred))
    report: dict[str, Any] = classification_report(
        y_true, y_pred, target_names=target_names, zero_division=0, output_dict=True
    )
    rounded = [round(report[cls][metric], 2) for cls in target_names]
    density = gaussian_kde(rounded)
    xs = np.linspace(0, 1, 100)

    # Create the figure
    plt.figure(figsize=(10, 5))
    plt.title(title)
    ys = density(xs)
    y_max = max(ys) + 0.1
    plt.plot(xs, ys)
    plt.vlines(
        report["macro avg"][metric],  # Class weighted
        ymin=0,
        ymax=y_max,
        label=f"Class weighted {metric}: {100 * report['macro avg'][metric]:.2f}%",
        colors="r",
    )
    plt.vlines(
        report["weighted avg"][metric],  # Sample weighted
        ymin=0,
        ymax=y_max,
        label=f"Sample weighted {metric}: {100 * report['weighted avg'][metric]:.2f}%",
        colors="g",
    )
    plt.xlim(0, 1)
    plt.xticks([i / 10 for i in range(10)])
    plt.xlabel(metric)
    plt.yticks([])
    plt.ylim(0, y_max)
    plt.grid(axis="x")
    plt.legend()
    plt.tight_layout()
    if save_dir is not None:
        plt.savefig(save_dir / f"{metric}.png")
    else:
        plt.show()
    plt.close()


def plot_metric_comparison(
    mdl_tags: list[str],
    mdl_reports: list[dict[str, Any]],
    save_dir: Path | None = None,
) -> None:
    """Plot the comparison of the main model metrics."""
    arr = np.zeros((5, len(mdl_tags)))
    for i, report in enumerate(mdl_reports):
        arr[0, i] = report["sample avg"]["f1-score"]  # Sample weighted
        arr[1, i] = report["class avg"]["f1-score"]  # Class weighted
        f1_scores = [v["f1-score"] for k, v in report.items() if len(k.split("-")) == 4]
        arr[2, i] = min(f1_scores)
        arr[3, i] = max(f1_scores)
        arr[4, i] = report["accuracy"]

    plt.figure(figsize=(len(mdl_tags) * 2, 5))
    plt.title("Model comparison")
    sns.heatmap(
        100 * arr,
        vmin=0,
        vmax=100,
        annot=np.asarray([f"{100 * x:.1f}%" for x in arr.flatten()]).reshape(arr.shape),
        fmt="",
        linewidths=0.01,
        square=True,
    )
    xticks = [tag.split("-", 1)[1] if "-" in tag else tag for tag in mdl_tags]
    plt.xticks(
        [i + 0.5 for i in range(len(xticks))],
        xticks,
        rotation=20,
        horizontalalignment="right",
    )
    plt.yticks(
        [i + 0.5 for i in range(5)],
        ["F1 (sample weighted)", "F1 (class weighted)", "F1 (worst)", "F1 (best)", "Accuracy"],
        rotation=0,
    )
    plt.tight_layout()
    if save_dir is not None:
        plt.savefig(save_dir / "metrics.png")
    else:
        plt.show()
    plt.close()


def create_confusion_matrix(
    y_true: NDArray[np.int_],
    y_pred: NDArray[np.int_],
    class_ids: NDArray[np.str_],
    class_names: NDArray[np.str_],
    title: str | None = None,
    normalise: bool = True,
) -> plt.figure:
    """Create the confusion matrix."""
    # Create the confusion matrix
    cm = confusion_matrix(y_true, y_pred, normalize="true" if normalise else None)

    # Create the figure
    def _annot(x: str | float) -> str:
        """Annotation function."""
        if normalise:
            return f"{100 * x:.1f}%" if x > 0.0 else ""
        else:
            return f"{x}" if x > 0 else ""

    class_tags = [f"{x}\n({y})" for x, y in zip(class_ids, class_names)]
    figsize = (ceil(0.7 * len(class_tags)) + 1, ceil(0.7 * len(class_tags)))
    figsize = (10, 9) if figsize[0] < 10 else figsize
    fig = plt.figure(figsize=figsize)
    plt.title(title)
    heatmap(
        100 * cm if normalise else cm,
        vmin=0,
        vmax=100 if normalise else None,
        annot=np.asarray([_annot(x) for x in cm.flatten()]).reshape(cm.shape),
        fmt="",
        xticklabels=class_tags,
        yticklabels=class_tags,
        linewidths=0.01,
        square=True,
    )
    plt.xticks(
        [i + 0.5 for i in range(len(class_tags))],
        class_tags,
        rotation=90,
    )
    plt.yticks(
        [i + 0.5 for i in range(len(class_tags))],
        class_tags,
        rotation=0,
    )
    plt.xlabel("Predicted")
    plt.ylabel("Target")
    plt.tight_layout()
    return fig
