"""Base classifier."""

from __future__ import annotations

import json
from pathlib import Path
from typing import Any, Iterator

import numpy as np
import torch
import torch.nn as nn
from numpy.typing import NDArray


class BaseClassifier:
    """Base classifier."""

    def __init__(
        self,
        clf_tag: str,
        input_size: int,
        classes: NDArray[np.str_] | list[str],
    ):
        """
        Base classifier.

        Parameters
        ----------
        clf_tag : str
            Name of the classifier, used to save it
        input_size : int
            Size of the input embedding to classify
        classes : NDArray[np.str_]
            List of classes to classify
        """
        self._tag = clf_tag
        self._input_size = input_size
        self._classes = np.asarray(classes)
        self._class_mapping = {i: c for i, c in enumerate(classes)}
        self._device = "cuda" if torch.cuda.is_available() else "cpu"
        self._model: nn.Module | None = None

    def __str__(self) -> str:
        """String representation of the classifier."""
        return (
            f"{self.__class__.__name__}(inp_dim={self._input_size}, n_classes={len(self._classes)})"
        )

    def __repr__(self) -> str:
        """String representation of the classifier."""
        return str(self)

    def __call__(self, embs: torch.Tensor) -> torch.Tensor:
        """Return best class and overall classes probability over the provided embeddings."""
        assert self._model is not None
        return self._model(embs)

    def forward_process(self, embs: torch.Tensor) -> torch.Tensor:
        """Forward the embedding through the classifier's underlying model."""
        return self._model(embs)

    def parameters(self) -> Iterator[nn.parameter.Parameter]:
        """Return model parameters to train."""
        assert self._model is not None
        return self._model.parameters()

    def train(self, device: str | None = None) -> None:
        """Put the model to training, if exists."""
        assert self._model is not None
        self._model.train()
        self.to(device)

    def eval(self, device: str | None = None) -> None:
        """Put the model to training, if exists."""
        assert self._model is not None
        self._model.eval()
        self.to(device)

    def to(self, device: str | None = None) -> None:
        """Put the model on the requested device."""
        if device is not None:
            self._device = device
        self._model.to(self._device)

    def get_type(self) -> str:
        """Get the model's tag."""
        return self.__class__.__name__

    def get_tag(self) -> str:
        """Get the model's tag."""
        return self._tag

    def get_classes(self) -> NDArray[np.str_]:
        """Return the classes."""
        return self._classes

    def get_metadata(self) -> dict[str, Any]:
        """Get the model's metadata."""
        return {
            "input_size": self._input_size,
            "classes": list(self._classes),
        }

    @classmethod
    def load(cls, mdl_f: Path, clf_tag: str, device: str | None = None) -> BaseClassifier:
        """Load in the corresponding encoder."""
        # Load the metadata
        load_path = mdl_f / "modules" / clf_tag
        with open(load_path / "clf_metadata.json", "r") as f:
            metadata = json.load(f)
        clf = cls(
            clf_tag=clf_tag,
            **metadata,
        )
        device = device if device is not None else ("cuda" if torch.cuda.is_available() else "cpu")
        clf._model = torch.load(load_path / "clf_weights", map_location=torch.device(device))  # type: ignore
        return clf

    def save(self, mdl_f: Path) -> None:
        """Save the classifier."""
        # Save the metadata
        save_path = mdl_f / "modules" / self._tag
        save_path.mkdir(parents=True, exist_ok=True)
        with open(save_path / "clf_metadata.json", "w") as f:
            json.dump(self.get_metadata(), f, indent=2)
        torch.save(self._model, save_path / "clf_weights")
