from cropsar import *
import pandas as pd
import os
import copy
import numpy as np
import logging, sys, glob
from s2_clean import smooth
from matplotlib import pyplot as plt
import time
from pyspark import SparkContext
import subprocess

log = logging.getLogger(__name__)

def mask_ts_fapar(ts_fapar, fapar_date_mask, identifier=''):
    '''
    function to mask a fapar pandas time series based on a list of dates
    :param ts_fapar: pandas time series of fapar
    :param fapar_date_mask: list of datestrings [yyyy-mm-dd] which will be used to mask the input fapar series
    :return: masked copy of ts_fapar
    '''

    # Make deep copy of the mask
    fapar_date_mask_copy = copy.deepcopy(fapar_date_mask)

    # Check if the dates are actually present in the input FAPAR time series
    for date in fapar_date_mask:
        if date not in ts_fapar.index.strftime('%Y-%m-%d').tolist():
            log.warning('Provided mask date not in fAPAR series: {}'.format(date))
            fapar_date_mask_copy.remove(date)

    log.info('{} - masking {} dates in input FAPAR list'.format(identifier, len(fapar_date_mask_copy)))

    # Perform the masking
    S2fapar_masked = pd.Series.copy(ts_fapar, deep=True)
    S2fapar_masked.loc[S2fapar_masked.index.isin(fapar_date_mask_copy)] = np.nan

    return S2fapar_masked

def get_start_end(S1ascendingVV, S1ascendingVH, S1ascendingIA, S1descendingVV, S1descendingVH,
                  S1descendingIA, S2_fapar_cleaned, S2_fapar_raw):
    start_date = S1ascendingVV.index[0]
    end_date = S1ascendingVV.index[-1]

    for ts in [S1ascendingVH, S1ascendingIA, S1descendingVV, S1descendingVH,
                    S1descendingIA, S2_fapar_cleaned, S2_fapar_raw]:
        if ts.index[0] < start_date: start_date = ts.index[0]
        if ts.index[-1] > end_date: end_date = ts.index[-1]

    return start_date, end_date

def get_input_ts(fieldID, root_dir, processing_version='V004'):
    '''
    function to get all necessary input time series to predict fapar
    :param fieldID: unique field identifier corresponding to folder name
    :param root_dir: path to root directory containing the csv files
    :return: all relevant pandas time series
    '''
    log.info('{} - Reading all input time series for processing version {} ...'.format(fieldID, processing_version))

    input_dir = os.path.join(root_dir, 'res', fieldID)

    try:
        S1ascendingVV = pd.read_csv(os.path.join(input_dir, 'S1_GRD_SIGMA0_ASCENDING_VV_' +
                                                 processing_version + '.csv'), index_col=0, parse_dates=True)[fieldID]
        S1ascendingVH = pd.read_csv(os.path.join(input_dir, 'S1_GRD_SIGMA0_ASCENDING_VH_' +
                                                 processing_version + '.csv'), index_col=0, parse_dates=True)[fieldID]
        S1ascendingIA = pd.read_csv(os.path.join(input_dir, 'S1_GRD_SIGMA0_ASCENDING_ANGLE_'+
                                                 processing_version + '.csv'), index_col=0, parse_dates=True)[fieldID]
        S1descendingVV = pd.read_csv(os.path.join(input_dir, 'S1_GRD_SIGMA0_DESCENDING_VV_' +
                                                 processing_version + '.csv'), index_col=0, parse_dates=True)[fieldID]
        S1descendingVH = pd.read_csv(os.path.join(input_dir, 'S1_GRD_SIGMA0_DESCENDING_VH_' +
                                                 processing_version + '.csv'), index_col=0, parse_dates=True)[fieldID]
        S1descendingIA = pd.read_csv(os.path.join(input_dir, 'S1_GRD_SIGMA0_DESCENDING_ANGLE_' +
                                                 processing_version + '.csv'), index_col=0, parse_dates=True)[fieldID]

        S2_fapar_cleaned = pd.read_csv(os.path.join(input_dir, 'S2_FAPAR_CLEAN_' +
                                                 processing_version + '.csv'), index_col=0, parse_dates=True)[fieldID]

    except:
        log.warning('{} - One or more input TS could not be loaded -> please check the input folders!'.format(fieldID))
        return None

    # Check if we still have FAPAR left, otherwise, we can't process the field
    if len(S2_fapar_cleaned.index) == 0:
        log.warning('{} - No FAPAR left in cleaned TS -> cant process this field!'.format(fieldID))
        return None

    S2_fapar_raw = pd.read_csv(os.path.join(root_dir, 'raw', fieldID, 'S2_FAPAR_In10m.csv'),
                               index_col=0, parse_dates=True)[fieldID]

    # Find correct start and end date
    start_date, end_date = get_start_end(S1ascendingVV, S1ascendingVH, S1ascendingIA, S1descendingVV, S1descendingVH,
           S1descendingIA, S2_fapar_cleaned, S2_fapar_raw)

    # Construct date range
    daterangeIndex = pd.date_range(start= start_date, end=end_date, freq='1D')

    # Reindex the input TS
    S1ascendingVV = S1ascendingVV.reindex(index=daterangeIndex)
    S1ascendingVH = S1ascendingVH.reindex(index=daterangeIndex)
    S1ascendingIA = S1ascendingIA.reindex(index=daterangeIndex)
    S1descendingVV = S1descendingVV.reindex(index=daterangeIndex)
    S1descendingVH = S1descendingVH.reindex(index=daterangeIndex)
    S1descendingIA = S1descendingIA.reindex(index=daterangeIndex)
    S2_fapar_cleaned = S2_fapar_cleaned.reindex(index=daterangeIndex)
    S2_fapar_raw = S2_fapar_raw.reindex(index=daterangeIndex)

    return S1ascendingVV, S1ascendingVH, S1ascendingIA, S1descendingVV, S1descendingVH, \
           S1descendingIA, S2_fapar_cleaned, S2_fapar_raw

