"""Transform the provided class into a new standard."""

from __future__ import annotations

from typing import Callable

import numpy as np
import pandas as pd


def main(
    df: pd.DataFrame,
    process_f: Callable[[...], pd.DataFrame],
    col_output: str = "LABEL",
    col_lc: str = "LC",
    continents: list[str] | None = None,
) -> pd.DataFrame:
    """
    Process the provided DataFrame.

    Parameters
    ----------
    df : pd.DataFrame
        DataFrame that needs to get transformed
    process_f : Callable[[pd.Series], str]
        Transformation function that operates on the lookup_cols to converts the class into a unified class label
    col_output : str
        Column name for the output
    col_lc : str
        Column name for the LC
    continents : list[str], optional
        Continents to keep
        Options: 'europe', 'north_america', 'south_america', 'africa', 'asia', 'australia'

    Returns
    -------
    df : pd.DataFrame
        A transformed dataframe:
         - Input:  df[col0, ..., colN]
         - Output: df[col0, ..., colN, target_id, target_name]
    """
    print(f"Processing hybrid DataFrame of shape {df.shape}:")
    df.index = np.arange(len(df))

    # Remove NDVI, if present
    print(" - Removing NDVI column..")
    cols = [c for c in df.columns if "ndvi" in c]
    df.drop(columns=cols, inplace=True)
    print(f"   - Done! {df.shape}")

    # Transform specific to the dataset
    try:
        df = process_f(
            df=df,
            continents=continents,
            col_output=col_output,
            col_lc=col_lc,
        )
    except KeyError as e:
        print("All DataFrame columns:")
        print(df.columns.to_list())
        raise e

    # Drop the 0-0-0-0 class
    print(" - Removing classes not in the current scenario..")
    df.drop(df[df["target_id"] == "0-0-0-0"].index, inplace=True)
    print(f"   - Done! {df.shape}")

    # Reindex DataFrame again
    df.index = np.arange(len(df))
    print(f" - Final DataFrame is of shape {df.shape}")
    return df


if __name__ == "__main__":
    from vito_crop_classification.constants import get_data_folder
    from vito_crop_classification.data_format.dataset_hybrid import process_hybrid

    # Load in the DataFrame
    my_df = pd.read_parquet(get_data_folder() / "hybrid_patch1/raw.parquet")

    # Process the DataFrame
    my_df = main(
        df=my_df,
        process_f=process_hybrid,
        continents=["europe", "north_america"],
    )

    # Write away the DataFrame
    my_df.to_parquet(get_data_folder() / "hybrid_patch1/df.parquet")
