"""Utils functions for encoders."""

from __future__ import annotations

from typing import Callable

import numpy as np
import pandas as pd
import torch

from vito_crop_classification.utils import sc_to_tensor, ts_to_tensor


def extract_ts(df: pd.DataFrame, cols: list[str] | None = None) -> torch.FloatTensor:
    """Extract specified ts columns, decide either to stack them or concatenate them."""
    cols = cols if cols is not None else [c for c in df.columns if c[:3] == "ts_"]
    return ts_to_tensor(df=df, cols=cols)


def extract_sc(df: pd.DataFrame, cols: list[str] | None = None) -> torch.FloatTensor:
    """Extract specified sc columns."""
    cols = cols if cols is not None else [c for c in df.columns if c[:3] == "sc_"]
    return sc_to_tensor(df=df, cols=cols)


def extract_mh(df: pd.DataFrame, cols: list[str] | None = None) -> torch.FloatTensor:
    """Extract specified mh columns."""
    cols = cols if cols is not None else [c for c in df.columns if c[:3] == "mh_"]
    assert len(cols) == 1, "Multiple multihotencoded extraction not supported yet"
    return torch.Tensor(np.vstack([x[0] for x in df[cols].values]))


def extract_any(df: pd.DataFrame, cols: list[str] | None = None) -> torch.FloatTensor:
    """Extract any combination of ts, sc, mh."""
    # extract columns
    tss_cols = [c for c in cols if "ts_" in c]
    tss = extract_ts(df, tss_cols) if tss_cols else None
    scs_cols = [c for c in cols if "sc_" in c]
    scs = extract_sc(df, scs_cols) if scs_cols else None
    mhs_cols = [c for c in cols if "mh_" in c]
    mhs = extract_mh(df, mhs_cols) if mhs_cols else None

    # if tss is extracted then reshape everything as timeseries
    if tss is not None:
        res = tss
        if scs is not None:
            scs = scs.unsqueeze(1).repeat(1, tss.shape[1], 1)
            res = torch.concat((res, scs), dim=2)
        if mhs is not None:
            mhs = mhs.unsqueeze(1).repeat(1, tss.shape[1], 1)
            res = torch.concat((res, mhs), dim=2)
        return res

    # if scs and mhs are extracted combine them if necessary else return the one requested
    if scs is not None:
        return torch.concat((scs, mhs), dim=1) if mhs is not None else scs
    return mhs


def get_extract_function(name: str) -> Callable[[pd.DataFrame], torch.FloatTensor]:
    """Return the requested extract function."""
    return globals()[name]


if __name__ == "__main__":
    from vito_crop_classification.data_io import load_datasets

    result = load_datasets(dataset="24ts-DEFAULT-nofilters", dataset_ratio=0.01)
    my_df = result["df_train"]

    print("\nExtracting biomes")
    print(extract_mh(my_df, cols=["mh_biome"]).shape)

    print("\nExtracting RGB")
    print(extract_ts(my_df, cols=["ts_R", "ts_G", "ts_B"]).shape)

    print("\nExtracting slope")
    print(extract_sc(my_df, cols=["sc_slope"]).shape)

    print("\nExtracting altitude and slope")
    print(extract_sc(my_df, cols=["sc_altitude", "sc_slope"]).shape)

    print("\nExtracting RGB, altitude, and slope")
    print(extract_any(my_df, cols=["ts_R", "ts_G", "ts_B", "mh_biome", "sc_slope"]).shape)

    print("\nExtracting RGB, and biomes")
    print(extract_any(my_df, cols=["ts_R", "ts_G", "ts_B", "mh_biome"]).shape)

    print("\nExtracting altitude, slope, and biomes")
    print(extract_any(my_df, cols=["sc_altitude", "sc_slope", "mh_biome"]).shape)
