import numbers
import numpy as np
import xarray as xr
import pandas as pd
from skimage.morphology import footprints, binary_erosion
from scipy.ndimage import (binary_dilation, binary_erosion)
from loguru import logger
from satio.timeseries import Timeseries

SCL_MASK_VALUES = [0, 1, 3, 8, 9, 10]


def dilate_mask(mask, dilate_r):

    dilate_disk = footprints.disk(dilate_r)
    for i in range(mask.shape[0]):
        mask[i] = binary_dilation(mask[i], dilate_disk)

    return mask


def erode_mask(mask, erode_r):
    erode_disk = footprints.disk(erode_r)
    for i in range(mask.shape[0]):
        mask[i] = binary_erosion(mask[i], erode_disk)

    return mask


def scl_mask(scl_data,
             *,
             mask_values=SCL_MASK_VALUES,
             erode_r=None,
             dilate_r=None,
             nodata=0,
             max_invalid_ratio=None,
             **kwargs):
    """
    From a timeseries (t, y, x) returns a binary mask False for the
    given mask_values and True elsewhere.

    Parameters:
    -----------
    slc_data: 3D array
        Input array for computing the mask

    mask_values: list
        values to set to False in the mask

    erode_r : int
        Radius for eroding disk on the mask

    dilate_r : int
        Radius for dilating disk on the mask

    nodata : int
        Nodata value used to count observations

    max_invalid_ratio : float
        Will set mask values to True, when they have an
        invalid_ratio > max_invalid_ratio

    Returns:
    --------
    mask : 3D bool array
        mask True for valid pixels, False for invalid

    """
    scl_data = np.squeeze(scl_data)

    ts_obs = scl_data != nodata

    obs = ts_obs.sum(axis=0)

    mask = np.isin(scl_data, mask_values)
    ma_mask = (mask & ts_obs)

    invalid_before = ma_mask.sum(axis=0) / obs * 100
    invalid_before = invalid_before.astype(int)

    if erode_r is not None:
        if erode_r > 0:
            mask = erode_mask(mask, erode_r)

    if dilate_r is not None:
        if dilate_r > 0:
            mask = dilate_mask(mask, dilate_r)

    ma_mask = (mask & ts_obs)
    invalid_after = ma_mask.sum(axis=0) / obs * 100
    invalid_after = invalid_after.astype(int)

    # invert values to have True for valid pixels and False for clouds
    mask = ~mask

    if max_invalid_ratio is not None:
        max_invalid_mask = invalid_after > max_invalid_ratio * 100
        mask = mask | np.broadcast_to(max_invalid_mask, mask.shape)

    return mask


def multitemporal_mask(ts: Timeseries, prior_mask: np.ndarray = None):
    '''
    Method to flag undetected clouds/shadows using multitemporal
    gap search approach.
    Cfr. work of Dominique Haesen (VITO)
    '''

    logger.info(('Performing multitemporal '
                 'cloud/shadow filtering ...'))

    # Load the raw bands needed to compute NDVI
    logger.info(f"Loading bands: ['B04', 'B08']")
    for b in ['B04', 'B08']:
        assert b in ts.bands, f'{b} not found!'

    # Make sure data is in uint16
    ts.data = ts.data.astype(np.uint16)

    # If a prior mask is provided, apply it to the timeseries
    if prior_mask is not None:
        ts_masked = ts.mask(prior_mask, drop_nodata=False)
    else:
        ts_masked = ts

    # Convert to float32 and replace missing values (0)
    # with NaN
    ts_masked.data = ts_masked.data.astype(np.float32)
    ts_masked.data[ts_masked.data == 0] = np.nan

    # Compute NDVI
    ndvi = ((ts_masked['B08'].data - ts_masked['B04'].data) /
            (ts_masked['B08'].data + ts_masked['B04'].data))[0, ...]

    # Make a DataArray for easy daily resampling
    ndvi_da = xr.DataArray(data=ndvi,
                           coords={'time': ts_masked.timestamps},
                           dims=['time', 'x', 'y'])

    # Resample to daily, missing data will be NaN
    daily_daterange = pd.date_range(
        ts_masked.timestamps[0],
        ts_masked.timestamps[-1] + pd.Timedelta(days=1),
        freq='D').floor('D')
    ndvi_daily = ndvi_da.reindex(time=daily_daterange,
                                 method='bfill', tolerance='1D')

    # Run multitemporal dip detection
    # Need to do it in slices, to avoid memory issues
    step = 256
    for idx in np.r_[:ndvi_daily.values.shape[1]:step]:
        for idy in np.r_[:ndvi_daily.values.shape[2]:step]:
            logger.debug((f"idx: {idx} - {idx+step} "
                          f"| idy: {idy} - {idy+step}"))
            ndvi_daily.values[
                :, idx:idx+step, idy:idy+step] = flaglocalminima(
                ndvi_daily.values[:, idx:idx+step, idy:idy+step],
                maxdip=0.01,
                maxdif=0.1,
                maxgap=60,
                maxpasses=5)

    # Subset on the original timestamps
    ndvi_cleaned = ndvi_daily.sel(time=ts.timestamps,
                                  method='ffill',
                                  tolerance='1D')

    # Extract the mask: True is invalid, False is valid
    mask = np.isnan(ndvi_cleaned.values)

    # Invert the mask
    mask = ~mask

    # Return the enhanced mask
    return mask


