"""Preprocess the DataFrame by ensuring all features are present and to scale their values."""

from __future__ import annotations

import numpy as np
import pandas as pd
from numpy.typing import NDArray

from vito_crop_classification.constants import BIOME_TAG, DEM_TAG, METEO_TAG, S1_TAG, S2_TAG
from vito_crop_classification.vito_logger import LogLevel, bh_logger

# Columns to ignore during scaling
SCALE_IGNORE = {
    "target_id",
    "target_name",
    "continent",
    "lat",
    "lon",
    "year",
    f"{DEM_TAG}-alt-20m",
    f"{DEM_TAG}-slo-20m",
}

INDICES = [
    (f"{S2_TAG}-ndvi-ts", f"{S2_TAG}-B08-ts", f"{S2_TAG}-B04-ts"),
    (f"{S2_TAG}-ndmi-ts", f"{S2_TAG}-B08-ts", f"{S2_TAG}-B11-ts"),
    (f"{S2_TAG}-ndwi-ts", f"{S2_TAG}-B03-ts", f"{S2_TAG}-B08-ts"),
    (f"{S2_TAG}-ndgi-ts", f"{S2_TAG}-B03-ts", f"{S2_TAG}-B04-ts"),
    (f"{S2_TAG}-ndti-ts", f"{S2_TAG}-B11-ts", f"{S2_TAG}-B12-ts"),
    (f"{S2_TAG}-ndre1-ts", f"{S2_TAG}-B08-ts", f"{S2_TAG}-B05-ts"),
    (f"{S2_TAG}-ndre2-ts", f"{S2_TAG}-B08-ts", f"{S2_TAG}-B06-ts"),
    (f"{S2_TAG}-ndre5-ts", f"{S2_TAG}-B07-ts", f"{S2_TAG}-B05-ts"),
]


def main(
    df: pd.DataFrame,
    scale_cfg: dict[str, tuple[float, float]] | None = None,
    scale_dynamically: bool = True,
) -> tuple[pd.DataFrame, dict[str, tuple[float, float]]]:
    """
    Process the DataFrame by adding missing columns (features) and rescaling.

    Parameters
    ----------
    df : pd.DataFrame
        Dataframe to process
    scale_cfg : dict[str, tuple[float, float]] | None
        Scaling configuration, if not provided the scaling is computed on the fly
    scale_dynamically : bool
        Update the scaling configuration dynamically, raise an exception otherwise

    Returns
    -------
    df : pd.DataFrame
        Processed dataframe
    scale_cfg : dict[str, tuple[float, float]]
        Applied scaling configuration
    """
    # fix indexes
    df.index = np.arange(len(df))

    # add indices
    bh_logger(" - Adding new indices..")
    df = add_indices(df)
    bh_logger(f" - Done! Updated DataFrame shape: {df.shape}")

    # scale df
    bh_logger(" - Scaling features..")
    scale_cfg = scale_cfg or {}
    df, scale_cfg = scale_df(df, scale_cfg=scale_cfg, scale_dynamically=scale_dynamically)
    bh_logger(f" - Done! Updated DataFrame shape: {df.shape}")

    return df, scale_cfg


def add_indices(df: pd.DataFrame) -> pd.DataFrame:
    """Add newly calculated indices."""
    dataframes = [df]

    # Add the normalised difference indices
    targets = [x for x, _, _ in INDICES] + [f"{S2_TAG}-anir-ts", f"{S1_TAG}-RATIO-ts"]
    df.drop(columns=[c for c in df.columns if any(x in c for x in targets)], inplace=True)
    for prefix, b1, b2 in INDICES:
        try:
            dataframes.append(_create_index(df, prefix=prefix, prefix_b1=b1, prefix_b2=b2))
        except (KeyError, AssertionError):
            bh_logger(
                f"   - Unable to generate '{prefix}', either '{b1}' or '{b2}' missing!",
                lvl=LogLevel.WARNING,
            )

    # Add the other indices
    try:
        dataframes.append(_compute_anir(df))
    except KeyError:
        bh_logger(
            "   - Unable to generate 'anir', either 'S2-B04-ts' or 'S2-B08-ts' or 'S2-B11-ts' missing!",
            lvl=LogLevel.WARNING,
        )
    try:
        dataframes.append(_compute_sar_ratio(df))
    except Exception:
        bh_logger(
            "   - Unable to generate 'sar_ratio', either 'S1-VV-ts' or 'S1-VH-ts' missing!",
            lvl=LogLevel.WARNING,
        )

    # Return the aggregated DataFrame
    return pd.concat(dataframes, axis=1)


