from typing import Dict, List

import xarray as xr
import numpy as np
from loguru import logger
import pandas as pd

from satio.features import Features
from satio.timeseries import Timeseries


MAX_GAP = 60

DEFAULT_SEGM_SETTINGS = {'scale': 1,
                         'sigma': 0.8,
                         'min_size': 3}

VALID_SENSOR_RANGES = {
    'S2': [-0.05, 2.0999],  # Reflectance
    'S1': [-85, 15],  # Backscatter in dB
    'METEO': [240, 315]  # Temperature in Kelvin
}


def get_sample_features(sampleID,
                        collections,
                        database,
                        featurecomputer,
                        debug,
                        labelattr,
                        idattr):
    """Main method to get the features for one sample. The method
    is defined outside the pipeline class to avoid pickling and
    sparkcontext issues when running on spark.
    """

    logger.info(f'Starting feature computation for {idattr}: {sampleID}')

    try:
        features_df = featurecomputer.get_features(collections,
                                                   sampleID,
                                                   database,
                                                   labelattr,
                                                   idattr)
    except Exception as e:
        if debug:
            raise e
        logger.error(f"Error processing {idattr} = {sampleID}: "
                     f"{e}.")
        return None

    logger.info(('Feature computation succesfull '
                 f'for {idattr}: {sampleID}'))

    return features_df


class GapError(Exception):
    pass


class UnexpectedRangeError(Exception):
    pass


