"""Model utilities."""

from __future__ import annotations

import cv2
import numpy as np
from numpy.typing import NDArray

from vito_cropsar.constants import PRECISION_FLOAT_NP


def get_repr(x: NDArray[PRECISION_FLOAT_NP]) -> NDArray[PRECISION_FLOAT_NP]:
    """
    Get the edge representative for the complete time stack (4D array).

    Parameters
    ----------
    x : NDArray[PRECISION_FLOAT_NP]
        The array to get the edges from
        Shape: (T, C, H, W)

    Returns
    -------
    NDArray[PRECISION_FLOAT_NP]
        The edges of the array
        Shape: (T, H, W)
        Values: 0 if no edge (or masked out), 1 if edge, range 0..1 if sometimes edge
    """
    # Check in which places there are no NaNs
    not_nan = [np.mean(~np.isnan(x)) >= 0.2 for x in x]  # noqa: PLR2004
    if not any(not_nan):
        return np.zeros((x.shape[0], x.shape[2], x.shape[3]))
    x_ = x[not_nan]

    # Get the edges
    e = np.stack([_get_edges_ts(x_t) for x_t in x_]).sum(axis=0)
    d = (~np.isnan(x_)).any(axis=1).sum(axis=0)
    r = e / np.clip(d, a_min=1, a_max=None)
    vmin, vmax = np.quantile(r, (0.00, 0.99))
    r = np.clip((r - vmin) / max(vmax - vmin, 1e-6), 0, 1) ** 2

    # Duplicate the representative over the time dimension
    return np.stack([r] * x.shape[0])


def _get_edges_ch(x: NDArray[PRECISION_FLOAT_NP]) -> NDArray[PRECISION_FLOAT_NP]:
    """
    Get the edges for a single time step and single channel (2D array).

    Parameters
    ----------
    x : NDArray[PRECISION_FLOAT_NP]
        The array to get the edges from
        Shape: (H, W)

    Returns
    -------
    NDArray[PRECISION_FLOAT_NP]
        The edges of the array
        Shape: (H, W)
        Values: 0 if no edge (or masked out), 1 if edge
    """
    if np.isnan(x).all():
        return np.zeros_like(x)

    # Normalize to [0, 1]
    x_ = (x - np.nanmin(x)) / max(np.nanmax(x) - np.nanmin(x), 1e-6)
    x_[np.isnan(x)] = 0.0

    # Get the edges
    x_ = (
        cv2.Canny(
            image=(255 * x_).astype(np.uint8),
            threshold1=0,
            threshold2=100,
        )
        / 255.0
    ).astype(np.float32)

    # Mask out the NaNs using a dilated mask (no edges at side of clouds)
    is_nan = cv2.dilate(
        np.isnan(x).astype(np.uint8),
        kernel=np.ones((3, 3), np.uint8),
        iterations=1,
    ).astype(np.bool_)
    x_[is_nan] = 0.0
    return x_


def _get_edges_ts(x: NDArray[PRECISION_FLOAT_NP]) -> NDArray[PRECISION_FLOAT_NP]:
    """
    Get the edges for a single time step, average across the channels (3D array).

    Parameters
    ----------
    x : NDArray[PRECISION_FLOAT_NP]
        The array to get the edges from
        Shape: (C, H, W)

    Returns
    -------
    NDArray[PRECISION_FLOAT_NP]
        The edges of the array
        Shape: (H, W)
        Values: 0 if no edge, 1 if always edge, range 0..1 if sometimes edge
    """
    x = np.stack([_get_edges_ch(x[i, :, :]) for i in range(x.shape[0])]).mean(axis=0)
    return np.zeros_like(x) if np.isnan(x).any() else x
