import sys, os, glob
import time
import shutil
import traceback
import argparse
import numpy as np
import configparser
import logging
import pandas as pd
from itertools import product
from cropsar.readers import get_spark_context
import tensorflow as tf
from cropsar.validation.evaluate_model import createScatterplot
from cropsar.preprocessing.utils import minmaxunscaler
from cropsar.preprocessing.prepare_datastack_rnn import getFullDataStacks, ts_to_dict

log = logging.getLogger(__name__)

class globals(object):
    inputDir = r'/data/CropSAR/tmp/dataStackGenerationSPARK/data/dataStacks/'

    nrObs_CAL = 150000  # Amount of calibration training points, not used in TFrecord version
    nrObs_VAL = 50000  # Amount of validation training points, not used in TFrecord version

    steps_per_epoch = 1500 # Used in TFrecord version
    validation_steps = 500 # Used in TFrecord version

    lossFunction = 'tilted loss' # Only used in config.ini

    timesteps = 150 # Default value, will be replaced by training data specific value if available

    epochs = 150 #default number of epochs to run

################ FUNCTIONS ############################

def get_logger():
    logger = logging.getLogger("cropsar")
    logger.setLevel(logging.INFO)
    log_formatter = logging.Formatter("%(asctime)s [%(levelname)s - THREAD: %(threadName)s - %(name)s] : %(message)s")
    log_stream_handler = logging.StreamHandler()
    log_stream_handler.setFormatter(log_formatter)
    logger.addHandler(log_stream_handler)
    logger.info('-' * 50 + "\nCropSAR RNN training program ...")
    logger.info('-' * 50)
    return logger

# Save settings to ini file
def save_settings(cal_paths, val_paths, datastack_settings_data, datastack_settings_file, outDir, experimentName,
                  experiment, modelName, steps_per_epoch, validation_steps,
                  maxEpochs, lossFunction, outfile, batch_size=128, rnnLayer='GRU', nrObsCal=None, nrObsVal=None):

    # Initiate ini file and sections
    config = configparser.ConfigParser()
    config.add_section('ModelData')
    config.add_section('ModelParameters')
    config.add_section('TrainingSettings')
    config.add_section('DataStackSettings')
    config.add_section('TrainingData')

    # Add information to ini
    config['ModelData']['modelDir'] = str(outDir)
    config['ModelData']['experiment'] = str(experimentName)
    config['ModelData']['modelName'] = str(modelName)
    config['ModelData']['modelFile'] = str(os.path.join(outDir, modelName + '.h5'))
    config['ModelData']['weightsFile'] = str(os.path.join(outDir, modelName + '_weights.h5'))

    config['ModelParameters']['nodes'] = str(experiment['nodes'])
    config['ModelParameters']['dropoutFraction'] = str(experiment['dropoutFraction'])
    config['ModelParameters']['bidirectional'] = 'True'
    config['ModelParameters']['rnn_layer'] = str(rnnLayer)

    config['TrainingSettings']['nrObsCal'] = str(nrObsCal)
    config['TrainingSettings']['nrObsVal'] = str(nrObsVal)
    config['TrainingSettings']['maxEpochs'] = str(maxEpochs)
    config['TrainingSettings']['lossFunction'] = lossFunction
    config['TrainingSettings']['steps_per_epoch'] = str(steps_per_epoch)
    config['TrainingSettings']['validation_steps'] = str(validation_steps)
    config['TrainingSettings']['batch_size'] = str(batch_size)

    config['DataStackSettings']['s1smoothing'] = str(datastack_settings_data['Parameters']['s1smoothing'])
    config['DataStackSettings']['takeoutRate'] = str(datastack_settings_data['Parameters']['takeoutRate'])
    config['DataStackSettings']['timesteps'] = str(datastack_settings_data['Parameters']['timesteps'])
    config['DataStackSettings']['outputresolution'] = str(datastack_settings_data['Parameters']['outputresolution'])

    config['TrainingData']['calpath'] = str(cal_paths)
    config['TrainingData']['valpath'] = str(val_paths)
    config['TrainingData']['testpath'] = str(cal_paths).replace("CAL", "TEST")
    config['TrainingData']['CALfields'] = str(datastack_settings_data['NrFields']['cal'])
    config['TrainingData']['VALfields'] = str(datastack_settings_data['NrFields']['val'])
    config['TrainingData']['year'] = str(datastack_settings_data['Data']['year'])
    config['TrainingData']['inifile'] = str(datastack_settings_file)

    if os.path.exists(outfile): os.remove(outfile)
    with open(outfile, 'w') as f: config.write(f)
    print('config.ini file saved to: {}'.format(outfile))

