"""Modules used in the CnnTransformer model."""

from __future__ import annotations

from math import log

import torch
from torch import nn

from vito_cropsar.models.shared import ConvBlock, parse_activation


class InputBlock(nn.Module):
    """
    Input processing block.

    This has the same functionality of the BasicBlock, but without the residual connection.

    This block processes the input as follows:
        - Processing 1
            - Convolution (regular) with channels_in -> channels_p
            - Activation function
        - Processing 2
            - Convolution (regular) with channels_p -> channels_p
        - Batch normalization
        - Activation function
    """

    def __init__(
        self,
        channels_in: int,
        channels_p: int,
        activation: str = "relu",
    ) -> None:
        """
        Initialize the input block.

        Parameters
        ----------
        channels_in : int
            Number of input channels
        channels_p : int
            Number of channels to use in the processing
        activation : str
            Activation function to use
        """
        super().__init__()
        self.conv1 = ConvBlock(
            channels_in=channels_in,
            channels_out=channels_p,
            spatial=True,
            temporal=False,
            transpose=False,
        )
        self.act1 = parse_activation(activation)
        # self.conv2 = ConvBlock(
        #     channels_in=channels_p,
        #     channels_out=channels_p,
        #     spatial=True,
        #     temporal=False,
        #     transpose=False,
        # )
        # self.bn = nn.BatchNorm3d(channels_p)
        # self.act = parse_activation(activation)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward process."""
        x = self.conv1(x)
        x = self.act1(x)
        # x = self.conv2(x)
        # x = self.bn(x)
        # x = self.act(x)
        return x


class SharpeningBlock(nn.Module):
    """Image sharpening block."""

    def __init__(
        self,
        channels_p: int,
        channels_out: int,
        activation: str = "relu",
    ) -> None:
        """
        Initialise the basic ResNet block.

        Parameters
        ----------
        channels_p : int
            Number of processing channels
        channels_out : int
            Number of output channels
        activation : str
            Activation function to use for the processing steps
        """
        super().__init__()
        self.conv1 = ConvBlock(
            channels_in=channels_p,
            spatial=True,
            temporal=True,
            transpose=False,
        )
        self.act1 = parse_activation(activation)
        self.conv2 = ConvBlock(
            channels_in=channels_p,
            spatial=True,
            temporal=True,
            transpose=False,
        )
        self.act2 = parse_activation(activation)
        self.conv3 = ConvBlock(
            channels_in=channels_p,
            channels_out=channels_out,
            spatial=False,  # Puts kernel etc. to (_,1,1)
            temporal=False,  # Puts kernel etc. to (1,_,_)
            transpose=False,
        )

    def forward(
        self,
        x: torch.Tensor,
        e: torch.Tensor | None = None,
    ) -> torch.Tensor:
        """
        Forward process.

        Parameters
        ----------
        x : torch.Tensor
            Input tensor
        e : torch.Tensor | None
            Edge representative tensor, if provided

        Returns
        -------
        torch.Tensor
            Output tensor, predicted S2 image
        """
        # First processing
        if e is not None:
            x = torch.cat([e[:, None], x[:, 1:, :, :]], dim=1)
        x = self.conv1(x)
        x = self.act1(x)

        # Second processing
        if e is not None:
            x = torch.cat([e[:, None], x[:, 1:, :, :]], dim=1)
        x = self.conv2(x)
        x = self.act2(x)

        # Final processing
        if e is not None:
            x = torch.cat([e[:, None], x[:, 1:, :, :]], dim=1)
        x = self.conv3(x)

        return x


class Resample3d(nn.Module):
    """
    Simple resampling block.

    This block processes the input as follows:
        - Convolution (transpose) with channels_in -> channels_out
        - Activation function
    """

    def __init__(
        self,
        channels_in: int,
        channels_out: int,
        activation: str = "relu",
        spatial: bool = False,  # Puts kernel etc. to (_,1,1)
        temporal: bool = False,  # Puts kernel etc. to (1,_,_)
        transpose: bool = False,
    ) -> None:
        """Initialize the block."""
        super().__init__()
        self.conv1 = ConvBlock(
            channels_in=channels_in,
            channels_out=channels_out,
            kernel_size=(4, 4, 4),
            padding=(1, 1, 1),
            stride=(2, 2, 2),
            spatial=spatial,
            temporal=temporal,
            transpose=transpose,
        )
        self.act1 = parse_activation(activation)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward process."""
        x = self.conv1(x)
        x = self.act1(x)
        return x


class Downsample3d(Resample3d):
    """
    Downsampling block.

    This block processes the input as follows:
        - Convolution (regular) with channels_in -> channels_out
            - Kernel size of 1x4x4 is used to enlarge spatial context
            - Stride of 1x2x2 is used to downsample by factor of 2
            - Padding of 0x1x1 is used to keep spatial resolution
            - Note: Only spatial downsampling is performed
        - Activation function

    Adaptation of: https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py#L240
    Note: Rather than MaxPool2d we use a kernel size of 4x4x4 and stride of 2x2x2.
    """

    def __init__(
        self,
        channels_in: int,
        channels_out: int,
        activation: str = "relu",
        spatial: bool = True,
        temporal: bool = False,
    ) -> None:
        """Initialize the block."""
        super().__init__(
            channels_in=channels_in,
            channels_out=channels_out,
            activation=activation,
            spatial=spatial,
            temporal=temporal,
            transpose=False,
        )


class Upsample3d(Resample3d):
    """
    Simple usample block.

    This block processes the input as follows:
        - Convolution (transpose) with channels_in -> channels_out
            - Kernel size of 1x4x4 is used enlarge spatial context
            - Stride of 1x2x2 is used to upsample by factor of 2
            - Padding of 0x1x1 is used to keep spatial resolution
            - Note: Only spatial upsampling is performed
        - Activation function
    """

    def __init__(
        self,
        channels_in: int,
        channels_out: int,
        activation: str = "relu",
        spatial: bool = True,
        temporal: bool = False,
    ) -> None:
        """Initialize the block."""
        super().__init__(
            channels_in=channels_in,
            channels_out=channels_out,
            activation=activation,
            spatial=spatial,
            temporal=temporal,
            transpose=True,
        )


class PositionalEncoder(nn.Module):
    """Positional encoder."""

    def __init__(
        self,
        n_ts: int,
        dim: int,
        dropout: float = 0.0,
    ) -> None:
        """
        Initialise the positional encoder.

        Parameters
        ----------
        n_ts : int
            The length of the input sequences
        dim : int
            The dimension of the output of sub-layers in the model
        dropout : float
            Dropout rate to use, applied after the positional encoding
        """
        super().__init__()

        # Set attributes
        self.dim = dim
        self.dropout = nn.Dropout(p=dropout)

        # Copy-pasted from PyTorch tutorial
        position = torch.arange(n_ts).unsqueeze(1)  # Array from 0..n_ts
        div_term = torch.exp(torch.arange(0, dim, 2) * (-log(10000.0) / dim))
        pe = torch.zeros(n_ts, 1, dim)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        pe = pe.transpose(0, 1)
        self.register_buffer("pe", pe)  # Store non-trainable tensor

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Apply the positional encoding."""
        x = x + self.pe[: x.size(1)]
        x = self.dropout(x)
        return x
