"""Lookup-tables used by the processing functions."""

from __future__ import annotations

from functools import lru_cache
from pathlib import Path

import pandas as pd

from vito_crop_classification.data_format.utils import transform_continents, transform_countries

# 4 scenarios as discussed earlier:

# NOTE: all samples with LC == 13 should be added to crop type 1-5-0-0 !!

SCENARIO_1 = {
    "1-1-1-0": [1100, 1110, 1120],  # wheat
    "1-1-2-0": [1500, 1510, 1520],  # barley
    "1-1-3-0": [1200],  # maize
    "1-1-4-0": [1300],  # rice
    "1-1-5-1": [1600, 1610, 1620],  # rye
    # "1-1-5-2": [],  # triticale (not present!)
    "1-1-5-3": [1700],  # oats
    # "1-1-5-4": [1400, 1800, 1900, 1910, 1920],  # other cereals (decided not to use anymore!)
    "1-2-1-0": [  # fresh vegetables
        2000,
        2100,
        2110,
        2120,
        2130,
        2140,
        2150,
        2160,
        2170,
        2190,
        2200,
        2210,
        2220,
        2230,
        2240,
        2250,
        2260,
        2290,
        2300,
        2310,
        2320,
        2330,
        2340,
        2350,
        2390,
        2400,
        2900,
    ],  # fresh vegetables
    "1-2-2-0": [7100, 7200, 7300, 7400, 7700, 7800, 7500, 7600],  # dry pulses
    "1-3-1-0": [5100],  # potatoes
    "1-3-2-0": [8100],  # sugar beet
    "1-4-1-0": [4380],  # sunflower
    "1-4-2-0": [4100],  # soybeans
    "1-4-3-0": [4350, 4351, 4352],  # rapeseed
    "1-4-4-0": [9211, 9213],  # flax, cotton and hemp
    "1-5-0-0": [7910, 7920, 9120, 9110, 9100],  # grass and fodder crops
    "2-0-1-0": [3300],  # grapes
    "2-0-2-0": [4420],  # olives
    "2-0-3-1": [  # fruits
        3900,
        3100,
        3110,
        3120,
        3130,
        3140,
        3150,
        3160,
        3170,
        3500,
        3510,
        3520,
        3530,
        3540,
        3550,
        3560,
        3590,
        3200,
        3210,
        3220,
        3230,
        3240,
        3290,
        3400,
        3410,
        3420,
        3430,
        3440,
        3450,
        3460,
        3490,
    ],  # fruits
    "2-0-3-2": [3600, 3610, 3620, 3630, 3640, 3650, 3660, 3690],  # nuts
}


SCENARIO_2 = {
    "1-1-1-0": [1100, 1110, 1120],  # wheat
    "1-1-2-0": [1500, 1510, 1520],  # barley
    "1-1-3-0": [1200],  # maize
    "1-1-4-0": [1300],  # rice
    "1-1-5-0": [1400, 1800, 1900, 1910, 1920, 1600, 1610, 1620, 1700],  # other cereals
    "1-2-1-0": [  # vegetables
        2000,
        2100,
        2110,
        2120,
        2130,
        2140,
        2150,
        2160,
        2170,
        2190,
        2200,
        2210,
        2220,
        2230,
        2240,
        2250,
        2260,
        2290,
        2300,
        2310,
        2320,
        2330,
        2340,
        2350,
        2390,
        2400,
        2900,
    ],  # vegetables
    "1-2-2-0": [7100, 7200, 7300, 7400, 7700, 7800, 7500, 7600],  # dry pulses
    "1-3-1-0": [5100],  # potatoes
    "1-3-2-0": [8100],  # sugar beet
    "1-4-1-0": [4380],  # sunflower
    "1-4-2-0": [4100],  # soybeans
    "1-4-3-0": [4350, 4351, 4352],  # rapeseed
    "1-4-4-0": [9211, 9213],  # flax, cotton and hemp
    "1-5-0-0": [7910, 7920, 9120, 9110, 9100],  # grass and fodder crops
    "2-0-1-0": [3300],  # grapes
    "2-0-2-0": [4420],  # olives
    "2-0-3-1": [  # fruits
        3900,
        3100,
        3110,
        3120,
        3130,
        3140,
        3150,
        3160,
        3170,
        3500,
        3510,
        3520,
        3530,
        3540,
        3550,
        3560,
        3590,
        3200,
        3210,
        3220,
        3230,
        3240,
        3290,
        3400,
        3410,
        3420,
        3430,
        3440,
        3450,
        3460,
        3490,
    ],  # fruits
    "2-0-3-2": [3600, 3610, 3620, 3630, 3640, 3650, 3660, 3690],  # nuts
}