def plot_test(inputs, outputs, predictions, outname):
    import matplotlib.pyplot as plt

    outputs = np.copy(outputs[1])
    outputs[outputs == 0] = np.nan

    s2assimilated = np.copy(inputs[1].squeeze())
    s2assimilated[s2assimilated == 0] = np.nan

    q10 = predictions[0].squeeze()
    q50 = predictions[1].squeeze()
    q90 = predictions[2].squeeze()

    fieldidx = list(np.random.randint(inputs[0].shape[0], size=3))
    i = 0

    f, ax = plt.subplots(2, 3, figsize=(19, 9))
    plt.tight_layout()

    for field in fieldidx:
        ax[0, i].fill_between(np.arange(q50.shape[1]), q10[field, :], q90[field, :], alpha=0.2, color='blue', edgecolor=None)
        ax[0, i].plot(q50[field, :], linewidth=1, color='blue')
        ax[0, i].grid(linestyle='--')

        ax[0, i].plot(outputs.squeeze()[field, :], 'ro', label='independent')
        ax[0, i].plot(s2assimilated[field, :], 'ko', label='assimilated')
        #ax[0, i].plot(weights[field, :], marker='x', color='green', linewidth=0, label='non-zero weights')
        ax[0, i].legend(loc='upper left')
        ax[0, i].set_ylim([-1, 1])

        ax[1, i].plot(inputs[0].squeeze()[field, :, 0], label='VV')
        ax[1, i].plot(inputs[0].squeeze()[field, :, 1], label='VH')
        ax[1, i].grid(linestyle='--')
        ax[1, i].legend(loc='lower left')
        ax[1, i].set_ylim([-0.8, 0.8])

        f.savefig(outname)
        plt.close()

        i+=1

def plot_metrics(cal_loss, val_loss, outname):
    import matplotlib.pyplot as plt

    print('\nCreating training progress graph ...')
    plt.figure(figsize=(12, 12))
    plt.plot(np.array(cal_loss), linewidth=2, label='cal loss')
    plt.plot(np.array(val_loss), linewidth=2, label='val loss')
    plt.grid(linestyle='--')

    plt.savefig(outname)
    plt.close()

def plot_validation_ts(valfile, model, outDir, epoch, S1smoothing, outputResolution, timesteps, S2layername, margin_in_days):
    '''
    Function to plot for a validation field the predicted time series during training
    :param valfile: path to csv file containing the necessary inputs
    :param model: tf.keras model to do the predictions
    :param outDir: base output directory where to save the figure
    :param epoch: number of the epoch that is just finished
    :param modelName: name of the model to use for filenaming
    :param S2layername: which S2 layer is being used [FAPAR or FCOVER]
    :return:
    '''
    import matplotlib.pyplot as plt

    # Read the inputs
    df = pd.read_csv(valfile, index_col=0, parse_dates=True)
    id = str(os.path.basename(valfile).split('_')[0])
    startdate = pd.to_datetime(df.index[0])
    enddate = pd.to_datetime(df.index[-1])

    #  Convert input time series to dictionary for internal use
    inputData = ts_to_dict(df['s1_clean_vv'], df['s1_clean_vh'], df['s1_clean_angle'],
                           df['fapar_clean'], id, startdate, enddate,
                           S2layername=S2layername, margin_in_days=margin_in_days)

    # Get the full datastack
    inputDataStack, assimilatedFapar, index = getFullDataStacks(id, inputData, startdate, enddate, timesteps,
                                                                outputResolution=outputResolution,
                                                                S1smoothing=S1smoothing, S2layername=S2layername)

    # Make the CropSAR prediction
    predictions = model.predict([inputDataStack[:, :, 0:2], inputDataStack[:, :, 2:3]])

    # Create the time series and dont forget to unscale
    q10 = pd.Series(index=index, data=minmaxunscaler(predictions[0].ravel(), 's2_' + S2layername.lower()))
    q50 = pd.Series(index=index, data=minmaxunscaler(predictions[1].ravel(), 's2_' + S2layername.lower()))
    q90 = pd.Series(index=index, data=minmaxunscaler(predictions[2].ravel(), 's2_' + S2layername.lower()))

    #  Make the plot
    plt.figure(figsize=(12, 6))
    plt.plot(q50, linewidth=3, color='#4298f5', label='CropSAR RNN')
    plt.fill_between(x=q50.index, y1=q90, y2=q10,
                     color='#4298f5', linewidth=0, alpha=0.4)

    # Original and cleaned FAPAR
    plt.plot(minmaxunscaler(assimilatedFapar, 's2_' + S2layername.lower()),
             'o', color='#9de33b', label='Assimilated {}'.format(S2layername))

    # Styling
    plt.ylim(0, 1)
    plt.ylabel(S2layername)
    plt.grid()
    plt.legend(loc='upper left')
    plt.title(id, fontsize=14, pad=14)
    plt.tight_layout()
    os.makedirs(os.path.join(outDir, 'testfields'), exist_ok=True)
    plt.savefig(os.path.join(outDir, 'testfields', id + '_epoch_{}'.format(epoch) + '.png'))
    plt.close()

