"""Outliers detection utils."""

from __future__ import annotations

import numpy as np
import pandas as pd
from scipy.stats import pearsonr

from vito_crop_classification.vito_logger import bh_logger


def detect_and_prune_outliers(
    df: pd.DataFrame, outliers_perc: float, cols: list[str] | None = None
) -> pd.DataFrame:
    """Detect and prune outliers according to outliers score.

    Parameters
    ----------
    df : pd.DataFrame
        Dataframe over with detect outliers
    outliers_perc : float
        Percentage of detected outliers to prune
    cols : list[str] | None
        List of columns to use for outliers detection.
        If None, indices columns are used, by default None

    Returns
    -------
    pd.DataFrame
        pruned dataframe
    """
    df_ = df.copy()
    cols = cols or [
        "ts_ndvi",
        "ts_ndmi",
        "ts_ndwi",
        "ts_ndgi",
        "ts_ndti",
        "ts_anir",
        "ts_ndre1",
        "ts_ndre2",
        "ts_ndre5",
    ]

    bh_logger("   - Compute outlier metrics..")
    df_ = _compute_outliers_metrics(df_, cols)

    bh_logger("   - Sorting by outlier score..")
    df_ = df_.sort_values(["outlier_score"])

    bh_logger("   - Pruning outliers..")
    return df_.iloc[int(len(df_) * outliers_perc) :]


def _compute_outliers_metrics(df: pd.DataFrame, cols: list[str]) -> pd.DataFrame:
    """Compute 4 different outliers metrics using the specified columns.

    Metrics:
    - mean_absolute_error
    - mean_squared_error
    - standard_deviation
    - pearson_correlation
    """
    bh_logger("     - Extract target medians..")
    target_medians = _extract_target_medians(df, cols)
    target_tss = np.stack(df["target_id"].apply(lambda x: target_medians[x]))

    bh_logger("     - Extract samples..")
    sample_tss = _extract_tss(df, cols)

    # optimize metrics calculation by computing substraction beforehand
    subtraction = sample_tss - target_tss
    bh_logger("     - Compute mean absolute error..")
    df["mean_absolute_error"] = np.mean(np.abs(subtraction), axis=1)
    bh_logger("     - Compute mean squared error..")
    df["mean_squared_error"] = np.mean((subtraction) ** 2, axis=1)
    bh_logger("     - Compute standard deviation..")
    df["standard_deviation"] = np.std(np.abs(subtraction), axis=1)
    bh_logger("     - Compute pearson correlation..")
    df["pearson_correlation"] = [pearsonr(x, y)[0] for x, y in zip(sample_tss, target_tss)]

    bh_logger("     - Combining metrics..")
    df["outlier_score"] = (
        3 * len(df)
        - df["standard_deviation"].rank()
        - df["mean_absolute_error"].rank()
        - df["mean_squared_error"].rank()
        + df["pearson_correlation"].rank()
    )

    return df


def _extract_tss(df: pd.DataFrame, ts_cols: list[str]) -> list[np.ndarray[float]]:
    """Extract tss vectors from a dataset and a set of columns."""
    tss = np.stack(df[ts_cols].values)
    tss = [np.concatenate(np.stack(ts)) for ts in tss]
    return tss


def _extract_target_medians(df: pd.DataFrame, cols: list[str]) -> dict[str, np.ndarray[float]]:
    """Extract target medians for a set of targets and columns.

    Parameters
    ----------
    df : pd.DataFrame
        DataFrame containing target samples
    cols : list[str]
        Columns over which compute median

    Returns
    -------
    dict[str, np.ndarray[float]]
        Dictionary containing a target median for each target in df
    """
    # Compute median for all target classes in training set
    target_medians = {}
    for t in df.target_id.unique():
        df_target = df[df.target_id == t]
        target_tss = _extract_tss(df_target, cols)
        target_medians[t] = np.median(target_tss, axis=0)
    return target_medians
