"""
Code to do extractions from the copernicus climate store
"""

# import needed packages
from loguru import logger
import cdsapi
import glob
import geopandas as gpd
import numpy as np
import os
import cfgrib
import xarray as xr
from cropcarbon.openeo.extractions import (
     split_batch_jobs,
     check_availabilty)
from cropcarbon.utils.timeseries import nc_to_pandas


def retrieve_request(c, DATASET, years, outdir, config, 
                     job_id, outname_tmp, area_reformat):
    logger.info(f'starting {DATASET} ')
    if DATASET == 'ERA5-LAND_COLLECTION' or DATASET == "ERA5-LAND_COLLECTION_SWC":
            dataset_name_eds = 'reanalysis-era5-land'
    else:
        raise AttributeError(f'EXTRACT FROM {DATASET}'
                                    'not yet supported')

    bands = config.get(job_id).get('DATASETS').get(DATASET)\
            .get('BANDS')
    
    # list of the months to adress for each split
    dict_months_split = {
        0: ['-01-01',"-03-31"],
        1: ['-04-01', '-06-30'],
        2: ['-07-01','-09-30'], 
        3: ['-10-01', '-12-31']}
    # store all the requested data locations in a list:
    lst_loc_data = []

    # split up the request in max 3 months data period
    iteration = 0
    for year in years:
        logger.info(f'starting {year} ')
        # year should be splitted in two


        for split in range(4):
            months = dict_months_split.get(split)
            outname_request = outname_tmp.split('.grib')[0] + \
                              f'_{str(iteration)}.grib'
                              
            date_start=f'{year}'+months[0]
            date_end=f'{year}'+months[1]
            s = {}

            s["variable"] = bands
            s["product_type"] = "reanalysis"
            s["date"] = date_start+"/"+date_end
            s["time"] = [str(t).zfill(2)+":00" for t in range(0, 24, 1)]
            s["area"] = area_reformat
            s["format"]="grib"
            c.retrieve(
                    dataset_name_eds,
                    s,
                    os.path.join(outdir, outname_request))
            iteration += 1 
            logger.info(f'split {iteration} ')
            lst_loc_data.append(os.path.join(outdir, outname_request))
    # merge all the files and concatenate them and take mean (all variables) and min & max for temperature
    lst_ds = [cfgrib.open_dataset(item) for item in lst_loc_data]
    xr_merged = xr.merge(lst_ds, compat="no_conflicts")
    if DATASET == 'ERA5-LAND_COLLECTION' : # [NOTE] grib structure is different for swvl1 dataset, step is not a dimension so other method for taking mean, mix, max
        variables_mean = ['pev', 'e',"t2m"] 
        variables_min_max = ["t2m"] 
        variables_sum=["ssr","ssrd"] 
            # Initialize dictionaries to store results
        daily_mean = {}
        daily_min = {}
        daily_max = {}
        daily_sum = {}
        

        # Calculate daily mean, min, and max for specified variables
        for var in variables_mean:
            daily_mean[var] = xr_merged[var].groupby('time.date').mean().mean(dim='step')
        for var in variables_sum: #the radiation is already summed we thus just need to get the last value on each day
            daily_sum[var] = xr_merged[var].groupby('time.date').last().max(dim="step")
        for var in variables_min_max:   
            daily_min[var] = xr_merged[var].groupby('time.date').mean().min(dim='step')
            daily_max[var] = xr_merged[var].groupby('time.date').mean().max(dim='step')
    if DATASET == "ERA5-LAND_COLLECTION_SWC":
        variables_mean = ["swvl1"] 
        variables_sum= []
        variables_min_max = []  
                    # Initialize dictionaries to store results
        daily_mean = {}

        # Calculate daily mean, min, and max for specified variables
        for var in variables_mean:
            daily_mean[var] = xr_merged[var].resample(time='1D').mean(dim='time')
        
    daily_stats = xr.Dataset()

    for var in variables_mean:
        daily_stats[f'{var}_mean'] = daily_mean[var]
    for var in variables_sum:
        daily_stats[f'{var}_sum'] = daily_sum[var]
    for var in variables_min_max:
        daily_stats[f'{var}_min'] = daily_min[var]
        daily_stats[f'{var}_max'] = daily_max[var]
        
    daily_stats = daily_stats.stack(pos=("longitude", "latitude"))
            # rename time axis to ensure consistency
    if DATASET == "ERA5-LAND_COLLECTION":
        daily_stats = daily_stats.rename({'date': 't'})
    if DATASET == "ERA5-LAND_COLLECTION_SWC":
        daily_stats = daily_stats.rename({'time': 't'})
    return daily_stats