SCENARIO_3 = {
    "1-1-1-0": [1100, 1110, 1120],  # wheat
    "1-1-2-0": [1500, 1510, 1520],  # barley
    "1-1-3-0": [1200],  # maize
    "1-1-4-0": [1300],  # rice
    "1-1-5-0": [1400, 1800, 1900, 1910, 1920, 1600, 1610, 1620, 1700],  # other cereals
    "1-2-1-0": [  # vegetables
        2000,
        2100,
        2110,
        2120,
        2130,
        2140,
        2150,
        2160,
        2170,
        2190,
        2200,
        2210,
        2220,
        2230,
        2240,
        2250,
        2260,
        2290,
        2300,
        2310,
        2320,
        2330,
        2340,
        2350,
        2390,
        2400,
        2900,
    ],  # vegetables
    "1-2-2-0": [7100, 7200, 7300, 7400, 7700, 7800, 7500, 7600],  # dry pulses
    "1-3-1-0": [5100],  # potatoes
    "1-3-2-0": [8100],  # sugar beet
    "1-4-1-0": [4380],  # sunflower
    "1-4-2-0": [4100],  # soybeans
    "1-4-3-0": [4350, 4351, 4352],  # rapeseed
    "1-4-4-0": [9211, 9213],  # flax, cotton and hemp
    "1-5-0-0": [7910, 7920, 9120, 9110, 9100],  # grass and fodder crops
    "2-0-1-0": [3300],  # grapes
    "2-0-2-0": [4420],  # olives
    "2-0-3-0": [  # fruits and nuts
        3900,
        3100,
        3110,
        3120,
        3130,
        3140,
        3150,
        3160,
        3170,
        3500,
        3510,
        3520,
        3530,
        3540,
        3550,
        3560,
        3590,
        3200,
        3210,
        3220,
        3230,
        3240,
        3290,
        3400,
        3410,
        3420,
        3430,
        3440,
        3450,
        3460,
        3490,
        3600,
        3610,
        3620,
        3630,
        3640,
        3650,
        3660,
        3690,
    ],  # fruits and nuts
}


SCENARIO_4 = {
    "1-1-1-0": [1100, 1110, 1120],  # wheat
    "1-1-2-0": [1500, 1510, 1520],  # barley
    "1-1-3-0": [1200],  # maize
    "1-1-4-0": [1300],  # rice
    "1-1-5-1": [1600, 1610, 1620],  # rye
    # "1-1-5-2": [],  # triticale (not present!)
    "1-1-5-3": [1700],  # oats
    "1-1-5-4": [1400, 1800, 1900, 1910, 1920],  # other cereals
    "1-2-1-0": [  # vegetables
        2000,
        2100,
        2110,
        2120,
        2130,
        2140,
        2150,
        2160,
        2170,
        2190,
        2200,
        2210,
        2220,
        2230,
        2240,
        2250,
        2260,
        2290,
        2300,
        2310,
        2320,
        2330,
        2340,
        2350,
        2390,
        2400,
        2900,
    ],  # vegetables
    "1-2-2-0": [7100, 7200, 7300, 7400, 7700, 7800, 7500, 7600],  # dry pulses
    "1-3-1-0": [5100],  # potatoes
    "1-3-2-0": [8100],  # sugar beet
    "1-4-1-0": [4380],  # sunflower
    "1-4-2-0": [4100],  # soybeans
    "1-4-3-0": [4350, 4351, 4352],  # rapeseed
    "1-4-4-0": [9211, 9213],  # flax, cotton and hemp
    "1-5-0-0": [7910, 7920, 9120, 9110, 9100],  # grass and fodder crops
    "2-0-1-0": [3300],  # grapes
    "2-0-2-0": [4420],  # olives
    "2-0-3-0": [  # fruits and nuts
        3900,
        3100,
        3110,
        3120,
        3130,
        3140,
        3150,
        3160,
        3170,
        3500,
        3510,
        3520,
        3530,
        3540,
        3550,
        3560,
        3590,
        3200,
        3210,
        3220,
        3230,
        3240,
        3290,
        3400,
        3410,
        3420,
        3430,
        3440,
        3450,
        3460,
        3490,
        3600,
        3610,
        3620,
        3630,
        3640,
        3650,
        3660,
        3690,
    ],  # fruits and nuts
}


