"""Main script to prepare and remove data folders."""

from __future__ import annotations

import json
import os
from functools import partial
from multiprocessing import Pool, cpu_count
from pathlib import Path
from shutil import rmtree
from time import time

from tqdm import tqdm

from vito_cropsar.constants import S1, S2, SPLITS, get_data_folder
from vito_cropsar.data.cache import push_to_cache
from vito_cropsar.data.faulty import check_faulty_sample
from vito_cropsar.data.io import load_sample
from vito_cropsar.data.masks import (
    clean_cloud_mask,
    extract_masks,
    inject_equidistant_mask,
)
from vito_cropsar.data.s3 import pull_from_s3, push_to_s3
from vito_cropsar.data.splits import extract_split_cfg, read_split_cfg
from vito_cropsar.vito_logger import bh_logger


def preprocess(
    cache_tag: str,
    data_f: Path | None = None,
    bands_s1: list[str] | None = None,
    bands_s2: list[str] | None = None,
    align: bool = True,
    smooth_s1: bool = True,
    overwrite_cache: bool = False,
    overwrite_masks: bool = False,
) -> None:
    """
    Preprocess dataset given a folder containing trainin, validation and testing splits.

    This function will:
     - Create fixed obscurence masks for testing and validation
     - Cache the results

    Parameters
    ----------
    cache_tag : str
        Cache postfix tag
    data_f : Path | None
        Path to folder containing training, testing and validation splits, by default None
    bands_s1 : list[str]
        The bands for S1 to load into cache, by default all
    bands_s2 : list[str]
        The bands for S2 to load into cache, by default all
    align: bool
        Whether to align the S1 and S2 images
    smooth_s1: bool
        Whether to smooth the S1 images (Speckle filter)
    overwrite_cache : bool
        Whether to overwrite previously cached files
    overwrite_masks : bool
        Whether to overwrite previously cached masks
    """
    # Get data folder
    data_f = get_data_folder() if data_f is None else data_f

    # Extract splits configuration
    print("Extracting split file..")
    try:
        split_cfg = read_split_cfg(data_f=data_f)
    except FileNotFoundError:
        print("Split configuration not found, creating it from the dataset splits..")
        extract_split_cfg(data_f=data_f)
        split_cfg = read_split_cfg(data_f=data_f)
    print("Done!\n")

    # Extract masks
    print("Extracting masks from training set..")
    extract_masks(data_f=data_f)
    print("Done!\n")

    # get available cores
    try:
        cpu_cores = len(os.sched_getaffinity(0))
    except AttributeError:
        cpu_cores = min(16, cpu_count() - 2)

    # Save the settings
    with open(data_f / f"settings_{cache_tag}.json", "w") as f:
        json.dump(
            {
                "cache_tag": cache_tag,
                "bands_s1": bands_s1 or S1,
                "bands_s2": bands_s2 or S2,
                "align": align,
                "smooth_s1": smooth_s1,
                "overwrite_cache": overwrite_cache,
                "overwrite_masks": overwrite_masks,
            },
            f,
            indent=2,
            sort_keys=True,
        )

    # iterate over samples
    bh_logger(f"Processing samples using {cpu_cores} processes..")
    datapoints = [(split, tile) for split in split_cfg for tile in split_cfg[split]]
    with Pool(cpu_cores) as pool:
        process_sample_ = partial(
            _process_sample,
            data_f=data_f,
            overwrite_cache=overwrite_cache,
            overwrite_masks=overwrite_masks,
            cache_tag=cache_tag,
            bands_s1=bands_s1,
            bands_s2=bands_s2,
            align=align,
            smooth_s1=smooth_s1,
        )
        _ = list(
            tqdm(
                pool.imap_unordered(process_sample_, datapoints),
                total=len(datapoints),
                desc="Processing samples..",
                smoothing=0.1,
            )
        )


