from typing import Dict, List, Tuple
from pathlib import Path
import json
from copy import deepcopy
import glob
import numpy as np
import pandas as pd
import geopandas as gpd
import xarray as xr
import cv2
from loguru import logger
from rasterio.enums import Resampling

try:
    import importlib.resources as pkg_resources
except ImportError:
    # Try backported to PY<37 `importlib_resources`.
    import importlib_resources as pkg_resources

from satio.utils import run_parallel
from satio.features import Features
from satio.collections import (TrainingCollection,
                               L2ATrainingCollection,
                               PatchLabelsTrainingCollection,
                               AgERA5TrainingCollection,
                               DEMCollection,
                               WorldCoverCollection)
from satio.geoloader import load_reproject

from cropclass.collections import (SIGMA0HRLTrainingCollection,
                                   OpenEOTrainingCollection)
from cropclass.utils.seasons import get_processing_dates
from cropclass.features.inference import (FeaturesComputer,
                                          get_sample_features)
from cropclass.tools.trainingdf import (labels_to_str,
                                        select_tsteps,
                                        translate_labels,
                                        select_training_classes,
                                        filter_df,
                                        remove_outliers
                                        )
from cropclass.resources import biomes
from cropclass.features.inference import MAX_GAP


MAX_WORKERS = 1

BIOME_RASTERS = {f"biome{i:02d}": f"biome_{i:02d}.tif"
                 for i in range(1, 14)}


