"""The CNN Transformer PyTorch model."""

from __future__ import annotations

import torch
from torch import nn

from vito_cropsar.models.CnnTransformer.model.layers import (
    SpatialDecoder,
    SpatialEncoder,
    TemporalLayer,
)
from vito_cropsar.models.CnnTransformer.model.modules import (
    InputBlock,
    SharpeningBlock,
)


class CnnTransformer(nn.Module):
    """The CNN Transformer PyTorch model."""

    def __init__(
        self,
        channels_in: int,
        channels_out: int,
        n_ts: int,
        enc_channels_p: int = 32,
        enc_channels_resnet: list[int] = [64, 128, 512],
        enc_dropout: float = 0.1,
        enc_activation: str = "relu",
        temp_layers: int = 4,
        temp_heads: int = 8,
        temp_dropout: float = 0.1,
        temp_activation: str = "relu",
        dec_channels_p: int = 32,
        dec_channels_resnet: list[int] = [512, 256, 128],
        dec_dropout: float = 0.0,  # No dropout in decoder!
        dec_activation: str = "relu",
        sha_activation: str = "relu",
    ) -> None:
        """
        Initialise the model.

        Parameters
        ----------
        channels_in : int
            Number of input channels
        channels_out : int
            Number of output channels
        n_ts : int
            Length of the temporal dimension
        activation : str
            Activation function to use
        enc_channels_p : int
            Number of input processing channels (spatial encoder)
        enc_channels_resnet : list[int]
            Number of channels in each spatial layer (spatial encoder)
        enc_dropout : float
            Dropout rate (spatial encoder)
        enc_activation : str
            Activation function to use (spatial encoder)
        temp_layers : int
            Number of layers (temporal layer)
        temp_heads : int
            Number of heads (temporal layer)
        temp_dropout : float
            Dropout rate (temporal layer)
        temp_activation : str
            Activation function to use (temporal layer)
        dec_channels_p : int
            Number of output processing channels (spatial decoder)
        dec_channels_resnet : list[int]
            Number of channels in each spatial layer (spatial decoder)
        dec_dropout : float
            Dropout rate (spatial decoder)
        dec_activation : str
            Activation function to use (spatial decoder)
        sha_activation : str
            Activation function to use (sharpening block)
        """
        assert len(enc_channels_resnet) == len(
            dec_channels_resnet
        ), "Spatial encoder and decoder must have the same number of layers"
        assert enc_channels_resnet[-1] == dec_channels_resnet[0], (
            "Spatial encoder and decoder must have the same number of channels "
            "in the last layer of the encoder and the first layer of the decoder"
        )

        super().__init__()
        self.input_block = InputBlock(
            channels_in=channels_in,
            channels_p=enc_channels_p,
            activation=enc_activation,
        )
        self.spatial_encoder = SpatialEncoder(
            channels_prev=enc_channels_p,
            channels_resnet=enc_channels_resnet,
            dropout=enc_dropout,
            activation=enc_activation,
        )
        self.temporal_layer = TemporalLayer(
            n_ts=n_ts,
            dim=enc_channels_resnet[-1],
            n_layers=temp_layers,
            n_heads=temp_heads,
            dropout=temp_dropout,
            activation=temp_activation,
        )
        self.spatial_decoder = SpatialDecoder(
            channels_last=dec_channels_p,
            channels_resnet=dec_channels_resnet,
            dropout=dec_dropout,
            activation=dec_activation,
        )
        self.output_block = SharpeningBlock(
            channels_p=dec_channels_p,
            channels_out=channels_out,
            activation=sha_activation,
        )

    def forward(self, x: torch.Tensor, e: torch.Tensor | None) -> torch.Tensor:
        """Forward pass."""
        # Encode spatial dimensions (S1 and S2)
        x = self.input_block(x)
        x = self.spatial_encoder(x)

        # Temporal layer encoding (S1 and S2)
        x = self.temporal_layer(x)

        # Decode spatial dimensions (S2)
        x = self.spatial_decoder(x, e=e)
        x = self.output_block(x, e=e)

        return x