def ts_cropsar(S1ascendingVV, S1ascendingVH, S1ascendingIA,
                    S1descendingVV, S1descendingVH, S1descendingIA,
                    S2fapar, fapar_date_mask, identifier='', type='RNN'):
    '''
    function to run cropsar on input data which is masked for certain dates. Returns both the predictions on the
    complete series as well as on the masked series
    :param S1ascendingVV: pandas time series of S1 ascending VV backscatter in dB
    :param S1ascendingVH: pandas time series of S1 ascending VH backscatter in dB
    :param S1ascendingIA: pandas time series of S1 ascending incidence angle in degrees
    :param S1descendingVV: pandas time series of S1 descending VV backscatter in dB
    :param S1descendingVH: pandas time series of S1 descending VH backscatter in dB
    :param S1descendingIA: pandas time series of S1 descending incidence angle in degrees
    :param S2fapar: pandas time series of (cleaned) S2 fapar
    :param fapar_date_mask: list of datestrings [yyyy-mm-dd] which will be used to mask the input fapar series
    :param identifier: string representing a field identifier, used here only for logging
    :param type: CropSAR model type (default=RNN) - get available model types by calling get_available_model_types()
    :return: dictionary with the original and masked cropsar predictions
    '''

    log.info('{} - starting CropSAR retrieval'.format(identifier))

    # Get a CropSAR model object with pre-loaded weights
    cropsarModel = get_model(type=type)

    # Mask the input fapar time series
    S2fapar_masked = mask_ts_fapar(S2fapar, fapar_date_mask, identifier=identifier)

    # Run cropsar on original and masked input data
    log.info('{} - running cropsar on original input data ...'.format(identifier))
    q10_orig, q50_orig, q90_orig = cropsarModel.get_timeseries(S1ascendingVV, S1ascendingVH, S1ascendingIA,
                                                S1descendingVV, S1descendingVH, S1descendingIA,
                                                S2fapar, identifier,
                                                S1ascendingVV.index[0].strftime('%Y-%m-%d'),
                                                S1ascendingVV.index[-1].strftime('%Y-%m-%d'))
    log.info('{} - running cropsar on masked input data ...'.format(identifier))
    q10_masked, q50_masked, q90_masked = cropsarModel.get_timeseries(S1ascendingVV, S1ascendingVH, S1ascendingIA,
                                                S1descendingVV, S1descendingVH, S1descendingIA,
                                                S2fapar_masked, identifier,
                                                S1ascendingVV.index[0].strftime('%Y-%m-%d'),
                                                S1ascendingVV.index[-1].strftime('%Y-%m-%d'))

    log.info('{} - cropsar predictions done ...'.format(identifier))

    return {'orig': [q10_orig, q50_orig, q90_orig], 'masked': [q10_masked, q50_masked, q90_masked]}

