"""Complete network layers."""

from __future__ import annotations

import torch
from torch import nn

from vito_cropsar.models.ResUNet3d.model.modules import (
    BottleneckBlock,
    Downsample3d,
    Resample3d,
    Upsample3d,
)
from vito_cropsar.models.shared import BasicBlock


class BasicLayer(nn.Module):
    """Basic ResUNet3d layer."""

    def __init__(
        self,
        mode: str,
        channels_in: int,
        channels_out: int | None = None,
        activation: str = "relu",
        spatial: bool = True,
        temporal: bool = True,
        dropout: float = 0.0,
    ) -> None:
        """
        Initialize the basic layer.

        A downward layer flows as follows:
            - Downsample
            - Processing layer
        A upward layer flows as follows:
            - Processing layer
            - Upsample
        A resampling layer flows as follows:
            - Resample
            - Processing layer

        Notes
        -----
        - The processing layer is a ResNet block
        - The number of channels used in the processing layer is equal to
            - channels_out if mode is "down" (downsample first, then process)
            - channels_in if mode is "up" (process first, then upsample)
            - channels_out if mode is "none" (resample first, then process)

        Parameters
        ----------
        mode : str
            Layer mode, either "up", "down", or "none"
        channels_in : int
            Number of input channels
        channels_out : int | None
            Number of channels used in the processing layers (same as channels_in if None)
        activation : str
            Activation function to use
        spatial : bool
            Whether to use spatial convolutions
        temporal : bool
            Whether to use temporal convolutions
        dropout : float
            Dropout rate to apply on the processing layer
        """
        super().__init__()
        channels_out = channels_out or channels_in
        assert mode in (
            "up",
            "down",
            "none",
        ), f"Mode should be either 'up', 'down', or 'none'. Got {mode}"
        assert any(
            [spatial, temporal]
        ), "At least one of spatial or temporal should be True"

        # Check for sampling mode
        self.mode = mode
        if self.mode == "down":
            self.down = Downsample3d(channels_in=channels_in, channels_out=channels_out)
        if self.mode == "up":
            self.up = Upsample3d(channels_in=channels_in, channels_out=channels_out)
        if (self.mode == "none") and (channels_in != channels_out):
            self.resample = Resample3d(
                channels_in=channels_in, channels_out=channels_out
            )

        # Create the processing block
        self.process = BasicBlock(
            channels=channels_in if self.mode == "up" else channels_out,
            activation=activation,
            dropout=dropout,  # Not recommended to do when "up"
            spatial=spatial,
            temporal=temporal,
            transpose=(self.mode == "up"),
        )

    def forward(self, x: torch.Tensor) -> None:
        """Forward pass."""
        # Downsampling if in downward pass
        if hasattr(self, "down"):
            x = self.down(x)
        if hasattr(self, "resample"):
            x = self.resample(x)

        # Processing
        x = self.process(x)

        # Upsampling if in upward pass
        if hasattr(self, "up"):
            x = self.up(x)

        return x


class BottleneckLayer(nn.Module):
    """Bottleneck ResUNet3d layer."""

    def __init__(
        self,
        channels: int,
        channels_bn: int,
        activation: str = "relu",
        spatial: bool = True,
        temporal: bool = True,
        dropout: float = 0.0,
    ) -> None:
        """
        Initialize the bottleneck layer.

        Parameters
        ----------
        channels : int
            Number of channels in the layer (right before bottleneck)
        channels_bn : int
            Number of channels in the bottleneck layer
        activation : str
            Activation function to use
        spatial : bool
            Whether to use spatial convolutions
        temporal : bool
            Whether to use temporal convolutions
        dropout : float
            Dropout rate
        """
        super().__init__()
        assert any(
            [spatial, temporal]
        ), "At least one of spatial or temporal should be True"
        assert (
            channels_bn < channels
        ), "Bottleneck channels should be smaller than channels"

        self.down = Downsample3d(
            channels_in=channels,
            channels_out=channels,
            spatial=spatial,
            temporal=temporal,
        )
        self.process_d = BottleneckBlock(
            channels=channels,
            channels_bn=channels_bn,
            activation=activation,
            dropout=dropout,
        )
        self.up = Upsample3d(
            channels_in=channels,
            channels_out=channels,
            spatial=spatial,
            temporal=temporal,
        )

    def forward(self, x: torch.Tensor) -> None:
        """Forward pass."""
        x = self.down(x)
        x = self.process_d(x)
        x = self.up(x)
        return x
