import copy
from typing import Dict
import numpy as np
import xarray as xr
import pandas as pd
from loguru import logger
from skimage.transform import resize

import satio
from satio.features import Features, BaseFeaturesProcessor
from satio.features import multitemporal_speckle
from satio.timeseries import Timeseries
from satio.utils.resample import downsample_n
from worldcereal.fp import apply_gdd_normalization

from worldcereal.utils.masking import (scl_mask,
                                       binary_mask,
                                       SCL_MASK_VALUES,
                                       flaglocalminima,
                                       dilate_mask,
                                       erode_mask)
from worldcereal.utils.resize import AgERA5_resize
from worldcereal.features.feat_dem import elev_from_dem
from worldcereal.features.feat_irr import (theoretical_boundaries,
                                           soil_moisture)

L2A_BANDS = ['B02', 'B03', 'B04', 'B05', 'B06',
             'B07', 'B08', 'B8A', 'B11', 'B12', 'SCL']
L2A_BANDS_DICT = {10: L2A_BANDS}
S1_BANDS_DICT = {10: ['VV', 'VH']}


class OpenEOS2FeaturesProcessor(BaseFeaturesProcessor):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    @ property
    def supported_bands(self):
        return L2A_BANDS_DICT

    @ property
    def supported_rsis(self):
        if self._supported_rsis is None:
            rsis_dict = {}
            rsi_res = {r: self._rsi_meta[r]['native_res']
                       for r in self._rsi_meta.keys()}
            rsis_dict[10] = [v for v, r in rsi_res.items()]
            self._supported_rsis = rsis_dict

        return self._supported_rsis

    @ property
    def _reflectance(self):
        return True

    def preprocess_data(self,
                        timeseries: 'satio.timeseries.Timeseries',
                        resolution: int,
                        mask: np.ndarray = None,
                        composite: bool = True,
                        interpolate: bool = True,
                        settings_override: Dict = None):
        """
        Pre-processing of loaded timeseries object. Includes masking,
        compositing and interpolation.
        """
        settings = settings_override or self.settings
        newtimeseries = None

        for band in timeseries.bands:
            band_ts = timeseries.select_bands([band])
            if mask is not None:
                # mask 'nan' values. for uint16 we use 0 as nodata value
                # mask values that are False are marked as invalid
                if isinstance(mask, dict):
                    mask = mask[resolution]

                band_ts = band_ts.mask(mask)

            # drop all no data frames
            band_ts = band_ts.drop_nodata()

            composite_settings = settings.get('composite')
            if (composite_settings is not None) & composite:
                logger.info(f"{band}: compositing")
                band_ts = band_ts.composite(**composite_settings)

            # If required, we do here the GDD normalization
            if 'normalize_gdd' in settings.keys():
                logger.info(f'{band}: applying GDD normalization')
                band_ts = apply_gdd_normalization(band_ts, settings)

            if interpolate:
                logger.info(f'{band}: interpolating')
                band_ts = band_ts.interpolate()

            if newtimeseries is None:
                newtimeseries = band_ts
            else:
                newtimeseries = newtimeseries.merge(band_ts)

        return newtimeseries

    def load_data(self,
                  resolution,
                  timeseries=None,
                  dtype=np.uint16):
        """
        Load Timeseries from the collection and merge with `timeseries` if
        given.
        `dtype` allows optional explicit casting of the loaded data
        """
        collection = self.collection

        loaded_bands = timeseries.bands if timeseries is not None else []

        # now the required bands...
        bands = self.bands[resolution].copy()
        loaded_bands = timeseries.bands if timeseries is not None else []
        bands_to_load = [b for b in bands if b not in loaded_bands]

        if(len(bands_to_load) > 0):
            logger.info(f'Loading bands: {bands_to_load}')

            data = collection.sel(bands=bands_to_load)
            bands_ts = Timeseries(data=data.values,
                                  timestamps=list(data.t.values),
                                  bands=list(data.bands.values),
                                  attrs={'sensor': collection.sensor})

            bands_ts.data[np.isnan(bands_ts.data)] = 0

            if dtype is not None:
                bands_ts.data = bands_ts.data.astype(dtype)
            if timeseries is None:
                timeseries = bands_ts
            else:
                timeseries = timeseries.merge(bands_ts)
        else:
            logger.info("Did not find bands to "
                        f"load for resolution: {resolution}")

        return timeseries

    def load_mask(self):

        logger.info(f'SCL: loading')
        scl_data = self.collection.sel(bands='SCL')
        scl_ts = Timeseries(
            data=np.expand_dims(scl_data.values.astype(np.uint16), axis=0),
            timestamps=list(scl_data.t.values),
            bands=['SCL'])

        logger.info(f"SCL: preparing mask")
        mask, obs, valid_before, valid_after = scl_mask(
            scl_ts.data, **self.settings['mask'])

        return mask, obs, valid_before, valid_after

    def load_multitemporal_mask(self, prior_mask=None):
        '''
        Method to flag undetected clouds/shadows using multitemporal
        gap search approach.
        Cfr. work of Dominique Haesen (VITO)
        '''

        logger.info(('Performing multitemporal '
                     'cloud/shadow filtering ...'))

        # Load the raw bands needed to compute NDVI
        logger.info(f"Loading bands: ['B04', 'B08']")
        data = self.collection.sel(bands=['B04', 'B08'])
        ts = Timeseries(
            data=data.values,
            timestamps=list(data.t.values),
            bands=['B04', 'B08'],
            attrs={'sensor': self.collection.sensor})

        # Make sure data is in uint16
        ts.data = ts.data.astype(np.uint16)

        # If a prior mask is provided, apply it to the timeseries
        if prior_mask is not None:
            ts_masked = ts.mask(prior_mask, drop_nodata=False)
        else:
            ts_masked = ts

        # Convert to float32 and replace missing values (0)
        # with NaN
        ts_masked.data = ts_masked.data.astype(np.float32)
        ts_masked.data[ts_masked.data == 0] = np.nan

        # Compute NDVI
        ndvi = ((ts_masked['B08'].data - ts_masked['B04'].data) /
                (ts_masked['B08'].data + ts_masked['B04'].data))[0, ...]

        # Make a DataArray for easy daily resampling
        ndvi_da = xr.DataArray(data=ndvi,
                               coords={'time': ts_masked.timestamps},
                               dims=['time', 'x', 'y'])

        # Resample to daily, missing data will be NaN
        daily_daterange = pd.date_range(
            ts_masked.timestamps[0],
            ts_masked.timestamps[-1] + pd.Timedelta(days=1),
            freq='D').floor('D')
        ndvi_daily = ndvi_da.reindex(time=daily_daterange,
                                     method='bfill', tolerance='1D')

        # Run multitemporal dip detection
        # Need to do it in slices, to avoid memory issues
        step = 256
        for idx in np.r_[:ndvi_daily.values.shape[1]:step]:
            for idy in np.r_[:ndvi_daily.values.shape[2]:step]:
                logger.debug((f"idx: {idx} - {idx+step} "
                              f"| idy: {idy} - {idy+step}"))
                ndvi_daily.values[
                    :, idx:idx+step, idy:idy+step] = flaglocalminima(
                    ndvi_daily.values[:, idx:idx+step, idy:idy+step],
                    maxdip=0.01,
                    maxdif=0.1,
                    maxgap=60,
                    maxpasses=5)

        # Subset on the original timestamps
        ndvi_cleaned = ndvi_daily.sel(time=ts.timestamps,
                                      method='ffill',
                                      tolerance='1D')

        # Extract the mask: True is invalid, False is valid
        mask = np.isnan(ndvi_cleaned.values)

        # Apply erosion/dilation to reduce speckle effect
        # DISABLED BECAUSE IF REMOVES WAY TOO MUCH !!!
        # erode_r = self.settings['mask'].get('erode_r', None)
        # dilate_r = self.settings['mask'].get('dilate_r', None)
        # erode_r = 3
        # dilate_r = 11
        # mask = erode_mask(mask, erode_r=erode_r)
        # mask = dilate_mask(mask, dilate_r=dilate_r)

        # Invert the mask
        mask = ~mask

        # Return the enhanced mask and the original raw
        # timeseries of B04 and B08 so we don't need
        # to load them again if needed
        return mask, ts

    def compute_features(self,
                         chunck_size=None):

        settings = self.settings
        features_meta = self._features_meta

        # check if texture features need to be computed
        # and store features_meta of those separately
        if 'texture' in features_meta.keys():
            text_feat = True
            text_feat_meta = features_meta['texture']
            del features_meta['texture']
        else:
            text_feat = False

        # if no other features remain at this point -> abort!
        if not bool(features_meta):
            raise ValueError('At least one other feature required'
                             'other than texture. Aborting...')

        # check if pheno features need to be computed
        # and store features_meta of those in separate variable
        seas_feat = False
        pheno_feat_meta = {}
        pheno_keys = ['pheno_mult_season', 'pheno_single_season']
        for p_key in pheno_keys:
            if p_key in features_meta.keys():
                seas_feat = True
                pheno_feat_meta.update({p_key: features_meta[p_key]})
                del features_meta[p_key]

        # check whether all bands and rsis are represented
        # in features_meta. If not, drop the ones not needed!
        check = False
        for key, value in features_meta.items():
            if value.get("bands", None) is None:
                check = True
        if not check:
            feat_bands = []
            for key, value in features_meta.items():
                feat_bands.extend(value.get("bands", []))
            to_remove_b = [b for b in settings['bands']
                           if b not in feat_bands]
            for b in to_remove_b:
                settings['bands'].remove(b)
            to_remove_r = [r for r in settings['rsis']
                           if r not in feat_bands]
            for r in to_remove_r:
                settings['rsis'].remove(r)

        # if seasons need to be detected:
        # make sure the RSI required for this is included in the rsi list
        # to be computed!
        seasonSettings = settings.get('seasons', {})
        if bool(seasonSettings) or (seas_feat):
            rsi_seas = seasonSettings.get('rsi', 'evi')
            if rsi_seas not in settings['rsis']:
                settings['rsis'].append(rsi_seas)
            resolution_seas = 10
        else:
            seasons = None
            resolution_seas = 0

        # get mask
        mask, obs, valid_before, valid_after = self.load_mask()

        #  ---------------------------------
        #  get advanced multitemporal mask
        #  ---------------------------------

        if bool(self.settings['mask'].get('multitemporal', False)):
            mask, ts = self.load_multitemporal_mask(prior_mask=mask)
            ts_proc = None

            # Need to adjust valid_after for additionally masked obs
            valid_after = (mask[10].sum(axis=0) / obs * 100).astype(int)

        else:
            ts = None
            ts_proc = None

        #  ---------------------------
        #  Handle 10m data
        #  ---------------------------
        resolution = 10
        ts = self.load_data(resolution, timeseries=ts)

        # feature computuation at 10 m
        if ts is not None:
            features_10m = self.compute_features_10m(ts,
                                                     chunck_size,
                                                     mask)

            # if season detection needs to be done on 10m -> do it now!
            if resolution_seas == 10:
                seasons = self.get_seasons(ts,
                                           mask, seasonSettings,
                                           resolution_seas)

                # compute pheno features
                phen_feat_10m = self.compute_phen_features(seasons,
                                                           pheno_feat_meta,
                                                           resolution)

                # merge with other 10m features
                if features_10m is not None:
                    if phen_feat_10m is not None:
                        features_10m = features_10m.merge(phen_feat_10m)
                else:
                    features_10m = phen_feat_10m
        else:
            features_10m = None

        #  ---------------------------
        #  Texture
        #  ---------------------------

        # optionally compute texture features based on
        # computed features
        if text_feat:
            logger.info('Computing texture features')
            inputFeat = features_10m.select(text_feat_meta['features'])
            params = text_feat_meta['parameters']
            # if desired, run PCA first
            if 'pca' in text_feat_meta.keys():
                inputFeat = inputFeat.pca(text_feat_meta['pca'],
                                          scaling_range=params.get(
                    'scaling_range', None))

            text_features = inputFeat.texture(
                win=params.get('win', 2),
                d=params.get('d', [1]),
                theta=params.get('theta', [0, np.pi/4]),
                levels=params.get('levels', 256),
                metrics=params.get('metrics', ('contrast',)),
                avg=params.get('avg', True),
                scaling_range=params.get('scaling_range', {}))

            features_10m = features_10m.merge(text_features)

        #  ---------------------------
        #  Meta
        #  ---------------------------

        # add meta features
        features_10m = features_10m.merge(
            Features(np.array([obs, valid_after]),
                     names=['l2a_obs',
                            'l2a_obs_percentvalid']))

        return features_10m

    def compute_features_10m(self, ts, chunk_size, mask):
        resolution = 10
        features_meta = self._features_meta
        rsi_meta = self._rsi_meta

        # first, we compute RSI's
        rsis = self.rsis[resolution]
        if len(rsis) > 0:
            logger.info(f"{resolution}m: computing rsis")
            ts_rsi = ts.compute_rsis(*rsis, rsi_meta=rsi_meta)

        # now pre-process the timeseries
        # (only needed if band features required)
        bands = self.bands[resolution]
        if len(bands) > 0:
            ts_proc = self.preprocess_data(ts.select_bands(bands),
                                           resolution, mask=mask)
        # pre-process computed rsis
        if len(rsis) > 0:
            ts_rsi = self.preprocess_data(ts_rsi,  # type: ignore
                                          resolution, mask=mask)

        # compute 10m band features and scale to reflectance
        features_10m = None
        if len(bands) > 0:
            logger.info(f"{resolution}m: computing bands features")
            features_10m = ts_proc.select_bands(
                bands).features_from_dict(
                    resolution,
                    features_meta=features_meta,
                    chunk_size=chunk_size)

            # because bands features are calculated from uint16 bands
            # scale them to reflectance values
            # (fft features should not be scaled though)
            features_10m.data /= 10000

        # add RSI features
        logger.info(f"{resolution}m: computing rsi features")
        features_10m_rsi = None
        for rn in rsis:
            if features_10m_rsi is None:
                features_10m_rsi = ts_rsi.select_bands(  # type: ignore
                    [rn]).features_from_dict(
                    resolution,
                    features_meta=features_meta,
                    chunk_size=chunk_size)
            else:
                features_10m_rsi = features_10m_rsi.merge(
                    ts_rsi.select_bands(  # type: ignore
                        [rn]).features_from_dict(
                        resolution,
                        features_meta=features_meta,
                        chunk_size=chunk_size))

        # merge with band features
        if features_10m is not None:
            if features_10m_rsi is not None:
                features_10m = features_10m.merge(features_10m_rsi)
        else:
            features_10m = features_10m_rsi

        return features_10m

    def get_seasons(self, ts, mask, seasonSettings, resolution):

        rsi_meta = self._rsi_meta
        rsi_seas = seasonSettings.get('rsi', 'evi')

        logger.info(f"{resolution}m: detecting growing seasons")

        '''
        Season detection needs specific settings.
        NO gdd normalization and daily compositing so smoothing
        and season detection happens on correct time series.
        '''
        season_override_settings = {
            'composite': self.settings['composite']
        }

        # compute required rsi ...
        rsi = ts.compute_rsis(rsi_seas, rsi_meta=rsi_meta)
        rsi = self.preprocess_data(
            rsi, resolution, mask=mask,
            interpolate=False,
            settings_override=season_override_settings)

        # RSI 0 means no data
        rsi.data[rsi.data == 0] = np.nan

        # get parameters
        max_seasons = seasonSettings.get('max_seasons', 5)
        amp_thr1 = seasonSettings.get('amp_thr1', 0.1)
        amp_thr2 = seasonSettings.get('amp_thr2', 0.35)
        min_window = seasonSettings.get('min_window', 10)
        max_window = seasonSettings.get('max_window', 185)
        partial = seasonSettings.get('partial', True)

        # ###################################################
        # EXPERIMENTAL WHITTAKER SMOOTHING
        # -> comment out to disable
        # ###################################################

        # Run whittaker smoother
        smoothed_ts = self.smooth_whittaker(rsi)

        # Overwrite in original satio.Timeseries
        rsi.data = np.expand_dims(smoothed_ts.values, axis=0)

        # # plot timeseries to be used for seasons detection
        # outdir = r'/data/worldcereal/tmp/jeroen/swets_numba_31_Wswets'
        # os.makedirs(outdir, exist_ok=True)
        # outfile = os.path.join(outdir,
        #                        f'{self.collection.location_ids[0]}.png')
        # fig, ax = plt.subplots()
        # ts = np.squeeze(rsi.data[0, :, 32, 32])
        # ax.plot(rsi.timestamps, ts)
        # plt.savefig(outfile)

        ####################################################

        # run season detection
        seasons = rsi.detect_seasons(max_seasons=max_seasons,
                                     amp_thr1=amp_thr1,
                                     amp_thr2=amp_thr2,
                                     min_window=min_window,
                                     max_window=max_window,
                                     partial=partial)

        return seasons

    def smooth_whittaker(self, rsi):
        '''Method to smooth a RSI using Whittaker smoother
        '''

        # SETTINGS
        lmbda = 100  # Very strong smoothing for season detection
        passes = 3
        dokeepmaxima = True

        # Smooth the RSI before detecting seasons
        logger.info(('Performing Whittaker smoothing ...'))

        # Make a DataArray for easy daily resampling
        rsi_da = xr.DataArray(data=rsi.data[0, ...],
                              coords={'time': rsi.timestamps},
                              dims=['time', 'x', 'y'])

        # Resample to daily
        daily_daterange = pd.date_range(
            rsi.timestamps[0],
            rsi.timestamps[-1] + pd.Timedelta(days=1),
            freq='D').floor('D')
        rsi_daily = rsi_da.reindex(time=daily_daterange,
                                   method='bfill', tolerance='1D')

        # Run whittaker smoother
        # Need to do it in slices, to avoid memory issues
        from worldcereal.utils.masking import whittaker
        step = 256
        for idx in np.r_[:rsi_daily.values.shape[1]:step]:
            for idy in np.r_[:rsi_daily.values.shape[2]:step]:
                logger.debug((f"idx: {idx} - {idx+step} "
                              f"| idy: {idy} - {idy+step}"))
                rsi_daily.values[
                    :, idx:idx+step,
                    idy:idy+step] = whittaker(
                        lmbda=lmbda,
                        npdatacube=rsi_daily.values[
                            :, idx:idx+step, idy:idy+step],
                    minimumdatavalue=0,
                    maximumdatavalue=1,
                    passes=passes,
                    dokeepmaxima=dokeepmaxima)

        # Subset on the original timestamps
        rsi_origts = rsi_daily.sel(time=rsi.timestamps,
                                   method='ffill',
                                   tolerance='1D')

        return rsi_origts

    def compute_phen_features(self, seasons, pheno_feat_meta, resolution):

        # TODO: for now, all pheno features are only computed on the RSI
        # from which the seasons were derived. This should be extended
        # towards multiple bands or rsis, but requires some more thinking...
        # The idea is that you specify in the features_meta for which bands
        # the single_season features need to be computed and you make sure
        # these timeseries are imported in this function

        phen_feat = None

        logger.info(f"{resolution}m: computing pheno features")
        if 'pheno_mult_season' in pheno_feat_meta.keys():
            phen_feat = seasons.pheno_mult_season_features(resolution)

        if 'pheno_single_season' in pheno_feat_meta.keys():
            sel_mode = pheno_feat_meta['pheno_single_season'][
                'select_season']['mode']
            sel_param = pheno_feat_meta['pheno_single_season'][
                'select_season']['param']
            if phen_feat is None:
                phen_feat = seasons.pheno_single_season_features(sel_mode,
                                                                 sel_param,
                                                                 resolution)
            else:
                phen_feat = phen_feat.merge(
                    seasons.pheno_single_season_features(sel_mode,
                                                         sel_param,
                                                         resolution))

        return phen_feat