def ts_whittaker(S2fapar, fapar_date_mask, lmbda=1, passes=3, identifier=''):
    '''
    function to run a standard whittaker smooting on input data which is masked for certain dates.
    Returns both the predictions on the complete series as well as on the masked series
    '''

    log.info('{} - starting whittaker retrieval'.format(identifier))

    # Mask the input fapar time series
    S2fapar_masked = mask_ts_fapar(S2fapar, fapar_date_mask, identifier=identifier)

    # Make a daily series from the input to be able to do the smoothing
    daily_index = pd.date_range(start=S2fapar.index[0], end=S2fapar.index[-1], freq='1D')
    S2fapar = S2fapar.reindex(index=daily_index)
    S2fapar_masked = S2fapar_masked.reindex(index=daily_index)

    # whittaker smoothing
    log.info('{} - running whittaker on original input data ...'.format(identifier))
    S2_fapar_wh_orig = pd.Series(data=smooth.whittaker_second_differences(lmbda, S2fapar, minimumdatavalue=0, maximumdatavalue=1, passes=passes), index=daily_index)
    log.info('{} - running whittaker on masked input data ...'.format(identifier))
    S2_fapar_wh_masked = pd.Series(data=smooth.whittaker_second_differences(lmbda, S2fapar_masked, minimumdatavalue=0, maximumdatavalue=1, passes=passes), index=daily_index)

    return {'orig': S2_fapar_wh_orig, 'masked': S2_fapar_wh_masked}

