"""Set of metrics functions."""

from __future__ import annotations

import torch


def _get_coverage_and_overlap(
    gt: torch.Tensor, pr: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
    """Compute coverage and overlapping matrices."""
    gt_ids = [x for x in torch.unique(gt) if x > 0]
    pr_ids = [x for x in torch.unique(pr) if x > 0]

    # compute overlap and coverage matrices
    overlap_m = torch.zeros(len(pr_ids), len(gt_ids))
    coverage_m = torch.zeros(len(gt_ids), len(pr_ids))
    for p, p_id in enumerate(pr_ids):
        for t, t_id in enumerate(gt_ids):
            t_arr = gt == t_id
            p_arr = pr == p_id
            # how much is the prediction overlapping the target in respect to its area
            overlap_m[p, t] = (t_arr & p_arr).sum() / (p_arr).sum()

            # how much is the target covered by the prediction
            coverage_m[t, p] = (t_arr & p_arr).sum() / (t_arr).sum()
    return coverage_m, overlap_m


def _merge_oversegmentation(
    pr: torch.Tensor,
    coverage_m: torch.Tensor,
    overlap_m: torch.Tensor,
    overlapping_thr: float = 0.8,
) -> torch.Tensor:
    """Merge oversegmented prediction ids."""
    pr_ids = [x for x in torch.unique(pr) if x > 0]

    # filter bad overlapping fields from coverage
    coverage_m_ = coverage_m.clone()
    coverage_m_[(overlap_m < overlapping_thr).T] = 0
    coverage_score = coverage_m_.sum(1)

    # merge ids
    for t, cvg_scr in enumerate(coverage_score):
        if cvg_scr > 0:
            p_ids = torch.Tensor(pr_ids)[torch.where(coverage_m_[t] > 0)[0]]
            for p_id in p_ids:
                pr[pr == p_id] = p_ids[0]
    return pr


def _get_overflows(
    pr: torch.Tensor, coverage_m: torch.Tensor, overflow_thr: float = 0.4
) -> torch.Tensor:
    """Get overflowing prediction ids."""
    pr_ids = [x for x in torch.unique(pr) if x > 0]
    overflowing_mask = (coverage_m > overflow_thr).sum(0) > 1
    return torch.Tensor(pr_ids)[overflowing_mask]


def get_pr(gt: torch.Tensor, pr: torch.Tensor) -> dict[str, torch.Tensor]:
    """Calculate the precision and recall array."""
    # Gather parameters
    coverage_m, overlap_m = _get_coverage_and_overlap(gt=gt, pr=pr)
    overlflowing_ids = _get_overflows(pr=pr, coverage_m=coverage_m)
    pr = _merge_oversegmentation(pr=pr, coverage_m=coverage_m, overlap_m=overlap_m)

    # compute ids
    gt_ids = [x for x in torch.unique(gt) if x > 0]
    pr_ids = [x for x in torch.unique(pr) if x > 0]

    # Calculate the precison recall matrix
    pr_m = torch.zeros(len(gt_ids), len(pr_ids))
    for i, t_id in enumerate(gt_ids):
        t_arr = gt == t_id
        for j, p_id in enumerate(pr_ids):
            if int(p_id) in overlflowing_ids:
                pr_m[i, j] = 0.0
            else:
                p_arr = pr == p_id
                pr_m[i, j] = (t_arr & p_arr).sum() / (t_arr | p_arr).sum()

    # Return the processed result
    return {
        "gt_ids": torch.stack(gt_ids) if gt_ids else torch.IntTensor([]),
        "pr_ids": torch.stack(pr_ids) if pr_ids else torch.IntTensor([]),
        "precision": torch.zeros(pr_m.shape[1])
        if (pr_m.shape[0] == 0)
        else pr_m.max(axis=0).values,
        "recall": torch.zeros(pr_m.shape[0])
        if (pr_m.shape[1] == 0)
        else pr_m.max(axis=1).values,
        "gt_size": torch.IntTensor([(gt == t_id).sum() for t_id in gt_ids]),
        "pr_size": torch.IntTensor([(pr == p_id).sum() for p_id in pr_ids]),
    }
