"""Jitter resolution code."""

from __future__ import annotations

import cv2
import numpy as np
from numpy.typing import NDArray

from vito_cropsar.constants import PRECISION_FLOAT_NP


def align_stack(
    stack: NDArray[PRECISION_FLOAT_NP],
    repr_e: NDArray[PRECISION_FLOAT_NP],
    bands: list[int] | None = None,
) -> NDArray[PRECISION_FLOAT_NP]:
    """
    Align the provided stack of images.

    Parameters
    ----------
    stack : NDArray[PRECISION_FLOAT_NP]
        The stack of images to align
        Shape: (T, C, H, W)
    repr_e : NDArray[PRECISION_FLOAT_NP]
        The representative edge image
        Shape: (H, W)
    bands : list[int] | None
        The bands to use to align the stack
        Note: All bands are used if None

    Returns
    -------
    NDArray[PRECISION_FLOAT_NP]
        The aligned stack
        Shape: (T, C, H, W)
    """
    return np.stack([correct_jitter_im(x, repr_e=repr_e, bands=bands) for x in stack])


def correct_jitter_im(
    im: NDArray[PRECISION_FLOAT_NP],
    repr_e: NDArray[PRECISION_FLOAT_NP],
    bands: list[int] | None = None,
) -> NDArray[PRECISION_FLOAT_NP]:
    """
    Align the provided image with the reference image.

    Parameters
    ----------
    im : NDArray[PRECISION_FLOAT_NP]
        The image to align
        Shape: (C, H, W)
    repr_e : NDArray[PRECISION_FLOAT_NP]
        The representative edge image
        Shape: (H, W)
    bands : list[int] | None
        The bands to use to align the image
        Note: All bands are used if None

    Returns
    -------
    NDArray[PRECISION_FLOAT_NP]
        The aligned image
    """
    # Get the global offset for the specified bands
    offsets = np.array(
        [
            git_jitter_band(im_b=im[b_idx], repr_e=repr_e)
            for b_idx in bands or range(im.shape[0])
        ]
    )
    dx, dy = np.nanmean(offsets, axis=0) if (~np.isnan(offsets)).any() else (0.0, 0.0)

    # Move each channel with the same offset
    m = np.float32([[1, 0, dx], [0, 1, dy]])
    return np.stack(
        [
            cv2.warpAffine(im[i], m, im.shape[1:], borderValue=np.nan)
            for i in range(im.shape[0])
        ]
    )


def git_jitter_band(
    im_b: NDArray[PRECISION_FLOAT_NP],
    repr_e: NDArray[PRECISION_FLOAT_NP],
) -> tuple[float, float]:
    """
    Get the jitter offset for the specific image.

    Parameters
    ----------
    im_b : NDArray[PRECISION_FLOAT_NP]
        The image of one specific band to align
        Shape: (H, W)
    repr_e : NDArray[PRECISION_FLOAT_NP]
        The representative edge image
        Shape: (H, W)

    Returns
    -------
    tuple[float, float]
        The offset in x and y direction
        Note: Can be (NaN, NaN) if image is completely NaN
    """
    # Assure no NaNs are present, return NaN offsets if full NaN
    if np.isnan(im_b).all():
        return np.nan, np.nan

    # Get the edges, ignore masked areas
    repr_e = repr_e.copy()
    im_e = get_edges(im_b)
    repr_e[np.isnan(im_e)] = 0.0
    im_e[np.isnan(im_e)] = 0.0

    # Extract the offset in x and y direction, no shift if unreasonably high
    #  Note: Unreasonably high shifts can be introduced by high noise from clouds
    (dx, dy), _ = cv2.phaseCorrelate(im_e, repr_e)
    if abs(dx) >= 5.0 or abs(dy) >= 5.0:  # noqa: PLR2004
        return 0.0, 0.0
    return dx, dy


