"""ResNet variant of the UNet3d inpainting model."""

from __future__ import annotations

from pathlib import Path
from typing import Iterator

import torch
from torch import nn

from vito_cropsar.models.base_model import InpaintingBase
from vito_cropsar.models.ResUNet3d.config import ModelConfig
from vito_cropsar.models.ResUNet3d.model import ResUNet3d
from vito_cropsar.models.shared import interpolate_s1_batch


class InpaintingModel(InpaintingBase):
    """ResNet variant of the UNet3d inpainting model."""

    def __init__(self, cfg: ModelConfig) -> None:
        """
        Initialize configuration file.

        Parameters
        ----------
        cfg : ConfigResUNet3d
            Configuration file specific to the ResUNet3d model
        """
        super().__init__(cfg=cfg)
        self._model = ResUNet3d(
            channels_in=self.n_channels_in,
            channels_out=self.n_channels_out,
            channels=self.channels,
            depth_spatial=self.depth_spatial,
            depth_temporal=self.depth_temporal,
            enc_channels_p=self.enc_channels_p,
            enc_dropout=self.enc_dropout,
            bn_fraction=self.bn_fraction,
            bn_dropout=self.bn_dropout,
            dec_channels_p=self.dec_channels_p,
            dec_dropout=self.dec_dropout,
        )
        self._model.to(self.device)

    def forward(
        self,
        s1: torch.Tensor,
        s2: torch.Tensor,
        mask: torch.Tensor,
    ) -> torch.Tensor:
        """
        Forward process of the model.

        Parameters
        ----------
        s1 : torch.Tensor
            Sentinel-1 data
            Shape: (batch, time, channels_s1, width, height)
        s2 : torch.Tensor
            Sentinel-2 data (partially obcured with NaN)
            Shape: (batch, time, channels_s2, width, height)
            Note: The NaNs are still present for S2-data!
        mask : torch.Tensor
            Mask indicating which pixels are visible (0 if obscured, 1 if visible)
            Shape: (batch, time, width, height)

        Returns
        -------
        torch.Tensor
            Inpainted Sentinel-2 data (NaN free)
            Shape: (batch, time, channels_s2, width, height)
        """
        assert not torch.isnan(mask).any(), "Assuming no NaN in mask!"
        assert (
            s1[:, :, 0].shape == s2[:, :, 0].shape
        ), f"Assuming same shape for S1 ({s1.shape}) and S2 ({s2.shape})!"

        # Interpolate NaNs in S1
        s1 = interpolate_s1_batch(s1)

        # Ensure S2 has no NaNs
        s2[s2.isnan()] = self.fill_nan

        # Feed the stacked input to the model
        inputs = torch.cat([s2, s1], dim=2)
        logits = self._model(inputs.permute(0, 2, 1, 3, 4)).permute(0, 2, 1, 3, 4)
        return logits

    def train(self) -> None:
        """Put the model in training mode."""
        torch.set_grad_enabled(True)
        self._model.train()

    def eval(self) -> None:
        """Put the model in evaluation mode."""
        torch.set_grad_enabled(False)
        self._model.eval()

    def parameters(self) -> Iterator[nn.Parameter]:
        """Enlist all the parameters."""
        return self._model.parameters()

    def get_parameter_count(self) -> tuple[int, int]:
        """Count the trainable and untrainable parameters."""
        trainable, untrainable = 0, 0
        for p in self._model.parameters():
            trainable += p.numel() if p.requires_grad else 0
            untrainable += 0 if p.requires_grad else p.numel()
        return trainable, untrainable

    @classmethod
    def load(cls, mdl_f: Path) -> InpaintingModel:
        """Load in the corresponding encoder."""
        mdl = super().load(mdl_f=mdl_f)
        mdl._model.load_state_dict(
            torch.load(mdl_f / "model_weights", map_location=mdl.device)
        )
        mdl.eval()
        return mdl

    def save(self) -> None:
        """Save model."""
        super().save()
        torch.save(self._model.state_dict(), self.model_folder / "model_weights")


if __name__ == "__main__":
    import numpy as np

    from vito_cropsar.constants import get_models_folder
    from vito_cropsar.data import Scaler

    # Create new model
    my_model = InpaintingModel(
        cfg=ModelConfig(
            tag="test_resunet3d",
            scaler=Scaler.load(
                bands_s1=["s1_vv", "s1_vh"],
                bands_s2=["s2_fapar"],
            ),
        ),
    )
    print(my_model.get_parameter_count())
    my_model.save()

    # Load model again
    new_model = InpaintingModel.load(mdl_f=get_models_folder() / "test_resunet3d")
    assert torch.allclose(
        list(my_model._model.parameters())[0],
        list(new_model._model.parameters())[0],
    )

    # Run the model in inference
    s1 = np.random.rand(32, 2, 128, 128).astype(np.float32)
    s1[[0, 4, 18, 19]] = np.nan
    s2 = np.random.rand(32, 1, 128, 128).astype(np.float32)
    s2[s2 > 0.5] = np.nan  # noqa: PLR2004
    _ = new_model(s1=s1, s2=s2)
