#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Thu Aug  4 06:55:24 2022

@author: bertelsl


The LUCAS dataset (S1, S2) was colleted for the period 2017/09/01 till 2019/03/31. and contains 
following entries:

    Dimensions:                (feature: 712, t: 528)
    Coordinates:
      * t                      (t) datetime64[ns] 2017-09-01 ... 2019-03-31
        lat                    (feature) float64 58.54 58.62 58.62 ... 57.38 57.35
        lon                    (feature) float64 14.37 14.42 14.35 ... 12.74 12.74
        feature_names          (feature) object 'feature_0' ... 'feature_711'
    Dimensions without coordinates: feature
    Data variables:
        VH                     (feature, t) float64 -15.87 -23.53 nan ... -15.33 nan
        VV                     (feature, t) float64 -9.958 -10.49 nan ... -8.22 nan
        local_incidence_angle  (feature, t) float64 15.59 16.27 nan ... 15.89 nan
        B01                    (feature, t) float64 nan nan nan nan ... nan nan nan
        B02                    (feature, t) float64 nan nan nan nan ... nan nan nan
        B03                    (feature, t) float64 nan nan nan nan ... nan nan nan
        B04                    (feature, t) float64 nan nan nan nan ... nan nan nan
        B05                    (feature, t) float64 nan nan nan nan ... nan nan nan
        B06                    (feature, t) float64 nan nan nan nan ... nan nan nan
        B07                    (feature, t) float64 nan nan nan nan ... nan nan nan
        B08                    (feature, t) float64 nan nan nan nan ... nan nan nan
        B09                    (feature, t) float64 nan nan nan nan ... nan nan nan
        B11                    (feature, t) float64 nan nan nan nan ... nan nan nan
        B12                    (feature, t) float64 nan nan nan nan ... nan nan nan
        B8A                    (feature, t) float64 nan nan nan nan ... nan nan nan
        SCL                    (feature, t) float64 nan nan nan nan ... nan nan nan
    Attributes:
        Conventions:  CF-1.8
        source:       Aggregated timeseries generated by openEO GeoPySpark backend.
        