def get_repr(
    arr: NDArray[PRECISION_FLOAT_NP],
    thr: float = 0.2,
    bands: list[int] | None = None,
) -> NDArray[PRECISION_FLOAT_NP]:
    """
    Get the best representative of the provided array.

    Parameters
    ----------
    arr : NDArray[PRECISION_FLOAT_NP]
        Array to get the representative of
        Shape: (T, C, W, H)
    thr : float
        Threshold for the minimum percentage of NaNs for one-step to be valid
        Note: Only valid images are used to determine the representative (attempt to avoid clouds)
    bands : list[int] | None
        The bands to use to extract the representative edges
        Note: All bands are used if None

    Returns
    -------
    NDArray[PRECISION_FLOAT_NP]
        Edge representative of the array
        Shape: (W, H)
        Values: 1 if edge representative, 0 otherwise
    """
    # Check in which places there are no NaNs
    not_nan = [np.mean(~np.isnan(x)) >= thr for x in arr]
    assert len(not_nan) > 0, "No valid images found"
    arr_ = arr[not_nan]

    # Gather all the edges over all the channels
    edges = np.zeros(arr_.shape[1:])
    for b_idx in bands or range(arr_.shape[1]):
        edges[b_idx] = np.nansum([get_edges(x[b_idx]) for x in arr_], axis=0)

    # Flatten on channel-level
    edges = np.nansum(edges, axis=0)

    # Gather all the edges that have a value above 50% of the maximum edge count
    return ((edges / edges.max()) ** 2).astype(np.float32)


def get_edges(x: NDArray[PRECISION_FLOAT_NP]) -> NDArray[PRECISION_FLOAT_NP]:
    """
    Get the edges of a 2D array.

    Parameters
    ----------
    x : NDArray[PRECISION_FLOAT_NP]
        The array to get the edges from
        Shape: (H, W)

    Returns
    -------
    NDArray[PRECISION_FLOAT_NP]
        The edges of the array
        Shape: (H, W)
        Values: 0 if no edge, 1 if edge
    """
    if np.isnan(x).all():
        return np.zeros_like(x)

    # Normalize to [0, 1]
    x_ = (x - np.nanmin(x)) / max(np.nanmax(x) - np.nanmin(x), 1e-5)
    x_[np.isnan(x)] = 0.0

    # Get the edges
    x_ = (
        cv2.Canny(
            image=(255 * x_).astype(np.uint8),
            threshold1=0,
            threshold2=100,
        )
        / 255.0
    ).astype(np.float32)

    # Mask out the NaNs using a dilated mask (no edges at side of clouds)
    is_nan = cv2.dilate(
        np.isnan(x).astype(np.uint8),
        kernel=np.ones((3, 3), np.uint8),
        iterations=1,
    ).astype(np.bool_)
    x_[is_nan] = np.nan
    return x_


if __name__ == "__main__":
    from time import time

    import matplotlib.pyplot as plt

    from vito_cropsar.data.io import load_sample

    # Load the data
    sample = load_sample(
        split="training",
        tile="31UFR_10240_10496_5888_6144_2019-06-16",
        bands_s2=["s2_fapar", "s2_b02", "s2_b03", "s2_b04"],
    )
    arr_s1, arr_s2 = sample["s1"], sample["s2"]

    # Get the representative
    start = time()
    repr_ = get_repr(arr_s2)
    print(f"Found representative in: {time() - start:.3f}s")

    # Correct the jitter for S1
    start = time()
    arr_ = align_stack(arr_s1, repr_e=repr_)
    print(f"Aligned the full S1 stack in: {time() - start:.3f}s")
    print(" -> Output shape:", arr_.shape)

    # Correct the jitter for S2
    start = time()
    arr_ = align_stack(arr_s2, repr_e=repr_)
    print(f"Aligned the full S2 stack in: {time() - start:.3f}s")
    print(" -> Output shape:", arr_.shape)

    # Plot the representative
    plt.figure()
    plt.imshow(repr_, vmin=0, vmax=1)
    plt.show()
