"""Mean absolute error (MAE) loss."""

from __future__ import annotations

import torch

from vito_cropsar.models.shared.losses.loss_base import LossBase
from vito_cropsar.models.shared.metrics import get_mask_errors


class MAELoss(LossBase):
    """Mean absolute error (MAE) loss."""

    def __call__(
        self,
        target: torch.Tensor,
        pred: torch.Tensor,
        mask: torch.Tensor,
    ) -> dict[str, torch.Tensor]:
        """
        Calculate the loss.

        Parameters
        ----------
        target : torch.Tensor
            Target
            Shape: (batch, time, channels, height, width)
        pred : torch.Tensor
            Prediction
            Shape: (batch, time, channels, height, width)
        mask : torch.Tensor
            Mask (cloud) indicating which part of the target was not present in the input
            Shape: (batch, time, height, width)
            Note: Value of 1 if visible (part of the input), 0 if not visible (cloud)

        Returns
        -------
        dict[str, torch.Tensor]
            Computed MAE loss
            Note: This is a dictionary of {<loss_name>: <loss_value>} pairs
        """
        unmasked = []
        masked = []
        for t, p, m in zip(target, pred, mask):
            e_unm, e_m = get_mask_errors(t=t, p=p, m=m)
            unmasked.append(e_unm.mean())
            masked.append(e_m.mean())

        # Calculate the mean loss for each sample and aggregate
        return {
            "mae_unmasked": torch.stack(unmasked).nanmean(),
            "mae_masked": torch.stack(masked).nanmean(),
        }