"""
import os
import glob
import netCDF4
import json
import xarray as xr
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from satio.timeseries import Timeseries
from satio.features import multitemporal_speckle


def _to_db(pwr):
    # Helper function to convert power units to dB
    return 10 * np.log10(pwr)

def _to_pwr(db):
    # Helper function to convert dB to power units
    return np.power(10, db / 10)

#================================================================================================================
class cLUCASpreprocessing(object):
#================================================================================================================
    def __init__(self, Info):        
#================================================================================================================
        
        self.fBasedir = Info['fBasedir']
        self.fSplitdir= Info['fSplitdir']
        self.fOutdir = Info['fOutdir']
        self.periods = Info['periods']
        self.newT = pd.date_range("2017-09-01", periods=Info['periods'])
        
        if not os.path.isdir(self.fOutdir):
            os.mkdir(self.fOutdir)
 
#================================================================================================================
    def start_processing(self):        
#================================================================================================================
        
        aInDirs = glob.glob(self.fBasedir+'Lucas *')
        aSplit = glob.glob(os.path.join(self.fSplitdir, '*.json'))
        Nr = 1
                  
        for fSubdir in aInDirs:

            basename = os.path.basename(fSubdir).split('.')[0]
            self.fOut = os.path.join(self.fOutdir, basename+'.nc')    
        
            print('\r * Processing: {} / {} - {}'.format(Nr, len(aInDirs),  basename))
            Nr += 1

            # check if outputfile is alreay there
            if os.path.isfile(self.fOut):
                continue 

            fTimeseries = os.path.join(fSubdir, 'timeseries.nc')
            fInfo = glob.glob(os.path.join(fSubdir, 'group_*.json'))
            
            if len(fInfo) == 0:
                fInfo = [s for s in aSplit if (basename.split(' ')[1] in s)][0]
            else:
                fInfo = fInfo[0]
            
            if not os.path.isfile(fTimeseries):
                print('*** ERROR: file not found: {}'.format(fTimeseries))
                
            if not os.path.isfile(fInfo):
                print('*** ERROR: file not found: {}'.format(fInfo))
            
            data = json.load(open(fInfo))
            aFeatures = data["features"]
        
            aLabels = []
        
            for feature in aFeatures:
                aLabels.append(feature['properties']['LC1_LABEL'])

            self.extract_LUCAS_data(fTimeseries, aLabels, basename)

#================================================================================================================
    def extract_LUCAS_data(self, fTimeseries, aLabels, basename):        
#================================================================================================================

        print(' * Extracting LUCAS data')
        
        '''Load the input data file as an xarray Dataset'''
        xInput = xr.open_dataset(fTimeseries)
        
        ''' ### Processing optical bands ###'''
        '''Select the optical bands'''
        optical_bands = [b for b in xInput.variables if b.startswith('B')]
        optical_bands.remove('B01')
        optical_bands.remove('B09')
        xdata = xInput[optical_bands]

        '''Get the numpy data behind the xarray Dataset'''
        raw_data = xdata.to_array(dim='band').transpose('band', 't', 'feature').values.astype(np.uint16)
        
        '''Get the metadata that Satio needs to describe the timeseries'''
        timestamps = xdata.t.values
        band_names = optical_bands
        
        '''Transform the data into a Satio Timeseries object'''
        raw_data_expanded = np.expand_dims(raw_data, axis=-1)
        ts = Timeseries(data=raw_data_expanded, timestamps=timestamps, bands=band_names)
        
        '''1) Compositing'''
        composite_settings = dict(
        freq=10,
        window=20,
        mode='median',
        start=pd.to_datetime(ts.timestamps[0]).strftime('%Y-%m-%d'),
        end=pd.to_datetime(ts.timestamps[-1]).strftime('%Y-%m-%d')
        )

        ts_composited = ts.composite(**composite_settings)
        
        '''2) Interpolation'''
        S2_ts_interpolated = ts_composited.interpolate()

        S2_processed_data = S2_ts_interpolated.to_xarray().to_dataset('bands').squeeze(drop=True)  # Now we're back in an xarray dataset with bands as variables and the dummy spatial dimension has been dropped

        ''' ### Processing radar bands ###'''
        # Select the VV and VH bands (assumed to be in dB range!)
        radar_bands = [b for b in xInput.variables if 'VV' in b or 'VH' in b]
        band_data = xInput[radar_bands]

        # Make a Satio Timeseries out of this
        raw_data = band_data.to_array(dim='band').transpose('band', 't', 'feature').values.astype(np.float32)  # Radar data is in float

        timestamps = band_data.t.values
        band_names = radar_bands
        raw_data_expanded = np.expand_dims(raw_data, axis=-1)
        
        ts = Timeseries(data=raw_data_expanded, timestamps=timestamps, bands=band_names)

        # Before doing manupulations
        # first need to get rid of dB
        data_lin = _to_pwr(ts.data)
        ts.data = data_lin 

        # Compositing: use S1-specific frequency, window and a mean compositing operation
        composite_settings = dict(
            freq=10,
            window=20,
            mode='mean',
            start=pd.to_datetime(ts.timestamps[0]).strftime('%Y-%m-%d'),
            end=pd.to_datetime(ts.timestamps[-1]).strftime('%Y-%m-%d')
        )
        ts_composited = ts.composite(**composite_settings)
        
        # Linear interpolation
        S1_ts_interpolated = ts_composited.interpolate()
        
        # Finally we can now go back to dB
        ts_interpolated_db = _to_db(S1_ts_interpolated.data)
        
        # And override the data array
        S1_ts_interpolated.data = ts_interpolated_db
        
       # Now we're back in an xarray dataset with bands as variables and the dummy spatial dimension has been dropped      
        S1_processed_data = S1_ts_interpolated.to_xarray().to_dataset('bands').squeeze(drop=True)  
 
        '''create the new dataset'''
        new_data = xr.Dataset(
                data_vars=dict(
                    B02=(["date", "labels"], S2_processed_data['B02'].values),
                    B03=(["date", "labels"], S2_processed_data['B03'].values),
                    B04=(["date", "labels"], S2_processed_data['B04'].values),
                    B05=(["date", "labels"], S2_processed_data['B05'].values),
                    B06=(["date", "labels"], S2_processed_data['B06'].values),
                    B07=(["date", "labels"], S2_processed_data['B07'].values),
                    B08=(["date", "labels"], S2_processed_data['B08'].values),
                    B8A=(["date", "labels"], S2_processed_data['B8A'].values),
                    B11=(["date", "labels"], S2_processed_data['B11'].values),
                    B12=(["date", "labels"], S2_processed_data['B12'].values),
                    VH=(["date", "labels"], S1_processed_data['VH'].values),
                    VV=(["date", "labels"], S1_processed_data['VV'].values),
                    lat=(["labels"], xInput['lat'].values),
                    lon=(["labels"], xInput['lon'].values),                    
                ),
                coords=dict(
                    date=S2_processed_data['time'].values,
                    labels = (aLabels)
                ),
                attrs=dict(description="satio cleaned data."),
            )

        new_data.to_netcdf(self.fOut)



#================================================================================================================
if __name__ == '__main__':
#================================================================================================================

    Info = {
        'fBasedir': r'/data/EEA_HRL_VLCC/data/ref/lucas/LUCAS2018/2018_EU_LUCAS2018_POINT_JD/',
        'fSplitdir': r'/data/EEA_HRL_VLCC/data/ref/lucas/LUCAS2018/2018_EU_LUCAS2018_POINT_JD/lucas_split/',
        'fOutdir': r'/data/EEA_HRL_VLCC/user/luc/data/LUCAS2018/1_satio_preprocessed/',
        'periods': 577 # 2017/09/01 till 2019/03/31
        }
    
    oLUCASpreprocessing = cLUCASpreprocessing(Info)
    oLUCASpreprocessing.start_processing()
    
    
    
    
    
    