# -*- coding: utf-8 -*-
import geopandas as gpd
from pathlib import Path
from loguru import logger
import glob
import xarray as xr
import os


def main(datasets, indir):

    for dataset in datasets:

        logger.info(f'Working on dataset: {dataset}')
        tsdir = indir / dataset / 'TS'
        shapefile = str(indir / f'{dataset}.gpkg')

        if not os.path.exists(shapefile):
            raise ValueError('Shapefile not found!')

        # get S2 file
        S2_file = glob.glob(str(tsdir / 'S2*.nc'))[0]
        S2_file_name = Path(S2_file).name
        start = S2_file_name.split('_')[1]
        end = S2_file_name.split('_')[2]

        # get S1 file
        S1_file = glob.glob(str(tsdir / 'S1*.nc'))[0]

        # open datasets
        ds_s2 = xr.open_dataset(S2_file)
        ds_s1 = xr.open_dataset(S1_file)

        # merge time series
        merged_ds = xr.merge([ds_s2, ds_s1])

        # read labels
        outputdata = gpd.read_file(shapefile)
        output_da = xr.DataArray(data=outputdata['CTfin'],
                                 dims=['CODE_OBJ'],
                                 coords={'CODE_OBJ': outputdata['sampleID']},
                                 name='LABEL')

        # merge with time series
        merged_ds = merged_ds.merge(output_da)

        # Select only object IDs where we have ground truth label
        merged_ds = merged_ds.where(merged_ds['LABEL'].notnull(), drop=True)

        # save result
        outfile = str(indir / dataset /
                      f'{dataset}_{start}_{end}_InputsOutputs_crops.nc')
        logger.info(f'Creating: {outfile}')
        merged_ds.to_netcdf(outfile, engine='h5netcdf')


if __name__ == '__main__':

    datasets = [
        '2018_BE_LPIS-Flanders_POLY_110',
        '2019_BE_LPIS-Flanders_POLY_110',
        '2021_BE_LPIS-Flanders_POLY_110',
        '2018_FR_LPIS_POLY_110',
        '2019_FR_LPIS_POLY_110',
        '2020_FR_LPIS_POLY_110',
        '2019_LV_LPIS_POLY_110',
        '2021_LV_LPIS_POLY_110',
        '2018_AT_LPIS_POLY_110',
        '2019_AT_LPIS_POLY_110',
        '2020_AT_LPIS_POLY_110',
        '2021_AT_LPIS_POLY_110',
        '2019_ESP_ESYRCE_POLY_111',
        '2020_ESP_ESYRCE_POLY_111',
        '2021_ESP_ESYRCE_POLY_111',
        '2019_ES_SIGPAC-Catalunya_POLY_111',
        '2019_UA_ARABLE-LAND_POLY_110'
    ]

    indir = Path("/data/users/Public/jeroendegerickx/Demeter/"
                 "Croptype_classification/training_data")
    main(datasets, indir)
