"""SSim time regularization loss."""

from __future__ import annotations

import torch

from vito_cropsar.models.shared.losses.loss_ssim import SSimLoss


class SSimTimeRegLoss(SSimLoss):
    """
    SSim time regularization loss.

    Time regularization loss computes the SSim based on the difference from neighboring pixels.
    """

    def __init__(
        self,
        device: str,
        data_range: float = 2.0,
        data_min: float = -1.0,
        win_size: int = 11,
        win_sigma: float = 1.5,
        n_channels: int = 1,
        spatial_dims: int = 2,
        k: tuple[float, float] = (0.01, 0.03),
    ) -> None:
        """
        Initialize the loss function.

        Parameters
        ----------
        device : str
            Device to use for the loss calculation
        data_range : float
            value range of input images (usually 1.0 or 255)
        data_min : float
            minimum value of input images (usually 0.0 or -1.0)
        win_size : int
            the size of gauss kernel
        win_sigma : float
            sigma of normal distribution
        n_channels : int
            Number of input channels
        spatial_dims : int
            number of spatial dimensions
        k : tuple[float, float]
            scalar constants (K1, K2). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results.
        """
        super().__init__(
            device=device,
            data_range=data_range,
            data_min=data_min,
            win_size=win_size,
            win_sigma=win_sigma,
            n_channels=n_channels,
            spatial_dims=spatial_dims,
            k=k,
        )

    def __call__(
        self,
        target: torch.Tensor,
        pred: torch.Tensor,
        mask: torch.Tensor,
    ) -> dict[str, torch.Tensor]:
        """
        Calculate the loss.

        Parameters
        ----------
        target : torch.Tensor
            Target
            Shape: (batch, time, channels, height, width)
        pred : torch.Tensor
            Prediction
            Shape: (batch, time, channels, height, width)
        mask : torch.Tensor
            Mask (cloud) indicating which part of the target was not present in the input
            Shape: (batch, time, height, width)
            Note: Value of 1 if visible (part of the input), 0 if not visible (cloud)

        Returns
        -------
        dict[str, torch.Tensor]
            Computed regularization loss over the time dimension
            Note: This is a dictionary of {<loss_name>: <loss_value>} pairs
        """
        # Prepare target and prediction
        p = pred.clone() - self.data_min  # (B, T, CH, H, W)
        t = p[:, :-1, ...]
        p = p[:, 1:, ...]

        ssims = []
        for i in range(t.shape[1]):
            t_ = t[:, i]
            p_ = p[:, i]
            ssims.append(
                self.compute_ssim(
                    target=t_,
                    pred=p_,
                    data_range=self.data_range,
                    win=self.win,
                    k=self.k,
                )
            )

        # Calculate the mean loss for each sample in the batch and aggregate over the batch
        return {
            "ssim_time_regularization": (-self.data_min)
            - torch.stack(ssims).mean(dim=1).mean()
        }
