"""Set of plotting functions used for evaluation purposes."""

from __future__ import annotations

import random
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import torch
from matplotlib.figure import Figure

from vito_lot_delineation.evaluation.metrics import get_pr


def create_iou_curve(
    data: torch.Tensor | dict[str, dict[str, torch.Tensor]]
) -> plt.Figure:
    """Create a precision curve over varying IoU thresholds."""
    fig = plt.figure(figsize=(6, 4))
    if isinstance(data, dict):
        for fs, y in data.items():
            plt.plot(range(101), y, label=fs, linewidth=2 if fs == "global" else 1)
        plt.legend()
    else:
        plt.plot(range(101), data)

    plt.xlabel("IoU threshold")
    plt.xticks(range(0, 101, 10), [f"{i:d}%" for i in range(0, 101, 10)])
    plt.ylabel("Score")
    plt.yticks([i / 10 for i in range(11)])
    plt.grid()
    return fig


def plot_metrics(path: Path, metrics: dict[str, dict[str, torch.Tensor]]) -> None:
    """Create plots for the provided metrics."""
    # Create the global IoU curves
    for fs in metrics:
        (path / fs).mkdir(exist_ok=True, parents=True)
        for mt in {"precision", "recall", "f1"}:
            fig = create_iou_curve(metrics[fs][f"{mt}_curve"])
            plt.title(
                f"{mt.capitalize()} ({fs} - auc = {metrics[fs][f'{mt}_auc']:.3f})"
            )
            fig.savefig(path / fs / f"{mt}.png")
            plt.close()

    # aggregate curves
    for mt in {"precision", "recall", "f1"}:
        fig = create_iou_curve({fs: metrics[fs][f"{mt}_curve"] for fs in metrics})
        plt.title(f"{mt.capitalize()} (varying sizes)")
        fig.savefig(path / f"{mt}.png")
        plt.close()


def plot_predictions(
    background: torch.Tensor,
    target: torch.Tensor,
    prediction: torch.Tensor,
    iou_thr: float = 0.5,
) -> Figure:
    """Plot figure to compare predictions to targets."""
    background = background.clip(0, 1)
    fig, axs = plt.subplots(2, 3)

    # Calculate the pr-results
    pr = get_pr(gt=target, pr=prediction)
    gt_found = pr["recall"] >= iou_thr
    pr_found = pr["precision"] >= iou_thr

    # Target
    axs[0, 0].imshow(background)
    axs[0, 0].set_title("Input")
    _plot_with_background(axs[0, 1], im=target, background=background, norm=False)
    axs[0, 1].set_title("Target")
    target_ = torch.zeros_like(target)
    for i, t_id in enumerate(pr["gt_ids"]):
        v = 1 if gt_found[i] else 2  # Green is found, yellow is not
        target_[target == t_id] = v
    _plot_with_background(axs[0, 2], im=target_, background=background, norm=True)
    v_c, v_u = (target_ == 1).sum(), (target_ == 2).sum()  # noqa: PLR2004
    axs[0, 2].set_title(f"Target (correct: {100*(v_c)/(v_c+v_u+1e-5):.2f}%)")

    # Prediction
    axs[1, 0].imshow(background)
    axs[1, 0].set_title("Input")
    _plot_with_background(axs[1, 1], im=prediction, background=background, norm=False)
    axs[1, 1].set_title("Prediction")
    prediction_ = torch.zeros_like(prediction)
    for i, p_id in enumerate(pr["pr_ids"]):
        v = 1 if pr_found[i] else 2  # Green is found, yellow is not
        prediction_[prediction == p_id] = v
    _plot_with_background(axs[1, 2], im=prediction_, background=background, norm=True)
    v_c, v_u = (prediction_ == 1).sum(), (prediction_ == 2).sum()  # noqa: PLR2004
    axs[1, 2].set_title(f"Prediction (correct: {100*(v_c)/(v_c+v_u+1e-5):.2f}%)")

    # Formatting
    _ = [axi.set_axis_off() for axi in axs.ravel()]
    plt.tight_layout()
    return fig


def _plot_with_background(
    ax: plt.Axes, im: torch.Tensor, background: torch.Tensor, norm: bool
) -> None:
    """Plot an image with a background."""
    background_ = background.clone()
    # background_[im != 0] = torch.nan
    ax.imshow(background_)
    im_ = im.clone().float()
    im_[im == 0] = torch.nan
    if norm:
        ax.imshow(im_, vmin=0, vmax=2)
    else:
        ax.imshow(im_, cmap="plasma")


def shuffle_preds_ids(pred: torch.Tensor) -> torch.Tensor:
    """Shuffle instance segmentation ids to optain a cleaner visualization."""
    for ix in np.unique(pred)[1:]:
        pred[pred == ix] = random.randint(1, 100_000)
    return pred
