"""Data IO operations."""

from __future__ import annotations

from functools import lru_cache
from pathlib import Path

import numpy as np
from numpy.typing import NDArray

from vito_cropsar.constants import (
    PRECISION_FLOAT_NP,
    PRECISION_INT_NP,
    S1,
    S2,
    SPLITS,
    get_data_folder,
)
from vito_cropsar.data.jitter import align_stack, get_repr
from vito_cropsar.data.utils import apply_mask, format_cloud, process_s1, process_s2
from vito_cropsar.vito_logger import LogLevel, bh_logger


@lru_cache
def list_tiles(
    split: str,
    data_f: Path | None = None,
    cache_tag: str | None = None,
) -> list[str]:
    """List all the possible tiles."""
    data_f = data_f or get_data_folder()
    assert split in SPLITS, f"Split should be part of {', '.join(SPLITS)}"
    tiles = (
        {x.with_suffix("").name for x in (data_f / split).glob("*.npz")}
        if (cache_tag is None)
        else {
            x.with_suffix("").name
            for x in (data_f / f"{split}_{cache_tag}").glob("*.npz")
        }
    )
    return sorted(tiles)


@lru_cache
def load_masks(data_f: Path | None = None) -> NDArray[PRECISION_INT_NP]:
    """Load in all the supported masks."""
    data_f = data_f or get_data_folder()
    return np.load(data_f / "masks.npy").astype(PRECISION_INT_NP)


def load_sample(
    split: str,
    tile: str,
    data_f: Path | None = None,
    has_mask_: str = True,
    bands_s1: list[str] | None = None,
    bands_s2: list[str] | None = None,
    cache_tag: str | None = None,
    smooth_s1: bool = True,
    align: bool = False,
    verbose: bool = True,
) -> dict[str, NDArray[PRECISION_FLOAT_NP | PRECISION_INT_NP]]:
    """
    Load one sample.

    Parameters
    ----------
    split : str
        The split to extract the tile from
    tile : str
        The tile's unique identifier
    data_f: str
        Folder where data is located, by default ./data/data
    has_mask_ : bool
        Whether the mask_ should exist (only applicable for 'testing' split)
    bands_s1 : list[str]
        The bands for S1 to load in, by default all
    bands_s2 : list[str]
        The bands for S2 to load in, by default all
    cache : str | None
        Cache postfix to check for cached samples
    smooth_s1 : bool
        Whether to smooth the S1 data
        Note: The speckle filter is used to do this
    align : bool
        Whether to align the S1 and S2 data
    verbose : bool
        Whether to print out logger messages

    Returns
    -------
    result : dict[str, NDArray[PRECISION_FLOAT_NP]]
        s1 : NDArray[PRECISION_FLOAT_NP]
            S1 satellite values
            Shape: (time, channels_s1, width, height)
        s2 : NDArray[PRECISION_FLOAT_NP]
            S2 satellite values, can be np.nan
            Shape: (time, channels_s2, width, height)
        mask : NDArray[PRECISION_INT_NP]
            Mask values, which reflects the s2's np.nan values
            Note: 1 means cloud-free, 0 means cloud (masked out)
            Shape: (time, width, height)
        mask_ : NDArray[PRECISION_INT_NP], optional
            Optionally fixed artificial mask to be applied on the inputs
            Note: 1 means cloud-free, 0 means cloud (masked out)
            Shape: (time, width, height)
    """
    # define data folder
    data_f = data_f or get_data_folder()

    # Check the inputs
    bands_s1 = _sort_bands_s1(tuple(bands_s1) if bands_s1 else ())
    bands_s2 = _sort_bands_s2(tuple(bands_s2) if bands_s2 else ())

    # Check if cache exists
    if cache_tag and (data_f / f"{split}_{cache_tag}" / f"{tile}.npz").is_file():
        x = np.load(
            data_f / f"{split}_{cache_tag}" / f"{tile}.npz",
            allow_pickle=True,
        )
        x = {f: x[f] for f in x.files}
        x["bands_s1"] = list(x["bands_s1"])
        x["bands_s2"] = list(x["bands_s2"])
        return x
    if cache_tag and verbose:
        bh_logger(f"Cache incomplete for split={split}!", lvl=LogLevel.WARNING)
    elif verbose:
        bh_logger(f"Cache not used for split={split}!", lvl=LogLevel.WARNING)

    # Load the data
    f = data_f / split
    x = np.load(f / f"{tile}.npz", allow_pickle=True)
    x = {f: x[f] for f in x.files}

    # process s1 and s2
    s1 = process_s1(x, bands=bands_s1, speckle=smooth_s1)
    s2 = process_s2(x, bands=bands_s2)

    # Ensure both S1 and S2 are aligned properly
    if align:
        bands = [bands_s2.index("s2_fapar")] if "s2_fapar" in bands_s2 else None
        repr_e = get_repr(s2, bands=bands)
        # s1 = align_stack(s1, repr_e=repr_e)  # By its own already well aligned
        s2 = align_stack(s2, repr_e=repr_e, bands=bands)

    # Ensure mask is 0 where s2 is np.nan, and s2 is np.nan where mask is 0
    mask = format_cloud(x["s2_mask"])  # n_ts, width, height
    mask[np.isnan(s2).any(axis=1)] = 0
    s2 = apply_mask(s2, mask=mask)

    # Create a sample (dictionary)
    sample = {
        "s1": s1,
        "s2": s2,
        "bands_s1": bands_s1,
        "bands_s2": bands_s2,
        "mask": mask,
    }

    # Optionally add the artificial mask
    if "s2_mask_" in x:
        # Ensure mask_ is 0 where s2 is np.nan
        mask_ = format_cloud(x["s2_mask_"])  # n_ts, width, height
        mask_[np.isnan(s2).any(axis=1)] = 0
        sample["mask_"] = mask_  # n_ts, width, height
    elif has_mask_ and (split != "training"):
        raise ValueError(f"No artificial mask found for split={split} and tile={tile}")
    return sample


