"""Biome filter."""

from __future__ import annotations

from collections import Counter
from pathlib import Path

import numpy as np
import pandas as pd
import torch
from numpy.typing import NDArray

from vito_crop_classification.filters.base import BaseFilter


class BiomeFilter(BaseFilter):
    """Biome filter."""

    def __init__(
        self,
        classes: NDArray[np.str_],
        biomes: NDArray[np.str_],
        filter_f: Path | None = None,
    ) -> None:
        """
        Initialize BiomeFilter.

        Parameters
        ----------
        classes : NDArray[np.str_]
            Classes used for training
        biomes : NDArray[np.str_]
            Biomes used to filter over
        filter_f : Path | None
            Folder to store the filters in, get_filters_folder() by default
        """
        self._biomes = biomes
        self._min_freq_sample = 0.0
        self._min_freq_biome = 0.0
        super(BiomeFilter, self).__init__(
            classes=classes,
            columns=list(self._biomes),
            filter_f=filter_f,
        )

    def forward_process(self, data: pd.DataFrame, preds: torch.Tensor) -> torch.Tensor:
        """Apply filter over the predictions."""
        biomes = data["mh_biome"].apply(lambda x: np.where(x == 1.0)[0])
        pred_filter = np.array(
            [self._filter[[f"biome_{i:02d}" for i in b]].sum(1) > 0 for b in biomes]
        )
        preds[~pred_filter] = 0.0
        return preds

    def train(
        self,
        df: pd.DataFrame,
        min_freq_sample: float = 0.01,
        min_freq_biome: float = 0.01,
    ) -> None:
        """
        Train the filter on the provided dataset.

        Parameters
        ----------
        df : pd.DataFrame
            Dataset to extract patterns from
        min_freq_sample : float
            Minimum sample frequency within a biome, in order for a class to belong to this biome
        min_freq_biome : float
            Minimum biome frequency before filtering on this biome
        """
        self._min_freq_sample = min_freq_sample
        self._min_freq_biome = min_freq_biome
        self._filter = self.create_new()
        biome_freq = Counter(
            [f"biome_{i:02d}" for biome in df["mh_biome"] for i, v in enumerate(biome) if v > 0]
        )
        for biome_i, biome in enumerate(self._biomes):
            # Ignore if the biome is not frequent enough
            if biome_freq[biome] < (self._min_freq_biome * len(df)):
                continue

            # Check if the sample is frequent enough in the biome
            arr = np.ones(len(self._classes))
            df_biome = df[df[["mh_biome"]].apply(lambda x: x["mh_biome"][biome_i] > 0, axis=1)]
            for cls_i, cls in enumerate(self._classes):
                n = sum(df_biome["target_id"] == cls)
                arr[cls_i] = n >= (self._min_freq_sample * biome_freq[biome])
            self._filter[biome] = arr

    def explain(self) -> None:
        """Explain filter verbally."""
        print(f"Explanation -> {self}")
        print(" ---")
        print(" This filter checks in which biomes a certain crop can grow, and discard ")
        print(" all crops that can not grow at the specified input biome. A crop is considered")
        print(
            f" plausible growing in a biome if it shows it can grow there at least "
            f"{100 * self._min_freq_sample:.2f}%"
        )
        print(
            f" plausible growing in a biome if the biome has a presence of at least "
            f"{100 * self._min_freq_biome:.2f}%"
        )
        print(" (min_freq) times in the training set.")
        print(" ---")
        print(" The following biome filters have been derived from the training set:")
        classes_joint = " | ".join([f"{x:^7s}" for x in self._classes])
        print(f"  - <biomes> -> {classes_joint}")
        for b in self._biomes:
            joint = " | ".join(["   X   " if x > 0 else "       " for x in self._filter[b]])
            print(f"  - {b:^7s} -> {joint}")


if __name__ == "__main__":
    from random import randint

    from vito_crop_classification.data import load_datasets

    my_result = load_datasets(dataset="hybrid_dataset", dataset_ratio=0.01)
    my_df_train = my_result["df_train"]
    my_df_val = my_result["df_val"]
    my_df_train["target_id"] = [f"{randint(1,3)}" for _ in range(len(my_df_train))]
    n_biomes = len(my_df_train["mh_biome"].iloc[0])
    my_biomes = np.asarray([f"biome_{i:02d}" for i in range(n_biomes)])

    # Create filter
    print("\nCreating filter..")
    my_filter = BiomeFilter(classes=np.asarray(["1", "2", "3"]), biomes=my_biomes)
    my_filter.train(my_df_train, min_freq_biome=0.01, min_freq_sample=0.01)
    my_filter.save()
    print(" - Filter:", my_filter)

    # Load in a new filter
    print("\nLoading in old filter..")
    my_filter = BiomeFilter(classes=np.asarray(["1", "2", "3"]), biomes=my_biomes)
    print(" - Filter:", my_filter)

    # Run inference on the filter
    print("\nRunning the filter..")
    my_filter(my_df_val, torch.rand(len(my_df_val), 3))
    my_filter.explain()