class TrainingFeaturesComputer(FeaturesComputer):
    def __init__(self,
                 *args,
                 nrpixels=1,
                 erode_fields=True,
                 **kwargs):

        super().__init__(*args, **kwargs)

        self.nrpixels = nrpixels
        self.erode_fields = erode_fields

    def get_features(self,
                     collections: Dict[str, TrainingCollection],
                     sampleID: str,
                     database,
                     labelattr,
                     idattr):

        bounds, epsg = self._get_location_bounds(sampleID, database, idattr)
        collections = self._filter_collections(collections, sampleID,
                                               bounds, epsg)

        # Get labels first
        labels = self.get_labels(collections['LABELS'],
                                 labelattr, erode=self.erode_fields)

        if not np.any(labels.data != 0):
            # Sample has no crop pixel!
            logger.info('No crop type information in this sample: skipping')
            return None

        # In TrainingFeatures we start from satio collections
        # which need to be transformed to xarray.DataArrays
        inputs = self._collections_to_inputs(collections, bounds)

        # Get features from main method
        features = super().get_features(inputs)

        # Add several attributes
        features = self.add_attributes(database, sampleID,
                                       idattr, features,
                                       bounds=bounds, epsg=epsg)

        # Add DEM features
        features = self.add_dem(inputs, bounds, epsg, features)

        # Add biomes features
        logger.info('Adding biome features...')
        for feat_name, raster_path in BIOME_RASTERS.items():
            features = self.add_raster_feat(
                features, feat_name, raster_path,
                bounds, epsg, resolution=10)

        # If segmentation was done, we should only keep
        # ony valid pixel for each instance
        if 'INSTANCE-ID' in features.names:
            labels = self.mask_labels_per_instance(
                labels, features.select(['INSTANCE-ID']))

        # convert to dataframe
        features = features.df

        # add the labels in their original dtype
        labels = np.squeeze(labels.data.reshape(labels.data.shape[0],
                                                labels.data.shape[1] *
                                                labels.data.shape[2]).T)
        features['LABEL'] = labels

        # Sample from available pixels
        sampled = self.sample_pixels(features, self.nrpixels)

        return sampled

    def add_attributes(self, database, sampleID, idattr,
                       features: Features, attributes: List = None,
                       bounds: Tuple = None, epsg: int = None):

        def _add_attr(features, attr_value, attr_name):
            if attr_name in features.attrs:
                # Check if value matches!
                assert attr_value == features.attrs[attr_name]
            else:
                features = features.add_attribute(attr_value, attr_name)
            return features

        attributes = attributes if attributes is not None else [
            'label', 'year', 'ref_id']

        logger.info('---- EXTRACTING ATTRIBUTES START')
        logger.info(f'Looking for attributes: {attributes}')

        sample = database.set_index(idattr).loc[sampleID]

        # Add attributes
        features = _add_attr(features, sampleID, idattr)
        features = _add_attr(features, int(sample['year']), 'year')
        features = _add_attr(features, sample['ref_id'], 'ref_id')

        # Add lat/lon as attributes
        features = self.add_latlon(features, bounds, epsg)

        logger.info((f'Successfully extracted '
                     f'{len(features.attrs)} attributes.'))
        logger.info('---- EXTRACTING ATTRIBUTES DONE')

        return features

    def add_dem(self, inputs, bounds, epsg, features):
        from satio.collections import DEMCollection
        logger.info('Computing DEM features for sample ...')

        # Create and filter DEM collection
        demcoll = DEMCollection(('/data/MEP/DEM/COP-DEM_GLO-30_DTED'
                                 '/S2grid_20m'))
        demcoll = demcoll.filter_bounds(bounds, epsg)

        # Compute DEM features
        dem_feat = Features.from_dem(demcoll).select(
            ['DEM-alt-20m', 'DEM-slo-20m'])

        # Merge with other features
        return features.merge(dem_feat)

    def add_raster_feat(self,
                        feat: Features,
                        feat_name: str,
                        raster_path: str,
                        bounds,
                        epsg,
                        resolution=10):
        """
        Add features from global latlon rasters/vrts
        """
        # we use border_buff to avoid blocks artefact on meteo data

        def _read_file(raster_file, bounds, epsg):
            arr = load_reproject(
                raster_file, bounds, epsg,
                resolution=resolution, border_buff=resolution,
                fill_value=0,
                resampling=Resampling.bilinear)

            if 'realm' in feat_name:
                arr = (arr > 0).astype(np.uint8)

            return feat.add(arr, [feat_name])

        with pkg_resources.path(biomes, raster_path) as raster_file:
            return _read_file(raster_file, bounds, epsg)

    def get_labels(self, labels_collection, labelattr,
                   erode=True):

        # Load the labels
        labelsdata = labels_collection.load()
        labelsdata = labelsdata[labelattr]

        if erode:
            labelsdata = self.erode_labels(labelsdata)

        labelfeatures = Features(data=labelsdata,
                                 names=['LABEL'],
                                 dtype='uint64')

        return labelfeatures

    @staticmethod
    def erode_labels(labelsdata: np.ndarray) -> np.ndarray:
        '''Method to apply a negative buffer to all labels
        in order to avoid border pixel sampling downstream.
        '''

        if (labelsdata.values != 0).sum() < 10:
            logger.warning('Too few labeled pixels to perform erosion!')
            return labelsdata
        logger.info('Performing label erosion ...')

        labelsdata_shorts = np.zeros_like(labelsdata, dtype=np.uint8)
        for i, label in enumerate(np.unique(labelsdata)):
            if label != 0:
                labelsdata_shorts[labelsdata == label] = i + 1

        edges = cv2.Canny(labelsdata_shorts, 0, 0)
        edges_dilated = cv2.dilate(
            edges, np.ones((2, 2), np.uint8), iterations=1)

        # Apply the buffer
        labelsdata.values[edges_dilated == 255] = 0

        return labelsdata

    @staticmethod
    def mask_labels_per_instance(labels: Features, instances: Features):
        logger.info('Selecting one valid pixel per instance ...')
        labeldata = labels.data.copy()
        segments = instances.data
        for segm in np.unique(segments):

            # Get a mask representing this instance
            msk = segments == segm

            # Get unique labels in segment
            unique_labels = np.unique(labeldata[msk][labeldata[msk] != 0])

            if len(unique_labels) == 0:
                # No labels for this instance
                continue

            if len(unique_labels) > 1:
                # More than one unique label in one segment
                # should ignore this segment!
                labeldata[msk] = 0
            else:
                # Need to mask all valid pixels except one
                # so we will not sample more than 1 pixel per segment

                # Get valid labels in the instance
                valid_instance_labels = (labeldata != 0) & (msk)

                # Select only one of these
                label_idx_to_keep = tuple([x[0] for x in list(
                    np.where(valid_instance_labels))])

                # Create new labeldata with only one valid pixel
                inpaint_labels = np.zeros_like(labeldata)
                inpaint_labels[label_idx_to_keep] = labeldata[label_idx_to_keep]  # NOQA

                # Update the original labels for this segment
                labeldata = np.where(msk, inpaint_labels, labeldata)

        labels.data = labeldata

        return labels

    @staticmethod
    def add_latlon(features, bounds, epsg, resolution=10):
        from shapely.geometry import box, Point

        logger.info('Adding lat/lon features for sample ...')

        xsize = features.data.shape[1]
        ysize = features.data.shape[2]
        bounds_geom = box(*bounds)

        gs = gpd.GeoSeries([bounds_geom], crs=f'EPSG:{epsg}')
        gs = gs.to_crs(epsg=3857)
        xmin, ymin, xmax, ymax = gs.bounds.iloc[0].values

        xx = np.linspace(xmin + resolution/2, xmax + resolution/2, xsize)
        yy = np.linspace(ymax + resolution/2, ymin + resolution/2, ysize)

        xx = np.broadcast_to(xx, [xsize, ysize]).reshape(-1)
        yy = np.broadcast_to(yy, [xsize, ysize]).T.reshape(-1)

        points = [Point(x0, y0) for x0, y0 in zip(xx, yy)]

        gs = gpd.GeoSeries(points, crs='EPSG:3857')
        gs = gs.to_crs(epsg=4326)

        lon_mesh = gs.apply(lambda p: p.x).values.reshape((xsize, ysize))
        lat_mesh = gs.apply(lambda p: p.y).values.reshape((xsize, ysize))

        features_arr = np.array([lat_mesh, lon_mesh])
        features_names = ['lat', 'lon']

        latlonfeatures = Features(features_arr, features_names)

        return features.merge(latlonfeatures)

    def _collections_to_inputs(self, collections, bounds):

        inputs = dict()
        if 'S2' in self.feature_settings.keys():
            inputs['S2'] = self._load_S2(
                collections['S2'],
                masksettings=self.preprocessing_settings['S2'].get(
                    'mask', None))

        if 'S1' in self.feature_settings.keys():
            inputs['S1'] = self._load_S1(collections['S1'])

        if 'METEO' in self.feature_settings.keys():
            inputs['METEO'] = self._load_METEO(collections['METEO'], bounds)

        return inputs

    @staticmethod
    def _get_location_bounds(location_id, trainingdb, idattr):
        trainingdb_row = trainingdb[
            trainingdb[idattr] == location_id].iloc[0]
        epsg = int(trainingdb_row['epsg'])
        bounds = trainingdb_row['bounds']

        if isinstance(bounds, str):
            bounds = eval(bounds)
        bounds = np.array(bounds)

        return bounds, epsg

    @staticmethod
    def _filter_collections(collections, location_id, bounds, epsg):

        new_colls = collections.copy()

        for s in new_colls.keys():
            if s in ['S2', 'S1', 'METEO', 'LABELS']:
                new_colls[s] = (new_colls[s]
                                .filter_location(location_id))

            elif s in ['DEM', 'WorldCover']:
                new_colls[s] = (new_colls[s]
                                .filter_bounds(bounds, epsg))

        return new_colls

    @staticmethod
    def _load_S2(coll, masksettings=None):

        logger.info('Loading S2 bands ...')

        bands_10m = coll.load_timeseries('B02', 'B03', 'B04', 'B08')
        bands_20m = coll.load_timeseries('B05', 'B06', 'B07', 'B11', 'B12')
        bands = bands_10m.merge(bands_20m.upsample())

        if masksettings is not None:

            from cropclass.utils.masking import (scl_mask, multitemporal_mask,
                                                 SCL_MASK_VALUES)

            logger.info('Starting SCL masking...')
            scl = coll.load_timeseries('SCL').upsample()

            # Get SCL mask settings for S2
            scl_mask_values = masksettings.get('scl_mask_values',
                                               SCL_MASK_VALUES)
            erode_r = masksettings.get('erode_r', None)
            dilate_r = masksettings.get('dilate_r', None)

            if (scl.data.shape[2] == 1) or (scl.data.shape[3] == 1):
                # point-based extractions
                # -> switch off erosion and dilation
                erode_r = None
                dilate_r = None

            # Get the SCL-based mask
            mask = scl_mask(scl.data, mask_values=scl_mask_values,
                            erode_r=erode_r, dilate_r=dilate_r)

            # Optionally get a multitemporal mask
            if masksettings.get('multitemporal', False):
                mask = multitemporal_mask(bands, prior_mask=mask)

            # now apply the mask
            bands = bands.mask(mask, drop_nodata=False)

        # convert to the correct format
        bands.data = bands.data.astype(np.float32)
        bands.data[bands.data == 0] = np.nan

        bands = bands.to_xarray().rename({'time': 't'})

        return bands

    @staticmethod
    def _load_S1(coll):
        # TODO: SWITCH TO 10M EXTRACTIONS AND DATA LOADING!

        logger.info('Loading S1 bands ... TODO: SWITCH TO 10M')

        bands_20m = coll.load_timeseries('VV', 'VH', mask_and_scale=True)
        bands = bands_20m.upsample()
        bands = bands.to_xarray().rename({'time': 't'})

        return bands

    @staticmethod
    def _load_METEO(coll, bounds):

        logger.info('Loading METEO bands ...')

        bands = coll.load_timeseries('temperature_mean')

        # Get desired output shape
        outdim = int((bounds[2] - bounds[0]) / 10)
        # Nr of required upsamplings
        upsampling = int(np.log2(outdim))
        # Resize the AgERA5 data
        for _ in range(upsampling):
            bands = bands.upsample()

        bands = bands.to_xarray().rename({'time': 't'})

        return bands

    @ staticmethod
    def sample_pixels(df: pd.DataFrame,
                      nrpixels: int,
                      label='LABEL',
                      nodata_value=0,
                      seed=1234):
        '''
        Method to sample one or more valid pixels
        from the training dataframe
        '''
        if df is None:
            return None

        np.random.seed(seed)

        labeldata = df[label].values

        idxvalid = np.where(labeldata != nodata_value)[0]

        if len(idxvalid) == 0:
            return None
        elif len(idxvalid) > nrpixels:

            # Make DataFrame with valid pixels
            labeldf = pd.DataFrame(index=idxvalid, data=labeldata[idxvalid],
                                   columns=['label'])

            # Get the pixel count for smallest class
            smallest = labeldf.value_counts().min()

            # Get the to be sampled pixel amount per class
            sample_pixels = np.min([nrpixels, smallest])

            # Do the class sampling
            sampled = labeldf.groupby('label').sample(sample_pixels)

            # Get the indexes of sampled pixels
            idxchoice = sampled.index.to_list()

            # Get the subset DF
            df_sample = df.iloc[idxchoice, :]
        else:
            df_sample = df.iloc[idxvalid, :]

        return df_sample