def ts_wig(S2fapar, fapar_date_mask, identifier=''):
    '''
    function to run an emulation of WIG on input data which is masked for certain dates.
    Returns both the predictions on the complete series as well as on the masked series
    '''

    log.info('{} - starting wig retrieval'.format(identifier))

    #
    #    wig outlier detection
    #
    def wig_remove_outliers(data, maxdip = 0.01, maxdif = 0.30, maxgap = 50, nrOutlierIterations = 4):
        '''
        emulation of WIG outlier detection code - dated 08 Jun 2018
        https://git.vito.be/projects/BIGGEO/repos/geotrellistimeseries/browse/TimeseriesSmooth/src/main/java/be/vito/eodata/smoothproxy/WhittakerSmoothing.java

        :param data: daily in-values
        '''

        Yi = np.array(data, dtype = float)
        Ni = len(Yi)

        icountdips = 0
        icountdifs = 0
        icountoths = 0
        icountvals = np.count_nonzero(~np.isnan(Yi))

        for out_i in range(nrOutlierIterations) :
            log.debug("outlier detection iteration nr %s" % (out_i))

            # update reduced array of valid values only, and remember their indices in a map
            YiShortList         = []
            YiShortListMapToInd = dict()
            cr = 0
            for i in range(Ni):
                if Yi[i] > 0:
                    log.debug("YiShortListMap add value %s at index %s" % (Yi[i], i))
                    YiShortList.append(Yi[i])
                    YiShortListMapToInd[cr] = i
                    cr += 1

            # detect outliers on shortlist
            for i in range(1, len(YiShortList) - 1):
                if (YiShortList[i-1] - YiShortList[i]) > 0:
                    if (YiShortList[i+1] - YiShortList[i]) > 0:
                        # local minimum - check if it is an outlier
                        log.debug("local minimum value %s at index %s" % (YiShortList[i], YiShortListMapToInd[i]))
                        abs_dip_left  = YiShortList[i-1] - YiShortList[i]
                        abs_dip_right = YiShortList[i+1] - YiShortList[i]
                        nr_days_left  = YiShortListMapToInd[i]   - YiShortListMapToInd[i-1]
                        nr_days_right = YiShortListMapToInd[i+1] - YiShortListMapToInd[i]

                        if (((abs_dip_left/nr_days_left) > maxdip) or ((abs_dip_right/nr_days_right) > maxdip)):
                            if ((nr_days_left < maxgap) and (nr_days_right < maxgap)):
                                Yi[YiShortListMapToInd[i]] = np.nan # remove value
                                icountdips += 1
                                log.debug("    outlier dip removed at index %s - removed dip(%s) dif(%s) oth(%s) total(%s)" % (YiShortListMapToInd[i], icountdips, icountdifs, icountoths, (icountdips+icountdifs+icountoths)))
                        elif ((abs_dip_left > maxdif) or (abs_dip_right > maxdif)):
                            if ((nr_days_left < maxgap) and (nr_days_right < maxgap)):
                                Yi[YiShortListMapToInd[i]] = np.nan # remove value
                                icountdifs += 1
                                log.debug("    outlier dif removed at index %s - removed dip(%s) dif(%s) oth(%s) total(%s)" % (YiShortListMapToInd[i], icountdips, icountdifs, icountoths, (icountdips+icountdifs+icountoths)))
                        elif  ((YiShortList[i-1] > 0.6) and (YiShortList[i+1] > 0.6)):
                            Yi[YiShortListMapToInd[i]] = np.nan # remove value
                            icountoths += 1
                            log.debug("    outlier ??? removed at index %s - removed dip(%s) dif(%s) oth(%s) total(%s)" % (YiShortListMapToInd[i], icountdips, icountdifs, icountoths, (icountdips+icountdifs+icountoths)))
        #
        # done.
        #
        log.info("%s - outlier detection removed dip(%s) dif(%s) oth(%s) total(%s) of %s" % (identifier, icountdips,
                                                                                               icountdifs,
                                                                                             icountoths,
                                                                                             (icountdips+icountdifs+icountoths),
                                                                                             icountvals))
        return Yi


    # Mask the input fapar time series
    S2fapar_masked = mask_ts_fapar(S2fapar, fapar_date_mask, identifier=identifier)

    # Make a daily series from the input to be able to do the smoothing
    daily_index = pd.date_range(start=S2fapar.index[0], end=S2fapar.index[-1], freq='1D')
    S2fapar = S2fapar.reindex(index=daily_index)
    S2fapar_masked = S2fapar_masked.reindex(index=daily_index)

    #
    #    outliers detection parameters
    #
    maxdip        = 0.01
    maxdif        = 0.3
    maxgap        = 50
    #
    #    whittaker parameters
    #
    lmbda        = 1
    passes       = 3
    dokeepmaxima = True
    #
    #    weights parameters
    #
    aboutequalepsilon = 0.02
    weightvalues = smooth.WeightValues(
        maximum    =  1.5,
        minimum    =  0.005,
        posslope   =  0.5,
        negslope   =  0.02,
        aboutequal =  1.0,
        default    =  1.0)
    #
    #    wig-alike
    #
    log.info('{} - running wig emulation on original input data ...'.format(identifier))
    nooutliers_data  = wig_remove_outliers(S2fapar, maxdip = maxdip, maxdif = maxdif, maxgap = maxgap)
    weighttypes      = smooth.makeweighttypescube(nooutliers_data, aboutequalepsilon)
    weights          = smooth.makesimpleweightscube(weighttypes, weightvalues=weightvalues)
    S2_fapar_wh_orig = pd.Series(data=smooth.whittaker_second_differences(lmbda, nooutliers_data, numpyweightscube=weights, minimumdatavalue=0, maximumdatavalue=1, passes=passes, dokeepmaxima=dokeepmaxima), index=daily_index)

    log.info('{} - running wig on masked input data ...'.format(identifier))
    nooutliers_data    = wig_remove_outliers(S2fapar_masked, maxdip = maxdip, maxdif = maxdif, maxgap = maxgap)
    weighttypes        = smooth.makeweighttypescube(nooutliers_data, aboutequalepsilon)
    weights            = smooth.makesimpleweightscube(weighttypes, weightvalues=weightvalues)
    S2_fapar_wh_masked = pd.Series(data=smooth.whittaker_second_differences(lmbda, nooutliers_data, numpyweightscube=weights, minimumdatavalue=0, maximumdatavalue=1, passes=passes, dokeepmaxima=dokeepmaxima), index=daily_index)

    return {'orig': S2_fapar_wh_orig, 'masked': S2_fapar_wh_masked}

