import logging
import numpy as np
import pandas as pd
import geopandas as gpd
import os

log = logging.getLogger(__name__)

def get_logger(name=__name__, filename=None, level=logging.INFO):
    log_formatter = logging.Formatter("%(asctime)s [%(levelname)s - THREAD: %(threadName)s - %(name)s] : %(message)s")
    log = logging.getLogger(name)

    if filename is not None:
        fileHandler = logging.FileHandler(filename, mode='w')
        fileHandler.setFormatter(log_formatter)
        log.addHandler(fileHandler)

    consoleHandler = logging.StreamHandler()
    consoleHandler.setFormatter(log_formatter)
    log.addHandler(consoleHandler)

    log.setLevel(level)

    return log

def _get_scaling_range(source):
    ranges = {}
    ranges['s2_fapar'] = [0, 1]
    ranges['s2_fapar_8band'] = [0, 1]
    ranges['s2_fcover'] = [0, 1]
    ranges['s2_ndvi'] = [0, 0.92]
    ranges['s1_vv'] = [-20, -2]
    ranges['s1_vh'] = [-33, -8]
    ranges['s1_vv_smooth'] = [-19, -3]
    ranges['s1_vh_smooth'] = [-25, -11]

    if source not in ranges.keys():
        raise Exception('Datasource "{}" not in known scalers!'.format(source))

    return ranges[source]

def minmaxscaler(data, source, range=None):
    '''
    Function to rescale input data to [-1,1] range based on predefined min and max values
    Optionally, the range can be manually provided, but use this with caution !!!
    :param data: input data to rescale
    :param source: string describing datasource [s2_fapar, s2_fapar_8band, s1_vv, s1_vh, s1_vv_smooth, s1_vh_smooth]
    :param range: optional list of [min, max] value to use for scaling
    :return: scaled data in [-1,1] range
    '''

    if range is not None:
        log.warning('Warning: custom scaling range being used for {}!'.format(source))
        manualscaling = True
    else:
        manualscaling = False
    range = _get_scaling_range(source) if not manualscaling else range

    # Scale between -1 and 1
    datarescaled = 2 * (data - range[0]) / (range[1] - range[0]) - 1

    # if not manualscaling:
    #     # todo: check if it's OK to do this only during inference
    #     datarescaled[datarescaled > 1] = 1
    #     datarescaled[datarescaled < -1] = -1
    return datarescaled

def minmaxunscaler(data, source, range=None, clip=False):
    '''
    Function to unscale input data from [-1,1] range based on predefined min and max values
    Optionally, the range can be manually provided, but use this with caution !!!
    :param data: input data to unscale
    :param source: string describing datasource [s2_fapar, s2_fapar_8band, s1_vv, s1_vh, s1_vv_smooth, s1_vh_smooth]
    :param range: optional list of [min, max] value to use for unscaling
    :param clip: whether or not to clip uncaled values to min-max range
    :return: unscaled data in original range
    '''

    if range is not None: log.warning('Warning: custom scaling range being used for {}!'.format(source))
    range = _get_scaling_range(source) if range is None else range

    # Unscale
    dataunscaled = 0.5 * (data + 1) * (range[1] - range[0]) + range[0]

    # Optional clipping
    if clip:
        log.info('Clipping values to [{},{}] range!'.format(range[0], range[1]))
        dataunscaled[dataunscaled < range[0]] = range[0]
        dataunscaled[dataunscaled > range[1]] = range[1]

    return dataunscaled

def to_linear(data_db):
    '''
    Function to transform numpy.ndarray from decibels to linear values
    :param data_db: numpy.ndarray with decibel values
    :return: numpy.ndarray with linear values
    '''
    return np.power(10, data_db/10)

def to_db(data_linear):
    '''
    Function to transform numpy.ndarray from linear to decibel values
    :param data_linear: numpy.ndarray with linear values
    :return: numpy.ndarray with decibel values
    '''
    return 10*np.log10(data_linear)

def sigma_to_gamma(sigma_linear, incidenceangle_degree):
    '''
    Function to transform sigma0 to gamma0 using cosine correction
    :param sigma_linear: numpy.ndarray holding sigma0 backscatter values in linear format
    :param incidenceangle_degree: numpy.ndarray holding incidence angles in degrees
    :return: numpy.ndarray holding calculated gamma0 values in linear format
    '''

    assert(sigma_linear.shape == incidenceangle_degree.shape)

    return sigma_linear / np.cos(np.deg2rad(incidenceangle_degree))

