import glob
import tensorflow as tf

from cropsar.preprocessing.prepare_datastack_rnnfull import *

log = logging.getLogger(__name__)


def get_spark_context(name="DatastackGeneration_Program SPARK", local=False):
    if not local:
        spark = SparkSession.builder \
            .appName(name) \
            .config('spark.executor.memory', '4G') \
            .config('spark.driver.memory', '4G') \
            .config("spark.jars.packages", "org.tensorflow:spark-tensorflow-connector_2.11:1.11.0") \
            .getOrCreate()
        sc = spark.sparkContext
    else:
        from pyspark import SparkConf, SparkContext
        conf = SparkConf()
        conf.setMaster('local[1]')
        sc = SparkContext(appName=name, conf=conf)
    return sc


class globals(object):
    inputDir = r'/data/CropSAR/tmp/dataStackGenerationSPARK/data/dataStacks/'

    nrObs_CAL = 150000  # Amount of calibration training points, not used in TFrecord version
    nrObs_VAL = 50000  # Amount of validation training points, not used in TFrecord version

    steps_per_epoch = 1500  # Used in TFrecord version
    validation_steps = 500  # Used in TFrecord version

    lossFunction = 'tilted loss adapted threshold 0.1'  # Only used in config.ini

    # Default value, will be replaced by training data specific value if available
    timesteps = 150

    epochs = 250  # default number of epochs to run

################ FUNCTIONS ############################

# Tilted loss function that is used for quantile regression


def tilted_loss(q, y, f):
    #threshold = 0.1
    #a = tf.to_float(tf.math.greater_equal(K.abs(y-f),threshold))
    #b = tf.to_float(tf.math.less(y-f,0.))*(y-f+threshold)+tf.to_float(tf.math.greater_equal(y-f,0.))*(y-f-threshold)
    #e = b*a
    e = y-f

    return tf.keras.backend.mean(tf.keras.backend.maximum(q*e, (q-1)*e), axis=-1)

# General function to initiate a Keras model with a set of chosen parameters


def create_model(nodes, dropoutFraction):
    from tensorflow.keras.models import Model
    from tensorflow.keras.layers import concatenate, TimeDistributed, Dense, Input, Bidirectional, Lambda, Conv1D, MaxPooling1D, Flatten
    from tensorflow.keras.layers import Dropout, BatchNormalization, LSTM, GRU, AveragePooling1D, LeakyReLU
    from tensorflow.keras.optimizers import Adam

    S1input = Input(shape=(None, 2), name="S1input")
    S2input = Input(shape=(None, 1), name="S2input")

    gruS1 = Dropout(dropoutFraction)(Bidirectional(
        GRU(nodes, return_sequences=True), merge_mode='concat')(S1input))
    gruS2 = Dropout(dropoutFraction)(Bidirectional(
        GRU(nodes, return_sequences=True), merge_mode='concat')(S2input))

    S1_S2_merged = concatenate([gruS1, gruS2])
    RNN_merged = Dropout(dropoutFraction)(Bidirectional(
        GRU(nodes*2, return_sequences=True), merge_mode='concat')(S1_S2_merged))
    RNN_merged_smooth = AveragePooling1D(
        pool_size=11, strides=1, padding='same')(RNN_merged)

    denseMerged = LeakyReLU()(TimeDistributed(
        Dense(128, activation='linear'))(RNN_merged_smooth))

    # One final layer towards the outputs
    final1 = TimeDistributed(
        Dense(1, activation='linear'), name='q10')(denseMerged)
    final2 = TimeDistributed(
        Dense(1, activation='linear'), name='q50')(denseMerged)
    final3 = TimeDistributed(
        Dense(1, activation='linear'), name='q90')(denseMerged)

    # Setup a custom loss function
    customLoss = {'q10': lambda y, f: tilted_loss(0.1, y, f),
                  'q50': lambda y, f: tilted_loss(0.5, y, f),
                  'q90': lambda y, f: tilted_loss(0.9, y, f)}
    customLossWeights = {'q10': 1,
                         'q50': 1,
                         'q90': 1}
    model = Model(inputs=[S1input, S2input], outputs=[final1, final2, final3])
    adam = Adam(lr=0.0001)
    model.compile(optimizer=adam, loss_weights=customLossWeights,
                  loss=customLoss, sample_weight_mode="temporal", metrics=['mae'])

    # model.summary()

    return model