class FeaturesComputer():
    '''Class that computes classification features
    '''

    def __init__(self,
                 feature_settings,
                 start_date,
                 end_date,
                 preprocessing_settings=None,
                 resolution=10,
                 max_gap=MAX_GAP,
                 segmentation: bool = False,
                 segm_settings=None):
        self.feature_settings = feature_settings
        self.preprocessing_settings = preprocessing_settings
        self.start_date = start_date
        self.end_date = end_date
        self.resolution = resolution
        self.max_gap = max_gap
        self.segmentation = segmentation
        self.segm_settings = segm_settings

    def get_features(self, inputs: Dict[str, xr.DataArray]) -> Features:

        logger.info('-' * 50)
        logger.info('STARTING FEATURE COMPUTATION')
        logger.info('-' * 50)

        self.sensors = list(inputs.keys())
        logger.info(f'Got data for following sensors: {self.sensors}')

        # Do check on inputs
        self._check_inputs(inputs)

        # Transform to satio Timeseries
        ts = self._to_satio_timeseries(inputs)

        # Perform sensor-specific gap checks
        self._perform_gap_checks(ts,
                                 self.start_date,
                                 self.end_date,
                                 self.max_gap)

        # Preprocess the timeseries
        if self.preprocessing_settings is not None:
            ts = self.preprocess(ts)

        # Do range checks on the preprocessed inputs
        ts = self._check_input_ranges(ts)

        # Compute the features
        features = self._compute_features(ts)

        # Average pixels per computed segment
        # if requested
        if self.segmentation:
            features = self.average_per_segment(features,
                                                settings=self.segm_settings)

        return features

    def preprocess(self, ts: Dict[str, Timeseries]):
        logger.info('---- PREPROCESSING INPUTS START')

        ts_proc = {}

        for sensor in self.sensors:
            if sensor not in self.preprocessing_settings:
                logger.warning(('No preprocessing settings found for: '
                                f'{sensor} -> skipping!'))
                ts_proc[sensor] = ts[sensor]
            else:
                logger.info(f'Preprocessing: {sensor}')
                ts_proc[sensor] = self._preprocess_sensor(
                    ts[sensor],
                    self.preprocessing_settings[sensor])

        logger.info('---- PREPROCESSING INPUTS DONE')

        return ts_proc

    def _compute_features(self, ts: Dict[str, Timeseries]):
        logger.info('---- COMPUTING FEATURES START')

        features = {}

        for sensor in self.sensors:
            if sensor not in self.feature_settings:
                raise RuntimeError(
                    ('No feature settings found for: '
                     f'{sensor} -> cannot compute features!'))
            else:
                logger.info(f'Feature computation: {sensor}')
                features[sensor] = self._compute_features_sensor(
                    sensor,
                    ts[sensor],
                    self.feature_settings[sensor])

        # If there is more than one sensor we need to merge the features
        features = list(features.values())
        features = Features.from_features(*features)

        logger.info((f'Successfully extracted '
                     f'{len(features.names)} features.'))
        logger.info('---- COMPUTING FEATURES DONE')

        return features

    def average_per_segment(self, features: Features,
                            segm_features: List[str] = None,
                            settings=None) -> Features:
        from skimage.segmentation import felzenszwalb

        logger.info('---- FEATURE AVERAGING OVER SEGMENTS START')

        bands = []
        for n in features.names:
            bands.append(n.split('-')[1])
        if 'B04' not in bands or 'B08' not in bands:
            raise RuntimeError(
                'Both B04 and B08 are required for segmentation.')
        nr_tsteps = bands.count('B04')
        tsteps = np.linspace(0, nr_tsteps - 1, 6).astype(int)
        logger.info(f'Computing segmentation on tsteps: {tsteps}')

        segm_features = []
        logger.info('Retrieving NDVI tsteps ...')
        for tstep in tsteps:
            b04_ft = f'S2-B04-ts{tstep}-10m'
            assert b04_ft in features.names, f'Required feature {b04_ft} not in features!'  # NOQA
            b04 = features.select([b04_ft])
            b08_ft = f'S2-B08-ts{tstep}-10m'
            assert b08_ft in features.names, f'Required feature {b08_ft} not in features!'  # NOQA
            b08 = features.select([b08_ft])
            ndvi_ft = f'S2-NDVI-ts{tstep}-10m'
            ndvi = np.divide((b08.data - b04.data),
                             (b08.data + b04.data))
            segm_features.append(Features(ndvi, [ndvi_ft]))

        segm_features = Features.from_features(*segm_features)

        # replace nan's by zeros
        segm_features.data[np.isnan(segm_features.data)] = 0

        # run PCA if more than 3 features selected
        if len(segm_features.names) > 3:
            logger.info('Running PCA ...')
            segm_features = segm_features.pca(num_components=3)

        # transform data into right shape
        inputs = np.moveaxis(segm_features.data, 0, -1)

        # get felzenszwalb settings
        if settings is None:
            settings = DEFAULT_SEGM_SETTINGS
        # apply segmentation algorithms
        logger.info('Running Felzenszwalb segmentation ...')
        segments_fz = felzenszwalb(inputs, **settings)

        logger.info('Segmentation done!')
        logger.info('Computing average features per segment')
        newdata = features.data.copy()
        for segm in np.unique(segments_fz):
            msk = segments_fz == segm
            newdata[:, msk] = np.repeat(np.nanmean(
                features.data[:, msk], axis=1, keepdims=True),
                msk.sum(), axis=1)
        features.data = newdata

        # Add the segments as additional feature
        features = features.merge(Features(data=segments_fz,
                                           names=['INSTANCE-ID'],
                                           dtype='uint16'))
        logger.info('Done.')
        logger.info('---- FEATURE AVERAGING OVER SEGMENTS DONE')

        return features

    def _preprocess_sensor(self, ts: Timeseries, settings: Dict):

        # Cast to proper dtype
        logger.info(f'Casting to dtype: {settings["dtype"]}')
        ts.data = ts.data.astype(settings['dtype'])

        # Properly set nodata
        if settings.get('src_nodata', None) is not None:
            src_nodata = settings.get('src_nodata')
            logger.info(f'Applying source nodata value: {src_nodata}')
            default_dst_nodata = 0 if settings['dtype'] == np.uint16 else np.nan  # NOQA
            dst_nodata = settings.get('dst_nodata', default_dst_nodata)
            logger.info(f'Setting nodata to: {dst_nodata}')
            ts.data[ts.data == src_nodata] = dst_nodata

        if settings.get('bands', None) is not None:
            # Got a list of bands: make the selection first
            band_list = settings.get('bands')
            logger.info(f'Band selection: {band_list}')
            ts = ts.select_bands(band_list)

        if settings.get('pre_func', None) is not None:
            # We got function we need to apply before starting
            # preprocessing
            pre_func = settings['pre_func']
            logger.info(f'Applying pre_func: {pre_func}')
            ts.data = pre_func(ts.data)

        if 'composite' in settings:
            logger.info('Compositing ...')
            ts = ts.composite(**settings['composite'])

        if settings.get('interpolate', False):
            logger.info('Interpolating ...')
            ts = ts.interpolate()

        if settings.get('post_func', None) is not None:
            # We got function we need to apply after
            # preprocessing
            post_func = settings['post_func']
            logger.info(f'Applying post_func: {post_func}')
            ts.data = post_func(ts.data)

        return ts

    def _compute_features_sensor(self, sensor: str, ts: Timeseries,
                                 settings: Dict):

        features = ts.features_from_dict(
            self.resolution,
            features_meta=settings,
            chunk_size=None)

        # append source to feature names
        featnames = features.names
        featnames = [sensor + '-' + f for f in featnames]
        features.names = featnames

        return features

    @staticmethod
    def _check_inputs(inputs: Dict[str, xr.DataArray]):
        req_dims = ['bands', 't', 'x', 'y']

        for sensor in inputs.keys():
            for dim in req_dims:
                if dim not in inputs[sensor].dims:
                    raise KeyError(
                        (f'Dimension `{dim}` not found '
                         f'in input DataArray of sensor `{sensor}`'))

    @staticmethod
    def _check_input_ranges(ts: Dict[str, Timeseries],
                            tolerance_ratio: np.float32 = 0.01):
        """Method to check whether timeseries have expected range
        of values. Otherwise an Exception will be raised.

        Args:
            ts (Dict[str, Timeseries]): input dict with timeseries
                    to check
            tolerance_ratio (np.float32): relative amount of pixel values
                    that can be out of range before failing. If below
                    this threshold, out of range values are clamped to
                    the min/max range.

        Raises:
            UnexpectedRangeError: if values exceed the expected range
                    for the sensor
        """
        for sensor in ts.keys():
            logger.info(f'Checking input range validity for {sensor}')
            if np.nanmin(ts[sensor].data) < VALID_SENSOR_RANGES[sensor][0]:
                # Report on the issue
                invalid_values = (ts[sensor].data.ravel()[
                                  ts[sensor].data.ravel() <
                                  VALID_SENSOR_RANGES[sensor][0]])
                unique, counts = np.unique(
                    invalid_values, return_counts=True
                )
                num_elements = np.prod(ts[sensor].data.shape)
                unvalid_fraction = (len(invalid_values) / num_elements) * 100.

                # Construct log message
                msg = (f'Got an unexpected min value for `{sensor}`:'
                       f'{np.nanmin(ts[sensor].data)} with values'
                       f'counts {unique} :: {counts}. In patch '
                       f'{unvalid_fraction:.1f}% of values are invalid.')

                if unvalid_fraction < 100 * tolerance_ratio:
                    # Below tolerance, log warning and set pixels to NaN
                    logger.warning(msg)
                    logger.warning((f'Less than {tolerance_ratio * 100} '
                                    r'% of values out of range: setting '
                                    'values to min value.'))
                    ts[sensor].data[ts[sensor].data < VALID_SENSOR_RANGES[
                        sensor][0]] = VALID_SENSOR_RANGES[sensor][0]

                else:
                    # Too many pixel values out of range: raise error
                    raise UnexpectedRangeError(msg)

            if np.max(ts[sensor].data) > VALID_SENSOR_RANGES[sensor][1]:
                # Report on the issue
                invalid_values = (ts[sensor].data.ravel()[
                                  ts[sensor].data.ravel() >
                                  VALID_SENSOR_RANGES[sensor][1]])
                unique, counts = np.unique(
                    invalid_values, return_counts=True
                )
                num_elements = np.prod(ts[sensor].data.shape)
                unvalid_fraction = (len(invalid_values) / num_elements) * 100.

                # Construct log message
                msg = (f'Got an unexpected max value for `{sensor}`:'
                       f'{np.nanmax(ts[sensor].data)} with values'
                       f'counts {unique} :: {counts}. In patch '
                       f'{unvalid_fraction:.1f}% of values are invalid.')

                if unvalid_fraction < 100 * tolerance_ratio:
                    # Below tolerance, log warning and set pixels to NaN
                    logger.warning(msg)
                    logger.warning((f'Less than {tolerance_ratio * 100} '
                                    r'% of values out of range: setting '
                                    'values to max value.'))
                    ts[sensor].data[ts[sensor].data > VALID_SENSOR_RANGES[
                        sensor][1]] = VALID_SENSOR_RANGES[sensor][1]

                else:
                    # Too many pixels out of range: raise error
                    raise UnexpectedRangeError(msg)

        return ts

    @staticmethod
    def _to_satio_timeseries(inputs: Dict[str, xr.DataArray]) -> Dict[str, Timeseries]:  # NOQA
        '''Method to transform input DataArray to satio Timeseries
        '''

        logger.info('Transforming xr.DataArray into satio.Timeseries ...')

        satio_ts = {}

        for sensor in inputs.keys():

            da = inputs[sensor]

            # Make sure we have a controlled dimension order
            da = da.transpose('bands', 't', 'x', 'y')

            # Transform the DataArray to satio Timeseries
            ts = Timeseries(
                data=da.values,
                timestamps=list(da.coords['t'].values),
                bands=list(da.coords['bands'].values)
            )

            satio_ts[sensor] = ts

        return satio_ts

    @staticmethod
    def _perform_gap_checks(ts: Dict[str, Timeseries],
                            start_date: str,
                            end_date: str,
                            max_gap: int):

        def raise_gap_failure(gap, gapkind, threshold, sensor):
            msg = (f'Incomplete timeseries `{sensor}`: '
                   f'got a value of {gap} days for `{gapkind}` '
                   f'which exceeds the threshold of {threshold}.')
            logger.error(msg)
            raise GapError(msg)

        start_date = pd.to_datetime(start_date)
        end_date = pd.to_datetime(end_date)

        for sensor in ts.keys():
            ts_sensor = ts[sensor]
            ts_start = ts_sensor.timestamps.min()
            ts_end = ts_sensor.timestamps.max()

            # Check timeseries start
            gapstart = ts_start - start_date
            gapstart = gapstart.days
            gapstart = 0 if gapstart < 0 else gapstart

            # Check timeseries end
            gapend = end_date - ts_end
            gapend = gapend.days
            gapend = 0 if gapend < 0 else gapend

            # Check max gap
            maxgap = pd.Series(ts_sensor.timestamps).diff().max().days

            # Report on completeness
            logger.info(f'{sensor} first obs: {gapstart}')
            logger.info(f'{sensor} last obs: {gapend}')
            logger.info(f'{sensor} largest gap: {maxgap}')

            # Fail processing if timeseries is incomplete
            if gapstart > max_gap:
                raise_gap_failure(gapstart, 'gapstart', max_gap, sensor)
            if gapend > max_gap:
                raise_gap_failure(gapend, 'gapend', max_gap, sensor)
            if maxgap > max_gap:
                raise_gap_failure(maxgap, 'maxgap', max_gap, sensor)
