"""Structural Similarity calculation."""

from __future__ import annotations

import numpy as np
from numpy.typing import NDArray
from skimage.metrics import structural_similarity


def get_ssim_score(
    pred: NDArray[np.float_],
    target: NDArray[np.float_],
    mask: NDArray[np.int_],
) -> tuple[float, ...]:
    """
    Calculate the structural similarity score.

    Parameters
    ----------
    pred : NDArray[np.float_]
        The predicted image
        Shape: (time, channel, width, height)
        Note: Assumed to be within the 0..1 range
    target : NDArray[np.float_]
        The target image
        Shape: (time, channel, width, height)
        Note: Assumed to be within the 0..1 range
    mask : NDArray[np.int_]
        The original mask
        Shape: (time, width, height)

    Returns
    -------
    float
        The Structural Similarity (SSIM) score
        Note: Perfect similarity is 1.0
    """
    ssim = [
        masked_ssim(
            pred=pred[i],
            target=target[i],
            mask=mask[i],
        )
        for i, m in enumerate(mask.mean(axis=(1, 2)))
        if m >= 0.1  # noqa: PLR2004
    ]
    return [float(x) for x in np.stack(ssim).mean(0)]


def masked_ssim(
    pred: NDArray[np.float_],
    target: NDArray[np.float_],
    mask: NDArray[np.int_],
) -> tuple[float, ...]:
    """Calculate the masked structural similarity score."""
    pred[np.isnan(target)] = 0.0
    target[np.isnan(target)] = 0.0
    _, ssim_im = structural_similarity(
        target,
        pred,
        win_size=11,
        data_range=1,
        channel_axis=0,
        full=True,
    )
    ssim_im[np.repeat(~mask[None,], ssim_im.shape[0], axis=0)] = np.nan
    return np.nanmean(ssim_im, axis=(1, 2))
