"""Structural Similarity (SSim) loss."""

from __future__ import annotations

import warnings

import torch
import torch.nn.functional as torch_f

from vito_cropsar.models.shared.losses.loss_base import LossBase


class SSimLoss(LossBase):
    """
    Strucutural Similarity (SSim) loss.

    https://github.com/VainF/pytorch-msssim
    """

    def __init__(
        self,
        device: str,
        data_range: float = 2.0,
        data_min: float = -1.0,
        win_size: int = 11,
        win_sigma: float = 1.5,
        n_channels: int = 1,
        spatial_dims: int = 2,
        k: tuple[float, float] = (0.01, 0.03),
    ) -> None:
        """
        Initialize the loss function.

        Parameters
        ----------
        device : str
            Device to use for the loss calculation
        data_range : float
            value range of input images (usually 1.0 or 255)
        data_min : float
            minimum value of input images (usually 0.0 or -1.0)
        win_size : int
            the size of gauss kernel
        win_sigma : float
            sigma of normal distribution
        n_channels : int
            Number of input channels
        spatial_dims : int
            number of spatial dimensions
        k : tuple[float, float]
            scalar constants (K1, K2). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results.
        """
        super().__init__()
        self.win_size = win_size
        self.win = (
            _fspecial_gauss_1d(win_size, win_sigma)
            .repeat([n_channels, 1] + [1] * spatial_dims)
            .to(device)
        )
        self.data_range = data_range
        self.data_min = data_min
        self.k = k

    def __call__(
        self,
        target: torch.Tensor,
        pred: torch.Tensor,
        mask: torch.Tensor,
    ) -> dict[str, torch.Tensor]:
        """
        Calculate the loss.

        Parameters
        ----------
        target : torch.Tensor
            Target
            Shape: (batch, time, channels, height, width)
        pred : torch.Tensor
            Prediction
            Shape: (batch, time, channels, height, width)
        mask : torch.Tensor
            Mask (cloud) indicating which part of the target was not present in the input
            Shape: (batch, time, height, width)
            Note: Value of 1 if visible (part of the input), 0 if not visible (cloud)

        Returns
        -------
        dict[str, torch.Tensor]
            Computed SSim loss
            Note: This is a dictionary of {<loss_name>: <loss_value>} pairs
        """
        # prepare target and pred
        t = target.clone() - self.data_min  # (B, T, CH, H, W)
        p = pred.clone() - self.data_min  # (B, T, CH, H, W)
        t[t.isnan()] = p[t.isnan()]  # set t and p equal where t is nan
        # TODO: Setting target as the prediction is odd IMO --> investigate the implication of this!

        # compute ssim time step wise
        ssims = [
            self.compute_ssim(
                target=t[:, i],
                pred=p[:, i],
                data_range=self.data_range,
                win=self.win,
                k=self.k,
            )
            for i in range(t.shape[1])
        ]

        return {"ssim": (-self.data_min) - torch.stack(ssims).mean(dim=1).mean()}

    def compute_ssim(
        self,
        target: torch.Tensor,
        pred: torch.Tensor,
        win: torch.Tensor,
        data_range: float = 2.0,
        k: tuple[float, float] = (0.01, 0.03),
    ) -> torch.Tensor:
        """
        SSIM calculation on a batch of images over a single timestamp.

        Parameters
        ----------
        target : torch.Tensor
            batch of target images (N,C,H,W)
        pred : torch.Tensor
            batch of predicted images (N,C,H,W)
        win : torch.Tensor
            1-D gaussian tensor filter
        data_range : float
            value range of input images. (usually 1.0 or 255)
        k : tuple[float, float]
            scalar constants (K1, K2). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results.

        Returns
        -------
        torch.Tensor
            ssim results
        """
        assert win.shape[-1] % 2 == 1, "Window size should be odd."
        assert (
            target.size() == pred.size()
        ), f"Input images should have the same dimensions, but got {target.shape} and {pred.shape}."

        # prepare gaussian window
        win = win.to(target.device, dtype=pred.dtype)

        # calculate ssim
        K1, K2 = k  # noqa: N806
        C1 = (K1 * data_range) ** 2  # noqa: N806
        C2 = (K2 * data_range) ** 2  # noqa: N806

        mu1 = _apply_gaussian_filter(target, win)
        mu2 = _apply_gaussian_filter(pred, win)
        mu1_sq = mu1**2
        mu2_sq = mu2**2
        mu1_mu2 = mu1 * mu2

        sigma1_sq = _apply_gaussian_filter(target * target, win) - mu1_sq
        sigma2_sq = _apply_gaussian_filter(pred * pred, win) - mu2_sq
        sigma12 = _apply_gaussian_filter(target * pred, win) - mu1_mu2

        cs_map = (2 * sigma12 + C2) / (sigma1_sq + sigma2_sq + C2)
        ssim_map = ((2 * mu1_mu2 + C1) / (mu1_sq + mu2_sq + C1)) * cs_map
        ssim_per_channel = torch.flatten(ssim_map, 2).mean(-1)

        # force positivity of ssim
        ssim_per_channel = torch.relu(ssim_per_channel)
        return ssim_per_channel.mean(dim=1)


def _apply_gaussian_filter(x: torch.Tensor, win: torch.Tensor) -> torch.Tensor:
    """
    Blur x with 1-D kernel.

    Parameters
    ----------
    x : torch.Tensor
        a batch of tensors to be blurred
    win : torch.Tensor
        1-D gauss kernel.

    Returns
    -------
    torch.Tensor
        blurred tensors
    """
    assert all(ws == 1 for ws in win.shape[1:-1]), win.shape
    channels = x.shape[1]
    out = x
    for i, s in enumerate(x.shape[2:]):
        if s >= win.shape[-1]:
            out = torch_f.conv2d(
                out,
                weight=win.transpose(2 + i, -1),
                stride=1,
                padding=0,
                groups=channels,
            )
        else:
            warnings.warn(
                f"Skipping Gaussian Smoothing at dimension 2+{i} for input: {x.shape} and win size: {win.shape[-1]}",
                stacklevel=1,
            )

    return out


def _fspecial_gauss_1d(size: int, sigma: float) -> torch.Tensor:
    """Create 1-D gauss kernel."""
    coords = torch.arange(size, dtype=torch.float)
    coords -= size // 2
    g = torch.exp(-(coords**2) / (2 * sigma**2))
    g /= g.sum()
    return g.unsqueeze(0).unsqueeze(0)