# Tilted loss function that is used for quantile regression
def tilted_loss(q, y, f):
    # threshold = 0.1
    # a = tf.cast(tf.math.greater_equal(tf.keras.backend.abs(y-f),threshold), tf.float32)
    # b = tf.cast(tf.math.less(y-f,0.), tf.float32)*(y-f+threshold)+\
    #     tf.cast(tf.math.greater_equal(y-f,0.), tf.float32)*(y-f-threshold)
    # e = b*a
    # return tf.keras.backend.mean(tf.keras.backend.maximum(q * e, (q - 1) * e))

    e = (y-f)
    return tf.keras.backend.mean(tf.keras.backend.maximum(q*e, (q-1)*e), axis=-1)

# General function to initiate a Keras model with a set of chosen parameters
def create_model(nodes, dropoutFraction):
    log.info('Using Tensorflow implementation of Keras ...')
    from tensorflow.keras.layers import Input, Dropout, Dense, Bidirectional, concatenate, GaussianNoise, GRU, LSTM
    from tensorflow.keras.models import Model
    from tensorflow.keras.layers import BatchNormalization, LeakyReLU, AveragePooling1D
    from tensorflow.keras.optimizers import Adam

    # One branch for the Sentinel-1 input
    S1_input = Input(shape=(None, 2), name='S1_input')

    # One branch for Sentinel-2 input
    S2_input = Input(shape=(None, 1), name='S2_input')

    # RNNs for both radar and optical inputs
    RNN_S1_1 = BatchNormalization()(Bidirectional(LSTM(nodes, return_sequences=True))(S1_input))
    RNN_S2_1 = BatchNormalization()(Bidirectional(LSTM(nodes, return_sequences=True))(S2_input))

    dropout_S1 = Dropout(dropoutFraction)(RNN_S1_1)
    dropout_S2 = Dropout(dropoutFraction)(RNN_S2_1)

    # Merged S1 and S2 inputs
    merged = concatenate([dropout_S1, dropout_S2])
    RNN_merged_1 = BatchNormalization()(Bidirectional(LSTM(nodes, return_sequences=True))(merged))
    RNN_merged_1_smooth = AveragePooling1D(pool_size=21, strides=1, padding='same')(RNN_merged_1)
    RNN_merged_2 = BatchNormalization()(Bidirectional(LSTM(nodes))(RNN_merged_1_smooth))

    dropout_merged = Dropout(dropoutFraction)(RNN_merged_2)
    denseMerged1 = Dense(nodes, activation='relu')(dropout_merged)
    denseMerged2 = Dense(nodes, activation='relu')(denseMerged1)

    # One final layer towards one output
    final1 = Dense(1, activation='linear', name='q10')(denseMerged2)
    final2 = Dense(1, activation='linear', name='q50')(denseMerged2)
    final3 = Dense(1, activation='linear', name='q90')(denseMerged2)

    # Setup a custom loss function
    customLoss = {'q10': lambda y,f: tilted_loss(0.1, y, f),
                  'q50': lambda y,f: tilted_loss(0.5, y, f),
                  'q90': lambda y,f: tilted_loss(0.9, y, f)}
    customLossWeights = {'q10': 1,
                         'q50': 1,
                         'q90': 1}

    # Compile the network
    model = Model(inputs=[S1_input, S2_input], outputs=[final1, final2, final3])
    adam = Adam(lr=0.0005, beta_1=0.9, beta_2=0.999, epsilon=None, decay=0.0, amsgrad=False)
    model.compile(optimizer=adam, loss=customLoss, loss_weights=customLossWeights, metrics=['mae'])

    return model

