from typing import Dict, List
from pathlib import Path
import json
from copy import deepcopy
import glob

import xarray as xr
import numpy as np
import geopandas as gpd
from loguru import logger
import pandas as pd
import geopandas as gpd
from satio.features import Features
from satio.timeseries import Timeseries
from satio.utils import run_parallel
from satio.collections import (TrainingCollection,
                               L2ATrainingCollection,
                               PatchLabelsTrainingCollection,
                               SIGMA0TrainingCollection,
                               AgERA5TrainingCollection,
                               DEMCollection,
                               WorldCoverCollection)

from cropclass.utils.seasons import get_processing_dates


MAX_WORKERS = 1
MAX_GAP = 60


def get_sample_features(sampleID,
                        collections,
                        database,
                        featurecomputer,
                        debug,
                        labelattr,
                        idattr):
    """Main method to get the features for one sample. The method
    is defined outside the pipeline class to avoid pickling and
    sparkcontext issues when running on spark.
    """

    logger.info(f'Starting feature computation for {idattr}: {sampleID}')

    try:
        features_df = featurecomputer.get_features(collections,
                                                   sampleID,
                                                   database,
                                                   labelattr,
                                                   idattr)
    except Exception as e:
        if debug:
            raise e
        logger.error(f"Error processing {idattr} = {sampleID}: "
                     f"{e}.")
        return None

    logger.info(('Feature computation succesfull '
                 f'for {idattr}: {sampleID}'))

    return features_df


class GapError(Exception):
    pass


