"""
Script to prepare a CSV/GeoJSON that defines which output tiles need to be generated.

TODO: basic building blocks are here, and working, but needs to be turned into a more convenient script

"""

import json
import os
from pathlib import Path

import geopandas as gpd
import pandas as pd

import openeo
from openeo.processes import array_create, mean, count


#Iceland is also include (will not have large impact after BVL applying)

EU27_NUTS_ID = [
"FR", "SE", "PL", "AT", "HU", "RO", "LT", "LV","EE", "DE", "BG", "EL", "HR", "LU",
"BE", "NL", "PT", "ES", "IE", "IT", "DK", "SI", "FI", "SK", "CZ", "MT", "CY", "IS"
]

EU27 = [
"France",
"Sweden",
"Poland",
"Austria",
"Hungary",
"Romania",
"Lithuania",
"Latvia",
"Estonia",
"Germany",
"Bulgaria",
"Greece",
"Croatia",
"Luxembourg",
"Belgium",
"Netherlands",
"Portugal",
"Spain",
"Ireland",
"Italy",
"Denmark",
"Slovenia",
"Finland",
"Slovakia",
"Czechia",
"Malta",
"Cyprus",
"Iceland"
]
years = range(2017,2022)

def LAEA_20km()->gpd.GeoDataFrame:
    europe = gpd.read_file("/vitodata/EEA_HRL_VLCC/data/ref/AOI/NUTS/NUTS_RG_01M_2021_3035_LEVL_0.shp")
    countries = europe[europe.NUTS_ID.isin(EU27_NUTS_ID)]
    df = gpd.read_file("https://artifactory.vgt.vito.be/auxdata-public/grids/LAEA-20km.gpkg",mask=countries)
    return df

def UTM_100km_EU27()->gpd.GeoDataFrame:
    europe = gpd.read_file("/vitodata/EEA_HRL_VLCC/data/ref/AOI/NUTS/NUTS_RG_01M_2021_3035_LEVL_0.shp")
    countries = europe[europe.NUTS_ID.isin(EU27_NUTS_ID)]
    df = gpd.read_file("https://artifactory.vgt.vito.be/auxdata-public/grids/utm-tiling-grid-100km.gpkg",mask=countries)
    df = df.cx[-14:35, 33:72] #rough EU bbox to get rid of overseas areas
    return df

def grid_statistics_mean(grid_df):
    c = openeo.connect("openeo.vito.be").authenticate_oidc()
    wc = c.load_collection("ESA_WORLDCOVER_10M_2021_V2", bands="MAP",
                                          temporal_extent=["2020-12-30", "2022-01-01"])

    statsfile = "cropland_mean_laea_2021.json"
    if (not Path(statsfile).exists()):

        (wc.band("MAP") == 40)\
            .aggregate_spatial("https://artifactory.vgt.vito.be/auxdata-public/grids/LAEA-20km-EU27+IS.geojson",reducer=lambda x:array_create(mean(x),count(x)))\
            .execute_batch(statsfile, title="Worldcover cropland stats LAEA", job_options={"executor-memory":"4G","executor-memoryOverhead":"2G"})

    with open(statsfile,'r') as f:
        mean_list = json.load(f)["2021-12-31T00:00:00Z"]
        mean_list = [bands[0] if len(bands) > 0 else 0.0 for bands in mean_list]

    import geopandas as gpd
    #grid_df = gpd.read_file(f"jobsplit_laea20km_{year}_all.geojson")  # https://artifactory.vgt.vito.be/auxdata-public/grids/
    grid_df["cropland_perc"] = mean_list
    grid_df["cropland_perc"]=100.0*grid_df["cropland_perc"]
    grid_df.to_file(f"jobsplit_laea20km_base_cropland.geojson", driver='GeoJSON')
    grid_df.to_csv(f"jobsplit_laea20km_base_all.csv", index=False)
    return grid_df