def _process_sample(
    split_tile: tuple[str, str],
    data_f: Path,
    overwrite_cache: bool,
    overwrite_masks: bool,
    cache_tag: str | None = None,
    bands_s1: list[str] | None = None,
    bands_s2: list[str] | None = None,
    align: bool = True,
    smooth_s1: bool = True,
) -> None:
    """
    Process a single sample.

    Parameters
    ----------
    split_tile: tuple[str, str]
        The split and tile to process
    data_f : Path
        The data folder
    overwrite_cache : bool
        Whether to overwrite cache
    overwrite_masks : bool
        Whether to overwrite masks
    cache_tag : str | None
        The cache tag, by default None
    bands_s1 : list[str] | None
        The bands for S1 to load, by default all
    bands_s2 : list[str] | None
        The bands for S2 to load, by default all
    align: bool
        Whether to align the S2 images
    smooth_s1: bool
        Whether to smooth the S1 images (Speckle filter)
    """
    split, tile = split_tile
    tile_path = data_f / split / f"{tile}.npz"
    try:
        npz = load_sample(
            split=split,
            tile=tile,
            data_f=data_f,
            has_mask_=False,
            bands_s1=bands_s1,
            bands_s2=bands_s2,
            align=align,
            smooth_s1=smooth_s1,
            verbose=False,
        )
    except FileNotFoundError:
        bh_logger(f"Tile {tile} ({split}) not found!")
        return

    # prepare and clean dataset
    npz = {k: npz[k] for k in npz}

    # Remove faulty samples
    if check_faulty_sample(
        npz=npz,
        tile_path=tile_path,
        writer=tqdm.write,
    ):
        return

    # clean masks
    clean_cloud_mask(npz)

    # inject masks
    if split in {"validation", "testing"}:
        inject_equidistant_mask(
            npz=npz,
            tile_path=tile_path,
            overwrite=overwrite_masks,
        )

    # push to cache
    push_to_cache(
        npz=npz,
        split=split,
        tile=tile,
        tag=cache_tag,
        overwrite=overwrite_cache,
    )


def push_datasets(
    cache_tag: str | None = None,
    overwrite: bool = False,
    splits: list[str] | None = None,
) -> None:
    """
    Push the specified datasets to S3.

    Parameters
    ----------
    cache_tag : str
        Cache postfix tag, raw datasets if None
    overwrite : bool
        Whether to overwrite previously pulled files
    splits : list[str]
        The splits to pull, by default all
        Options: training, validation, testing
    """
    print("\nPushing data to S3..")
    start = time()
    push_to_s3(
        cache_tag=cache_tag,
        overwrite=overwrite,
        splits=splits,
    )
    print(f" --> Done! ({time() - start:.2f}s)")


def pull_datasets(
    cache_tag: str | None = None,
    overwrite: bool = False,
    splits: list[str] | None = None,
) -> None:
    """
    Pull the specified datasets from S3.

    Parameters
    ----------
    cache_tag : str
        Cache postfix tag, raw datasets if None
    overwrite : bool
        Whether to overwrite previously pulled files
    splits : list[str]
        The splits to pull, by default all
        Options: training, validation, testing
    """
    print("\nPulling data from S3..")
    start = time()
    pull_from_s3(
        cache_tag=cache_tag,
        overwrite=overwrite,
        splits=splits,
    )
    print(f" --> Done! ({time() - start:.2f}s)")


def clean_cache() -> None:
    """Clean all the caches and generated obscurence masks."""
    for split in SPLITS:
        cache_folders = list(get_data_folder().glob(f"{split}_*/"))
        if not cache_folders:
            print(f"No cache found for {split}! Skipping..")
        for folder in cache_folders:
            print(f"Removing {folder}..")
            start = time()
            rmtree(str(get_data_folder() / folder))
            print(f" --> Done! ({time() - start:.2f}s)")


if __name__ == "__main__":
    # Preprocess a dataset
    # create_splits(data_f=Path("data/test/raw"))
    preprocess(
        cache_tag="fapar_rgb",
        bands_s1=S1,
        bands_s2=["s2_fapar", "s2_b02", "s2_b03", "s2_b04"],
        align=False,
        smooth_s1=True,
        overwrite_cache=True,
        overwrite_masks=True,
    )
