"""Data splits specific operations."""

from __future__ import annotations

import json
from pathlib import Path
from shutil import copy

import numpy as np
from sklearn.model_selection import StratifiedGroupKFold
from tqdm import tqdm


def create_splits(
    data_f: Path,
    create_copy: bool = True,
) -> dict[str, list[Path]]:
    """Create splits from a given folder.

    Parameters
    ----------
    data_f : Path
        Path to the folder containing the raw data to split.
    create_copy : bool, optional
        Create a copy of the data rather than moving it, by default True


    Returns
    -------
    dict[str, list[Path]]
        Dictionary containing the split name and a list of tiles names.
    """
    npzs = np.array(list(data_f.glob("**/*.npz")))
    regions = np.array([p.name.split("_")[0] for p in npzs])
    coords = np.array(["_".join(p.name.split("_")[1:5]) for p in npzs])
    groups = np.array([f"{r}_{c}" for r, c in zip(regions, coords)])
    print(f"Found {len(npzs)} tiles")

    # get train and test splits
    splitter = StratifiedGroupKFold(
        n_splits=max(len(npzs) // 100, 10), random_state=42, shuffle=True
    ).split(npzs, regions, groups)
    train_ids, test_ids = next(splitter)
    test_npzs = npzs[test_ids]
    train_npzs = npzs[train_ids]
    train_regions = regions[train_ids]
    train_groups = groups[train_ids]

    # get train and validation splits
    splitter = StratifiedGroupKFold(
        n_splits=max(len(train_npzs) // 100, 10), random_state=42, shuffle=True
    ).split(train_npzs, train_regions, train_groups)
    train_ids, val_ids = next(splitter)
    val_npzs = train_npzs[val_ids]
    train_npzs = train_npzs[train_ids]

    # generate splits and create folders
    print("Generating splits...")
    print(f" - Training: {len(train_npzs)}")
    print(f" - Validation: {len(val_npzs)}")
    print(f" - Testing: {len(test_npzs)}")

    splits = {"training": [], "validation": [], "testing": []}
    for split, paths in zip(
        ["training", "validation", "testing"],
        [train_npzs, val_npzs, test_npzs],
    ):
        parent_f = data_f.parent
        (parent_f / split).mkdir(exist_ok=True, parents=True)
        for p in tqdm(paths, desc=split):
            splits[split].append(p.name.split(".")[0])
            if create_copy:
                copy(p, parent_f / split / p.name)
            else:
                p.rename(parent_f / split / p.name)
    print("Done!\n")

    # save splits file as json
    with open(data_f.parent / "splits.json", "w") as f:
        json.dump(splits, f, indent=4)


def read_split_cfg(
    data_f: Path, cfg_name: str = "splits.json"
) -> dict[str, list[Path]]:
    """
    Read the splits configuration file.

    Returns a dictionary containing the split name and a list of tiles names.
    """
    return json.loads((data_f / cfg_name).read_text())


def extract_split_cfg(data_f: Path, cfg_name: str = "splits.json") -> None:
    """Extract configuration split file from the training, testing and validation splits."""
    config = {"training": [], "validation": [], "testing": []}
    # create splits
    for f in data_f.glob("**/*.npz"):
        if f.parent.name == "training":
            config["training"].append(f.name.split(".")[0])
        elif f.parent.name == "validation":
            config["validation"].append(f.name.split(".")[0])
        elif f.parent.name == "testing":
            config["testing"].append(f.name.split(".")[0])

    # save splits
    with open((data_f / cfg_name), "w") as f:
        json.dump(config, f, indent=4)


if __name__ == "__main__":
    extract_split_cfg(Path("data/data"))
    read_split_cfg(Path("data/data"))