# Function to subset data on desired number of obs
def subsetTrainingData(inputs, outputs, nrObs):
    if inputs.shape[0] < nrObs:
        return inputs, outputs
    else:
        idxSubset = np.random.choice(np.arange(inputs.shape[0]), size=nrObs, replace=False)
        return inputs[idxSubset, :, :], outputs[idxSubset]

def parse_record(example_proto):
    example = tf.io.parse_single_sequence_example(example_proto, sequence_features={
        's1_data': tf.io.FixedLenSequenceFeature([2], dtype=tf.float32, allow_missing=False),
        's2_data': tf.io.FixedLenSequenceFeature([1], dtype=tf.float32, allow_missing=False)
    }, context_features={
        'field_id': tf.io.FixedLenFeature(shape=[], dtype=tf.string),
        'outputs': tf.io.FixedLenFeature(shape=[], dtype=tf.float32)
    })
    s2_value = example[0]['outputs']
    return (example[1]['s1_data'], example[1]['s2_data']), (
    s2_value, s2_value, s2_value)

def create_dataset(paths_string:str, batchsize):
    nrFiles = len(glob.glob(paths_string))

    compression_type = ""
    if paths_string.endswith(".gz"):
        compression_type = "GZIP"

    files = tf.data.Dataset.list_files(paths_string).shuffle(nrFiles)

    # increase cycle length to read multiple files in parallel
    dataset = files.interleave(
            lambda path: tf.data.TFRecordDataset(path, compression_type=compression_type),
            cycle_length=4, block_length=1, num_parallel_calls=tf.data.experimental.AUTOTUNE)

    return dataset.map(parse_record, num_parallel_calls=nrFiles).repeat().shuffle(5000).batch(batchsize).prefetch(1)

def tf_datastack_to_array(tf_path, nrBatches = 350):
    '''
    Function to transform input data from TFrecord files into numpy arrays that can also be fed into network

    :param tf_path: path where the TFrecord file(s) are located
    :param nrBatches: the number of batches to generate into arrays (one batch = 128 samples)
    :return: numpy arrays containing inputs and outputs
    '''

    # Create the TFrecord dataset generator
    testDataset = create_dataset(tf_path, batchsize=128)

    # Get the input/output shapes
    for data in testDataset.take(1):
        s1inputs = np.empty((0, data[0][0].numpy().shape[1], data[0][0].numpy().shape[2]))
        s2inputs = np.empty((0, data[0][1].numpy().shape[1], data[0][1].numpy().shape[2]))
        outputs = np.empty((0, ))

    # Populate arrays with input and output data
    print('Getting test input and output data ...')
    for data in testDataset.take(nrBatches):
        # Get new batch
        s1inputs = np.concatenate([s1inputs, data[0][0].numpy()], axis=0)
        s2inputs = np.concatenate([s2inputs, data[0][1].numpy()], axis=0)
        outputs = np.concatenate([outputs, data[1][0].numpy()], axis=0)

    return (s1inputs, s2inputs), outputs