def _create_index(df: pd.DataFrame, prefix: str, prefix_b1: str, prefix_b2: str) -> pd.DataFrame:
    """Create normalised difference indices using the provided bands."""
    cols1 = sorted(
        [c for c in df.columns if c[: len(prefix_b1)] == prefix_b1], key=lambda x: (len(x), x)
    )
    cols2 = sorted(
        [c for c in df.columns if c[: len(prefix_b2)] == prefix_b2], key=lambda x: (len(x), x)
    )
    postfixes = [x[len(prefix_b1) :].split("-")[0] for x in cols1]
    assert postfixes, "No timestamps found!"
    assert postfixes == [
        x[len(prefix_b2) :].split("-")[0] for x in cols2
    ], f"Postfixes for bands '{prefix_b1}' and '{prefix_b2}' don't match! ({cols1} vs {cols2})"

    # Create the normalised difference index
    return pd.DataFrame(
        _ndi(a=df[cols1].to_numpy(), b=df[cols2].to_numpy()),
        columns=[f"{prefix}{postfix}" for postfix in postfixes],
        index=df.index,
    )


def _ndi(a: NDArray[np.float64], b: NDArray[np.float64]) -> NDArray[np.float64]:
    """Calculate the normalised difference index."""
    v_diff = a - b
    v_sum = a + b
    v_sum[v_sum == 0] = 1e-12
    return v_diff / v_sum


def _compute_anir(df: pd.DataFrame) -> pd.DataFrame:
    """Compute ANIR combination of S2 data."""
    # Define the bands and weights to use
    prefix_b1, w_b1 = f"{S2_TAG}-B04-ts", 0.6646
    prefix_b2, w_b2 = f"{S2_TAG}-B08-ts", 0.8328
    prefix_b3, w_b3 = f"{S2_TAG}-B11-ts", 1.610

    # See if columns match
    cols1 = sorted(
        [c for c in df.columns if c[: len(prefix_b1)] == prefix_b1], key=lambda x: (len(x), x)
    )
    cols2 = sorted(
        [c for c in df.columns if c[: len(prefix_b2)] == prefix_b2], key=lambda x: (len(x), x)
    )
    cols3 = sorted(
        [c for c in df.columns if c[: len(prefix_b3)] == prefix_b3], key=lambda x: (len(x), x)
    )
    postfixes = [x[len(prefix_b1) :].split("-")[0] for x in cols1]
    assert postfixes, "No timestamps found!"
    assert postfixes == [
        x[len(prefix_b2) :].split("-")[0] for x in cols2
    ], f"Postfixes for bands '{prefix_b1}' and '{prefix_b2}' don't match! ({cols1} vs {cols2})"
    assert postfixes == [
        x[len(prefix_b3) :].split("-")[0] for x in cols3
    ], f"Postfixes for bands '{prefix_b1}' and '{prefix_b3}' don't match! ({cols1} vs {cols3})"

    a = np.sqrt(np.square(w_b2 - w_b1) + np.square(df[cols2].to_numpy() - df[cols1].to_numpy()))
    b = np.sqrt(np.square(w_b3 - w_b2) + np.square(df[cols3].to_numpy() - df[cols2].to_numpy()))
    c = np.sqrt(np.square(w_b3 - w_b1) + np.square(df[cols3].to_numpy() - df[cols1].to_numpy()))

    # calculate angle with NIR as reference (ANIR)
    site_length = (np.square(a) + np.square(b) - np.square(c)) / (2 * a * b)
    site_length[site_length < -1] = -1
    site_length[site_length > 1] = 1
    return pd.DataFrame(
        1.0 / np.pi * np.arccos(site_length),
        columns=[f"{S2_TAG}-anir-ts{postfix}" for postfix in postfixes],
        index=df.index,
    )


def _compute_sar_ratio(df: pd.DataFrame) -> pd.DataFrame:
    """Compute S1 RATIO combination of S1 data."""
    prefix_vv = f"{S1_TAG}-VV-ts"
    prefix_vh = f"{S1_TAG}-VH-ts"
    cols_vv = sorted(
        [c for c in df.columns if c[: len(prefix_vv)] == prefix_vv], key=lambda x: (len(x), x)
    )
    cols_vh = sorted(
        [c for c in df.columns if c[: len(prefix_vh)] == prefix_vh], key=lambda x: (len(x), x)
    )
    postfixes = [x[len(prefix_vv) :].split("-")[0] for x in cols_vv]

    assert postfixes, "No timestamps found!"
    assert postfixes == [
        x[len(prefix_vh) :].split("-")[0] for x in cols_vh
    ], f"Postfixes for bands '{prefix_vv}' and '{prefix_vh}' don't match! ({cols_vv} vs {cols_vh})"

    sar_ratio = df[cols_vv].to_numpy() / df[cols_vh].to_numpy()
    return pd.DataFrame(
        sar_ratio,
        columns=[f"{S1_TAG}-RATIO-ts{postfix}" for postfix in postfixes],
        index=df.index,
    )


