import geopandas as gpd
import pandas as pd
import json
import glob
from loguru import logger
import pandas as pd
import numpy as np
import h3.api.basic_int as h3
from pathlib import Path
from datetime import datetime
import fire
from shutil import copy2
import xarray as xr
import os
import warnings

from cropclass.trainingdata.datasets import get_start_end_from_valtime
from cropclass.openeo.preprocessing import cropclass_preprocessed_inputs
from cropclass.config.settings import get_job_options

from openeo_classification.connection import (
    openeo_platform, creo, terrascope_dev, openeo_prod)
from openeo.extra.job_management import MultiBackendJobManager


def _split_dataset(infile, json_dir, source_name, start_date,
                   end_date, ct_attr='CT_fin', id_attr='sampleID',
                   add_attributes=None):

    logger.info('Reading shapefile...')
    df = gpd.read_file(infile)

    logger.info('Preparing geometry...')
    cellsh3 = df.geometry.centroid.apply(lambda x: h3.geo_to_h3(x.y, x.x, 10))
    df['s2cell'] = cellsh3
    df.index = cellsh3
    df.sort_index(inplace=True)

    logger.info('Only retaining crucial attributes...')
    df = df[[ct_attr, id_attr, 's2cell', 'geometry']]

    logger.info('Splitting up dataset...')

    def dist(x):
        return h3.point_dist(h3.h3_to_geo(x.astype(np.int64)[-1] - 1),
                             h3.h3_to_geo(x.astype(np.int64)[0] - 1))

    distance = df.s2cell.rolling(window=2).apply(lambda x: dist(x), raw=True)
    distance_thr = 200
    breakpoints = df.s2cell[distance > distance_thr]
    # while (len(breakpoints) <= 1) :
    #     distance_thr -= 50
    #     breakpoints = df.s2cell[distance > distance_thr]

    # out,bins = pd.qcut(df.s2cell,q=int(len(cells)/500),retbins=True)
    if len(breakpoints) > 0:
        cuts = breakpoints.values
        cuts.sort()
        categories = pd.cut(df.index, bins=(cuts - 1), right=False)
        grouped = df.groupby(categories)
        count = grouped[id_attr].agg('count')
        count.name = "COUNT"
        count = count.reset_index().COUNT
        union = grouped.geometry.aggregate(lambda s: s.unary_union)
        polys = union.convex_hull
        polys.name = "GEOM"

        filenames = []
        for name, group in grouped:
            f = json_dir / f"{source_name}_group_{name}.json"
            filenames.append(str(f))
            group.to_file(f)

        logger.info(f'Dataset to be split into {len(filenames)} parts!')

    else:
        logger.info('No splitting required for this one!')
        count = pd.Series(data=[len(df)], index=[0],
                          dtype=np.int64, name="COUNT")
        name = 'all'
        filenames = [str(json_dir / f"{source_name}_group_{name}.json")]
        df.to_file(filenames[0])
        union = df.geometry.aggregate(lambda s: s.unary_union)
        polys = pd.Series(data=[union.convex_hull],
                          index=[0], name="GEOM")

    logger.info('Creating splits dataframe...')
    splits_frame = gpd.GeoDataFrame({"COUNT": count, "FILENAME": pd.Series(
        filenames)}, geometry=polys.reset_index().GEOM)

    logger.info('Adding some attributes...')
    ref_id = '_'.join(source_name.split('_')[0:-1])
    splits_frame['ref_id'] = [ref_id] * len(splits_frame)
    splits_frame['start_date'] = [start_date] * len(splits_frame)
    splits_frame['end_date'] = [end_date] * len(splits_frame)
    if add_attributes is not None:
        for title, value in add_attributes.items():
            splits_frame[title] = [value] * len(splits_frame)

    logger.info('Saving split dataframe...')
    splits_frame.to_file(
        str(json_dir / f"{source_name}_split_overview.json"), index=False)
    logger.info('Split dataframe saved!')


