"""Softmax classifier."""

from __future__ import annotations

from pathlib import Path

import numpy as np
import torch
from numpy.typing import NDArray
from torch import nn

from vito_crop_classification.model.classifiers.base import BaseClassifier


class SimilarityClassifier(BaseClassifier):
    """Softmax based classifier."""

    def __init__(
        self,
        clf_tag: str,
        input_size: int,
        classes: NDArray[np.str_],
    ):
        """
        Similarity classifier.

        Parameters
        ----------
        clf_tag : str
            Classifier tag
        input_size : int
            Input size of the classifier
        classes : NDArray[np.str_]
            List of classes used for classification
        """
        super(SimilarityClassifier, self).__init__(
            clf_tag=clf_tag,
            input_size=input_size,
            classes=classes,
        )
        self._model = _create_model()
        self._sim = torch.cosine_similarity
        self._reps: torch.Tensor | None = None

    def __call__(self, embs: torch.Tensor) -> torch.Tensor:
        """Return best class and overall classes probability over the provided embeddings."""
        sims = torch.stack([self._sim(embs, x) for x in self._reps], dim=1)
        v_max, v_min = torch.max(sims, dim=1).values, torch.min(sims, dim=1).values
        logits = torch.zeros_like(sims)

        # Adjust the similarities that have at least one similarity larger than zero
        if (v_max > 0).any():
            sims[v_max > 0] = torch.clip(sims[v_max > 0], min=0, max=1)
            logits[v_max > 0] = torch.log(
                sims[v_max > 0] / (1 - sims[v_max > 0])  # Reverse sigmoid
            )

        # Linearly rescale for the other samples (unlikely)
        if (v_max < 0).any():
            logits[v_max < 0] = (  # Normalise to the [-6, 6] range
                12
                * (sims[v_max < 0] - v_min[v_max < 0, None])
                / (v_max[v_max < 0, None] - v_min[v_max < 0, None])
                - 6
            )

        return logits

    def create_representatives(self, relax_thr: float = 1e-6) -> dict[str, torch.Tensor]:
        """Create class representatives."""
        self._reps = torch.rand(len(self._classes), self._input_size)
        classes = self._classes.tolist()

        # Create an embedding model
        model = torch.nn.Embedding(len(classes), embedding_dim=self._input_size)
        model.weight = torch.nn.Parameter(self._reps)

        # Define all the class couples
        couples = []
        for i in range(len(classes)):
            for j in range(i + 1, len(classes)):
                couples.append((classes[i], classes[j]))

        # Optimise the embedding locations
        optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
        prev_loss = float("inf")
        while True:
            optimizer.zero_grad()
            v1 = model(torch.LongTensor([classes.index(x) for x, _ in couples]))
            v2 = model(torch.LongTensor([classes.index(x) for _, x in couples]))
            loss = _vector_loss(v1=v1, v2=v2)
            if loss.isnan().any():
                self.create_representatives(relax_thr=relax_thr)
            loss.backward()
            optimizer.step()

            # Stop the loop once the model learned enough
            if (prev_loss - float(loss)) < relax_thr:
                break
            prev_loss = float(loss)

        # Set the new representatives
        self._reps = torch.Tensor(model.weight).to(self._device)

        return {cls: self._reps[i] for i, cls in enumerate(classes)}

    @classmethod
    def load(cls, mdl_f: Path, clf_tag: str, device: str | None = None) -> BaseClassifier:
        """Load in the corresponding encoder."""
        clf = super(SimilarityClassifier, cls).load(
            mdl_f=mdl_f,
            clf_tag=clf_tag,
            device=device,
        )
        load_path = mdl_f / "modules" / clf_tag
        clf._reps = torch.load(load_path / "clf_reps", map_location=torch.device(device))  # type: ignore
        return clf

    def save(self, mdl_f: Path) -> None:
        """Save the classifier."""
        super(SimilarityClassifier, self).save(
            mdl_f=mdl_f,
        )
        save_path = mdl_f / "modules" / self._tag
        torch.save(self._reps, save_path / "clf_reps")


def _create_model() -> torch.nn.Module:
    """Create a dense model."""
    return nn.Sequential(
        nn.Identity(),
    )


def _vector_loss(v1: torch.Tensor, v2: torch.Tensor) -> torch.float64:
    """Custom vector-based comparison."""
    dist = 1 - torch.nn.functional.cosine_similarity(v1, v2)
    return torch.mean(1 - torch.sqrt(dist))


def _get_cls_distance(cls1: str, cls2: str) -> int:
    """Get the distance between the two classes."""
    lvl1 = (f"{cls1[:-1]}." if cls1[-1] in {"w", "s"} else cls1).split("-")
    lvl2 = (f"{cls2[:-1]}." if cls2[-1] in {"w", "s"} else cls2).split("-")
    return next((len(lvl1) - i for i in range(len(lvl1)) if lvl1[i] != lvl2[i]), 0)
