from pyspark.sql.types import *
from cropsar.readers import *
from cropsar.preprocessing.utils import *
import configparser
import xarray
import warnings

log = logging.getLogger(__name__)


def parse_timeseries(fieldID, inputData, s2_variable):

    # Combine S1 orbits
    S1inputData = combine_s1_orbits(inputData['S1'], fieldID)

    s1_vv = np.expand_dims(S1inputData['VH'][fieldID].values, axis=1)
    s1_vh = np.expand_dims(S1inputData['VV'][fieldID].values, axis=1)
    s1_angle = np.expand_dims(
        S1inputData['incidenceAngle'][fieldID].values, axis=1)
    s2_data = np.expand_dims(
        inputData['S2'][s2_variable][fieldID].values, axis=1)

    # Check if we have enough valid S2 data in the time series.
    if np.sum(np.isfinite(s2_data.ravel().astype(float))) < 4:
        return None
    else:
        return (s1_vv, s1_vh, s1_angle, s2_data)


def convert_row_RNNfull(tuple):
    field_id = tuple[0]
    croptype = tuple[1]
    s1_vv = tuple[2][0].tolist()
    s1_vh = tuple[2][1].tolist()
    s1_angle = tuple[2][2].tolist()
    s2_data = tuple[2][3].tolist()
    return Row(field_id=field_id, croptype=croptype, s1_vv=s1_vv,
               s1_vh=s1_vh, s1_angle=s1_angle, s2_data=s2_data)


def datadict_to_tuple(datadict, field, S2layername='FAPAR'):
    '''
    Helper function to transform an input dictionary to a tuple that can be used in "convert_inputs"
    :param datadict:
    :param field:
    :return:
    '''
    data = (
        (
            datadict['S1']['VV'][field].values.reshape(
                (-1, len(datadict['S1']['VV'][field].values))),
            datadict['S1']['VH'][field].values.reshape(
                (-1, len(datadict['S1']['VV'][field].values))),
            datadict['S1']['incidenceAngle'][field].values.reshape(
                (-1, len(datadict['S1']['VV'][field].values))),
            datadict['S2'][S2layername][field].values.reshape(
                (-1, len(datadict['S1']['VV'][field].values))),
        ),
        (
            field
        )
    )
    return data


def getFullDataStack_RNNfull(fieldID, inputData,
                             S2layername='FAPAR', useAfter=True, endDate=None, S1var='gamma'):
    """
    Generate dataStacks as full time series for RNNfull input, APPLICATION VERSION

    :param useAfter: whether or not to use available observations past specified endDate; for gap-filling application,
                    this can be useful; for testing purposes, we want to end the time series and not use future observations
    :param endDate: end date of the returned dataStack
    :param S2layername: layername of S2 data [FAPAR, FCOVER]
    :param S1var: which varible the backscatter represents [gamma, sigma]
    """

    # If Sentinel-1 time series are delivered in seperate orbit passes, we first need to combine them
    # todo: currently, the training data is still provided as separate orbits, so for now, we need this for compatibility
    if 'ASCENDING' in inputData['S1'].keys():
        inputData['S1'] = combine_s1_orbits(inputData['S1'], fieldID)

    if not useAfter:
        log.warning(
            '"useAfter" parameter disabled: cutting off data after {} ...'.format(endDate))
        for variable in inputData['S1'].keys():
            inputData['S1'][variable].loc[endDate:] = np.nan
        for variable in inputData['S2'].keys():
            inputData['S2'][variable].loc[endDate:] = np.nan

    # Transform the input dictionary to a tuple to use in "convert_inputs"
    inputtuple = datadict_to_tuple(inputData, fieldID, S2layername)

    # Transform the inputs to a stack that RNNfull can use
    inputs, outputs, _ = convert_inputs(
        inputtuple, S2layername=S2layername, S1var=S1var)

    log.info('Field {} fully processed!'.format(fieldID))

    inputs = np.concatenate([inputs[0], inputs[1]], axis=2)

    return (inputs, outputs)