class FeaturesComputer():
    '''Class that computes classification features
    '''

    def __init__(self,
                 feature_settings,
                 start_date,
                 end_date,
                 preprocessing_settings=None,
                 resolution=10,
                 max_gap=MAX_GAP):
        self.feature_settings = feature_settings
        self.preprocessing_settings = preprocessing_settings
        self.start_date = start_date
        self.end_date = end_date
        self.resolution = resolution
        self.max_gap = max_gap

    def get_features(self, inputs: Dict[str, xr.DataArray]):

        logger.info('-' * 50)
        logger.info('STARTING FEATURE COMPUTATION')
        logger.info('-' * 50)

        self.sensors = list(inputs.keys())
        logger.info(f'Got data for following sensors: {self.sensors}')

        # Do check on inputs
        self._check_inputs(inputs)

        # Transform to satio Timeseries
        ts = self._to_satio_timeseries(inputs)

        # Perform sensor-specific gap checks
        self._perform_gap_checks(ts,
                                 self.start_date,
                                 self.end_date,
                                 self.max_gap)

        # Preprocess the timeseries
        if self.preprocessing_settings is not None:
            ts = self.preprocess(ts)

        # Compute the features
        features = self._compute_features(ts)

        return features

    def preprocess(self, ts: Dict[str, Timeseries]):
        logger.info('---- PREPROCESSING INPUTS START')

        ts_proc = {}

        for sensor in self.sensors:
            if sensor not in self.preprocessing_settings:
                logger.warning(('No preprocessing settings found for: '
                                f'{sensor} -> skipping!'))
                ts_proc[sensor] = ts[sensor]
            else:
                logger.info(f'Preprocessing: {sensor}')
                ts_proc[sensor] = self._preprocess_sensor(
                    ts[sensor],
                    self.preprocessing_settings[sensor])

        logger.info('---- PREPROCESSING INPUTS DONE')

        return ts_proc

    def _compute_features(self, ts: Dict[str, Timeseries]):
        logger.info('---- COMPUTING FEATURES START')

        features = {}

        for sensor in self.sensors:
            if sensor not in self.feature_settings:
                raise RuntimeError(
                    ('No feature settings found for: '
                     f'{sensor} -> cannot compute features!'))
            else:
                logger.info(f'Feature computation: {sensor}')
                features[sensor] = self._compute_features_sensor(
                    sensor,
                    ts[sensor],
                    self.feature_settings[sensor])

        # If there is more than one sensor we need to merge the features
        features = list(features.values())
        features = Features.from_features(*features)

        logger.info((f'Successfully extracted '
                     f'{len(features.names)} features.'))
        logger.info('---- COMPUTING FEATURES DONE')

        return features

    def _preprocess_sensor(self, ts: Timeseries, settings: Dict):

        # Cast to proper dtype
        logger.info(f'Casting to dtype: {settings["dtype"]}')
        ts.data = ts.data.astype(settings['dtype'])

        if settings.get('bands', None) is not None:
            # Got a list of bands: make the selection first
            band_list = settings.get('bands')
            logger.info(f'Band selection: {band_list}')
            ts = ts.select_bands(band_list)

        if settings.get('pre_func', None) is not None:
            # We got function we need to apply before starting
            # preprocessing
            pre_func = settings['pre_func']
            logger.info(f'Applying pre_func: {pre_func}')
            ts.data = pre_func(ts.data)

        if 'composite' in settings:
            logger.info('Compositing ...')
            ts = ts.composite(**settings['composite'])

        if settings.get('interpolate', False):
            logger.info('Interpolating ...')
            ts = ts.interpolate()

        if settings.get('post_func', None) is not None:
            # We got function we need to apply after
            # preprocessing
            post_func = settings['post_func']
            logger.info(f'Applying post_func: {post_func}')
            ts.data = post_func(ts.data)

        return ts

    def _compute_features_sensor(self, sensor: str, ts: Timeseries,
                                 settings: Dict):

        features = ts.features_from_dict(
            self.resolution,
            features_meta=settings,
            chunk_size=None)

        # append source to feature names
        featnames = features.names
        featnames = [sensor + '-' + f for f in featnames]
        features.names = featnames

        return features

    @staticmethod
    def _check_inputs(inputs: Dict[str, xr.DataArray]):
        req_dims = ['bands', 't', 'x', 'y']

        for sensor in inputs.keys():
            for dim in req_dims:
                if dim not in inputs[sensor].dims:
                    raise KeyError(
                        (f'Dimension `{dim}` not found '
                         f'in input DataArray of sensor `{sensor}`'))

    @staticmethod
    def _to_satio_timeseries(inputs: Dict[str, xr.DataArray]) -> Dict[str, Timeseries]:  # NOQA
        '''Method to transform input DataArray to satio Timeseries
        '''

        logger.info('Transforming xr.DataArray into satio.Timeseries ...')

        satio_ts = {}

        for sensor in inputs.keys():

            da = inputs[sensor]

            # Make sure we have a controlled dimension order
            da = da.transpose('bands', 't', 'x', 'y')

            # Transform the DataArray to satio Timeseries
            ts = Timeseries(
                data=da.values,
                timestamps=list(da.coords['t'].values),
                bands=list(da.coords['bands'].values)
            )

            satio_ts[sensor] = ts

        return satio_ts

    @staticmethod
    def _perform_gap_checks(ts: Dict[str, Timeseries],
                            start_date: str,
                            end_date: str,
                            max_gap: int):

        def raise_gap_failure(gap, gapkind, threshold, sensor):
            msg = (f'Incomplete timeseries `{sensor}`: '
                   f'got a value of {gap} days for `{gapkind}` '
                   f'which exceeds the threshold of {threshold}.')
            logger.error(msg)
            raise GapError(msg)

        start_date = pd.to_datetime(start_date)
        end_date = pd.to_datetime(end_date)

        for sensor in ts.keys():
            ts_sensor = ts[sensor]
            ts_start = ts_sensor.timestamps.min()
            ts_end = ts_sensor.timestamps.max()

            # Check timeseries start
            gapstart = ts_start - start_date
            gapstart = gapstart.days
            gapstart = 0 if gapstart < 0 else gapstart

            # Check timeseries end
            gapend = end_date - ts_end
            gapend = gapend.days
            gapend = 0 if gapend < 0 else gapend

            # Check max gap
            maxgap = pd.Series(ts_sensor.timestamps).diff().max().days

            # Report on completeness
            logger.info(f'{sensor} first obs: {gapstart}')
            logger.info(f'{sensor} last obs: {gapend}')
            logger.info(f'{sensor} largest gap: {maxgap}')

            # Fail processing if timeseries is incomplete
            if gapstart > max_gap:
                raise_gap_failure(gapstart, 'gapstart', max_gap, sensor)
            if gapend > max_gap:
                raise_gap_failure(gapend, 'gapend', max_gap, sensor)
            if maxgap > max_gap:
                raise_gap_failure(maxgap, 'maxgap', max_gap, sensor)


