"""Base model class."""

from __future__ import annotations

import json
from pathlib import Path
from typing import Any, Iterator

import torch
from torch.nn.parameter import Parameter

from vito_lot_delineation.data import get_models_folder
from vito_lot_delineation.vito_logger import Logger


class BaseModel:
    """Base model class."""

    def __init__(
        self,
        model_tag: str,
        config_file: dict[str, Any],
        mdl_f: Path | None = None,
        threshold: float = 0.5,
        **kwargs: dict[str, Any],
    ) -> None:
        """
        Initialize configuration file.

        Parameters
        ----------
        model_tag : str
            Name of the model
        config_file : dict[str, Any]
            Model configuration, specifying model creation
        mdl_f : Path | None, optional
            Folder where the model gets stored
        treshold : float, optional
            Threshold used during inference to remove uncertainty, by default 0.5
        """
        self.tag: str = model_tag
        self.device: str = "cuda" if torch.cuda.is_available() else "cpu"
        self.cfg: dict[str, Any] = config_file
        self._threshold = threshold

        # create model folder
        self.model_folder: Path = (mdl_f or get_models_folder()) / self.tag
        self.model_folder.mkdir(exist_ok=True, parents=True)

        # Create model specific logger
        self.logger = Logger(log_f=self.model_folder)

    def __str__(self) -> str:
        """Representation of the model."""
        return f"{self.__class__.__name__} : {json.dumps(self.cfg, indent=2)}"

    def __repr__(self) -> str:
        """Representation of the model."""
        return str(self)

    def __call__(self, x: torch.Tensor) -> torch.Tensor:
        """
        Make predictions on the provided image.

        Parameters
        ----------
        x : torch.Tensor
            Input image of shape (width, height, channels, time)

        Returns
        -------
        torch.Tensor
            Field prediction of shape (width, height) where:
                0: No fields
                1: Index of first field
                2: Index of second field
                ...
                N: Index of Nth field
        """
        raise NotImplementedError

    def forward_process(self, x: torch.Tensor) -> torch.Tensor:
        """Obtain the raw model output over a batch of images.

        Parameters
        ----------
        x : torch.Tensor
            Batch of samples over which to predict

        Returns
        -------
        torch.Tensor
            Raw model output
        """
        raise NotImplementedError

    def train(self) -> None:
        """Set enc and clf to train mode."""
        raise NotImplementedError

    def eval(self) -> None:  # noqa: A003
        """Set enc and clf to eval mode."""
        raise NotImplementedError

    def to(self, device: str) -> None:
        """Put the model on the requested device."""
        raise NotImplementedError

    def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
        """Get parameters from encoder and decoder."""
        raise NotImplementedError

    def set_threshold(self, threshold: float) -> None:
        """Set new model threshold."""
        self._threshold = threshold

    def get_threshold(self) -> None:
        """Get current model threshold."""
        return self._threshold

    def save(self) -> None:
        """Save model."""
        raise NotImplementedError

    @classmethod
    def load(cls, mdl_f: Path) -> BaseModel:
        """Load model from model folder."""
        raise NotImplementedError
