"""Series evaluation plot."""

from __future__ import annotations

from pathlib import Path
from typing import Callable

import matplotlib.pyplot as plt
import numpy as np
from numpy.typing import NDArray

from vito_cropsar.constants import PRECISION_FLOAT_NP, PRECISION_INT_NP, S2_SCALES

RGB = ["s2_b04", "s2_b03", "s2_b02"]


def plot_series(  # noqa: C901
    target: NDArray[PRECISION_FLOAT_NP],
    inp: NDArray[PRECISION_FLOAT_NP],
    pred: NDArray[PRECISION_FLOAT_NP],
    mask: NDArray[PRECISION_INT_NP],
    s2_bands: list[str],
    save_f: Path | None = None,
    writer_function: Callable | None = None,
) -> None:
    """
    Plot the complete series.

    Parameters
    ----------
    target : NDArray[PRECISION_FLOAT_NP]
        Target value
    inp : NDArray[PRECISION_FLOAT_NP]
        Augmented input (target with mask_merged applied)
    pred : NDArray[PRECISION_FLOAT_NP]
        Model inpainted prediction
    mask : NDArray[PRECISION_INT_NP]
        Mask indicating the missing values of S2-data
    s2_bands : list[str]
        List of Sentinel-2 bands
    save_f : Path
        File under which the result should get saved
    writer_function : Callable | None
        Function to write the figure to tensorboard
    """
    if save_f is not None:
        assert save_f.suffix == ".png", "Only PNG files are supported"
    assert (
        target.shape[1] == inp.shape[1] == pred.shape[1] == len(s2_bands)
    ), f"The number of bands should be the same for all arrays (target: {target.shape}, inp: {inp.shape}, pred: {pred.shape}, s2_bands: {len(s2_bands)})"

    # Select band indexes
    idx = [0]
    if set(RGB).issubset(set(s2_bands)):
        idx = [s2_bands.index(x) for x in RGB]
    elif "s2_fapar" in s2_bands:
        idx = [s2_bands.index("s2_fapar")]

    # Calculate the error
    error = np.clip(np.abs(target - pred), a_min=0.0, a_max=1.0)
    error = error[:, idx].mean(axis=1) if len(idx) > 1 else error[:, idx[0]]

    # Define scaling function
    def _correct_range(x: NDArray[PRECISION_FLOAT_NP]) -> NDArray[PRECISION_FLOAT_NP]:
        """Scale the requested array."""
        x = x.copy()
        for i, b in enumerate(s2_bands):
            x[:, i] = np.clip(x[:, i], S2_SCALES[b][0], S2_SCALES[b][1])
            x[:, i] = (x[:, i] - S2_SCALES[b][0]) / (S2_SCALES[b][1] - S2_SCALES[b][0])
        return x

    # Correct range of bands with expert rules
    target = _correct_range(target)
    inp = _correct_range(inp)
    pred = _correct_range(pred)

    # Define formatting function
    def _format(x: NDArray[PRECISION_FLOAT_NP], t: int) -> NDArray[PRECISION_FLOAT_NP]:
        """Format the requested array."""
        x = x[t, idx, ...].transpose(1, 2, 0)  # H, W, C
        if len(idx) == 3:  # noqa: PLR2004
            is_nan = np.isnan(x[:, :, 0])
            x[is_nan] = 1.0
        return x.clip(0.0, 1.0)

    # Plot the result
    fig, axs = plt.subplots(nrows=5, ncols=len(target), figsize=(len(target), 5))
    if save_f is not None:
        plt.suptitle(save_f.with_suffix("").name)
    else:
        plt.suptitle("Series evaluation")

    # Fill in the subplots
    plt.setp(axs, xticks=[], yticks=[])
    axs[0, 0].set_ylabel("target")
    axs[1, 0].set_ylabel("mask")
    axs[2, 0].set_ylabel("input")
    axs[3, 0].set_ylabel("prediction")
    axs[4, 0].set_ylabel("error (0 - 0.5)")
    for t in range(len(target)):
        axs[0, t].imshow(_format(target, t=t), vmin=0.0, vmax=1.0)
        axs[1, t].imshow(mask[t], vmin=0.0, vmax=1.0)
        axs[2, t].imshow(_format(inp, t=t), vmin=0.0, vmax=1.0)
        axs[3, t].imshow(_format(pred, t=t), vmin=0.0, vmax=1.0)
        axs[4, t].imshow(error[t], vmin=0.0, vmax=0.5)
        axs[4, t].set_xlabel(f"t_{t:02d}")

    # General formatting and save image
    plt.tight_layout()
    if save_f is not None:
        plt.savefig(save_f, dpi=200)
    if writer_function is not None:
        writer_function(figure=fig)
    plt.close()