def scale_df(
    df: pd.DataFrame,
    scale_cfg: dict[str, tuple[float, float]],
    scale_dynamically: bool = True,
) -> tuple[pd.DataFrame, dict[str, tuple[float, float]]]:
    """
    Scale the time series columns in the dataset.

    Parameters
    ----------
    df : pd.DataFrame
        Dataframe to scale
    scale_cfg : dict[str, tuple[float, float]]
        Predefined scaling boundaries
    scale_dynamically : bool
        Update the scaling configuration dynamically, raise an exception otherwise

    Returns
    -------
    df : pd.DataFrame
        Scaled dataframe
    scale_cfg : dict[str, tuple[float, float]]
        Applied scaling boundaries
    """
    results = []
    for col in {x.split("-ts")[0] for x in df.columns}:
        if (
            (col[: len(S1_TAG)] == S1_TAG)
            or (col[: len(S2_TAG)] == S2_TAG)
            or (col[: len(METEO_TAG)] == METEO_TAG)
        ):
            # Compute the scaling range if not provided
            if col not in scale_cfg:
                if scale_dynamically:
                    scale_cfg[col] = _compute_range(df=df, key=col)
                else:
                    raise Exception(f"No scale range provided for column '{col}'!")

            # Rescale the column
            new_df = _rescale_ts(
                df=df,
                key=col,
                scale_range=scale_cfg[col],
            )
        elif BIOME_TAG in col:
            new_df = _rescale_biome(df=df)
        elif col in SCALE_IGNORE:
            new_df = df[[col]]
        else:
            bh_logger(f"   - Not scaling column '{col}'", lvl=LogLevel.WARNING)
            new_df = df[[col]]
        results.append(new_df)
    return pd.concat(results, axis=1), scale_cfg


def _compute_range(
    df: pd.DataFrame,
    key: str,
    min_quantile: float = 0.01,
    max_quantile: float = 0.99,
) -> tuple[float, float]:
    """Compute the range for the provided key."""
    cols = [c for c in df.columns if key in c]

    # Get the scaling boundaries
    x = df[cols].to_numpy()
    min_v = np.nanquantile(x, min_quantile)
    max_v = np.nanquantile(x, max_quantile)
    bh_logger(
        f"   - No scale range defined for '{key}', setting to bounds [{min_v:.5f}, {max_v:.5f}]",
        lvl=LogLevel.WARNING,
    )
    return min_v, max_v


def _rescale_ts(
    df: pd.DataFrame,
    key: str,
    scale_range: tuple[float, float],
    clip_values: bool = False,
) -> pd.DataFrame:
    """
    Rescale columns of a dataframe that contains key_str.

    Parameters
    ----------
    df : pd.DataFrame
        Dataframe to scale
    key : str
        Key to select the columns to scale
    scale_range : tuple[float, float]
        Predefined scaling boundaries
    clip_values : bool, optional
        Clip the values after rescaling, off by default

    Returns
    -------
    df : pd.DataFrame
        Dataframe containing scaled columns

    Note
    ----
     - Scaling is computed in the range [-1, 1].
     - Overflows can occur if clip_values=False.
    """
    min_v, max_v = scale_range

    # Get the proper columns
    cols = [c for c in df.columns if key in c]
    x = df[cols].to_numpy()

    # Rescale the columns between [-1, 1]
    x = 2 * ((x - min_v) / (max_v - min_v)) - 1

    # Clip the values if requested
    if clip_values:
        x = np.clip(x, -1, 1)

    return pd.DataFrame(x, columns=cols, index=df.index)


def _rescale_biome(df: pd.DataFrame) -> pd.DataFrame:
    """Rescale the biome columns to fractions in the [0, 1] range."""
    cols = [c for c in df.columns if BIOME_TAG in c]
    x = df[cols].to_numpy()
    return pd.DataFrame(x / 100, columns=cols, index=df.index)
