"""Calculate the evalution scores."""

from __future__ import annotations

import numpy as np
from numpy.typing import NDArray

from vito_cropsar.constants import PRECISION_FLOAT_NP, PRECISION_INT_NP


def calculate_results(
    target: NDArray[PRECISION_FLOAT_NP],
    pred: NDArray[PRECISION_FLOAT_NP],
    mask: NDArray[PRECISION_INT_NP],
    acc_thr: float = 0.05,
) -> dict[str, float]:
    """
    Generate the predictive report.

    Notes
    -----
     - Model inputs are expected to be scaled from 0..1

    Parameters
    ----------
    target : NDArray[PRECISION_FLOAT_NP]
        Target inpainting (partly obscured)
        Shape: (time, channels, width, height)
    pred : NDArray[PRECISION_FLOAT_NP]
        Predicted inpainting
        Shape: (time, channels, width, height)
    mask : NDArray[PRECISION_INT_NP]
        Applied mask (real + generated)
        Shape: (time, width, height)
    acc_thr : float
        Accuracy threshold (correct if error <= threshold)
        Note: This comes down to 5%, since the input is scaled from 0..1

    Returns
    -------
    result : dict[str, float]
        Generated evaluation metrics: mae_total, mae_masked
    """
    # Calculate the errors
    e_unm, e_m = _get_errors(
        target=target,
        pred=pred,
        mask=mask,
    )

    # Calculat the metrics and return
    e_unm_, e_m_ = e_unm.flatten(), e_m.flatten()
    return {
        "mae_unmasked": float(e_unm_.mean()) if len(e_unm_) > 0 else 0.0,
        "mae_masked": float(e_m_.mean()) if len(e_m_) > 0 else 0.0,
        "acc_unmasked": float((e_unm_ <= acc_thr).sum() / max(len(e_unm_), 1)),
        "acc_masked": float((e_m_ <= acc_thr).sum() / max(len(e_m_), 1)),
    }


def _get_errors(
    target: NDArray[PRECISION_FLOAT_NP],
    pred: NDArray[PRECISION_FLOAT_NP],
    mask: NDArray[PRECISION_INT_NP],
) -> tuple[NDArray[PRECISION_FLOAT_NP], NDArray[PRECISION_FLOAT_NP]]:
    """Calculate the unmasked and masked error."""
    # Calculate the absolute error
    e = np.abs(target - pred)
    m = mask[:, None].repeat(e.shape[1], axis=1)

    # Get the error of the unmasekd area (target exists and not masked out)
    e_unm = e[~np.isnan(target) & (m == 1)]

    # Get the error of the masked area (target exists but masked out)
    e_m = e[~np.isnan(target) & (m == 0)]

    return e_unm, e_m
