"""Set of loss functions."""

from __future__ import annotations

import torch

from vito_cropsar.models import LossConfig
from vito_cropsar.models.shared.losses import (
    MAELoss,
    MAETimeRegLoss,
    SSimLoss,
    SSimTimeRegLoss,
)


class VITOLoss:
    """Linear combination of multiple losses."""

    def __init__(
        self,
        cfg: LossConfig,
        device: str,
        n_channels: int,
    ) -> None:
        """
        Initialize loss function.

        Parameters
        ----------
        device : str
            The device on which the loss is calculated
        n_channels : int
            Number of channels in the prediction, to calculate the loss over
        cfg : LossConfig
            Global loss configuration
        """
        self.cfg = cfg
        self.device = device
        self.n_channels = n_channels
        self.mae_loss = MAELoss()
        self.ssim_loss = SSimLoss(device=self.device, n_channels=self.n_channels)
        self.mae_time_reg_loss = MAETimeRegLoss()
        self.ssim_time_reg_loss = SSimTimeRegLoss(
            device=self.device, n_channels=self.n_channels
        )

    def __call__(
        self,
        target: torch.Tensor,
        pred: torch.Tensor,
        mask: torch.Tensor,
    ) -> dict[str, torch.Tensor]:
        """
        Calculate the total (linearly aggregated) loss.

        Parameters
        ----------
        target : torch.Tensor
            Target
            Shape: (batch, time, channels, height, width)
        pred : torch.Tensor
            Prediction
            Shape: (batch, time, channels, height, width)
        mask : torch.Tensor
            Mask (cloud) indicating which part of the target was not present in the input
            Shape: (batch, time, height, width)
            Note: Value of 1 if visible (part of the input), 0 if not visible (cloud)

        Returns
        -------
        dict[str, torch.Tensor]
            Computed regularization loss over the time dimension
            Note: This is a dictionary of {<loss_name>: <loss_value>} pairs
            Note: The linearly aggregated loss is located under the key "loss"
        """
        losses = {"loss": torch.tensor(0.0, requires_grad=True)}

        def update_loss(loss: dict[str, torch.Tensor]) -> None:
            """Update the loss collection and the global loss computation."""
            for k, v in loss.items():
                if torch.isnan(v):
                    continue  # Don't add to backpropagation loss
                if self.cfg.weights[k] > 0:
                    losses["loss"] = losses["loss"] + self.cfg.weights[k] * v
            losses.update(loss)

        # Calculate each loss, if configured
        if self.cfg.use_mae:
            update_loss(self.mae_loss(target=target, pred=pred, mask=mask))
        if self.cfg.use_ssim:
            update_loss(self.ssim_loss(target=target, pred=pred, mask=mask))
        if self.cfg.use_mae_time_reg:
            update_loss(self.mae_time_reg_loss(target=target, pred=pred, mask=mask))
        if self.cfg.use_ssim_time_reg:
            update_loss(self.ssim_time_reg_loss(target=target, pred=pred, mask=mask))
        return losses

    def __str__(self) -> str:
        """Loss representation."""
        msg = "VITOLoss("
        losses = []
        if self.cfg.use_mae:
            losses.append(f"mae_masked={self.cfg.weights['mae_masked']:0.2f}")
            losses.append(f"mae_unmasked={self.cfg.weights['mae_unmasked']:0.2f}")
        if self.cfg.use_ssim:
            losses.append(f"ssim={self.cfg.weights['ssim']:0.2f}")
        if self.cfg.use_mae_time_reg:
            losses.append(
                f"mae_time_reg={self.cfg.weights['mae_time_regularization']:0.2f}"
            )
        if self.cfg.use_ssim_time_reg:
            losses.append(
                f"ssim_time_reg={self.cfg.weights['ssim_time_regularization']:0.2f}"
            )
        msg += ", ".join(losses)
        return msg + ")"

    def __repr__(self) -> str:
        """Loss representation."""
        return str(self)