def combine_s1_orbits(s1data, fieldID):
    '''
    Function to combine ascending and descending time series into one merged time series
    :param s1data: dict with ascending and descending time series
    :param fieldID: str representing the fieldID to produce the combined series for
    :return: dict with merged time series
    '''

    log.warning('S1 input data provided as separate ASC/DES orbits -> need to combine orbits ...')

    s1data_merged = {'VV': {},
                     'VH': {},
                     'incidenceAngle': {}}

    s1data_merged['VV'][fieldID] = to_db(pd.concat([to_linear(s1data['ASCENDING']['VV'][fieldID]),
                                              to_linear(s1data['DESCENDING']['VV'][fieldID])]).groupby(level=0).mean())
    s1data_merged['VH'][fieldID] = to_db(pd.concat([to_linear(s1data['ASCENDING']['VH'][fieldID]),
                                              to_linear(s1data['DESCENDING']['VH'][fieldID])]).groupby(level=0).mean())
    s1data_merged['incidenceAngle'][fieldID] = pd.concat([s1data['ASCENDING']['incidenceAngle'][fieldID],
                                                        s1data['DESCENDING']['incidenceAngle'][fieldID]]).groupby(level=0).mean()

    return s1data_merged

def preprocess_sentinel1(sentinel1_timeseries, S1layername, sentinel1_incidenceangle_timeseries,
                         S1smoothing, S1var='gamma'):
    '''
    Function preprocess S1 time series. Includes conversion from sigma to gamma if needed, an optional smoothing,
    and scaling to [-1,1] range.
    :param sentinel1_timeseries: input S1 time series -> SHOULD BE IN DB!!!
    :param S1var:
    :param S1layername:
    :param sentinel1_incidenceangle_timeseries:
    :param S1smoothing:
    :return:
    '''

    # To natural, since current S1 input is always in dB
    sentinel1_timeseries = np.power(10, sentinel1_timeseries / 10)

    # If the input is sigma0, we need to convert to gamma0
    # NOTE: even if the conversion is done with gamma as an input, the result is unchanged if the corresponding
    # incidence angle time series is 0.
    if S1var == 'sigma': sigma_to_gamma(sentinel1_timeseries, sentinel1_incidenceangle_timeseries)

    # Linear interpolate NaN values
    sentinel1_timeseries = sentinel1_timeseries.interpolate(method='linear')

    # Smooth S1 if needed
    if S1smoothing is not None:
        S1_currentTSmean = sentinel1_timeseries.rolling(S1smoothing, center=True, min_periods=1).mean()
    else:
        S1_currentTSmean = sentinel1_timeseries

    # Back to dB
    sentinel1_timeseries = to_db(S1_currentTSmean)

    return minmaxscaler(sentinel1_timeseries, S1layername)


def ts_to_dict(S1VV, S1VH, S1IncidenceAngle,
               S2data, identifier, startDate, endDate, S2layername='FAPAR',
               margin_in_days=0):
    '''
    Function to transform pandas time series of input variables into a dictionary that can be used to generate a
    dataStack for a CropSAR model

    :param S1VV: pandas time series of S1 VV backscatter in dB
    :param S1VH: pandas time series of S1 VH backscatter in dB
    :param S1IncidenceAngle: pandas time series of S1 incidence angle in degrees
    :param S2data: pandas time series of S2 data
    :param identifier: string that identifies the object to which the time series belong (e.g. fieldID)
    :param startDate: string representing desired start date of returned time series (format: yyyy-mm-dd)
    :param endDate: string representing desired end date of returned time series (format: yyyy-mm-dd)
    :param S2layername: layername describing the biopar variable of S2
    :return: dictionary in the proper shape to serve as input for getDataStacks or getFullDataStack functions
    '''

    S1VV = ts_to_df(S1VV, identifier, startDate, endDate, margin_in_days)
    S1VH = ts_to_df(S1VH, identifier, startDate, endDate, margin_in_days)
    S1IncidenceAngle = ts_to_df(S1IncidenceAngle, identifier, startDate, endDate, margin_in_days)
    S2data = ts_to_df(S2data, identifier, startDate, endDate, margin_in_days)

    return {
        'S1': {
            'VV': S1VV,
            'VH': S1VH,
            'incidenceAngle': S1IncidenceAngle,
        },
        'S2': {
            S2layername: S2data
        }
    }


