"""Softmax classifier."""

import numpy as np
import torch
from numpy.typing import NDArray
from torch import nn

from vito_crop_classification.model.classifiers.base import BaseClassifier


class DenseClassifier(BaseClassifier):
    """Softmax based classifier."""

    def __init__(self, clf_tag: str, input_size: int, classes: NDArray[np.str_]):
        """Dense classifier.

        Parameters
        ----------
        clf_tag : str
            Classifier tag
        input_size : int
            Input size of the classifier
        classes : NDArray[np.str_]
            List of classes used for classification
        """
        super(DenseClassifier, self).__init__(
            clf_tag=clf_tag,
            input_size=input_size,
            classes=classes,
        )
        self._model = _create_model(inp_size=self._input_size, n_classes=len(classes))


def _create_model(inp_size: int, n_classes: int) -> torch.nn.Module:
    """Create a dense model."""
    return nn.Sequential(
        nn.Linear(in_features=inp_size, out_features=n_classes),
    )