class GEETrainingFeaturesComputer(TrainingFeaturesComputer):

    def __init__(self,
                 *args,
                 **kwargs):
        super().__init__(*args, **kwargs)

    @staticmethod
    def _load_S1(coll):
        logger.info('Loading S1 bands ...')
        bands = coll.load_timeseries('VV', 'VH', mask_and_scale=True)
        bands = bands.to_xarray().rename({'time': 't'})

        return bands


class OpenEOTrainingFeaturesComputer(FeaturesComputer):

    def __init__(self,
                 *args,
                 nrpixels=1,
                 erode_fields=False,
                 **kwargs):
        super().__init__(*args, **kwargs)
        self.nrpixels = nrpixels
        self.erode_fields = erode_fields

    def _load_data(self, collection, bands):

        data = collection.load_timeseries(*bands)
        # convert to the correct format
        data.data = data.data.astype(np.float32)
        data.data[data.data == 0] = np.nan
        # rename time dimension
        data = data.to_xarray().rename({'time': 't'})

        return data

    def _collections_to_inputs(self, collections):

        coll = collections['OpenEO']
        inputs = dict()
        if 'S2' in self.feature_settings.keys():
            S2_bands = self.preprocessing_settings['S2'].get(
                'bands', ["B02", "B03", "B04", "B05",
                          "B06", "B07", "B08", "B11", "B12"])
            inputs['S2'] = self._load_data(coll, S2_bands)

        if 'S1' in self.feature_settings.keys():
            S1_bands = self.preprocessing_settings['S1'].get(
                'bands', ['VV', 'VH'])
            inputs['S1'] = self._load_data(coll, S1_bands)

        if 'METEO' in self.feature_settings.keys():
            meteo_bands = self.preprocessing_settings['METEO'].get(
                'bands', ['temperature_mean'])
            inputs['METEO'] = self._load_data(coll, meteo_bands)

        return inputs

    def _filter_collections(self,
                            collections: Dict[str, TrainingCollection],
                            location_id: str):

        new_colls = collections.copy()

        for s in new_colls.keys():
            if s in ['OpenEO']:
                new_colls[s] = (new_colls[s]
                                .filter_location(location_id))
            else:
                logger.warning(f'Collection {s} not supported!')

        return new_colls

    def get_bounds_epsg(self, database, sampleID,
                        idattr):

        # Find S2 tile and matching epsg code
        from satio import layers as satio_layers
        s2grid = satio_layers.load('s2grid')
        sample = database.loc[database[idattr] == sampleID]
        sample = gpd.sjoin(sample,
                           s2grid,
                           how='left',
                           op='intersects')
        epsg = sample.epsg.values[0]
        bounds = sample.to_crs(f'EPSG:{epsg}'
                               ).geometry.buffer(50, cap_style=3
                                                 ).bounds
        bounds = np.array([bounds.minx.iloc[0], bounds.miny.iloc[0],
                           bounds.maxx.iloc[0], bounds.maxy.iloc[0]])

        return bounds, epsg

    def add_dem(self, bounds, epsg, features):
        from satio.collections import DEMCollection
        logger.info('Computing DEM features for sample ...')

        # Create and filter DEM collection
        demcoll = DEMCollection(('/data/MEP/DEM/COP-DEM_GLO-30_DTED'
                                 '/S2grid_20m'))
        demcoll = demcoll.filter_bounds(bounds, epsg)

        # Compute DEM features
        dem_feat = Features.from_dem(demcoll).select(
            ['DEM-alt-20m', 'DEM-slo-20m'])
        # select center pixel
        dem_data = dem_feat.data[:, 5, 5]
        dem_data = np.expand_dims(dem_data,
                                  axis=[1, 2])
        # create new version of dem_feat
        dem_feat = Features(data=dem_data, names=dem_feat.names)

        # Merge with other features
        return features.merge(dem_feat)

    def add_attributes(self, database, sampleID, idattr,
                       labelattr, features: Features,
                       attributes: List = None):

        def _add_attr(features, attr_value, attr_name):
            if attr_name in features.attrs:
                # Check if value matches!
                assert attr_value == features.attrs[attr_name]
            else:
                features = features.add_attribute(attr_value, attr_name)
            return features

        attributes = attributes if attributes is not None else [
            'ref_id', labelattr, idattr, 'lat', 'lon']

        logger.info('---- EXTRACTING ATTRIBUTES START')
        logger.info(f'Looking for attributes: {attributes}')

        sample = database.set_index(idattr).loc[sampleID]
        # infer ref_id from sample_file
        ref_id = Path(sample.samplefile).parents[1].stem
        # get label
        label = sample[labelattr]
        # convert to int64 to make it compatible with GEE
        # workflow
        label = int(label.replace('-', ''))
        # infer lat lon from geometry
        lat = sample.geometry.y
        lon = sample.geometry.x

        # Add attributes
        features = _add_attr(features, sampleID, idattr)
        features = _add_attr(features, ref_id, 'ref_id')
        features = _add_attr(features, lat, 'lat')
        features = _add_attr(features, lon, 'lon')
        features = _add_attr(features, label, 'LABEL')

        logger.info((f'Successfully extracted '
                     f'{len(features.attrs)} attributes.'))
        logger.info('---- EXTRACTING ATTRIBUTES DONE')

        return features

    def add_raster_feat(self,
                        feat: Features,
                        feat_name: str,
                        raster_path: str,
                        bounds,
                        epsg,
                        resolution=10):
        """
        Add features from global latlon rasters/vrts
        """
        def _read_file(raster_file, bounds, epsg):
            arr = load_reproject(
                raster_file, bounds, epsg,
                resolution=resolution,
                fill_value=0,
                resampling=Resampling.bilinear)

            if 'realm' in feat_name:
                arr = (arr > 0).astype(np.uint8)

            # get center pixel
            center = int(arr.shape[0] / 2)
            arr = np.expand_dims(np.array([arr[center, center]]),
                                 axis=1)

            return feat.add(arr, [feat_name])

        with pkg_resources.path(biomes, raster_path) as raster_file:
            return _read_file(raster_file, bounds, epsg)

    def get_features(self,
                     collections: Dict[str, TrainingCollection],
                     sampleID: str,
                     database,
                     labelattr,
                     idattr):

        collections = self._filter_collections(collections, sampleID)
        inputs = self._collections_to_inputs(collections)

        # Need to set spurious S2 values to NaN
        # it seems that clouds are marked with value 2.1
        if 'S2' in inputs.keys():
            inputs['S2'] = xr.where(inputs['S2'] > 20000,
                                    np.nan, inputs['S2'])



        # Get features from main method
        features = super().get_features(inputs)

        # Add several attributes
        features = self.add_attributes(database, sampleID,
                                       idattr, labelattr,
                                       features)

        # Get bounds and epsg for next two features
        bounds, epsg = self.get_bounds_epsg(database, sampleID,
                                            idattr)

        # Add DEM features
        features = self.add_dem(bounds, epsg, features)

        # Add biomes features
        logger.info('Adding biome features...')
        for feat_name, raster_path in BIOME_RASTERS.items():
            features = self.add_raster_feat(
                features, feat_name, raster_path,
                bounds, epsg, resolution=10)

        # convert to dataframe
        features = features.df

        return features


