"""Model evalution script."""

from __future__ import annotations

import json
from math import ceil
from time import time

import matplotlib.pyplot as plt
import numpy as np
import torch
from tqdm import tqdm

from vito_lot_delineation.data import DelineationDataset
from vito_lot_delineation.evaluation.metrics import get_pr
from vito_lot_delineation.evaluation.plot import (
    plot_metrics,
    plot_predictions,
    shuffle_preds_ids,
)
from vito_lot_delineation.models import BaseModel, load_model

# Pre-computed field sizes
FIELD_SIZES = {
    "small": (0, 40),
    "medium": (40, 100),
    "large": (100, 300),
    "extra large": (300, 1e9),
}
KEYS = ("precision", "recall", "n_gt", "n_pr")


def main(
    model: BaseModel,
    data_dir: Path | None = None,
    split: str = "testing",
    field_sizes: dict[str, tuple[int, int]] | None = None,
    max_steps: int | None = None,
) -> None:
    """
    Evaluate the provided model.

    Parameters
    ----------
    model : BaseModel
        The model to evaluate.
    data_dir : Path | None,
        Directory to locate dataset to test on.
    split : str,
        Dataset split on which to test on.
    field_sizes : dict[str,tuple[int,int]], optional
        The field sizes to evaluate the model on, by default None
    max_steps : int, optional
        The maximum number of batches to evaluate over, by default None (all)
    """
    # Create evaluation folder
    model.logger("Start model evaluation")
    (model.model_folder / "evaluation").mkdir(exist_ok=True, parents=True)

    # Get the testing metrics
    metrics = compute_metrics(
        model=model,
        data_dir=data_dir,
        split=split,
        field_sizes=field_sizes,
        max_steps=max_steps,
    )

    # Log the evaluation results to the model
    model.logger("Evaluation results:")
    for fs in ("global", "small", "medium", "large", "extra large"):
        model.logger(f" - Results for field size: {fs}")
        model.logger(
            f"   - Precision: {100*metrics[fs]['precision'].sum()/max(len(metrics[fs]['n_pr']),1):.2f}%"
        )
        model.logger(
            f"   - Recall: {100*metrics[fs]['recall'].sum()/max(len(metrics[fs]['n_gt']),1):.2f}%"
        )

    # Write away the metrics
    with open(model.model_folder / "evaluation/results.json", "w") as f:
        json.dump(
            {
                fs: {k: v.detach().cpu().tolist() for k, v in stats.items()}
                for fs, stats in metrics.items()
            },
            f,
        )

    # Create plots of for the metrics
    plot_metrics(
        path=model.model_folder / "evaluation",
        metrics=metrics,
    )


def compute_metrics(
    model: BaseModel,
    data_dir: Path | None = None,
    split: str = "testing",
    field_sizes: dict[str, tuple[int, int]] | None = None,
    max_steps: int | None = None,
    save_predictions: bool = True,
) -> dict[str, dict[str, torch.Tensor]]:
    """Compute the evaluation metrics for the specified model."""
    # Input parsing
    field_sizes = field_sizes or FIELD_SIZES
    assert all(x in field_sizes for x in ("small", "medium", "large", "extra large"))

    # get testing metrics
    data = DelineationDataset(
        split=split,
        data_dir=data_dir,
        bands=model.cfg["input"]["bands"],
        n_ts=model.cfg["input"]["n_ts"],
        augment=False,
        return_distance=False,
        return_watersheds=False,
    )
    return _evaluate_model(
        model=model,
        data=data,
        field_sizes=field_sizes,
        max_steps=max_steps,
        save_predictions=save_predictions,
    )


