"""Data loaders."""

from __future__ import annotations

import json
from pathlib import Path
from typing import Any

import pandas as pd
from sklearn.model_selection import GroupShuffleSplit

from vito_crop_classification.constants import get_data_folder
from vito_crop_classification.data_io.outliers import detect_and_prune_outliers
from vito_crop_classification.data_process import transform_df
from vito_crop_classification.vito_logger import bh_logger


def load_data(data_f: str = "hybrid_dataset", data_tag: str = "df.parquet") -> pd.DataFrame:
    """
    Load dataframe from data folder.

    Parameters
    ----------
    data_f : str
        Subfolder to look for the requested data
    data_tag : str
        Name of the DataFrame to load

    Returns
    -------
    pd.DataFrame
        Loaded dataframe
    """
    return pd.read_parquet(get_data_folder() / data_f / data_tag)


def load_feature_settings(
    data_f: str = "hybrid_dataset", settings_name: str = "Feature_settings.json"
) -> dict[str, Any] | None:
    """
    Load feature settings from data folder.

    Parameters
    ----------
    data_f : str
        Subfolder to look for the requested data
    settings_name : str
        Name of the settings file to load

    Returns
    -------
    dict[str, Any]
        Settings json file
    """
    path = get_data_folder() / data_f / settings_name
    if path.exists():
        with open(path, "r") as f:
            settings_cfg = json.load(f)
        return settings_cfg
    return None


def load_datasets(
    dataset: str | None = None,
    df: pd.DataFrame | None = None,
    r_test: float | None = 0.1,
    r_val: float | None = 0.1,
    dataset_ratio: float = 1.0,
    random_state: int | None = 42,
    scale_cfg: dict[str, tuple[float, float]] | None = None,
    outliers_perc: float = 0.0,
    drop_classes: list[str] = ["0-0-0-0", "1-6-0-0"],
) -> dict[str, Any]:
    """
    Load in a train and test split of the dataset.

    Parameters
    ----------
    dataset : str, optional
        Dataset, specified by its tag, to load
    df : pd.DataFrame, optional
        Raw dataframe to transform
    r_test : float, optional
        Optional ratio of the test-set
    r_val : float, optional
        Optional ratio of the validation-set (cut off after test-cut), not created if None
    dataset_ratio : float, optional
        Ratio of complete dataset to be used (use only for experimentation)
    random_state : int, optional
        Optional random state for reproducibility
    scale_cfg : dict[str, tuple[float, float]], optional
        Configuration used to scale the values
    outliers_perc : float, optional
        Percentage of outliers to prune from training set
    drop_classes : list[str], optional
        Drop certain classes from dataset

    Returns
    -------
    results : dict[str, Any]
        Results including
         - df_train : pd.DataFrame
            Training DataFrame
         - df_val : pd.DataFrame, optional
            Validation DataFrame, if requested
         - df_test : pd.DataFrame, optional
            Testing DataFrame, if requested
         - scale_cfg : dict[str, tuple[float, float]]
            Scaling configuration applied on the transformed DataFrames
    """
    # Loading and transforming data
    assert (dataset is None) != (df is None)
    df = df if (dataset is None) else load_data(data_f=dataset)  # Ignored, assert checks this
    if dataset_ratio < 1.0:
        df = df.sample(frac=dataset_ratio, random_state=random_state)
    df, scale_cfg = transform_df(df, scale_cfg=scale_cfg, drop_classes=drop_classes)

    # creating results
    results = {
        "df_train": df,
        "scale_cfg": scale_cfg,
        "feat_cfg": load_feature_settings(data_f=dataset) if type(dataset) == Path else None,
    }

    # Start dataset splits
    bh_logger(f"Splitting in train, test ({r_test}), and validation ({r_val})..")
    if r_test is not None:
        bh_logger(" - Splitting test data..")
        results["df_train"], results["df_test"] = _stratified_split(
            df=results["df_train"],
            ratio=r_test,
            random_state=random_state,
        )

    # Split the validation-dataset, if r_val is provided
    if r_val is not None:
        bh_logger(" - Splitting validation data..")
        results["df_train"], results["df_val"] = _stratified_split(
            df=results["df_train"],
            ratio=r_val,
            random_state=random_state,
        )

    # report splits dimentions
    if r_val or r_test:
        _print_splits(results, message="Computed dataset splits")

    # run outlier detection and pruning
    if outliers_perc > 0.0:
        assert outliers_perc <= 1.0
        bh_logger(f" - Removing top {outliers_perc}% outliers from training split..")
        results["df_train"] = detect_and_prune_outliers(results["df_train"], outliers_perc)
        _print_splits(results, "New splits")

    return results