# Function that will run on the executors and performs all necessary tasks
def runModel_tfrecords(experiment, outDir, logDir,
                       epochs=globals.epochs, steps_per_epoch=globals.steps_per_epoch,
                       validation_steps=globals.validation_steps, lossFunction=globals.lossFunction,
                       batch_size=128,
                       S2layername='FAPAR',
                       datastack_settings="/data/CropSAR/tmp/dataStackGenerationSPARK/data/dataStacks_tfrecords_Kristof/GEE_Cleaned_V002/2017/2017_multiCrop_allFields_DataStack_1d_150d_.ini",
                       cal_paths="/data/CropSAR/tmp/dataStackGenerationSPARK/data/dataStacks_tfrecords_Kristof/GEE_Cleaned_V002/2017/2017_multiCrop_allFields_DataStack_1d_150d_CAL_/part*",
                       val_paths="/data/CropSAR/tmp/dataStackGenerationSPARK/data/dataStacks_tfrecords_Kristof/GEE_Cleaned_V002/2017/2017_multiCrop_allFields_DataStack_1d_150d_VAL_/part*",
                       test_paths="/data/CropSAR/tmp/dataStackGenerationSPARK/data/dataStacks_tfrecords_Kristof/GEE_Cleaned_V002/2017/2017_multiCrop_allFields_DataStack_1d_150d_TEST_/part*"):
    """
    Run the model using the tensorflow.data api to read the input data.

    :param experiment:
    :param outDir:
    :param logDir:
    :param outputResolution:
    :param application:
    :param withS1:
    :param mode:
    :param epochs:
    :return:
    """

    # Read parameters of the input datastack
    datastack_settings_data = configparser.ConfigParser()
    datastack_settings_data.read_file(open(str(datastack_settings), 'r'))
    timesteps = datastack_settings_data['Parameters']['timesteps']
    outputResolution = datastack_settings_data['Parameters']['outputresolution']
    S1smoothing = int(datastack_settings_data['Parameters']['S1smoothing']) if datastack_settings_data['Parameters']['S1smoothing'] is not None else None

    model, modelName = experiment_to_model(experiment, timesteps, outputResolution, outDir, logDir)
    experimentName = os.path.basename(outDir)

    CAL_dataset = create_dataset(cal_paths, batch_size)
    VAL_dataset = create_dataset(val_paths, batch_size)

    # Define callbacks
    os.umask(0o000)
    callbacks = setup_callbacks(logDir, modelName, outDir, S1smoothing, outputResolution, timesteps, S2layername, test_paths)

    # Save all settings to ini file
    iniOutfile = os.path.join(outDir, os.path.splitext(modelName)[0] + '.ini')
    save_settings(cal_paths, val_paths, datastack_settings_data, datastack_settings, outDir, experimentName, experiment,
                  modelName, steps_per_epoch, validation_steps, epochs,
                  lossFunction, iniOutfile, batch_size=batch_size, rnnLayer='GRU')

    # Train the model without sample weights
    print('Start model training ...')
    model.fit(CAL_dataset,
              steps_per_epoch = steps_per_epoch,
              validation_data=VAL_dataset,
              validation_steps = validation_steps,
              epochs=epochs,
              verbose=2,  # One line per epoch
              callbacks=callbacks)

# Function that will run on the executors and performs all necessary tasks
def runModel(experiment, outDir, logDir, CALdata, VALdata, outputResolution,
             epochs=250, timesteps=globals.timesteps):
    model, modelName = experiment_to_model(experiment, timesteps, outputResolution,outDir, logDir)

    # Define callbacks
    os.umask(0o000)
    S1smoothing = 13
    timesteps = 150
    callbacks = setup_callbacks(logDir, modelName, outDir, S1smoothing, outputResolution,
                                timesteps, S2layername='FAPAR')

    # Train the model
    log.info('Start model training ...')
    model.fit(x=CALdata[0], y=[CALdata[1], CALdata[1], CALdata[1]],
              batch_size=128,
              validation_data=(VALdata[0], [VALdata[1], VALdata[1], VALdata[1]]),
              epochs=epochs,
              verbose=2,  # One line per epoch
              callbacks=callbacks)

def experiment_to_model(experiment, timesteps, outputResolution, outDir, logDir):
    log.info('-' * 50)
    log.info('Model settings of current experiment:')
    log.info('-' * 50)
    for key in experiment.keys():
        log.info('{}: {}'.format(key, experiment[key]))
    log.info('Experiment directory: {}'.format(outDir))
    # Get all settings for this run
    nodes = experiment['nodes']
    dropoutFraction = experiment['dropoutFraction']
    # Construct model name
    modelType = 'BiDirGRU'
    modelName = '_'.join(
        [modelType, str(timesteps) + 'days', str(outputResolution) + 'd', str(nodes) + 'nodes',
         str(dropoutFraction).replace('.', '') + 'dropout'])

    log.info('Model name: {}'.format(modelName))
    # Delete any existing models of this name in output location
    if os.path.exists(modelName): os.remove(modelName)
    if os.path.exists(os.path.join(outDir, modelName + '_weights.h5')): os.remove(
        os.path.join(outDir, modelName + '_weights.h5'))

    # Create a model with the specified settings
    model = create_model(nodes, dropoutFraction)

    # Print summary of the network
    log.info(model.summary())

    # Also save the network architecture to a text file
    with open(os.path.join(logDir, modelName + '_Architecture.log'), 'w') as f:
        f.write('--- ' + modelName + ' ---' + '\n\n\n')
        model.summary(print_fn=lambda x: f.write(x + '\n'))
    return model, modelName


