import numpy as np
from skimage.morphology import selem, binary_erosion
from scipy.ndimage import (binary_dilation, binary_erosion)

SCL_MASK_VALUES = [0, 1, 3, 8, 9, 10, 11]


def dilate_mask(mask, dilate_r):

    dilate_disk = selem.disk(dilate_r)
    for i in range(mask.shape[0]):
        mask[i] = binary_dilation(mask[i], dilate_disk)

    return mask


def erode_mask(mask, erode_r):
    erode_disk = selem.disk(erode_r)
    for i in range(mask.shape[0]):
        mask[i] = binary_erosion(mask[i], erode_disk)

    return mask


def scl_mask(scl_data,
             *,
             mask_values=SCL_MASK_VALUES,
             erode_r=None,
             dilate_r=None,
             nodata=0,
             max_invalid_ratio=None,
             **kwargs):
    """
    From a timeseries (t, y, x) returns a binary mask False for the
    given mask_values and True elsewhere.

    Parameters:
    -----------
    slc_data: 3D array
        Input array for computing the mask

    mask_values: list
        values to set to False in the mask

    erode_r : int
        Radius for eroding disk on the mask

    dilate_r : int
        Radius for dilating disk on the mask

    nodata : int
        Nodata value used to count observations

    max_invalid_ratio : float
        Will set mask values to True, when they have an
        invalid_ratio > max_invalid_ratio

    Returns:
    --------
    mask : 3D bool array
        mask True for valid pixels, False for invalid

    """
    scl_data = np.squeeze(scl_data)

    ts_obs = scl_data != nodata

    obs = ts_obs.sum(axis=0)

    mask = np.isin(scl_data, mask_values)
    ma_mask = (mask & ts_obs)

    invalid_before = ma_mask.sum(axis=0) / obs * 100
    invalid_before = invalid_before.astype(int)

    if erode_r is not None:
        if erode_r > 0:
            mask = erode_mask(mask, erode_r)

    if dilate_r is not None:
        if dilate_r > 0:
            mask = dilate_mask(mask, dilate_r)

    ma_mask = (mask & ts_obs)
    invalid_after = ma_mask.sum(axis=0) / obs * 100
    invalid_after = invalid_after.astype(int)

    # invert values to have True for valid pixels and False for clouds
    mask = ~mask

    if max_invalid_ratio is not None:
        max_invalid_mask = invalid_after > max_invalid_ratio * 100
        mask = mask | np.broadcast_to(max_invalid_mask, mask.shape)

    return mask