def parse_record(example_proto):

    example = tf.io.parse_single_sequence_example(example_proto, sequence_features={
        's1_vv': tf.io.FixedLenSequenceFeature([1], dtype=tf.float32, allow_missing=False),
        's1_vh': tf.io.FixedLenSequenceFeature([1], dtype=tf.float32, allow_missing=False),
        's1_angle': tf.io.FixedLenSequenceFeature([1], dtype=tf.float32, allow_missing=False),
        's2_data': tf.io.FixedLenSequenceFeature([1], dtype=tf.float32, allow_missing=False)
    }, context_features={
        'field_id': tf.io.FixedLenFeature(shape=[], dtype=tf.string),
        'croptype': tf.io.FixedLenFeature(shape=[], dtype=tf.string)
    })

    return ((example[1]['s1_vv'], example[1]['s1_vh'], example[1]['s1_angle'],
             example[1]['s2_data']), (example[0]['field_id'], example[0]['croptype']))


def create_dataset(paths_string: str, batchsize):
    nrFiles = len(glob.glob(paths_string))

    compression_type = ""
    if paths_string.endswith(".gz"):
        compression_type = "GZIP"

    files = tf.data.Dataset.list_files(paths_string).shuffle(nrFiles)

    # increase cycle length to read multiple files in parallel
    dataset = files.interleave(
        lambda path: tf.data.TFRecordDataset(
            path, compression_type=compression_type),
        cycle_length=4, block_length=1, num_parallel_calls=tf.data.experimental.AUTOTUNE)

    return dataset.map(parse_record, num_parallel_calls=nrFiles).shuffle(5000).repeat().batch(batchsize).prefetch(50)


def tf_dataset_to_array(dataset, nrbatches, removalratios=None, cutoff=False):

    # Determine shapes
    for data in dataset.take(1):
        inputs, outputs, weights = convert_inputs(data, tfdataset=True)
    s1inputs = np.empty((0, inputs[0].shape[1], inputs[0].shape[2]))
    s2inputs = np.empty((0, inputs[1].shape[1], inputs[1].shape[2]))
    outputs = np.empty((0, outputs[0].shape[1], outputs[0].shape[2]))
    weights = np.empty((0, weights[0].shape[1]))

    for data in dataset.take(nrbatches):
        currentinputs, currentoutputs, currentweights = convert_inputs(data, removalratios=removalratios,
                                                                       cutoff=cutoff, tfdataset=True)
        s1inputs = np.concatenate([s1inputs, currentinputs[0]], axis=0)
        s2inputs = np.concatenate([s2inputs, currentinputs[1]], axis=0)
        outputs = np.concatenate([outputs, currentoutputs[0]], axis=0)
        weights = np.concatenate([weights, currentweights[0]], axis=0)

    return ([s1inputs, s2inputs], outputs, weights)


def plot_test(inputs, outputs, predictions, weights, outname):
    import matplotlib
    import matplotlib.pyplot as plt

    outputs = np.copy(outputs[1])
    outputs[outputs == 0] = np.nan
    #weights = np.copy(weights[1])
    #weights[weights == 0] = np.nan

    s2assimilated = np.copy(inputs[1].squeeze())
    s2assimilated[s2assimilated == 0] = np.nan

    if tf.__version__ == '2.0.0':
        # TODO: once CropSAR requirements don't fix TF version at 2.0.0 any longer this should be removed
        q10 = predictions[0].squeeze()
        q50 = predictions[1].squeeze()
        q90 = predictions[2].squeeze()
    else:
        q10 = predictions[0].numpy().squeeze()
        q50 = predictions[1].numpy().squeeze()
        q90 = predictions[2].numpy().squeeze()

    fieldidx = list(np.random.randint(inputs[0].shape[0], size=3))
    i = 0

    backend = matplotlib.get_backend()
    try:
        plt.switch_backend('Agg')

        f, ax = plt.subplots(2, 3, figsize=(19, 9))
        plt.tight_layout()

        for field in fieldidx:
            ax[0, i].fill_between(np.arange(q50.shape[1]), q10[field, :],
                                  q90[field, :], alpha=0.2, color='blue', edgecolor=None)
            ax[0, i].plot(q50[field, :], linewidth=1, color='blue')
            ax[0, i].grid(linestyle='--')

            ax[0, i].plot(outputs.squeeze()[field, :],
                          'ro', label='independent')
            ax[0, i].plot(s2assimilated[field, :], 'ko', label='assimilated')
            #ax[0, i].plot(weights[field, :], marker='x', color='green', linewidth=0, label='non-zero weights')
            ax[0, i].legend(loc='upper left')
            ax[0, i].set_ylim([-2, 1])

            ax[1, i].plot(inputs[0].squeeze()[field, :, 0], label='VV')
            ax[1, i].plot(inputs[0].squeeze()[field, :, 1], label='VH')
            ax[1, i].grid(linestyle='--')
            ax[1, i].legend(loc='lower left')
            ax[1, i].set_ylim([-2, 1])

            f.savefig(outname)
            plt.close()

            i += 1

    finally:
        plt.switch_backend(backend)