def load_datasets_from_cfg(
    data_f: str, data_cfg: str
) -> tuple(dict[str, Any], dict[str, list[int]]):
    """Load split keys from data configuration saved on disk."""
    data_cfg = load_splits_cfg(data_f, data_cfg)

    data = load_data(data_f=data_f)
    df, scale_cfg = transform_df(data, drop_classes=[], scale_cfg=data_cfg["scale_cfg"])
    results = {
        "df_train": df,
        "scale_cfg": scale_cfg,
    }

    bh_logger("Extracting splits..")  # order is important!
    results["df_test"] = results["df_train"][results["df_train"].key_ts.isin(data_cfg["test"])]
    results["df_val"] = results["df_train"][results["df_train"].key_ts.isin(data_cfg["val"])]
    results["df_train"] = results["df_train"][results["df_train"].key_ts.isin(data_cfg["train"])]
    _print_splits(results, message="Computed dataset splits")

    return results, data_cfg


def load_splits_cfg(data_f: str, data_cfg: str) -> dict[str, list[int]]:
    """Load splits from disk."""
    with open(get_data_folder() / data_f / f"cfg_{data_cfg}.json", "r") as f:
        data_cfg = json.load(f)
    return data_cfg


def load_time_folds() -> list[pd.DataFrame]:
    """
    Load in as many dataset as the years in the dataset.

    Returns
    -------
    list[pd.DataFrame]
        List of K time-varying dataframes, K depends on number of years
    """
    df = load_data()
    df["year"] = [year.split("_")[0] for year in df["ref_id"].to_list()]
    results = [df[df["year"] == year] for year in set(df["year"])]
    bh_logger(f"Loaded in {len(results)} time folds!")
    return results


def _stratified_split(
    df: pd.DataFrame,
    ratio: float,
    random_state: int | None = None,
) -> tuple[pd.DataFrame, pd.DataFrame]:
    """
    Manual implementation of the stratified split.

    This implementation does not allow classes to not appear in all splits
    """
    split1, split2 = [], []
    for target in set(df.target_id):
        gss = GroupShuffleSplit(
            n_splits=1,
            test_size=ratio,
            random_state=random_state,
        )
        df1_idx, df2_idx = next(
            gss.split(
                X=df[df.target_id == target],
                groups=df[df.target_id == target]["key_field"],
            )
        )
        df1 = df[df.target_id == target].iloc[df1_idx]
        df2 = df[df.target_id == target].iloc[df2_idx]
        split1.append(df1)
        split2.append(df2)
    return pd.concat(split1, axis=0), pd.concat(split2, axis=0)


def _print_splits(results: dict[str, Any], message: str) -> None:
    """Pring splits dimensions."""
    bh_logger(f" - {message}:")
    df_train = results.get("df_train", pd.DataFrame([], columns=["target_id"]))
    bh_logger(f"   -   Training: {len(df_train)} samples ({len(set(df_train.target_id))} classes)")
    df_val = results.get("df_val", pd.DataFrame([], columns=["target_id"]))
    bh_logger(f"   - Validation: {len(df_val)} samples ({len(set(df_val.target_id))} classes)")
    df_test = results.get("df_test", pd.DataFrame([], columns=["target_id"]))
    bh_logger(f"   -    Testing: {len(df_test)} samples ({len(set(df_test.target_id))} classes)\n")


if __name__ == "__main__":
    res = load_datasets_from_cfg("hybrid-18ts", "test")
    my_result = load_datasets(dataset="hybrid-18ts", dataset_ratio=0.2)
    print("Extracted configuration:")
    for k, v in sorted(my_result.get("scale_cfg", {}).items()):
        print(f" - {k}: [{v[0]:.5f}, {v[1]:.5f}]")
