"""Data loaders."""

from __future__ import annotations

from typing import Any

import pandas as pd
from sklearn.model_selection import train_test_split

from vito_crop_classification.constants import get_data_folder
from vito_crop_classification.data.process import transform_df
from vito_crop_classification.vito_logger import bh_logger


def load_data(data_f: str = "hybrid_dataset", data_tag: str = "df.parquet") -> pd.DataFrame:
    """
    Load dataframe from data folder.

    Parameters
    ----------
    data_f : str
        Subfolder to look for the requested data
    data_tag : str
        Name of the DataFrame to load

    Returns
    -------
    pd.DataFrame
        Loaded dataframe
    """
    return pd.read_parquet(get_data_folder() / data_f / data_tag)


def load_datasets(
    dataset: str | None = None,
    df: pd.DataFrame | None = None,
    r_test: float | None = 0.1,
    r_val: float | None = 0.1,
    dataset_ratio: float = 1.0,
    random_state: int | None = 42,
    scale_cfg: dict[str, tuple[float, float]] | None = None,
) -> dict[str, Any]:
    """
    Load in a train and test split of the dataset.

    Parameters
    ----------
    dataset : str
        Dataset, specified by its tag, to load
    df : pd.DataFrame
        Raw dataframe to transform
    r_test : float, optional
        Optional ratio of the test-set
    r_val : float, optional
        Optional ratio of the validation-set (cut off after test-cut), not created if None
    dataset_ratio : float
        Ratio of complete dataset to be used (use only for experimentation)
    random_state : int, optional
        Optional random state for reproducibility
    scale_cfg : dict[str, tuple[float, float]]
        Configuration used to scale the values

    Returns
    -------
    results : dict[str, Any]
        Results including
         - df_train : pd.DataFrame
            Training DataFrame
         - df_val : pd.DataFrame, optional
            Validation DataFrame, if requested
         - df_test : pd.DataFrame, optional
            Testing DataFrame, if requested
         - scale_cfg : dict[str, tuple[float, float]]
            Scaling configuration applied on the transformed DataFrames
    """
    # Loading and transforming data
    assert (dataset is None) != (df is None)
    df = df if (dataset is None) else load_data(data_f=dataset)  # Ignored, assert checks this
    if dataset_ratio < 1.0:
        df = df.sample(frac=dataset_ratio, random_state=random_state)
    df, scale_cfg = transform_df(df, scale_cfg=scale_cfg)
    results = {
        "df_train": df,
        "scale_cfg": scale_cfg,
    }

    # Start dataset splits
    bh_logger(f"Splitting in train, test ({r_test}), and validation ({r_val})..")
    if r_test is not None:
        bh_logger(" - Splitting test data..")
        results["df_train"], results["df_test"] = stratified_split(
            df=results["df_train"],
            ratio=r_test,
            random_state=random_state,
        )

    # Split the validation-dataset, if r_val is provided
    if r_val is not None:
        bh_logger(" - Splitting validation data..")
        results["df_train"], results["df_val"] = stratified_split(
            df=results["df_train"],
            ratio=r_val,
            random_state=random_state,
        )

    bh_logger(" - Loaded in datasets of size:")
    df_train = results.get("df_train", pd.DataFrame([], columns=df.columns))
    bh_logger(f"   -   Training: {len(df_train)} samples ({len(set(df_train.target_id))} classes)")
    df_val = results.get("df_val", pd.DataFrame([], columns=df.columns))
    bh_logger(f"   - Validation: {len(df_val)} samples ({len(set(df_val.target_id))} classes)")
    df_test = results.get("df_test", pd.DataFrame([], columns=df.columns))
    bh_logger(f"   -    Testing: {len(df_test)} samples ({len(set(df_test.target_id))} classes)\n")
    return results


def stratified_split(
    df: pd.DataFrame,
    ratio: float,
    random_state: int | None = None,
) -> tuple[pd.DataFrame, pd.DataFrame]:
    """
    Manual implementation of the stratified split.

    This implementation enforces at least one target sample in each split. Note that if there's only one class sample,
    a duplicate sample would be added to the other split.
    """
    split1, split2 = [], []
    for target in set(df.target_id):
        df1, df2 = train_test_split(
            df[df.target_id == target],
            test_size=ratio,
            random_state=random_state,
        )
        if df1.shape[0] == 0:
            df1 = df2  # Dirty problem-solving
            bh_logger(
                f"   - Not possible to split over target '{target}', "
                "duplicating sample over both train and test splits!"
            )
        elif df2.shape[0] == 0:
            df2 = df1  # Dirty problem-solving
            bh_logger(
                f"   - Not possible to split over target '{target}', "
                "duplicating sample over both train and test splits!"
            )
        split1.append(df1)
        split2.append(df2)
    return pd.concat(split1, axis=0), pd.concat(split2, axis=0)


def load_time_folds() -> list[pd.DataFrame]:
    """
    Load in as many dataset as the years in the dataset.

    Returns
    -------
    list[pd.DataFrame]
        List of K time-varying dataframes, K depends on number of years
    """
    df = load_data()
    df["year"] = [year.split("_")[0] for year in df["ref_id"].to_list()]
    results = [df[df["year"] == year] for year in set(df["year"])]
    bh_logger(f"Loaded in {len(results)} time folds!")
    return results


if __name__ == "__main__":
    my_result = load_datasets(dataset_ratio=0.01)
    print("Extracted configuration:")
    for k, v in sorted(my_result.get("scale_cfg", {}).items()):
        print(f" - {k}: [{v[0]:.5f}, {v[1]:.5f}]")
