"""Artifact calculation."""

from __future__ import annotations

import cv2
import numpy as np
from numpy.typing import NDArray


def get_artifacts_score(
    pred: NDArray[np.float_],
    target: NDArray[np.float_],
    mask: NDArray[np.int_],
    mask_agg: NDArray[np.int_],
) -> float:
    """
    Calculate the artifacts score.

    Parameters
    ----------
    pred : NDArray[np.float_]
        The predicted image
        Shape: (time, channel, width, height)
    target : NDArray[np.float_]
        The target image
        Shape: (time, channel, width, height)
    mask : NDArray[np.int_]
        The original mask
        Shape: (time, width, height)
    mask_agg : NDArray[np.int_]
        The aggregated mask
        Shape: (time, width, height)

    Returns
    -------
    float
        The artifacts score, normalized by the amount of edge (mask) pixels present
    """
    # Only consider the time-steps where there is a mask
    has_mask = [i for i, m in enumerate(mask_agg.any(axis=(1, 2))) if m]

    # Get the scores
    scores = [
        get_artifacts_score_ts(
            pred=pred[ts], target=target[ts], mask=mask[ts], mask_agg=mask_agg[ts]
        )
        for ts in has_mask
    ]

    # Return the mean score
    a, b = zip(*scores)
    return sum(a) / max(sum(b), 1)


def get_artifacts_score_ts(
    pred: NDArray[np.float_],
    target: NDArray[np.float_],
    mask: NDArray[np.int_],
    mask_agg: NDArray[np.int_],
) -> tuple[float, float]:
    """
    Calculate the artifacts score for a specified time-step and channel.

    Parameters
    ----------
    pred : NDArray[np.float_]
        The predicted image
        Shape: (time, channel, width, height)
    target : NDArray[np.float_]
        The target image
        Shape: (time, channel, width, height)
    mask : NDArray[np.int_]
        The original mask
        Shape: (time, width, height)
    mask_agg : NDArray[np.int_]
        The aggregated mask
        Shape: (time, width, height)

    Returns
    -------
    tuple[float, float]
        The channel-average pixel-based artifact score
        The channel-average number of pixels in the focus
    """
    assert len(pred.shape) == len(target.shape) == 3  # noqa: PLR2004
    assert len(mask.shape) == len(mask_agg.shape) == 2  # noqa: PLR2004

    # Calculate the score for each channel
    scores = [
        get_artifacts_score_ts_ch(
            pred=pred[ch], target=target[ch], mask=mask, mask_agg=mask_agg
        )
        for ch in range(pred.shape[0])
    ]

    # Return the mean values for artifacts and focus
    a, b = zip(*scores)
    return sum(a) / len(scores), sum(b) / len(scores)


def get_artifacts_score_ts_ch(
    pred: NDArray[np.float_],
    target: NDArray[np.float_],
    mask: NDArray[np.int_],
    mask_agg: NDArray[np.int_],
) -> tuple[float, float]:
    """
    Calculate the artifacts score for a specified time-step and channel.

    Parameters
    ----------
    pred : NDArray[np.float_]
        The predicted image
        Shape: (time, channel, width, height)
    target : NDArray[np.float_]
        The target image
        Shape: (time, channel, width, height)
    mask : NDArray[np.int_]
        The original mask
        Shape: (time, width, height)
    mask_agg : NDArray[np.int_]
        The aggregated mask
        Shape: (time, width, height)

    Returns
    -------
    tuple[float, float]
        The pixel-based artifact score
        The number of pixels in the focus
    """
    assert len(pred.shape) == len(target.shape) == 2  # noqa: PLR2004
    assert len(mask.shape) == len(mask_agg.shape) == 2  # noqa: PLR2004

    # Get the edges
    edges_pred = get_edges(pred)
    edges_target = get_edges(target)

    # Get the difference
    edges_diff = edges_pred - edges_target
    edges_diff[edges_diff < 0] = 0

    # Get the focus
    focus = get_focus(mask=mask, mask_agg=mask_agg)

    # Get the artifact score
    edges_diff_ = edges_diff * focus
    return np.sum(edges_diff_), np.sum(focus)


def get_edges(x: NDArray[np.float_]) -> NDArray[np.float_]:
    """Get the edges."""
    # Apply transformations
    x = x.copy()
    x[np.isnan(x)] = 0
    x = (255 * x).astype(np.uint8)

    # Get the edges over varying thresholds
    edges = np.zeros_like(x, dtype=np.float_)
    for i in range(5):
        canny = cv2.Canny(
            image=x,
            threshold1=i * 100,
            threshold2=(i + 1) * 100,
        )
        edges[canny > 0] = (i + 1) / 5
    return edges


def get_focus(mask: NDArray[np.int_], mask_agg: NDArray[np.int_]) -> NDArray[np.int_]:
    """Get the focus boundary."""
    assert len(mask.shape) == len(mask_agg.shape) == 2  # noqa: PLR2004

    # Get the canny edges of the masks
    canny = cv2.Canny(
        mask.astype(np.uint8),
        threshold1=0,
        threshold2=0,
    )
    canny_agg = cv2.Canny(
        mask_agg.astype(np.uint8),
        threshold1=0,
        threshold2=0,
    )

    # Check the overlap
    diff = canny_agg - canny
    diff[diff != 255] = 0  # noqa: PLR2004
    diff[diff == 255] = 1  # noqa: PLR2004

    # # Enlarge the focus
    # diff = cv2.dilate(diff, np.ones((2, 2), np.uint8), iterations=1)
    return diff
