"""Cache operations."""

from __future__ import annotations

from shutil import rmtree

import numpy as np
from tqdm import tqdm

from vito_cropsar.constants import SPLITS, get_data_folder
from vito_cropsar.data.io import list_tiles, load_sample


def fill_cache(
    tag: str,
    bands_s1: list[str] | None = None,
    bands_s2: list[str] | None = None,
    align: bool = True,
    smooth_s1: bool = True,
    overwrite: bool = False,
) -> None:
    """
    Fill the cache.

    Parameters
    ----------
    tag : str
        Cache postfix tag
    bands_s1 : list[str]
        The bands for S1 to load in, by default all
    bands_s2 : list[str]
        The bands for S2 to load in, by default all
    align: bool
        Whether to align the S1 and S2 images
    smooth_s1: bool
        Whether to smooth the S1 images (Speckle filter)
    overwrite : bool
        Whether to overwrite previously cached files
    align : bool
        Whether to align the S2 stacks
    """
    for split in SPLITS:
        (get_data_folder() / f"{split}_{tag}").mkdir(exist_ok=True, parents=True)
        tiles = list_tiles(split=split)
        for tile in tqdm(tiles, desc=f"Filling cache for {split}.."):
            if (not overwrite) and (
                get_data_folder() / f"{split}_{tag}" / f"{tile}.npz"
            ).is_file():
                continue
            np.savez(
                get_data_folder() / f"{split}_{tag}" / f"{tile}.npz",
                **load_sample(
                    split=split,
                    tile=tile,
                    bands_s1=bands_s1,
                    bands_s2=bands_s2,
                    smooth_s1=smooth_s1,
                    align=align,
                    verbose=False,
                ),
            )


def push_to_cache(
    npz: np.ndarray, split: str, tile: str, tag: str, overwrite: bool = False
) -> None:
    """
    Push a sample to the cache.

    Parameters
    ----------
    npz : np.ndarray
        The sample to push
    split : str
        The split to push to
    tile : str
        The tile to push to
    tag : str
        Cache postfix tag
    overwrite : bool
        Whether to overwrite previously cached files
    """
    data_f = get_data_folder() / f"{split}_{tag}"
    data_f.mkdir(exist_ok=True, parents=True)
    if (data_f / f"{tile}.npz").is_file() and (not overwrite):
        return
    np.savez(data_f / f"{tile}.npz", **npz)


def clean_cache(tag: str) -> None:
    """
    Clean the cache.

    Note: This will remove entire cache directories.

    Parameters
    ----------
    tag : str
        Cache postfix tag
    """
    for split in tqdm(SPLITS, desc="Cleaning cache.."):
        if (get_data_folder() / f"{split}_{tag}").is_dir():
            rmtree(get_data_folder() / f"{split}_{tag}")


if __name__ == "__main__":
    fill_cache("48_256_256_fapar", bands_s2=["s2_fapar"])
    # clean_cache("48_256_256_fapar")
