"""Custom CnnTransformer transformer trainer."""

from __future__ import annotations

from typing import Any

import torch
from torch import nn

from vito_cropsar.models.CnnTransformer.config import TrainerConfig
from vito_cropsar.models.CnnTransformer.main import InpaintingModel
from vito_cropsar.models.CnnTransformer.trainers.base import (
    CnnTransformerTrainerBase,
)


class CnnTransformerTrainerTransformer(CnnTransformerTrainerBase):
    """Custom CnnTransformer transformer trainer."""

    def __init__(
        self,
        model: InpaintingModel,
        cfg: TrainerConfig,
        shared_state: dict[str, Any],
    ) -> None:
        """
        Initialise the trainer.

        Parameters
        ----------
        model : InpaintingBase
            Model to train
        cfg : TrainerConfig
            Trainer configuration
        shared_state
            Shared stated passed between the trainers
        """
        super().__init__(
            model=model,
            cfg=cfg,
            shared_state=shared_state,
            tag="transformer",
        )
        self._new_model = self._create_model()
        self.initialise_new_model(self._new_model)

    def forward(self, inp: dict[str, torch.Tensor | None]) -> torch.Tensor:
        """Forward the necessary data through the model."""
        x = inp["inputs"]
        x = x.permute(0, 2, 1, 3, 4)
        x = self._new_model.input_block(x)
        x = self._new_model.spatial_encoder(x)
        x = self._new_model.temporal_layer(x)
        x = self._new_model.spatial_decoder(x, e=inp["edges"])
        x = self._new_model.output_block(x, e=inp["edges"])
        x = x.permute(0, 2, 1, 3, 4)
        return x

    def _create_model(self) -> nn.Module:
        """Create a PyTorch model suitable for this trainig regime."""
        mdl = nn.ModuleDict(
            {  # Complete copy-paste of the model
                "input_block": self._model._model.input_block,  # 0, frozen
                "spatial_encoder": self._model._model.spatial_encoder,  # 1, frozen
                "temporal_layer": self._model._model.temporal_layer,  # 2, trainable
                "spatial_decoder": self._model._model.spatial_decoder,  # 3, trainable
                "output_block": self._model._model.output_block,  # 4, trainable
            }
        )
        mdl.to(self._model.device)

        # Freeze the encoder, unfreeze the rest
        for block in (
            "input_block",
            "spatial_encoder",
            # "temporal_layer",
            # "spatial_decoder",
            # "output_block",
        ):
            for param in mdl[block].parameters():
                param.requires_grad = False
        for block in (
            # "input_block",
            # "spatial_encoder",
            "temporal_layer",
            "spatial_decoder",
            "output_block",
        ):
            for param in mdl[block].parameters():
                param.requires_grad = True

        return mdl