class TrainingFeaturesComputer(FeaturesComputer):
    def __init__(self,
                 feature_settings,
                 start_date,
                 end_date,
                 preprocessing_settings=None,
                 resolution=10,
                 max_gap=MAX_GAP,
                 nrpixels=1):

        super().__init__(feature_settings,
                         start_date,
                         end_date,
                         preprocessing_settings=preprocessing_settings,
                         resolution=resolution,
                         max_gap=max_gap)

        self.nrpixels = nrpixels

    def get_features(self,
                     collections: Dict[str, TrainingCollection],
                     sampleID: str,
                     database,
                     labelattr,
                     idattr):

        bounds, epsg = self._get_location_bounds(sampleID, database, idattr)
        collections = self._filter_collections(collections, sampleID,
                                               bounds, epsg)

        # Get labels first
        labels = self.get_labels(collections['LABELS'],
                                 labelattr)

        if not np.any(labels.data != 0):
            # Sample has no crop pixel!
            logger.info('No crop type information in this sample: skipping')
            return None

        # In TrainingFeatures we start from satio collections
        # which need to be transformed to xarray.DataArrays
        inputs = self._collections_to_inputs(collections, bounds)

        # Get features from main method
        features = super().get_features(inputs)

        # Add the labels
        features = features.merge(labels)

        # Add several attributes
        features = self.add_attributes(database, sampleID,
                                       idattr, features)

        # Add DEM features
        features = self.add_dem(inputs, bounds, epsg, features)

        # Sample from available pixels
        sampled = self.sample_pixels(features.df, self.nrpixels)

        return sampled

    def add_attributes(self, database, sampleID, idattr,
                       features: Features, attributes: List = None):

        def _add_attr(features, attr_value, attr_name):
            if attr_name in features.attrs:
                # Check if value matches!
                assert attr_value == features.attrs[attr_name]
            else:
                features = features.add_attribute(attr_value, attr_name)
            return features

        attributes = attributes or ['label', 'year', 'ref_id']

        logger.info('---- EXTRACTING ATTRIBUTES START')
        logger.info(f'Looking for attributes: {attributes}')

        sample = database.set_index(idattr).loc[sampleID]

        # Add attributes
        features = _add_attr(features, sampleID, idattr)
        features = _add_attr(features, int(sample['year']), 'year')
        features = _add_attr(features, sample['ref_id'], 'ref_id')

        # Add lat/lon as attributes
        lat, lon = self._get_lat_lon(sample)
        features = features.add_attribute(lat, 'lat')
        features = features.add_attribute(lon, 'lon')

        logger.info((f'Successfully extracted '
                     f'{len(features.attrs)} attributes.'))
        logger.info('---- EXTRACTING ATTRIBUTES DONE')

        return features

    def add_dem(self, inputs, bounds, epsg, features):
        from satio.collections import DEMCollection
        logger.info('Computing DEM features for sample ...')

        # Create and filter DEM collection
        demcoll = DEMCollection(('/data/MEP/DEM/COP-DEM_GLO-30_DTED'
                                 '/S2grid_20m'))
        demcoll = demcoll.filter_bounds(bounds, epsg)

        # Compute DEM features
        dem_feat = Features.from_dem(demcoll).select(
            ['DEM-alt-20m', 'DEM-slo-20m'])

        # Merge with other features
        return features.merge(dem_feat)

    def get_labels(self, labels_collection, labelattr):

        # Load the labels
        labelsdata = labels_collection.load()

        labelfeatures = Features(data=labelsdata['label'], names=[
            'LABEL'], dtype='uint64')

        return labelfeatures

    @staticmethod
    def _get_lat_lon(sample):

        lat = sample['round_lat']
        lon = sample['round_lon']

        return lat, lon

    def _collections_to_inputs(self, collections, bounds):

        inputs = dict()
        inputs['S2'] = self._load_S2(collections['S2'])
        inputs['S1'] = self._load_S1(collections['S1'])
        inputs['METEO'] = self._load_METEO(collections['METEO'], bounds)

        return inputs

    @staticmethod
    def _get_location_bounds(location_id, trainingdb, idattr):
        trainingdb_row = trainingdb[
            trainingdb[idattr] == location_id].iloc[0]
        epsg = int(trainingdb_row['epsg'])
        bounds = trainingdb_row['bounds']

        if isinstance(bounds, str):
            bounds = eval(bounds)
        bounds = np.array(bounds)

        return bounds, epsg

    @staticmethod
    def _filter_collections(collections, location_id, bounds, epsg):

        new_colls = collections.copy()

        for s in new_colls.keys():
            if s in ['S2', 'S1', 'METEO', 'LABELS']:
                new_colls[s] = (new_colls[s]
                                .filter_location(location_id))

            elif s in ['DEM', 'WorldCover']:
                new_colls[s] = (new_colls[s]
                                .filter_bounds(bounds, epsg))

        return new_colls

    @staticmethod
    def _load_S2(coll):
        from cropclass.utils.masking import scl_mask

        logger.info('Loading S2 bands ...')

        bands_10m = coll.load_timeseries('B02', 'B03', 'B04', 'B08')
        bands_20m = coll.load_timeseries('B05', 'B06', 'B07', 'B11', 'B12')
        bands = bands_10m.merge(bands_20m.upsample())

        scl = coll.load_timeseries('SCL').upsample()
        mask = scl_mask(scl.data)

        bands = bands.mask(mask, drop_nodata=False)
        bands.data = bands.data.astype(np.float32)
        bands.data[bands.data == 0] = np.nan

        bands = bands.to_xarray().rename({'time': 't'})

        return bands

    @staticmethod
    def _load_S1(coll):
        # TODO: SWITCH TO 10M EXTRACTIONS AND DATA LOADING!

        logger.info('Loading S1 bands ... TODO: SWITCH TO 10M')

        bands_20m = coll.load_timeseries('VV', 'VH', mask_and_scale=True)
        bands = bands_20m.upsample()
        bands = bands.to_xarray().rename({'time': 't'})

        return bands

    @staticmethod
    def _load_METEO(coll, bounds):

        logger.info('Loading METEO bands ...')

        bands = coll.load_timeseries('temperature_mean')

        # Get desired output shape
        outdim = int((bounds[2] - bounds[0]) / 10)
        # Nr of required upsamplings
        upsampling = int(np.log2(outdim))
        # Resize the AgERA5 data
        for _ in range(upsampling):
            bands = bands.upsample()

        bands = bands.to_xarray().rename({'time': 't'})

        return bands

    @ staticmethod
    def sample_pixels(df: pd.DataFrame,
                      nrpixels: int,
                      label='LABEL',
                      nodata_value=0,
                      seed=1234):
        '''
        Method to sample one or more valid pixels
        from the training dataframe
        '''
        if df is None:
            return None

        np.random.seed(seed)

        labeldata = df[label].values

        idxvalid = np.where(labeldata != nodata_value)[0]

        if len(idxvalid) == 0:
            return None
        elif len(idxvalid) > nrpixels:

            # Make DataFrame with valid pixels
            labeldf = pd.DataFrame(index=idxvalid, data=labeldata[idxvalid],
                                   columns=['label'])

            # Get the pixel count for smallest class
            smallest = labeldf.value_counts().min()

            # Get the to be sampled pixel amount per class
            sample_pixels = np.min([nrpixels, smallest])

            # Do the class sampling
            sampled = labeldf.groupby('label').sample(sample_pixels)

            # Get the indexes of sampled pixels
            idxchoice = sampled.index.to_list()

            # Get the subset DF
            df_sample = df.iloc[idxchoice, :]
        else:
            df_sample = df.iloc[idxvalid, :]

        return df_sample


