"""Create pixel-level evolution plots for the provided tile."""

from __future__ import annotations

from pathlib import Path
from typing import Callable

import matplotlib.pyplot as plt
import numpy as np
from numpy.typing import NDArray

from vito_cropsar.constants import PRECISION_FLOAT_NP

# Keep this fixed
TITLES = [
    "Quadrant 1",
    "Quadrant 2",
    "Quadrant 3",
    "Quadrant 4",
    "90% variance",
]


def plot_pixel_evolution(
    target: NDArray[PRECISION_FLOAT_NP],
    inp: NDArray[PRECISION_FLOAT_NP],
    pred: NDArray[PRECISION_FLOAT_NP],
    s2_bands: list[str],
    save_f: Path | None = None,
    writer_function: Callable | None = None,
    band: str = "s2_fapar",
) -> None:
    """
    Create the pixel-level evoluation plot.

    Parameters
    ----------
    target : NDArray[PRECISION_FLOAT_NP]
        Target value
    inp : NDArray[PRECISION_FLOAT_NP]
        Augmented input (target with mask_merged applied)
    pred : NDArray[PRECISION_FLOAT_NP]
        Model inpainted prediction
    mask : NDArray[PRECISION_INT_NP]
        Mask indicating the missing values of S2-data
    s2_bands : list[str]
        List of Sentinel-2 bands
    save_f : Path | None
        File under which the result should get saved
    band : str
        Sentinel-2 band to visualise in left-side image and over time-series
        Note: Can only be one band, and only this band will be plotted
        Note: Should be present in s2_bands
    writer_function : Callable
        Function to write the figure to tensorboard
    """
    if save_f is not None:
        assert save_f.suffix == ".png", "Only PNG files are supported"
    assert band in s2_bands, f"Band {band} not in s2_bands: {s2_bands}"
    assert (
        target.shape[1] == inp.shape[1] == pred.shape[1] == len(s2_bands)
    ), f"The number of bands should be the same for all arrays (target: {target.shape}, inp: {inp.shape}, pred: {pred.shape}, s2_bands: {len(s2_bands)})"

    # Select the band to visualise
    b_idx = s2_bands.index(band)
    target = target[:, b_idx]
    inp = inp[:, b_idx]
    pred = pred[:, b_idx]

    # Extract the indices with the most variance
    indices = _get_variance_indices(target)

    # Create the plot
    _, axs = plt.subplot_mosaic(
        "AAAAABBBBBB;AAAAACCCCCC;AAAAADDDDDD;AAAAAEEEEEE;AAAAAFFFFFF",
        figsize=(11, 5),
    )
    if save_f is not None:
        plt.suptitle(save_f.with_suffix("").name)
    else:
        plt.suptitle("Pixel-level evolution")

    # Plot the main image
    im = _get_image(target)
    v_min, v_max = np.nanmin(target), np.nanmax(target)
    im_w, im_h = im.shape[-2:]
    im = np.clip((im - v_min) / (v_max - v_min), a_min=0, a_max=1)
    axs["A"].imshow(im, vmin=0, vmax=1)
    axs["A"].set_axis_off()
    for i, a in enumerate(indices):
        axs["A"].plot(a % 128, a // 128, "ro")
        axs["A"].text(a % 128 + 2, a // 128 - 1, TITLES[i], color="k", fontsize=8)

    # Plot the first time-series
    def create_time_series(ax: plt.Axes, i: int) -> None:
        """Create a time series plot."""
        idx = indices[i]
        ax.plot(
            pred[:, idx // im_w, idx % im_h],
            linestyle="--",
            linewidth=0.5,
            color="b",
            label=TITLES[i],
            zorder=2,
        )
        ax.scatter(
            range(len(target)),
            target[:, idx // im_w, idx % im_h],
            linewidth=1.5,
            marker="o",
            color="r",
            facecolors="none",
            label="Target",
            zorder=3,
        )
        ax.scatter(
            range(len(inp)),
            inp[:, idx // im_w, idx % im_h],
            linewidth=1,
            marker="x",
            color="g",
            label="Input",
            zorder=4,
        )
        ax.set_xticks(range(len(target)), ["" for _ in range(len(target))])
        ax.set_xlim(-0.5, len(target) - 0.5)
        ax.set_ylim(v_min - 0.1, v_max + 0.1)
        ax.yaxis.tick_right()
        ax.grid(axis="x", linestyle="--", linewidth=0.5, zorder=1)
        ax.legend(fontsize=6)

    # Create the time series plots
    create_time_series(axs["B"], i=0)
    create_time_series(axs["C"], i=1)
    create_time_series(axs["D"], i=2)
    create_time_series(axs["E"], i=3)
    create_time_series(axs["F"], i=4)

    # General formatting and save image
    plt.tight_layout()
    if save_f is not None:
        plt.savefig(save_f, dpi=200)
    if writer_function:
        writer_function(figure=plt.gcf())
    plt.close()


def _get_image(
    arr: NDArray[PRECISION_FLOAT_NP],
    thr: float = 0.1,
) -> NDArray[PRECISION_FLOAT_NP]:
    """
    Get a good image representative, located in the middle of the series.

    Parameters
    ----------
    arr : NDArray[PRECISION_FLOAT_NP]
        The array for which to extract the visible parts
    thr : float
        Threshold for the maximum percentage of NaNs in a pixel

    Returns
    -------
    NDArray[PRECISION_FLOAT_NP]
        The representative image
    """
    visible = [
        i for i, x in enumerate(arr) if (np.isnan(x).sum() / len(x.flatten())) < thr
    ]
    return arr[visible[len(visible) // 2]].copy() if visible else arr[0]


def _get_variance_indices(arr: NDArray[PRECISION_FLOAT_NP]) -> list[int]:
    """
    Calculate the indices with high variance to plot.

    Notes
    -----
     - This function plots the index with most variance in each quadrant of the image (first four indices)
     - This function plots the 90% index ranked on  variance in the image (fifth index)

    Parameters
    ----------
    arr : NDArray[PRECISION_FLOAT_NP]
        The array for which to extract the indices

    Returns
    -------
    list[int]
        The indices with high variance
    """
    assert (
        arr.ndim == 3  # noqa: PLR2004
    ), f"Array should be 3-dimensional, but is {arr.ndim}-dimensional"

    # Get the global variance
    t, w, h = arr.shape
    arr_ = arr.reshape(t, w * h)
    arr_ = np.swapaxes(  # Better way of doing this?
        np.stack(
            [
                np.interp(
                    np.arange(len(x)),
                    np.arange(len(x))[~np.isnan(x)],
                    x[~np.isnan(x)],
                )
                if (~np.isnan(x)).any()
                else np.zeros_like(x)
                for x in np.swapaxes(arr_, 0, 1)
            ]
        ),
        0,
        1,
    )
    var = np.nanvar(arr_, axis=0)  # Variance over time
    args = var.argsort()

    # Get the variance in each quadrant
    w, h = arr.shape[1:]
    half_w, half_h = w // 2, h // 2

    def _is_in_quadrant(i: int, q: int, offset: int = 5) -> bool:
        """Check if the index is in the quadrant."""
        x, y = i // w, i % w
        # Take buffer from edge into account
        if (x < offset) or (y < offset) or (x > w - offset) or (y > h - offset):
            return False

        # Check if the index is in the quadrant
        return ((x < half_w) == (q in {0, 1})) and ((y < half_h) == (q in {0, 2}))

    def _get_quadrant_variance(q: int) -> float:
        """Get the variance in the quadrant."""
        args_ = [a for a in args if _is_in_quadrant(a, q=q)]
        assert 0 <= len(args_) <= (len(args) // 4)  # At most quandrant size
        return args_[-1]

    # Calculate the indices with most variance in each quadrant
    indices = [_get_quadrant_variance(q) for q in range(4)]

    # Get the 90% index ranked on variance in the image
    indices.append(args[int(len(args) * 0.9)])
    return indices