def ts_moving_average(S2fapar, fapar_date_mask, moving_average_window, identifier=''):
    '''
    function to run moving average on input data which is masked for certain dates. Returns both the moving average on
    the complete series as well as on the masked series
    :param S2fapar: pandas time series of 2 fapar
    :param fapar_date_mask: list of datestrings [yyyy-mm-dd] which will be used to mask the input fapar series
    :param moving_average_window: window in days to use for calculating moving average

    :return: dictionary with the original and masked moving averages
    '''

    log.info('{} - starting moving average retrieval'.format(identifier))

    # Mask the input fapar time series
    S2fapar_masked = mask_ts_fapar(S2fapar, fapar_date_mask, identifier=identifier)

    # Make a daily series from the input to be able to do the smoothing
    daily_index = pd.date_range(start=S2fapar.index[0], end=S2fapar.index[-1], freq='1D')
    S2fapar = S2fapar.reindex(index=daily_index)
    S2fapar_masked = S2fapar_masked.reindex(index=daily_index)

    # Calculate moving average
    log.info('{} - running moving average on original input data ...'.format(identifier))
    S2_fapar_ma_orig = pd.Series(data=smooth.movingaverage(smooth.linearinterpolation(S2fapar),
                                                           moving_average_window), index=daily_index)
    log.info('{} - running moving average on masked input data ...'.format(identifier))
    S2_fapar_ma_masked = pd.Series(data=smooth.movingaverage(smooth.linearinterpolation(S2fapar_masked),
                                                             moving_average_window), index=daily_index)

    return {'orig': S2_fapar_ma_orig, 'masked': S2_fapar_ma_masked}