def setup_callbacks(logDir, modelName, outDir, S1smoothing, outputResolution, timesteps, S2layername, testpath=None):
    from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard, EarlyStopping
    from tensorflow.keras.callbacks import ReduceLROnPlateau

    checkpointer = ModelCheckpoint(filepath=os.path.join(outDir, modelName + '.h5'), save_best_only=True, verbose=1)
    earlyStopping = EarlyStopping(monitor='val_loss', min_delta=0, patience=8, verbose=1, mode='auto')
    reduceLR = ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=3, verbose=1, mode='auto', cooldown=0,
                                 min_lr=0)
    checkpointSaveWeights = ModelCheckpoint(filepath=os.path.join(outDir, modelName + '_epoch_{epoch:02d}-valloss_{val_loss:.2f}_weights.h5'),
                                            save_weights_only=True, verbose=1)
    if os.path.exists(os.path.join(logDir, modelName)): shutil.rmtree(os.path.join(logDir, modelName))
    # Create tensorboard callback to start logging into tensorboard format
    tensorboard = TensorBoard(log_dir=os.path.join(logDir, modelName))

    # Setup an evaluation callback to plot the current performance scatter plot
    if testpath is not None:
        os.makedirs(os.path.join(outDir, 'scatterplots'), exist_ok=True)
        class ModelEvaluation(tf.keras.callbacks.Callback):
            def on_epoch_end(self, epoch, logs=None):

                log.info('Creating scatter plot ...')

                # Get inputs and outputs from TFrecord files
                (S1inputs, S2inputs), outputs = tf_datastack_to_array(testpath, 150)
                predictions = minmaxunscaler(self.model.predict([S1inputs, S2inputs])[1].ravel(), 's2_' + S2layername.lower())
                outputs = minmaxunscaler(outputs.reshape((-1, 1)).ravel(), 's2_' + S2layername.lower())

                createScatterplot(outputs, predictions, 'epoch_{}'.format(epoch), outDir, str(epoch), modelName)

                # Plot some validation time series
                #valdir = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'resources')
                valdir = '/data/CropSAR/tmp/kristof/cropsar_ts/testtraining/'
                valfiles = glob.glob(os.path.join(valdir, '*_S2_FAPAR_shub.csv'))
                print('Found {} files in {}'.format(len(valfiles), valdir))
                log.info('Creating test time series plots ...')
                for valfile in valfiles: plot_validation_ts(valfile, self.model, outDir, str(epoch),
                                                            S1smoothing, int(outputResolution), int(timesteps),
                                                            S2layername, 150)


        callbacks = [checkpointer, checkpointSaveWeights, tensorboard, reduceLR, ModelEvaluation(), earlyStopping]
    else:
        callbacks = [checkpointer, checkpointSaveWeights, tensorboard, reduceLR, earlyStopping]
    return callbacks


# Function to define all experiments to be run
def getExperiments():
    # Settings to test
    settings = {
        'nodes': [64],
        'dropoutFraction': [0.2],
    }

    keys, values = zip(*settings.items())
    experiments = [dict(zip(keys, v)) for v in product(*values)]

    return experiments