def split_dataset(basedir, dataset, settingsfile):

    logger.info(f'Preparing {dataset} for openeo extractions...')

    samplesdir = basedir / dataset / 'samples'
    extractionsdir = basedir / dataset / 'extractions'
    jsons_dir = extractionsdir / 'jsons'
    jsons_dir.mkdir(exist_ok=True, parents=True)

    # get all sample files
    samplefiles = glob.glob(str(samplesdir / f'{dataset}_*.gpkg'))

    # open settings
    settings = json.load(open(settingsfile, 'r'))

    # split the ones which haven't been split yet
    for samplefile in samplefiles:
        sname = Path(samplefile).stem
        split_file = jsons_dir / f'{sname}_split_overview.json'
        if not split_file.exists():

            logger.info(f'Splitting samplefile {sname} '
                        'into manageable chuncks for openeo...')

            # get start and end date
            start_date = settings.get('start_date', None)
            end_date = settings.get('end_date', None)
            if (start_date is None) or (end_date is None):
                logger.warning('Inferring start date and/or '
                               'end date from validity time...')
                sdate, edate = get_start_end_from_valtime(samplefile)
                if start_date is None:
                    start_date = sdate
                if end_date is None:
                    end_date = edate
            logger.info('Date range for extractions: '
                        f'{start_date} - {end_date}')

            add_attributes = {'settingsfile': settingsfile}

            _split_dataset(samplefile, jsons_dir, sname,
                           start_date, end_date,
                           add_attributes=add_attributes)

            logger.info(f'{sname} split successfully!')

    return


class CustomJobManager(MultiBackendJobManager):

    def on_job_done(self, job, row):

        fnp = row['FILENAME']
        ref_id = row['ref_id']

        target_dir = Path(fnp).parents[1]
        job_metadata = job.describe_job()
        target_dir = target_dir / job_metadata['title']
        target_dir.mkdir(exist_ok=True, parents=True)
        job.get_results().download_files(target=target_dir)

        with open(target_dir / f'job_{job.job_id}.json', 'w') as f:
            json.dump(job_metadata, f, ensure_ascii=False)

        # copy geometry to result directory
        try:
            copy2(fnp, target_dir)
        except:
            print(f'COPY ERROR {fnp} {target_dir}')

        # open extractions netcdf
        infile = glob.glob(str(target_dir / 'timeseries.nc'))
        if len(infile) == 0:
            logger.warning('No extractions found!')
            return
        data = xr.open_dataset(infile[0])

        # open json containing the geometries
        polys = gpd.read_file(fnp)

        # do check on completeness
        if data.feature.shape[0] != len(polys):
            logger.warning('Extractions incomplete, ignoring these!')
            os.rename(infile[0],
                      infile[0].replace('.nc', '_INCOMPLETE.nc'))
            return

        # add ref_id
        refID = [row['ref_id']] * len(polys)
        data = data.assign_coords(refID=("feature", refID))

        # add sampleID
        sampleID = list(polys['sampleID'].values)
        data = data.assign_coords(sampleID=("feature", sampleID))

        # add label
        labels = list(polys['CT_fin'].values)
        data = data.assign_coords(label=("feature", labels))
        # remove feature_names coordinate
        data = data.reset_coords(names=['feature_names'],
                                 drop=True)

        # save netcdf
        outfile = infile[0].replace('.nc', '_fin.nc')
        data.to_netcdf(path=outfile)
        logger.info('Final netcdf saved!')


def get_cropclass_features(connection, provider, gdf, start, end,
                           settings):

    # Select the right collections depending on provider
    if provider == 'terrascope':
        S2_collection = 'TERRASCOPE_S2_TOC_V2'
        S1_collection = 'SENTINEL1_GRD_SIGMA0'
    elif provider.isin(['sentinelhub', 'creodias']):
        S2_collection = 'SENTINEL2_L2A'
        S1_collection = 'SENTINEL1_GRD'
    else:
        raise ValueError(f'Provider {provider} not supported!')

    # get processing settings from settings file
    get_S1 = settings.get('get_S1', True)
    if not get_S1:
        S1_collection = None
    demcol = settings.get('DEM_collection', 'COPERNICUS_30')
    meteocol = settings.get('METEO_collection', 'AGERA5')
    worldcover = settings.get('WORLDCOVER_collection', None)
    processing_options = settings.get('processing_options')

    # geometry determined by sample points
    bbox = None

    # call actual function computing the inputs
    datacube = cropclass_preprocessed_inputs(
        connection, bbox, start, end,
        S2_collection=S2_collection,
        S1_collection=S1_collection,
        DEM_collection=demcol,
        METEO_collection=meteocol,
        WORLDCOVER_collection=worldcover,
        preprocess=True,
        **processing_options)

    # now only retain the data for the points of interest
    geo = json.loads(gdf.to_json())
    return datacube.aggregate_spatial(geo, reducer='mean')


