"""Base loss class."""

from __future__ import annotations

import torch


class LossBase:
    """Base loss class."""

    def __call__(
        self,
        target: torch.Tensor,
        pred: torch.Tensor,
        mask: torch.Tensor,
    ) -> dict[str, torch.Tensor]:
        """
        Calculate the loss.

        Parameters
        ----------
        target : torch.Tensor
            Target
            Shape: (batch, time, channels, height, width)
        pred : torch.Tensor
            Prediction
            Shape: (batch, time, channels, height, width)
        mask : torch.Tensor
            Mask (cloud) indicating which part of the target was not present in the input
            Shape: (batch, time, height, width)
            Note: Value of 1 if visible (part of the input), 0 if not visible (cloud)

        Returns
        -------
        dict[str, torch.Tensor]
            Computed loss
            Note: This is a dictionary of {<loss_name>: <loss_value>} pairs
        """
        raise NotImplementedError