def plot_datastack(s1inputs, s2inputs, outputs, weights, outname):
    import matplotlib
    import matplotlib.pyplot as plt

    weights = weights.astype(float)

    outputs[outputs == -2.] = np.nan
    weights[weights == 0] = np.nan

    backend = matplotlib.get_backend()
    try:
        plt.switch_backend('Agg')

        f, ax = plt.subplots(2, 1, figsize=(15, 9))
        plt.tight_layout()

        # Sentinel-2
        ax[0].grid(linestyle='--')

        ax[0].plot(outputs.squeeze(),
                   'ro', label='outputs')
        ax[0].plot(s2inputs[:, 0], 'k', label='S2 inputs')
        ax[0].plot(weights/np.nanmax(weights), marker='x', color='green',
                   linewidth=0, label='non-zero weights')
        ax[0].legend(loc='upper left')
        ax[0].set_ylim([-1, 1.2])

        # Sentinel-1
        ax[1].plot(s1inputs[:, 0], label='VV')
        ax[1].plot(s1inputs[:, 1], label='VH')
        ax[1].grid(linestyle='--')
        ax[1].legend(loc='lower left')
        ax[1].set_ylim([-1, 1])

        f.savefig(outname)
        plt.close()

    finally:
        plt.switch_backend(backend)


def plot_metrics(cal_loss, val_loss, outname):
    import matplotlib
    import matplotlib.pyplot as plt

    backend = matplotlib.get_backend()
    try:
        plt.switch_backend('Agg')

        print('\nCreating training progress graph ...')
        plt.figure(figsize=(12, 12))
        plt.plot(np.array(cal_loss), linewidth=2, label='cal loss')
        plt.plot(np.array(val_loss), linewidth=2, label='val loss')
        plt.grid(linestyle='--')

        plt.savefig(outname)
        plt.close()

    finally:
        plt.switch_backend(backend)


