"""
ResUNet class.

Adaptation of: https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py
"""

from __future__ import annotations

import torch
from torch import nn

from vito_cropsar.models.ResUNet3d.model.layers import BasicLayer, BottleneckLayer
from vito_cropsar.models.ResUNet3d.model.modules import (
    InputBlock,
    SharpeningBlock,
)


class ResUNet3d(nn.Module):
    """Custom ResUNet3d module."""

    def __init__(
        self,
        channels_in: int,
        channels_out: int,
        channels: list[int] = [64, 128, 256],
        depth_spatial: int = 3,
        depth_temporal: int = 3,
        enc_channels_p: int = 32,
        enc_dropout: float = 0.1,
        enc_activation: str = "relu",
        bn_fraction: float = 0.25,
        bn_dropout: float = 0.1,
        bn_activation: str = "relu",
        dec_channels_p: int = 32,
        dec_dropout: float = 0.0,
        dec_activation: str = "relu",
    ) -> None:
        """
        Initialise the ResUNet3d module.

        Parameters
        ----------
        channels_in : int
            Number of input channels
        channels_out : int
            Number of output channels
        activation : str
            Activation function to use
        channels : list[int]
            Number of channels per layer in the UNet
            Note: The first channels is used for input and output mapping
            Note: The last channel is used for the bottleneck, with channel//4 as bottleneck
        depth_spatial : int
            Number of spatial convolution layers in the UNet
        depth_temporal : int
            Number of temporal convolution layers in the UNet
        enc_channels_p : int
            Number of input processing channels (spatial encoder)
        enc_dropout : float
            Dropout rate (spatial encoder)
        enc_activation : str
            Activation function to use (spatial encoder)
        bn_fraction : float
            Number of channels to use in the bottleneck, expressed as a fraction from channels[-1]
        bn_dropout : float
            Dropout rate (bottleneck)
        bn_activation : str
            Activation function to use (bottleneck)
        dec_channels_p : int
            Number of output processing channels (spatial decoder)
        dec_dropout : float
            Dropout rate (spatial decoder)
        dec_activation : str
            Activation function to use (spatial decoder)
        """
        super().__init__()

        # Check attributes
        assert len(channels) == max(depth_spatial, depth_temporal), (
            f"Number of channels ({len(channels)}) should be equal to the maximum of "
            f"spatial ({depth_spatial}) and temporal ({depth_temporal}) depth"
        )
        assert (
            len(channels) >= 2  # noqa: PLR2004
        ), "At least two channels must be provided!"
        assert depth_spatial >= 1, "At least one spatial layer must be provided!"
        assert depth_temporal >= 1, "At least one temporal layer must be provided!"

        # Input processing
        self.conv_in = InputBlock(
            channels_in=channels_in,
            channels_p=enc_channels_p,
            activation=enc_activation,
        )

        # Downsample layers of the UNet
        #  - Downsample
        #  - ResNet block
        self.layers_down = torch.nn.ModuleList()
        for i in range(len(channels)):
            self.layers_down.append(
                BasicLayer(
                    mode="down",
                    channels_in=enc_channels_p if (i == 0) else channels[i - 1],
                    channels_out=channels[i],
                    activation=enc_activation,
                    dropout=enc_dropout,
                    spatial=((i + 1) <= depth_spatial),
                    temporal=((i + 1) <= depth_temporal),
                )
            )

        # Bottleneck layer
        self.bottleneck = BottleneckLayer(
            channels=channels[-1],
            channels_bn=int(bn_fraction * channels[-1]),
            activation=bn_activation,
            dropout=bn_dropout,
            spatial=(len(channels) == depth_spatial),
            temporal=(len(channels) == depth_temporal),
        )

        # Upsample layers of the UNet (intermediate layers)
        self.layers_up = torch.nn.ModuleList()
        for i in range(len(channels) - 1, -1, -1):
            self.layers_up.append(
                BasicLayer(
                    mode="up",
                    channels_in=2 * channels[i],
                    channels_out=dec_channels_p if (i == 0) else channels[i - 1],
                    activation=dec_activation,
                    dropout=dec_dropout,
                    spatial=((i + 1) <= depth_spatial),
                    temporal=((i + 1) <= depth_temporal),
                )
            )
        assert len(self.layers_down) == len(
            self.layers_up
        ), f"Network needs as many layers down ({len(self.layers_down)}) as up ({len(self.layers_up)})"

        # Output mapping (processing + reconstruction)
        self.conv_out = SharpeningBlock(
            channels_in=enc_channels_p + dec_channels_p,
            channels_p=dec_channels_p,
            channels_out=channels_out,
            activation=dec_activation,
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward process."""
        # Input mapping (processing)
        x = self.conv_in(x)

        # Downward pass
        memory = [x]
        for layer in self.layers_down:
            memory.append(layer(memory[-1]))

        # Bottleneck
        y = self.bottleneck(memory[-1])

        # Upward pass
        for layer in self.layers_up:
            y = layer(torch.cat([y, memory.pop(-1)], dim=1))
        assert len(memory) == 1, "Memory should contain exactly one!"

        # Output mapping
        y = self.conv_out(torch.cat([y, memory.pop(-1)], dim=1))
        return y