# Only overriding the _load_S1 method of the TrainingFeaturesComputer, as the cropclass extraction works with S1 data at
# 10m resolution, extactly like the other bands.
class GEETrainingFeaturesComputer(TrainingFeaturesComputer):

    def __init__(self,
                 feature_settings,
                 start_date,
                 end_date,
                 preprocessing_settings=None,
                 resolution=10,
                 max_gap=MAX_GAP,
                 nrpixels=1):
        super().__init__(feature_settings,
                         start_date,
                         end_date,
                         preprocessing_settings=preprocessing_settings,
                         resolution=resolution,
                         max_gap=max_gap)

        self.nrpixels = nrpixels

    @staticmethod
    def _load_S1(coll):
        logger.info('Loading S1 bands ...')
        bands = coll.load_timeseries('VV', 'VH', mask_and_scale=True)
        bands = bands.to_xarray().rename({'time': 't'})

        return bands


class Pipeline:
    def __init__(self,
                 database: Path,
                 features_dir: Path,
                 start_doy: int,
                 end_doy: int,
                 feature_settings: Dict,
                 preproc_settings: Dict = None,
                 scenario: str = None,
                 overwrite: bool = False,
                 debug: bool = False,
                 sparkcontext=None):
        self.database = database
        self.features_dir = features_dir
        self.preprocessing_settings = preproc_settings
        self.feature_settings = feature_settings
        self.start_doy = start_doy
        self.end_doy = end_doy
        self._debug = debug
        self._overwrite = overwrite
        self._sc = sparkcontext
        self.labelattr = 'CT'
        self.idattr = 'location_id'

        if scenario is not None:
            self.features_dir = Path(self.features_dir) / scenario

    def run(self, merge: bool = True):

        # Create the output directory
        self.features_dir.mkdir(exist_ok=True)

        # Load the database
        self.load_database(self.database)

        # Get processing tasks
        self._tasks = self.get_processing_tasks()

        # Save the settings
        self._save_settings(
            self.features_dir,
            preprocessing_settings=self.preprocessing_settings,
            feature_settings=self.feature_settings
        )

        # Process tasks sequentially
        feature_df_files = []
        for i, task in enumerate(self._tasks):
            logger.info((f'Launching task: {task} '
                         f'({i+1}/{len(self._tasks)})'))
            feature_df_files.append(self.run_task(task))

            if self._debug and i == 1:
                logger.debug('Interrupting loop in debug mode.')
                break

        # Filter out any failed tasks
        feature_df_files = [f for f in feature_df_files if f is not None]

        if merge:
            self._merge(feature_df_files, self.features_dir)

    def run_task(self, taskname: str):

        # Create the output directory
        self.features_dir.mkdir(exist_ok=True)

        # Construct outfile path
        outfile = Path(self.features_dir) / f'{taskname}_features_df.parquet'

        if outfile.is_file() and not self._overwrite:
            logger.info(f'Output file `{outfile}` exists -> skipping')
            return outfile

        # infer the year of the dataset from task
        year = taskname.split('_')[0]

        # Get processing dates
        start_date, end_date = get_processing_dates(self.start_doy,
                                                    self.end_doy,
                                                    year)

        # Create training collections
        task_df = self.df[self.df['ref_id'] == taskname]
        task_db = self.db[self.db['ref_id'] == taskname]
        collections = self.create_collections(task_df)

        # Get local copies of the settings
        preproc_settings = deepcopy(self.preprocessing_settings)
        feature_settings = deepcopy(self.feature_settings)

        # Impute processing dates into preprocessing settings
        if preproc_settings is not None:
            for sensor in preproc_settings.keys():
                preproc_settings[sensor]['composite']['start'] = start_date  # NOQA
                preproc_settings[sensor]['composite']['end'] = end_date

        # Initialize FeatureComputer
        featurecomputer = TrainingFeaturesComputer(feature_settings,
                                                   start_date,
                                                   end_date,
                                                   preproc_settings)

        # Run feature computation
        if self._sc is None:
            logger.info('Running feature computation in serial.')
            features = self._run_task_local(taskname, collections,
                                            task_db,
                                            featurecomputer)

        else:
            logger.info('Running feature computation on executors.')
            features = self._run_task_spark(taskname, collections,
                                            task_db,
                                            featurecomputer)

        # Save the result
        if features is not None:
            self.save_features(features, outfile)
            return outfile
        else:
            logger.warning('Could not compute features for this task!')
            return None

    def save_features(self, features, outfile):
        logger.info(f'Saving features df to: {outfile}')
        features.to_parquet(outfile)

    def _run_task_local(self, task, collections, task_db,
                        featurecomputer):
        def get_feat_df(sampleID):
            '''Wrapper function to allow to use run_parallel
            '''
            return get_sample_features(sampleID,
                                       collections,
                                       task_db,
                                       featurecomputer,
                                       self._debug,
                                       self.labelattr,
                                       self.idattr)

        sampleIDs = self.db[self.db['ref_id'] == task][self.idattr].to_list()

        if self._debug:
            sampleIDs = sampleIDs[:5]

        results = run_parallel(get_feat_df,
                               sampleIDs,
                               max_workers=MAX_WORKERS)
        results = [result for result in results if result is not None]

        df = pd.concat(results, ignore_index=True)

        return df

    def _run_task_spark(self, task, collections,
                        task_db, featurecomputer):

        sampleIDs = self.db[self.db['ref_id'] == task][self.idattr].to_list()
        debug = self._debug

        if debug:
            sampleIDs = sampleIDs[:5]

        rdd = self._sc.parallelize(sampleIDs,
                                   len(sampleIDs)).map(
            lambda x: get_sample_features(
                x, collections, task_db,
                featurecomputer, debug,
                self.labelattr, self.idattr)
        ).filter(lambda t: t is not None)

        # Check if we actually have something left
        if rdd.isEmpty():
            logger.warning('No samples left in RDD -> aborting job')
            return None

        results = rdd.collect()
        df = pd.concat(results, ignore_index=True)

        return df

    @ staticmethod
    def _get_df(features):
        if features is None:
            return None
        else:
            return features.df

    def _save_settings(self, outdir: Path, **settings):

        def _make_serial(d):
            if isinstance(d, dict):
                return {_make_serial(k): _make_serial(v)
                        for k, v in d.items()}
            elif isinstance(d, list):
                return [_make_serial(i) for i in d]
            elif callable(d):
                return d.__name__
            else:
                return d

        logger.info(f'Saving settings to: {outdir}')

        for settings_name, settings_values in settings.items():
            settings_file = Path(outdir) / f'{settings_name}.json'
            if settings_file.is_file() and not self._overwrite:
                # Need to check if settings have not changed
                old_settings = json.load(open(settings_file))
                assert old_settings == _make_serial(settings_values)
            else:
                with open(settings_file, 'w') as f:
                    json.dump(_make_serial(settings_values), f, indent=4)

    def get_processing_tasks(self):
        # One processing task per ref_id
        tasks = list(self.db.ref_id.unique())
        logger.info(f'Found {len(tasks)} processing tasks ...')

        return tasks

    def load_database(self, file: Path):
        logger.info(f'Loading database file: {file}')
        db = gpd.read_file(file)
        db = db[db['contenttype'].isin(['110', '111'])]
        logger.info(f'Database successfully opened. Found {len(db)} samples.')

        # Convert into SATIO format
        # TODO: maybe not all of these are needed!
        df = db[[self.idattr, 'tile', 'epsg', 'path', 'split',
                 'ref_id', 'year', 'start_date', 'end_date', 'bounds']]

        self.db = db
        self.df = df

    def create_collections(self, df):

        collections = {}

        if self.idattr != 'location_id':
            df = df.rename(columns={self.idattr: 'location_id'})

        # Collections we always need
        collections['LABELS'] = PatchLabelsTrainingCollection(
            df, dataformat='worldcereal')
        collections['S2'] = L2ATrainingCollection(
            df, dataformat='worldcereal')
        collections['S1'] = SIGMA0TrainingCollection(
            df, dataformat='worldcereal')
        collections['METEO'] = AgERA5TrainingCollection(
            df, dataformat='worldcereal')
        collections['DEM'] = DEMCollection(
            folder='/data/MEP/DEM/COP-DEM_GLO-30_DTED/S2grid_20m')
        collections['WorldCover'] = WorldCoverCollection(
            folder='/data/worldcereal/auxdata/WORLDCOVER/2020/')

        return collections

    def _load_arrays(self, start_date: str, end_date: str, **inputfiles):

        arrays = {}
        common_sampleIDs = None

        for sensor in inputfiles:
            logger.info(f'Loading {sensor} file: {inputfiles[sensor]}')

            # Open dataset. NOTE: lock=False ensures we can pickle
            # the dataset later on for spark processing!
            ds = xr.open_dataset(inputfiles[sensor], lock=False)
            current_sampleIDs = list(ds.sampleID.values)

            if common_sampleIDs is None:
                common_sampleIDs = current_sampleIDs
            else:
                common_sampleIDs = list(set(
                    current_sampleIDs).intersection(common_sampleIDs))

            # These are time series features: need to add 1D x and y
            # for feature computation to work
            ds = ds.expand_dims({'x': 1, 'y': 1})

            # # filter data to keep only the required time period
            # (+ buffer of 1 month for compositing purposes)
            start_date_buf = pd.to_datetime(start_date) - pd.Timedelta(days=30)
            end_date_buf = pd.to_datetime(end_date) + pd.Timedelta(days=30)
            ds = ds.sel(t=slice(start_date_buf, end_date_buf))

            arrays[sensor] = ds.to_array(dim='bands')

        # All files loaded, now subset on common sampleIDs
        common_sampleIDs = list(set(common_sampleIDs))
        logger.info(f'Found {len(common_sampleIDs)} common sampleIDs.')

        common_arrays = {}

        logger.info('Creating sample-specific arrays ...')
        for sampleID in common_sampleIDs:
            common_arrays[sampleID] = {}

            for sensor in inputfiles:
                self._check_sampleID_uniqueness(arrays[sensor], sampleID)
                common_arrays[sampleID][sensor] = arrays[sensor].sel(
                    feature=arrays[sensor][
                        self.idattr] == sampleID).isel(
                    feature=0)  # Select first one after subsetting

        return common_arrays

    def _merge(self, feature_df_files, outdir):
        logger.info(f'Merging {len(feature_df_files)} feature DFs ...')
        dfs = [pd.read_parquet(f) for f in feature_df_files]
        merged_df = pd.concat(dfs)
        outfile = Path(outdir) / f'merged_features_df.parquet'
        self.save_features(merged_df, outfile)
        logger.success(('Successfully merged features of '
                        f'{len(merged_df)} samples!'))

    def _check_sampleID_uniqueness(self, array, sampleID, fail=False):
        matching_samples = (array.sampleID == sampleID).values.sum()
        if matching_samples > 1:
            logger.warning((f'{self.idattr} `{sampleID}` occurred '
                            f'{matching_samples} times in array: '
                            'taking first occurrence!'))
        if fail:
            raise RuntimeError(
                f'{self.idattr} not unique which is not allowed.')


