"""Lookup-tables used by the processing functions."""

from __future__ import annotations

import re
from functools import lru_cache
from pathlib import Path
from typing import Callable

import numpy as np
import pandas as pd

from vito_crop_classification.data_format.utils import transform_continents, transform_countries

LUT = {
    "1-1-1-0": ["^11-01-01-.*"],
    "1-1-2-0": ["^11-01-02-.*"],
    "1-1-3-0": ["^11-01-06-.*"],
    "1-1-4-0": ["^11-01-08-.*"],
    "1-1-5-1": ["^11-01-03-.*"],
    "1-1-5-2": ["^11-01-05-.*"],
    "1-1-5-3": ["^11-01-04-.*"],
    "1-2-1-0": ["^11-03-.*"],
    "1-2-2-0": ["^11-05-.*"],
    "1-3-1-0": ["^11-07-00-001-.*"],
    "1-3-2-0": ["^11-07-00-003-.*"],
    "1-4-1-0": ["^11-06-00-001-0"],
    "1-4-2-0": ["^11-06-00-002-0"],
    "1-4-3-0": ["^11-06-00-003-.*"],
    "1-4-4-0": ["^11-08-00-.*", "^11-08-01-.*", "^11-09-00-053-.*"],
    "1-5-0-0": ["^11-11-.*", "^13-.*"],
    "2-0-1-0": ["^12-01-00-001-.*"],
    "2-0-2-0": ["^12-03-00-001-.*"],
    "2-0-3-1": ["^12-01-01-.*", "^12-01-02-.*", "^12-01-03-.*", "^12-01-05-.*", "^12-01-06-.*"],
    "2-0-3-2": ["^12-01-04-.*"],
}

LUT2 = {
    "1-1-1-1s": ["^11-01-01-...-2"],
    "1-1-1-1w": ["^11-01-01-...-1"],
    "1-1-1-2w": ["^11-01-02-...-1"],
    "1-1-1-2s": ["^11-01-02-...-2"],
    "1-1-1-3w": ["^11-01-03-...-1"],
    "1-1-1-3s": ["^11-01-03-...-2"],
    "1-1-1-4w": ["^11-01-04-...-1"],
    "1-1-1-4s": ["^11-01-04-...-2"],
    "1-1-1-5w": ["^11-01-05-...-1"],
    "1-1-1-5s": ["^11-01-05-...-2"],
    "1-1-1-11": ["^11-01-07-001-0", "^11-01-07-002-0"],
    "1-1-1-12": ["11-01-09-000-0"],
    "1-1-1-13": ["^11-01-10-000-0"],
    "1-1-2-1": ["^11-01-06-.*"],
    "1-1-2-2": ["^11-01-07-003-0", "^11-01-07-004-0"],
    "1-1-3-0": ["^11-01-08-.*"],
    "1-2-1-0": ["^11-03-.*"],
    "1-2-2-1": ["^11-05-01-.*"],
    "1-2-2-2": ["^11-05-00-.*"],
    "1-3-1-0": ["^11-07-00-001-.*"],
    "1-3-2-0": ["^11-07-00-003-.*"],
    "1-4-1-0": ["^11-06-00-001-0"],
    "1-4-2-0": ["^11-06-00-002-0"],
    "1-4-3-0": ["^11-06-00-003-.*"],
    "1-4-4-0": ["^11-08-00-.*", "^11-08-01-.*", "^11-09-00-053-.*"],
    "1-5-1-0": ["^11-11-00-001-0", "^11-11-01-.*", "^13-.*"],
    "1-5-2-0": ["^11-11-02-.*", "^11-11-04-.*"],
    "2-1-0-0": ["^12-01-00-001-.*"],
    "2-2-0-0": ["^12-03-00-001-.*"],
    "2-3-1-1": ["^12-01-01-.*"],
    "2-3-1-2": ["^12-01-03-.*"],
    "2-3-1-3": ["^12-01-05-.*"],
    "2-3-2-0": ["^12-01-04-.*"],
}

CROP_LABEL_DESCR = {
    "1-6-0-0": "other",
    "1-1-1-0": "wheat",
    "1-1-2-0": "barley",
    "1-1-3-0": "maize",
    "1-1-4-0": "rice",
    "1-1-5-1": "rye",
    "1-1-5-2": "triticale",
    "1-1-5-3": "oats",
    "1-2-1-0": "vegetables",
    "1-2-2-0": "dry pulses",
    "1-3-1-0": "potatoes",
    "1-3-2-0": "sugar beet",
    "1-4-1-0": "sunflower",
    "1-4-2-0": "soybeans",
    "1-4-3-0": "rapeseed",
    "1-4-4-0": "flax, cotton and hemp",
    "1-5-0-0": "grass and fodder crops",
    "2-0-1-0": "grapes",
    "2-0-2-0": "olives",
    "2-0-3-1": "fruits",
    "2-0-3-2": "nuts",
}


