"""
Script to prepare a CSV/GeoJSON that defines which output tiles need to be generated.

TODO: basic building blocks are here, and working, but needs to be turned into a more convenient script

"""

import json
from pathlib import Path

import geopandas as gpd
import pandas as pd

import openeo
from openeo.processes import array_create, mean, count

EU27 = [
"France",
"Sweden",
"Poland",
"Austria",
"Hungary",
"Romania",
"Lithuania",
"Latvia",
"Estonia",
"Germany",
"Bulgaria",
"Greece",
"Croatia",
"Luxembourg",
"Belgium",
"Netherlands",
"Portugal",
"Spain",
"Ireland",
"Italy",
"Denmark",
"Slovenia",
"Finland",
"Slovakia",
"Czechia"
]
year = 2021

def LAEA_20km()->gpd.GeoDataFrame:
    europe = gpd.read_file(gpd.datasets.get_path("naturalearth_lowres"))
    europe = europe[europe.continent=="Europe"]
    countries = europe[europe.name.isin(EU27)]
    df = gpd.read_file("https://artifactory.vgt.vito.be/auxdata-public/grids/LAEA-20km.gpkg",mask=countries)

    return df

def UTM_100km_EU27()->gpd.GeoDataFrame:
    europe = gpd.read_file(gpd.datasets.get_path("naturalearth_lowres"))
    europe = europe[europe.continent=="Europe"]
    countries = europe[europe.name.isin(EU27)]
    df = gpd.read_file("https://artifactory.vgt.vito.be/auxdata-public/grids/utm-tiling-grid-100km.gpkg",mask=countries)
    df = df.cx[-14:35, 33:72] #rough EU bbox to get rid of overseas areas
    return df

def grid_statistics_mean():
    #from .grids import LAEA_20km
    # from shapely.geometry import GeometryCollection
    # grid_df = LAEA_20km()
    c = openeo.connect("openeo-dev.vito.be").authenticate_oidc()
    wc = c.load_collection("ESA_WORLDCOVER_10M_2021_V2", bands="MAP",
                                          temporal_extent=["2020-12-30", "2022-01-01"])

    statsfile = "cropland_mean_laea_2021.json"
    if (not Path(statsfile).exists()):

        (wc.band("MAP") == 40).aggregate_spatial(
            "https://artifactory.vgt.vito.be/auxdata-public/grids/LAEA-20km-EU27.geojson",reducer=lambda x:array_create(mean(x),count(x))).execute_batch(
            statsfile, title="Worldcover cropland stats LAEA", job_options={"executor-memory":"4G","executor-memoryOverhead":"2G"})

    with open(statsfile,'r') as f:
        mean_list = json.load(f)["2021-12-31T00:00:00Z"]
        mean_list = [bands[0] if len(bands) > 0 else 0.0 for bands in mean_list]

    import geopandas as gpd
    grid_df = gpd.read_file(f"jobsplit_laea20km_{year}_all.geojson")  # https://artifactory.vgt.vito.be/auxdata-public/grids/
    grid_df["cropland_perc"] = mean_list
    grid_df["cropland_perc"]=100.0*grid_df["cropland_perc"]
    grid_df.to_file(f"jobsplit_laea20km_{year}_cropland.geojson", driver='GeoJSON')
    grid_df.to_csv(f"jobsplit_laea20km_{year}_all.csv", index=False)
    return grid_df


