from cropsar.readers import *
import _pickle as pickle
import os, sys
import pandas as pd
import geopandas as gpd
import numpy as np
import logging
import pprint

log = logging.getLogger(__name__)

def save_as_parquet(filename, df, dest=CROPSAR_BASE / "tmp/dataStackGenerationSPARK/data/parquetFiles", overwrite=False):

    if not os.path.exists(str(dest / filename)) or overwrite:
        log.info('Writing to parquet file: {}'.format(str(dest / filename)))
        df.columns = df.columns.astype(str)
        df.to_parquet(dest / filename)
    else: log.info('File exists: {} -> skipping'.format(str(dest / filename)))

    return str(dest / filename)

def _shp_to_calvaltestpickle(shp, year, dest=CROPSAR_BASE / "data/ref/pickle"):

    log.info('Reading original shapefile: {}'.format(shp))
    if not os.path.exists(shp): sys.exit('Shapefile not found! ({})'.format(shp))
    df = gpd.read_file(shp).set_index('fieldID')
    outdf = df[['croptype', 'area']]

    outfile = str(dest / str(year + '_croptypes.p'))
    log.info('Writing to pickle file: {}'.format(outfile))
    pickle.dump(outdf, open(outfile, 'wb'))

    # Split in CAL/VAL/TEST
    log.info('Splitting cal/val/test ...')
    fieldIDs = outdf.index.tolist()
    CAL = list(np.random.choice(fieldIDs, size=int(0.6*len(fieldIDs)), replace=False))
    remainingfields = list(set(fieldIDs) - set(CAL))
    VAL = list(np.random.choice(remainingfields, size=int(0.85*len(remainingfields)), replace=False))
    TEST = list(set(remainingfields) - set(VAL))
    log.info('Nr. of CAL fields: {}'.format(len(CAL)))
    log.info('Nr. of VAL fields: {}'.format(len(VAL)))
    log.info('Nr. of TEST fields: {}'.format(len(TEST)))

    calvaltest = {
        'CAL':  pd.Series(data=CAL),
        'VAL':  pd.Series(data=VAL),
        'TEST': pd.Series(data=TEST)
    }

    calvaltestfile = str(dest / str(year + '_CALVALTEST.p'))
    log.info('Saving to pickle file: {}'.format(calvaltestfile))
    pickle.dump(calvaltest, open(calvaltestfile, 'wb'))

def read_s1_training(year, source=CROPSAR_BASE / "data/training/S1_GEE"):

    filename = '_'.join([year,
                         "GroundTruthFlanders",
                         str(int(year)-1) + '-06-01',
                         year + '-12-31',
                         "S1_backscatter.p"])

    filepath = str(Path(source / filename))

    log.info('Reading S1 data: {}'.format(filepath))

    df = pickle.load(open(filepath, 'rb'))

    log.info('Done')

    return df

def read_s2_training(year, variable, version, source=CROPSAR_BASE / "data/training"):

        if variable not in ["FAPAR", "NDVI"]: sys.exit('S2 training variable not recognized: {}'.format(variable))

        filename = '_'.join(["S2_" + variable,
                             year,
                             str(int(year) - 1) + '0601',
                             year + '1231',
                             "1ha_10m.csv"])

        filepath = str(Path(source / version / "Croptype_ALL" / filename))

        log.info('Reading S2 data: {}'.format(filepath))

        df = pd.read_csv(filepath, index_col=0, parse_dates=True)

        log.info('Done')

        return df

def read_croptypes(year, source=CROPSAR_BASE / "data/ref/pickle"):

    filepathcroptypes = str(source / str(year + '_croptypes.p'))

    if not os.path.exists(filepathcroptypes):
        shp = str(CROPSAR_BASE / "data/ref/shp/{}_CroptypesFlemishParcels_DeptLV.shp".format(year))
        _shp_to_calvaltestpickle(shp, str(year))

    log.info('Reading croptypes: {}'.format(filepathcroptypes))

    # Open the croptypes
    croptypes = pickle.load(open(filepathcroptypes, 'rb'))

    filepathcalvaltest = str(source / str(year + '_CALVALTEST.p'))

    log.info('Reading CAL/VAL/TEST: {}'.format(filepathcalvaltest))

    # Open CAL/VAL/TEST
    calvaltest = pickle.load(open(filepathcalvaltest, 'rb'))

    return croptypes, calvaltest


