"""Copy of the plots.py file from the vito_lot_delineation.evaluation package."""

from __future__ import annotations

import random
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import torch
from matplotlib.figure import Figure


def _get_coverage_and_overlap(
    gt: torch.Tensor, pr: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
    """Compute coverage and overlapping matrices."""
    gt_ids = [x for x in torch.unique(gt) if x > 0]
    pr_ids = [x for x in torch.unique(pr) if x > 0]

    # compute overlap and coverage matrices
    overlap_m = torch.zeros(len(pr_ids), len(gt_ids))
    coverage_m = torch.zeros(len(gt_ids), len(pr_ids))
    for p, p_id in enumerate(pr_ids):
        for t, t_id in enumerate(gt_ids):
            t_arr = gt == t_id
            p_arr = pr == p_id
            # how much is the prediction overlapping the target in respect to its area
            overlap_m[p, t] = (t_arr & p_arr).sum() / (p_arr).sum()

            # how much is the target covered by the prediction
            coverage_m[t, p] = (t_arr & p_arr).sum() / (t_arr).sum()
    return coverage_m, overlap_m


def _merge_oversegmentation(
    pr: torch.Tensor,
    coverage_m: torch.Tensor,
    overlap_m: torch.Tensor,
    overlapping_thr: float = 0.8,
) -> torch.Tensor:
    """Merge oversegmented prediction ids."""
    pr_ids = [x for x in torch.unique(pr) if x > 0]

    # filter bad overlapping fields from coverage
    coverage_m_ = coverage_m.clone()
    coverage_m_[(overlap_m < overlapping_thr).T] = 0
    coverage_score = coverage_m_.sum(1)

    # merge ids
    for t, cvg_scr in enumerate(coverage_score):
        if cvg_scr > 0:
            p_ids = torch.Tensor(pr_ids)[torch.where(coverage_m_[t] > 0)[0]]
            for p_id in p_ids:
                pr[pr == p_id] = p_ids[0]
    return pr


def _get_overflows(
    pr: torch.Tensor, coverage_m: torch.Tensor, overflow_thr: float = 0.4
) -> torch.Tensor:
    """Get overflowing prediction ids."""
    pr_ids = [x for x in torch.unique(pr) if x > 0]
    overflowing_mask = (coverage_m > overflow_thr).sum(0) > 1
    return torch.Tensor(pr_ids)[overflowing_mask]


def get_pr(gt: torch.Tensor, pr: torch.Tensor) -> dict[str, torch.Tensor]:
    """Calculate the precision and recall array."""
    # Gather parameters
    coverage_m, overlap_m = _get_coverage_and_overlap(gt=gt, pr=pr)
    overlflowing_ids = _get_overflows(pr=pr, coverage_m=coverage_m)
    pr = _merge_oversegmentation(pr=pr, coverage_m=coverage_m, overlap_m=overlap_m)

    # compute ids
    gt_ids = [x for x in torch.unique(gt) if x > 0]
    pr_ids = [x for x in torch.unique(pr) if x > 0]

    # Calculate the precison recall matrix
    pr_m = torch.zeros(len(gt_ids), len(pr_ids))
    for i, t_id in enumerate(gt_ids):
        t_arr = gt == t_id
        for j, p_id in enumerate(pr_ids):
            if int(p_id) in overlflowing_ids:
                pr_m[i, j] = 0.0
            else:
                p_arr = pr == p_id
                pr_m[i, j] = (t_arr & p_arr).sum() / (t_arr | p_arr).sum()

    # Return the processed result
    return {
        "gt_ids": torch.stack(gt_ids) if gt_ids else torch.IntTensor([]),
        "pr_ids": torch.stack(pr_ids) if pr_ids else torch.IntTensor([]),
        "precision": torch.zeros(pr_m.shape[1])
        if (pr_m.shape[0] == 0)
        else pr_m.max(axis=0).values,
        "recall": torch.zeros(pr_m.shape[0])
        if (pr_m.shape[1] == 0)
        else pr_m.max(axis=1).values,
        "gt_size": torch.IntTensor([(gt == t_id).sum() for t_id in gt_ids]),
        "pr_size": torch.IntTensor([(pr == p_id).sum() for p_id in pr_ids]),
    }


def create_iou_curve(
    data: torch.Tensor | dict[str, dict[str, torch.Tensor]]
) -> plt.Figure:
    """Create a precision curve over varying IoU thresholds."""
    fig = plt.figure(figsize=(6, 4))
    if isinstance(data, dict):
        for fs, y in data.items():
            plt.plot(range(101), y, label=fs, linewidth=2 if fs == "global" else 1)
        plt.legend()
    else:
        plt.plot(range(101), data)

    plt.xlabel("IoU threshold")
    plt.xticks(range(0, 101, 10), [f"{i:d}%" for i in range(0, 101, 10)])
    plt.ylabel("Score")
    plt.yticks([i / 10 for i in range(11)])
    plt.grid()
    return fig


def plot_metrics(path: Path, metrics: dict[str, dict[str, torch.Tensor]]) -> None:
    """Create plots for the provided metrics."""
    # Create the global IoU curves
    for fs in metrics:
        (path / fs).mkdir(exist_ok=True, parents=True)
        for mt in {"precision", "recall", "f1"}:
            fig = create_iou_curve(metrics[fs][f"{mt}_curve"])
            plt.title(
                f"{mt.capitalize()} ({fs} - auc = {metrics[fs][f'{mt}_auc']:.3f})"
            )
            fig.savefig(path / fs / f"{mt}.png")
            plt.close()

    # aggregate curves
    for mt in {"precision", "recall", "f1"}:
        fig = create_iou_curve({fs: metrics[fs][f"{mt}_curve"] for fs in metrics})
        plt.title(f"{mt.capitalize()} (varying sizes)")
        fig.savefig(path / f"{mt}.png")
        plt.close()


def plot_predictions(
    background: torch.Tensor,
    target: torch.Tensor,
    prediction: torch.Tensor,
    iou_thr: float = 0.5,
) -> Figure:
    """Plot figure to compare predictions to targets."""
    background = background.clip(0, 1)
    fig, axs = plt.subplots(2, 3)

    # Calculate the pr-results
    pr = get_pr(gt=target, pr=prediction)
    gt_found = pr["recall"] >= iou_thr
    pr_found = pr["precision"] >= iou_thr

    # Target
    axs[0, 0].imshow(background)
    axs[0, 0].set_title("Input")
    _plot_with_background(axs[0, 1], im=target, background=background, norm=False)
    axs[0, 1].set_title("Target")
    target_ = torch.zeros_like(target)
    for i, t_id in enumerate(pr["gt_ids"]):
        v = 1 if gt_found[i] else 2  # Green is found, yellow is not
        target_[target == t_id] = v
    _plot_with_background(axs[0, 2], im=target_, background=background, norm=True)
    v_c, v_u = (target_ == 1).sum(), (target_ == 2).sum()  # noqa: PLR2004
    axs[0, 2].set_title(f"Target (correct: {100*(v_c)/(v_c+v_u+1e-5):.2f}%)")

    # Prediction
    axs[1, 0].imshow(background)
    axs[1, 0].set_title("Input")
    _plot_with_background(axs[1, 1], im=prediction, background=background, norm=False)
    axs[1, 1].set_title("Prediction")
    prediction_ = torch.zeros_like(prediction)
    for i, p_id in enumerate(pr["pr_ids"]):
        v = 1 if pr_found[i] else 2  # Green is found, yellow is not
        prediction_[prediction == p_id] = v
    _plot_with_background(axs[1, 2], im=prediction_, background=background, norm=True)
    v_c, v_u = (prediction_ == 1).sum(), (prediction_ == 2).sum()  # noqa: PLR2004
    axs[1, 2].set_title(f"Prediction (correct: {100*(v_c)/(v_c+v_u+1e-5):.2f}%)")

    # Formatting
    _ = [axi.set_axis_off() for axi in axs.ravel()]
    plt.tight_layout()
    return fig


def _plot_with_background(
    ax: plt.Axes, im: torch.Tensor, background: torch.Tensor, norm: bool
) -> None:
    """Plot an image with a background."""
    background_ = background.clone()
    # background_[im != 0] = torch.nan
    ax.imshow(background_)
    im_ = im.clone().float()
    im_[im == 0] = torch.nan
    if norm:
        ax.imshow(im_, vmin=0, vmax=2)
    else:
        ax.imshow(im_, cmap="plasma")


def shuffle_preds_ids(pred: torch.Tensor) -> torch.Tensor:
    """Shuffle instance segmentation ids to optain a cleaner visualization."""
    for ix in np.unique(pred)[1:]:
        pred[pred == ix] = random.randint(1, 100_000)
    return pred