class ewoc_CIB_Pipeline:
    def __init__(self,
                 database: Path,
                 features_dir: Path,
                 start_month: int,
                 end_month: int,
                 feature_settings: Dict,
                 preproc_settings: Dict = None,
                 scenario: str = None,
                 overwrite: bool = False,
                 debug: bool = False,
                 sparkcontext=None,
                 nrpixels=1,
                 segmentation=False,
                 segm_settings=None,
                 erode_fields=True):
        self.database = database
        self.features_dir = features_dir
        self.preprocessing_settings = preproc_settings
        self.feature_settings = feature_settings
        self.start_month = start_month
        self.end_month = end_month
        self._debug = debug
        self._overwrite = overwrite
        self._sc = sparkcontext
        self.labelattr = 'CT'
        self.idattr = 'location_id'
        self.nrpixels = nrpixels
        self.segmentation = segmentation
        self.segm_settings = segm_settings
        self.erode_fields = erode_fields
        self.featcomputer = TrainingFeaturesComputer

        if scenario is not None:
            self.features_dir = Path(self.features_dir) / scenario

    def run(self, merge: bool = True):

        # Create the output directory
        self.features_dir.mkdir(exist_ok=True)

        # Load the database
        self.load_database(self.database)

        # Get processing tasks
        self._tasks = self.get_processing_tasks()

        # Save the settings
        self._save_settings(
            self.features_dir,
            preprocessing_settings=self.preprocessing_settings,
            feature_settings=self.feature_settings
        )

        # Process tasks sequentially
        feature_df_files = []
        for i, task in enumerate(self._tasks):
            logger.info((f'Launching task: {task} '
                         f'({i+1}/{len(self._tasks)})'))
            feature_df_files.append(self.run_task(task))

            if self._debug and i == 1:
                logger.debug('Interrupting loop in debug mode.')
                break

        # Filter out any failed tasks
        feature_df_files = [f for f in feature_df_files if f is not None]

        if merge:
            self._merge(feature_df_files, self.features_dir)

    def run_task(self, taskname: str):

        # Getting some settings
        nrpixels = self.nrpixels

        # Create the output directory
        self.features_dir.mkdir(exist_ok=True)

        # Construct outfile path
        outfile = Path(self.features_dir) / f'{taskname}_features_df.parquet'

        if outfile.is_file() and not self._overwrite:
            logger.info(f'Output file `{outfile}` exists -> skipping')
            return outfile

        # infer the year of the dataset from task
        year = taskname.split('_')[0]

        # Get processing dates
        start_date, end_date = get_processing_dates(self.start_month,
                                                    self.end_month,
                                                    year)
        logger.info(f'Start date: {start_date}; '
                    f'End date: {end_date}')

        # Create training collections
        task_df = self.df[self.df['ref_id'] == taskname]
        task_db = self.db[self.db['ref_id'] == taskname]
        collections = self.create_collections(task_df)

        # Get local copies of the settings
        preproc_settings = deepcopy(self.preprocessing_settings)
        feature_settings = deepcopy(self.feature_settings)

        # Impute processing dates into preprocessing settings
        if preproc_settings is not None:
            for sensor in preproc_settings.keys():
                preproc_settings[sensor]['composite']['start'] = start_date  # NOQA
                preproc_settings[sensor]['composite']['end'] = end_date

        # Initialize FeatureComputer
        featurecomputer = self.featcomputer(
            feature_settings,
            start_date,
            end_date,
            preproc_settings,
            nrpixels=nrpixels,
            segmentation=self.segmentation,
            segm_settings=self.segm_settings)

        # Run feature computation
        if self._sc is None:
            logger.info('Running feature computation in serial.')
            features = self._run_task_local(taskname, collections,
                                            task_db,
                                            featurecomputer)

        else:
            logger.info('Running feature computation on executors.')
            features = self._run_task_spark(taskname, collections,
                                            task_db,
                                            featurecomputer)

        # Save the result
        if features is not None:
            self.save_features(features, outfile)
            return outfile
        else:
            logger.warning('Could not compute features for this task!')
            return None

    def save_features(self, features, outfile):
        logger.info(f'Saving {len(features)} features to: {outfile}')
        features.to_parquet(outfile)

    def _run_task_local(self, task, collections, task_db,
                        featurecomputer):
        def get_feat_df(sampleID):
            '''Wrapper function to allow to use run_parallel
            '''
            return get_sample_features(sampleID,
                                       collections,
                                       task_db,
                                       featurecomputer,
                                       self._debug,
                                       self.labelattr,
                                       self.idattr)

        sampleIDs = self.db[self.db['ref_id'] == task][self.idattr].to_list()

        if self._debug:
            sampleIDs = sampleIDs[:5]

        results = run_parallel(get_feat_df,
                               sampleIDs,
                               max_workers=MAX_WORKERS)
        results = [result for result in results if result is not None]

        df = pd.concat(results, ignore_index=True)

        return df

    def _run_task_spark(self, task, collections,
                        task_db, featurecomputer):

        sampleIDs = self.db[self.db['ref_id'] == task][self.idattr].to_list()
        debug = self._debug
        labelattr = self.labelattr
        idattr = self.idattr

        if debug:
            sampleIDs = sampleIDs[:5]

        rdd = self._sc.parallelize(sampleIDs,
                                   len(sampleIDs)).map(
            lambda x: get_sample_features(
                x, collections, task_db,
                featurecomputer, debug,
                labelattr, idattr)
        ).filter(lambda t: t is not None)

        # Check if we actually have something left
        if rdd.isEmpty():
            logger.warning('No samples left in RDD -> aborting job')
            return None

        results = rdd.collect()
        df = pd.concat(results, ignore_index=True)

        return df

    @ staticmethod
    def _get_df(features):
        if features is None:
            return None
        else:
            return features.df

    def _save_settings(self, outdir: Path, **settings):

        def _make_serial(d):
            if isinstance(d, dict):
                return {_make_serial(k): _make_serial(v)
                        for k, v in d.items()}
            elif isinstance(d, list):
                return [_make_serial(i) for i in d]
            elif callable(d):
                return d.__name__
            elif (isinstance(d, float)) and (np.isnan(d)):
                return 'NaN'
            else:
                return d

        settings_file = Path(outdir) / 'Feature_settings.json'
        logger.info(f'Saving settings to: {settings_file}')

        if settings_file.is_file() and not self._overwrite:
            # Need to check if settings have not changed
            old_settings = json.load(open(settings_file))
            assert old_settings == _make_serial(settings)
        else:
            with open(settings_file, 'w') as f:
                json.dump(_make_serial(settings), f, indent=4)

    def get_processing_tasks(self):
        # One processing task per ref_id
        tasks = list(self.db.ref_id.unique())
        logger.info(f'Found {len(tasks)} processing tasks ...')

        return tasks

    def load_database(self, file: Path):
        logger.info(f'Loading database file: {file}')
        db = gpd.read_file(file)
        db = db[db['contenttype'].isin(['110', '111'])]
        logger.info(f'Database successfully opened. Found {len(db)} samples.')

        # Convert into SATIO format
        # TODO: maybe not all of these are needed!
        df = db[[self.idattr, 'tile', 'epsg', 'path', 'split',
                 'ref_id', 'year', 'start_date', 'end_date', 'bounds']]

        self.db = db
        self.df = df

    def create_collections(self, df, settings=None):

        collections = {}

        if self.idattr != 'location_id':
            df = df.rename(columns={self.idattr: 'location_id'})

        # Collections we always need
        collections['LABELS'] = PatchLabelsTrainingCollection(
            df, dataformat='worldcereal')
        collections['S2'] = L2ATrainingCollection(
            df, dataformat='worldcereal')
        if settings is not None:
            s1_orbits = settings.get('S1_orbits', None)
        else:
            s1_orbits = None
        collections['S1'] = SIGMA0HRLTrainingCollection(
            df, dataformat='worldcereal', orbits=s1_orbits)
        collections['METEO'] = AgERA5TrainingCollection(
            df, dataformat='worldcereal')
        collections['DEM'] = DEMCollection(
            folder='/data/MEP/DEM/COP-DEM_GLO-30_DTED/S2grid_20m')
        collections['WorldCover'] = WorldCoverCollection(
            folder='/data/worldcereal/auxdata/WORLDCOVER/2020/')

        return collections

    def _load_arrays(self, start_date: str, end_date: str, **inputfiles):

        arrays = {}
        common_sampleIDs = None

        for sensor in inputfiles:
            logger.info(f'Loading {sensor} file: {inputfiles[sensor]}')

            # Open dataset. NOTE: lock=False ensures we can pickle
            # the dataset later on for spark processing!
            ds = xr.open_dataset(inputfiles[sensor], lock=False)
            current_sampleIDs = list(ds.sampleID.values)

            if common_sampleIDs is None:
                common_sampleIDs = current_sampleIDs
            else:
                common_sampleIDs = list(set(
                    current_sampleIDs).intersection(common_sampleIDs))

            # These are time series features: need to add 1D x and y
            # for feature computation to work
            ds = ds.expand_dims({'x': 1, 'y': 1})

            # # filter data to keep only the required time period
            # (+ buffer of 1 month for compositing purposes)
            start_date_buf = pd.to_datetime(start_date) - pd.Timedelta(days=30)
            end_date_buf = pd.to_datetime(end_date) + pd.Timedelta(days=30)
            ds = ds.sel(t=slice(start_date_buf, end_date_buf))

            arrays[sensor] = ds.to_array(dim='bands')

        # All files loaded, now subset on common sampleIDs
        common_sampleIDs = list(set(common_sampleIDs))
        logger.info(f'Found {len(common_sampleIDs)} common sampleIDs.')

        common_arrays = {}

        logger.info('Creating sample-specific arrays ...')
        for sampleID in common_sampleIDs:
            common_arrays[sampleID] = {}

            for sensor in inputfiles:
                self._check_sampleID_uniqueness(arrays[sensor], sampleID)
                common_arrays[sampleID][sensor] = arrays[sensor].sel(
                    feature=arrays[sensor][
                        self.idattr] == sampleID).isel(
                    feature=0)  # Select first one after subsetting

        return common_arrays

    def _merge(self, feature_df_files, outdir):

        outfile = str(Path(outdir) / 'merged_features_df.parquet')
        # make sure the merged df is not part of dfs to merge
        dfs = [pd.read_parquet(f) for f in feature_df_files
               if f != outfile]
        logger.info(f'Merging {len(dfs)} feature DFs ...')
        merged_df = pd.concat(dfs)
        merged_df.index = list(range(len(merged_df)))

        self.save_features(merged_df, outfile)
        logger.success(('Successfully merged features of '
                        f'{len(merged_df)} samples!'))

    def _check_sampleID_uniqueness(self, array, sampleID, fail=False):
        matching_samples = (array.sampleID == sampleID).values.sum()
        if matching_samples > 1:
            logger.warning((f'{self.idattr} `{sampleID}` occurred '
                            f'{matching_samples} times in array: '
                            'taking first occurrence!'))
        if fail:
            raise RuntimeError(
                f'{self.idattr} not unique which is not allowed.')


