"""Postprocess the DataFrame to put it into a format the models understand."""

from __future__ import annotations

import numpy as np
import pandas as pd
from numpy.typing import NDArray

from vito_crop_classification.constants import DEM_TAG, METEO_TAG, S1_TAG, S2_TAG
from vito_crop_classification.vito_logger import LogLevel, bh_logger

COL_PREFIX = [
    ("mh_biome", "biome"),
    ("sc_slope", f"{DEM_TAG}-slo"),
    ("sc_altitude", f"{DEM_TAG}-alt"),
    ("sc_lat", "lat"),
    ("sc_lon", "lon"),
    ("sc_location", "location_id"),
    ("sc_sample", "sampleID"),
    ("sc_group", "aez_groupid"),
    ("sc_zone", "aez_zoneid"),
    ("ts_R", f"{S2_TAG}-B02-ts"),
    ("ts_G", f"{S2_TAG}-B03-ts"),
    ("ts_B", f"{S2_TAG}-B04-ts"),
    ("ts_vegred1", f"{S2_TAG}-B05-ts"),
    ("ts_vegred2", f"{S2_TAG}-B06-ts"),
    ("ts_vegred3", f"{S2_TAG}-B07-ts"),
    ("ts_nir", f"{S2_TAG}-B08-ts"),
    ("ts_swir1", f"{S2_TAG}-B11-ts"),
    ("ts_swir2", f"{S2_TAG}-B12-ts"),
    ("ts_ndvi", f"{S2_TAG}-ndvi-ts"),
    ("ts_ndmi", f"{S2_TAG}-ndmi-ts"),
    ("ts_ndwi", f"{S2_TAG}-ndwi-ts"),
    ("ts_ndgi", f"{S2_TAG}-ndgi-ts"),
    ("ts_ndti", f"{S2_TAG}-ndti-ts"),
    ("ts_anir", f"{S2_TAG}-anir-ts"),
    ("ts_ndre1", f"{S2_TAG}-ndre1-ts"),
    ("ts_ndre2", f"{S2_TAG}-ndre2-ts"),
    ("ts_ndre5", f"{S2_TAG}-ndre5-ts"),
    ("ts_sar_vh", f"{S1_TAG}-VH-ts"),
    ("ts_sar_vv", f"{S1_TAG}-VV-ts"),
    ("ts_sar_ratio", f"{S1_TAG}-RATIO-ts"),
    ("ts_meteo", f"{METEO_TAG}-temperature_mean-ts"),
    ("year", "year"),
    ("continent", "continent"),
    ("country", "country"),
    ("ref_id", "ref_id"),
    ("target_id", "target_id"),
    ("target_name", "target_name"),
    ("key_ts", "key_ts"),
    ("key_field", "key_field"),
]


def main(df: pd.DataFrame) -> pd.DataFrame:  # noqa: C901
    """Merge the DataFrame columns (e.g. time series) into one array and rename columns."""
    bh_logger(" - Reshaping features..")

    # Check which columns are not included
    col_tags = {x for _, x in COL_PREFIX}
    for col in df.columns:
        if all(x not in col for x in col_tags):
            bh_logger(f"   - Feature '{col}' is not getting reshaped!", lvl=LogLevel.WARNING)

    # Collect the new columns
    new_df = pd.DataFrame([], index=df.index)
    for col, prefix in COL_PREFIX:
        try:
            if col[:2] == "ts":
                new_df[col] = process_ts(df=df, ts_prefix=prefix)
            elif col[:2] == "sc":
                new_df[col] = process_sc(df=df, sc_prefix=prefix)
            elif col[:2] == "mh":
                new_df[col] = process_mh(df=df, mh_prefix=prefix, threshold=0.2)
            elif (col[:6] == "target") or (
                col in {"year", "continent", "country", "ref_id", "key_ts", "key_field"}
            ):
                if prefix in df.columns:
                    new_df[col] = df[prefix]
                else:
                    bh_logger(f"   - Unable to extract '{col}', not found in original dataframe")
            else:
                raise Exception(
                    f"\nColumn name {col} does not have a valid prefix. Should be either sc, ts, mh, or target."
                )
        except (KeyError, AssertionError):
            bh_logger(
                f"   - Unable to extract '{col}', column '{prefix}' missing!",
                lvl=LogLevel.WARNING,
            )
    return new_df


def process_ts(df: pd.DataFrame, ts_prefix: str) -> list[NDArray[np.float64]]:
    """
    Process the requested time series.

    Parameters
    ----------
    df : pd.DataFrame
        Dataframe from which to extract the time series from
    ts_prefix : str
        Prefix for the columns in the dataframe to look for

    Note
    ----
    Linear interpolation is applied if NaN values are encountered.

    Returns
    -------
    pd.DataFrame
        Processed dataframe
    """
    columns = sorted(
        [c for c in df.columns if c[: len(ts_prefix)] == ts_prefix],
        key=lambda x: (len(x), x),
    )
    assert columns != [], f"No columns found for time series '{ts_prefix}'!"
    df_interp = df[columns].apply(lambda x: x.interpolate(limit_direction="both"))
    return list(df_interp.to_numpy())


def process_sc(df: pd.DataFrame, sc_prefix: str) -> pd.Series:
    """
    Process the requested scalar value.

    Parameters
    ----------
    df : pd.DataFrame
        Dataframe from which to extract the scalar value from
    sc_prefix : str
        Prefix for the columns in the dataframe to look for

    Returns
    -------
    pd.DataFrame
        Processed dataframe
    """
    columns = sorted(
        [c for c in df.columns if c[: len(sc_prefix)] == sc_prefix],
        key=lambda x: (len(x), x),
    )
    assert columns != [], f"No columns found for scalar values '{sc_prefix}'!"
    return df[columns]


def process_mh(
    df: pd.DataFrame, mh_prefix: str, threshold: float = 0.1
) -> list[NDArray[np.float64]]:
    """Process the requested columns as multihot encoding.

    Parameters
    ----------
    df : pd.DataFrame
        Dataframe from which to extract the raw values
    mh_prefix : str
        Prefix for the columns in the dataframe to look for
    threshold : float
        Threshold to use to set values to 1 or 0 in the multi hot encoding

    Returns
    -------
    list[NDArray[np.float64]]
        processed dataframe
    """
    columns = sorted([c for c in df.columns if c.startswith(mh_prefix)], key=lambda x: (len(x), x))
    assert columns != [], "No columns found for biomes!"
    return [(row >= threshold).astype(float) for row in df[columns].to_numpy()]
