"""Custom CnnTransformer base trainer."""

from __future__ import annotations

from typing import Any

import torch
from torch import nn
from torch.optim import Adam
from torch.utils.data import DataLoader
from tqdm import tqdm

from vito_cropsar.models.base_trainer import TrainerBase
from vito_cropsar.models.CnnTransformer.config import TrainerConfig
from vito_cropsar.models.CnnTransformer.main import InpaintingModel
from vito_cropsar.models.CnnTransformer.utils import get_repr
from vito_cropsar.models.shared import (
    CropSarDataset,
    LearningRateScheduler,
    TensorBoardLogger,
)
from vito_cropsar.models.shared.loss import VITOLoss
from vito_cropsar.models.shared.utils import interpolate_s1_batch


class CnnTransformerTrainerBase(TrainerBase):
    """Custom CnnTransformer base trainer."""

    def __init__(
        self,
        model: InpaintingModel,
        cfg: TrainerConfig,
        shared_state: dict[str, Any],
        tag: str,
    ) -> None:
        """
        Initialise the trainer.

        Parameters
        ----------
        model : InpaintingBase
            Model to train
        cfg : TrainerConfig
            Trainer configuration
        shared_state
            Shared stated passed between the trainers
        tag : str
            Trainer tag
        """
        super().__init__(
            model=model,
            cfg=cfg,
        )
        self.shared_state = shared_state
        self.tag = tag
        self._new_tensorboard = TensorBoardLogger(
            mdl_f=self._model.model_folder,
            n_plots=self.cfg.n_plots,
            scaler=self._model.scaler,
            fill_nan=self._model.fill_nan,
            logs_f=f"logs_{self.tag}",
        )
        self.loss = VITOLoss(
            device=self._model.device,
            n_channels=self._model.n_channels_out,  # In- and output should never change
            cfg=self.loss_cfg,
        )

    def initialise_new_model(self, new_model: nn.Module) -> None:
        """Initialise the new model."""
        self._optimizer = Adam(new_model.parameters(), lr=self.cfg.lr)
        self._lr_scheduler = LearningRateScheduler(
            optimizer=self._optimizer,
            lr_min=self.lr / 2,
            lr_max=self.lr,
            val_mode="min",  # Loss (minimize)
            plateau_improvement_ratio=self.es_improvement_ratio,
            warmup_steps=self.steps_val,
            cycle_steps=self.steps_val,
            logger=self._model.logger,
        )

    def create_model(self) -> nn.Module:
        """Create a PyTorch model suitable for this trainig regime."""
        raise NotImplementedError

    def forward(self, x: dict[str, torch.Tensor | None]) -> torch.Tensor:
        """Forward through the model."""
        raise NotImplementedError

    def train(self, global_step: int = 0) -> tuple[int, InpaintingModel]:
        """
        Training script.

        Parameters
        ----------
        global_step : int
            Global step counter, only used for logging

        Returns
        -------
        tuple[int,InpaintingModel]
            global_step
                Global step after training finishes
            model
                Trained model, recovered from the best checkpoint
        """
        assert hasattr(self, "_new_model"), "Please create a sub-model first"
        dataloader_train = CropSarDataset(
            data_f=self.data_f,
            split="training",
            scaler=self._model.scaler,
            size=self.size_train,
            n_ts=self._model.n_ts,
            resolution=self._model.resolution,
            cache_tag=self.cache_tag,
            align=self._model.align_input,
            smooth_s1=self._model.smooth_s1,
            augm_cfg=self.augm_cfg,
        ).get_dataloader(batch_size=self.batch_size, n_loaders=self.n_data_loaders)
        dataloader_val = CropSarDataset(
            data_f=self.data_f,
            split="validation",
            scaler=self._model.scaler,
            size=self.size_val,
            n_ts=self._model.n_ts,
            resolution=self._model.resolution,
            cache_tag=self.cache_tag,
            align=self._model.align_input,
            smooth_s1=self._model.smooth_s1,
            augm_cfg=self.augm_cfg,
        ).get_dataloader(batch_size=self.batch_size, n_loaders=self.n_data_loaders)

        # Start the training
        self.log("")
        self.log(f"Training '{self.tag}':")
        param_train, param_freeze = get_parameter_count(self._new_model)
        self.log(f" - Number of trainable parameters: {param_train:,}")
        self.log(f" - Number of frozen parameters: {param_freeze:,}")
        self.log(f" - Validation steps: {self.steps_val}")
        self.log(f" - Maximum steps: {self.steps_max}")
        self.log(f" - Batch-size: {self.batch_size}")
        self.log(f" - Early-stopping tolerance: {self.es_tolerance}")
        self.log(f" - Early-stopping improvement: {self.es_improvement_ratio}")
        self.log(f" - Optimizer: {self._optimizer.__class__.__name__}")
        self.log(f" - Learning rate scheduling: {self._lr_scheduler}")
        self.log(f" - Loss function: {self.loss}")
        self.log(f" - Device: {self._model.device}")
        self.log("")

        # Start with initial validation
        es_buffer = 0
        best_metric = self._validate(
            dataloader=dataloader_val,
            global_step=global_step,
        )
        self._lr_scheduler(best_metric)
        self.log(f"Initial validation loss: {best_metric:.5f}")
        self.log("")
        self._model.save()

        # start training loop
        while (self.steps_max is None) or (global_step < self.steps_max):
            self.log(
                f"Training for {self.steps_val} steps (current step: {global_step})"
            )
            global_step = self._train(
                n=self.steps_val,
                dataloader=dataloader_train,
                global_step=global_step,
            )
            new_metric = self._validate(
                dataloader=dataloader_val,
                global_step=global_step,
            )
            self._lr_scheduler(new_metric)

            # Check for early stopping
            self.log("Early stopping:")
            if new_metric < (1 - self.es_improvement_ratio) * best_metric:
                self.log(f" - New best loss: {new_metric:.5f}")
                self.log(
                    f" - Improvement over previous best: {new_metric-best_metric:.5f}"
                )
                self.log(" - Best model updated!")
                best_metric = new_metric
                es_buffer = 0
                self._model.save()
            else:
                es_buffer += 1
                self.log(
                    f" - No improvement: {new_metric:.5f} (current best: {best_metric:.5f})"
                )
                self.log(f" - Patience: {es_buffer}/{self.es_tolerance}")
                if es_buffer >= self.es_tolerance:
                    self.log(" - Early Stopping tolerance reached. Training stopped!")
                    self.log("")
                    return global_step, self.recover()
            self.log("")
        return global_step, self.recover()

    def _step(
        self,
        batch: dict[str, torch.Tensor],
        incl_pred: bool = False,
    ) -> dict[str, Any]:
        """
        Create one step.

        Parameters
        ----------
        batch : dict[str, torch.Tensor]
            Batch of data
        incl_pred : bool
            Whether to include the prediction in the metric results

        Returns
        -------
        dict[str, Any]
            Metric results
        """
        s1 = batch["s1"].to(self._model.device)
        s2 = batch["s2"].clone().to(self._model.device)
        mask = batch["mask"].to(self._model.device)
        target = batch["target"].to(self._model.device)

        # Make the prediction
        try:
            inp = prepare_input(
                s1=s1,
                s2=s2,
                fill_nan=self._model.fill_nan,
                incl_edge=self._model.edge_repr,
            )
            pred = self.forward(inp)
            assert target.shape == pred.shape
        except Exception as e:
            self.log_exception(e)
            self.log(f"Tile IDs: {', '.join(batch['tile'])}")
            raise e

        # Calculate the loss
        metrics = self.loss(target=target, pred=pred, mask=mask)

        # Append the prediction to the metric results
        if incl_pred:
            metrics["pred"] = pred.cpu()
        return metrics

    def _train(self, n: int, dataloader: DataLoader, global_step: int) -> int:  # type: ignore[type-arg]
        """Train a single epoch."""
        assert hasattr(self, "_new_model"), "Model not initialized!"
        torch.set_grad_enabled(True)
        self._new_model.train()

        # Run a training loop
        loss_train, train_step = [], 0
        with tqdm(total=n, desc="Training..") as pbar:
            while True:
                for batch in dataloader:
                    # Run the batch
                    self._optimizer.zero_grad()
                    b_result = self._step(batch, incl_pred=False)
                    loss = b_result["loss"]
                    loss.backward()
                    self._optimizer.step()
                    self._lr_scheduler.step()
                    self.log_lr(
                        step=global_step,
                        lr=self._optimizer.param_groups[0]["lr"],
                    )

                    # Write away the evaluation metrics
                    loss_train.append(loss.item())
                    if global_step % max(self.steps_val // 10, 1) == 0:
                        self.log_metrics(
                            split="train",
                            step=global_step,
                            metrics={
                                k: v.item() for k, v in b_result.items() if k != "pred"
                            },
                        )

                    # Update the progress bar
                    pbar.update(1)
                    pbar.set_postfix(
                        {"loss": f"{sum(loss_train[-20:]) / len(loss_train[-20:])}"}
                    )
                    train_step += 1
                    global_step += 1

                    # Stop training if number of steps exceeded
                    if train_step >= n:
                        pbar.set_postfix(  # Total training summary
                            {
                                "loss": f"{sum(loss_train) / len(loss_train)}",
                            }
                        )
                        return global_step

    @torch.no_grad()
    def _validate(self, dataloader: DataLoader, global_step: int) -> float:  # type: ignore[type-arg]
        """Run one validation sequence."""
        assert hasattr(self, "_new_model"), "Model not initialized!"
        torch.set_grad_enabled(False)
        self._new_model.eval()

        # Run a validation loop
        results = {}
        with tqdm(total=len(dataloader), desc="Validating..") as pbar:
            for i, batch in enumerate(dataloader):
                n_samples = batch["s1"].shape[0]

                # Run the batch
                b_result = self._step(batch, incl_pred=(i == 0))
                results = {
                    k: results.get(k, []) + [v.item()] * n_samples
                    for k, v in b_result.items()
                    if k != "pred"
                }

                # Update the progress bar
                pbar.update(1)
                pbar.set_postfix(
                    {"loss": f"{sum(results['loss']) / len(results['loss'])}"}
                )

                # Plot if it's the first batch
                if "pred" in b_result:
                    self.log_plots(
                        step=global_step,
                        preds=b_result["pred"].cpu().detach(),
                        targets=batch["target"].cpu().detach(),
                        inputs=batch["s2"].cpu().detach(),
                        masks=batch["mask"].cpu().detach(),
                    )

        # Write away the evaluation metrics
        self.log_metrics(
            split="val",
            step=global_step,
            metrics={k: sum(v) / len(v) for k, v in results.items()},
        )
        return sum(results["loss"]) / len(results["loss"])

    def log_metrics(self, split: str, metrics: dict[str, Any], step: int) -> None:
        """Log metrics to TensorBoard."""
        self._tensorboard.log_metrics(
            split=split,
            step=step,
            **metrics,
        )
        self._new_tensorboard.log_metrics(
            split=split,
            step=step,
            **metrics,
        )

    def log_plots(
        self,
        step: int,
        preds: torch.Tensor,
        targets: torch.Tensor,
        inputs: torch.Tensor,
        masks: torch.Tensor,
    ) -> None:
        """Log plots to TensorBoard."""
        self._tensorboard.log_plots(
            step=step,
            preds=preds,
            targets=targets,
            inputs=inputs,
            masks=masks,
            postfix=f"_{self.tag}",
        )

    def log_lr(self, step: int, lr: float) -> None:
        """Log the learning rate to TensorBoard."""
        self._tensorboard.log_lr(step=step, lr=lr)
        self._new_tensorboard.log_lr(step=step, lr=lr)


def prepare_input(
    s1: torch.Tensor,
    s2: torch.Tensor,
    fill_nan: float,
    incl_edge: bool,
) -> dict[str, torch.Tensor | None]:
    """Prepare the input for the models."""
    # Interpolate NaNs in S1
    s1 = interpolate_s1_batch(s1)

    # Calculate the edges of the S2 data
    s2_e = (
        torch.stack(
            [
                torch.tensor(get_repr(sample), dtype=s2.dtype, device=s2.device)
                for sample in s2.cpu().numpy()
            ]
        ).to(s2.dtype)
        if incl_edge
        else None
    )

    # Ensure S2 has no NaNs
    s2[s2.isnan()] = fill_nan

    # Stack S1 and S2, and run through model
    inputs = torch.cat([s2, s1], dim=2)
    return {
        "inputs": inputs,
        "edges": s2_e,
    }


def get_parameter_count(model: nn.Module) -> tuple[int, int]:
    """Get the number of trainable and frozen parameters."""
    trainable, untrainable = 0, 0
    for p in model.parameters():
        trainable += p.numel() if p.requires_grad else 0
        untrainable += 0 if p.requires_grad else p.numel()
    return trainable, untrainable
