"""Complete network layers."""


from __future__ import annotations

import torch
from torch import nn

from vito_cropsar.models.CnnTransformer.model.modules import (
    Downsample3d,
    PositionalEncoder,
    Resample3d,
    Upsample3d,
)
from vito_cropsar.models.shared import BasicBlock
from vito_cropsar.vito_logger import LogLevel, bh_logger


class SpatialLayer(nn.Module):
    """Basic ResNet layer for spatial processing."""

    def __init__(
        self,
        mode: str,
        channels_in: int,
        channels_out: int | None = None,
        activation: str = "relu",
        dropout: float = 0.0,
        k_up_spatial: bool = True,
        k_up_temporal: bool = False,
        k_process_spatial: bool = True,
        k_process_temporal: bool = False,
    ) -> None:
        """
        Initialize the basic layer.

        A downward layer flows as follows:
            - Downsample
            - Processing layer
        A upward layer flows as follows:
            - Processing layer
            - Upsample
        A resampling layer flows as follows:
            - Resample
            - Processing layer

        Notes
        -----
        - The processing layer is a ResNet block
        - The number of channels used in the processing layer is equal to
            - channels_out if mode is "down" (downsample first, then process)
            - channels_in if mode is "up" (process first, then upsample)
            - channels_out if mode is "none" (resample first, then process)

        Parameters
        ----------
        mode : str
            Layer mode, either "up", "down", or "none"
        channels_in : int
            Number of input channels
        channels_out : int | None
            Number of channels used in the processing layers (same as channels_in if None)
        activation : str
            Activation function to use
        dropout : float
            Dropout rate to apply on the processing layer
        k_up_spatial : bool
            Whether to use spatial convolutions in the upsampling layer
        k_up_temporal : bool
            Whether to use temporal convolutions in the upsampling layer
        k_process_spatial : bool
            Whether to use spatial convolutions in the processing layer
        k_process_temporal : bool
            Whether to use temporal convolutions in the processing layer
        """
        super().__init__()
        channels_out = channels_out or channels_in
        assert mode in (
            "up",
            "down",
            "none",
        ), f"Mode should be either 'up', 'down', or 'none'. Got {mode}"

        # Check for sampling mode
        self.mode = mode
        if self.mode == "down":
            self.down = Downsample3d(
                channels_in=channels_in,
                channels_out=channels_out,
                spatial=k_up_spatial,
                temporal=k_up_temporal,
            )
        if self.mode == "up":
            self.up = Upsample3d(
                channels_in=channels_in,
                channels_out=channels_out,
                spatial=k_up_spatial,
                temporal=k_up_temporal,
            )
        if (self.mode == "none") and (channels_in != channels_out):
            self.resample = Resample3d(
                channels_in=channels_in,
                channels_out=channels_out,
                spatial=k_up_spatial,
                temporal=k_up_temporal,
            )

        # Create the processing block
        self.process = BasicBlock(
            channels=channels_in if self.mode == "up" else channels_out,
            activation=activation,
            dropout=dropout,
            spatial=k_process_spatial,
            temporal=k_process_temporal,
            transpose=(self.mode == "up"),
        )

    def forward(self, x: torch.Tensor) -> None:
        """Forward pass."""
        # Downsampling if in downward pass
        if hasattr(self, "down"):
            x = self.down(x)
        if hasattr(self, "resample"):
            x = self.resample(x)

        # Processing
        x = self.process(x)

        # Upsampling if in upward pass
        if hasattr(self, "up"):
            x = self.up(x)

        return x


