"""
ResNet modules.

Source: https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py
"""

from __future__ import annotations

import torch
from torch import nn

from vito_cropsar.models.shared import ConvBlock, parse_activation


class BottleneckBlock(nn.Module):
    """
    ResNet bottleneck block.

    This block processes the input as follows:
        - Processing 1
            - Convolution (regular) with channels -> channels_bn
            - Activation function
        - Processing 2
            - Convolution (regular) with channels_bn -> channels_bn
            - Activation function
        - Processing 3
            - Convolution (transpose) with channels_bn -> channels using
        - Add input to output of processing 3
        - Activation function

    Note that this block implements the UNet bottleneck and not the ResNet bottleneck. This is
    because we don't want to throw away information, since most of it is useful for the
    reconstruction phase.
    """

    def __init__(
        self,
        channels: int,
        channels_bn: int,
        activation: str = "relu",
        spatial: bool = True,
        temporal: bool = True,
        dropout: float = 0.0,
    ) -> None:
        """
        Initialise the bottleneck ResNet block.

        Parameters
        ----------
        channels : int
            Number of channels
        channels_bn : int
            Number of channels for the bottleneck
        activation : str
            Activation function to use
        spatial : bool
            Whether to use spatial convolutions
        temporal : bool
            Whether to use temporal convolutions
        dropout : float
            Dropout rate
        """
        super().__init__()
        self.conv1 = ConvBlock(
            channels_in=channels,
            channels_out=channels_bn,
        )
        self.act1 = parse_activation(activation)
        self.dropout1 = nn.Dropout3d(p=dropout)
        self.conv2 = ConvBlock(
            channels_in=channels_bn,
            channels_out=channels_bn,
            spatial=spatial,
            temporal=temporal,
        )
        self.act2 = parse_activation(activation)
        self.dropout2 = nn.Dropout3d(p=dropout)
        self.conv3 = ConvBlock(
            channels_in=channels_bn,
            channels_out=channels,
            transpose=True,
        )
        self.act3 = parse_activation(activation)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward process."""
        y = self.conv1(x)  # Conv, activation
        y = self.act1(y)  # Activation
        y = self.dropout1(y)  # Apply dropout on processed input
        y = self.conv2(y)  # Conv, activation
        y = self.act2(y)  # Activation
        y = self.dropout2(y)  # Apply dropout on processed input
        y = self.conv3(y)  # Conv
        y = y + x  # Residual connection
        y = self.act3(y)  # Activation
        return y


class InputBlock(nn.Module):
    """
    Input processing block.

    This has the same functionality of the BasicBlock, but without the residual connection.

    This block processes the input as follows:
        - Processing 1
            - Convolution (regular) with channels_in -> channels_p
            - Activation function
        - Processing 2
            - Convolution (regular) with channels_p -> channels_p
        - Batch normalization
        - Activation function
    """

    def __init__(
        self,
        channels_in: int,
        channels_p: int,
        activation: str = "relu",
    ) -> None:
        """
        Initialize the input block.

        Parameters
        ----------
        channels_in : int
            Number of input channels
        channels_p : int
            Number of channels to use in the processing
        activation : str
            Activation function to use
        """
        super().__init__()
        self.conv1 = ConvBlock(
            channels_in=channels_in,
            channels_out=channels_p,
            spatial=True,
            temporal=False,
            transpose=False,
        )
        self.act1 = parse_activation(activation)
        # self.conv2 = ConvBlock(
        #     channels_in=channels_p,
        #     channels_out=channels_p,
        #     spatial=True,
        #     temporal=False,
        #     transpose=False,
        # )
        # self.bn = nn.BatchNorm3d(channels_p)
        # self.act = parse_activation(activation)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward process."""
        x = self.conv1(x)
        x = self.act1(x)
        # x = self.conv2(x)
        # x = self.bn(x)
        # x = self.act(x)
        return x


