# -*- coding: utf-8 -*-
"""
Created on Thursday 14 December 2017
Description: This script makes use of the Google Earth Engine API to extract time series of Sentinel-1 backscatter,
averaged over a collection of polygons (1 TS per polygon). The output is saved to a CSV in Google Drive.
@author: Dr. Kristof Van Tricht (VITO)
"""
import glob, os
import ee
import datetime
import logging
import pandas as pd
import numpy as np
import _pickle as pickle
from cropsar.preprocessing.utils import to_db, to_linear

log = logging.getLogger(__name__)

# Function to buffer feature (30m inward)
def bufferFeature(feature):
  return ee.Feature(feature.geometry().buffer(-10), feature.toDictionary())

# We need this function to mask out edges that have corrupt values
def maskEdge(img):
    mask = img.select(0).unitScale(-25, 5).multiply(255).toByte().connectedComponents(ee.Kernel.rectangle(1, 1), 120)
    return img.updateMask(mask.select(0))

# Function to transform dB values to linear values
def toNaturalGEE(img):
    keepProperties = ['system:time_start', 'system:time_end']
    return ee.Image(ee.Image(10.0).pow(img.select(['VV', 'VH']).divide(10.0)).copyProperties(img,keepProperties)).addBands(img.select('angle')).rename(['VV','VH','angle'])

# Function to add the date of an image as property in the format yyyy-mm-dd
def addDate(img):
  return img.set('date', img.date().format('Y-M-d H:m:s'))

# Function to get the dates of the image to put into a list
def getDate(item):
  return ee.Image(item).date().format('Y-M-d H:m:s')

# Function to get the relative orbit number from an image as listItem
def getRO(item):
    return ee.Image(item).get('relativeOrbitNumber_start')

# Function to get the platform number (A/B) from an image
def getPlatformNumber(item):
    return ee.Image(item).get('platform_number')


def gee_s1_extract(start_date, end_date, localOutDir, fc, identifier, outpattern):
    '''
    Main function to submit jobs to GEE for polygon-averaged extraction of S1 time series
    :param start_date: desired start date of TS (yyyy-mm-dd)
    :param end_date: desired end date of TS (yyyy-mm-dd)
    :param localOutDir: local directory where TS are stored (used for checking if file can be skipped)
    :param fc: remote ID of the featurecollection with polygons
    :param identifier: attribute name in the feature collection to be used for identifying the polygons
    :param outpattern: pattern to be used in the output files
    :return:
    '''

    # Initialize the Earth Engine object, using the authentication credentials.
    log.info('Initializing Earth Engine API...')
    ee.Initialize()

    # Define the FC that holds the area of interest
    region = ee.FeatureCollection("users/geesigma/SHP/FlandersShapefile")

    # Initiate task list
    all_tasks = []

    # Import the collections
    sentinel1 = ee.ImageCollection("COPERNICUS/S1_GRD")

    # Load the feature collection
    features = ee.FeatureCollection(fc)

    # First check if the identifier exists as feature property, otherwise we'll process a lot for nothing!
    existingproperties = features.first().propertyNames().getInfo()
    if identifier not in existingproperties:
        import sys
        sys.exit('Property "{}" not found in featureCollection -> cannot continue!'.format(identifier))

    # Check how many features we need to process
    nr_features = features.size().getInfo()
    log.info('Backscatter for {} features will be retrieved ...'.format(nr_features))

    while True:
        try:
            ###############################################################################
            # PROCESSING SENTINEL 1
            ###############################################################################

            for mode in ['ASCENDING', 'DESCENDING']:
                log.info('Extracting Sentinel-1 data in %s mode' % (mode))
                # Filter S1 by metadata properties.
                sentinel1_filtered = sentinel1.filterBounds(region).filterDate(start_date, end_date) \
                    .filter(ee.Filter.eq('orbitProperties_pass', mode)) \
                    .filter(ee.Filter.eq('instrumentMode', 'IW')) \
                    .filter(ee.Filter.listContains('transmitterReceiverPolarisation', 'VV')) \
                    .filter(ee.Filter.listContains('transmitterReceiverPolarisation', 'VH'))

                sentinel1_collection_contents = ee.List(sentinel1_filtered).getInfo()
                current_nr_files = len(sentinel1_collection_contents['features'])
                log.info('{} Sentinel-1 images match the request ...'.format(current_nr_files))

                for img_nr in range(current_nr_files):
                    current_sentinel_img_id = str(sentinel1_collection_contents['features'][img_nr]['id'])
                    current_sentinel_img = ee.Image(current_sentinel_img_id)
                    current_relative_orbit_number = int(current_sentinel_img.get('relativeOrbitNumber_start').getInfo())
                    current_platform_number = current_sentinel_img.get('platform_number').getInfo()
                    current_date = datetime.datetime.strptime(current_sentinel_img.date().format('Y-M-d H:m:s').getInfo(), '%Y-%m-%d %H:%M:%S').strftime('%Y-%m-%d_%H%M%S')

                    taskname = 'S1' + str(current_platform_number) + '_' + current_date + '_' + mode + '_RO' \
                            + str(current_relative_orbit_number) + '_' + outpattern

                    # The actual function: a reduceRegions on the original (nonDB values)
                    results = ee.Image(toNaturalGEE(current_sentinel_img)).reduceRegions(
                        collection= features.select([identifier]).filterBounds(current_sentinel_img.geometry()),
                        reducer= ee.Reducer.mean(),
                        scale= 10)

                    # Check if output exists
                    if len(glob.glob(os.path.join(localOutDir, '*' + taskname + '*.csv'))) != 0:
                        log.info('{} exists -> skipping'.format(taskname))
                        continue

                    log.info('Starting new task: ' + taskname)

                    all_tasks.append(ee.batch.Export.table(results, taskname,
                                                           {'driveFolder': 'GEE_TimeSeries',
                                                            'driveFileNamePrefix': taskname,
                                                            'selectors': ','.join([identifier, 'VV', 'VH', 'angle']),
                                                            'fileFormat': 'CSV'}))

                    all_tasks[-1].start()

                    log.info('%s tasks submitted...' % (str(len(all_tasks))))

        except KeyboardInterrupt:
            raise
        except:
            raise
            log.error('An error was encountered -> Attempting to continue')
            continue
        break