def main(years, s2variable, s2cleaningversion):

    s2parquetfiles = []
    croptypescalvaltestparquetfiles = []
    s1parquetfiles = {
        'ASCENDING': {
            'VV': [],
            'VH': [],
            'incidenceAngle': []
        },
        'DESCENDING': {
            'VV': [],
            'VH': [],
            'incidenceAngle': []
        }
    }

    for year in years:
        log.info('-'*70)
        log.info('Working on: {}'.format(year))

        s1data = read_s1_training(year)
        s2data = read_s2_training(year, s2variable, s2cleaningversion)
        croptypes, calvaltest = read_croptypes(year)

        # Take all needed fields
        allFieldIDs = list(calvaltest['CAL'].values) + list(calvaltest['VAL'].values) + list(calvaltest['TEST'].values)

        # Find the fields that are present in all dataframes
        fieldIDsCommon = list(set(list(set(s1data['ASCENDING']['VV'].columns.tolist())
                                       .intersection(s2data.columns.tolist()))).intersection(
                                        croptypes.index.tolist()))

        # Now intersect both lists
        finalfieldIDs = list(set(fieldIDsCommon).intersection(set(allFieldIDs)))
        log.info('Amount of fields to extract: {}'.format(len(finalfieldIDs)))

        #------------------------------------------------------------------------------------------
        # Now we subset all time series on these fieldIDs and save the resulting df to parquet file
        # and we do the same for the croptypes and CAL/VAL/TEST lists

        # Set the destination folder for parquet files
        parquetdest = CROPSAR_BASE / "tmp/dataStackGenerationSPARK/data/parquetFiles" / "S1GEE_S2CLEANED_{}".format(s2cleaningversion) / year

        # Make sure destionation exists
        os.makedirs(str(parquetdest), exist_ok=True)

        # Croptypes and CAL/VAL/TEST
        croptypes = croptypes.loc[fieldIDsCommon]
        calvaltestseries = pd.Series(index=fieldIDsCommon)
        cal = list(set(calvaltest['CAL'].tolist()).intersection(fieldIDsCommon))
        val = list(set(calvaltest['VAL'].tolist()).intersection(fieldIDsCommon))
        test = list(set(calvaltest['TEST'].tolist()).intersection(fieldIDsCommon))
        calvaltestseries.loc[cal] = 'CAL'
        calvaltestseries.loc[val] = 'VAL'
        calvaltestseries.loc[test] = 'TEST'
        croptypes['CALVALTEST'] = calvaltestseries

        croptypescalvaltestparquetfile = save_as_parquet('{}.parquet'.format('croptypes_calvaltest'), croptypes,
                                                         dest=parquetdest, overwrite=True)
        croptypescalvaltestparquetfiles.append(croptypescalvaltestparquetfile)

        # Sentinel-2
        s2data = s2data[fieldIDsCommon].T
        s2parquetfile = save_as_parquet('{}.parquet'.format('S2_' + s2variable), s2data,
                                        dest=parquetdest, overwrite=False)
        s2parquetfiles.append(s2parquetfile)

        # Sentinel-1
        orbitpasses = ['ASCENDING', 'DESCENDING']
        s1variables = ['VV', 'VH', 'incidenceAngle']

        for orbitpass in orbitpasses:
            for variable in s1variables:
                s1dataSubset = s1data[orbitpass][variable][fieldIDsCommon].T
                s1parquetfile  = save_as_parquet('S1_{}_{}.parquet'.format(orbitpass, variable),
                                                 s1dataSubset, dest=parquetdest)
                s1parquetfiles[orbitpass][variable].append(s1parquetfile)



    # ---------------------------------------------------------------------------------------------------------------
    # Finally merge all parquet files from the different years
    log.info('-'*70)
    if len(years) > 1:
        parquetdest = CROPSAR_BASE / "tmp/dataStackGenerationSPARK/data/parquetFiles" / "S1GEE_S2CLEANED_{}".format(
            s2cleaningversion) / ''.join(years)
        os.makedirs(str(parquetdest), exist_ok=True)

        # If we want to combine all years, we need a hack to pretend as if all data comes from the same year: 2018
        index = pd.date_range('2017-06-01', '2018-12-31')

        ###############
        # Croptypes / calvaltest
        log.info('Concatenating croptypes and CAL/VAL/TEST dataframes ...')
        croptypescompletedf = pd.DataFrame()

        for croptypescalvaltestparquetfile in croptypescalvaltestparquetfiles:
            df = pd.read_parquet(croptypescalvaltestparquetfile)
            croptypescompletedf = pd.concat([croptypescompletedf, df], sort=False)

        # Write to final parquet file
        save_as_parquet('{}.parquet'.format('croptypes_calvaltest'), croptypescompletedf,
                        dest=parquetdest, overwrite=False)

        ###############
        # Sentinel-2
        log.info('Concatenating S2 dataframes ...')
        s2completedf = pd.DataFrame(columns=index.tolist())

        for s2parquetfile in s2parquetfiles:
            df = pd.read_parquet(s2parquetfile)
            if len(df.columns) == len(index) + 1:
                # We're in an odd year, need to pop 1 item to make the df fit
                df = df.iloc[:, :-1]
            df = df.T.set_index(index).T
            s2completedf = pd.concat([s2completedf, df], sort=False)

        # Write to final parquet file
        save_as_parquet('{}.parquet'.format('S2_' + s2variable), s2completedf, dest=parquetdest, overwrite=False)

        ###############
        # Sentinel-1
        log.info('Concatenating S1 dataframes ...')
        orbitpasses = ['ASCENDING', 'DESCENDING']
        s1variables = ['VV', 'VH', 'incidenceAngle']
        for orbitpass in orbitpasses:
            for variable in s1variables:
                s1completedf = pd.DataFrame(columns=index.tolist())

                for s1parquetfile in s1parquetfiles[orbitpass][variable]:
                    df = pd.read_parquet(s1parquetfile)
                    if len(df.columns) == len(index) + 1:
                        # We're in an odd year, need to pop 1 item to make the df fit
                        df = df.iloc[:, :-1]
                    df = df.T.set_index(index).T
                    s1completedf = pd.concat([s1completedf, df], sort=False)

                # Write to final parquet file
                save_as_parquet('S1_{}_{}.parquet'.format(orbitpass, variable), s1completedf,
                                dest=parquetdest, overwrite=False)

if __name__ == '__main__':

    logging.basicConfig(level=logging.INFO,
                        format='{asctime} {levelname} {name}: {message}',
                        style='{',
                        datefmt='%Y-%m-%d %H:%M:%S')

    log.setLevel(logging.INFO)

    years = ['2019']
    s2cleaningversion = 'V005_SHUB'
    s2variable = "FAPAR"

    main(years, s2variable, s2cleaningversion)