def _evaluate_model(
    model: BaseModel,
    data: DelineationDataset,
    field_sizes: dict[str, tuple[int, int]],
    batch_size: int = 16,
    max_steps: int | None = None,
    save_predictions: bool = True,
) -> dict[str, list[torch.Tensor]]:
    """
    Run full evaluation over a trained model.

    Parameters
    ----------
    model : BaseModel
        The model to evaluate.
    data : DelineationDataset
        The dataset to evaluate on.
    field_sizes : dict[str,tuple[int,int]]
        Field size boundaries used to compute metrics.
    batch_size : int, optional
        The batch size to use, by default 16
    max_steps : int|None, optional
        The maximum number of batches to evaluate over, by default None (all)
    save_predictions : bool
        Save a batch of predictions on disk to analyze

    Returns
    -------
    dict[str, list[torch.Tensor]]
        The evaluation metrics.
    """
    model.eval()
    metrics = {fs: {k: [] for k in KEYS} for fs in field_sizes}
    metrics["global"] = {k: [] for k in KEYS}
    n_steps = max_steps or ceil(len(data) / batch_size)
    times = []
    for i in tqdm(range(n_steps), desc="Evaluating..."):
        batch = data.get_batch(batch_size=batch_size)
        inputs = batch["input"].to(model.device)
        instances = batch["instance"].to(model.device)

        # predict and measure time
        time_a = time()
        outputs = model(inputs)
        time_b = time()
        times.append(time_b - time_a)

        # save handful of predictions
        if save_predictions and i == 0:
            pred_path = model.model_folder / "evaluation" / "predictions"
            pred_path.mkdir(parents=True, exist_ok=True)
            for i, (inp, trg, prd) in enumerate(zip(inputs, instances, outputs)):
                inp = inp.flatten(0, 1).permute(1, 2, 0)
                dims = 3 if inp.shape[0] >= 3 else 1  # noqa: PLR2004
                fig = plot_predictions(
                    inp[:, :, :dims].cpu(), trg.cpu(), shuffle_preds_ids(prd.cpu())
                )
                fig.savefig(pred_path / f"{i}.png")
                plt.close()

        # Compute metrics
        for gt, pr in zip(instances, outputs):
            result = _compute_single(gt=gt.cpu(), pr=pr.cpu(), field_sizes=field_sizes)
            for fs, stats in result.items():
                for k, v in stats.items():
                    metrics[fs][k].append(v)
                    metrics["global"][k].append(v)

    # log inference time
    model.logger(
        f"Avg prediction time per sample: {np.mean(times) / len(inputs): .4f}s"
    )

    # generate curves
    res = {
        fs: {k: torch.concat(v) for k, v in stats.items()}
        for fs, stats in metrics.items()
    }
    return _compute_ioucurves_and_auc(res)


def _compute_ioucurves_and_auc(
    res: dict[str, dict[str, torch.Tensor]]
) -> dict[str, dict[str, torch.Tensor]]:
    """Update dictionary containing precisions and recalls per samples by caclulating iou curves and aucs."""
    for fs in res:
        res[fs]["precision_curve"] = torch.Tensor(
            [
                (res[fs]["precision"] > (i / 100)).sum()
                / max(len(res[fs]["precision"]), 1)
                for i in range(101)
            ]
        )
        res[fs]["recall_curve"] = torch.Tensor(
            [
                (res[fs]["recall"] > (i / 100)).sum() / max(len(res[fs]["recall"]), 1)
                for i in range(101)
            ]
        )
        res[fs]["f1_curve"] = torch.Tensor(
            [
                (2 * pr * rc) / (pr + rc + 1e-12)
                for pr, rc in zip(res[fs]["precision_curve"], res[fs]["recall_curve"])
            ]
        )
        res[fs]["precision_auc"] = torch.mean(res[fs]["precision_curve"])
        res[fs]["recall_auc"] = torch.mean(res[fs]["recall_curve"])
        res[fs]["f1_auc"] = torch.mean(res[fs]["f1_curve"])
    return res


def _compute_single(
    gt: torch.Tensor,
    pr: torch.Tensor,
    field_sizes: dict[str, tuple[int, int]],
) -> dict[str, dict[str, float | int]]:
    """Compute a single prediction's metrics."""
    result = {}
    arr = get_pr(gt=gt, pr=pr)
    for fs, (s_min, s_max) in field_sizes.items():
        fs_pr = (s_min <= arr["pr_size"]) & (arr["pr_size"] < s_max)
        fs_gt = (s_min <= arr["gt_size"]) & (arr["gt_size"] < s_max)
        result[fs] = {
            "n_pr": arr["pr_size"][fs_pr],
            "n_gt": arr["gt_size"][fs_gt],
            "precision": arr["precision"][fs_pr],
            "recall": arr["recall"][fs_gt],
        }
    return result


if __name__ == "__main__":
    from pathlib import Path

    # Evaluate the requested model
    paths = ["data/models/20230609T083721-EnanchedResUnet3D_ts24_newAugm"]
    for p in paths:
        print(p)
        model = load_model(Path(p))
        main(
            model=model,
            data_dir=Path(__file__).parent.parent.parent.parent / "data/data",
        )