class SharpeningBlock(nn.Module):
    """
    Image sharpening block.

    All the processing happens only in the spatial dimension. No dropout is applied.

    This block processes the input as follows:
        - Processing 1
            - Convolution (regular, space only) with channels_in -> channels_p
            - Activation function
        - Processing 2
            - Convolution (regular, space only) with channels_p -> channels_p
            - Activation function
        - Processing 2
            - Convolution (regular or transposed) with channels_p -> channels_out
            - Note: 1x1x1 kernels are used (with padding 0 and stride of 1)
    """

    def __init__(
        self,
        channels_in: int,
        channels_p: int,
        channels_out: int,
        activation: str = "relu",
    ) -> None:
        """
        Initialise the basic ResNet block.

        Parameters
        ----------
        channels_in : int
            Number of input channels
        channels_p : int
            Number of processing channels
        channels_out : int
            Number of output channels
        activation : str
            Activation function to use for the processing steps
        """
        super().__init__()
        self.conv1 = ConvBlock(
            channels_in=channels_in,
            channels_out=channels_p,
            spatial=True,
            temporal=False,
            transpose=False,
        )
        self.act1 = parse_activation(activation)
        self.conv2 = ConvBlock(
            channels_in=channels_p,
            spatial=True,
            temporal=False,
            transpose=False,
        )
        self.act2 = parse_activation(activation)
        self.conv3 = ConvBlock(
            channels_in=channels_p,
            channels_out=channels_out,
            spatial=False,  # Puts kernel etc. to (_,1,1)
            temporal=False,  # Puts kernel etc. to (1,_,_)
            transpose=False,
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward process."""
        x = self.conv1(x)
        x = self.act1(x)
        x = self.conv2(x)
        x = self.act2(x)
        x = self.conv3(x)
        return x


class Resample3d(nn.Module):
    """
    Simple resampling block.

    This block processes the input as follows:
        - Convolution (transpose) with channels_in -> channels_out
        - Activation function
    """

    def __init__(
        self,
        channels_in: int,
        channels_out: int,
        activation: str = "relu",
        spatial: bool = False,  # Puts kernel etc. to (_,1,1)
        temporal: bool = False,  # Puts kernel etc. to (1,_,_)
        transpose: bool = False,
    ) -> None:
        """Initialize the block."""
        super().__init__()
        self.f = ConvBlock(
            channels_in=channels_in,
            channels_out=channels_out,
            kernel_size=(4, 4, 4),
            padding=(1, 1, 1),
            stride=(2, 2, 2),
            spatial=spatial,
            temporal=temporal,
            transpose=transpose,
        )
        self.act = parse_activation(activation)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward process."""
        x = self.f(x)
        x = self.act(x)
        return x


class Downsample3d(Resample3d):
    """
    Downsampling block.

    This block processes the input as follows:
        - Convolution (regular) with channels_in -> channels_out
            - Note: Input dimensionality change depends if spatial or temporal downsampling is used
        - Activation function

    Adaptation of: https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py#L240
    Note: Rather than MaxPool2d we use a kernel size of 4x4x4 and stride of 2x2x2.
    Note: No BatchNorm is used.
    """

    def __init__(
        self,
        channels_in: int,
        channels_out: int,
        activation: str = "relu",
        spatial: bool = True,
        temporal: bool = True,
    ) -> None:
        """Initialize the block."""
        super().__init__(
            channels_in=channels_in,
            channels_out=channels_out,
            activation=activation,
            spatial=spatial,
            temporal=temporal,
            transpose=False,
        )


class Upsample3d(Resample3d):
    """
    Simple usample block.

    This block processes the input as follows:
        - Convolution (transpose) with channels_in -> channels_out
            - Note: Input dimensionality change depends if spatial or temporal upsampling is used
        - Activation function
    """

    def __init__(
        self,
        channels_in: int,
        channels_out: int,
        activation: str = "relu",
        spatial: bool = True,
        temporal: bool = True,
    ) -> None:
        """Initialize the block."""
        super().__init__(
            channels_in=channels_in,
            channels_out=channels_out,
            activation=activation,
            spatial=spatial,
            temporal=temporal,
            transpose=True,
        )
