"""Utils functions for encoders."""

from __future__ import annotations

from typing import Callable

import pandas as pd
import torch

from vito_crop_classification.utils import sc_to_tensor, ts_to_tensor


def extract_ts(
    df: pd.DataFrame,
    cols: list[str] | None = None,
) -> torch.FloatTensor:
    """Extract specified ts columns, decide either to stack them or concatenate them."""
    cols = cols if cols is not None else [c for c in df.columns if c[:3] == "ts_"]
    return ts_to_tensor(df=df, cols=cols)


def extract_sc(df: pd.DataFrame, cols: list[str] | None = None) -> torch.FloatTensor:
    """Extract specified sc columns."""
    cols = cols if cols is not None else [c for c in df.columns if c[:3] == "sc_"]
    return sc_to_tensor(df=df, cols=cols)


def get_extract_function(name: str) -> Callable[[pd.DataFrame], torch.FloatTensor]:
    """Return the requested extract function."""
    return globals()[name]


if __name__ == "__main__":
    from vito_crop_classification.data import load_datasets

    result = load_datasets(dataset_ratio=0.01)
    my_df = result["df_train"]

    print("\nExtracting RGB with stacking False")
    print(extract_ts(my_df, cols=["ts_R", "ts_G", "ts_B"]).shape)

    print("\nExtracting RGB with stacking True")
    print(extract_ts(my_df, cols=["ts_R", "ts_G", "ts_B"]).shape)

    print("\nExtracting slope")
    print(extract_sc(my_df, cols=["sc_slope"]).shape)

    print("\nExtracting altitude and slope")
    print(extract_sc(my_df, cols=["sc_altitude", "sc_slope"]).shape)
