"""Custom CnnTransformer autoencoder trainer."""

from __future__ import annotations

from typing import Any

import torch
from torch import nn

from vito_cropsar.models.CnnTransformer.config import TrainerConfig
from vito_cropsar.models.CnnTransformer.main import InpaintingModel
from vito_cropsar.models.CnnTransformer.model.modules import SharpeningBlock
from vito_cropsar.models.CnnTransformer.trainers.base import (
    CnnTransformerTrainerBase,
    prepare_input,
)
from vito_cropsar.models.shared.loss import VITOLoss


class CnnTransformerTrainerAutoencoder(CnnTransformerTrainerBase):
    """Custom CnnTransformer autoencoder trainer."""

    def __init__(
        self,
        model: InpaintingModel,
        cfg: TrainerConfig,
        shared_state: dict[str, Any],
    ) -> None:
        """
        Initialise the trainer.

        Parameters
        ----------
        model : InpaintingBase
            Model to train
        cfg : TrainerConfig
            Trainer configuration
        shared_state
            Shared stated passed between the trainers
        """
        super().__init__(
            model=model,
            cfg=cfg,
            shared_state=shared_state,
            tag="autoencoder",
        )
        self._new_model = self._create_model()
        self.initialise_new_model(self._new_model)
        self.loss = VITOLoss(
            device=self._model.device,
            n_channels=self._model.n_channels_out + 1,
            cfg=self.loss_cfg,
        )

    def forward(self, inp: dict[str, torch.Tensor | None]) -> torch.Tensor:
        """Forward the necessary data through the model."""
        x = inp["inputs"]
        x = x.permute(0, 2, 1, 3, 4)
        x = self._new_model.input_block(x)
        x = self._new_model.spatial_encoder(x)
        x = self._new_model.spatial_decoder(x, e=inp["edges"])
        x = self._new_model.output_block(x, e=inp["edges"])
        x = x.permute(0, 2, 1, 3, 4)
        return x

    def _create_model(self) -> nn.Module:
        """Create a PyTorch model suitable for this trainig regime."""
        mdl = nn.ModuleDict(
            {
                "input_block": self._model._model.input_block,  # 0, trainable
                "spatial_encoder": self._model._model.spatial_encoder,  # 1, trainable
                "spatial_decoder": self._model._model.spatial_decoder,  # 2, trainable
                "output_block": SharpeningBlock(
                    channels_p=self._model.dec_channels_p,
                    channels_out=self._model.cfg.n_channels_out + 1,  # 3, trainable
                ),
            }
        )
        mdl.to(self._model.device)

        # Unfreeze everything (redundant!)
        for block in (
            "input_block",
            "spatial_encoder",
            "spatial_decoder",
            "output_block",
        ):
            for param in mdl[block].parameters():
                param.requires_grad = True

        return mdl

    def _step(
        self,
        batch: dict[str, torch.Tensor],
        incl_pred: bool = False,
    ) -> dict[str, Any]:
        """
        Create one step.

        Parameters
        ----------
        batch : dict[str, torch.Tensor]
            Batch of data
        incl_pred : bool
            Whether to include the prediction in the metric results

        Returns
        -------
        dict[str, Any]
            Metric results
        """
        s1 = batch["s1"].to(self._model.device)
        s2 = batch["s2"].clone().to(self._model.device)
        mask = batch["mask"].to(self._model.device)
        target = batch["target"].to(self._model.device)
        _, _, c, _, _ = target.shape

        # Append mask to target
        target_ = torch.cat([target, mask[:, :, None]], dim=2)

        # Make the prediction
        try:
            inp = prepare_input(
                s1=s1,
                s2=s2,
                fill_nan=self._model.fill_nan,
                incl_edge=self._model.edge_repr,
            )
            pred = self.forward(inp)
            assert target_.shape == pred.shape
        except Exception as e:
            self.log_exception(e)
            self.log(f"Tile IDs: {', '.join(batch['tile'])}")
            raise e

        # Calculate the loss
        metrics = self.loss(target=target_, pred=pred, mask=mask)

        # Append the prediction to the metric results
        if incl_pred:
            metrics["pred"] = pred.cpu()[:, :, :c, :]
        return metrics