class GEE_Pipeline(Pipeline):
    def __init__(self,
                 *args,
                 **kwargs):
        super().__init__(*args, **kwargs)

        self.extractions_dir = self.database
        self.labelattr = 'CT_fin'
        self.idattr = 'sampleID'

    def get_processing_tasks(self):
        # One processing task per ref_id

        # check for which ref ids we have a database.json file...
        db_paths = glob.glob(
            str(self.extractions_dir / '*' / 'database' / 'database.json'))
        if len(db_paths) == 0:
            tasks = None
        else:
            logger.info(f'Found {len(db_paths)} processing tasks ...')
            tasks = [Path(x).parents[1].name for x in db_paths]

        return tasks

    def load_database(self, file: Path):
        logger.info(f'Loading database file: {file}')
        db = gpd.read_file(file)
        logger.info(f'Database successfully opened. Found {len(db)} samples.')

        # Convert into SATIO format
        # TODO: maybe not all of these are needed!
        df = db[[self.idattr, 'tile', 'epsg', 'path',
                 'ref_id', 'year', 'start_date', 'end_date', 'bounds']]

        return df, db

    def run(self, merge: bool = True):

        # Create the output directory
        self.features_dir.mkdir(exist_ok=True)

        # Get processing tasks
        self._tasks = self.get_processing_tasks()

        if self._tasks is None:
            logger.info('No tasks to process, aborting!')
            return

        # Save the settings
        self._save_settings(
            self.features_dir,
            preprocessing_settings=self.preprocessing_settings,
            feature_settings=self.feature_settings
        )

        # Process tasks sequentially
        feature_df_files = []
        for i, task in enumerate(self._tasks):
            logger.info((f'Launching task: {task} '
                         f'({i+1}/{len(self._tasks)})'))
            feature_df_files.append(self.run_task(task))

            if self._debug and i == 1:
                logger.debug('Interrupting loop in debug mode.')
                break

        # Filter out any failed tasks
        feature_df_files = [f for f in feature_df_files if f is not None]

        if merge:
            self._merge(feature_df_files, self.features_dir)

    def run_task(self, taskname: str):

        # Create the output directory
        self.features_dir.mkdir(exist_ok=True)

        # Construct outfile path
        outfile = Path(self.features_dir) / f'{taskname}_features_df.parquet'

        # Prepare df in the correct format for feature computation
        db_file = (self.extractions_dir / taskname / 'database' /
                   'database.json')
        if not db_file.is_file():
            raise FileNotFoundError(
                f'Database json {db_file} not found, cannot continue!')
        task_df, task_db = self.load_database(db_file)

        # if outfile already exists and overwrite False,
        # check which sampleIDs have been processed and remove from db and df
        if outfile.is_file():
            if not self._overwrite:
                features_df = pd.read_parquet(outfile)
                processed = list(features_df.sampleID.unique())
                logger.info(f'Found {len(processed)} sampleIDs '
                            'which were already processed!')
                self.db = self.db.loc[~self.db[self.idattr].isin(processed)]
                self.df = self.df.loc[~self.df[self.idattr].isin(processed)]
                if len(self.db) == 0:
                    logger.info('No more samples to process, aborting!')
                    return
                else:
                    logger.info(f'{len(self.db)} remaining samples to process')
            else:
                # if we need to overwrite, delete existing!
                logger.info(f'Features df already existed, deleting!')
                outfile.unlink()

        # infer the year of the dataset from task
        year = taskname.split('_')[0]

        # Get processing dates
        start_date, end_date = get_processing_dates(self.start_doy,
                                                    self.end_doy,
                                                    year)

        # Create training collections
        collections = self.create_collections(task_df)

        # Get local copies of the settings
        preproc_settings = deepcopy(self.preprocessing_settings)
        feature_settings = deepcopy(self.feature_settings)

        # Impute processing dates into preprocessing settings
        if preproc_settings is not None:
            for sensor in preproc_settings.keys():
                preproc_settings[sensor]['composite']['start'] = start_date  # NOQA
                preproc_settings[sensor]['composite']['end'] = end_date

        # Initialize FeatureComputer
        featurecomputer = GEETrainingFeaturesComputer(feature_settings,
                                                      start_date,
                                                      end_date,
                                                      preproc_settings)

        # Run feature computation
        if self._sc is None:
            logger.info('Running feature computation in serial.')
            features = self._run_task_local(collections,
                                            task_db,
                                            featurecomputer)

        else:
            logger.info('Running feature computation on executors.')
            features = self._run_task_spark(collections,
                                            task_db,
                                            featurecomputer)

        # Save the result
        if features is not None:
            self.save_features(features, outfile)
            return outfile
        else:
            logger.warning('Could not compute features for this task!')
            return None

    def _run_task_local(self, collections, task_db,
                        featurecomputer):
        def get_feat_df(sampleID):
            '''Wrapper function to allow to use run_parallel
            '''
            return get_sample_features(sampleID,
                                       collections,
                                       task_db,
                                       featurecomputer,
                                       self._debug,
                                       self.labelattr,
                                       self.idattr)

        sampleIDs = task_db[self.idattr].to_list()

        if self._debug:
            sampleIDs = sampleIDs[:5]

        results = run_parallel(get_feat_df,
                               sampleIDs,
                               max_workers=MAX_WORKERS)
        results = [result for result in results if result is not None]

        df = pd.concat(results, ignore_index=True)

        return df

    def _run_task_spark(self, collections,
                        task_db, featurecomputer):

        sampleIDs = task_db[self.idattr].to_list()
        debug = self._debug

        if debug:
            sampleIDs = sampleIDs[:5]

        rdd = self._sc.parallelize(sampleIDs,
                                   len(sampleIDs)).map(
            lambda x: get_sample_features(
                x, collections, task_db,
                featurecomputer, debug,
                self.labelattr)
        ).filter(lambda t: t is not None)

        # Check if we actually have something left
        if rdd.isEmpty():
            logger.warning('No samples left in RDD -> aborting job')
            return None

        results = rdd.collect()
        df = pd.concat(results, ignore_index=True)

        return df
