# coding: utf-8

import glob
import logging
import configparser
import shutil
from s2_clean import smooth

import numpy as np
from pyspark import SparkConf,StorageLevel
from pyspark.sql.types import *
from pyspark.sql.functions import udf

try:
    import _pickle as pickle
except:
    import pickle

from cropsar.readers import *


log = logging.getLogger(__name__)










def parse_params(S1smoothing, maskAngle, outputResolution, takeOutScale, timesteps, minArea, NRTfraction):
    if S1smoothing:
        # if int(S1smoothing) not in [12, 30, 60]: sys.exit(
        #     'Wrong S1 smoothing window specified: {} (should be 12, 30 or 60)'.format(str(S1smoothing)))
        S1smoothing = int(S1smoothing)
    S1smoothingParam = '_S1smoothing' + str(S1smoothing) + 'd' if S1smoothing is not None else ''
    maskAngleParam = '_maskAngle' if maskAngle else ''
    takeOutScaleParam = '_takeOutScale' + str(takeOutScale).replace(".","") if takeOutScale is not None else ''
    minAreaParam = '_minArea' + str(minArea) if minArea else ''
    NRTfractionParam = '_NRTfraction' + str(NRTfraction).replace(".","") if NRTfraction else ''
    outPattern = '_multiCrop_allFields_DataStack_' + str(
        outputResolution) + 'd' + S1smoothingParam + maskAngleParam + takeOutScaleParam + \
                 minAreaParam + NRTfractionParam + '_' + str(timesteps) + 'd' + '_'
    return S1smoothing, outPattern


def load_fields(calval_split, indir):
    if calval_split is None:
        fields = pd.read_parquet(os.path.join(indir, 'croptypes_calvaltest.parquet'))

        CAL_fields = fields[fields['CALVALTEST'] == 'CAL'].index.tolist()
        VAL_fields = fields[fields['CALVALTEST'] == 'VAL'].index.tolist()
        TEST_fields = fields[fields['CALVALTEST'] == 'TEST'].index.tolist()

    elif type(calval_split) is str:
        # Load CAL/VAL/TEST lists
        CALVAL = pickle.load(open(calval_split, 'rb'))
        CAL_fields = CALVAL['CAL']
        VAL_fields = CALVAL['VAL']
        TEST_fields = CALVAL.get('TEST')

    elif type(calval_split) is dict:
        CALVAL = calval_split
        CAL_fields = CALVAL['CAL']
        VAL_fields = CALVAL['VAL']
        TEST_fields = CALVAL.get('TEST')
    return CAL_fields, VAL_fields, TEST_fields


def get_logger():
    logger = logging.getLogger("cropsar")
    logger.setLevel(logging.INFO)
    log_formatter = logging.Formatter("%(asctime)s [%(levelname)s - THREAD: %(threadName)s - %(name)s] : %(message)s")
    log_stream_handler = logging.StreamHandler()
    log_stream_handler.setFormatter(log_formatter)
    logger.addHandler(log_stream_handler)
    logger.info('-' * 50 + "\nDatastack generation program ...")
    logger.info('-' * 50)
    return logger

def convert_row(tuple):
    field_id  = tuple[0]
    outputs   = float(tuple[1][1])
    s1_data   = tuple[1][0][ :, 0:2].tolist()
    s2_data   = tuple[1][0][ :, 2:3].tolist()
    return Row(field_id=field_id, outputs=outputs, s1_data=s1_data,s2_data=s2_data)

def convert_row_RNNfull(tuple):
    field_id  = tuple[0]
    outputs   = tuple[1][1][:, 0:1].tolist()
    s1_data   = tuple[1][0][ :, 0:4].tolist()
    s2_data   = tuple[1][0][ :, 4:5].tolist()
    delta_obs = tuple[1][0][ :, 5:6].tolist()
    weights = tuple[1][2].tolist()
    return Row(field_id=field_id, outputs=outputs, s1_data=s1_data,s2_data=s2_data, delta_obs=delta_obs, weights=weights)

def compute_and_write_tfrecords(field_subset, timeseries_df, outputResolution, output_basename,
                                timesteps, S1smoothing, maskAngle,takeOutScale, NRTfraction,
                                S2layername='FAPAR', S1var='gamma'):
    field_subset.sort()
    field_subset = list(map(lambda x: str(x), field_subset))
    fields_df = SparkSession.builder.getOrCreate().createDataFrame(field_subset, schema=StringType())
    fields_df = fields_df.orderBy(fields_df.value.desc())
    filtered = fields_df.join(timeseries_df,fields_df.value == timeseries_df.name,'inner').drop(fields_df.value)
    ##Filter
    calibration_timeseries = filtered.repartition(max(2,int(len(field_subset)/200)))

    ##Generate Inputs
    input_output_rdd = calibration_timeseries.rdd \
        .map(lambda row: (row.name, getDataStacks(row.name, df_row_to_dict(row.name, row), timesteps,
                                                  timestepUnit='days', outputResolution=outputResolution,
                                                  S1smoothing=S1smoothing, maskAngle=maskAngle,
                                                  accum=None, takeOutScale=takeOutScale, NRTfraction=NRTfraction,
                                                  S2layername=S2layername, S1var=S1var)))\
        .filter(lambda t:t[1][0] is not None) \
        .flatMapValues(lambda input_output_tuple:list(zip(input_output_tuple[0],input_output_tuple[1]))) \
        .map(convert_row)
    output_df = input_output_rdd.toDF()

    # Write to output files
    output_df.repartition(200).write.format("tfrecords").mode('overwrite').option("recordType", "SequenceExample") \
        .option("codec", "org.apache.hadoop.io.compress.GzipCodec").save(output_basename)

