"""Time regularization loss."""

from __future__ import annotations

import torch

from vito_cropsar.models.shared.losses.loss_base import LossBase


class MAETimeRegLoss(LossBase):
    """
    Time regularization loss.

    Time regularization loss computes the MAE based on the difference from neighboring pixels.
    """

    def __call__(
        self,
        target: torch.Tensor,
        pred: torch.Tensor,
        mask: torch.Tensor,
    ) -> dict[str, torch.Tensor]:
        """
        Calculate the loss.

        Parameters
        ----------
        target : torch.Tensor
            Target
            Shape: (batch, time, channels, height, width)
        pred : torch.Tensor
            Prediction
            Shape: (batch, time, channels, height, width)
        mask : torch.Tensor
            Mask (cloud) indicating which part of the target was not present in the input
            Shape: (batch, time, height, width)
            Note: Value of 1 if visible (part of the input), 0 if not visible (cloud)

        Returns
        -------
        dict[str, torch.Tensor]
            Computed regularization loss over the time dimension
            Note: This is a dictionary of {<loss_name>: <loss_value>} pairs
        """
        # Calculate the difference in neighboring pixels
        diff = torch.abs(pred[:, 1:, :, :, :] - pred[:, :-1, :, :, :])

        # Calculate the mean loss for each sample in the batch and aggregate over the batch
        return {"mae_time_regularization": diff.mean(dim=(1, 2, 3, 4)).mean()}