def ts_to_df(ts, identifier, start_date, end_date, margin_in_days=0):
    '''
    Helper function to transform a pandas Series or DataFrame into a uniform DataFrame from start_date till end_date
    :param ts: series or dataframe
    :param identifier: name of the series or column
    :param start_date:
    :param end_date:
    :param margin_in_days:
    :return:
    '''

    # If we got a DataFrame, extract the Series for 'identifier'

    if isinstance(ts, pd.DataFrame):
        ts = ts[identifier]

    # Timeseries should contain every day for [start-margin, end+margin]

    margin = pd.to_timedelta(margin_in_days, unit='days')

    min_date = pd.to_datetime(start_date) - margin
    max_date = pd.to_datetime(end_date)   + margin

    full_index = pd.date_range(min_date, max_date)

    # Create a DataFrame for the full date range and fill it with
    # the input series data

    ts = ts.loc[(ts.index >= min_date) & (ts.index <= max_date)]

    df = pd.DataFrame(index=full_index, columns=[identifier], dtype=float)
    df.loc[ts.index, identifier] = ts

    return df


def load_fields(parquetfile):
    '''
    Function to load the cal/val/test fields
    :param parquetfile: file containing the information
    :return:
    '''
    fields = pd.read_parquet(parquetfile)

    CAL_fields = fields[fields['CALVALTEST'] == 'CAL'].index.tolist()
    VAL_fields = fields[fields['CALVALTEST'] == 'VAL'].index.tolist()
    TEST_fields = fields[fields['CALVALTEST'] == 'TEST'].index.tolist()

    return CAL_fields, VAL_fields, TEST_fields

def write_to_temp(tempDir, fieldID, input_output_tuple):
    '''
    Save to numpy array, used in test functions
    :param tempDir:
    :param fieldID:
    :param input_output_tuple:
    :return:
    '''
    log.info('Datastack built, saving to numpy arrays ...')
    np.save(os.path.join(tempDir, fieldID + '_inputs'), input_output_tuple[0])
    np.save(os.path.join(tempDir, fieldID + '_outputs'), input_output_tuple[1])

def generate_outputs(input_df, S2layername='FAPAR'):
    from cropsar.readers import row_to_pandas
    """
    Generate an key value RDD[String,float], mapping field names to output fapar values

    :param input_df: input dataframe, should have a 'name' and 's2_fapar' column
    :param scalers:
    :param S2layername: layername of S2 data [FAPAR, FCOVER]
    :return: rdd containing the outputs
    """

    def _row_to_output(row):
        # generating the output is nothing more than filtering and applying a minmaxscaler
        # note that Spark MLLib offers some built-in functionality as well
        #print(row)
        S2data = row_to_pandas(row)
        # Check if there's enough S2data data, for now 3 points at least
        # Otherwise we skip this field and return
        if np.sum(S2data.notnull()) < 3:
            print('Not enough valid S2 data points! Skipping ...')
            return
        candidatePoints = S2data.loc[S2data.notnull()].iloc[1:]
        scaled = minmaxscaler(candidatePoints, 's2_' + S2layername.lower())
        return scaled

    output_rdd = input_df.select("name", "s2_" + S2layername.lower()).rdd.mapValues(_row_to_output).filter(lambda t:t[1] is not None)
    return output_rdd

def get_delta_obs(input_data):
    """
    This function returns an array containing for each point the amount of time that has passed since the last valid
    observation
    """

    m0 = np.ones(input_data.shape, dtype=float)
    for i, item in enumerate(input_data):
        if np.sum(np.isfinite(item)) == 0:
            m0[i, :] = np.nan
        else:
            idx = np.flatnonzero(~np.isnan(item))
            m0[i, :idx[0]] = 0
            m0[i, idx[1:]] = idx[:-1] - idx[1:] + 1

    out = np.full(input_data.shape, np.nan, dtype=float)
    out[:, 1:] = m0[:, :-1].cumsum(1)
    out[out == 0] = np.nan
    out[np.isfinite(input_data)] = 0
    return out

# ----------------------------------------------------------------------------------------------------------------------
# Below are functions that are not used any longer