@lru_cache
def _sort_bands_s1(bands: tuple[str, ...]) -> list[str]:
    """Sort the S1 bands by their order."""
    bands = list(bands) or S1
    bands = sorted(bands, key=lambda x: tuple(reversed(x.split("_"))))
    assert all(b in S1 for b in bands), f"Band should be part of {', '.join(S1)}"
    return bands


@lru_cache
def _sort_bands_s2(bands: tuple[str, ...]) -> list[str]:
    """Sort the S2 bands by their order."""
    bands = list(bands) or S2
    bands = sorted(bands, key=lambda x: tuple(reversed(x.split("_"))))
    assert all(b in S2 for b in bands), f"Band should be part of {', '.join(S2)}"
    return bands


if __name__ == "__main__":
    from time import time

    import matplotlib.pyplot as plt
    from tqdm import tqdm

    SPLIT = "testing"
    IDX = 1

    tiles = list_tiles(split=SPLIT)
    print(f"Loaded total of {len(tiles)} tiles, example tile: {tiles[IDX]}")
    sample = load_sample(split=SPLIT, tile=tiles[IDX])
    print(f"S1 shape: {sample['s1'].shape}")
    print(f"S1 bands: {sample['bands_s1']}")
    print(f"S2 shape: {sample['s2'].shape}")
    print(f"S2 bands: {sample['bands_s2']}")
    print(f"Mask shape: {sample['mask'].shape}")
    for k in ("s1", "s2", "mask"):
        v = sample[k]
        print(
            f"{k}: shape={v.shape}, min={v[~np.isnan(v)].min()}, max={v[~np.isnan(v)].max()}"
        )

    # Stress-test the dataloading
    print("\nStart stress-testing (no cache)..")
    start, N = time(), 20
    assert len(tiles) >= N
    for i in tqdm(range(N), "Loading.."):
        _ = load_sample(split=SPLIT, tile=tiles[i])
    print(f" --> Loading takes {(time() - start) / N:.5f} s/tile on average")

    # Plot the whole time series
    print("\nPlotting one whole time series..")
    ts, ch_s1, _, _ = sample["s1"].shape
    _, ch_s2, _, _ = sample["s2"].shape
    ch = ch_s1 + ch_s2 + 1 + 1 + 1  # S1, S2, RGB, mask, mask_
    _, axs = plt.subplots(ts, ch, figsize=(ch, ts))
    for t in range(ts):
        axs[t, 0].set_ylabel(f"t={t}")

        # Show S1
        for c in range(ch_s1):
            axs[t, c].imshow(
                np.clip(sample["s1"][t, c], a_min=-50, a_max=5),
                vmin=-50,
                vmax=5,
            )
            if t == 0:
                axs[t, c].set_title(sample["bands_s1"][c])

        # Show S2
        for c in range(ch_s2):
            axs[t, c + ch_s1].imshow(
                np.clip(sample["s2"][t, c], a_min=0, a_max=1),
                vmin=0,
                vmax=1,
            )
            if t == 0:
                axs[t, c + ch_s1].set_title(sample["bands_s2"][c])

        # Show RGB
        R = sample["bands_s2"].index("s2_b04")
        G = sample["bands_s2"].index("s2_b03")
        B = sample["bands_s2"].index("s2_b02")
        rgb = sample["s2"][t, [R, G, B]]
        rgb = np.clip(rgb, a_min=0, a_max=0.3) / 0.3
        rgb = np.nan_to_num(rgb, nan=1)
        axs[t, -3].imshow(rgb.transpose(1, 2, 0))
        if t == 0:
            axs[t, -3].set_title("RGB")

        # Show the mask
        axs[t, -2].imshow(sample["mask"][t], vmin=0, vmax=1)
        if t == 0:
            axs[t, -2].set_title("Mask")

        # Show the artificial mask
        if "mask_" in sample:
            axs[t, -1].imshow(sample["mask_"][t], vmin=0, vmax=1)
        if t == 0:
            axs[t, -1].set_title("Mask_")

    # Disable axis
    for ax in axs.ravel():
        ax.set_xticks([])
        ax.set_yticks([])

    # Create the plot
    plt.tight_layout()
    plt.savefig("example_stack.png")
    plt.close()