def train(modelname, modelrootdir, trainingdir, figdir, year='201620172018', batchsize=32,
          num_epochs=250, num_steps=2000, num_val_steps=1000):

    os.makedirs(os.path.join(modelrootdir, modelname), exist_ok=True)
    os.makedirs(figdir, exist_ok=True)

    removalratios = [0,
                     0.1, 0.1,
                     0.2, 0.2, 0.2, 0.2,
                     0.3, 0.3, 0.3, 0.3, 0.3, 0.3,
                     0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4,
                     0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5,
                     0.6, 0.6, 0.6, 0.6, 0.6, 0.6, 0.6, 0.6,
                     0.7, 0.7, 0.7, 0.7, 0.7, 0.7,
                     0.8, 0.8, 0.8, 0.8,
                     0.9, 0.9,
                     1]

    print('Initializing tensorflow datasets ...')
    dataset_CAL = create_dataset(os.path.join(
        trainingdir, '{}_data_CAL_'.format(year), 'part*.gz'), batchsize=batchsize)
    dataset_VAL = create_dataset(os.path.join(
        trainingdir, '{}_data_VAL_'.format(year), 'part*.gz'), batchsize=batchsize)
    dataset_TEST = create_dataset(os.path.join(
        trainingdir, '{}_data_TEST_'.format(year), 'part*.gz'), batchsize=batchsize)

    print('Creating dataset iterators ...')
    dataset_CAL_iter = iter(dataset_CAL)
    dataset_VAL_iter = iter(dataset_VAL)
    dataset_TEST_iter = iter(dataset_TEST)

    # Configure the datasets to return the amount of batches we need in one epoch
    #dataset_CAL = dataset_CAL.take(num_steps)
    #dataset_VAL = dataset_VAL.take(num_val_steps)

    print('Setting up model ...')

    model = create_model(64, dropoutFraction=0.2)

    print('Starting training ...')

    lowest_val_loss = 10000
    epochs_since_improvement = 0

    mean_epoch_cal_loss = []
    mean_epoch_val_loss = []

    while True:

        # manually enumerate epochs
        for i in range(num_epochs):
            print('Epoch: {}'.format(i + 1))
            all_cal_loss = []
            all_val_loss = []

            progbar = tf.keras.utils.Progbar(
                num_steps, stateful_metrics=['cal loss'])
            # get training data
            for j in range(num_steps):

                # Get the next calibration batch
                inputbatch = next(dataset_CAL_iter)

                # convert training data
                inputs, output, weights = convert_inputs(inputbatch, removalratios=removalratios, cutoff=True,
                                                         tfdataset=True, training=True, addnoise=True, forecast=True,
                                                         S1var='sigma')
                s1_input, s2_input = inputs[0], inputs[1]

                # Train the model
                cal_loss = model.train_on_batch(
                    x=[s1_input, s2_input], y=output, sample_weight=weights)[0]
                all_cal_loss.append(cal_loss)

                if j % 25 == 0:
                    # summarize loss on this batch
                    progbar.update(
                        j, values=[('mean epoch cal loss', np.mean(np.array(all_cal_loss)))])

                if j % 200 == 0:

                    # Run predictions on test data
                    print('\nGenerating test data for plotting ...')

                    # Get the next test batch
                    inputbatch = next(dataset_TEST_iter)
                    testinputs, testoutputs, testweights = convert_inputs(inputbatch, removalratios=removalratios,
                                                                          cutoff=False, tfdataset=True,
                                                                          training=True,
                                                                          addnoise=False, forecast=True,
                                                                          S1var='sigma')

                    testpredictions = model.predict_on_batch(testinputs)
                    currenttestdir = os.path.join(
                        figdir, 'epoch_{}'.format(i+1))
                    os.makedirs(currenttestdir, exist_ok=True)
                    outname = os.path.join(
                        currenttestdir, 'step_{}.png'.format(j+1))

                    plot_test(testinputs, testoutputs,
                              testpredictions, testweights, outname)

            # When all batches for one step have been processed, we do validation
            print('\nEpoch finished')
            print('Running validation ...')
            progbar = tf.keras.utils.Progbar(
                num_val_steps, stateful_metrics=['val loss'])
            for valstep in range(num_val_steps):

                # Get the next validation batch
                inputbatch = next(dataset_VAL_iter)

                # Convert
                inputs, output, weights = convert_inputs(inputbatch, removalratios=removalratios, cutoff=True,
                                                         tfdataset=True, training=True, addnoise=False, forecast=True,
                                                         S1var='sigma')
                s1_input, s2_input = inputs[0], inputs[1]
                # Calculate validation loss for current batch
                val_loss = model.test_on_batch([s1_input, s2_input], y=output, sample_weight=weights,
                                               reset_metrics=True)[0]
                all_val_loss.append(val_loss)

                if valstep % 5 == 0:
                    progbar.update(valstep, values=[
                                   ('mean epoch val loss', np.mean(np.array(all_val_loss)))])

            mean_epoch_cal_loss.append(np.mean(np.array(all_cal_loss)))
            mean_epoch_val_loss.append(np.mean(np.array(all_val_loss)))

            # Create training progress graph
            outname = os.path.join(figdir, 'trainingprogress.png')
            plot_metrics(mean_epoch_cal_loss, mean_epoch_val_loss, outname)

            if np.mean(np.array(all_val_loss)) < lowest_val_loss:
                print('Model is performing better -> saving')
                lowest_val_loss = np.mean(np.array(all_val_loss))

                #model.save(os.path.join(modelrootdir, modelname, modelname + '.h5'))
                model.save_weights(os.path.join(
                    modelrootdir, modelname, modelname + '_weights.h5'))

                epochs_since_improvement = 0

            else:
                epochs_since_improvement += 1
                if epochs_since_improvement > 7:
                    print("It's been 7 epochs since loss decreased -> early stopping!")
                    break
            print('-' * 70)
        print('Training finished!')
        break

############### MAIN ###################################


def main_tfrecords(modelname, modelrootdir, trainingdir, figdir, local=True):
    """
    :param modelname: desired output modelname
    :param modelrootdir: location where the model will be stored
    :param trainingdir: directory where training files are located
    :param figdir: location where figures will be stored
    :param local: whether or not to train in local spark session

    :return:
    """

    np.seterr(invalid='ignore')

    log.info('-' * 50 + "\nCROPSAR training program ...")
    log.info('Tensorflow version: {}'.format(tf.__version__))

    get_spark_context(name='RNNfull CROPSAR training', local=local)

    # All functionality is in the train() function
    train(modelname, modelrootdir, trainingdir, figdir)

    log.info('-' * 50 + '\nAll done!')
    log.info('-' * 80)
    log.info('CROPSAR RNNFULL TRAINING PROGRAM SUCCESSFULLY FINISHED!')
    log.info('-' * 80)


if __name__ == '__main__':

    modelname = 'cropsar_ts_gru_pooling11_withcutoff_lastvalueremoved05_64nodes_withprediction_-2'
    modelrootdir = '/data/CropSAR/tmp/kristof/cropsar_ts/models/RNNfull'
    trainingdir = '/data/CropSAR/tmp/dataStackGenerationSPARK/data/dataStacks/S1GEE_S2CLEANED_V005/RNNfull/201620172018/'
    figdir = os.path.join(modelrootdir, modelname, 'figs')
    local = False

    main_tfrecords(modelname, modelrootdir, trainingdir, figdir, local=local)
