#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Mon Sep  5 12:56:29 2022

@author: bertelsl
"""
import os
import glob
import netCDF4
import json
import xarray as xr
import numpy as np
import pandas as pd


#================================================================================================================
class cMerge_LUCAS_LPIS_crops(object):
#================================================================================================================
    def __init__(self, Info):        
#================================================================================================================
        
        self.fLPIS_dir = Info['fLPIS_dir']
        self.fLUCAS_dir = Info['fLUCAS_dir']
        self.fCroptypes = Info['fCroptypes']
        self.fOutdir = Info['fOutdir']
        self.overwrite = Info['overwrite']

        if not os.path.isdir(self.fOutdir):
            os.mkdir(self.fOutdir)        

#================================================================================================================
    def start_processing(self):        
#================================================================================================================
   
        aLPIS_crops = self.get_unique_LPIS_croptypes(self.fLPIS_dir)
        aLUCAS_crops = self.get_unique_LUCAS_croptypes(self.fLUCAS_dir)
        
        '''  file looks like:
        crop_type;LPIS1;LUCAS1;LUCAS2;LUCAS3
        Barley;Barley;Barley;None;None
        Beet;Beet;Sugar beet;None;None
        Grass;Grass;Grassland with sparse tree_shrub cover;Grassland without tree_shrub cover;Temporary grasslands
        '''

        df = pd.read_csv(self.fCroptypes, sep=';')
        df.set_index('crop_type', inplace=True)
        
        for crop, row in df.iterrows():
            
            afS1 = list()
            afS2 = list()

            fOut_S1 = os.path.join(self.fOutdir, 'S1_{}.nc'.format(crop))            
            fOut_S2 = os.path.join(self.fOutdir, 'S2_{}.nc'.format(crop))
            
            if os.path.isfile(fOut_S1) and os.path.isfile(fOut_S2):
                if self.overwrite:
                    os.remove(fOut_S1)
                    os.remove(fOut_S2)
                else:
                    continue

            print('\r * Writing data for {}'.format(crop), 
                  end='                                                                                                                                               ')

            if row.LPIS1 != 'None':
                afS1 = afS1 + list(glob.glob(self.fLPIS_dir + 'S1_*{}*.nc'.format(row.LPIS1)))
                afS2 = afS2 + list(glob.glob(self.fLPIS_dir + 'S2_*{}*.nc'.format(row.LPIS1)))
            if row.LUCAS1 != 'None':
                afS1 = afS1 + list(glob.glob(self.fLUCAS_dir + 'S1_*{}*.nc'.format(row.LUCAS1)))
                afS2 = afS2 + list(glob.glob(self.fLUCAS_dir + 'S2_*{}*.nc'.format(row.LUCAS1)))
            if row.LUCAS2 != 'None':
                afS1 = afS1 + list(glob.glob(self.fLUCAS_dir + 'S1_*{}*.nc'.format(row.LUCAS2)))
                afS2 = afS2 + list(glob.glob(self.fLUCAS_dir + 'S2_*{}*.nc'.format(row.LUCAS2)))
            if row.LUCAS3 != 'None':
                afS1 = afS1 + list(glob.glob(self.fLUCAS_dir + 'S1_*{}*.nc'.format(row.LUCAS3)))
                afS2 = afS2 + list(glob.glob(self.fLUCAS_dir + 'S2_*{}*.nc'.format(row.LUCAS3)))

            ''' Process S1 files'''
            init = True
                
            for fS1 in afS1:
                xInput = xr.open_dataset(fS1)
                aDate = xInput['date'].values
                xInput = xInput.drop('date', dim=None)
                
                if init:
                    init = False
                    aXinput = xInput                    
                else:
                    aXinput = xr.concat([aXinput, xInput], dim='labels')
               
            ''' Because LUCAS IDs were float numbers'''                  
            aIDs = []       
            
            for ID in aXinput['IDs'].values:
                aIDs.append(str(ID))
                    
            '''create the new dataset'''
            xOut = xr.Dataset(
                    data_vars=dict(
                        VH=(["date", "IDs"], aXinput['VH'].values),
                        VV=(["date", "IDs"], aXinput['VV'].values)                    
                    ),
                    coords=dict(
                        date = aDate,
                        IDs = aIDs,   
                    ),
                    attrs=dict(description="merged LPIS and LUCAS data."),
                )

            xOut = xOut.where(xOut.apply(np.isfinite), drop=True)      
            xOut.to_netcdf(fOut_S1)
            
            ''' Process S2 files'''
            init = True
                
            for fS2 in afS2:
                xInput = xr.open_dataset(fS2)
                aDate = xInput['date'].values
                xInput = xInput.drop('date', dim=None)
                
                if init:
                    init = False
                    aXinput = xInput                    
                else:
                    aXinput = xr.concat([aXinput, xInput], dim='labels')
 
            ''' Because LUCAS IDs were float numbers'''                  
            aIDs = []       
            
            for ID in aXinput['IDs'].values:
                aIDs.append(str(ID))
                                                   
            '''create the new dataset'''
            xOut = xr.Dataset(
                    data_vars=dict(
                        B02=(["date", "IDs"], aXinput['B02'].values),
                        B03=(["date", "IDs"], aXinput['B03'].values),
                        B04=(["date", "IDs"], aXinput['B04'].values),
                        B05=(["date", "IDs"], aXinput['B05'].values),
                        B06=(["date", "IDs"], aXinput['B06'].values),
                        B07=(["date", "IDs"], aXinput['B07'].values),
                        B08=(["date", "IDs"], aXinput['B08'].values),
                        B8A=(["date", "IDs"], aXinput['B8A'].values),
                        B11=(["date", "IDs"], aXinput['B11'].values),
                        B12=(["date", "IDs"], aXinput['B12'].values),
                    ),
                    coords=dict(
                        IDs = aIDs,
                        date = aDate
                    ),
                    attrs=dict(description="merged LPIS and LUCAS data."),
                )
            
            xOut = xOut.where(xOut.apply(np.isfinite), drop=True)                
            xOut = xOut.where(xOut.B04 != 0, drop=True)      
            
            xOut.to_netcdf(fOut_S2)

#================================================================================================================
    def get_unique_LUCAS_croptypes(self, fIndir):        
#================================================================================================================
   
        afCrops = glob.glob(os.path.join(fIndir, '*.nc'))
        afCrops.sort()
        
        aUnique_CropTypes = []

        for fCrops in afCrops:          
            crop = os.path.basename(fCrops).split('.')[0]
            crop = crop.split('S1_')[-1]
            crop = crop.split('S2_')[-1]

            aUnique_CropTypes.append(crop)

        return np.unique(aUnique_CropTypes)

#================================================================================================================
    def get_unique_LPIS_croptypes(self, fIndir):        
#================================================================================================================
   
        afCrops = glob.glob(os.path.join(fIndir, '*.nc'))
        afCrops.sort()
        
        aUnique_CropTypes = []

        for fCrops in afCrops:          
            crop = os.path.basename(fCrops).split('.')[0]
            crop = crop.split('_')[2:]

            aUnique_CropTypes.append(crop)

        return np.unique(aUnique_CropTypes)

#================================================================================================================
if __name__ == '__main__':
#================================================================================================================

    Info = {
        'fLUCAS_dir': r'/data/EEA_HRL_VLCC/user/luc/data/LUCAS2018/03_LUCAS_crops/',
        'fLPIS_dir': r'/data/EEA_HRL_VLCC/user/luc/data/LPIS/04_LPIS_crops/',
        'fCroptypes': '/data/EEA_HRL_VLCC/user/luc/xlsx/LPIS_vs_LUCAS_crops.csv',
        'fOutdir': r'/data/EEA_HRL_VLCC/user/luc/data/04_LPIS_LUCAS_merged_crops/',
        'overwrite': False
        }

    oMerge_LUCAS_LPIS_crops = cMerge_LUCAS_LPIS_crops(Info)
    oMerge_LUCAS_LPIS_crops.start_processing()