############### MAIN ###################################
def main(args):
    log.info('-' * 50 + "\nCROPSAR training program ...")

    ## For working in SPARK context first initiate it
    if args.bSpark:
        # set if working in local mode or in cluster
        sc = get_spark_context("CropSAR_" + args.experiment)
    ## End of ini SPARK context

    print('Tensorflow version: {}'.format(tf.__version__))

    ############################################################
    # Set some settings
    ############################################################
    outDir = os.path.join('/data/CropSAR/tmp/NN/models/', args.experiment)
    g = globals

    nrJobs = args.jobs # Total number of parallellized jobs; should be the number of executors we require in SPARK submit
    np.random.seed(5) # for reproducability

    logDir = os.path.join(outDir, 'logs') # where to save the TensorBoard logs of the experiments
    os.makedirs(outDir, exist_ok=True)
    os.makedirs(logDir, exist_ok=True)

    ############################################################
    # Load the input data
    ############################################################

    # LOAD THE DATA
    log.info('Loading input data ...')
    inputDataSource = os.path.join(g.inputDir, g.year + '_multiCrop_100000Fields_DataStack_takeOutRate015_cleaned_' + str(
                                       g.outputResolution) + 'd_')
    log.info('Input data source: {}'.format(inputDataSource))
    inputsCAL = np.load(inputDataSource + 'CAL_Inputs.npy')
    outputsCAL = np.load(inputDataSource + 'CAL_Outputs.npy')
    inputsVAL = np.load(inputDataSource + 'VAL_Inputs.npy')
    outputsVAL = np.load(inputDataSource + 'VAL_Outputs.npy')
    log.info('Input data loaded!')

    # Take subset
    inputsCAL, outputsCAL = subsetTrainingData(inputsCAL, outputsCAL, g.nrObs_CAL)
    inputsVAL, outputsVAL = subsetTrainingData(inputsVAL, outputsVAL, g.nrObs_VAL)

    ############################################################
    # Prepare the model runs
    ############################################################

    # Get the experiments to be run
    experiments = getExperiments()
    log.info('{} models need to be trained ...'.format(len(experiments)))

    # Make sure input data is in right format to be fed to Keras
    CALdata = ([inputsCAL[:, :, 0:2], inputsCAL[:, :, 2:3]], outputsCAL)
    VALdata = ([inputsVAL[:, :, 0:2], inputsVAL[:, :, 2:3]], outputsVAL)

    ############################################################
    # Train the networks on the executors
    ############################################################

    if args.bSpark:
        if args.mode == 'cpu':
            log.info('Working in CPU mode --> sending experiments to executors ...')
            sc.parallelize(experiments, nrJobs).foreach(
                    lambda experiment: runModel(experiment, outDir, logDir, CALdata, VALdata, g.outputResolution,
                                                args.application))
        elif args.mode == 'gpu':
            log.info('Working in GPU mode --> processing all experiments in serial ...')
            for experiment in experiments:
                runModel(experiment, outDir, logDir, CALdata, VALdata, g.outputResolution,
                         args.application)
    else:
        log.info('Training models in serial...')
        for experiment in experiments:
            runModel(experiment, outDir, logDir, CALdata, VALdata, g.outputResolution,
                     args.application)

    log.info('-' * 50 + '\nAll done!')

    log.info('-' * 80)
    log.info('DISTRIBUTED DNN TRAINING PROGRAM SUCCESSFULLY FINISHED!')
    log.info('-' * 80)


############### END of MAIN ###########################

