"""
ResNet modules.

Source: https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py
"""

from __future__ import annotations

import torch
from torch import nn

from vito_lot_delineation.models.utils import parse_activation


class ConvBlock(nn.Module):
    """
    Single convolutional block.

    This block runs the following sequence:
     - Convolution (regular or transposed) with channels_in -> channels_out
     - Activation function
    """

    def __init__(
        self,
        channels_in: int,
        channels_out: int | None = None,
        activation: str = "relu",
        kernel_size: tuple = (3, 3, 3),
        padding: tuple = (1, 1, 1),
        stride: tuple = (1, 1, 1),
        spatial: bool = True,
        temporal: bool = True,
        transpose: bool = False,
    ) -> None:
        """
        Initialize the block.

        Parameters
        ----------
        channels_in : int
            Number of input channels
        channels_out : int
            Number of output channels, same as channels_in if None
        activation : str
            Activation function to use
        kernel_size : int
            Kernel size of the convolution
        padding : int
            Padding of the convolution
        stride : int
            Stride of the convolution
        spatial : bool
            Whether to use spatial convolutions
        temporal : bool
            Whether to use temporal convolutions
        transpose : bool
            Whether to use transposed convolutions for upsampling
        """
        super().__init__()
        conv_cls = nn.ConvTranspose3d if transpose else nn.Conv3d
        self.conv = conv_cls(
            in_channels=channels_in,
            out_channels=channels_out or channels_in,
            kernel_size=(
                kernel_size[0] if temporal else 1,
                kernel_size[1] if spatial else 1,
                kernel_size[2] if spatial else 1,
            ),
            stride=(
                stride[0] if temporal else 1,
                stride[1] if spatial else 1,
                stride[2] if spatial else 1,
            ),
            padding=(
                padding[0] if temporal else 0,
                padding[1] if spatial else 0,
                padding[2] if spatial else 0,
            ),
            bias=True,
        )
        self.act = parse_activation(activation)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward process."""
        x = self.conv(x)  # Perform the convolution
        x = self.act(x)  # Perform the activation function
        return x


class BasicBlock(nn.Module):
    """
    Basic ResNet block.

    This block processes the input as follows:
        - Processing 1
            - Convolution (regular or transposed) with channels -> channels
            - Activation function
        - Processing 2
            - Convolution (regular or transposed) with channels -> channels
        - Add input to output of processing 2
        - Batch normalization
        - Activation function
    """

    def __init__(
        self,
        channels: int,
        activation: str = "relu",
        spatial: bool = True,
        temporal: bool = True,
        transpose: str = False,
        dropout_rate: float = 0.0,
    ) -> None:
        """
        Initialise the basic ResNet block.

        Parameters
        ----------
        channels : int
            Number of channels
        activation : str
            Activation function to use
        spatial : bool
            Whether to use spatial convolutions
        temporal : bool
            Whether to use temporal convolutions
        transpose : str
            Whether or not to use transposed convolutions
        dropout_rate : float
            Dropout rate to use, applied in between the two convolutions
        """
        super().__init__()
        self.conv1 = ConvBlock(
            channels_in=channels,
            activation=activation,
            spatial=spatial,
            temporal=temporal,
            transpose=transpose,
        )
        self.dropout = nn.Dropout3d(dropout_rate)
        self.conv2 = ConvBlock(
            channels_in=channels,
            activation="none",
            spatial=spatial,
            temporal=temporal,
            transpose=transpose,
        )
        self.bn = nn.BatchNorm3d(channels)
        self.act_out = parse_activation(activation)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward process."""
        y = self.conv1(x)  # With activation
        y = self.dropout(y)  # Apply dropout on the partially processed data
        y = self.conv2(y)  # Without activation
        y += x  # Residual connection
        y = self.bn(y)  # Batch normalization
        y = self.act_out(y)  # Apply output activation
        return y