class SpatialEncoder(nn.Module):
    """Spatial encoding CNN layer."""

    def __init__(
        self,
        channels_prev: int,
        channels_resnet: list[int],
        dropout: float = 0.0,
        activation: str = "relu",
    ) -> None:
        """
        Initialize the layer.

        Note: The encoder halves the spatial dimension at each SpatialLayer.

        Parameters
        ----------
        channels_prev : int
            Number of channels that came right before this block
        channels_resnet : list[int]
            List of ResNet layer (SpatialLayer) channels
        dropout : float
            Dropout rate to apply on the processing layers
        activation : str
            Activation function to use
        """
        super().__init__()

        # Calculate shape
        bh_logger(
            f"Images of resolution 128 will be downsampled to a resolution of {128//2**len(channels_resnet)}",
            LogLevel.WARNING,
        )

        # Layered processing
        #  - Downsample
        #  - ResNet block
        self.layers = nn.ModuleList()
        for channels in channels_resnet:
            self.layers.append(
                SpatialLayer(
                    mode="down",
                    channels_in=channels_prev,
                    channels_out=channels,
                    activation=activation,
                    dropout=dropout,
                )
            )
            channels_prev = channels

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass.

        Parameters
        ----------
        x : torch.Tensor
            Input tensor
            Shape: (batch_size, time, channels, height, width)

        Returns
        -------
        torch.Tensor
            Output tensor
            Shape: (batch_size, time, channels, height//(2**len(channels)), width//(2**len(channels)))
        """
        for layer in self.layers:
            x = layer(x)
        return x


class SpatialDecoder(nn.Module):
    """Spatial decoding CNN layer."""

    def __init__(
        self,
        channels_last: int,
        channels_resnet: list[int],
        dropout: float = 0.0,
        activation: str = "relu",
    ) -> None:
        """
        Initialize the layer.

        Note: The decoder doubles the spatial dimension at each SpatialLayer.

        Parameters
        ----------
        channels_last : int
            Number of channels that go out of this block
        channels_resnet : list[int]
            List of ResNet layer (SpatialLayer) channels
        dropout : float
            Dropout rate to apply on the processing layers
        activation : str
            Activation function to use
        """
        super().__init__()

        # Calculate shape
        bh_logger(
            f"Images of resolution {128//2**len(channels_resnet)} will be upsampled to a resolution of 128",
            LogLevel.WARNING,
        )
        if dropout > 0:
            bh_logger(
                "It's advised not to use dropout in the decoder",
                LogLevel.WARNING,
            )

        # Layered processing
        #  - ResNet block
        #  - Upsample
        self.layers = nn.ModuleList()
        for i, channels in enumerate(channels_resnet):
            self.layers.append(
                SpatialLayer(
                    mode="up",
                    channels_in=channels,
                    channels_out=channels_last
                    if (i == (len(channels_resnet) - 1))
                    else channels_resnet[i + 1],
                    activation=activation,
                    dropout=dropout,
                )
            )

        # Average pooling for the edges
        self.avgpool = nn.AvgPool3d((1, 2, 2))

    def forward(
        self,
        x: torch.Tensor,
        e: torch.Tensor | None = None,
    ) -> torch.Tensor:
        """
        Forward pass.

        Parameters
        ----------
        x : torch.Tensor
            Input tensor
            Shape: (batch_size, time, channels, height//(2**len(channels)), width//(2**len(channels)))
        e : torch.Tensor | None
            Edge representative tensor, if provided
            Note: These edges are mean sampled to fit the right shape

        Returns
        -------
        torch.Tensor
            Output tensor
            Shape: (batch_size, time, channels, height, width)
        """
        if e is not None:
            memory = [e]
            for _ in range(len(self.layers)):
                memory.append(self.avgpool(memory[-1]))
            for layer, mem in zip(self.layers, reversed(memory)):
                # Create combined tensor (must be the case for backpropagation)
                x = torch.cat([mem[:, None], x[:, 1:, :, :]], dim=1)

                # Feed the combined tensor through the layer
                x = layer(x)
        else:
            for layer in self.layers:
                x = layer(x)
        return x


class TemporalLayer(nn.Module):
    """Temporal Transformer layer."""

    def __init__(
        self,
        n_ts: int,
        dim: int,
        n_heads: int = 8,
        n_layers: int = 4,
        activation: str = "relu",
        dropout: float = 0.1,
    ) -> None:
        """
        Initialize the layer.

        Parameters
        ----------
        n_ts : int
            Number of time steps
        dim : int
            Dimensionality of the transformer
            Note: This is the hidden dimension
            Note: The feedforward dimension is 4 times this
        n_heads : int
            Number of heads in the multi-head attention
        n_layers : int
            Number of layers in the transformer
        activation : str
            Activation function to use
        dropout : float
            Dropout rate to apply on the processing layers
        """
        super().__init__()
        self.pos_enc = PositionalEncoder(
            n_ts=n_ts,
            dim=dim,
            dropout=dropout,
        )
        enc_layer = nn.TransformerEncoderLayer(
            d_model=dim,
            dim_feedforward=dim * 4,
            nhead=n_heads,
            batch_first=True,
            activation=activation,
            dropout=dropout,
        )
        self.enc = nn.TransformerEncoder(
            encoder_layer=enc_layer,
            num_layers=n_layers,
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass.

        Parameters
        ----------
        x : torch.Tensor
            Input tensor
            Shape: (batch_size, time, channels, height//(2**len(channels)), width//(2**len(channels)))

        Returns
        -------
        torch.Tensor
            Output tensor
            Shape: (batch_size, time, channels, height//(2**len(channels)), width//(2**len(channels)))
        """
        # Put width and height in the batch dimension
        b, e, t, w, h = x.shape
        x = x.permute(0, 3, 4, 2, 1).reshape(b * w * h, t, e)

        # Encode the provided embeddings over time
        x = self.pos_enc(x)

        # Encode in the spatial dimension
        x = self.enc(x)

        # Restore original shape
        x = x.reshape(b, w, h, t, e).permute(0, 4, 3, 1, 2)

        return x