class OpenEOS1FeaturesProcessor(BaseFeaturesProcessor):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    @ property
    def supported_bands(self):
        return S1_BANDS_DICT

    @ property
    def supported_rsis(self):
        if self._supported_rsis is None:
            rsis_dict = {}
            rsi_res = {r: self._rsi_meta[r]['native_res']
                       for r in self._rsi_meta.keys()}
            rsis_dict[10] = [v for v, r in rsi_res.items()]
            self._supported_rsis = rsis_dict

        return self._supported_rsis

    @ property
    def _reflectance(self):
        return False

    def preprocess_data(self,
                        timeseries: 'satio.timeseries.Timeseries',
                        resolution: int = 10,
                        speckle_filter: bool = True,
                        composite: bool = True,
                        interpolate: bool = True):
        """
        Pre-processing of loaded timeseries object. Includes masking,
        compositing and interpolation.
        """
        def _to_db(pwr):
            return 10 * np.log10(pwr)

        def _to_pwr(db):
            return np.power(10, db / 10)

        settings = self.settings
        newtimeseries = None

        # number of obs needs to be calculated here
        self._obs = None
        self._obs_notmasked = None

        for band in timeseries.bands:
            band_ts = timeseries.select_bands([band])

            # get number of observations for first band loaded
            if self._obs is None:
                self._obs = band_ts.data[0]
                self._obs = np.isfinite(self._obs)
                self._obs = self._obs.sum(axis=0)

            # drop all no data frames
            band_ts = band_ts.drop_nodata()

            # Data is in linear already
            data_lin = band_ts.data

            if speckle_filter:
                logger.info(f"{band}: speckle filtering")
                idx_nodata = np.isnan(data_lin)
                data_lin[idx_nodata] = 0  # Speckle filter uses 0 as nodata
                for band_idx in range(data_lin.shape[0]):
                    data_lin[band_idx] = multitemporal_speckle(
                        data_lin[band_idx])
                data_lin[idx_nodata] = np.nan

            band_ts.data = data_lin

            composite_settings = settings.get('composite')
            if (composite_settings is not None) & composite:
                logger.info(f"{band}: compositing")
                band_ts = band_ts.composite(**composite_settings)

            # If required, we do here the GDD normalization
            if 'normalize_gdd' in self.settings.keys():
                logger.info(f'{band}: applying GDD normalization')
                band_ts = apply_gdd_normalization(band_ts, self.settings)

            if interpolate:
                logger.info(f'{band}: interpolating')
                band_ts = band_ts.interpolate()

            # Finally convert to dB
            data_db = _to_db(band_ts.data)
            band_ts.data = data_db

            if newtimeseries is None:
                newtimeseries = band_ts
            else:
                newtimeseries = newtimeseries.merge(band_ts)

        return newtimeseries

    def load_data(self,
                  resolution,
                  timeseries=None,
                  dtype=np.float32):
        """
        Load Timeseries from the collection and merge with `timeseries` if
        given.
        `dtype` allows optional explicit casting of the loaded data
        """
        collection = self.collection

        loaded_bands = timeseries.bands if timeseries is not None else []

        # now the required bands...
        bands = self.bands[resolution].copy()
        loaded_bands = timeseries.bands if timeseries is not None else []
        bands_to_load = [b for b in bands if b not in loaded_bands]

        if(len(bands_to_load) > 0):
            logger.info(f'Loading bands: {bands_to_load}')

            data = collection.sel(bands=bands_to_load)
            bands_ts = Timeseries(data=data.values,
                                  timestamps=list(data.t.values),
                                  bands=list(data.bands.values),
                                  attrs={'sensor': collection.sensor})

            if dtype is not None:
                bands_ts.data = bands_ts.data.astype(dtype)
            if timeseries is None:
                timeseries = bands_ts
            else:
                timeseries = timeseries.merge(bands_ts)
        else:
            logger.info("Did not find bands to "
                        f"load for resolution: {resolution}")

        return timeseries

    def compute_features(self,
                         chunk_size=None):

        lproc = self
        features_meta = self._features_meta
        rsi_meta = self._rsi_meta

        # check if texture features need to be computed
        # and store features_meta of those separately
        if 'texture' in features_meta.keys():
            text_feat = True
            text_feat_meta = features_meta['texture']
            del features_meta['texture']
        else:
            text_feat = False

        # if no other features remain at this point -> abort!
        if not bool(features_meta):
            raise ValueError('At least one other feature required'
                             'other than texture. Aborting...')

        # 10m processing
        resolution = 10

        ts = lproc.load_data(resolution)

        if ts is not None:

            # first we compute rsis
            rsis = lproc.rsis[resolution]
            if len(rsis) > 0:
                logger.info(f"{resolution}m: computing rsis")
                ts_rsi = ts.compute_rsis(*rsis,
                                         rsi_meta=rsi_meta,
                                         bands_scaling=1)

            # now pre-process the timeseries
            # (only needed if band features required)
            bands = self.bands[resolution]
            if len(bands) > 0:
                ts_proc = lproc.preprocess_data(ts.select_bands(bands),
                                                resolution)
            # pre-process computed rsis
            if len(rsis) > 0:
                ts_rsi = lproc.preprocess_data(ts_rsi, resolution)

            # start feature calculation
            features = None
            if len(bands) > 0:
                logger.info(f"{resolution}m: computing bands features")
                features = ts_proc.select_bands(
                    bands).features_from_dict(resolution,
                                              features_meta=features_meta,
                                              chunk_size=chunk_size)

            # add RSI features
            logger.info(f"{resolution}m: computing rsi features")
            features_rsi = None
            for rn in rsis:
                if features_rsi is None:
                    features_rsi = ts_rsi.select_bands(  # type: ignore
                        [rn]).features_from_dict(
                        resolution,
                        features_meta=features_meta,
                        chunk_size=chunk_size)
                else:
                    features_rsi = features_rsi.merge(
                        ts_rsi.select_bands(  # type: ignore
                            [rn]).features_from_dict(
                            resolution,
                            features_meta=features_meta,
                            chunk_size=chunk_size))

            # merge band and rsi features
            if features is None:
                features = features_rsi
            elif features_rsi is not None:
                features = features.merge(features_rsi)

            # optionally compute texture features based on
            # computed features
            if text_feat:
                logger.info('Computing texture features')
                self.timer.text_features[resolution].start()
                inputFeat = features.select(text_feat_meta['features'])
                params = text_feat_meta['parameters']
                # if desired, run PCA first
                if 'pca' in text_feat_meta.keys():
                    inputFeat = inputFeat.pca(text_feat_meta['pca'],
                                              scaling_range=params.get(
                        'scaling_range', None))

                text_features = inputFeat.texture(
                    win=params.get('win', 2),
                    d=params.get('d', [1]),
                    theta=params.get('theta', [0, np.pi/4]),
                    levels=params.get('levels', 256),
                    metrics=params.get('metrics', ('contrast',)),
                    avg=params.get('avg', True),
                    scaling_range=params.get('scaling_range', {}))

                features = features.merge(text_features)

                self.timer.text_features[resolution].stop()

        else:
            features = None

        # add meta features
        features = features.merge(
            Features(np.array([self._obs]),
                     names=['sigma0_obs']))

        return features