# REMAINING QUESTION: how to deal with crops not included in the lists above?

# Option 1: we just ignore them during training

# Option 2: we create 2 "garbage" classes and add these to the training:
OTHER_CROPS_V1 = {
    "1-6-0-0": [
        5200,
        5300,
        5400,
        5900,
        6200,
        6211,
        6212,
        6219,
        6221,
        6222,
        6223,
        6224,
        6225,
        6226,
        6229,
        4360,
        4340,
        4330,
        4320,
        4310,
        4300,
        4200,
        7900,
        8200,
        8300,
        8900,
        9219,
        9310,
        9510,
        9600,
        9910,
    ],
    "2-0-4-0": [
        6000,
        6100,
        6110,
        6120,
        6130,
        6140,
        6190,
        9920,
        9520,
        9400,
        9320,
        9220,
        6229,
        4490,
        4430,
        4410,
        4400,
    ],
}

# Option 3: we just put them all together in one "garbage" class and add it to the training:
OTHER_CROPS_V2 = {
    "3-0-0-0": [
        5200,
        5300,
        5400,
        5900,
        6200,
        6211,
        6212,
        6219,
        6221,
        6222,
        6223,
        6224,
        6225,
        6226,
        6229,
        4360,
        4340,
        4330,
        4320,
        4310,
        4300,
        4200,
        7900,
        8200,
        8300,
        8900,
        9219,
        9310,
        9510,
        9600,
        9910,
        6000,
        6100,
        6110,
        6120,
        6130,
        6140,
        6190,
        9920,
        9520,
        9400,
        9320,
        9220,
        6229,
        4490,
        4430,
        4410,
        4400,
    ]
}

# labels for all crop types to be used for plotting results
CROP_TYPE_LABELS = {
    "0-0-0-0": "other",
    "1-1-0-9": "grouped cereals",
    "1-1-1-0": "wheat",
    "1-1-2-0": "barley",
    "1-1-3-0": "maize",
    "1-1-4-0": "rice",
    "1-1-5-0": "other cereals",
    "1-1-5-1": "rye",
    "1-1-5-2": "triticale",
    "1-1-5-3": "oats",
    "1-1-5-4": "other cereals",
    "1-2-0-9": "dry pulses",
    "1-2-1-0": "vegetables",
    "1-2-2-0": "dry pulses",
    "1-3-1-0": "potatoes",
    "1-3-2-0": "sugar beet",
    "1-4-1-0": "sunflower",
    "1-4-2-0": "soybeans",
    "1-4-3-0": "rapeseed",
    "1-4-4-0": "flax, cotton and hemp",
    "1-5-0-0": "grass and fodder crops",
    "2-0-0-0": "fruits",
    "2-0-1-0": "grapes",
    "2-0-2-0": "olives",
    "2-0-3-0": "fruits and nuts",
    "2-0-3-1": "fruits",
    "2-0-3-2": "nuts",
    "1-6-0-0": "other arable crops",
    "2-0-4-0": "other permanent crops",
    "3-0-0-0": "other crops",
}


@lru_cache
def get_lut(scenario: int = 1, other_opt: int = 0) -> dict[str, list[int]]:
    """Get dictionary containing LUT for specified scenario."""
    res: dict[str, list[int]] = globals()[f"SCENARIO_{scenario}"]
    if other_opt in {1, 2}:
        other_dict: dict[str, list[int]] = globals()[f"OTHER_CROPS_V{other_opt}"]
        for k, v in other_dict.items():
            res[k] = v
    return res


@lru_cache
def get_lut_rev(scenario: int = 1, other_opt: int = 0) -> dict[int, str]:
    """Get dictionary containing reverted LUT for specified scenario."""
    res = get_lut(scenario=scenario, other_opt=other_opt)
    return {x: k for k, v in res.items() for x in v}