def extraction_function(row, connection_provider, connection,
                        provider):

    # get settings for the row
    start = row['start_date']
    end = row['end_date']
    dataset_name = row['ref_id']
    fnp = row['FILENAME']
    filename = str(Path(fnp).name)
    settingsfile = row['settingsfile']
    settings = json.load(open(settingsfile, 'r'))

    # Read samples
    pols = gpd.read_file(fnp)
    # make sure we extract points
    pols["geometry"] = pols["geometry"].centroid

    # get cropclass features for these points
    features = get_cropclass_features(connection, provider,
                                      pols, start, end, settings)

    job_options = get_job_options(task='extractions',
                                  provider=provider)

    job = features.create_job(
        title=f"Cropclass_TrainingFeat_{dataset_name}_{provider}",
        description=f"Cropclass training extractions. Source of samples: {filename}",
        out_format="NetCDF",
        job_options=job_options)

    print(job)

    return job


def extract_samples(dataframe, status_file,
                    provider="terrascope",
                    parallel_jobs=1):

    manager = CustomJobManager()
    # c = openeo_prod
    c = terrascope_dev
    if provider.upper() == "CREODIAS":
        c = creo
    manager.add_backend(provider, connection=c,
                        parallel_jobs=parallel_jobs)

    manager.run_jobs(
        df=dataframe,
        start_job=extraction_function,
        output_file=Path(status_file)
    )


def run_extractions(status_file,
                    dataframe,
                    provider="terrascope"):
    import logging
    logging.basicConfig(level=logging.DEBUG)
    extract_samples(dataframe, status_file,
                    provider=provider, parallel_jobs=1)


def merge_extractions(basedir, ref_ids):

    logger.info(f'Found {len(ref_ids)} ref ids with new extractions')
    for ref_id in ref_ids:
        logger.info(f'Processing: {ref_id}')
        extractionsdir = basedir / ref_id / 'extractions'
        outfile = str(extractionsdir / f'MERGED_{ref_id}_extractions.nc')

        # Find the individual parts
        parts = [Path(x) for x in glob.glob(
            str(extractionsdir / '*' / '*_fin.nc'))]
        if len(parts) == 0:
            logger.info('Found no results, skipping!')
            continue
        logger.info(f'Found {len(parts)} individual parts')

        # merge the files
        ds_merged = xr.open_mfdataset(
            parts, concat_dim='feature', combine='nested', parallel=True)

        # Check if there's observations
        if ds_merged.to_array().notnull().sum() == 0:
            raise ValueError(
                'The merged dataframe contains no valid observations!')

        # Remove duplicate samples
        n_samples = len(ds_merged.sampleID.values)
        logger.info(f'{n_samples} samples in dataset')

        _, idx = np.unique(ds_merged.sampleID, return_index=True)

        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            ds_merged = ds_merged.isel(feature=idx)
            n_duplicates = n_samples - len(ds_merged.sampleID.values)
            logger.info(f'{n_duplicates} duplicate samples removed!')

            # Cast datatypes
            ds_merged['refID'] = ds_merged['refID'].astype('U75')
            ds_merged['sampleID'] = ds_merged['sampleID'].astype('U75')
            ds_merged['label'] = ds_merged['label'].astype('U75')

            # Write to output file
            # Note that load() dramatically speeds up the process
            # should memory issues occur, this can be removed to lower
            # memory burder, but it will go much slower
            logger.info(f'Saving to: {outfile}')
            ds_merged.load().to_netcdf(outfile, engine='h5netcdf')
            logger.info('File saved!')

        # now create or update the database for this ref_id
        logger.info('Creating/updating sample database...')

        # get all samples for this ref_id
        samplefiles = glob.glob(str(basedir / ref_id / 'samples' / '*.gpkg'))
        logger.info(f'Found {len(samplefiles)} sample files for this ref id')
        gdf_all = None
        for samplefile in samplefiles:
            gdf = gpd.read_file(samplefile)
            gdf = gdf.set_index('sampleID')
            # add source of samples
            gdf['samplefile'] = [samplefile] * len(gdf)
            if gdf_all is not None:
                gdf_all = pd.concat([gdf_all, gdf], axis=0)
            else:
                gdf_all = gdf.copy()
        # only retain the samples which have extractions available
        gdf_all = gdf_all.loc[gdf_all.index.isin(ds_merged['sampleID'].values)]
        logger.info(f'{len(gdf_all)} samples available for {ref_id}')
        # add start and end dates
        start = pd.to_datetime(ds_merged.t.values.min()
                               ).strftime(format='%Y-%m-%d')
        end = pd.to_datetime(ds_merged.t.values.max()
                             ).strftime(format='%Y-%m-%d')
        gdf_all['start_date'] = [start] * len(gdf_all)
        gdf_all['end_date'] = [end] * len(gdf_all)
        # add path to extractions
        gdf_all['data_path'] = [outfile] * len(gdf_all)
        # simplify geometry to centroid
        gdf_all['geometry'] = gdf_all['geometry'].centroid
        # save as json
        dbfile = basedir / ref_id / 'database' / 'database.json'
        dbfile.parent.mkdir(exist_ok=True, parents=True)
        gdf_all.to_file(dbfile, driver='GeoJSON')
        logger.info(f'Database saved to {dbfile}')


