"""Custom trainer class."""

from __future__ import annotations

from pathlib import Path
from typing import Any

import matplotlib.pyplot as plt
import torch
from tqdm import tqdm

from vito_lot_delineation.data import DelineationDataset
from vito_lot_delineation.evaluation import create_iou_curve, get_pr, plot_predictions
from vito_lot_delineation.evaluation.plot import shuffle_preds_ids
from vito_lot_delineation.models import SemanticModel
from vito_lot_delineation.training.base_trainer import BaseTrainer
from vito_lot_delineation.training.loss import VITOLoss


class SemanticTrainer(BaseTrainer):
    """Custom trainer class."""

    def __init__(
        self,
        model: SemanticModel,
        loss_cfg: list[dict[str, Any]],
        input_cfg: list[dict[str, Any]],
        es_tolerance: int = 5,
        es_improvement: float = 0.001,
        es_metric: str = "loss",
        lr_tolerance: int = 3,
    ) -> None:
        """
        Initialise the trainer.

        Parameters
        ----------
        model : Model
            Model to train
        loss_cfg : list[dict[str, Any]]
            List of losses configurations
        input_cfg : list[dict[str, Any]]
            List of model input configurations
        es_tolerance : int
            Early stopping tolerance, by default 5
        es_improvement : float
            Early stopping improvement before best model gets replaced, by default 0.0
        es_metric : int
            Metric used to monitor early stopping, loss by default
            Options: loss
        lr_tolerance : int
            Scheduler tolerance, by default 3
        """
        assert es_metric in {"loss"}, "Invalid early stopping metric"
        super().__init__(
            model=model,
            es_tolerance=es_tolerance,
            es_improvement=es_improvement,
            es_metric=es_metric,
            lr_tolerance=lr_tolerance,
        )
        self._loss = VITOLoss(loss_cfg=loss_cfg, reduction="sum")
        self._input_cfg = input_cfg

    def train(
        self,
        data_dir: Path | None = None,
        steps_max: int | None = None,
        steps_val: int = 100,
        batch_size: int = 32,
        use_augmentation: bool = True,
    ) -> SemanticModel:
        """Train a segmentation model.

        Parameters
        ----------
        data_dir : Path | None, optional
            Dataset directory, if not specified, default one is used, by default None
        steps_max : int | None, optional
            Maximum number of steps used to train a model, by default None (unlimited)
        steps_val : int, optional
            Number of steps used to evaluate a model, by default 100
        batch_size : int, optional
            Batch size used during training, by default 32
        use_augmentation : bool, optional
            Use augmentation techniques while training, by default True

        Returns
        -------
        BaseModel
            Best model obtained during training
        """
        # get datasets
        data_train = DelineationDataset(
            split="training",
            data_dir=data_dir,
            augment=use_augmentation,
            return_watersheds=True,
            return_distance=True,
            bands=self._input_cfg["bands"],
            n_ts=self._input_cfg["n_ts"],
        )
        data_val = DelineationDataset(
            split="validation",
            data_dir=data_dir,
            augment=False,
            return_watersheds=True,
            return_distance=True,
            bands=self._input_cfg["bands"],
            n_ts=self._input_cfg["n_ts"],
        )

        # Start the training
        self.model.logger(" Model configs:")
        self.model.logger(f" - Model: {self.model.tag}")
        self.model.logger(f" - Parameters: {self.model.n_params:_}")
        self.model.logger(f" - Device: {self.model.device}")
        self.model.logger(f" - Input bands: {self.model.cfg['input']['bands']}")
        self.model.logger(f" - N ts: {self.model.cfg['input']['n_ts']}\n ")

        self.model.logger(" Steps configs:")
        self.model.logger(f" - Maximum steps: {steps_max}")
        self.model.logger(f" - Validation steps: {steps_val}")
        self.model.logger(f" - Batch-size: {batch_size}\n ")

        self.model.logger(" Optimizer configs:")
        self.model.logger(f" - Optimizer: {self._optimizer.__class__.__name__}")
        self.model.logger(f" - Early-stopping tolerance: {self._es_tol}")
        self.model.logger(f" - Early-stopping improvement: {self._es_improv}")
        self.model.logger(f" - Early-stopping metric: {self._es_metric}")
        self.model.logger(f" - Scheduler: {self._scheduler.__class__.__name__}")
        self.model.logger(f" - Learninr rate tolerance: {self._scheduler.patience}\n ")

        self.model.logger(" Loss configs:")
        for i, (tag, ls, wm, lc) in enumerate(
            zip(
                self._loss.tags,
                self._loss.losses,
                self._loss.weights_multipliers,
                self._loss.loss_coefficients,
            )
        ):
            self.model.logger(f" - Tag: {tag}")
            self.model.logger(f" - Loss: {ls.__name__}")
            self.model.logger(f" - Weight / multiplier: {wm}")
            self.model.logger(f" - LC coefficient: {lc}")
            self.model.logger(" ---" if i != len(self._loss.tags) - 1 else "\n")

        # start training loop
        steps_global, es_buffer = 0, 0
        metric_best = self._validate(
            data=data_val,
            batch_size=batch_size,
            steps_global=steps_global,
        )
        self._scheduler.step(metric_best)
        self.model.logger("\n")

        while (steps_max is None) or (steps_global < steps_max):
            # run for steps_val steps
            self.model.logger(
                f"Training for {steps_val} steps (current step: {steps_global}{f'/{steps_max}' if steps_max is not None else ''})"
            )
            steps_global = self._train(
                data=data_train,
                n_steps=steps_val
                if (steps_max is None)
                else min(steps_val, steps_max - steps_global),
                batch_size=batch_size,
                steps_global=steps_global,
            )

            # evaluate
            self.model.logger("Validating")
            metric_current = self._validate(
                data=data_val,
                batch_size=batch_size,
                steps_global=steps_global,
            )
            self._scheduler.step(metric_current)

            # check metrics with es
            if not self._es_goal(new=metric_current, old=metric_best):
                es_buffer += 1
                self.model.logger(
                    f" - No improvement for {self._es_metric}: {metric_current:.5f} (current best: {metric_best:.5f})"
                )
                if es_buffer >= self._es_tol:
                    self.model.logger(
                        " - Early Stopping tolerance reached. Stopping training!\n "
                    )
                    break
            else:
                self.model.logger(
                    f" - New best {self._es_metric}: {metric_current:.5f}"
                )
                self.model.logger(
                    f" - Improvement over previous best: {metric_current - metric_best:.5f}"
                )
                metric_best = metric_current
                es_buffer = 0
                self.model.save()
                self.model.logger(" - Best model updated!\n ")

        return SemanticModel.load(self.model.model_folder)

    def _step(self, batch: dict[str, Any], return_outputs: bool = True) -> None:
        """Forward the provided batch through the model."""
        # retrieve data
        inputs = batch["input"].to(self.model.device)
        extents = batch["extent"].to(self.model.device)
        watersheds = batch["watersheds"].to(self.model.device)
        distance = batch["distance"].to(self.model.device)

        # Forward pass
        outputs = self.model.forward_process(inputs)

        # Loss computation
        local_variables = locals()  # generators have their own scope
        losses_weights = self._gather_losses_weights(local_variables)
        res = self._loss(outputs, extents, losses_weights)

        # Add outputs to result if requested
        if return_outputs:
            res["outputs"] = self.model.post_process(outputs).detach().cpu()
        return res

    def _gather_losses_weights(self, variables: dict[Any, Any]) -> list[torch.Tensor]:
        """Gather losses weights from local variables."""
        loss_weights = []
        for weight_multipl in self._loss.weights_multipliers:
            if weight_multipl is not None:
                canvas = torch.ones_like(variables["extents"]).to(self.model.device)
                for weights_name in weight_multipl:
                    canvas += variables[weights_name] * weight_multipl[weights_name]
                loss_weights.append(canvas)
            else:
                loss_weights.append(None)
        return loss_weights

    def _train(
        self,
        data: DelineationDataset,
        n_steps: int,
        batch_size: int,
        steps_global: int,
    ) -> int:
        """
        Train the model over a specified number of steps.

        Parameters
        ----------
        data : DelineationDataset
            The dataset to train on.
        n_steps : int
            The number of steps to train for.
        batch_size : int
            The batch size to use.
        steps_global : int
            The current global step.

        Returns
        -------
        int
            The new global step.
        """
        self.model.train()
        current_step = 0
        running_loss = []
        with tqdm(total=n_steps, desc="Training...") as pbar:
            for batch in data.get_iterator(batch_size=batch_size, n=n_steps):
                # Perform a training step
                self._optimizer.zero_grad()
                result = self._step(
                    batch=batch, return_outputs=(steps_global % 10 == 0)
                )
                result["loss"].backward()
                self._optimizer.step()

                # Write away the evaluation metrics
                if "outputs" in result:
                    # NOTE: '% 100' should be a multiple of '% 10' found at return_outputs
                    incl_pred = (steps_global + current_step) % 100 == 0
                    result["losses"]["loss"] = result["loss"]
                    self._write_tensorboard(
                        tag="train",
                        step=steps_global + current_step,
                        losses=result["losses"],
                        inputs=batch["input"] if incl_pred else None,
                        targets=batch["instance"] if incl_pred else None,
                        predictions=result["outputs"] if incl_pred else None,
                        n_samples=0,
                    )

                # update progress bar
                pbar.update(1)
                running_loss.append(result["loss"].detach())
                pbar.set_postfix({"loss": f"{sum(running_loss) / len(running_loss)}"})
                current_step += 1

        return steps_global + current_step

    def _validate(
        self,
        data: DelineationDataset,
        batch_size: int,
        steps_global: int,
        n_samples: int = 100,
    ) -> float:
        """
        Validate the current performance of the model.

        Parameters
        ----------
        data : DelineationDataset
            The dataset to validate on.
        batch_size : int
            The batch size to use.
        steps_global : int
            The current global step.
        n_samples : int
            The number of samples to use for the PR-curves.

        Returns
        -------
        float
            The metric to use for early stopping.
        """
        self.model.eval()
        n_batches = len(data) // batch_size  # Drop last
        inputs, targets, predictions = (
            torch.Tensor([]),
            torch.Tensor([]),
            torch.Tensor([]),
        )
        metrics = {"loss": []}
        with tqdm(total=n_batches, desc="Validate...") as pbar:
            for batch in data.get_iterator(
                batch_size=batch_size, n=n_batches, reset_first=True
            ):
                # Perform a validation step
                result = self._step(batch=batch)

                # Update the loss metrics
                metrics["loss"] += [result["loss"].detach()]
                for k, v in result["losses"].items():
                    vv = [v.item()] * len(batch)
                    metrics[k] = metrics[k] + vv if k in metrics else vv

                # Aggregate predictions and targets
                if len(inputs) < n_samples:
                    inputs = torch.cat((inputs, batch["input"]), dim=0)[:n_samples]
                    targets = torch.cat((targets, batch["instance"]), dim=0)[:n_samples]
                    predictions = torch.cat((predictions, result["outputs"]), dim=0)[
                        :n_samples
                    ]

                # Update the progress bar
                pbar.update(1)
                pbar.set_postfix(
                    {"loss": f"{sum(metrics['loss']) / len(metrics['loss'])}"}
                )

        # Write away the evaluation metrics
        self._write_tensorboard(
            tag="val",
            step=steps_global,
            losses={k: sum(v) / len(v) for k, v in metrics.items()},
            inputs=inputs,
            targets=targets,
            predictions=predictions,
        )

        return sum(metrics["loss"]) / len(metrics["loss"])

    def _es_goal(self, new: float, old: float) -> bool:
        """Check if the early stopping criterion is met (i.e. significant model improvement)."""
        if self._es_metric == "loss":
            return new < (old - self._es_improv)
        return new > (old + self._es_improv)

    def _write_tensorboard(
        self,
        tag: str,
        step: int,
        losses: dict[str, float],
        inputs: torch.Tensor | None = None,
        targets: torch.Tensor | None = None,
        predictions: torch.Tensor | None = None,
        n_samples: int = 10,
    ) -> None:
        """
        Write away results in TensorBoard.

        Parameters
        ----------
        tag: str
            Logging prefix
        step: int
            Global step of the training
        losses: dict[str, float]
            Dictionary containing all losses to write away
        inputs: torch.Tensor
            Inputs of the model
        targets: torch.Tensor
            Targets predictions for the given input
        predictions: torch.Tensor
            Model predictions for the given input
        """
        assert tag in {"train", "val"}

        # write away learning rate
        self._writer.add_scalar(
            "lr", self._optimizer.param_groups[0]["lr"], global_step=step
        )

        # Write away losses
        if tag == "val":
            self.model.logger("Writing results to TensorBoard:")
            self.model.logger(" - Start loss computation")
        for k, v in losses.items():
            self._writer.add_scalar(f"{k}/{tag}", v, global_step=step)

        # Write away images
        if inputs is not None:
            if tag == "val":
                self.model.logger(f" - Start image generation for {len(inputs)} images")
            assert targets is not None
            assert predictions is not None

            # Create aggregated images
            precision_iou, recall_iou = [], []
            for target, prediction in tqdm(
                zip(targets, predictions),
                desc="Creating PR calculations",
                total=len(targets),
                leave=False,
            ):
                pr = get_pr(gt=target, pr=prediction)
                precision_iou.append(pr["precision"])
                recall_iou.append(pr["recall"])

            # compute curves
            precision_iou = torch.cat(precision_iou, dim=0)
            recall_iou = torch.cat(recall_iou, dim=0)
            precision_curve = [
                (precision_iou > (i / 100)).sum() / max(len(precision_iou), 1)
                for i in range(101)
            ]
            recall_curve = [
                (recall_iou > (i / 100)).sum() / max(len(recall_iou), 1)
                for i in range(101)
            ]
            f1_curve = [
                (2 * pr * rc) / (pr + rc + 1e-12)
                for pr, rc in zip(precision_curve, recall_curve)
            ]

            # save to tensorboard
            for metric, curve in zip(
                ["precision", "recall", "f1"],
                [precision_curve, recall_curve, f1_curve],
            ):
                fig = create_iou_curve(curve)
                plt.title(
                    f"{metric.capitalize()} - (auc = {sum(curve) / len(curve):.4f})"
                )
                self._writer.add_figure(f"{metric}_iou/{tag}", fig, global_step=step)

            # Create sample based images
            for n in range(n_samples):
                inp = inputs[n].flatten(0, 1).permute(1, 2, 0)
                dims = 3 if inp.shape[0] >= 3 else 1  # noqa: PLR2004
                fig = plot_predictions(
                    background=inp[:, :, :dims],
                    target=targets[n],
                    prediction=shuffle_preds_ids(predictions[n]),
                )
                self._writer.add_figure(f"sample_{n}/{tag}", fig, global_step=step)


if __name__ == "__main__":
    from datetime import datetime

    from vito_lot_delineation.configs import get_cfg

    model = SemanticModel(
        model_tag=f"{datetime.now().strftime('%Y%m%dT%H%M%S')}-",
        config_file=get_cfg(cfg_tag="base"),
    )

    trainer = SemanticTrainer(
        model=model,
        loss_cfg=model.cfg["train"]["losses"],
        input_cfg=model.cfg["input"],
        es_tolerance=0.0001,
    )
    trainer.train(
        data_dir=Path("data/data"),
        steps_val=250,
        batch_size=32,
        use_augmentation=True,
    )
