from matplotlib import pyplot as plt
from scipy import stats
import configparser
from pathlib import Path, PurePath
import platform as pltf

from cropsar.training.train_RNNfullmodel import *
from cropsar.preprocessing.prepare_data_stack import *

def convert_windows_path(purepath, root_unix='/data/CropSAR/', root_windows='O:/'):
    return Path(root_windows) / Path(purepath.relative_to(root_unix))

def createScatterplot(outputs, predictions, outname, outDir, experiment, modelWeights):
    """
    Function to create a scatterplot between observed and predicted biopars
    :param outputs: numpy array with observed values
    :param predictions: numpy array with predicted values
    :param outname: outputname of the scatterplot
    :param outDir: output directory to save the scatterplot
    :return:
    """

    # Make a scatter plot
    slope, intercept, r_value, p_value, std_err = stats.linregress(predictions, outputs)
    predRegress = np.arange(0, 1, 0.05) * slope + intercept

    # Calculate the point density
    xy = np.vstack([outputs, predictions])
    z = stats.gaussian_kde(xy)(xy)

    plt.figure(figsize=(10, 10))
    plt.scatter(outputs, predictions, c=z, s=50, edgecolor='')
    plt.plot(np.arange(0, 1, 0.05), predRegress, color='red', linewidth=3)
    plt.plot([0, 1], [0, 1], color='black')
    plt.grid()
    plt.xlabel('Observed FAPAR')
    plt.ylabel('Predicted FAPAR')
    plt.xlim(0, 1.1)
    plt.ylim(0, 1.1)
    MSE = np.mean(np.abs(outputs - predictions))

    plt.text(0.05, 1, 'Slope: {}'.format(np.round(slope, 2)), {'fontsize': 16, 'color': 'red'})
    plt.text(0.05, 0.95, 'MSE: {}'.format(np.round(MSE, 2)), {'fontsize': 16, 'color': 'red'})
    plt.text(0.05, 0.9, 'r: {}'.format(np.round(r_value, 2)), {'fontsize': 16, 'color': 'red'})
    plt.title(experiment + '\n' + os.path.basename(modelWeights))

    # plt.tight_layout()
    if not os.path.exists(os.path.join(outDir, 'scatterplots')): os.mkdir(os.path.join(outDir, 'scatterplots'))
    plt.savefig(os.path.join(outDir, 'scatterplots', outname))
    plt.close()
    print('Scatterplot saved to: {}'.format(os.path.join(outDir, 'scatterplots', outname)))

