#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed Aug 24 07:02:01 2022

@author: bertelsl

Algorithm:
    
    - search for S1, S2 files in the pre-defined input directories:
        
        'fIndir': r'/data/EEA_HRL_VLCC/data/ref/crop_type/',
        
    - use satio to composite and interpolate the data over the full time period 2017/09/01 - 2019/03/31 
   
    - save the pre-processed data to the output directory as netcdf, only for the time period of interest 2018/03/01 - 2018/08/31:
         
        'fOutdir': r'/data/EEA_HRL_VLCC/user/luc/data/LPIS/01_LPIS_preprocessed/',
        
Version: 02/09/2022

"""

import os
import glob
import netCDF4
import json
import warnings
import xarray as xr
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from satio.timeseries import Timeseries
from satio.features import multitemporal_speckle
from datetime import datetime

def _to_db(pwr):
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        # Helper function to convert power units to dB
        return 10 * np.log10(pwr)

def _to_pwr(db):
    # Helper function to convert dB to power units
    return np.power(10, db / 10)

#================================================================================================================
class cLPISpreprocessing(object):
#================================================================================================================
    def __init__(self, Info):        
#================================================================================================================
        
        self.fIndir = Info['fIndir']
        self.fOutdir = Info['fOutdir']
        self.start_date_focus = Info['start_date_focus']
        self.end_date_focus = Info['end_date_focus']   
        self.focus_periods = Info['focus_periods']
        self.overwrite = Info['overwrite']

        if not os.path.isdir(self.fOutdir):
            os.mkdir(self.fOutdir)
 
#================================================================================================================
    def start_processing(self):        
#================================================================================================================
        
        aSubdirs = glob.glob(self.fIndir+'*/')

        for Subdir in aSubdirs:
            
            basename = Subdir.split('/')[-2]
            
            fOut_S1 = os.path.join(self.fOutdir, 'S1_{}.nc'.format(basename))
            fOut_S2 = os.path.join(self.fOutdir, 'S2_{}.nc'.format(basename))

            if os.path.isfile(fOut_S1) and (os.path.isfile(fOut_S2)):
                if self.overwrite:
                    if os.path.isfile(fOut_S1):
                        os.remove(fOut_S1)
                    if os.path.isfile(fOut_S2):
                        os.remove(fOut_S2)
                else:
                    continue
        
            if not self.read_IDs(Subdir):
                continue
            
            print('\r Processing: {}'.format(basename), end='                                                                                                                       ')

            self.get_labels(Subdir)
            self.process_S1_data(Subdir, fOut_S1)
            self.process_S2_data(Subdir, fOut_S2)

#================================================================================================================
    def get_labels(self, Subdir,):        
#================================================================================================================

        fIn = glob.glob(Subdir + '*.nc')
            
        xData = xr.open_dataset(fIn[0])
        self.aLabels = xData['LABEL'].values    

#================================================================================================================
    def process_S1_data(self, S1dir, fOut_S1):        
#================================================================================================================

        fS1 = glob.glob(os.path.join(S1dir,  'TS/S1_*.nc'))[0]
        xInput = self.read_data('S1', fS1)

        ''' ### Processing radar bands ###'''
        # Select the VV and VH bands (assumed to be in dB range!)
        radar_bands = [b for b in xInput.variables if 'VV' in b or 'VH' in b]
        band_data = xInput[radar_bands]

        # Make a Satio Timeseries out of this
        raw_data = band_data.to_array(dim='band').transpose('band', 'time', 'feature').values.astype(np.float32)  # Radar data is in float
        timestamps = band_data.time.values
        band_names = radar_bands
        raw_data_expanded = np.expand_dims(raw_data, axis=-1)

        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            raw_data_expanded = 10 * np.log10(raw_data_expanded)
        
        ts = Timeseries(data=raw_data_expanded, timestamps=timestamps, bands=band_names)

        # Before doing manupulations
        # first need to get rid of dB
        data_lin = _to_pwr(ts.data)
        ts.data = data_lin 

        # Compositing: use S1-specific frequency, window and a mean compositing operation
        composite_settings = dict(
            freq=10,
            window=20,
            mode='mean',
            start=pd.to_datetime(ts.timestamps[0]).strftime('%Y-%m-%d'),
            end=pd.to_datetime(ts.timestamps[-1]).strftime('%Y-%m-%d')
        )
        ts_composited = ts.composite(**composite_settings)
        
        # Linear interpolation
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            S1_ts_interpolated = ts_composited.interpolate()
        
        # Finally we can now go back to dB
        ts_interpolated_db = _to_db(S1_ts_interpolated.data)
        
        # And override the data array
        S1_ts_interpolated.data = ts_interpolated_db
        
       # Now we're back in an xarray dataset with bands as variables and the dummy spatial dimension has been dropped      
        S1_processed_data = S1_ts_interpolated.to_xarray().to_dataset('bands').squeeze(drop=True)  
        
        VH = S1_processed_data['VH'].values
        VV = S1_processed_data['VV'].values
        
        if len(self.aIDs) == 1:
            VH = np.reshape(VH, (len(VH), 1))
            VV = np.reshape(VV, (len(VV), 1))

        '''create the new dataset'''
        xData = xr.Dataset(
                data_vars=dict(
                    VH=(["date", "labels"], VH),
                    VV=(["date", "labels"], VV),      
                    IDs=(["labels"], xInput['IDs'].values)
                ),
                coords=dict(
                    date=S1_processed_data['time'].values,
                    labels = (self.aLabels)
                ),
                attrs=dict(description="satio cleaned data."),
            )

        xOut = xData.dropna(dim='labels', how='any')
        xOut.to_netcdf(fOut_S1)
        
#================================================================================================================
    def process_S2_data(self, S2dir, fOut_S2):        
#================================================================================================================

        fS2 = glob.glob(os.path.join(S2dir,  'TS/S2_*.nc'))[0]
        xInput = self.read_data('S2', fS2)

        ''' ### Processing optical bands ###'''
        '''Select the optical bands'''
        optical_bands = [b for b in xInput.variables if b.startswith('B')]
        xdata = xInput[optical_bands]

        '''Get the numpy data behind the xarray Dataset'''
        raw_data = xdata.to_array(dim='band').transpose('band', 'time', 'feature').values.astype(np.float32)
        
        '''Get the metadata that Satio needs to describe the timeseries'''
        timestamps = xdata.time.values
        band_names = optical_bands
        
        '''Transform the data into a Satio Timeseries object'''
        raw_data_expanded = np.round(np.expand_dims(raw_data, axis=-1) * (1/0.0001)).astype(np.uint16)
        ts = Timeseries(data=raw_data_expanded, timestamps=timestamps, bands=band_names)
        
        '''1) Compositing'''
        composite_settings = dict(
        freq=10,
        window=20,
        mode='median',
        start=pd.to_datetime(ts.timestamps[0]).strftime('%Y-%m-%d'),
        end=pd.to_datetime(ts.timestamps[-1]).strftime('%Y-%m-%d')
        )

        ts_composited = ts.composite(**composite_settings)
        
        '''2) Interpolation'''
        S2_ts_interpolated = ts_composited.interpolate()

        S2_processed_data = S2_ts_interpolated.to_xarray().to_dataset('bands').squeeze(drop=True)  # Now we're back in an xarray dataset with bands as variables and the dummy spatial dimension has been dropped

        B02 = S2_processed_data['B02'].values
        B03 = S2_processed_data['B03'].values
        B04 = S2_processed_data['B04'].values
        B05 = S2_processed_data['B05'].values
        B06 = S2_processed_data['B06'].values
        B07 = S2_processed_data['B07'].values
        B08 = S2_processed_data['B08'].values
        B8A = S2_processed_data['B8A'].values
        B11 = S2_processed_data['B11'].values
        B12 = S2_processed_data['B12'].values

        if len(self.aIDs) == 1:
            B02 = np.reshape(B02, (len(B02), 1))
            B03 = np.reshape(B02, (len(B03), 1))
            B04 = np.reshape(B02, (len(B04), 1))
            B05 = np.reshape(B02, (len(B05), 1))
            B06 = np.reshape(B02, (len(B06), 1))
            B07 = np.reshape(B02, (len(B07), 1))
            B08 = np.reshape(B02, (len(B08), 1))
            B8A = np.reshape(B02, (len(B8A), 1))
            B11 = np.reshape(B02, (len(B11), 1))
            B12 = np.reshape(B02, (len(B12), 1))

        '''create the new dataset'''
        xData = xr.Dataset(
                data_vars=dict(
                    B02=(["date", "labels"], B02),
                    B03=(["date", "labels"], B03),
                    B04=(["date", "labels"], B04),
                    B05=(["date", "labels"], B05),
                    B06=(["date", "labels"], B06),
                    B07=(["date", "labels"], B07),
                    B08=(["date", "labels"], B08),
                    B8A=(["date", "labels"], B8A),
                    B11=(["date", "labels"], B11),
                    B12=(["date", "labels"], B12),     
                    IDs=(["labels"], xInput['IDs'].values)
                ),
                coords=dict(
                    date=S2_processed_data['time'].values,
                    labels = (self.aLabels)
                ),
                attrs=dict(description="satio cleaned data."),
            )

        xOut = xData.dropna(dim='labels', how='any')
        xOut.to_netcdf(fOut_S2)
  
#================================================================================================================
    def read_IDs(self, Subdir):        
#================================================================================================================
        
        try:
            IDs_dir = os.path.join(Subdir, 'TS')   
            fIDs = glob.glob(os.path.join(IDs_dir, 'S1_*.nc'))
                     
            xIDs = xr.open_dataset(fIDs[0])
            self.aIDs = xIDs['CODE_OBJ'].values

            return True
        except:
            return False

#================================================================================================================
    def read_data(self, forS, fIn):        
#================================================================================================================

        xInput = xr.open_dataset(fIn)        

        # if len(xInput['time'].values) == self.full_period:
        #     return xInput

        # get available dates for the actual data
        file2read = netCDF4.Dataset(fIn,'r')
        at_date = file2read.variables['time'][:]
        at_date = np.ma.getdata(at_date)
        full_period = np.max(at_date) + 1

        year = str(xInput['time'].values[0].astype('datetime64[D]')).split('-')[0]
        start_date_interpolate = np.datetime64(str(xInput['time'].values[0].astype('datetime64[D]')))
        end_date_interpolate = np.datetime64(str(xInput['time'].values[-1].astype('datetime64[D]')))

        fullT = pd.date_range(xInput['time'].values[0], periods=full_period)

        year = str(xInput['time'].values[-1].astype('datetime64[D]')).split('-')[0]
        start_date_focus = np.datetime64(self.start_date_focus.format(year))
        end_date_focus = np.datetime64(self.end_date_focus.format(year))
        focusT = pd.date_range(start_date_focus, periods=self.focus_periods)

        

        ### S1 bands ###
        if forS == 'S1':
    
            aVH = xInput['VH'].values
            aVV = xInput['VV'].values

            aVH_expanded = np.zeros(shape=(full_period, np.shape(aVH)[1]))
            aVV_expanded = np.zeros(shape=(full_period, np.shape(aVV)[1]))
            
            # if len(xInput['time'].values) == self.full_period:
            aVH_expanded[at_date, :] = aVH
            aVV_expanded[at_date, :] = aVV

            '''create the new dataset'''
            xData = xr.Dataset(
                    data_vars=dict(
                        VH=(["time", "feature"], aVH_expanded),
                        VV=(["time", "feature"], aVV_expanded),
                    ),
                    coords=dict(
                        time=fullT,
                        feature = self.aLabels, 
                        IDs=(["feature"], self.aIDs)                       
                    ),
                    attrs=dict(description="satio cleaned data."),
                )  
                
            xNew =  xData.sel(time=slice(focusT[0], focusT[-1]), drop=True)

            xData = xr.Dataset(
                    data_vars=dict(
                        VH=(["time", "feature"], xNew['VH'].values),
                        VV=(["time", "feature"], xNew['VV'].values),
                        IDs=(["feature"], xNew['IDs'].values)               
                    ),
                    coords=dict(
                        time=focusT,
                        feature = xNew['feature'].values,                         
                    ),
                    attrs=dict(description="satio cleaned data."),
                )             

        ### S2 bands ###
        if forS == 'S2':
    
            aB02 = xInput['B02'].values
            aB03 = xInput['B03'].values
            aB04 = xInput['B04'].values
            aB05 = xInput['B05'].values
            aB06 = xInput['B06'].values
            aB07 = xInput['B07'].values
            aB08 = xInput['B08'].values
            aB8A = xInput['B8A'].values
            aB11 = xInput['B11'].values
            aB12 = xInput['B12'].values
            
            aB02[np.isnan(aB02)] = 0
            aB03[np.isnan(aB03)] = 0
            aB04[np.isnan(aB04)] = 0
            aB05[np.isnan(aB05)] = 0
            aB06[np.isnan(aB06)] = 0
            aB07[np.isnan(aB07)] = 0
            aB08[np.isnan(aB08)] = 0
            aB8A[np.isnan(aB8A)] = 0
            aB11[np.isnan(aB11)] = 0
            aB12[np.isnan(aB12)] = 0

            aB02_expanded = np.zeros(shape=(full_period, np.shape(aB02)[1]))
            aB02_expanded[at_date, :] = aB02
            aB03_expanded = np.zeros(shape=(full_period, np.shape(aB02)[1]))
            aB03_expanded[at_date, :] = aB03           
            aB04_expanded = np.zeros(shape=(full_period, np.shape(aB02)[1]))
            aB04_expanded[at_date, :] = aB04
            aB05_expanded = np.zeros(shape=(full_period, np.shape(aB02)[1]))
            aB05_expanded[at_date, :] = aB05
            aB06_expanded = np.zeros(shape=(full_period, np.shape(aB02)[1]))
            aB06_expanded[at_date, :] = aB06
            aB07_expanded = np.zeros(shape=(full_period, np.shape(aB02)[1]))
            aB07_expanded[at_date, :] = aB07
            aB08_expanded = np.zeros(shape=(full_period, np.shape(aB02)[1]))
            aB08_expanded[at_date, :] = aB08
            aB8A_expanded = np.zeros(shape=(full_period, np.shape(aB02)[1]))
            aB8A_expanded[at_date, :] = aB8A
            aB11_expanded = np.zeros(shape=(full_period, np.shape(aB02)[1]))
            aB11_expanded[at_date, :] = aB11
            aB12_expanded = np.zeros(shape=(full_period, np.shape(aB02)[1]))
            aB12_expanded[at_date, :] = aB12
                 
            '''create the new dataset'''
            xData = xr.Dataset(
                    data_vars=dict(
                        B02=(["time", "feature"], aB02_expanded),
                        B03=(["time", "feature"], aB03_expanded),
                        B04=(["time", "feature"], aB04_expanded),
                        B05=(["time", "feature"], aB05_expanded),
                        B06=(["time", "feature"], aB06_expanded),
                        B07=(["time", "feature"], aB07_expanded),
                        B08=(["time", "feature"], aB08_expanded),
                        B8A=(["time", "feature"], aB8A_expanded),
                        B11=(["time", "feature"], aB11_expanded),
                        B12=(["time", "feature"], aB12_expanded),
                    ),
                    coords=dict(
                        time=fullT,
                        feature = self.aLabels,
                        IDs=(["feature"], self.aIDs)                       
                    ),
                    attrs=dict(description="satio cleaned data."),
                )  
                
            xNew =  xData.sel(time=slice(focusT[0], focusT[-1]), drop=True)

            xData = xr.Dataset(
                    data_vars=dict(
                        B02=(["time", "feature"], xNew['B02'].values),
                        B03=(["time", "feature"], xNew['B03'].values),
                        B04=(["time", "feature"], xNew['B04'].values),
                        B05=(["time", "feature"], xNew['B05'].values),
                        B06=(["time", "feature"], xNew['B06'].values),
                        B07=(["time", "feature"], xNew['B07'].values),
                        B08=(["time", "feature"], xNew['B08'].values),
                        B8A=(["time", "feature"], xNew['B8A'].values),
                        B11=(["time", "feature"], xNew['B11'].values),
                        B12=(["time", "feature"], xNew['B12'].values),
                        IDs=(["feature"], xNew['IDs'].values)  
                    ),
                    coords=dict(
                        time=focusT,
                        feature = xNew['feature'].values,    
                    ),
                    attrs=dict(description="satio cleaned data."),
                )  
            
        return xData
                  
#================================================================================================================
if __name__ == '__main__':
#================================================================================================================

    Info = {
        'fIndir': r'/data/EEA_HRL_VLCC/data/ref/crop_type_demeter_extract/',
        'fOutdir': r'/data/EEA_HRL_VLCC/user/luc/data/LPIS/02_LPIS_preprocessed/',
        'start_date_focus': '{}-03-01',
        'end_date_focus': '{}-08-31',        
        'focus_periods': 183, # 2018/03/01 till 2018/08/30
        'overwrite': False
        }
    
    oLPISpreprocessing = cLPISpreprocessing(Info)
    oLPISpreprocessing.start_processing()
    
    


