"""Custom U-Net 2D loss functions."""

from __future__ import annotations

import torch

from vito_cropsar.constants import PRECISION_FLOAT


@torch.no_grad()
def calculate_metrics(
    target: torch.Tensor,
    pred: torch.Tensor,
    mask: torch.Tensor,
    acc_thr: float = 0.1,
) -> dict[str, PRECISION_FLOAT]:
    """
    Calculate all the desired metrics.

    Parameters
    ----------
    target : torch.Tensor
        Target tensor
        Shape: (batch, time, bands, width, height)
    pred : torch.Tensor
        Predicted tensor
        Shape: (batch, time, bands, width, height)
    mask : torch.Tensor
        Mask tensor
        Shape: (batch, time, width, height)
    acc_thr : float
        Accuracy threshold (correct if error <= threshold)

    Returns
    -------
    dict[str, float]
        Dictionary containing the aggregated metrics
    """
    # Calculate sample-wise metrics
    results = {
        k: [] for k in ("mae_unmasked", "mae_masked", "acc_unmasked", "acc_masked")
    }
    for t, p, m in zip(target, pred, mask):
        results_ = _metrics_sample(t=t, p=p, m=m, acc_thr=acc_thr)
        for k, v in results_.items():
            results[k].append(v)

    # Return aggregated results
    return {k: torch.stack(v).mean() for k, v in results.items()}


def _metrics_sample(
    t: torch.Tensor,
    p: torch.Tensor,
    m: torch.Tensor,
    acc_thr: float = 0.1,
) -> dict[str, PRECISION_FLOAT]:
    """Calculate the unmasked and masked metrics for the provided sample."""
    e_unm, e_m = get_mask_errors(t=t, p=p, m=m)
    return {
        "mae_unmasked": e_unm.mean(),
        "mae_masked": e_m.mean(),
        "acc_unmasked": (e_unm <= acc_thr).sum() / e_unm.numel(),
        "acc_masked": (e_m <= acc_thr).sum() / e_m.numel(),
    }


def get_mask_errors(
    t: torch.Tensor,
    p: torch.Tensor,
    m: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Calculate the unmasked and masked errors for the provided sample.

    Parameters
    ----------
    t : torch.Tensor
        Target tensor
        Shape: (time, bands, width, height)
    p : torch.Tensor
        Predicted tensor
        Shape: (time, bands, width, height)
    m : torch.Tensor
        Mask tensor
        Shape: (time, width, height)

    Returns
    -------
    tuple[torch.Tensor, torch.Tensor]
        Unmasked and masked errors
    """
    # Calculate the absolute error
    e = torch.abs(t - p)
    m = m[:, None].repeat(1, e.shape[1], 1, 1)

    # Get the error of the unmasekd area (target exists and not masked out)
    e_unm = e[~torch.isnan(t) & (m == 1)]

    # Get the error of the masked area (target exists but masked out)
    e_m = e[~torch.isnan(t) & (m == 0)]

    return e_unm, e_m