def compute_and_write_tfrecords_RNNfull(field_subset, timeseries_df, outputResolution, output_basename, scalers,
                                timesteps, S1smoothing, maskAngle,takeOutScale,scalingMethod, stratified,
                                        S2layername='FAPAR'):
    field_subset.sort()
    field_subset = list(map(lambda x: str(x), field_subset))
    fields_df = SparkSession.builder.getOrCreate().createDataFrame(field_subset, schema=StringType())
    fields_df = fields_df.orderBy(fields_df.value.desc())
    filtered = fields_df.join(timeseries_df,fields_df.value == timeseries_df.name,'inner').drop(fields_df.value)
    ##Filter
    calibration_timeseries = filtered.repartition(max(2,int(len(field_subset)/200)))

    ##Generate Inputs
    input_output_rdd = calibration_timeseries.rdd \
        .map(lambda row: (row.name, getDataStacks(row.name, df_row_to_dict(row.name, row), timesteps, scalers,
                                                  scalingMethod=scalingMethod,S1smoothing=S1smoothing, maskAngle=maskAngle,
                                                  takeOutScale=takeOutScale, outputResolution=outputResolution,
                                                  S2layername=S2layername)))\
        .filter(lambda t:t[1][0] is not None) \
        .flatMapValues(lambda input_output_tuple:list(zip(input_output_tuple[0],input_output_tuple[1], input_output_tuple[2]))) \
        .map(convert_row_RNNfull)
    output_df = input_output_rdd.toDF()

    if stratified:
        import sys
        sys.exit('Stratified sampling not implemented for RNNfull datastacks!!')
    else:
        # Write to normal unstratified output files
        output_df.repartition(200).write.format("tfrecords").mode('overwrite').option("recordType", "SequenceExample").save(output_basename)

def compute_and_write(field_subset, timeseries_df, outputResolution, output_basename,
                      timesteps, S1smoothing, maskAngle,takeOutScale, S2layername='FAPAR', S1var='gamma'):
    """
    Deprecated: tfrecords will be used
    :param field_subset:
    :param timeseries_df:
    :param outputResolution:
    :param output_basename:
    :param timesteps:
    :param S1smoothing:
    :param maskAngle:
    :param takeOutScale:
    :return:
    """

    ##Filter
    calibration_timeseries = timeseries_df.filter(timeseries_df.name.isin(field_subset)).repartition(max(2,int(len(field_subset)/200)))

    ##Generate Outputs
    cal_outputs_rdd = generate_outputs(calibration_timeseries).cache()
    total_output_count = cal_outputs_rdd.map(lambda t: len(t[1])).sum()
    cal_outputs_dict = cal_outputs_rdd.collectAsMap()
    cal_outputs_rdd.unpersist()

    ##Generate Inputs
    cal_inputs_rdd = calibration_timeseries.rdd \
        .map(lambda row: (row.name, getDataStacks(row.name, df_row_to_dict(row.name, row), timesteps,
                                                  timestepUnit='days', outputResolution=outputResolution,
                                                  S1smoothing=S1smoothing, maskAngle=maskAngle,
                                                  accum=None, takeOutScale=takeOutScale,
                                                  S2layername=S2layername,
                                                  S1var=S1var)[0])).filter(lambda t:t[1] is not None)

    #Retrieve input partitions one by one, and save
    window_size = int(timesteps * 2 / outputResolution) + 1
    inputs = np.empty((total_output_count, window_size, 3))
    outputs = np.empty((total_output_count,))
    cal_inputs_rdd = cal_inputs_rdd.persist(StorageLevel.MEMORY_AND_DISK)
    #force computing all partitions in parallel
    log.info(cal_inputs_rdd.count())

    iterator = cal_inputs_rdd.toLocalIterator()
    obs = 0
    for currentInputData in iterator:
        fieldID = currentInputData[0]
        inputs[obs:obs + currentInputData[1].shape[0], :, :] = currentInputData[1]
        currentOutputData = cal_outputs_dict[fieldID]
        #print(currentOutputData)
        outputs[obs:obs + currentInputData[1].shape[0], ] = currentOutputData
        obs += currentInputData[1].shape[0]

    cal_inputs_rdd.unpersist()
    np.save(output_basename + 'Inputs', inputs)
    np.save(output_basename + 'Outputs', outputs)
    return (inputs,outputs)

def write_to_temp(tempDir, fieldID, input_output_tuple):
    # Save to numpy array, used in test functions
    log.info('Datastack built, saving to numpy arrays ...')
    np.save(os.path.join(tempDir, fieldID + '_inputs'), input_output_tuple[0])
    np.save(os.path.join(tempDir, fieldID + '_outputs'), input_output_tuple[1])

if __name__ == '__main__':

    main_parquet(indir='/data/CropSAR/tmp/dataStackGenerationSPARK/data/parquetFiles/S1GEE_S2CLEANED_V005/201620172018/',
                 outDir='/data/CropSAR/tmp/dataStackGenerationSPARK/data/dataStacks/S1GEE_S2CLEANED_V005/RNN/201620172018/',
                 year='201620172018', timesteps=150, outputResolution=1,
                 S1smoothing=13, maskAngle=False, takeOutScale=None, minArea=10000,
                 NRTfraction=None, stackType='RNN', S2layername='FAPAR', S1var='sigma')