class Pipeline:
    def __init__(self,
                 extractions_dir: Path,
                 features_dir: Path,
                 start_month: int,
                 end_month: int,
                 feature_settings: Dict,
                 preproc_settings: Dict = None,
                 collection_settings=None,
                 postprocessing_settings=None,
                 scenario: str = None,
                 overwrite: bool = False,
                 debug: bool = False,
                 sparkcontext=None,
                 nrpixels=1,
                 segmentation=False,
                 segm_settings=None,
                 erode_fields=True,
                 filter_tasks=None
                 ):

        self.extractions_dir = extractions_dir
        self.features_dir = features_dir
        self.preprocessing_settings = preproc_settings
        self.feature_settings = feature_settings
        self.start_month = start_month
        self.end_month = end_month
        self._debug = debug
        self._overwrite = overwrite
        self._sc = sparkcontext
        self.labelattr = 'CT'
        self.idattr = 'sampleID'
        self.nrpixels = nrpixels
        self.segmentation = segmentation
        self.segm_settings = segm_settings
        self.erode_fields = erode_fields
        self.filter_tasks = filter_tasks
        self.postprocessing_settings = postprocessing_settings
        if 'max_gap' not in collection_settings.keys():
            collection_settings['max_gap'] = MAX_GAP
        self.collection_settings = collection_settings
        self.featcomputer = TrainingFeaturesComputer

        if scenario is not None:
            self.features_dir = Path(self.features_dir) / scenario

    def save_features(self, features, outfile):
        logger.info(f'Saving {len(features)} features to: {outfile}')
        features.to_parquet(outfile)

    @ staticmethod
    def _get_df(features):
        if features is None:
            return None
        else:
            return features.df

    def _save_settings(self, outdir: Path, **settings):

        def _make_serial(d):
            if isinstance(d, dict):
                return {_make_serial(k): _make_serial(v)
                        for k, v in d.items()}
            elif isinstance(d, list):
                return [_make_serial(i) for i in d]
            elif callable(d):
                return d.__name__
            elif (isinstance(d, float)) and (np.isnan(d)):
                return 'NaN'
            else:
                return d

        settings_file = Path(outdir) / 'Feature_settings.json'
        logger.info(f'Saving settings to: {settings_file}')

        if settings_file.is_file() and not self._overwrite:
            # Need to check if settings have not changed
            old_settings = json.load(open(settings_file))
            assert old_settings == _make_serial(settings)
        else:
            with open(settings_file, 'w') as f:
                json.dump(_make_serial(settings), f, indent=4)

    def _merge(self, feature_df_files, outdir):

        outfile = str(Path(outdir) / 'merged_features_df.parquet')
        # make sure the merged df is not part of dfs to merge
        dfs = [pd.read_parquet(f) for f in feature_df_files
               if f != outfile]
        logger.info(f'Merging {len(dfs)} feature DFs ...')
        merged_df = pd.concat(dfs)
        merged_df.index = list(range(len(merged_df)))

        self.save_features(merged_df, outfile)
        logger.success(('Successfully merged features of '
                        f'{len(merged_df)} samples!'))

    def get_processing_tasks(self):
        # One processing task per ref_id

        # check for which ref ids we have a database.json file...
        db_paths = glob.glob(
            str(self.extractions_dir / '*' / 'database' / 'database.json'))
        if len(db_paths) == 0:
            tasks = None
        else:
            logger.info(f'Found {len(db_paths)} processing tasks ...')
            tasks = [Path(x).parents[1].name for x in db_paths]

        # filter tasks if necessary
        if self.filter_tasks is not None:
            tasks = [t for t in tasks if t in self.filter_tasks]

        return tasks

    def run(self, merge: bool = True):

        # Create the output directory
        self.features_dir.mkdir(exist_ok=True, parents=True)

        # Get processing tasks
        self._tasks = self.get_processing_tasks()

        if self._tasks is None:
            logger.info('No tasks to process, aborting!')
            return

        # Compile general settings
        general_settings = {
            'start_month': self.start_month,
            'end_month': self.end_month,
            'nrpixels_per_class': self.nrpixels,
            'erode_fields': self.erode_fields,
            'segmentation': self.segmentation,
            'segm_settings': self.segm_settings
        }
        # Save the settings
        self._save_settings(
            self.features_dir,
            general=general_settings,
            preprocessing=self.preprocessing_settings,
            feature=self.feature_settings,
            collection=self.collection_settings
        )
        # if post-processing required,
        # also save settings in final outdir
        if self.postprocessing_settings is not None:
            tag = self.postprocessing_settings.get('tag')
            self.final_dir = self.features_dir / f'Postprocess_{tag}'
            self.final_dir.mkdir(exist_ok=True, parents=True)
            self._save_settings(self.final_dir,
                                general=general_settings,
                                preprocessing=self.preprocessing_settings,
                                feature=self.feature_settings,
                                collection=self.collection_settings,
                                postprocessing=self.postprocessing_settings)

        # Process tasks sequentially
        feature_df_files = []
        for i, task in enumerate(self._tasks):
            logger.info((f'Launching task: {task} '
                         f'({i+1}/{len(self._tasks)})'))
            feature_df_files.append(self.run_task(task))

            if self._debug and i == 1:
                logger.debug('Interrupting loop in debug mode.')
                break

        # Filter out any failed tasks
        feature_df_files = [f for f in feature_df_files if f is not None]
        # If no task was processed, return here
        if len(feature_df_files) == 0:
            return

        if merge:
            outdir = Path(feature_df_files[0]).parent
            # make sure to merge all feature dataframes in the folder!
            feature_df_files = glob.glob(
                str(outdir / '*_features_df.parquet'))
            self._merge(feature_df_files, outdir)

    def run_task(self, taskname: str):

        # Create the output directory
        self.features_dir.mkdir(exist_ok=True)

        # Construct outfile path
        outfile = Path(self.features_dir) / f'{taskname}_features_df.parquet'

        # Prepare df in the correct format for feature computation
        db_file = (self.extractions_dir / taskname / 'database' /
                   'database.json')
        if not db_file.is_file():
            raise FileNotFoundError(
                f'Database json {db_file} not found, cannot continue!')
        task_df, task_db = self.load_database(db_file)

        # if outfile already exists and overwrite False,
        # check which sampleIDs have been processed and remove from db and df
        process = True
        features_df = None
        if outfile.is_file():
            if not self._overwrite:
                features_df = pd.read_parquet(outfile)
                processed = list(features_df.sampleID.unique())
                logger.info(f'Found {len(processed)} sampleIDs '
                            'which were already processed!')
                task_db = task_db.loc[~task_db[self.idattr].isin(processed)]
                task_df = task_df.loc[~task_df[self.idattr].isin(processed)]
                if len(task_db) == 0:
                    logger.info('No more samples, skip processing!')
                    process = False
                else:
                    logger.info(f'{len(task_db)} remaining samples to process')
            else:
                # if we need to overwrite, delete existing!
                logger.info('Features df already existed, deleting!')
                outfile.unlink()

        # start processing
        if process:
            # infer the year of the dataset from task
            year = int(taskname.split('_')[0])

            # Get processing dates
            start_date, end_date = get_processing_dates(self.start_month,
                                                        self.end_month,
                                                        year)
            logger.info(f'Start date: {start_date}; '
                        f'End date: {end_date}')

            # Create training collections
            collection_settings = deepcopy(self.collection_settings)
            collections = self.create_collections(task_df, collection_settings)

            # Get local copies of the settings
            preproc_settings = deepcopy(self.preprocessing_settings)
            feature_settings = deepcopy(self.feature_settings)
            segmentation = deepcopy(self.segmentation)
            segm_settings = deepcopy(self.segm_settings)
            erode_fields = deepcopy(self.erode_fields)
            nrpixels = deepcopy(self.nrpixels)

            # Impute processing dates into preprocessing settings
            if preproc_settings is not None:
                for sensor in preproc_settings.keys():
                    preproc_settings[sensor]['composite']['start'] = start_date  # NOQA
                    preproc_settings[sensor]['composite']['end'] = end_date

            # Initialize FeatureComputer
            featurecomputer = self.featcomputer(
                feature_settings,
                start_date,
                end_date,
                preproc_settings,
                nrpixels=nrpixels,
                segmentation=segmentation,
                segm_settings=segm_settings,
                erode_fields=erode_fields,
                max_gap=collection_settings.get('max_gap'))

            # Run feature computation
            if self._sc is None:
                logger.info('Running feature computation in serial.')
                features = self._run_task_local(collections,
                                                task_db,
                                                featurecomputer)

            else:
                logger.info('Running feature computation on executors.')
                features = self._run_task_spark(collections,
                                                task_db,
                                                featurecomputer)

            # Translate the labels back to strings
            # and merge with previously computed features
            if features is not None:
                logger.info('Converting labels to strings...')
                features['LABEL'] = labels_to_str(
                    list(features['LABEL'].values))
                if features_df is not None:
                    logger.info(
                        'Merging features with previous results for this task...')
                    features_df = pd.concat([features_df, features], axis=0)
                    features_df.index = list(range(len(features_df)))
                else:
                    features_df = features.copy()
            else:
                logger.warning('Could not compute new features for this task!')

        if features_df is not None:
            # Save the result
            self.save_features(features_df, outfile)

            # postprocessing of features...
            if self.postprocessing_settings is not None:
                outfile = self._postprocessing(features_df, taskname)

        return outfile

    def _run_task_local(self, collections, task_db,
                        featurecomputer):
        def get_feat_df(sampleID):
            '''Wrapper function to allow to use run_parallel
            '''
            return get_sample_features(sampleID,
                                       collections,
                                       task_db,
                                       featurecomputer,
                                       self._debug,
                                       self.labelattr,
                                       self.idattr)

        sampleIDs = task_db[self.idattr].to_list()

        if self._debug:
            logger.info('Debugging mode: limiting ourselves to 5 samples')
            sampleIDs = sampleIDs[:5]

        results = run_parallel(get_feat_df,
                               sampleIDs,
                               max_workers=MAX_WORKERS)
        results = [result for result in results if result is not None]

        if len(results) > 0:
            df = pd.concat(results, ignore_index=True)
        else:
            df = None

        return df

    def _run_task_spark(self, collections,
                        task_db, featurecomputer):

        sampleIDs = task_db[self.idattr].to_list()
        debug = self._debug
        labelattr = self.labelattr
        idattr = self.idattr

        if debug:
            sampleIDs = sampleIDs[:5]

        rdd = self._sc.parallelize(sampleIDs,
                                   len(sampleIDs)).map(
            lambda x: get_sample_features(
                x, collections, task_db,
                featurecomputer, debug,
                labelattr, idattr)
        ).filter(lambda t: t is not None)

        # Check if we actually have something left
        if rdd.isEmpty():
            logger.warning('No samples left in RDD -> aborting job')
            return None

        results = rdd.collect()
        df = pd.concat(results, ignore_index=True)

        return df

    def _postprocessing(self, df, taskname):

        logger.info('START POST-PROCESSING')

        # remove post-processed features if they already exist
        # as post-processing always needs to be done from scratch
        # (e.g. outlier detection)
        outfile = self.final_dir / f'{taskname}_features_df.parquet'
        if outfile.is_file():
            logger.info('Features df already existed, deleting!')
            outfile.unlink()

        # if in debugging mode, limit ourselves to max 100 samples
        if self._debug:
            if len(df) > 100:
                logger.info('Limiting to 100 samples in debug mode!')
                df = df.iloc[0:100]

        # Step one: select tsteps
        tstep_selection = self.postprocessing_settings.get(
            'tstep_selection', None)
        if tstep_selection is not None:
            month_start = tstep_selection[0]
            month_end = tstep_selection[1]
            df = select_tsteps(df, month_start, month_end)

        # Step two: translate crop type labels to relevant classes
        translation_key = self.postprocessing_settings.get(
            'translation_key', None)
        if translation_key is not None:
            df = translate_labels(df, translation_key)

        # Step three: Selection of training classes
        classes_key = self.postprocessing_settings.get(
            'training_classes', None)
        if classes_key is not None:
            df = select_training_classes(df, classes_key)

        # Step four: apply filters to get rid of obvious
        # erronous training data
        filters = self.postprocessing_settings.get(
            'filters', None)
        if filters is not None:
            df = filter_df(df, filters)

        # Step five: additional outlier detection
        outlier_detection = self.postprocessing_settings.get(
            'outlier_detection', False)
        if outlier_detection:
            outlier_inputs = self.postprocessing_settings.get(
                'outlier_inputs')
            df = remove_outliers(df, outlier_inputs)

        # save features to file
        df.to_parquet(outfile)
        logger.success(f'Final result written to: {outfile}!')

        return outfile