def process_field(root_dir, field, cropsar_model_type, processing_version):
    start = time.time()

    # --------------------------------------------------------------------------------------------------------------
    fieldID = os.path.basename(field)
    log.info('{} - START PROCESSING FIELD'.format(fieldID))

    # Get the list of dates to mask
    if not os.path.exists(os.path.join(field, 'dates_to_mask.csv')):
        log.warning('{} - No list of dates to be removed available for this field: should first generate {}'.format(
            fieldID, os.path.join(field, 'dates_to_mask.csv')))
        return
    fapar_date_mask = pd.read_csv(os.path.join(field, 'dates_to_mask.csv'), index_col=0,
                                  parse_dates=True).index.strftime('%Y-%m-%d').tolist()

    # Get all the input time series
    try:
        S1ascendingVV, S1ascendingVH, S1ascendingIA, S1descendingVV, S1descendingVH, S1descendingIA, S2_fapar_cleaned, S2_fapar_raw = get_input_ts(
            fieldID, root_dir, processing_version=processing_version)
    except: return

    # ------------
    # CROPSAR
    # ------------
    for currentModel in cropsar_model_type:
        log.info('{} - working on cropsar model type: {}'.format(fieldID, currentModel))
        cropsar_test_result = ts_cropsar(S1ascendingVV, S1ascendingVH, S1ascendingIA,
                                              S1descendingVV, S1descendingVH, S1descendingIA,
                                              S2_fapar_cleaned, fapar_date_mask, type=currentModel,
                                              identifier=fieldID)
        log.info('{} - Writing cropsar results to csv files ...'.format(fieldID))

        cropsar_test_result['orig'][0].to_csv(os.path.join(field, 'cropsar_' + currentModel + '_' +
                                                           processing_version + '_cleaned_q10_original.csv'), header=[fieldID])
        cropsar_test_result['orig'][1].to_csv(os.path.join(field, 'cropsar_' + currentModel + '_' +
                                                           processing_version + '_cleaned_q50_original.csv'), header=[fieldID])
        cropsar_test_result['orig'][2].to_csv(os.path.join(field, 'cropsar_' + currentModel + '_' +
                                                           processing_version + '_cleaned_q90_original.csv'), header=[fieldID])
        cropsar_test_result['masked'][0].to_csv(os.path.join(field, 'cropsar_' + currentModel + '_' + processing_version
                                                             + '_cleaned_q10_masked.csv'), header=[fieldID])
        cropsar_test_result['masked'][1].to_csv(os.path.join(field, 'cropsar_' + currentModel + '_' + processing_version
                                                             + '_cleaned_q50_masked.csv'), header=[fieldID])
        cropsar_test_result['masked'][2].to_csv(os.path.join(field, 'cropsar_' + currentModel + '_' + processing_version +
                                                             '_cleaned_q90_masked.csv'), header=[fieldID])

    # ------------
    # WIG
    # ------------

    wig_test_result_raw = ts_wig(S2_fapar_raw, fapar_date_mask, identifier=fieldID)
    wig_test_result_cleaned = ts_wig(S2_fapar_cleaned, fapar_date_mask, identifier=fieldID)
    log.info('{} - Writing WIG results to csv files ...'.format(fieldID))
    wig_test_result_cleaned['orig'].to_csv(os.path.join(field, 'wig_' + processing_version +
                                                      '_cleaned_original.csv'), header=[fieldID])
    wig_test_result_cleaned['masked'].to_csv(os.path.join(field, 'wig_' + processing_version +
                                                        '_cleaned_masked.csv'), header=[fieldID])
    wig_test_result_raw['orig'].to_csv(os.path.join(field, 'wig_raw_original.csv'), header=[fieldID])
    wig_test_result_raw['masked'].to_csv(os.path.join(field, 'wig_raw_masked.csv'), header=[fieldID])

    # ------------
    # MOVING AVERAGE
    # ------------

    # Run a 50-day moving average on original and masked fapar time series, for raw and cleaned FAPAR
    moving_average_window = 50
    moving_average_test_result_raw = ts_moving_average(S2_fapar_raw, fapar_date_mask, moving_average_window,
                                                       identifier=fieldID)
    moving_average_test_result_cleaned = ts_moving_average(S2_fapar_cleaned, fapar_date_mask, moving_average_window,
                                                       identifier=fieldID)

    log.info('{} - Writing moving average results to csv files ...'.format(fieldID))
    moving_average_test_result_cleaned['orig'].to_csv(os.path.join(field, 'movingaverage' +
                                                                   str(moving_average_window) +
                                                                   '_' + processing_version +
                                                                   '_cleaned_original.csv'), header=[fieldID])
    moving_average_test_result_cleaned['masked'].to_csv(os.path.join(field, 'movingaverage' +
                                                                     str(moving_average_window) +
                                                                     '_' + processing_version + '_cleaned_masked.csv'),
                                                                        header=[fieldID])
    moving_average_test_result_raw['orig'].to_csv(os.path.join(field, 'movingaverage' +
                                                                    str(moving_average_window) +
                                                                   '_raw_original.csv'), header=[fieldID])
    moving_average_test_result_raw['masked'].to_csv(os.path.join(field, 'movingaverage' +
                                                                      str(moving_average_window) +
                                                                      '_raw_masked.csv'), header=[fieldID])

    end = time.time()

    log.info('{} - Field took {} seconds to process ...'.format(fieldID, int(end - start)))

    return