def count_products():

    output_json = f"jobsplit_laea20km_{year}_all.geojson"

    if Path(output_json).exists():
        grid = gpd.read_file(output_json)
    else:
        grid = LAEA_20km()
        grid["provider"]="notset"
        grid["sentinel1count"]=-1
        grid["sentinel2count"]=-1

    from terracatalogueclient import Catalogue
    catalogue = Catalogue()


    for i in grid.index:
        row = grid.loc[i]
        box = row.geometry.bounds

        print(row)
        print(box)
        if row['sentinel1count'] <= 0:

            count = catalogue.get_product_count(
                "urn:eop:VITO:CGS_S1_GRD_SIGMA0_L1",
                start=str(year) + "-01-01",
                end=str(year) + "-12-31",
                bbox = list(box)
            )
            print(count)
            grid.loc[i, "sentinel1count"] = count

            grid.loc[i, "provider"] = "not_terrascope"
            #cube = features.load_features(2019, terrascope_dev, provider="terrascope")
            #validation = cube.filter_bbox(west=box[0], south=box[1], east=box[2], north=box[3]).validate()
            if(count>80):
                s2_count = catalogue.get_product_count(
                    "urn:eop:VITO:TERRASCOPE_S2_TOC_V2",
                    start=str(year) + "-01-01",
                    end=str(year) + "-12-31",
                    bbox=list(box)
                )
                grid.loc[i, "sentinel2count"] = s2_count
                if s2_count > 100 and count > 80:
                    grid.loc[i, "provider"] = "terrascope"


            grid.to_csv(f"jobsplit_laea20km_{year}_all.csv", index=False)

            grid.to_file(output_json, driver='GeoJSON')




def assign_mgrs_tiles(df):
    """
    Add MGRS tile id's, purely for information.
    Returns:

    """
    utm_grid = UTM_100km_EU27()
    utm_grid['tile_id'] = utm_grid.name.str.rstrip("_0_0")
    name_only = utm_grid[['tile_id','geometry']]

    joined = df.sjoin(name_only,how="left")
    list_of_names = joined.groupby('name')['tile_id'].apply(list)
    df.index=df.name
    df['mgrs_ids'] = list_of_names

    return df

def assign_meteo(df, year):
    """
    Add meteo file names for year year.
    Returns: df

    """

    df[f'meteo_{str(year)}'] = 'METEO-' + df['name'] + f'-{str(2021)}'

    return df



#count_products()
df = grid_statistics_mean()
grid = assign_mgrs_tiles(df)
grid = assign_meteo(grid, year)

grid.sort_values(by=['cropland_perc'], inplace=True,ascending=False)

output_json = f"jobsplit_laea20km_{year}_all_final.geojson"
grid.to_csv(f"jobsplit_laea20km_{year}_all_final.csv", index=False)
json_collection = grid.to_json()
with open(output_json,'w') as f:
    f.write(json_collection)
#grid.to_file(output_json, driver='GeoJSON', index=False)

#now create separate files for different jobs runs

grid["title"] = "Cropclass-Classification-" + grid["name"] + "-" +grid["cropland_perc"].apply(lambda x: "{:,.2f}".format(x))
grid["description"] = "Croptype map generated for openEO platform CCN \n" + grid["name"] + "-" +grid["cropland_perc"].apply(lambda x: "{:,.2f}".format(x))

terra_selection = grid[grid.provider == "terrascope"]
terra_selection.to_csv("croptype2021_terrascope.csv")
json_collection = terra_selection.to_json()
with open("croptype2021_terrascope.geojson",'w') as f:
    f.write(json_collection)

json_collection = grid[grid.provider!="terrascope"].to_json()
with open("croptype2021.geojson",'w') as f:
    f.write(json_collection)
"""
These jobs are for test runs on the different backends.
Start from tiles with more than 5% crops, to avoid processing irrelevant areas
"""

test_ids = ["33TWN","35VME","34VFN","35WMM","31UDP","31TCJ","33TWJ","33TXL","30TUM","30SVH","30UFG","29SNC","29TPG","33VWG","33TWM","33UYP"]

tiles_with_crops = grid[grid.cropland_perc>10]
subset = tiles_with_crops.mgrs_ids.apply(lambda x: any([y in test_ids for y in x]))
subset31UFS = tiles_with_crops.mgrs_ids.apply(lambda x: '31UFS' in x)
sample = pd.concat([tiles_with_crops[subset31UFS].sample(5),tiles_with_crops[subset].sample(100)])

sample["provider"] = "terrascope"
json_collection = sample.to_json()
with open("terrascope_test.geojson",'w') as f:
    f.write(json_collection)

sample["provider"] = "creodias"
json_collection = sample.to_json()
with open("creo_test.geojson",'w') as f:
    f.write(json_collection)

sample["provider"] = "shub"
json_collection = sample.to_json()
with open("shub_test.geojson",'w') as f:
    f.write(json_collection)