def _maskAngle(TS, incidenceAngles):
    '''
    This function takes a Pandas TS of backscatter and a corresponding
    TS of incidence angles and returns only those obs that were made
    at minimum incidence angle
    :param TS:
    :param incidenceAngles:
    :return:
    '''

    TScopy = pd.DataFrame.copy(TS, deep=True)
    minAngle = np.nanmin(incidenceAngles)
    TScopy[np.abs(incidenceAngles - minAngle) > 1] = np.nan
    return TScopy

def compute_delta_obs(S2_currentTSCopy, outputResolution, scalers, scalingMethod):
    # Calculate the time lapsed for each day since the last valid observation
    deltaObs = get_delta_obs(S2_currentTSCopy.values.transpose()).transpose() * outputResolution
    deltaObs[np.isnan(deltaObs)] = 100  # will be rescaled
    deltaObs = pd.DataFrame(index=S2_currentTSCopy.index, data=deltaObs)
    return rescale(deltaObs, scalers['S2']['deltaObs'], scalingMethod)

def rescale(timeseries, scaler, scalingMethod):
    '''

    DEPRECATED rescaling function

    Rescale based on sklearn scaler
    :param timeseries: a pandas time series or a dataframe
    :param scaler: the sklearn scaler
    :param scalingMethod: scalingmethod
    :return: scaled time series or dataframe
    '''
    # If only one dimension, need to convert temporarily from series to dataframe otherwise scaling fails
    if len(timeseries.shape) == 1:
        seriesFlag = True
        timeseries = pd.DataFrame(timeseries)
    else: seriesFlag = False

    # Find NaN values, because scikit-learn < 0.20.0 cannot directly cope with these
    timeseriesData = timeseries.values
    idxValid = np.where(np.isfinite(timeseriesData))
    timeseriesDataValid = timeseriesData[idxValid]
    timeseriesDataValid = scaler.transform(timeseriesDataValid.reshape((-1,1))).ravel()
    if scalingMethod == 'minmax':
        timeseriesDataValid[timeseriesDataValid > 1] = 1
        timeseriesDataValid[timeseriesDataValid < -1] = -1
    timeseriesData[idxValid] = timeseriesDataValid

    # Back to pd object
    timeseries = pd.DataFrame(index=timeseries.index, data=timeseriesData)

    # If input was series, make sure we transform back to series
    if seriesFlag: timeseries = timeseries.squeeze()
    return timeseries


def buffer_geometry(df, distance, **kwargs):

    # If no buffer kwargs were supplied, pick some sensible defaults
    # depending on whether we have an inward or outward buffer

    if not kwargs:
        if distance <= 0:
            kwargs = {'cap_style': 1, 'join_style': 2, 'resolution': 4}
        else:
            kwargs = {'cap_style': 1, 'join_style': 1, 'resolution': 4}

    # Select a lat/lon point per geometry

    points = df.geometry.representative_point()
    points.crs = df.crs
    points = points.to_crs('epsg:4326')

    # For 'traditional' EPSG:4326 use x=lat, y=lon

    lon = points.y
    lat = points.x

    # Use longitude to determine UTM zone 'band'

    zone = (np.floor((lon + 180.0) / 6.0) % 60)
    zone = zone.astype(np.int) + 1

    # Use latitude to determine north/south

    epsg = zone + np.where(lat >= 0.0, 32600, 32700)

    # Now convert to UTM and apply buffer per EPSG code/UTM zone

    buffered = gpd.GeoSeries(index=df.index, crs=df.crs)

    for n in epsg.unique():
        epsg_sel = (epsg == n)

        geometry = df.loc[epsg_sel].geometry

        # Convert to UTM, apply buffer and convert back to
        # the original CRS

        geometry = geometry.to_crs('epsg:{}'.format(n))
        geometry = geometry.buffer(distance, **kwargs)
        geometry = geometry.to_crs(df.crs)

        buffered.loc[epsg_sel] = geometry

    return buffered


def buffer_geometry_if_not_empty(df, distance, **kwargs):

    buffered = buffer_geometry(df, distance, **kwargs)

    # If after buffering a geometry becomes empty, just use
    # the original geometry.  In that case we still apply a
    # buffer of 0, as an easy fix for some 'bad' geometries.

    empty_sel = buffered.is_empty

    buffered.loc[empty_sel] = df.loc[empty_sel].geometry.buffer(0)

    return buffered
