"""CNN embedded transformer version of the inpainting model."""

from __future__ import annotations

from pathlib import Path
from typing import Iterator

import numpy as np
import torch
from torch import nn

from vito_cropsar.models.base_model import InpaintingBase
from vito_cropsar.models.CnnTransformer.config import ModelConfig
from vito_cropsar.models.CnnTransformer.model import CnnTransformer
from vito_cropsar.models.CnnTransformer.utils import get_repr
from vito_cropsar.models.shared.utils import interpolate_s1_batch


class InpaintingModel(InpaintingBase):
    """CNN embedded transformer  interpolation version of the inpainting model."""

    def __init__(self, cfg: ModelConfig) -> None:
        """
        Initialize configuration file.

        Parameters
        ----------
        cfg : ModelConfig
            Configuration file specific to the CnnTransformer model
        """
        super().__init__(cfg=cfg)

        # Create the model and compile it
        self._model = CnnTransformer(
            channels_in=self.n_channels_in,
            channels_out=self.n_channels_out,
            n_ts=self.n_ts,
            enc_channels_p=self.enc_channels_p,
            enc_channels_resnet=self.enc_channels_resnet,
            enc_dropout=self.enc_dropout,
            enc_activation=self.enc_activation,
            temp_layers=self.temp_layers,
            temp_heads=self.temp_heads,
            temp_dropout=self.temp_dropout,
            temp_activation=self.temp_activation,
            dec_channels_p=self.dec_channels_p,
            dec_channels_resnet=self.dec_channels_resnet,
            dec_dropout=self.dec_dropout,
            dec_activation=self.dec_activation,
            sha_activation=self.sha_activation,
        )

        # Move all models to device
        self.to(self.device)

    def forward(
        self,
        s1: torch.Tensor,
        s2: torch.Tensor,
        mask: torch.Tensor,
    ) -> torch.Tensor:
        """
        Forward process of the model.

        Parameters
        ----------
        s1 : torch.Tensor
            Sentinel-1 data
            Shape: (batch, time, channels_s1, width, height)
        s2 : torch.Tensor
            Sentinel-2 data (partially obcured with NaN)
            Shape: (batch, time, channels_s2, width, height)
            Note: The NaNs are still present for S2-data!
        mask : torch.Tensor
            Mask indicating which pixels are visible (0 if obscured, 1 if visible)
            Shape: (batch, time, width, height)

        Returns
        -------
        torch.Tensor
            Inpainted Sentinel-2 data (NaN free)
            Shape: (batch, time, channels_s2, width, height)
        """
        assert not torch.isnan(mask).any(), "Assuming no NaN in mask!"
        assert (
            s1[:, :, 0].shape == s2[:, :, 0].shape
        ), f"Assuming same shape for S1 ({s1.shape}) and S2 ({s2.shape})!"

        # Interpolate NaNs in S1
        s1 = interpolate_s1_batch(s1)

        # Calculate the edges of the S2 data
        s2_e = (
            torch.stack(
                [
                    torch.tensor(get_repr(sample), dtype=s2.dtype, device=s2.device)
                    for sample in s2.cpu().numpy()
                ]
            )
            if self.edge_repr
            else None
        )

        # Ensure S2 has no NaNs
        s2[s2.isnan()] = self.fill_nan

        # Stack S1 and S2, and run through model
        inputs = torch.cat([s2, s1], dim=2)
        logits = self._model(
            x=inputs.permute(0, 2, 1, 3, 4),
            e=s2_e,
        ).permute(0, 2, 1, 3, 4)
        return logits

    def train(self) -> None:
        """Put the model in training mode."""
        torch.set_grad_enabled(True)
        self._model.train()

    def eval(self) -> None:
        """Put the model in evaluation mode."""
        torch.set_grad_enabled(False)
        self._model.eval()

    def to(self, device: torch.device) -> None:
        """Move the model to the given device."""
        self._model.to(device)

    def parameters(self) -> Iterator[nn.Parameter]:
        """Enlist all the parameters."""
        return self._model.parameters()

    def get_parameter_count(self) -> tuple[int, int]:
        """Count the trainable and untrainable parameters."""
        trainable, untrainable = 0, 0
        for p in self._model.parameters():
            trainable += p.numel() if p.requires_grad else 0
            untrainable += 0 if p.requires_grad else p.numel()
        return trainable, untrainable

    @classmethod
    def load(cls, mdl_f: Path) -> InpaintingModel:
        """Load in the corresponding encoder."""
        mdl = super().load(mdl_f=mdl_f)
        mdl._model.load_state_dict(
            torch.load(mdl_f / "model_weights", map_location=mdl.device)
        )
        mdl.eval()
        return mdl

    def save(self) -> None:
        """Save model."""
        super().save()
        torch.save(self._model.state_dict(), self.model_folder / "model_weights")


if __name__ == "__main__":
    from vito_cropsar.constants import get_models_folder
    from vito_cropsar.data import Scaler

    # Create new model
    my_model = InpaintingModel(
        cfg=ModelConfig(
            tag="test_transf",
            scaler=Scaler.load(
                bands_s1=["s1_asc_vv", "s1_des_vv", "s1_asc_vh", "s1_des_vh"],
                bands_s2=["s2_fapar"],
            ),
        )
    )
    print(my_model.get_parameter_count())
    my_model.save()

    # Load model again
    new_model = InpaintingModel.load(mdl_f=get_models_folder() / "test_transf")
    assert torch.allclose(
        list(my_model._model.parameters())[0],
        list(new_model._model.parameters())[0],
    )

    # Run the model in inference
    s1 = np.random.rand(32, 2, 128, 128).astype(np.float32)
    s1[[0, 4, 18, 19]] = np.nan
    s2 = np.random.rand(32, 1, 128, 128).astype(np.float32)
    s2[s2 > 0.5] = np.nan  # noqa: PLR2004
    _ = new_model(s1=s1, s2=s2)