def count_products():

    output_json = f"jobsplit_laea20km_{year}_all.geojson"

    if Path(output_json).exists():
        grid = gpd.read_file(output_json)
    else:
        grid = LAEA_20km()
        grid["provider"]="notset"
        grid["sentinel1count"]=-1
        grid["sentinel2count"]=-1

    from terracatalogueclient import Catalogue
    catalogue = Catalogue()


    for i in grid.index:
        row = grid.loc[i]
        box = row.geometry.bounds

        print(row)
        print(box)
        if row['sentinel1count'] <= 0:

            count = catalogue.get_product_count(
                "urn:eop:VITO:CGS_S1_GRD_SIGMA0_L1",
                start=str(year) + "-01-01",
                end=str(year) + "-12-31",
                bbox = list(box)
            )
            print(count)
            grid.loc[i, "sentinel1count"] = count

            grid.loc[i, "provider"] = "not_terrascope"
            #cube = features.load_features(2019, terrascope_dev, provider="terrascope")
            #validation = cube.filter_bbox(west=box[0], south=box[1], east=box[2], north=box[3]).validate()
            if(count>80):
                s2_count = catalogue.get_product_count(
                    "urn:eop:VITO:TERRASCOPE_S2_TOC_V2",
                    start=str(year) + "-01-01",
                    end=str(year) + "-12-31",
                    bbox=list(box)
                )
                grid.loc[i, "sentinel2count"] = s2_count
                if s2_count > 100 and count > 80:
                    grid.loc[i, "provider"] = "terrascope"


            grid.to_csv(f"jobsplit_laea20km_{year}_all.csv", index=False)

            grid.to_file(output_json, driver='GeoJSON')

def assign_mgrs_tiles(df):
    """
    Add MGRS tile id's, purely for information.
    Returns: update df
    """
    utm_grid = UTM_100km_EU27()
    utm_grid['tile_id'] = utm_grid.name.str.rstrip("_0_0")
    name_only = utm_grid[['tile_id','geometry']]

    joined = df.sjoin(name_only,how="left")
    list_of_names = joined.groupby('name')['tile_id'].apply(list)
    df.index=df.name
    df['mgrs_ids'] = list_of_names

    return df

def assign_meteo(df, year):
    """
    Add meteo file names for year year.
    Returns: df
    """

    df[f'meteo_{str(year)}'] = 'METEO-' + df['name'] + f'-{str(2021)}'

    return df



#count_products()
grid = LAEA_20km()
grid['100km_tile'] = grid['name'].apply(lambda x : x[0:3]+x[4:7])
prio = pd.read_csv('/vitodata/EEA_HRL_VLCC/data/ref/AOI/priority/priority_EU27.csv',sep=';')
grid = grid.merge(prio, how='left')
#set to lowest priority if no priority
grid.loc[grid['priority'].isna(),'priority']=18



#df = grid_statistics_mean(grid)
grid = assign_mgrs_tiles(grid)
#below line if new processing list for meteo is needed.
#grid = assign_meteo(grid, year)

grid = pd.concat([grid] * len(years), ignore_index=True)
grid['year'] = (grid.groupby('name').cumcount() + years[0]).astype(int).astype(str)
grid['version'] = 'V100'

grid["title"] = f'CROP_' +grid["year"] +'_' + grid["name"] +'-03035-010m_' + grid['version'] +'_RAW.tif'
grid["description"] = "Croptype map generated for HR VLCC " + grid["year"] + \
                        " generated with cropclass SOFTWARE_VERSION and hrl vlcc croptype model MODEL_TAG."



grid.sort_values(by=['priority'], inplace=True,ascending=True)

output_json = f"/vitodata/EEA_HRL_VLCC/data/production/jobsplit_laea20km_all_final.geojson"
json_collection = grid.to_json()
if False and os.path.exists(output_json):
    raise ("ooh wait, this file already exists. do you really want to overwrite it?")
else:
    with open(output_json,'w') as f:
        f.write(json_collection)