@lru_cache
def get_lut(lut_name: str = "LUT") -> dict[str, list[int]]:
    """Get dictionary containing LUT for specified scenario."""
    return globals()[lut_name]


@lru_cache
def get_lut_rev(lut_name: str = "LUT") -> dict[str, str]:
    """Get dictionary containing reverted LUT for specified scenario."""
    res = get_lut(lut_name=lut_name)
    return {x: k for k, v in res.items() for x in v}


def translate_label(x: str, lut: dict[str, str]):
    """Search x in the look up table lut."""
    return next((v for k, v in lut.items() if re.search(k, x) is not None), "1-6-0-0")


def process_openeo(
    df: pd.DataFrame,
    lut_name: str | None = None,
    col_output: str | None = None,
    tranform_f: Callable | None = None,
    continents: list[str] | None = None,
    countries: list[str] | None = None,
    filter_biomes: bool = True,
) -> pd.DataFrame:
    """Process Openeo dataset to the standardised format.

    Parameters
    ----------
    df : pd.DataFrame
        Raw dataset
    lut_name : str | None, optional
        Name of the LUT to use, by default "LUT"
    col_output : str | None, optional
        Name of the output column in the raw dataset, by default "LABEL"
    scale_S1 : bool, optional
        If True rescale S1 bands with 10log10(x), by default "True"
    transform_f : Callable | None, optional
        Transform function to apply to input dataframe, function input is the dataframe only.
        Make use of partial() if other parameters are needed. By default "None"
    continents : list[str], optional
        Continents to keep
        Options: 'europe', 'north_america', 'south_america', 'africa', 'asia', 'australia'
    countries : list[str], optional
        countries to keep
        Options: All country codes in the world, if 'Other' in countries, everything not specified is considered as 'Other'


    Returns
    -------
    pd.DataFrame
        Formatted dataset
    """
    print(" - Processing openeo dataset with configuration:")
    print(f"   - Original shape: {df.shape}")
    if lut_name is not None:
        print(f"   - lut: {lut_name}")
        print(f"   - col_output: {col_output}")

    print(" - Re-indexing dataframe..")
    df.index = list(range(len(df)))

    # select countries
    print(" - Extract countries..")
    transform_countries(df)
    if countries is not None:
        print(f"   - Selecting countries {', '.join(countries)}..")
        if "Other" in countries:
            df["country"] = df["country"].apply(lambda x: x if x in countries else "Other")
        df.drop(df[~df["country"].isin(countries)].index, inplace=True)
    print(f"   - Done! {df.shape}")

    print(" - Extract continents..")
    transform_continents(df)
    if continents is not None:
        print(f"   - Selecting continents {', '.join(continents)}..")
        df.drop(df[~df["continent"].isin(continents)].index, inplace=True)
    print(f"   - Done! {df.shape}")

    if tranform_f is not None:
        print(" - Applying transform function..")
        df = tranform_f(df)

    if lut_name is not None:
        print(f" - Transform {col_output} column to correct format..")
        lut = get_lut_rev(lut_name="LUT")
        target_id = df[col_output].apply(lambda x: translate_label(x, lut))
        df["target_id"] = target_id
        df["target_name"] = [CROP_LABEL_DESCR[x] for x in target_id]
        print(f"   - Done! {df.shape}")

    if filter_biomes:
        print(" - Filter biomes which are not well represented..")
        biomes = np.array([c for c in df.columns if "biome" in c])
        counts = (df[biomes] > 0).sum().values
        for biome, count in zip(biomes, counts):
            if count > 0 and count < 500:
                print(
                    f"   - Removing {biome}, too few samples: {count}/{len(df)} ({count/len(df):.6f}%)"
                )
                df = df[~(df[biome] > 0)]
        print(f"   - Done! {df.shape}")

        print("   - Removing biome columns..")
        mask = (df[biomes] > 0).sum().values < 500
        df = df.drop(columns=biomes[mask])
        print(f"   - Done! {df.shape}")

    print(" - Generating unique keys per sample..")
    cols = [c for c in df.columns if "ts" in c]
    df["key_ts"] = [hash(tuple(vector)) for vector in df[cols].values]
    df["key_field"] = df[["sampleID", "target_name"]].apply(
        lambda x: x["sampleID"] + x["target_name"], axis=1
    )

    print("   - Dropping duplicates..")
    df = df.drop_duplicates(["key_ts"])
    print(f"   - Done! {df.shape}")

    return df


if __name__ == "__main__":
    df = pd.read_parquet(Path.home() / "Downloads/HRL_GEE-DEFAULT-nofilters/24raw.parquet")
    df = process_openeo(
        df=df,
        countries=[
            "LV",
            "FR",
            "BE",
            "FI",
            "ES",
            "AT",
            "PT",
            "SE",
            "DK",
            "HR",
            "SK",
            "LT",
            "EE",
            "SI",
            "DE",
            "Other",
        ],
    )
    df.to_parquet(
        Path.home() / "data/vito/crop_classification/data/24ts-DEFAULT-nofilters/df.parquet"
    )