def extract_cds(config, outdir):
    # connect the cds backend
    c = cdsapi.Client()
    # dictionary that translates the band names to the 
    # actual name they will have in the retrieved data --> after calculating mean etc
    band_names_dic = {
        'ERA5-LAND_COLLECTION': ['pev_mean', 'e_mean', 't2m_mean', 'ssr_sum', 'ssrd_sum', 't2m_min', 't2m_max'],
        'ERA5-LAND_COLLECTION_SWC': ['swvl1_mean']
    }
    #
    # translation of the collection name to the actual 
    # naming that will be used for the output folder
    translate_col_outfold = {
        "ERA5-LAND_COLLECTION": 'ERA5-LAND',
        "ERA5-LAND_COLLECTION_SWC" :'ERA5-LAND-SWC'
    }

    # create request based on config file
    for job_id in config.keys():
        site_id = '_'.join(job_id.split('_')[:-1])
        outdir = os.path.join(outdir, site_id, 'extractions')
        outname_tmp = f'{job_id}_out_cds.grib'
        os.makedirs(outdir, exist_ok=True)
        period = config.get(job_id).get('PERIOD')
        yrs_range = [int(item[0:4]) for item in period]
        years = [int(item) for item in np.arange(yrs_range[0], yrs_range[-1]+1, 1)]
        GEOM = config.get(job_id).get('GEOM')
        area = list(GEOM.bounds)
        area_reformat = [area[1], area[0], area[-1], area[2]]
        dict_df_collections = {}
        for DATASET in config.get(job_id).get('DATASETS'):
            ds = retrieve_request(c, DATASET, years, outdir,
                                             config, job_id, outname_tmp, 
                                             area_reformat)

            band_names=band_names_dic.get(DATASET, [])
                
            dict_df_collections = nc_to_pandas(ds, band_names, 
                                               years[0], years[-1], 
                                               dict_df_collections,
                                               translate_col_outfold.get(DATASET), 
                                               clip_year=True)
            # remove now all files related with the original extraction
            files_remove = glob.glob(os.path.join(outdir, f'{outname_tmp.split(".")[0]}_*'))
            [os.unlink(item) for item in files_remove]
        for dataset in dict_df_collections.keys():
            collection_name = '_'.join(dataset.split('_')[:-1])
            year_dataset = dataset.split('_')[-1]
            outfold = os.path.join(outdir, collection_name, year_dataset)
            os.makedirs(outfold, exist_ok=True)
            outname = f'{site_id}_{collection_name}_{str(year_dataset)}.csv'
            dict_df_collections.get(dataset).to_csv(os.path.join(outfold, 
                                                    outname),
                                                    index=True)
            logger.info(f'{DATASET} dataframes saved!')
            
    logger.info('Final dataframes saved!')


def define_batch_jobs(outdir, dataset, settings):
    logger.info(f'Preparing {dataset} for extractions...')
    datasetdir = outdir / dataset
    extractionsdir = outdir / dataset / 'extractions'
    # get the sample files
    # get start and end date
    start_date = settings.get('START', None)
    end_date = settings.get('END', None)
    # now define for which periods the extractions 
    # should be done. If for example only for certain 
    # years in the time range in-situ data is available
    # the batch jobs will be subdivided in consecutive
    # years for which data is available. 
    batch_job_ids, periods = split_batch_jobs(settings, start_date, end_date,
                                              dataset)
    # Also important is to assess for each period if 
    # the data is not yet available. 
    # If not, for which datasets in the period 
    # the processing should be still done.
    # if for a part of the period the processing is already done, 
    # the period should be made smaller
    periods, collection_info_periods, batch_job_ids = check_availabilty(settings, periods,
                                                                        batch_job_ids ,
                                                                        extractionsdir.as_posix())
    if not periods:
        logger.info(f'EXTRACTION FOR {dataset} ALREADY DONE --> SKIP')
        return {}
    # get the geometry information for extraction
    samplefiles = glob.glob(str(datasetdir / f'{dataset}_*.shp'))
    if len(samplefiles) == 0:
        raise ValueError(f'NO SHAPE INFORMATION AVAILABLE FOR {dataset}')
    if len(samplefiles) > 1:
        raise ValueError(f'TOO MANY SHAPE INFORMATION AVAILABLE FOR {dataset}')
    geom = gpd.read_file(samplefiles[0]).geometry.values[0]

    # now create the config file for the extraction request
    dict_config = {}

    for i in range(len(batch_job_ids)):
        jobid = batch_job_ids[i]
        col_name = list(collection_info_periods.keys())[i]
        collections_extract = collection_info_periods.get(col_name)
        dict_col_extract_info = {}  
        for col in collections_extract:
            # check if the corresponding col should be still extracted
            if not collection_info_periods.get(col_name).get(col):
                # collection already extracted
                continue
            DATASETS = settings.get('DATASETS')#
            dict_col_extract_info.update(
                                         DATASETS
                                        )
        # get now info on the bands that should be 
        dict_config.update({jobid: {'PERIOD': periods[i],
                                    'GEOM': geom,
                                    'DATASETS': dict_col_extract_info}})
    return dict_config



def main(outdir, datasets, settings):
    logger.info(f'GETTING ALL INPUTS FOR EXTRACTIONS')

    for dataset in datasets:
        # define batch jobs for extractions on cds
        config_batch_job = define_batch_jobs(outdir, 
                                             dataset, 
                                             settings)
        if not config_batch_job:
            continue
        logger.info(f'STARTING PROCESSING FOR {dataset}')
        extract_cds(config_batch_job, outdir)
    return