def flaglocalminima(npdatacube, maxdip=None, maxdif=None,
                    maxgap=None, maxpasses=1, verbose=True):
    '''
    Remove dips and difs (replace by np.nan) from the input npdatacube.

    dip on position i:
        (xn - xi) < (n-l) * maxdip AND (xm - xi) < (m-i) * maxdip
        n first not-None position with value 'left' of i
        m first not-None position with value 'right' of i

    dif on position i:
        (xn - xi) < (n-l) * maxdif OR (xm - xi) < (m-i) * maxdif
        n first not-None position with value 'left' of i
        m first not-None position with value 'right' of i
    '''

    return _flaglocalextrema_ct(npdatacube, maxdip, maxdif,
                                maxgap=maxgap, maxpasses=maxpasses,
                                doflagmaxima=False, verbose=verbose)


def flaglocalmaxima(npdatacube, maxdip=None, maxdif=None,
                    maxgap=None, maxpasses=1, verbose=True):
    return _flaglocalextrema_ct(npdatacube, maxdip, maxdif,
                                maxgap=maxgap, maxpasses=maxpasses,
                                doflagmaxima=True, verbose=verbose)


def _flaglocalextrema_ct(npdatacube, maxdip, maxdif, maxgap=None,
                         maxpasses=1, doflagmaxima=False, verbose=True):
    #
    #
    #
    def slopeprev(npdatacube, maxgap):
        """
        """
        shiftedval = np.full_like(npdatacube, np.nan, dtype=float)
        shifteddis = np.full_like(npdatacube,         1, dtype=int)
        numberofrasters = npdatacube.shape[0]
        shiftedval[1:numberofrasters, ...] = npdatacube[0:numberofrasters-1, ...]

        if np.isscalar(npdatacube[0]):
            nans = np.isnan(npdatacube)
            for iIdx in range(1, numberofrasters):
                if nans[iIdx-1]:
                    # can still be nan in case series started with nan
                    shiftedval[iIdx] = shiftedval[iIdx-1]
                    shifteddis[iIdx] = shifteddis[iIdx-1] + 1

        else:
            for iIdx in range(1, numberofrasters):
                nans = np.isnan(npdatacube[iIdx-1])
                shiftedval[iIdx][nans] = shiftedval[iIdx-1][nans]
                shifteddis[iIdx][nans] = shifteddis[iIdx-1][nans] + 1

        slopetoprev = (shiftedval-npdatacube)/shifteddis
        comparable = ~np.isnan(slopetoprev)
        if maxgap is not None:
            comparable &= shifteddis <= maxgap

        return slopetoprev, comparable

    def slopenext(npdatacube, maxgap):
        """
        """
        shiftedval = np.full_like(npdatacube, np.nan, dtype=float)
        shifteddis = np.full_like(npdatacube,         1, dtype=int)
        numberofrasters = npdatacube.shape[0]
        shiftedval[0:numberofrasters -
                   1, ...] = npdatacube[1:numberofrasters, ...]

        if np.isscalar(npdatacube[0]):
            nans = np.isnan(npdatacube)
            for iIdx in range(numberofrasters-2, -1, -1):
                if nans[iIdx+1]:
                    # can still be nan in case series started with nan
                    shiftedval[iIdx] = shiftedval[iIdx+1]
                    shifteddis[iIdx] = shifteddis[iIdx+1] + 1

        else:
            for iIdx in range(numberofrasters-2, -1, -1):
                nans = np.isnan(npdatacube[iIdx+1])
                shiftedval[iIdx][nans] = shiftedval[iIdx+1][nans]
                shifteddis[iIdx][nans] = shifteddis[iIdx+1][nans] + 1

        slopetonext = (shiftedval-npdatacube)/shifteddis
        comparable = ~np.isnan(slopetonext)
        if maxgap is not None:
            comparable &= shifteddis <= maxgap

        return slopetonext, comparable

    #
    #
    #
    def masklocalminima(slopesraster, thresholdvalue):
        return slopesraster > thresholdvalue

    def masklocalmaxima(slopesraster, thresholdvalue):
        return slopesraster < thresholdvalue
    if doflagmaxima:
        maskextrema = masklocalmaxima
    else:
        maskextrema = masklocalminima

    #
    #
    #
    if maxdip is not None and (not isinstance(maxdip, numbers.Real) or (float(maxdip) != maxdip) or (maxdip <= 0)):
        raise ValueError("maxdip must be positive number or None")
    if maxdif is not None and (not isinstance(maxdif, numbers.Real) or (float(maxdif) != maxdif) or (maxdif <= 0)):
        raise ValueError("maxdif must be positive number or None")
    if maxgap is not None and (not isinstance(maxgap, numbers.Real) or (int(maxgap) != maxgap) or (maxgap <= 0)):
        raise ValueError("maxgap must be positive integer or None")

    #
    #
    #
    initialnumberofvalues = np.sum(~np.isnan(npdatacube))
    previousnumberofvalues = initialnumberofvalues
    for iteration in range(maxpasses):
        #
        #
        #
        prevslope, prevcomparable = slopeprev(npdatacube, maxgap)
        nextslope, nextcomparable = slopenext(npdatacube, maxgap)
        #
        #
        #
        isdip = None
        if maxdip is not None:
            isdip = prevcomparable & nextcomparable
            isdip[isdip] = isdip[isdip] & maskextrema(prevslope[isdip], maxdip)
            isdip[isdip] = isdip[isdip] & maskextrema(nextslope[isdip], maxdip)

        isdif = None
        if maxdif is not None:
            isdif = np.full_like(npdatacube, False, dtype=bool)
            isdif[prevcomparable] = isdif[prevcomparable] | maskextrema(
                prevslope[prevcomparable], maxdif)
            isdif[nextcomparable] = isdif[nextcomparable] | maskextrema(
                nextslope[nextcomparable], maxdif)

        if isdip is not None:
            npdatacube[isdip] = np.nan
        if isdif is not None:
            npdatacube[isdif] = np.nan

        #
        #
        #
        remainingnumberofvalues = np.sum(~np.isnan(npdatacube))
        removednumberofvalues = previousnumberofvalues - remainingnumberofvalues
        if verbose:
            logger.debug("localextrema_ct pass(%s) - removed %s values. %s values remaining. %s values removed in total" %
                         (iteration+1, removednumberofvalues, remainingnumberofvalues, initialnumberofvalues - remainingnumberofvalues))
        previousnumberofvalues = remainingnumberofvalues
        if removednumberofvalues <= 0 and 1 < maxpasses:
            if verbose:
                logger.debug("localextrema_ct pass(%s) - exits" %
                             (iteration+1))
            break
    #
    #
    #
    return npdatacube