def convert_inputs(data, removalratios=None, cutoff=False, training=False, tfdataset=False, addnoise=False,
                   forecast=False, S2layername='FAPAR', S1var='gamma'):
    '''
    Actually the core function of RNNfull, as it transforms the original input series into the format that the
    RNNfull network accepts.

    :param data:
    :param removalratios:
    :param cutoff:
    :param training:
    :param tfdataset:
    :param addnoise:
    :param forecast:
    :return:
    '''

    if type(removalratios) == float:
        removalratios = [removalratios]

    batchsize = data[0][0].shape[0] if not tfdataset else data[0][0].numpy(
    ).shape[0]

    with warnings.catch_warnings():
        warnings.filterwarnings('ignore', r'Mean of empty slice')

        if tfdataset:

            # Convert to xarray
            s1_inputs_dict = {
                's1_vv': (['fieldID', 'time'], to_linear(data[0][0].numpy().reshape((batchsize, -1)))),
                's1_vh': (['fieldID', 'time'], to_linear(data[0][1].numpy().reshape((batchsize, -1)))),
                's1_angle': (['fieldID', 'time'], data[0][2].numpy().reshape((batchsize, -1)))
            }

            fieldIDs = data[1][0].numpy().astype(str)
            s1_inputs = xarray.Dataset(
                s1_inputs_dict, coords={'fieldID': fieldIDs})
            s2_inputs = xarray.Dataset({'s2_data': (['fieldID', 'time'], data[0][3].numpy().reshape((batchsize, -1)))},
                                       coords={'fieldID': fieldIDs})

        else:
            # Convert to xarray
            s1_inputs_dict = {
                's1_vv': (['fieldID', 'time'], to_linear(data[0][0]).reshape((batchsize, -1))),
                's1_vh': (['fieldID', 'time'], to_linear(data[0][1]).reshape((batchsize, -1))),
                's1_angle': (['fieldID', 'time'], data[0][2].reshape((batchsize, -1)))
            }

            fieldIDs = np.array([data[1][0]]).astype(str)
            s1_inputs = xarray.Dataset(
                s1_inputs_dict, coords={'fieldID': fieldIDs})
            s2_inputs = xarray.Dataset({'s2_data': (['fieldID', 'time'], data[0][3].reshape((batchsize, -1)))},
                                       coords={'fieldID': fieldIDs})

        # Calculate gamma0
        if S1var == 'sigma':
            s1_inputs_gamma = xarray.Dataset({
                's1_vv_gamma': (['fieldID', 'time'], sigma_to_gamma(s1_inputs['s1_vv'], s1_inputs['s1_angle'])),
                's1_vh_gamma': (['fieldID', 'time'], sigma_to_gamma(s1_inputs['s1_vh'], s1_inputs['s1_angle']))
            }, coords={'fieldID': fieldIDs})
        else:
            s1_inputs_gamma = xarray.Dataset({
                's1_vv_gamma': (['fieldID', 'time'], s1_inputs['s1_vv'].data),
                's1_vh_gamma': (['fieldID', 'time'], s1_inputs['s1_vh'].data)
            }, coords={'fieldID': fieldIDs})

        # Linear interpolate NaN values
        s1_inputs_gamma_interp = s1_inputs_gamma.interpolate_na(
            'time', method='linear')

        # Calculate moving average
        ma_window = 13
        s1_inputs_gamma_smooth = s1_inputs_gamma_interp.rolling(
            {'time': ma_window}, center=True, min_periods=1).mean()

        # S1 back to db
        s1_inputs_gamma_smooth = to_db(s1_inputs_gamma_smooth)

        # Convert to arrays
        s1array = s1_inputs_gamma_smooth[[
            's1_vv_gamma', 's1_vh_gamma']].to_array().data.transpose((1, 2, 0))
        s2array = s2_inputs.to_array().data.transpose((1, 2, 0))

        # Scale the input data
        s1arrayscaled = np.empty_like(s1array)
        s2arrayscaled = np.empty_like(s2array)
        s1arrayscaled[:, :, 0] = minmaxscaler(s1array[:, :, 0], 's1_vv_smooth')
        s1arrayscaled[:, :, 1] = minmaxscaler(s1array[:, :, 1], 's1_vh_smooth')
        s2arrayscaled[:, :, 0] = minmaxscaler(
            s2array[:, :, 0], 's2_{}'.format(S2layername.lower()))

        # Create a copy of S2 to be the outputs
        outputarrayscaled = np.copy(s2arrayscaled)

        # Find valid observations
        idxvalid_x, idxvalid_y, _ = np.where(np.isfinite(outputarrayscaled))

        # Remove some valid obs for training
        if removalratios is not None:
            currentremovalratio = np.random.choice(removalratios)
            removalidx = np.random.choice(np.arange(len(idxvalid_x)),
                                          size=int(currentremovalratio * len(idxvalid_x)), replace=False)
            s2arrayscaled[idxvalid_x[removalidx],
                          idxvalid_y[removalidx], :] = np.nan

        # Add some noise to input S2 during training
        if addnoise:
            noise = np.random.normal(
                scale=0.01, size=s2arrayscaled.size).reshape(s2arrayscaled.shape)
            s2arrayscaled = s2arrayscaled + np.isfinite(s2arrayscaled) * noise

        TSlength = s1arrayscaled.shape[1]

        if cutoff:
            # Start/Stop the time series at a random moment
            maxStart = 500
            minLength = 60

            tsStart = np.random.choice(
                np.arange(0, maxStart), size=s1arrayscaled.shape[0])
            tsStop = np.random.choice(
                np.arange(maxStart + minLength, TSlength), size=s1arrayscaled.shape[0])
            tsSize = tsStop - tsStart

            s1arrayscaledCut = np.zeros_like(s1arrayscaled)*np.nan
            s2arrayscaledCut = np.zeros_like(s2arrayscaled)*np.nan
            outputarrayscaledCut = np.zeros_like(outputarrayscaled)*np.nan

            r = np.arange(s1arrayscaled.shape[1])
            m = (tsStart[:, None] <= r) & (tsStop[:, None] > r)
            s = s1arrayscaled.shape[1] - tsSize
            m2 = s[:, None] <= r
            s1arrayscaledCut[m2] = s1arrayscaled[m]
            s2arrayscaledCut[m2] = s2arrayscaled[m]
            outputarrayscaledCut[m2] = outputarrayscaled[m]

        else:
            s1arrayscaledCut = s1arrayscaled
            s2arrayscaledCut = s2arrayscaled
            outputarrayscaledCut = outputarrayscaled

        if training:
            # In some cases, remove the last valid fapar input, so we're sure we train well on predicting that value
            validoutput = np.isfinite(outputarrayscaledCut)
            idx_y = np.array(
                validoutput.shape[1] - validoutput[:, ::-1].argmax(1) - 1).ravel()
            idx_x = np.arange(batchsize)
            removalchance = 0.5
            removalidx = np.random.choice(idx_x, size=int(
                removalchance*batchsize), replace=False)
            s2arrayscaledCut[idx_x[removalidx], idx_y[removalidx], :] = np.nan

        # Determine the weights
        # weights are determined based on the relative amount of valid observations in a row
        nrvalidobs = np.sum(np.isfinite(outputarrayscaledCut),
                            axis=1).reshape((batchsize, -1))
        weightsperrow = np.array(
            outputarrayscaledCut.shape[1]/nrvalidobs, dtype=int)
        weightsperrowrepeated = np.repeat(np.expand_dims(weightsperrow, axis=1),
                                          repeats=outputarrayscaledCut.shape[1], axis=1).reshape((batchsize, -1))
        validmask = np.isfinite(outputarrayscaledCut.reshape((batchsize, -1)))
        weightsarray = validmask * weightsperrowrepeated

        if training:
            # Double the weights on the removed last valid fapars
            weightsarray[idx_x[removalidx], idx_y[removalidx]
                         ] = weightsarray[idx_x[removalidx], idx_y[removalidx]] * 3

        # As a final step, get rid of all NaN values
        #s1arrayscaledCut[np.isnan(s1arrayscaledCut)] = 0.
        #s2arrayscaledCut[np.isnan(s2arrayscaledCut)] = 0.
        outputarrayscaledCut[np.isnan(outputarrayscaledCut)] = -2.

        # Final final step: add forecast functionality
        if forecast:
            r = np.arange(s1arrayscaled.shape[1])
            inputstop = np.random.randint(TSlength - 30 * 5, TSlength,
                                          size=batchsize)  # random up to five months pediction
            inputstop[np.random.choice(np.arange(batchsize), size=int(0.80 * batchsize),
                                       replace=False)] = TSlength  # In most of the cases we don't do this
            m = inputstop[:, None] < r
            s1arrayscaledCut[m] = -2.
            s2arrayscaledCut[m] = -2.

        # Additional step: forward imputation instead of Zeroes
        # s2arrayscaledCut = s2arrayscaledCut[:, :, 0]
        # s2arrayscaledCut[s2arrayscaledCut == 0] = np.nan
        # mask = np.isnan(s2arrayscaledCut)
        # idx = np.where(~mask, np.arange(mask.shape[1]), 0)
        # np.maximum.accumulate(idx, axis=1, out=idx)
        # out = s2arrayscaledCut[np.arange(idx.shape[0])[:, None], idx]
        # s2arrayscaledCut = np.expand_dims(out, axis=2)
        # s2arrayscaledCut[np.isnan(s2arrayscaledCut)] = 0

        s2arrayscaledCut[np.isnan(s2arrayscaledCut)] = -2.
        s1arrayscaledCut[np.isnan(s1arrayscaledCut)] = -2.

    return ([s1arrayscaledCut, s2arrayscaledCut],
            [outputarrayscaledCut, outputarrayscaledCut, outputarrayscaledCut],
            [weightsarray, weightsarray, weightsarray])