############### MAIN ###################################
def main_tfrecords(experiment = 'GRU_150timesteps_BiDir_QuantileRegr_10000TP_CPU',
    datastack_settings = "/data/CropSAR/tmp/dataStackGenerationSPARK/data/dataStacks_tfrecords_Kristof/GEE_Cleaned_V002/2017/2017_multiCrop_allFields_DataStack_1d_150d_.ini",
    cal_paths = "/data/CropSAR/tmp/dataStackGenerationSPARK/data/dataStacks_tfrecords_Kristof/GEE_Cleaned_V002/2017/2017_multiCrop_allFields_DataStack_1d_150d_CAL_/part*",
    val_paths = "/data/CropSAR/tmp/dataStackGenerationSPARK/data/dataStacks_tfrecords_Kristof/GEE_Cleaned_V002/2017/2017_multiCrop_allFields_DataStack_1d_150d_VAL_/part*",
    test_paths = None,
    out='/data/CropSAR/tmp/NN/models/', mode='cpu', S2layername='FAPAR', nrJobs = 2, epochs=globals.epochs, batch_size=128,
    spark=True, steps_per_epoch=globals.steps_per_epoch, validation_steps=globals.validation_steps):
    """
    :param experiment: experiment name
    :param mode: cpu or gpu
    :param datastack_settings: path to training ini file that contains the settings based on which datastack was generated
    :param nrJobs: Total number of parallellized jobs; should be the number of executors we require in SPARK submit
    :param spark:
    :param ini: location where config.ini file for the model will be saved
    :param local:

    :return:
    """

    log = get_logger()

    log.info('-' * 50 + "\nCROPSAR training program ...")

    ## For working in SPARK context first initiate it
    if spark:
        sc = get_spark_context("CROPSAR_" + experiment)

    ## End of ini SPARK context

    print('Tensorflow version: {}'.format(tf.__version__))

    ############################################################
    # Set some settings
    ############################################################

    # Define the model output directory
    outDir = os.path.join(out, experiment)

    np.random.seed(5)  # for reproducability

    logDir = os.path.join(outDir, 'logs')  # where to save the TensorBoard logs of the experiments
    os.makedirs(outDir, exist_ok=True)
    os.makedirs(logDir, exist_ok=True)

    ############################################################
    # Prepare the model runs
    ############################################################

    # Get the experiments to be run
    experiments = getExperiments()
    log.info('{} models need to be trained ...'.format(len(experiments)))

    ############################################################
    # Train the networks on the executors
    ############################################################

    if spark:
        if mode == 'cpu':
            log.info('Working in CPU mode --> sending experiments to executors ...')
            sc.parallelize(experiments, nrJobs).foreach(
                lambda experiment: runModel_tfrecords(experiment, outDir, logDir,
                                                      cal_paths=cal_paths,val_paths=val_paths,
                                                      test_paths=test_paths,
                                                      epochs=epochs, datastack_settings=datastack_settings,
                                                      steps_per_epoch=steps_per_epoch,
                                                      validation_steps=validation_steps, batch_size=batch_size,
                                                      S2layername=S2layername))
        elif mode == 'gpu':
            log.info('Working in GPU mode --> processing all experiments in serial ...')
            for experiment in experiments:
                runModel_tfrecords(experiment, outDir, logDir,
                                   cal_paths=cal_paths,val_paths=val_paths, test_paths=test_paths, epochs=epochs,
                                   datastack_settings=datastack_settings, steps_per_epoch=steps_per_epoch,
                                   validation_steps=validation_steps, batch_size=batch_size,
                                   S2layername=S2layername)
    else:
        log.info('Training models in serial...')
        for experiment in experiments:
            runModel_tfrecords(experiment, outDir, logDir,
                               cal_paths=cal_paths,val_paths=val_paths, test_paths=test_paths, epochs=epochs,
                               datastack_settings=datastack_settings, steps_per_epoch=steps_per_epoch,
                               validation_steps=validation_steps, batch_size=batch_size, S2layername=S2layername)

    log.info('-' * 50 + '\nAll done!')

    log.info('-' * 80)
    log.info('DISTRIBUTED DNN TRAINING PROGRAM SUCCESSFULLY FINISHED!')
    log.info('-' * 80)


############### END of MAIN ###########################

# Main code
if __name__ == "__main__":

    try:
        # check if right Python version is available.
        assert sys.version_info[0:2] >= (3, 5), "You need at minimum python 3.5 to execute this script."
        start_time = time.time()

        # set up parser and parse input arguments
        parser = argparse.ArgumentParser()

        parser.add_argument("experiment", type=str, help="Experiment name")
        parser.add_argument("-v", "--verbose", help="increase output verbosity", action="store_true")
        parser.add_argument("-l", "--local", help='Set -l to tell script to run in local mode, hence 1 Spark executor. Should run together with -s argument.', action="store_true")
        parser.add_argument("-s", "--spark", help='Set -s if you want to work in the spark context', action="store_true")
        parser.add_argument("-j", "--jobs", action="store", type=int, default=25)

        args = parser.parse_args()

        if args.verbose: log.info(time.asctime())

        # run the main part of the script
        main(args)

        if args.verbose: log.info(time.asctime())
        if args.verbose: log.info('TOTAL TIME for Fusion model training IN MINUTES:')
        if args.verbose: log.info("{:10.4f}".format((time.time() - start_time) / 60))

        sys.exit(0)
    except KeyboardInterrupt:  # Ctrl-C
        raise
    except SystemExit:  # sys.exit()
        raise
    except Exception as e:
        log.error('ERROR, UNEXPECTED EXCEPTION')
        log.error(str(e))
        traceback.print_exc()
        os._exit(1)



