"""
ResUNet class.

Adaptation of: https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py
"""

from __future__ import annotations

from logging import warning

import torch
from torch import nn

from vito_lot_delineation.models.ResUnet3D.model.layers import (
    BasicLayer,
    BottleneckLayer,
)
from vito_lot_delineation.models.ResUnet3D.model.modules import ConvBlock, OutputBlock


class ResUnet3D(nn.Module):
    """Custom ResUNet3d module."""

    def __init__(
        self,
        channels_in: int,
        channels_out: int,
        n_ts: int,
        channels: list[int] = [64, 128, 256],
        depth_spatial: int = 3,
        depth_temporal: int = 3,
        enc_channels_p: int = 32,
        enc_dropout_rate: float = 0.1,
        enc_activation: str = "relu",
        bn_fraction: float = 0.25,
        bn_dropout_rate: float = 0.1,
        bn_activation: str = "relu",
        dec_channels_p: int = 32,
        dec_dropout_rate: float = 0.0,
        dec_activation: str = "relu",
        out_activation: str = "sigmoid",
    ) -> None:
        """
        Initialise the ResUNet3d module.

        Parameters
        ----------
        channels_in : int
            Number of input channels
        channels_out : int
            Number of output channels
        n_ts: int
            Number of time stamps per channel
        activation : str
            Activation function to use
        channels : list[int]
            Number of channels per layer in the UNet
            Note: The first channels is used for input and output mapping
            Note: The last channel is used for the bottleneck, with channel//4 as bottleneck
        depth_spatial : int
            Number of spatial convolution layers in the UNet
        depth_temporal : int
            Number of temporal convolution layers in the UNet
        enc_channels_p : int
            Number of input processing channels (spatial encoder)
        enc_dropout_rate : float
            Dropout rate (spatial encoder)
        enc_activation : str
            Activation function to use (spatial encoder)
        bn_fraction : float
            Number of channels to use in the bottleneck, expressed as a fraction from channels[-1]
        bn_dropout_rate : float
            Dropout rate (bottleneck)
        bn_activation : str
            Activation function to use (bottleneck)
        dec_channels_p : int
            Number of output processing channels (spatial decoder)
        dec_dropout_rate : float
            Dropout rate (spatial decoder)
        dec_activation : str
            Activation function to use (spatial decoder)
        out_activation : str
            Activation function to use (output mapping)
        """
        super().__init__()

        # Check attributes
        assert len(channels) == max(depth_spatial, depth_temporal), (
            f"Number of channels ({len(channels)}) should be equal to the maximum of "
            f"spatial ({depth_spatial}) and temporal ({depth_temporal}) depth"
        )
        assert (
            len(channels) >= 2  # noqa: PLR2004
        ), "At least two channels must be provided!"
        assert depth_spatial >= 1, "At least one spatial layer must be provided!"
        assert depth_temporal >= 1, "At least one temporal layer must be provided!"

        # check if input needs padding
        self.padding = 0
        temporal_pool_factor = 2 ** (depth_temporal + 1)
        if n_ts % temporal_pool_factor != 0:
            warning(
                f" Number of time stamps ({n_ts}) should be divisible by 2**pooling_operations = {temporal_pool_factor}"
            )
            next_multiple = temporal_pool_factor * (n_ts // temporal_pool_factor + 1)
            self.padding = next_multiple - n_ts
            warning(
                f" Padding input with {self.padding} new time stamps to reach {next_multiple} time stamps"
            )

        # Input processing
        self.conv_in = ConvBlock(
            channels_in=channels_in,
            channels_out=enc_channels_p,
            activation=enc_activation,
            spatial=True,
            temporal=False,
            transpose=False,
        )

        # Downsample layers of the UNet
        #  - Downsample
        #  - ResNet block
        self.layers_down = torch.nn.ModuleList()
        for i in range(len(channels)):
            self.layers_down.append(
                BasicLayer(
                    mode="down",
                    channels_in=enc_channels_p if (i == 0) else channels[i - 1],
                    channels_out=channels[i],
                    activation=enc_activation,
                    dropout_rate=enc_dropout_rate,
                    spatial=((i + 1) <= depth_spatial),
                    temporal=((i + 1) <= depth_temporal),
                )
            )

        # Bottleneck layer
        self.bottleneck = BottleneckLayer(
            channels=channels[-1],
            channels_bn=int(bn_fraction * channels[-1]),
            activation=bn_activation,
            dropout_rate=bn_dropout_rate,
            spatial=(len(channels) == depth_spatial),
            temporal=(len(channels) == depth_temporal),
        )

        # Upsample layers of the UNet (intermediate layers)
        self.layers_up = torch.nn.ModuleList()
        for i in range(len(channels) - 1, -1, -1):
            self.layers_up.append(
                BasicLayer(
                    mode="up",
                    channels_in=2 * channels[i],
                    channels_out=dec_channels_p if (i == 0) else channels[i - 1],
                    activation=dec_activation,
                    dropout_rate=dec_dropout_rate,
                    spatial=((i + 1) <= depth_spatial),
                    temporal=((i + 1) <= depth_temporal),
                )
            )
        assert len(self.layers_down) == len(
            self.layers_up
        ), f"Network needs as many layers down ({len(self.layers_down)}) as up ({len(self.layers_up)})"

        # Output mapping (processing + reconstruction)
        self.conv_out = OutputBlock(
            channels_in=enc_channels_p + dec_channels_p,
            channels_p=dec_channels_p,
            channels_out=channels_out,
            prc_activation=dec_activation,
            out_activation=out_activation,
            n_ts=n_ts + self.padding,
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward process."""
        if self.padding:
            x = torch.cat([x, torch.zeros_like(x)[:, :, : self.padding]], dim=2)

        # Input mapping (processing)
        x = self.conv_in(x)

        # Downward pass
        memory = [x]
        for layer in self.layers_down:
            memory.append(layer(memory[-1]))

        # Bottleneck
        y = self.bottleneck(memory[-1])

        # Upward pass
        for layer in self.layers_up:
            y = layer(torch.cat([y, memory.pop(-1)], dim=1))
        assert len(memory) == 1, "Memory should contain exactly one!"

        # Output mapping
        y = self.conv_out(torch.cat([y, memory.pop(-1)], dim=1))

        # Remove the time dimension which is not needed anymore
        return y.squeeze(2)