class GEE_Pipeline(Pipeline):
    def __init__(self,
                 *args,
                 **kwargs):
        super().__init__(*args, **kwargs)

        self.labelattr = 'label'
        self.featcomputer = GEETrainingFeaturesComputer

    def load_database(self, file: Path):
        logger.info(f'Loading database file: {file}')
        db = gpd.read_file(file)
        logger.info(f'Database successfully opened. Found {len(db)} samples.')

        # Convert into SATIO format
        # TODO: maybe not all of these are needed!
        df = db[[self.idattr, 'tile', 'epsg', 'path',
                 'ref_id', 'year', 'start_date', 'end_date', 'bounds']]

        return df, db

    def create_collections(self, df, settings=None):

        collections = {}

        if self.idattr != 'location_id':
            df = df.rename(columns={self.idattr: 'location_id'})

        # Collections we always need
        collections['LABELS'] = PatchLabelsTrainingCollection(
            df, dataformat='worldcereal')
        collections['S2'] = L2ATrainingCollection(
            df, dataformat='worldcereal')
        if settings is not None:
            s1_orbits = settings.get('S1_orbits', None)
        else:
            s1_orbits = None
        collections['S1'] = SIGMA0HRLTrainingCollection(
            df, dataformat='worldcereal', orbits=s1_orbits)
        collections['METEO'] = AgERA5TrainingCollection(
            df, dataformat='worldcereal')
        collections['DEM'] = DEMCollection(
            folder='/data/MEP/DEM/COP-DEM_GLO-30_DTED/S2grid_20m')
        collections['WorldCover'] = WorldCoverCollection(
            folder='/data/worldcereal/auxdata/WORLDCOVER/2020/')

        return collections


class OpenEO_Pipeline(Pipeline):
    def __init__(self,
                 *args,
                 **kwargs):
        super().__init__(*args, **kwargs)
        self.labelattr = 'CT_fin'
        self.featcomputer = OpenEOTrainingFeaturesComputer

    def load_database(self, file: Path):
        logger.info(f'Loading database file: {file}')
        db = gpd.read_file(file)
        logger.info(f'Database successfully opened. Found {len(db)} samples.')

        # Convert into SATIO format
        df = db[[self.idattr, 'samplefile', 'data_path',
                 'validityTi', 'start_date', 'end_date']]

        return df, db

    def create_collections(self, df, settings=None):

        collections = {}

        if self.idattr != 'location_id':
            df = df.rename(columns={self.idattr: 'location_id'})

        collections['OpenEO'] = OpenEOTrainingCollection(df)

        return collections