class BottleneckBlock(nn.Module):
    """
    ResNet bottleneck block.

    This block processes the input as follows:
        - Processing 1
            - Convolution (regular) with channels -> channels_bn
            - Activation function
        - Processing 2
            - Convolution (regular) with channels_bn -> channels_bn
            - Activation function
        - Processing 3
            - Convolution (transpose) with channels_bn -> channels using
        - Add input to output of processing 3
        - Activation function

    Note that this block implements the UNet bottleneck and not the ResNet bottleneck. This is
    because we don't want to throw away information, since most of it is useful for the
    reconstruction phase.
    """

    def __init__(
        self,
        channels: int,
        channels_bn: int,
        activation: str = "relu",
        spatial: bool = True,
        temporal: bool = True,
        dropout_rate: float = 0.0,
    ) -> None:
        """
        Initialise the bottleneck ResNet block.

        Parameters
        ----------
        channels : int
            Number of channels
        channels_bn : int
            Number of channels for the bottleneck
        activation : str
            Activation function to use
        spatial : bool
            Whether to use spatial convolutions
        temporal : bool
            Whether to use temporal convolutions
        dropout_rate : float
            Dropout rate
        """
        super().__init__()
        self.conv1 = ConvBlock(
            channels_in=channels,
            channels_out=channels_bn,
            activation=activation,
        )
        self.dropout1 = nn.Dropout3d(p=dropout_rate)
        self.conv2 = ConvBlock(
            channels_in=channels_bn,
            channels_out=channels_bn,
            activation=activation,
            spatial=spatial,
            temporal=temporal,
        )
        self.dropout2 = nn.Dropout3d(p=dropout_rate)
        self.conv3 = ConvBlock(
            channels_in=channels_bn,
            channels_out=channels,
            activation="none",
            transpose=True,
        )
        self.act = parse_activation(activation)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward process."""
        y = self.conv1(x)  # Conv, activation
        y = self.dropout1(y)  # Apply dropout on processed input
        y = self.conv2(y)  # Conv, activation
        y = self.dropout2(y)  # Apply dropout on processed input
        y = self.conv3(y)  # Conv
        y = y + x  # Residual connection
        y = self.act(y)  # Activation
        return y


class OutputBlock(nn.Module):
    """
    ResNet output block.

    All the processing happens only in the spatial dimension. No dropout is applied.

    This block processes the input as follows:
        - Processing 1
            - Convolution (regular or transposed) with channels_p -> channels_p
            - Activation function
        - Processing 2
            - Convolution (regular or transposed) with channels_p -> channels_p
            - Activation function
        - Processing 2
            - Convolution (regular or transposed) with channels_p -> channels_out
            - Note: 1x1x1 kernels are used (with padding 0 and stride of 1)
    """

    def __init__(
        self,
        channels_in: int,
        channels_p: int,
        channels_out: int,
        n_ts: int,
        prc_activation: str = "relu",
        out_activation: str = "sigmoid",
    ) -> None:
        """
        Initialise the Output ResNet block.

        Parameters
        ----------
        channels_in : int
            Number of input channels
        channels_p : int
            Number of processing channels
        channels_out : int
            Number of output channels
        prc_activation : str
            Activation function to use for the processing steps
        out_activation : str
            Activation function to use for the final output
        n_ts : int
            Number of time stamps
        """
        super().__init__()
        # NOTE: Substituted the Resample3D with custom ConvBlock
        self.conv1 = ConvBlock(
            channels_in=channels_in,
            channels_out=channels_p,
            activation=prc_activation,
            kernel_size=((n_ts // 2) + 1, 3, 3),
            padding=(0, 1, 1),
            spatial=True,
            temporal=True,
            transpose=False,
        )
        self.conv2 = ConvBlock(
            channels_in=channels_p,
            activation=prc_activation,
            kernel_size=(n_ts // 2 if n_ts % 2 == 0 else (n_ts // 2) + 1, 3, 3),
            padding=(0, 1, 1),
            spatial=True,
            temporal=True,
            transpose=False,
        )
        self.conv3 = ConvBlock(
            channels_in=channels_p,
            channels_out=channels_out,
            activation=out_activation,
            spatial=False,  # Puts kernel etc. to (_,1,1)
            temporal=False,  # Puts kernel etc. to (1,_,_)
            transpose=False,
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward process."""
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        return x


class Resample3d(nn.Module):
    """
    Simple resampling block.

    This block processes the input as follows:
        - Convolution (transpose) with channels_in -> channels_out
        - Activation function
    """

    def __init__(
        self,
        channels_in: int,
        channels_out: int,
        activation: str = "relu",
        spatial: bool = False,  # Puts kernel etc. to (_,1,1)
        temporal: bool = False,  # Puts kernel etc. to (1,_,_)
        transpose: bool = False,
    ) -> None:
        """Initialize the block."""
        super().__init__()
        self.f = ConvBlock(
            channels_in=channels_in,
            channels_out=channels_out,
            activation=activation,
            kernel_size=(4, 4, 4),
            padding=(1, 1, 1),
            stride=(2, 2, 2),
            spatial=spatial,
            temporal=temporal,
            transpose=transpose,
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward process."""
        return self.f(x)


class Downsample3d(Resample3d):
    """
    Downsampling block.

    This block processes the input as follows:
        - Convolution (regular) with channels_in -> channels_out
            - Note: Input dimensionality change depends if spatial or temporal downsampling is used
        - Activation function

    Note: Rather than MaxPool2d we use a kernel size of 4x4x4 and stride of 2x2x2.
    Note: No BatchNorm is used.
    """

    def __init__(
        self,
        channels_in: int,
        channels_out: int,
        activation: str = "relu",
        spatial: bool = True,
        temporal: bool = True,
    ) -> None:
        """Initialize the block."""
        super().__init__(
            channels_in=channels_in,
            channels_out=channels_out,
            activation=activation,
            spatial=spatial,
            temporal=temporal,
            transpose=False,
        )


class Upsample3d(Resample3d):
    """
    Simple usample block.

    This block processes the input as follows:
        - Convolution (transpose) with channels_in -> channels_out
            - Note: Input dimensionality change depends if spatial or temporal upsampling is used
        - Activation function
    """

    def __init__(
        self,
        channels_in: int,
        channels_out: int,
        activation: str = "relu",
        spatial: bool = True,
        temporal: bool = True,
    ) -> None:
        """Initialize the block."""
        super().__init__(
            channels_in=channels_in,
            channels_out=channels_out,
            activation=activation,
            spatial=spatial,
            temporal=temporal,
            transpose=True,
        )