def evaluate_model(modelPath, outDir, cleaningVersion, experiment, nrTestBatches=350, scatterPlot=True, fields=True, TFrecordPath=None,
                   testFieldData=None, startDate=None, endDate=None, useAfter=False, overwrite=True):

    # Consistency check
    if scatterPlot and TFrecordPath is None: sys.exit('A TFrecordPath is needed to create a scatterplot!')
    if fields and testFieldData is None: sys.exit('A testFieldsPath is needed to create a field time series!')
    if fields and (startDate is None or endDate is None): sys.exit('A start and end date is needed to create a field time series!')

    # Create the model infrastructure
    print('Initializing model ...')
    model = create_model(nodes=32, dropoutFraction=0.2)

    # Load the model weights
    print('Loading model weights ...')
    model.load_weights(str(modelPath))
    print('Model successfully loaded!')

    if scatterPlot:

        # Set outname for scatterplot
        outname = os.path.basename(os.path.dirname(modelPath)) + '_' + os.path.splitext(os.path.basename(modelPath))[0] + '.png'

        # Only create if necessary
        if not os.path.exists(os.path.join(outDir, 'scatterplots', outname)) or overwrite:

            print('-' * 50)
            print('SCATTERPLOT')
            print('-' * 50)

            # Get inputs and outputs from TFrecord files
            testdataset = create_dataset(str(TFrecordPath), batchsize=128)
            inputs, outputs, weights = tf_dataset_to_array(testdataset, nrTestBatches, removalratios=0.3, cutoff=False)

            print('Making predictions ...')
            predictions = minmaxunscaler(model.predict(inputs)[1].ravel(), 's2_fapar')

            # Uncscale the outputs
            outputs = minmaxunscaler(outputs.reshape((-1, 1)).ravel(), 's2_fapar')

            # Remove the points without valid fapar
            weights = weights.ravel()
            idx = np.where(weights != 0)[0]
            predictions = predictions[idx]
            outputs = outputs[idx]

            # Create the scatterplot
            createScatterplot(outputs, predictions, outname, outDir, experiment, str(modelPath))

    if fields:

        print('-' * 50)
        print('FIELD TIME SERIES')
        print('-' * 50)

        if not os.path.exists(os.path.join(outDir, 'testFields')): os.mkdir(os.path.join(outDir, 'testFields'))
        if not os.path.exists(os.path.join(outDir, 'testFields', experiment)):
            os.mkdir(os.path.join(outDir, 'testFields', experiment))

        # Initialize some lists to calculate statistics
        FAPARerrors = np.empty((0,1))
        predictionSpreadMean = np.empty((0,1))
        predictionSpreadMedian = np.empty((0,1))

        # Get the FieldIDs
        fieldIDs = testFieldData['S2']['FAPAR'].columns.tolist()

        index = testFieldData['S2']['FAPAR'].index.tolist()

        i = 0
        for field in fieldIDs:

            # Build full datastack
            # useAfter = False means that no potentially available data after "endDate" will be used for gapfilling
            data = datadict_to_tuple(testFieldData, field)
            inputs, outputs, weights = convert_inputs(data, tfdataset=False)

            # Get the independent FAPAR observations that were not used for gap-filling
            independentFAPAR = testFieldData['S2']['FAPARindependent'][field]

            # Make predictions for entire TS
            predictions = model.predict(inputs)

            # Unscale the mean (median) prediction
            meanPrediction = pd.Series(index=index, data=minmaxunscaler(predictions[1], 's2_fapar').ravel())

            # Unscale the quantiles
            q10 = pd.Series(index=index, data=minmaxunscaler(predictions[0], 's2_fapar').ravel())
            q90 = pd.Series(index=index, data=minmaxunscaler(predictions[2], 's2_fapar').ravel())

            # Get the assimilated FAPAR TS and unscale
            assimilatedFAPAR = pd.Series(index=index, data=np.copy(outputs[1].ravel()))
            null_index = assimilatedFAPAR.isnull()
            assimilatedFAPAR.loc[~null_index] = minmaxunscaler(
                assimilatedFAPAR.loc[~null_index].values.reshape((-1, 1)), 's2_fapar').ravel()

            if endDate is not None: independentFAPAR = independentFAPAR.loc[:endDate]

            # Calculate a 5d running mean
            meanPredictionRunningMean = meanPrediction.rolling(window=5, win_type='hamming', center=True).mean()

            # Derive some statistics
            FAPARerrorsCurrent = meanPrediction.loc[independentFAPAR.notnull() & meanPrediction.notnull()] - independentFAPAR.loc[independentFAPAR.notnull() & meanPrediction.notnull()]
            FAPARerrors = np.append(FAPARerrors, FAPARerrorsCurrent.values)
            predictionSpreadMean = np.append(predictionSpreadMean, np.mean(q90 - q10))
            predictionSpreadMedian = np.append(predictionSpreadMedian, np.mean(q90 - q10))

            # Plot the result and save the figure
            outname = os.path.join(outDir, 'testFields', experiment,
                                   str(field) + '_' + testFieldData['croptypes'][
                             i] + '_' + startDate + '_' + endDate + '_GapFilling_Cleaned_' + cleaningVersion + '_TF.png')
            if not os.path.exists(outname) or overwrite:

                plt.figure(figsize=(12, 8))
                plt.plot(meanPredictionRunningMean, label='CropSAR')
                plt.plot(assimilatedFAPAR, 'bo', label='Assimilated')
                plt.plot(independentFAPAR, 'ro', label='Independent')
                # plt.plot(testFieldData['S2']['FAPAR'][field],'ko')
                plt.grid()
                plt.xlim(startDate, endDate)
                plt.ylim(0, 1)

                plt.fill_between(q10.index, q10.rolling(window=5, win_type='hamming', center=True).mean(),
                                 q90.rolling(window=5, win_type='hamming', center=True).mean(), alpha=0.5, label='q10-q90')

                plt.legend()
                plt.title('Croptype: {} ({})'.format(testFieldData['croptypes'][i], field))
                plt.savefig(outname)
                plt.close()

            i += 1

        # Report the stats
        meanFAPARerror = np.mean(FAPARerrors)
        meanFAPARabsoluteError = np.mean(np.abs(FAPARerrors))
        medianFAPARerror = np.median(FAPARerrors)
        medianFAPARabsoluteError = np.median(np.abs(FAPARerrors))
        predictionSpreadMean = np.mean(predictionSpreadMean)
        predictionSpreadMedian = np.median(predictionSpreadMedian)

        outFile = os.path.basename(os.path.dirname(modelPath)) + '_' + os.path.splitext(os.path.basename(modelPath))[
            0] + '_STATS.txt'

        with open(os.path.join(outDir, 'stats', outFile), 'w') as f:
            f.write('mean FAPAR error: {}\n'.format(np.round(meanFAPARerror, 2)))
            f.write('median FAPAR error: {}\n'.format(np.round(medianFAPARerror, 2)))
            f.write('mean FAPAR absolute error: {}\n'.format(np.round(meanFAPARabsoluteError, 2)))
            f.write('median FAPAR absolute error: {}\n'.format(np.round(medianFAPARabsoluteError, 2)))
            f.write('mean prediction spread: {}\n'.format(np.round(predictionSpreadMean, 2)))
            f.write('median prediction spread: {}\n'.format(np.round(predictionSpreadMedian, 2)))

    return