def main(basedir, provider):

    logger.info(f'GETTING ALL INPUTS FOR EXTRACTIONS')

    # set output dir
    status_dir = basedir / 'extraction_status'
    status_dir.mkdir(exist_ok=True, parents=True)

    # screen directory of jsons to check for datasets to be added
    # to extractions
    dfs = glob.glob(str(basedir / '*' / 'extractions' / 'jsons' /
                        '*_split_overview.json'))

    # if extractions were already started previously, skip these
    check_files = [Path(x).parent / f'{Path(x).stem}_dummy.json'
                   for x in dfs]
    check = [x.is_file() for x in check_files]
    to_add = []
    check_file_to_create = []
    for i, df in enumerate(dfs):
        if not check[i]:
            to_add.append(df)
            check_file_to_create.append(check_files[i])

    n_datasets = len(to_add)
    if n_datasets > 0:

        logger.info(f'Found {len(to_add)} datasets for extractions')
        logger.info('Merging datasets for this batch...')
        dfs = []
        for df_to_add in to_add:
            dfs.append(gpd.read_file(df_to_add))
        df = pd.concat(dfs, axis=0, ignore_index=True)

        # create dummy files for these datasets so they do not end up
        # in next round of extractions
        for outfile in check_file_to_create:
            data = {}
            with open(outfile, 'w') as fp:
                json.dump(data, fp)

        # save dataframe
        time = datetime.now().strftime('%Y%m%d-%H%M%S')
        df_file = str(status_dir / f'{time}_df.json')
        df.to_file(df_file, format='GeoJSON')

        logger.info('Datasets added!')

    logger.info('STARTING ACTUAL EXTRACTIONS...')
    ref_id_new = []
    df_files = glob.glob(str(status_dir / '*_df.json'))
    for df_file in df_files:
        df = gpd.read_file(df_file)
        status_file = df_file.replace('df.json', 'status.csv')
        # check if all extractions have been done or not
        launch = True
        if Path(status_file).is_file():
            status_df = pd.read_csv(status_file, header=0)
            status_df_todo = status_df.loc[status_df["status"] != 'finished']
            if len(status_df_todo) == 0:
                launch = False
            else:
                ref_id_new.extend(status_df_todo.ref_id.unique().tolist())

        if launch:
            logger.info(f'LAUNCHING {len(df)} JOBS FROM {df_file}!')
            fire.Fire(run_extractions(status_file, df, provider))

            # TODO: when all extractions are done, the script should stop
            # and move on, which is not the case if there are extractions
            # remaining with status "error"

    logger.info('Merging extractions per ref_id...')
    # remove duplicates in list of ref_ids to check
    ref_id_new = list(dict.fromkeys(ref_id_new))
    # call function that merges the data per ref_id
    merge_extractions(basedir, ref_id_new)

    logger.success('All done!')
    return
