"""Set of custom loss functions."""

from __future__ import annotations

from typing import Any, Callable

import torch


class VITOLoss:
    """Custom Loss class for VITO."""

    def __init__(self, loss_cfg: list[dict[str, Any]], reduction="sum") -> None:
        """Initialize Custom VITO loss.

        Parameters
        ----------
        loss_cfg : list[dict[str, Any]]
            List of configuration files for each loss.
        reduction : str, optional
            Reduction to apply to the computed losses.
            It can be either "sum" or "average", by default "sum"
        """
        assert reduction in {
            "sum",
            "average",
        }, f"Reduction {reduction} not implemented."
        self.losses = [get_loss(loss["type"]) for loss in loss_cfg]
        self.weights_multipliers = [loss["weights_multipliers"] for loss in loss_cfg]
        self.loss_coefficients = [loss["loss_coefficient"] for loss in loss_cfg]
        self.output_ids = [loss["output_id"] for loss in loss_cfg]
        self.tags = [loss["tag"] for loss in loss_cfg]
        self.reduction = reduction

    def __call__(
        self,
        prs: torch.Tensor | list[torch.Tensor],
        gt: torch.Tensor,
        ws: list[torch.Tensor],
    ) -> dict[str, Any]:
        """Compute loss function for each branch."""
        # deal with single-output model
        prs = [prs] if type(prs) == torch.Tensor else prs
        res = {
            "losses": {
                f"{tag}": loss(prs[output_id], gt, ws[i])
                for i, (loss, output_id, tag) in enumerate(
                    zip(self.losses, self.output_ids, self.tags)
                )
            }
        }

        weighted_losses = torch.stack(
            [v * i for (_, v), i in zip(res["losses"].items(), self.loss_coefficients)]
        )
        if self.reduction == "sum":
            res["loss"] = torch.sum(weighted_losses)
        elif self.reduction == "average":
            res["loss"] = torch.sum(weighted_losses) / sum(self.loss_coefficients)
        return res

    def __str__(self) -> str:
        """Representation of the VITOLoss."""
        losses = [
            f"{ls.__class__.__name__}_{w} -> branch {o}"
            for ls, w, o in zip(self.losses, self.weights_multipliers, self.output_ids)
        ]
        return f"VITOLoss: {','.join(losses)}"

    def __repr__(self) -> str:
        """Representation of the VITOLoss."""
        return str(self)


def dice_loss(
    pr: torch.Tensor, gt: torch.Tensor, w: torch.Tensor = None
) -> torch.Tensor:
    """Compute DiceLoss."""
    w = w if w is not None else torch.ones_like(gt).to(gt.device)
    pr_ = pr.squeeze()

    tp = (w * pr_ * gt).sum()
    fn = (w * gt).sum() - tp
    fp = (w * pr_).sum() - tp
    dice = (2 * tp + 1) / (2 * tp + fn + fp + 1)
    return 1.0 - dice


def binary_cross_entropy(
    pr: torch.Tensor, gt: torch.Tensor, w: torch.Tensor = None
) -> torch.Tensor:
    """Compute BinaryCrossEntopy loss."""
    w = w if w is not None else torch.ones_like(gt).to(gt.device)
    bce = torch.nn.BCEWithLogitsLoss(reduction="none")
    pr_ = pr.squeeze()
    gt_ = gt.float()
    return (bce(pr_, gt_) * w).mean([1, 2]).mean()


# TODO: Use custom threshold. Keep track of gradient computation! thresholding is unreversable in backpropagation
def classification_loss(
    pr: torch.Tensor, gt: torch.Tensor, w: torch.Tensor = None
) -> torch.Tensor:
    """Compute Classification loss -> Count of correct pixels."""
    w = w if w is not None else torch.ones_like(gt).to(gt.device)
    pr_ = pr.squeeze().round()
    return 1.0 - (w * (pr_ == gt)).sum() / w.sum()


def get_loss(name: str) -> Callable:
    """Retrieve the loss function from its name."""
    return globals()[name]