def combine_gee_s1_timeseries(indir, outdir, start_date, end_date, identifier, inpattern):
    '''
    Function to process all resulting CSV time series files into output pickle files with dataframes
    :param indir: directory where input files are stored
    :param outdir: directory where output files are stored
    :param start_date: desired start date of TS (yyyy-mm-dd)
    :param end_date: desired end date of TS (yyyy-mm-dd)
    :param identifier: attribute name in the feature collection to be used for identifying the polygons
    :param inpattern:  pattern used to find matching input files
    :return:
    '''

    modes = ['ASCENDING', 'DESCENDING']
    variables = ['VV', 'VH']

    for mode in modes:
        log.info('Working on mode: {}'.format(mode))

        # Let's see how many images we have for this region and overpass mode
        images = list(sorted(glob.glob(os.path.join(indir, 'S1*' + mode + '*' + inpattern + '*.csv'))))

        # List the different relative orbits
        orbits = []
        for file in images:
            orbits.append(int(os.path.basename(file).split('_')[4][2:]))
        unique_orbits = list(sorted(set(orbits)))

        for current_orbit in unique_orbits:
            log.info('Working on relative orbit: {}'.format(current_orbit))

            subfiles = list(sorted(glob.glob(os.path.join(indir, 'S1*' + mode + '_RO' + str(current_orbit) + '_*' + inpattern + '*.csv'))))
            log.info('Found {} images that match the specifications'.format(len(subfiles)))

            def processDF(subfile):
                log.info('Reading: {}'.format(os.path.basename(subfile)))
                currentDate = pd.to_datetime(os.path.basename(subfile).split('_')[1] + ' ' + os.path.basename(subfile).split('_')[2])
                RO = int(os.path.basename(subfile).split('_')[4][2:])
                try:
                    tempDF = pd.read_csv(subfile)
                except: return None, None, None

                newDF_VV = pd.DataFrame(columns=tempDF[identifier],
                                        data=to_db(np.transpose(np.expand_dims(np.array(tempDF['VV']), axis=1))),
                                        index=[currentDate])
                newDF_VH = pd.DataFrame(columns=tempDF[identifier],
                                        data=to_db(np.transpose(np.expand_dims(np.array(tempDF['VH']), axis=1))),
                                        index=[currentDate])
                newDF_angle = pd.DataFrame(columns=tempDF[identifier],
                                        data=np.transpose(np.expand_dims(np.array(tempDF['angle']), axis=1)),
                                        index=[currentDate])
                return newDF_VV, newDF_VH, newDF_angle

            frames_VV, frames_VH, frames_angle = [], [], []

            for file in subfiles:
                frame_VV, frame_VH, frame_angle = processDF(file)
                frames_VV.append(frame_VV)
                frames_VH.append(frame_VH)
                frames_angle.append(frame_angle)

            log.info('Concatenating dataframes ...')
            dataDF_VV = pd.concat(frames_VV, axis=0, join='outer', sort=False).sort_index()
            dataDF_VH = pd.concat(frames_VH, axis=0, join='outer', sort=False).sort_index()
            dataDF_angle = pd.concat(frames_angle, axis=0, join='outer', sort=False).sort_index()

            log.info('Removing false values ...')
            dataDF_VV[dataDF_VV == 990] = np.nan # No data value
            dataDF_VH[dataDF_VH == 990] = np.nan  # No data value

            # Need to go back to linear values
            dataDF_VVnat = to_linear(dataDF_VV)
            dataDF_VHnat = to_linear(dataDF_VH)

            # Resample to daily to be able to correctly specify start and end
            log.info('Resampling to daily ...')
            ix = pd.DatetimeIndex(start=pd.to_datetime(start_date), end=pd.to_datetime(end_date), freq='D')

            TS_daily_VV = to_db(dataDF_VVnat.resample('1D').mean().loc[start_date:end_date].reindex(ix))
            TS_daily_VH = to_db(dataDF_VHnat.resample('1D').mean().loc[start_date:end_date].reindex(ix))
            TS_daily_angle = dataDF_angle.resample('1D').mean().loc[start_date:end_date].reindex(ix)

            # Specify outfiles
            outfile_VV = os.path.join(os.path.join(indir, 'processed'), '_'.join(['S1', start_date, end_date, mode, 'RO'+str(current_orbit), inpattern + '_VV.p']))
            outfile_VH = os.path.join(os.path.join(indir, 'processed'), '_'.join(
                ['S1', start_date, end_date, mode, 'RO' + str(current_orbit), inpattern + '_VH.p']))
            outfile_angle = os.path.join(os.path.join(indir, 'processed'), '_'.join(
                ['S1', start_date, end_date, mode, 'RO' + str(current_orbit), inpattern + '_angle.p']))

            # Create output dict
            outputVV = dict()
            outputVV['original'] = dataDF_VV
            outputVV['daily'] = TS_daily_VV

            outputVH = dict()
            outputVH['original'] = dataDF_VH
            outputVH['daily'] = TS_daily_VH

            outputAngle = dict()
            outputAngle['original'] = dataDF_angle
            outputAngle['daily'] = TS_daily_angle

            # Save
            log.info('Saving to files ...')
            pickle.dump(outputVV, open(outfile_VV, 'wb'))
            pickle.dump(outputVH, open(outfile_VH, 'wb'))
            pickle.dump(outputAngle, open(outfile_angle, 'wb'))

    log.info('-'*75)
    log.info('Starting concatenation ...')

    # Now concatenate all different orbits, add the incidence angles, and do other important stuff

    final_S1_data = dict()  # This dictionary will hold the final processed data
    final_S1_data['ASCENDING'] = dict()
    final_S1_data['DESCENDING'] = dict()

    for variable in variables:
        log.info('Working on variable: {}'.format(variable))

        for mode in modes:
            log.info('Working on mode: {}'.format(mode))

            subfiles = glob.glob(os.path.join(os.path.join(indir, 'processed'),
                                              'S1_' + start_date + '_' + end_date + '_' + mode + '*' + inpattern + '*' + variable + '.p'))

            # First we'll check if no overlapping dates are present
            superList = []
            for subfile in subfiles:
                newDF = pickle.load(open(subfile, 'rb'))['original']
                newDF = to_db(to_linear(newDF).groupby(pd.TimeGrouper('D')).mean()).dropna(axis=0, how='all')
                assert not list(set(newDF.index).intersection(superList))
                superList += sorted(list(set(newDF.index)))

            dataDFMerged = None
            incidenceAnglesMerged = None
            for subfile in subfiles:
                log.info('Reading: {}'.format(os.path.basename(subfile)))

                # The original dataframe
                newDF = pickle.load(open(subfile, 'rb'))['daily']
                newDF = to_db(to_linear(newDF).groupby(newDF.index).mean())
                dataDFMerged = pd.concat([dataDFMerged, newDF], join='outer', sort=False).sort_index(level='Date')

                # Test: 0000280449BFB7BB
                # print(np.sum(np.isfinite(newDF['0000280449BFB7BB'])))

                # Read the incidence angles of the current orbit for each field
                incidenceAnglesDF = pd.read_pickle(os.path.splitext(subfile)[0][:-2] + 'angle.p')['daily']
                # Set the angles to NaN if we didn't have an overpass
                incidenceAnglesDF[np.isnan(newDF)] = np.nan
                # Append to merged dataframe
                incidenceAnglesMerged = pd.concat([incidenceAnglesMerged, incidenceAnglesDF], join='outer',
                                                  sort=False).sort_index(
                    level='Date')

            # Group the merged data to daily values (we're already sure here that there were no overlapping observations
            log.info('Grouping data ...')
            dataDF = to_db(to_linear(dataDFMerged).groupby(dataDFMerged.index).mean())
            incidenceAnglesMerged = incidenceAnglesMerged.groupby(incidenceAnglesMerged.index).mean()

            # Put the data in the dictionary
            final_S1_data[mode][variable] = dataDF
            final_S1_data[mode]['incidenceAngle'] = incidenceAnglesMerged

    # Finally, save the data to pickle file
    finaloutfile = os.path.join(outdir, inpattern + '_' + start_date + '_' + end_date + '_S1_backscatter.p')
    log.info('Saving to output file: {}'.format(finaloutfile))
    with open(finaloutfile, 'wb') as f:
        pickle.dump(final_S1_data, f)
