import xarray as xr
import numpy as np

""""
Wrapper functions to composite the satellite data for a certain defined time step
"""

def take_best_acquisition(x, n_steps=6):
    '''Method to compute tsteps features.
    '''

    def _take_best(group: xr.DataArray):
        best_image = group.isel(t=group.notnull().sum(
            dim=('x', 'y')).argmax())
        return best_image

    da = xr.DataArray(x, dims=('t', 'x', 'y'))
    da_resampled = da.groupby_bins('t', bins=n_steps).map(_take_best)

    return da_resampled.values

def get_mean_acquisition(x, n_steps=6):
    da = xr.DataArray(x, dims=('t', 'x', 'y'))
    da_resampled = da.groupby_bins('t', bins=n_steps).mean()

    return da_resampled.values

def get_median_acquisition(x, n_steps=6):
    da = xr.DataArray(x, dims=('t', 'x', 'y'))
    da_resampled = da.groupby_bins('t', bins=n_steps).median()

    return da_resampled.values


def get_preprocessing_settings():
    return {
        'S2':
            {
                'dtype': np.float32,
                'bands': ["B02", "B03", "B04", "B08"],
                'composite': {
                    'freq': 10,
                    'start': None,
                    'end': None
                },
                'interpolate': False
            }
    }


def get_feature_settings(n_tsteps, agg_method):
    ## Define which type of temporal aggregation function should be used
    if agg_method == 'mean':
        agg_function = get_mean_acquisition
    elif agg_method == 'median':
        agg_function = get_median_acquisition
    elif agg_method  == 'best':
        agg_function = take_best_acquisition
    else:
        raise ValueError(f'Temporal aggregation method {agg_method} is not supported')


    return {
        'S2':
            {
                "tsteps": {
                    "function": agg_function,
                    "parameters": {
                        'n_steps': n_tsteps,
                    },
                    "names": [f'ts{i}' for i in range(n_tsteps)]
                },
            }
    }
