# %%
import pandas as pd
import seaborn as sns
from matplotlib import pyplot as plt
from cropclass import get_ct_classes, get_lc_classes
from cropclass.training import get_label
from cropclass import LUT_LC


sns.set()

# %%
df = pd.read_parquet(('/vitodata/CropSAR/cropmap/NEXTLAND/worldcereal/features/'
                      'features_catboost_5pix-EURUSDA/custom_01-01_10-31_CIB/'
                      'TEST/training_df_LC.parquet'))

# %%
# Translate LC/CT to output values
df['LABEL'] = df.apply(lambda row: get_label(row, 'OUTPUT', LUT_LC), axis=1)

# Remove the unknown labels
remove_idx = (df['LABEL'] == 0)
df = df[~remove_idx].copy()

# %%
BANDS_CROPLAND = [
    'L2A-ndvi-p10-10m',
    'L2A-ndvi-p50-10m',
    'L2A-ndvi-p90-10m',
    'L2A-ndvi-iqr-10m',
    'L2A-ndvi-std-10m',
    'L2A-ndvi-skew-10m',
    'L2A-ndvi-kurt-10m',
    'L2A-ndwi-p10-10m',
    'L2A-ndwi-p50-10m',
    'L2A-ndwi-p90-10m',
    'L2A-ndwi-iqr-10m',
    'L2A-ndwi-std-10m',
    'L2A-rgbBR-p10-10m',
    'L2A-rgbBR-p50-10m',
    'L2A-rgbBR-p90-10m',
    'L2A-rgbBR-iqr-10m',
    'L2A-evi-nSeas-10m',
    'L2A-anir-p10-20m',
    'L2A-anir-p50-20m',
    'L2A-anir-p90-20m',
    'L2A-anir-iqr-20m',
    'L2A-ndmi-p10-20m',
    'L2A-ndmi-p50-20m',
    'L2A-ndmi-p90-20m',
    'L2A-ndmi-iqr-20m',
    'L2A-brightness-p10-20m',
    'L2A-brightness-p50-20m',
    'L2A-brightness-p90-20m',
    'L2A-brightness-iqr-20m',
    'L2A-brightness-std-20m',
    'L2A-B11-p10-20m',
    'L2A-B11-p50-20m',
    'L2A-B11-p90-20m',
    'L2A-B11-iqr-20m',
    'L2A-B12-p10-20m',
    'L2A-B12-p50-20m',
    'L2A-B12-p90-20m',
    'L2A-B12-iqr-20m',
    'L2A-B11-std-20m',
    'L2A-B12-std-20m',
    'SIGMA0-VV-p10-20m',
    'SIGMA0-VV-p50-20m',
    'SIGMA0-VV-p90-20m',
    'SIGMA0-VV-iqr-20m',
    'SIGMA0-VH-p10-20m',
    'SIGMA0-VH-p50-20m',
    'SIGMA0-VH-p90-20m',
    'SIGMA0-VH-iqr-20m',
    'SIGMA0-VV-std-20m',
    'SIGMA0-VH-std-20m',
    'SIGMA0-vh_vv-std-20m',
    'AgERA5-precipitation_flux-sum-1000m',
    'AgERA5-temperature_mean-p10-1000m',
    'AgERA5-temperature_mean-p50-1000m',
    'AgERA5-temperature_mean-p90-1000m',
    'DEM-alt-20m',
    'DEM-slo-20m',
    'biome01',
    'biome02',
    'biome03',
    'biome04',
    'biome05',
    'biome06',
    'biome07',
    'biome08',
    'biome09',
    'biome10',
    'biome11',
    'biome12',
    'biome13',
]

df = df[BANDS_CROPLAND + ['LABEL']]

# %%
plt.close()
df_temp = pd.melt(df, id_vars='LABEL', value_vars=list(df.columns)[:-1],
                  var_name="Feature", value_name="Value")
g = sns.FacetGrid(data=df_temp, col="Feature",
                  col_wrap=4, size=4.5, sharey=False)
g.map(sns.boxplot, "LABEL", "Value")
plt.savefig('temp1.png')

plt.close()
g = sns.FacetGrid(data=df_temp, col="Feature",
                  col_wrap=4, size=4.5, sharey=False)
g.map(sns.violinplot, "LABEL", "Value")
plt.savefig('temp2.png')
