import logging
import numpy as np
import numbers

logger = logging.getLogger(__name__)

def flaglocalminima(npdatacube, maxdip=None, maxdif=None,
                    maxgap=None, maxpasses=1, verbose=True):
    '''
    Remove dips and difs (replace by np.nan) from the input npdatacube.

    dip on position i:
        (xn - xi) < (n-l) * maxdip AND (xm - xi) < (m-i) * maxdip
        n first not-None position with value 'left' of i
        m first not-None position with value 'right' of i

    dif on position i:
        (xn - xi) < (n-l) * maxdif OR (xm - xi) < (m-i) * maxdif
        n first not-None position with value 'left' of i
        m first not-None position with value 'right' of i
    '''

    return _flaglocalextrema_ct(npdatacube, maxdip, maxdif,
                                maxgap=maxgap, maxpasses=maxpasses,
                                doflagmaxima=False, verbose=verbose)

def _flaglocalextrema_ct(npdatacube, maxdip, maxdif, maxgap=None,
                         maxpasses=1, doflagmaxima=False, verbose=True):
    #
    #
    #
    def slopeprev(npdatacube, maxgap):
        """
        """
        shiftedval = np.full_like(npdatacube, np.nan, dtype=float)
        shifteddis = np.full_like(npdatacube,         1, dtype=int)
        numberofrasters = npdatacube.shape[0]
        shiftedval[1:numberofrasters, ...] = npdatacube[0:numberofrasters-1, ...]

        if np.isscalar(npdatacube[0]):
            nans = np.isnan(npdatacube)
            for iIdx in range(1, numberofrasters):
                if nans[iIdx-1]:
                    shiftedval[iIdx] = shiftedval[iIdx-1]  # can still be nan in case series started with nan
                    shifteddis[iIdx] = shifteddis[iIdx-1] + 1

        else:
            for iIdx in range(1, numberofrasters):
                nans = np.isnan(npdatacube[iIdx-1])
                shiftedval[iIdx][nans] = shiftedval[iIdx-1][nans]
                shifteddis[iIdx][nans] = shifteddis[iIdx-1][nans] + 1

        slopetoprev = (shiftedval-npdatacube)/shifteddis
        comparable = ~np.isnan(slopetoprev)
        if maxgap is not None:
            comparable &= shifteddis <= maxgap

        return slopetoprev, comparable

    def slopenext(npdatacube, maxgap):
        """
        """
        shiftedval = np.full_like(npdatacube, np.nan, dtype=float)
        shifteddis = np.full_like(npdatacube,         1, dtype=int)
        numberofrasters = npdatacube.shape[0]
        shiftedval[0:numberofrasters-1, ...] = npdatacube[1:numberofrasters, ...]

        if np.isscalar(npdatacube[0]):
            nans = np.isnan(npdatacube)
            for iIdx in range(numberofrasters-2, -1, -1):
                if nans[iIdx+1]:
                    shiftedval[iIdx] = shiftedval[iIdx+1]  # can still be nan in case series started with nan
                    shifteddis[iIdx] = shifteddis[iIdx+1] + 1

        else:
            for iIdx in range(numberofrasters-2, -1, -1):
                nans = np.isnan(npdatacube[iIdx+1])
                shiftedval[iIdx][nans] = shiftedval[iIdx+1][nans]
                shifteddis[iIdx][nans] = shifteddis[iIdx+1][nans] + 1

        slopetonext = (shiftedval-npdatacube)/shifteddis
        comparable = ~np.isnan(slopetonext)
        if maxgap is not None:
            comparable &= shifteddis <= maxgap

        return slopetonext, comparable

    #
    #
    #
    def masklocalminima(slopesraster, thresholdvalue):
        return slopesraster > thresholdvalue

    def masklocalmaxima(slopesraster, thresholdvalue):
        return slopesraster < thresholdvalue
    if doflagmaxima:
        maskextrema = masklocalmaxima
    else:
        maskextrema = masklocalminima

    #
    #
    #
    if maxdip is not None and (not isinstance(maxdip, numbers.Real) or (float(maxdip) != maxdip) or (maxdip <= 0)):
        raise ValueError("maxdip must be positive number or None")
    if maxdif is not None and (not isinstance(maxdif, numbers.Real) or (float(maxdif) != maxdif) or (maxdif <= 0)):
        raise ValueError("maxdif must be positive number or None")
    if maxgap is not None and (not isinstance(maxgap, numbers.Real) or (int(maxgap) != maxgap) or (maxgap <= 0)):
        raise ValueError("maxgap must be positive integer or None")

    #
    #
    #
    initialnumberofvalues = np.sum(~np.isnan(npdatacube))
    previousnumberofvalues = initialnumberofvalues
    for iteration in range(maxpasses):
        #
        #
        #
        prevslope, prevcomparable = slopeprev(npdatacube, maxgap)
        nextslope, nextcomparable = slopenext(npdatacube, maxgap)
        #
        #
        #
        isdip = None
        if maxdip is not None:
            isdip = prevcomparable & nextcomparable
            isdip[isdip] = isdip[isdip] & maskextrema(prevslope[isdip], maxdip)
            isdip[isdip] = isdip[isdip] & maskextrema(nextslope[isdip], maxdip)

        isdif = None
        if maxdif is not None:
            isdif = np.full_like(npdatacube, False, dtype=bool)
            isdif[prevcomparable] = isdif[prevcomparable] | maskextrema(prevslope[prevcomparable], maxdif)
            isdif[nextcomparable] = isdif[nextcomparable] | maskextrema(nextslope[nextcomparable], maxdif)

        if isdip is not None:
            npdatacube[isdip] = np.nan
        if isdif is not None:
            npdatacube[isdif] = np.nan

        #
        #
        #
        remainingnumberofvalues = np.sum(~np.isnan(npdatacube))
        removednumberofvalues = previousnumberofvalues - remainingnumberofvalues
        if verbose:
            logger.debug("localextrema_ct pass(%s) - removed %s values. %s values remaining. %s values removed in total" %
                         (iteration+1, removednumberofvalues, remainingnumberofvalues, initialnumberofvalues - remainingnumberofvalues))
        previousnumberofvalues = remainingnumberofvalues
        if removednumberofvalues <= 0 and 1 < maxpasses:
            if verbose:
                logger.debug("localextrema_ct pass(%s) - exits" % (iteration+1))
            break
    #
    #
    #
    return npdatacube