"""Shared PyTorch modules."""

from __future__ import annotations

from typing import Callable

import torch
from torch import nn


class ConvBlock(nn.Module):
    """
    Single convolutional block.

    This block runs the following sequence:
     - Convolution (regular or transposed) with channels_in -> channels_out
     - Activation function
    """

    def __init__(
        self,
        channels_in: int,
        channels_out: int | None = None,
        kernel_size: tuple = (3, 3, 3),
        padding: tuple = (1, 1, 1),
        stride: tuple = (1, 1, 1),
        spatial: bool = True,
        temporal: bool = True,
        transpose: bool = False,
    ) -> None:
        """
        Initialize the block.

        Parameters
        ----------
        channels_in : int
            Number of input channels
        channels_out : int
            Number of output channels, same as channels_in if None
        kernel_size : int
            Kernel size of the convolution
        padding : int
            Padding of the convolution
        stride : int
            Stride of the convolution
        spatial : bool
            Whether to use spatial convolutions
        temporal : bool
            Whether to use temporal convolutions
        transpose : bool
            Whether to use transposed convolutions for upsampling
        bias : bool
            Whether to use bias in the convolution
        """
        super().__init__()
        conv_cls = nn.ConvTranspose3d if transpose else nn.Conv3d
        self.conv = conv_cls(
            in_channels=channels_in,
            out_channels=channels_out or channels_in,
            kernel_size=(
                kernel_size[0] if temporal else 1,
                kernel_size[1] if spatial else 1,
                kernel_size[2] if spatial else 1,
            ),
            stride=(
                stride[0] if temporal else 1,
                stride[1] if spatial else 1,
                stride[2] if spatial else 1,
            ),
            padding=(
                padding[0] if temporal else 0,
                padding[1] if spatial else 0,
                padding[2] if spatial else 0,
            ),
            bias=True,
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward process."""
        x = self.conv(x)  # Perform the convolution
        return x


class DynamicLayerNorm(nn.Module):
    """Dynamic Layer Normalization."""

    def __init__(self, n_features: int, eps: float = 1e-5) -> None:
        """Initialize the layer."""
        super().__init__()
        self.num_features = n_features
        self.eps = eps
        self.gamma = nn.Parameter(torch.ones(n_features))
        self.beta = nn.Parameter(torch.zeros(n_features))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward process."""
        mean = x.mean([2, 3, 4], keepdim=True)
        std_dev = x.std([2, 3, 4], keepdim=True)
        normalized_x = (x - mean) / (std_dev + self.eps)
        return self.gamma.view(1, -1, 1, 1, 1) * normalized_x + self.beta.view(
            1, -1, 1, 1, 1
        )


class BasicBlock(nn.Module):
    """
    Basic block, inspired by the basic ResNet block.

    Adaptation of: https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py#L59
    """

    def __init__(
        self,
        channels: int,
        activation: str = "relu",
        spatial: bool = True,
        temporal: bool = True,
        transpose: str = False,
        dropout: float = 0.0,
    ) -> None:
        """
        Initialise the basic ResNet block.

        Parameters
        ----------
        channels : int
            Number of channels
        activation : str
            Activation function to use
        spatial : bool
            Whether to use spatial convolutions
        temporal : bool
            Whether to use temporal convolutions
        transpose : str
            Whether or not to use transposed convolutions
        dropout : float
            Dropout probability
        """
        super().__init__()
        self.conv1 = ConvBlock(
            channels_in=channels,
            spatial=spatial,
            temporal=temporal,
            transpose=transpose,
        )
        self.act1 = parse_activation(activation)
        self.dropout = nn.Dropout3d(dropout)
        self.conv2 = ConvBlock(
            channels_in=channels,
            spatial=spatial,
            temporal=temporal,
            transpose=transpose,
        )
        self.ln2 = DynamicLayerNorm(channels)
        self.act2 = parse_activation(activation)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward process."""
        y = self.conv1(x)
        y = self.act1(y)
        y = self.dropout(y)
        y = self.conv2(y)
        y += x  # Residual connection
        y = self.ln2(y)  # Normalization layer
        y = self.act2(y)
        return y


def parse_activation(  # noqa: PLR0911
    activation: str,
) -> Callable[[torch.Tensor], torch.Tensor]:
    """Parse the activation function."""
    if activation == "relu":
        return nn.ReLU()
    if activation == "leaky_relu":
        return nn.LeakyReLU()
    if activation == "elu":
        return nn.ELU()
    if activation == "gelu":
        return nn.GELU()
    if activation == "selu":
        return nn.SELU()
    if activation == "silu":
        return nn.SiLU()
    if activation == "tanh":
        return nn.Tanh()
    if activation == "sigmoid":
        return nn.Sigmoid()
    if activation == "none":
        return nn.Identity()
    raise ValueError(f"Unknown activation function: {activation}")
