"""Base inpainting model."""

from __future__ import annotations

import json
from pathlib import Path
from typing import Any, Iterator

import numpy as np
import torch
from numpy.typing import NDArray
from torch import nn

from vito_cropsar.constants import (
    PRECISION_FLOAT_NP,
    PRECISION_INT_NP,
)
from vito_cropsar.data import Scaler, apply_mask
from vito_cropsar.models.base_config import ModelBaseConfig
from vito_cropsar.models.utils import JSONEncoder
from vito_cropsar.vito_logger import Logger


class InpaintingBase:
    """Base inpainting model."""

    def __init__(self, cfg: ModelBaseConfig) -> None:
        """
        Initialize configuration file.

        Parameters
        ----------
        cfg : ModelBaseConfig
            Model configuration
        """
        # Set the model configuration
        self.cfg = cfg

        # Model specific attributes
        self.device = "cuda" if torch.cuda.is_available() else "cpu"

        # Create model specific logger
        self.model_folder.mkdir(parents=True, exist_ok=True)
        self.logger = Logger(log_f=self.model_folder)

    def __str__(self) -> str:
        """Representation of the model."""
        return f"{self.__class__.__name__}({self.tag})"

    def __repr__(self) -> str:
        """Representation of the model."""
        return str(self)

    def __call__(
        self,
        s1: NDArray[PRECISION_FLOAT_NP],
        s2: NDArray[PRECISION_FLOAT_NP],
    ) -> NDArray[PRECISION_FLOAT_NP]:
        """
        Create an inpainting for the provided S2 data.

        Parameters
        ----------
        s1 : NDArray[PRECISION_FLOAT_NP]
            Sentinel-1 data
            Shape: (time, channels_s1, width, height)
        s2 : NDArray[PRECISION_FLOAT_NP]
            Sentinel-2 data (partially obcured with NaN)
            Shape (time, channels_s2, width, height)

        Returns
        -------
        NDArray[PRECISION_FLOAT_NP]
            Inpainted Sentinel-2 data
            Shape: (time, channels_s2, width, height)
        """
        # Scale if the scaler is provided
        if self.scaler:
            # Safe is set to True since it shouldn't manipulate the original data
            s1, s2 = self.scaler(s1=s1, s2=s2, inference=True)

        # Format the attributes
        mask = (
            (~np.isnan(s2)).any(axis=1).astype(PRECISION_INT_NP)
        )  # 1 if visible, 0 is cloud obscured
        s2 = apply_mask(s2, mask=mask)  # Mask out s2

        # Make the prediction
        pred = (
            self.forward(
                s2=torch.tensor(s2[None, :], device=self.device),
                mask=torch.tensor(mask[None, :], device=self.device),
                s1=torch.tensor(s1[None, :], device=self.device),
            )
            .cpu()
            .detach()
            .numpy()[0]
        )

        # Inverse the scaling if the scaler is provided
        if self.scaler:
            # Safe is set to False since it can manipulate the original prediction
            _, pred = self.scaler(s2=pred, reverse=True, safe=False)

        return pred

    def __getitem__(self, key: str) -> Any:
        """Get the attribute."""
        return getattr(self, key)

    def __getattribute__(self, attr: str) -> Any:
        """Get the attribute."""
        try:
            return super().__getattribute__(attr)
        except AttributeError as e:
            if hasattr(self.cfg, attr):
                return getattr(self.cfg, attr)
            raise e

    def forward(
        self,
        s1: torch.Tensor,
        s2: torch.Tensor,
        mask: torch.Tensor,
    ) -> torch.Tensor:
        """
        Forward process of the model.

        Parameters
        ----------
        s1 : torch.Tensor
            Sentinel-1 data
            Shape: (batch, time, channels_s1, width, height)
        s2 : torch.Tensor
            Sentinel-2 data (partially obcured with NaN)
            Shape: (batch, time, channels_s2, width, height)
        mask : torch.Tensor
            Mask indicating which pixels are visible (0 if obscured, 1 if visible)
            Shape: (batch, time, channels_s2, width, height)

        Returns
        -------
        torch.Tensor
            Inpainted Sentinel-2 data (NaN free)
            Shape: (batch, time, channels_s2, width, height)
        """
        raise NotImplementedError

    def train(self) -> None:
        """Put the model in training mode."""
        raise NotImplementedError

    def eval(self) -> None:
        """Put the model in evaluation mode."""
        raise NotImplementedError

    def parameters(self) -> Iterator[nn.Parameter]:
        """Enlist all the parameters."""
        raise NotImplementedError

    def get_metadata(self) -> dict[str, Any]:
        """Get the metadata of the model."""
        return self.cfg.dict()

    @classmethod
    def load(cls, mdl_f: Path) -> InpaintingBase:
        """Load in the corresponding encoder."""
        # Load in the model's configuration
        with open(mdl_f / "config_model.json") as f:
            cfg = json.load(f)
        cfg["folder"] = mdl_f.parent
        cfg["model_folder"] = mdl_f

        # Load in the scaler
        cfg["scaler"] = Scaler.load(
            path=mdl_f / "scaler",
            **cfg["scaler"],
        )

        # Create the model
        return cls(cfg=ModelBaseConfig(**cfg))

    def save(self) -> None:
        """Save model."""
        self.model_folder.mkdir(parents=True, exist_ok=True)
        with open(self.model_folder / "config_model.json", "w") as f:
            json.dump(
                self.get_metadata(),
                f,
                indent=2,
                cls=JSONEncoder,
            )
        self.scaler.save(self.model_folder / "scaler")


if __name__ == "__main__":
    from vito_cropsar.constants import get_models_folder

    model = InpaintingBase(
        cfg=ModelBaseConfig(
            tag="tag",
            scaler=Scaler.load(
                bands_s1=["s1_vv", "s1_vh"],
                bands_s2=["s2_fapar"],
            ),
            optional="optional",
        ),
    )
    model.save()
    model2 = InpaintingBase.load(get_models_folder() / "tag")
    print(model2)
