"""Loading functions for classifiers."""

from __future__ import annotations

from pathlib import Path
from typing import Any

import numpy as np
import torch
from numpy.typing import NDArray

from vito_crop_classification.model.classifiers.base import BaseClassifier
from vito_crop_classification.model.classifiers.dense import DenseClassifier
from vito_crop_classification.model.classifiers.similarity import SimilarityClassifier


def load_classifier(
    mdl_f: Path,
    clf_type: str,
    clf_tag: str,
    device: str | None = None,
) -> BaseClassifier:
    """Load the specified pre-trained classifier."""
    device = device if device is not None else ("cuda" if torch.cuda.is_available() else "cpu")
    if clf_type == "DenseClassifier":
        return DenseClassifier.load(mdl_f, clf_tag=clf_tag, device=device)
    if clf_type == "SimilarityClassifier":
        return SimilarityClassifier.load(mdl_f, clf_tag=clf_tag, device=device)
    else:
        raise Exception(f"Cannot parse classifier type: '{clf_type}'!")


def parse_classifier(
    cfg_clf: dict[str, Any],
    input_size: int,
    classes: NDArray[np.str_],
) -> BaseClassifier | None:
    """Parse the requested classifier."""
    if cfg_clf["type"] == "DenseClassifier":
        return DenseClassifier(
            input_size=input_size,
            classes=classes,
            **cfg_clf["params"],
        )
    elif cfg_clf["type"] == "SimilarityClassifier":
        return SimilarityClassifier(
            input_size=input_size,
            classes=classes,
            **cfg_clf["params"],
        )
    return None
