"""Trainer class."""

from __future__ import annotations

from typing import Any

import numpy as np
import pandas as pd
import torch
from numpy.typing import NDArray
from sklearn.metrics import precision_recall_fscore_support
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

from vito_crop_classification.evaluation import create_confusion_matrix
from vito_crop_classification.model import Model
from vito_crop_classification.model.classifiers import SimilarityClassifier
from vito_crop_classification.training.dataset import get_dataset
from vito_crop_classification.training.loss import CustomLoss
from vito_crop_classification.training.lr_scheduler import LearningRateScheduler


class Trainer:
    """Trainer class."""

    def __init__(
        self,
        model: Model,
        train_type: str,
        loss_type: str,
        es_tolerance: int = 5,
        es_improvement: float = 0.0,
        es_metric: str = "f1_class",
        lr_scheduler: str | None = "reduce_on_plateau",
    ) -> None:
        """
        Trainer class.

        Parameters
        ----------
        model : Model
            Model to train
        train_type : str
            Type of model training to use, depends on which classifier is used
            Options: classification, similarity
        loss_type : str
            Type of loss to use, options vary based on which classifier is used
        es_tolerance : int
            Early stopping tolerance, by default 5
        es_improvement : float
            Early stopping improvement before best model gets replaced, by default 0.0
        es_metric : int
            Metric used to monitor early stopping, f1_weighed by default
            Options: loss, f1_sample, f1_class
        lr_scheduler : str, optional
            Learning rate scheduling to use
            Options: None, reduce_on_plateau
        """
        assert train_type in {"classification", "similarity"}
        assert es_metric in {"loss", "f1_sample", "f1_class"}

        self.model = model
        self._train_type = train_type
        self._loss = CustomLoss(
            loss_type=loss_type,
            classes=self.model.get_class_ids(),
        )
        self.optimizer = Adam(self.model.parameters(), lr=1e-3)
        self.lr_scheduler = LearningRateScheduler(
            tag=lr_scheduler,
            optimizer=self.optimizer,
            mode="min" if (es_metric == "loss") else "max",
            logger=self.model.logger,
        )
        self.es_tolerance = es_tolerance
        self.es_improvement = es_improvement
        self.es_metric = es_metric
        self._writer = SummaryWriter(log_dir=str(self.model.model_folder / "logs"))

    def train(
        self,
        data_train: pd.DataFrame,
        data_val: pd.DataFrame,
        steps_max: int = 1e9,
        steps_val: int = 500,
        batch_size: int = 512,
        augmentation: bool = True,
    ) -> Model:
        """
        Training script.

        Parameters
        ----------
        data_train : pd.DataFrame
            Data used to train the model
        data_val : pd.DataFrame
            Data used to evaluate a model
        steps_max : int, optional
            Maximum number of steps used to train a model (nearly infinite, 1e9, by default)
        steps_val : int, optional
            Maximum number of steps used to evaluate a model, by default 1000
        batch_size : int, optional
            Batch size used during training, by default 512
        augmentation : bool, optional
            Either use augmentation or not, by default True

        Returns
        -------
        Model
            Best model obtained during training.
        """
        # Create class representations if necessary
        reps = self._set_reps() if (self._train_type == "similarity") else None

        # Prepare the datasets
        dataset_train = get_dataset(
            train_type=self._train_type,
            data=data_train,
            process_f=self.model.enc.preprocess_df,
            balance=True,
            augment=augmentation,
            classes=self.model.get_class_ids(),
            representatives=reps,
        )
        dataloader_train = dataset_train.get_dataloader(batch_size=batch_size)
        dataset_val = get_dataset(
            train_type=self._train_type,
            data=data_val,
            process_f=self.model.enc.preprocess_df,
            balance=False,
            augment=False,
            classes=self.model.get_class_ids(),
            representatives=reps,
        )
        dataloader_val = dataset_val.get_dataloader(batch_size=batch_size)

        # Start the training
        self.model.logger("Start the training:")
        self.model.logger(f" - Maximum steps: {int(steps_max)}")
        self.model.logger(f" - Validation steps: {steps_val}")
        self.model.logger(f" - Batch-size: {batch_size}")
        self.model.logger(f" - Loss function: {self._loss}")
        self.model.logger(f" - Early-stopping tolerance: {self.es_tolerance}")
        self.model.logger(f" - Early-stopping improvement: {self.es_improvement}")
        self.model.logger(f" - Early-stopping metric: {self.es_metric}")
        self.model.logger(f" - Optimizer: {self.optimizer.__class__.__name__}")
        self.model.logger(f" - Learning rate scheduling: {self.lr_scheduler}")
        self.model.logger(f" - Device: {self.model.device}")
        self.model.logger("\n")
        global_step, es_buffer = 0, 0

        # Start with initial validation
        best_metric = self._validate(
            dataloader=dataloader_val,
            global_step=global_step,
        )
        self.lr_scheduler(best_metric)
        self.model.logger(f"Initial validation score for '{self.es_metric}': {best_metric:.5f}")
        self.model.logger("\n")

        # start training loop
        while global_step < steps_max:
            self.model.logger(
                f"Training for {steps_val} steps (current step: {global_step}/{steps_max})"
            )
            global_step = self._train(
                dataloader=dataloader_train,
                global_step=global_step,
                n=min(steps_val, steps_max - global_step),
            )
            new_metric = self._validate(
                dataloader=dataloader_val,
                global_step=global_step,
            )
            self.lr_scheduler(best_metric)

            # Check for early stopping
            self.model.logger("Early stopping:")
            if not self._es_goal(new=new_metric, old=best_metric):
                es_buffer += 1
                message = f" - No improvement for {self.es_metric}: {new_metric:.5f} (current best: {best_metric:.5f})"
                self.model.logger(message)
                if es_buffer >= self.es_tolerance:
                    self.model.logger(" - Early Stopping tolerance reached. Training stopped!")
                    self.model.logger("\n")
                    break
            else:
                self.model.logger(f" - New best {self.es_metric}: {new_metric:.5f}")
                self.model.logger(
                    f" - Improvement over previous best: {new_metric-best_metric:.5f}"
                )
                self.model.logger(" - Best model updated!")
                best_metric = new_metric
                es_buffer = 0
                self.model.save()
            self.model.logger("\n")

        return Model.load(self.model.model_folder)

    def _step(
        self, batch: dict[str, torch.Tensor | np.ndarray[list[torch.Tensor]]]
    ) -> dict[str, Any]:
        """Create one step."""
        enc = self.model.enc.forward_process(batch["inputs"])
        output = self.model.clf.forward_process(enc)
        loss = self._loss(
            inputs=output,
            inputs2=batch.get("representatives", None),
            targets=batch["targets"].to(self.model.device),
        )
        return {
            "preds": self.model.clf(enc),
            "targets": batch["targets"],
            "loss": loss,
        }

    def _train(self, dataloader: DataLoader, global_step: int, n: int) -> int:
        """Train a single epoch."""
        self.model.train()
        loss_train, train_step = [], 0
        with tqdm(total=n, desc="Training..") as pbar:
            while True:
                for batch in dataloader:
                    # Run the batch
                    self.optimizer.zero_grad()
                    result = self._step(batch)
                    loss = result["loss"]
                    loss.backward(retain_graph=True)
                    self.optimizer.step()

                    # Write away the evaluation metrics
                    loss_train.append(float(loss.cpu().detach()))
                    self._write_eval(
                        loss=loss_train[-1],
                        step=global_step,
                        y_true=torch.argmax(result["targets"], dim=1).cpu().detach().numpy(),
                        y_pred=torch.argmax(result["preds"], dim=1).cpu().detach().numpy(),
                        is_val=False,
                        incl_conf=(train_step + 1) == n,
                    )

                    # Update the progress bar
                    pbar.update(1)
                    pbar.set_postfix({"loss": f"{sum(loss_train) / len(loss_train)}"})
                    train_step += 1
                    global_step += 1

                    # Stop training if number of steps exceeded
                    if train_step >= n:
                        return global_step

    def _validate(self, dataloader: DataLoader, global_step: int) -> float:
        """Run one validation sequence."""
        self.model.eval()
        loss_val, all_targets, all_preds = [], torch.LongTensor([]), torch.LongTensor([])
        with tqdm(total=len(dataloader), desc="Validating..") as pbar:
            for batch in dataloader:
                # Run the batch
                result = self._step(batch)
                loss = result["loss"]

                # Keep the results
                loss_val.append(float(loss.cpu().detach()))
                all_targets = torch.concat(
                    [all_targets, torch.argmax(result["targets"].cpu(), dim=1)]
                )
                all_preds = torch.concat([all_preds, torch.argmax(result["preds"].cpu(), dim=1)])

                # Update the progress bar
                pbar.update(1)
                pbar.set_postfix({"loss": f"{sum(loss_val) / len(loss_val)}"})

        # Write away the evaluation metrics
        val_report = self._write_eval(
            loss=sum(loss_val) / len(loss_val),
            step=global_step,
            y_true=all_targets.detach().numpy(),
            y_pred=all_preds.detach().numpy(),
            is_val=True,
            incl_conf=True,
        )

        # Return evaluation loss
        return val_report[self.es_metric]

    def _es_goal(self, new: float, old: float) -> bool:
        """Check if the early stopping criterion is met (i.e. significant model improvement)."""
        if self.es_metric == "loss":
            return new < (old - self.es_improvement)
        else:
            return new > (old + self.es_improvement)

    def _write_eval(
        self,
        loss: float,
        step: int,
        y_true: NDArray[np.int_],
        y_pred: NDArray[np.int_],
        is_val: bool,
        incl_conf: bool,
    ) -> dict[str, Any]:
        """
        Write away the evaluation.

        Parameters
        ----------
        loss: float
            Average loss
        step: int
            Global step of the training
        y_true: torch.Tensor
            Target labels
        y_pred: torch.Tensor
            Predicted labels
        is_val: bool
            Whether to write it away as validation or training results
        incl_conf: bool
            Include the confusion matrix plots
        """
        # Word on averages for 'precision_recall_fscore_support':
        #  - micro is the same as accuracy sum(true positives) / len(all predictions)
        #  - macro is class-weighted (every class is as important)
        #  - weighted is sample-weighted (every sample is as important, keeps class imbalance)
        tag = "val" if is_val else "train"
        p_s, r_s, f1_s, _ = precision_recall_fscore_support(
            y_true=y_true,
            y_pred=y_pred,
            zero_division=0,
            average="weighted",  # Globally (every sample is as important)
        )
        p_c, r_c, f1_c, _ = precision_recall_fscore_support(
            y_true=y_true,
            y_pred=y_pred,
            zero_division=0,
            average="macro",  # Weighted (every class is as important)
        )
        self._writer.add_scalar("loss", loss, global_step=step)
        self._writer.add_scalar(f"precision/{tag}_sample", p_s, global_step=step)
        self._writer.add_scalar(f"precision/{tag}_class", p_c, global_step=step)
        self._writer.add_scalar(f"recall/{tag}_sample", r_s, global_step=step)
        self._writer.add_scalar(f"recall/{tag}_class", r_c, global_step=step)
        self._writer.add_scalar(f"f1/{tag}_sample", f1_s, global_step=step)
        self._writer.add_scalar(f"f1/{tag}_class", f1_c, global_step=step)

        # Create and write the confusion matrix
        if incl_conf:
            fig = create_confusion_matrix(
                y_true=y_true,
                y_pred=y_pred,
                class_ids=self.model.get_class_ids(),
                class_names=self.model.get_class_names(),
                normalise=False,
            )
            self._writer.add_figure(f"confusion/{tag}_count", fig, global_step=step)
            fig = create_confusion_matrix(
                y_true=y_true,
                y_pred=y_pred,
                class_ids=self.model.get_class_ids(),
                class_names=self.model.get_class_names(),
                normalise=True,
            )
            self._writer.add_figure(f"confusion/{tag}_ratio", fig, global_step=step)

        return {
            "loss": loss,
            "f1_sample": f1_s,
            "f1_class": f1_c,
        }

    def _set_reps(self) -> dict[str, torch.Tensor]:
        """Set the model's classifier representatives."""
        assert isinstance(self.model.clf, SimilarityClassifier)
        return self.model.clf.create_representatives()
