"""Evaluation functions."""

from __future__ import annotations

import json
from pathlib import Path

import numpy as np
from numpy.typing import NDArray

from vito_crop_classification.constants import get_models_folder
from vito_crop_classification.evaluation.plots import (
    plot_confusion,
    plot_metric,
    plot_metric_comparison,
)
from vito_crop_classification.evaluation.report import generate_report
from vito_crop_classification.vito_logger import Logger, bh_logger


def evaluate(
    y_true: NDArray[np.int64],
    y_pred: NDArray[np.int64],
    class_ids: NDArray[np.str_],
    class_names: NDArray[np.str_],
    save_dir: Path | None = None,
    logger: Logger | None = None,
    verbose: bool = True,
    title: str | None = None,
) -> None:
    """Run evaluation over predictions.

    Parameters
    ----------
    y_true : NDArray[np.int64]
        Target classes.
    y_pred : NDArray[np.float64]
        Predicted classes by the model.
    save_dir : Path | None
        Where to save evaluation files, by Default None.
    logger : Logger
        Model logger used to log results.
    verbose: bool
        Whether to print out the report.
    title: str | None
        Optional title to add to each plot.
    """
    logger = logger if (logger is not None) else bh_logger
    logger("Evaluating:")
    logger(f" - Writing results to {save_dir}")
    if save_dir is not None:
        save_dir.mkdir(parents=True, exist_ok=True)
    logger(" - Generating report..")
    generate_report(
        y_true=y_true,
        y_pred=y_pred,
        save_dir=save_dir,
        logger=logger,
        verbose=verbose,
    )
    logger(" - Plotting confusion matrices..")
    plot_confusion(
        y_true=y_true,
        y_pred=y_pred,
        save_dir=save_dir,
        class_ids=class_ids,
        class_names=class_names,
        title=title,
    )
    logger(" - Plotting F1 scores..\n")
    plot_metric(
        y_true=y_true,
        y_pred=y_pred,
        save_dir=save_dir,
        title=title,
    )


def compare(
    mdl_tags: list[str],
    save_dir: Path | None = None,
    mdl_f: Path | None = None,
) -> None:
    """
    Compare different models.

    Parameters
    ----------
    mdl_tags : list[str]
        List of tags referencing the models to compare.
    save_dir : Path | None
        Where to save evaluation files, by Default None.
    mdl_f : Path, optional
        Folder where models are stored.

    Note
    ----
    This method assumes that each referenced model contains a report.json in its evaluation folder.
    """
    mdl_f = mdl_f or get_models_folder()

    # Load in all the model reports
    reports = []
    for tag in mdl_tags:
        with open(mdl_f / tag / "evaluation/report.json", "r") as f:
            reports.append(json.load(f))

    # Compare the model's main metrics
    plot_metric_comparison(
        mdl_tags=mdl_tags,
        mdl_reports=reports,
        save_dir=save_dir,
    )


if __name__ == "__main__":
    class_ids = np.asarray([f"{i}" for i in range(10)])
    class_names = np.asarray([f"test_{i}" for i in range(10)])
    a = np.asarray([class_ids[x] for x in np.random.randint(0, len(class_ids), 1000)])
    b = np.asarray([class_ids[x] for x in np.random.randint(0, len(class_ids), 1000)])
    evaluate(
        y_true=a,
        y_pred=np.concatenate([b[:500], a[500:]]),
        class_ids=class_ids,
        class_names=class_names,
        verbose=False,
        save_dir=(Path(__file__)).parent / "../../../reports",
    )

    # compare(
    #     mdl_tags=[
    #         "20220925T223025-dense_output_size=32_depth=2",
    #         "20220925T223908-dense_output_size=32_depth=4",
    #         "20220925T224804-dense_output_size=64_depth=2",
    #         "20220925T225629-dense_output_size=64_depth=4",
    #     ],
    #     save_dir=(Path(__file__)).parent / "../../../reports",
    # )
