"""Naive interpolation version of the inpainting model."""

from __future__ import annotations

from typing import Any, Iterator

import torch
from torch import nn

from vito_cropsar.models.base_config import ModelBaseConfig
from vito_cropsar.models.base_model import InpaintingBase
from vito_cropsar.models.utils import interpolate_time


class InpaintingModel(InpaintingBase):
    """Naive interpolation version of the inpainting model."""

    def __init__(self, cfg: ModelBaseConfig) -> None:
        """
        Initialize the model.

        Parameters
        ----------
        cfg : ModelBaseConfig
            Configuration file specific to the naive model
            Note: Naive doesn't need configuration so simply pass the base config
        """
        super().__init__(cfg=cfg)

    def forward(
        self,
        s2: torch.Tensor,
        *args: Any,
        **kwargs: Any,
    ) -> torch.Tensor:
        """
        Forward process of the model.

        Parameters
        ----------
        s2 : torch.Tensor
            Sentinel-2 data (partially obcured with NaN)
            Shape: (batch, time, channels_s2, width, height)

        Returns
        -------
        torch.Tensor
            Inpainted Sentinel-2 data (NaN free)
            Shape: (batch, time, channels_s2, width, height)
        """
        results = torch.zeros_like(s2)
        for i, x in enumerate(s2):  # Iterate over the batch
            results[i] = interpolate_time(arr=x)
        return results

    def train(self) -> None:
        """Put the model in training mode."""
        # No model

    def eval(self) -> None:
        """Put the model in evaluation mode."""
        # No model

    def parameters(self) -> Iterator[nn.Parameter]:
        """Enlist all the parameters."""
        yield nn.Parameter()