def main(experiment, modelweights, tfpath, useAfter=False, plotFields=True, makeScatterplot=True, nrBatches=350, cleaningVersion='V005',
         startDate='2016-09-01', endDate='2017-10-31', overwrite=True):
    '''
    main function to perform a CropSAR model evaluation

    :param modelIniFile: path to the ini file with all necessary model settings
    :param useAfter: whether or not to use available observations AFTER the specified endDate (default=False)
    :param plotFields: whether or not to plot time series of the independent fields (default=True)
    :param makeScatterplot: whether or not to make a scatterplot of network performance with its test data (default=True)
    :param nrBatches: total number of test batches (128 observations each) to generate for scatter plot(default=350)
    :param cleaningVersion: only used in reporting on the plots, does not affect function itself
    :param startDate: when to start the field time series (format: yyyy-mm-dd)
    :param endDate: when to end the field time series (format: yyyy-mm-dd)
    :return:
    '''

    # Model and input training data location and output directory
    if pltf.system() == 'Linux':
        modelPath = str(Path(modelweights))
        tf_path = Path(tfpath)
        outDir = os.path.join(r'/data/CropSAR/tmp/NN/evaluation')
    else:
        modelPath = str(convert_windows_path(PurePath(modelweights)))
        tf_path = convert_windows_path(PurePath(tfpath))
        outDir = os.path.join(r'O:\tmp\NN\evaluation')

    # Independent field data
    testFieldData = pickle.loads(
        pkgutil.get_data("cropsar.validation", "resources/2017_multiCrop_TESTfields_S1_S2_Cleaned_V002_TS.p"))

    if not os.path.exists(outDir): os.mkdir(outDir)

    # Run the evaluation function
    evaluate_model(modelPath, outDir, cleaningVersion, experiment, nrTestBatches=nrBatches, scatterPlot=makeScatterplot,
                   fields=plotFields, TFrecordPath=tf_path, testFieldData=testFieldData, startDate=startDate, endDate=endDate,
                   useAfter=useAfter, overwrite=overwrite)


if __name__ == "__main__":

    # Define which model to evaluate
    #modelIniFile = r'/data/CropSAR/tmp/NN/models/201620172018_1000spe_1000vs_GEE_Cleaned_V004_S1smoothing11d_05takeout_minArea50000_NRTfraction015/BiDirGRU_150days_1d_64nodes_02dropout_noiseFalse_gapfilling_TF.ini'
    # modelIniFile = r'O:\tmp\NN\models\CuDNNGRU_150timesteps_BiDir_QuantileRegr_allFields_TFRecords_GPU_1500spe_500vs_GEE_Cleaned_V003\BiDirGRU_150days_1d_64nodes_02dropout_noiseFalse_gapfilling_TF.ini'
    #modelIniFile = '/data/CropSAR/tmp/NN/models/CuDNNGRU_150timesteps_BiDir_QuantileRegr_allFields_TFRecords_GPU_1500spe_500vs_GEE_Cleaned_V003//BiDirGRU_150days_1d_64nodes_02dropout_noiseFalse_gapfilling_TF.ini'
    #modelIniFile = '/data/CropSAR/tmp/NN/models/20172018_CuDNNGRU_150timesteps_BiDir_QuantileRegr_allFields_TFRecords_GPU_1500spe_500vs_GEE_Cleaned_V003_adaptedLoss_01_05takeout2018/BiDirGRU_150days_1d_64nodes_02dropout_noiseFalse_gapfilling_TF.ini'

    modelweights = '/data/CropSAR/tmp/kristof/cropsar_ts/models/cropsar_ts_gru_gpu_pooling41/cropsar_ts_gru_gpu_pooling41_weights.h5'
    testtfpath = '/data/CropSAR/tmp/dataStackGenerationSPARK/data/dataStacks/S1GEE_S2CLEANED_V005/RNNfull/201620172018/201620172018_datastack_TEST_/part*.gz'

    # Set some settings
    useAfter = False
    plotFields = True
    makeScatterplot = False
    nrBatches = 10
    cleaningVersion = 'V005'

    # Define a start and end date to process
    startDate = '2016-08-01'
    endDate = '2017-11-30'

    experiment = 'gru_gpu_pooling41'

    main(experiment, modelweights, testtfpath, useAfter=useAfter, plotFields=plotFields, makeScatterplot=makeScatterplot, nrBatches=nrBatches,
         cleaningVersion=cleaningVersion, startDate=startDate, endDate=endDate)