def run_tests_to_csv(root_dir, removal_scenario, cropsar_model_type=None, processing_version='V004',
                     spark=False, debug=False):

    # Setup logging
    logging.basicConfig(level=logging.INFO,
                        format='{asctime} {levelname} {name}: {message}',
                        style='{',
                        datefmt='%Y-%m-%d %H:%M:%S')

    if type(cropsar_model_type) == str: cropsar_model_type = [cropsar_model_type]
    if cropsar_model_type is None: cropsar_model_type = get_available_model_types()

    # -------------------------
    # Setup SPARK if requested
    if spark: sc = SparkContext()
    # -------------------------

    # Check if scenario exists
    if not os.path.isdir(os.path.join(root_dir, 'tests', removal_scenario)):
        log.error(
            'Removal scenario folder does not exist: {}'.format(os.path.join(root_dir, 'tests', removal_scenario)))
        sys.exit()

    # List the available fields
    fields = glob.glob(os.path.join(root_dir, 'tests', removal_scenario, '*'))
    if not fields:
        log.error(
            'No fields found in scenario folder: {}'.format(os.path.join(root_dir, 'tests', removal_scenario)))
        sys.exit()

    log.info('Found {} fields to process ...'.format(len(fields)))

    # Run the fields either in parallel on the MEP, or serial locally
    if spark:
        log.info('Sending the fields to the executors ...')
        sc.parallelize(fields, len(fields)).foreach(
            lambda field: process_field(root_dir, field, cropsar_model_type, processing_version))

        # Files created on the mep are not accessible by default on windows
        subprocess.call('chmod -R 777 {}'.format(os.path.join(root_dir, 'tests', removal_scenario, '*')), shell=True)

    else:
        log.info('Processing fields locally in serial ...')
        for field in fields:
            process_field(root_dir, field, cropsar_model_type, processing_version)
            if debug: break

    return

def plot_orig_vs_masked(predicted_fapar_orig, predicted_fapar_masked, fapar_orig, fapar_date_mask, identifier):
    '''
    simple plot function to show the difference between predictions based on original vs masked fapar series

    :param predicted_fapar_orig: predicted fapar time series based on original fapar
    :param predicted_fapar_masked: predicted fapar time series based on masked fapar
    :param fapar_orig: original fapar pandas time series
    :param fapar_date_mask: list of datestring defining the mask
    :param identifier: string to use as title of the plot
    :return: matplotlib figure instance
    '''

    # Get the masked series
    fapar_masked = mask_ts_fapar(fapar_orig, fapar_date_mask, identifier=identifier)

    f = plt.figure(figsize=(12, 8))

    plt.plot(predicted_fapar_orig, label='Predicted on original', linewidth=3)
    plt.plot(predicted_fapar_masked, label='Predicted on masked', linewidth=3)

    # Plot S2 FAPAR on top
    plt.plot(fapar_orig.loc[fapar_masked.isnull()], 'rs', label='Masked FAPAR obs')
    plt.plot(fapar_masked, 'bo', label='Used FAPAR obs')

    plt.grid()
    plt.xlim(predicted_fapar_orig.index.tolist()[0], predicted_fapar_orig.index.tolist()[-1])
    plt.ylim(0, 1)
    plt.ylabel('FAPAR')

    plt.legend()
    plt.tight_layout()
    plt.title(identifier)

    return f

def example_cropsar():
    '''
    function to run an example showing what this code does
    '''

    root_dir = r'O:\data\ref\field_selection\test_fields_sample\2017_TEST'
    fieldID = '000028044A31C991'

    # Get all the input time series
    S1ascendingVV, S1ascendingVH, S1ascendingIA, S1descendingVV, S1descendingVH, S1descendingIA, S2_fapar_cleaned, S2_fapar_raw = get_input_ts(fieldID, root_dir)

    # As an example, take out 20% of the fapars
    np.random.seed(3)
    fapar_dates =  S2_fapar_cleaned.index.strftime('%Y-%m-%d').tolist()
    fapar_date_mask = list(np.random.choice(fapar_dates, size=int(0.4*len(fapar_dates)), replace=False))

    # Run the cropsar model on original and masked fapar time series
    cropsar_test_result = ts_cropsar(S1ascendingVV, S1ascendingVH, S1ascendingIA,
                                          S1descendingVV, S1descendingVH, S1descendingIA,
                                          S2_fapar_cleaned, fapar_date_mask,
                                          identifier=fieldID, type='RNN')

    # Make a plot of both predictions
    fig = plot_orig_vs_masked(cropsar_test_result['orig'][1],
                              cropsar_test_result['masked'][1],
                              S2_fapar_cleaned, fapar_date_mask, 'CropSAR ' + fieldID)

    # Show the plot
    plt.show()

