"""Main pre-processing pipeline."""

from __future__ import annotations

import warnings
from functools import partial
from math import ceil
from pathlib import Path

import numpy as np
import pandas as pd
import torch
from multiprocess import Pool, cpu_count
from sklearn.model_selection import GroupShuffleSplit
from tqdm import tqdm

from vito_lot_delineation.data_prepr.utils import (
    clean_instance,
    compute_extent,
    compute_ndvi,
    fill_and_interpolate,
    list_npz,
    rerange,
    scale,
    transform,
)

warnings.filterwarnings("ignore")


def split_into_dfs(
    paths: list[Path],
) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
    """Extract train, val, and test df."""
    rows = [list(p.name.split(".")[0].split("_")) for p in paths]

    # splitting data into columns which are multiples of 128
    ticks_x = np.array(list(range(0, 128 * 1000, 128 * 3)))
    ticks_y = np.array(list(range(0, 128 * 1000, 128 * 3)))
    for i, (code, path) in enumerate(zip(rows, paths)):
        key = f"{code[0]}_{sum(ticks_x < int(code[2]))}_{sum(ticks_y < int(code[4]))}"
        rows[i].append(path.parent.name.split("-")[0].split("_")[-1])
        rows[i].append(key)
        rows[i].append(path)

    # peparing df
    df = pd.DataFrame(
        rows, columns=["country", "xmin", "xmax", "ymin", "ymax", "year", "key", "path"]
    )

    # split train and test
    gs = GroupShuffleSplit(n_splits=2, test_size=0.1, random_state=0)
    train_ix, test_ix = next(gs.split(df, groups=df.key))
    df_train = df.iloc[train_ix]
    df_test = df.iloc[test_ix]

    # split train and val
    gs = GroupShuffleSplit(n_splits=2, test_size=0.1, random_state=0)
    train_ix, val_ix = next(gs.split(df_train, groups=df_train.key))
    df_val = df_train.iloc[val_ix]
    df_train = df_train.iloc[train_ix]

    return df_train, df_val, df_test


def preprocess_data(npz: np.ndarray, save_dir: Path) -> None:
    """Preprocess and save a single patch."""
    # preprocess bands
    bands = transform({k: npz[k] for k in npz if k not in ["parcelids", "CT"]})
    bands = compute_ndvi(bands)
    bands = scale(bands)
    bands = fill_and_interpolate(bands)
    bands = rerange(bands)
    bands = {k: np.nan_to_num(bands[k]) for k in bands}

    # clean instance and generate extent
    cleaned = clean_instance(
        torch.Tensor(npz["parcelids"].astype(np.int32)),
        min_size=10,
    )
    extent = compute_extent(cleaned, min_size=10)

    # save numpy
    np.savez(
        save_dir,
        instance=cleaned.numpy(),
        extent=extent.numpy(),
        **bands,
    )


def preprocess_chunk(chunk: pd.DataFrame, save_dir: Path) -> None:
    """Preprocess data."""
    for _, row in tqdm(chunk.iterrows(), total=len(chunk)):

        # read in data
        path = row[-1]
        name = "_".join(row[:-2])
        try:
            npz = np.load(path, allow_pickle=True)
        except:  # noqa: E722
            print("Found corrupted file: ", path)
            continue
        preprocess_data(npz, save_dir / name)


# TODO
#  - Add quantile scalers
if __name__ == "__main__":
    print(" - Reading data..")
    paths = []
    for p in Path("data/data_raw").glob("E*"):
        paths += list_npz(p)

    print(" - Splitting data into train, test, val..")
    df_train, df_val, df_test = split_into_dfs(paths)
    print("   - Obtained 3 datasets:")
    print(f"     - train : {len(df_train)}")
    print(f"     -   val : {len(df_val)}")
    print(f"     -  test : {len(df_test)}")

    for split, df in zip(
        ["training", "validation", "testing"],
        [df_train, df_val, df_test],
    ):
        print(f" - Processing {split}..")

        # create save directory
        save_dir = Path(f"data/data_new/{split}")
        save_dir.mkdir(parents=True, exist_ok=True)

        # split data into chunks
        chunk_size = ceil(len(df) / (cpu_count() - 2))
        chunks = [
            df.iloc[i : min(i + chunk_size, len(df) + 1)]
            for i in range(0, len(df), chunk_size)
        ]
        with Pool(cpu_count() - 2) as p:
            _ = list(p.imap(partial(preprocess_chunk, save_dir=save_dir), chunks))
