"""Group filter."""

from __future__ import annotations

from collections import Counter
from pathlib import Path

import numpy as np
import pandas as pd
import torch
from numpy.typing import NDArray

from vito_crop_classification.filters.base import BaseFilter


class GroupFilter(BaseFilter):
    """Group filter."""

    def __init__(
        self,
        classes: NDArray[np.str_],
        groups: NDArray[np.int64],
        filter_f: Path | None = None,
    ) -> None:
        """
        Initialize GroupFilter.

        Parameters
        ----------
        classes : NDArray[np.str_]
            Classes used for training
        groups : NDArray[np.int64]
            Groups used to filter over
        filter_f : Path | None
            Folder to store the filters in, get_filters_folder() by default
        """
        self._groups = [f"{g}" for g in groups]
        self._min_freq_sample = 0.0
        self._min_freq_group = 0.0
        super(GroupFilter, self).__init__(
            classes=classes,
            columns=list(self._groups),
            filter_f=filter_f,
        )

    def forward_process(self, data: pd.DataFrame, preds: torch.Tensor) -> torch.Tensor:
        """Apply filter over the predictions."""
        pred_filter = np.array([self._filter[str(g)] != 0 for g in data["sc_group"]])
        preds[~pred_filter] = 0.0
        return preds

    def train(
        self,
        df: pd.DataFrame,
        min_freq_sample: float = 0.01,
        min_freq_group: float = 0.01,
    ) -> None:
        """
        Train the filter on the provided dataset.

        Parameters
        ----------
        df : pd.DataFrame
            Dataset to extract patterns from
        min_freq_sample : float
            Minimum sample frequency within a group, in order for a class to belong to this group
        min_freq_group : float
            Minimum group frequency before filtering on this group
        """
        self._min_freq_sample = min_freq_sample
        self._min_freq_group = min_freq_group
        self._filter = self.create_new()
        group_freq = Counter([str(z) for z in df["sc_group"]])
        for group in self._groups:
            # Ignore if the group is not frequent enough
            if group_freq[group] < (self._min_freq_group * len(df)):
                continue

            # Check if the sample is frequent enough in the group
            arr = np.ones(len(self._classes))
            df_group = df[df["sc_group"] == int(group)]
            for cls_i, cls in enumerate(self._classes):
                n = sum(df_group["target_id"] == cls)
                arr[cls_i] = n >= (self._min_freq_sample * group_freq[group])
            self._filter[group] = arr

    def explain(self) -> None:
        """Explain filter verbally."""
        print(f"Explanation -> {self}")
        print(" ---")
        print(" This filter checks in which groupID a certain crop can grow, and discard ")
        print(" all crops that can not grow at the specified input groupID. A crop is considered")
        print(
            f" plausible growing in a groupID if it shows it can grow there at least {100*self._min_freq_sample:.2f}%"
        )
        print(
            f" plausible growing in a groupID if the groupID has a presence of at least {100*self._min_freq_group:.2f}%"
        )
        print(" (min_freq) times in the training set.")
        print(" ---")
        print(" The following groupID filters have been derived from the training set:")
        classes_joint = " | ".join([f"{x:^7s}" for x in self._classes])
        print(f"  - <groups> -> {classes_joint}")
        for g in self._groups:
            joint = " | ".join(["   X   " if x > 0 else "       " for x in self._filter[g]])
            print(f"  - {g:^7s} -> {joint}")


if __name__ == "__main__":
    from random import randint

    from vito_crop_classification.data_io import load_datasets

    my_result = load_datasets(dataset="hybrid_dataset", dataset_ratio=0.01)
    my_df_train = my_result["df_train"]
    my_df_val = my_result["df_val"]
    my_df_train["target_id"] = [f"{randint(1,3)}" for _ in range(len(my_df_train))]
    my_df_train["sc_group"] = [randint(4, 6) for _ in range(len(my_df_train))]
    my_df_val["sc_group"] = [randint(4, 6) for _ in range(len(my_df_val))]

    # Create filter
    print("\nCreating filter..")
    my_filter = GroupFilter(classes=np.asarray(["1", "2", "3"]), groups=np.asarray([4, 5, 6]))
    my_filter.train(my_df_train, min_freq_group=0.01, min_freq_sample=0.01)
    my_filter.save()
    print(" - Filter:", my_filter)

    # Load in a new filter
    print("\nLoading in old filter..")
    my_filter = GroupFilter(classes=np.asarray(["1", "2", "3"]), groups=np.asarray([4, 5, 6]))
    print(" - Filter:", my_filter)

    # Run inference on the filter
    print("\nRunning the filter..")
    my_filter(my_df_val, torch.rand(len(my_df_val), 3))
    my_filter.explain()