def example_moving_average():
    '''
    function to run an example showing what this code does
    '''

    root_dir = r'O:\data\ref\field_selection\test_fields_sample\2017_TEST'
    fieldID = '000028044A31C991'

    # Get all the input time series
    S1ascendingVV, S1ascendingVH, S1ascendingIA, S1descendingVV, S1descendingVH, S1descendingIA, S2_fapar_cleaned, S2_fapar_raw = get_input_ts(fieldID, root_dir)

    # As an example, take out 20% of the fapars
    np.random.seed(3)
    fapar_dates =  S2_fapar_cleaned.index.strftime('%Y-%m-%d').tolist()
    fapar_date_mask = list(np.random.choice(fapar_dates, size=int(0.4*len(fapar_dates)), replace=False))

    # Run a 90-day moving average on original and masked fapar time series
    moving_average_test_result = ts_moving_average(S2_fapar_cleaned, fapar_date_mask, 90)

    # Make a plot of both predictions
    fig = plot_orig_vs_masked(moving_average_test_result['orig'],
                              moving_average_test_result['masked'],
                              S2_fapar_cleaned, fapar_date_mask, '90-day moving average ' + fieldID)

    # Show the plot
    plt.show()

def example_whittaker():
    '''
    function to run an example showing what this code does
    '''

    root_dir = r'O:\data\ref\field_selection\test_fields_sample\2017_TEST'
    fieldID = '000028044A31C991'

    # Get all the input time series
    S1ascendingVV, S1ascendingVH, S1ascendingIA, S1descendingVV, S1descendingVH, S1descendingIA, S2_fapar_cleaned, S2_fapar_raw = get_input_ts(fieldID, root_dir)

    # As an example, take out 20% of the fapars
    np.random.seed(3)
    fapar_dates =  S2_fapar_cleaned.index.strftime('%Y-%m-%d').tolist()
    fapar_date_mask = list(np.random.choice(fapar_dates, size=int(0.4*len(fapar_dates)), replace=False))

    # Run (unweighted) 2nd order whittaker on original and masked fapar time series
    lmbda=1
    passes=3
    whittaker_test_result = ts_whittaker(S2_fapar_cleaned, fapar_date_mask, lmbda=lmbda, passes=passes)

    # Make a plot of both predictions
    fig = plot_orig_vs_masked(whittaker_test_result['orig'],
                              whittaker_test_result['masked'],
                              S2_fapar_cleaned, fapar_date_mask, 'whittaker lambda(%s), passes(%s) %s' % (lmbda, passes, fieldID))

    # Show the plot
    plt.show()

def example_wig():
    '''
    function to run an example showing what this code does
    '''

    root_dir = r'O:\tmp\validation\data\example_1\wicappdata'
    fieldID = '0000280464E91831'

    # Get all the input time series
    S1ascendingVV, S1ascendingVH, S1ascendingIA, S1descendingVV, S1descendingVH, S1descendingIA, S2_fapar_cleaned, S2_fapar_raw = get_input_ts(fieldID, root_dir)

    # As an example, take out 20% of the fapars ( in this example we're using the RAW fapar!)
    np.random.seed(3)
    fapar_dates =  S2_fapar_raw.index.strftime('%Y-%m-%d').tolist()
    fapar_date_mask = list(np.random.choice(fapar_dates, size=int(0.4*len(fapar_dates)), replace=False))

    # Run wig-whittaker on original and masked fapar time series
    wig_test_result = ts_wig(S2_fapar_raw, fapar_date_mask)

    # Make a plot of both predictions
    fig = plot_orig_vs_masked(wig_test_result['orig'],
                              wig_test_result['masked'],
                              S2_fapar_raw, fapar_date_mask, 'WIG-alike %s' % (fieldID,))

    # Show the plot
    plt.show()

def example_run_tests_to_csv():
    run_tests_to_csv('/data/CropSAR/data/ref/field_selection/test_fields_sample/2017_TEST/', 's01_no_sos', debug=True)


if __name__ == '__main__':

    # If this file is run as main, just run the examples
    example_cropsar()
    example_moving_average()
    example_whittaker()
    example_wig()
    #example_run_tests_to_csv()