"""Patch specific functionality."""

from __future__ import annotations

import numpy as np
import pandas as pd
from numpy.typing import NDArray

DISTANCE = np.asarray([[3, 2, 3], [2, 1, 2], [3, 2, 3]])


# TODO: Add masking!
def smooth_patch(
    df: pd.DataFrame,
    shape: tuple[int, int],
    class_ids: NDArray[np.str_],
    class_names: NDArray[np.str_],
) -> pd.DataFrame:
    """
    Smooth the patch using a weighted sum of the distance and probabilities of neighbouring pixels.

    Parameters
    ----------
    df : pd.DataFrame
        Prediction dataframe
    shape : tuple[int, int]
        Shape of the corresponding patch
    class_ids : NDArray[np.str_]
        Class IDs used by the dataframe (corresponds with probabilities)
    class_names : NDArray[np.str_]
        Class names used by the dataframe (corresponds with probabilities)

    Return
    ------
    df : pd.DataFrame
        Updated predictions
    """
    n_classes = len(df.iloc[0].probabilities)
    assert len(class_ids) == len(class_names) == n_classes
    prob_map = np.stack(df.probabilities).reshape((shape[0], shape[1], n_classes))

    # Smooth the predictions
    prob_padded = np.zeros((shape[0] + 2, shape[1] + 2, n_classes))
    prob_padded[1:-1, 1:-1] = prob_map
    result = np.zeros_like(prob_map)
    for x in range(1, prob_padded.shape[0] - 1):
        for y in range(1, prob_padded.shape[1] - 1):
            segment = prob_padded[x - 1 : x + 2, y - 1 : y + 2]
            weighted = segment / DISTANCE[:, :, None]
            probs = weighted.sum(axis=(0, 1))
            result[x - 1, y - 1, :] = probs / probs.sum()

    # Unfold the result
    probabilities = result.reshape((len(df), n_classes))
    # df["probabilities"] = probabilities
    best_indices = np.argmax(probabilities, axis=1)
    df["probability"] = [p[i] for p, i in zip(probabilities, best_indices)]
    df["prediction_id"] = [class_ids[i] for i in best_indices]
    df["prediction_name"] = [class_names[i] for i in best_indices]
    return df