def compute_and_write_tfrecords(field_subset, timeseries_df, output_basename,
                                nroutputfiles=500, s2_variable='FAPAR'):

    fields_df = SparkSession.builder.getOrCreate(
    ).createDataFrame(field_subset.reset_index())
    filtered = fields_df.join(timeseries_df, fields_df.fieldID ==
                              timeseries_df.name, 'inner').drop(fields_df.fieldID)

    # Repartition in partitions (it's a lot of data)
    timeseries_df = filtered.repartition(nroutputfiles)

    # Generate Inputs
    input_output_rdd = timeseries_df.rdd \
        .map(lambda row: (row.name, row.croptype, parse_timeseries(row.name, df_row_to_dict(row.name, row, s2_variable),
                                                                   s2_variable=s2_variable)))\
        .filter(lambda t: t[2] is not None) \
        .map(convert_row_RNNfull)
    output_df = input_output_rdd.toDF()

    # Write to  output files
    output_df.repartition(nroutputfiles).write.format("tfrecords").mode('overwrite').option("recordType", "SequenceExample") \
        .option("codec", "org.apache.hadoop.io.compress.GzipCodec").save(output_basename)


def main_parquet(indir='/data/CropSAR/tmp/dataStackGenerationSPARK/data/parquetFiles/S1GEE_S2CLEANED_V005/',
                 outDir='/data/CropSAR/tmp/dataStackGenerationSPARK/data/dataStacks/S1GEE_S2CLEANED_V005/RNNfull/',
                 s2_path=None, year='2018', minArea=None, S2layername='FAPAR', logFile=None, nroutputfiles=500):
    """
    Convert timeseries parquet files into a datastack that can be used for training the RNNfull.
    validation, calibration and test directories are generated.

    :param indir:  Input directory containing Parquet files
    :param outDir: Output directory where to write
    :param year: The year for which to generate data stacks
    :param minArea: if specified, this is the minimum area in m² for a parcel to be included in the datastack
    :param S2layername: layername of S2 data [FAPAR, FCOVER]
    :return: No result
    """

    # Make sure pathlib.Path objects are converted to strings if needed
    indir = str(indir)
    outDir = str(outDir)
    s2_path = str(s2_path) if s2_path is not None else None

    year = str(year)

    # Initiate ini with settings
    config = configparser.ConfigParser()
    config['Data'] = {
        'inputDir': indir[7:] if indir[0:4] == 'file' else indir,
        'outputDir': outDir[7:] if outDir[0:4] == 'file' else outDir,
        'year': str(year),
    }

    # Create outdir if it doesn't exist
    if outDir[0:4] == 'file':
        os.makedirs(outDir[7:], exist_ok=True)
    else:
        os.makedirs(outDir, exist_ok=True)

    # Initiate spark
    sc = get_spark_context()

    # Initiate logging
    log = get_logger(name="cropsar_datastacks", filename=logFile)

    log.info('Creating datastacks for the year: {}'.format(year))
    log.info('DATASTACK TYPE: RNNFULL')

    fields = pd.read_parquet(os.path.join(
        indir, 'croptypes_calvaltest.parquet'))
    CAL_fields = fields[fields['CALVALTEST'] == 'CAL']
    VAL_fields = fields[fields['CALVALTEST'] == 'VAL']
    TEST_fields = fields[fields['CALVALTEST'] == 'TEST']

    # Subset the fields on minimum area if requested
    if minArea:
        log.info('Subsetting fields on minimum area of {} m² ...'.format(minArea))
        areasCAL = fields.loc[CAL_fields.index.tolist()]
        CAL_fields = areasCAL.loc[areasCAL['area'] >= minArea]
        areasVAL = fields.loc[VAL_fields.index.tolist()]
        VAL_fields = areasVAL.loc[areasVAL['area'] >= minArea]
        areasTEST = fields.loc[TEST_fields.index.tolist()]
        TEST_fields = areasTEST.loc[areasTEST['area'] >= minArea]

    # Create the outDir if needed
    os.makedirs(outDir, exist_ok=True)

    log.info('Nr of CAL fields: {}'.format(len(CAL_fields)))
    log.info('Nr of VAL fields: {}'.format(len(VAL_fields)))
    log.info('Nr of TEST fields: {}'.format(len(TEST_fields)))

    config['NrFields'] = {
        'CAL': len(CAL_fields),
        'VAL': len(VAL_fields),
        'TEST': len(TEST_fields)
    }

    # Read all of the timeseries input data, as a PySpark Dataframe
    log.info('Reading parquet data...')
    timeseries_df = read_full_timeseries_input(
        indir, s2_path=s2_path, S2layername=S2layername)

    log.info('Done reading data.')
    output_basename = os.path.join(outDir, year + '_data_')
    log.info('Basename: {}'.format(output_basename))

    # Save the ini file
    configFile = output_basename[7:] + \
        '.ini' if output_basename[0:4] == 'file' else output_basename + '.ini'
    if os.path.exists(configFile):
        os.remove(configFile)
    with open(configFile, 'w') as f:
        config.write(f)
    log.info('config.ini file saved to: {}'.format(configFile))

    timeseries_df.persist()

    # Process the calibration stacks
    log.info('-' * 75)
    log.info('Creating calibration timeseries.')
    compute_and_write_tfrecords(CAL_fields, timeseries_df, output_basename + 'CAL_',
                                nroutputfiles=nroutputfiles, s2_variable=S2layername)

    # Process the calibration stacks
    log.info('-' * 75)
    log.info('Creating validation timeseries.')
    compute_and_write_tfrecords(VAL_fields, timeseries_df, output_basename + 'VAL_',
                                nroutputfiles=nroutputfiles, s2_variable=S2layername)

    # Process the calibration stacks
    log.info('-' * 75)
    log.info('Creating test timeseries.')
    compute_and_write_tfrecords(TEST_fields, timeseries_df, output_basename + 'TEST_',
                                nroutputfiles=nroutputfiles, s2_variable=S2layername)
    timeseries_df.unpersist()

    log.info('-' * 80)
    log.info('DATASTACK GENERATION PROGRAM SUCCESSFULLY FINISHED!')
    log.info('-' * 80)


if __name__ == '__main__':

    import fire
    fire.Fire(main_parquet)
