"""Custom loss functions."""

from __future__ import annotations

import numpy as np
import torch
from numpy.typing import NDArray
from torch.nn.functional import cosine_embedding_loss, cross_entropy

LOSSES = [
    "cross_entropy",
    "cross_entropy_focal",
    "cross_entropy_hierarchical_acc",
    "cross_entropy_hierarchical_smooth",
    "cosine_embedding",
]


class CustomLoss:
    """Custom loss class."""

    def __init__(self, loss_type: str, classes: NDArray[np.str_]) -> None:
        """
        Initialise the loss class.

        Parameters
        ----------
        loss_type : str
            Type of loss to be used
        classes : NDArray[np.str_]
            Classes used for classification
        """
        assert (
            loss_type in LOSSES
        ), f"Loss type should be any of {LOSSES} (got '{loss_type}' instead)"

        self._type = loss_type
        self._classes = classes

    def __str__(self) -> str:
        """String representation of the loss."""
        return f"CustomLoss(type={self._type})"

    def __call__(
        self,
        inputs: torch.Tensor,
        inputs2: torch.Tensor | None = None,
        targets: torch.Tensor | None = None,
    ) -> torch.float64:
        """Calculate the loss."""
        assert ("cosine_embedding" in self._type) == (
            inputs2 is not None
        ), "Second inputs should only be provided if train_type='similarity'"
        if "cross_entropy" in self._type:
            return self._cross_entropy(inputs=inputs, targets=targets)
        elif "cosine_embedding" in self._type:
            return self._cosine_embedding(inputs1=inputs, inputs2=inputs2)
        else:
            raise f"Loss type '{self._type}' not supported!"

    def _cross_entropy(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.float64:
        """Cross entropy loss."""
        assert targets is not None, "Targets should be provided for similarity training!"
        if "focal" in self._type:
            return focal_cross_entropy(inputs=inputs, targets=targets)
        elif "hierarchical_acc" in self._type:
            return cross_entropy_hierarchical_acc(
                inputs=inputs, targets=targets, classes=self._classes
            )
        elif "hierarchical_smooth" in self._type:
            return cross_entropy_hierarchical_smooth(
                inputs=inputs, targets=targets, classes=self._classes
            )
        else:
            return cross_entropy(input=inputs, target=targets)

    def _cosine_embedding(self, inputs1: torch.Tensor, inputs2: torch.Tensor) -> torch.float64:
        """Cosine similarity loss."""
        ones = torch.ones(len(inputs1)).to(inputs1.device)  # Put on same device
        losses = cosine_embedding_loss(
            input1=inputs1, input2=inputs2, target=ones, reduction="none"
        )
        return (losses**2).mean()


def focal_cross_entropy(
    inputs: torch.Tensor,
    targets: torch.Tensor,
    alpha: float = 1,
    gamma: float = 2.0,
) -> torch.float64:
    """Compute focal loss of cross entropy loss."""
    ce_loss = cross_entropy(inputs, targets, reduction="none")
    pt = torch.exp(-ce_loss)
    return (alpha * (1 - pt) ** gamma * ce_loss).mean()


def cross_entropy_hierarchical_acc(
    inputs: torch.Tensor,
    targets: torch.Tensor,
    classes: NDArray[np.str_],
) -> torch.float64:
    """Compute the hierarchical cross entropy."""
    probs = torch.nn.functional.softmax(input=inputs, dim=1)
    loss1 = _cross_entropy_leveled(probs=probs, targets=targets, classes=classes, lvl=1)
    loss2 = _cross_entropy_leveled(probs=probs, targets=targets, classes=classes, lvl=2)
    loss3 = _cross_entropy_leveled(probs=probs, targets=targets, classes=classes, lvl=3)
    loss4 = cross_entropy(input=inputs, target=targets)
    return (loss1 + loss2 + loss3 + loss4) / 5


def _cross_entropy_leveled(
    probs: torch.Tensor,
    targets: torch.Tensor,
    classes: NDArray[np.str_],
    lvl: int,
) -> torch.float64:
    """Apply cross entropy on a given level."""
    classes_mapped = sorted({"-".join(c.split("-")[:lvl]) for c in classes})
    probs_mapped = torch.zeros((len(probs), len(classes_mapped)))
    targets_mapped = torch.zeros((len(targets), len(classes_mapped)))

    # Merge on the given level
    for i, c in enumerate(classes):
        cls_i = classes_mapped.index("-".join(c.split("-")[:lvl]))
        probs_mapped[:, cls_i] += probs[:, i]
        targets_mapped[:, cls_i] += targets[:, i]

    # Calculate cross entropy using probabilities (not logits!)
    return (-torch.sum(targets_mapped * torch.log(probs_mapped), dim=1)).mean()


def cross_entropy_hierarchical_smooth(
    inputs: torch.Tensor,
    targets: torch.Tensor,
    classes: np.ndarray[str],
    max_distance: int = 4,
) -> torch.float64:
    """Compute smoothed hierarchical cross entropy, based on model classes."""
    trg_clss = [classes[t] for t in torch.argmax(targets, dim=1)]
    targets_smoothed = torch.Tensor([[_hier_distance(tc, c) for c in classes] for tc in trg_clss])
    targets_smoothed = targets_smoothed.to(targets.device)
    targets_smoothed = (2 ** (max_distance - targets_smoothed) - 1) / (max_distance**2 - 1)
    return cross_entropy(input=inputs, target=targets_smoothed)


def _hier_distance(x: str, y: str) -> int:
    """Compute hierarchical distance between 2 classes."""
    x_lvls = (f"{x[:-1]}." if x[-1] in {"w", "s"} else x).split("-")
    y_lvls = (f"{y[:-1]}." if y[-1] in {"w", "s"} else y).split("-")
    return next((len(x_lvls) - i for i in range(len(x_lvls)) if x_lvls[i] != y_lvls[i]), 0)