def process_hybrid(
    df: pd.DataFrame,
    scenario: int = 1,
    other_opt: int = 0,
    col_output: str = "LABEL",
    col_lc: str = "LC",
    continents: list[str] | None = None,
    countries: list[str] | None = None,
    ambiguous_codes: list[int] | None = [991, 9998, 1900, 1910, 1920],
) -> pd.DataFrame:
    """
    Process the hybrid dataset to the standardised format.

    Parameters
    ----------
    df : pd.DataFrame
        DataFrame to process
    scenario : int
        Hybrid dataset scenario to transform labels to
    other_opt : int
        Hybrid dataset other option
    col_output : str
        Column name for the output
    col_lc : str
        Column name for the LC
    continents : list[str], optional
        Continents to keep
        Options: 'europe', 'north_america', 'south_america', 'africa', 'asia', 'australia'
    countries : list[str], optional
        countries to keep
        Options: All country codes in the world, if 'Other' in countries, everything not specified is considered as 'Other'
    ambiguous_cls : list[str]
        Ambiguous classes to remove
    """
    print(" - Processing hybrid dataset with configuration:")
    print(f"   - Scenario: {scenario}")
    print(f"   - Other option: {other_opt}")
    print(f"   - col_output: {col_output}")
    print(f"   - col_lc: {col_lc}")
    print(f"   - countries: {', '.join(countries or [])}")
    print(f"   - continents: {', '.join(continents or [])}")
    print(f"   - ambiguous_codes: {ambiguous_codes}")

    print(" - Re-indexing dataframe")
    df.index = list(range(len(df)))
    print(f"   - Done! {df.shape}")

    print(" - Extract countries..")
    transform_countries(df)
    if countries is not None:
        print(f"   - Selecting countries {', '.join(countries)}..")
        if "Other" in countries:
            df["country"] = df["country"].apply(lambda x: x if x in countries else "Other")
        df.drop(df[~df["country"].isin(countries)].index, inplace=True)
    print(f"   - Done! {df.shape}")

    # select continents
    print(" - Transforming zone-IDs to continents..")
    transform_continents(df)
    if continents is not None:
        print(f"   - Selecting continents {', '.join(continents)}..")
        df.drop(df[~df["continent"].isin(continents)].index, inplace=True)
    print(f"   - Done! {df.shape}")

    # Remove ambiguous classes
    print(" - Removing ambiguous codes..")
    df.drop(df[df[col_output].isin(ambiguous_codes)].index, inplace=True)
    print(f"   - Done! {df.shape}")

    # Drop old/unnecessary labels
    print(" - Removing unneccesary labels..")
    df.drop(columns=["WORLDCOVER-LABEL-10m", "POTAPOV-LABEL-10m"], errors="ignore", inplace=True)
    print(f"   - Done! {df.shape}")

    # Transform targets column
    print(f" - Remapping output to scenario {scenario} with `other option` v{other_opt}..")
    table = get_lut_rev(scenario=scenario, other_opt=other_opt)
    if col_lc in df.columns:
        cols = [col_output, col_lc]
    else:
        cols = [col_output]
        print(
            f"   - Warning: col_lc='{col_lc}' not in columns! ({', '.join([c for c in df.columns if all(x not in c for x in ('-ts', '-TS'))])})"
        )
    targets = df[cols].apply(
        lambda x: _process_hybrid_row(x=x, table=table, col_output=col_output, col_lc=col_lc),
        axis=1,
    )
    df["target_id"] = targets
    df["target_name"] = [CROP_TYPE_LABELS[x] for x in targets]
    print(f"   - Done! {df.shape}")

    print(" - Generating unique keys per sample and fields..")
    cols = [c for c in df.columns if "ts" in c]
    df["key_ts"] = [hash(tuple(vector)) for vector in df[cols].values]
    df["key_field"] = df[["sampleID", "target_name"]].apply(
        lambda x: x["sampleID"] + x["target_name"], axis=1
    )

    print("   - Dropping duplicates..")
    df = df.drop_duplicates(["key_ts"])
    print(f"   - Done! {df.shape}")

    return df


def _process_hybrid_row(
    x: pd.Series,
    table: dict[int, str],
    col_output: str = "OUTPUT",
    col_lc: str = "LC",
) -> str:
    """Process a hybrid column's row."""
    # Multiprocessing slows down since function is so simple
    output = int(x[col_output])
    if col_lc in x.index:
        if x[col_output] in table:
            return "1-5-0-0" if output == 9120 and x[col_lc] == 13 else table[output]
        else:
            return "0-0-0-0"
    else:
        return table.get(output, "0-0-0-0")


if __name__ == "__main__":
    df = pd.read_parquet(
        Path.home() / "data/vito/crop_classification/data/hybrid_dataset/raw.parquet"
    )
    df = process_hybrid(
        df=df,
        continents=["europe"],
        countries=["LV", "FR", "BE", "GB", "FI", "ES", "AT", "Other"],
    )
    df.to_parquet(Path.home() / "data/vito/crop_classification/data/hybrid_dataset/df.parquet")
