from pyspark.sql import SparkSession
from pyspark.sql import functions as functions
from pyspark import Row
import pandas as pd
import os
import pkgutil
import logging

log = logging.getLogger(__name__)


try:
    import _pickle as pickle
except:
    import pickle


def get_spark_sql():
    # initialise sparkContext
    sc = get_spark_context()
    from pyspark.sql import SQLContext
    sqlContext = SQLContext(sc)
    return sqlContext

def get_spark_context(name="DatastackGeneration_Program_SPARK", local=False):
    if not local:
        spark = SparkSession.builder \
            .appName(name) \
            .config('spark.executor.memory', '4G') \
            .config('spark.driver.memory', '4G') \
            .config("spark.jars.packages", "org.tensorflow:spark-tensorflow-connector_2.11:1.11.0") \
            .getOrCreate()
        sc = spark.sparkContext
    else:
        log.info('Running SPARK in local mode!')
        spark = SparkSession.builder \
            .appName(name) \
            .master('local[1]') \
            .config('spark.driver.host', '127.0.0.1') \
            .config('spark.executor.memory', '4G') \
            .config('spark.driver.memory', '4G') \
            .config("spark.jars.packages", "org.tensorflow:spark-tensorflow-connector_2.11:1.11.0") \
            .getOrCreate()
        sc = spark.sparkContext
    return sc

from sys import platform
from pathlib import Path
if platform == 'win32' or platform == 'win64': dataDir = Path('O:')
else: dataDir = Path('/data/CropSAR')
CROPSAR_BASE = dataDir

def read_parquet_timeseries(parquet_file:str, name = "timeseries"):
    """
    Reads a parquet file
    :param parquet_file: the file name, for example: 'file:///home/driesj/alldata/cropsar/S1_DESCENDING_VH.parquet'
    :return: A pyspark DataFrame where each Row contains the name of the timeseries, and another row representing the timeseries.
    """
    sqlContext = get_spark_sql()
    parquet_file = str(parquet_file)

    if(not parquet_file.startswith("file:") and not parquet_file.startswith("hdfs:") and not parquet_file.startswith("http")):
        parquet_file = "file:///" + parquet_file
    df = sqlContext.read.parquet(parquet_file)

    ts_df = combine_timeseries_columns(df, name)

    return ts_df


def combine_timeseries_columns(timeseries_dataframe, ts_column_name):
    # get all date columns, __index_level_0__ is the default name used by pandas to store it's index
    columns = timeseries_dataframe.columns
    identifier = "__index_level_0__"
    if 'fieldID' in columns:
        identifier = 'fieldID'
    elif '_c0' in columns:
        identifier = '_c0'
    columns.remove(identifier)
    # our original dataframe has one column for the field id, and one column per timestamp
    # group all of the timestamps into a single column, so we can more easily combine multiple timeseries dataframes
    ts_df = timeseries_dataframe.select(timeseries_dataframe[identifier].alias("name"), functions.struct(columns).alias(ts_column_name))
    return ts_df


def read_csv_timeseries(csv_file:str, name ="timeseries"):
    """
    Reads a parquet file
    :param csv_file: the file name, for example: 'file:///home/driesj/alldata/cropsar/S1_DESCENDING_VH.parquet'
    :return: A pyspark DataFrame where each Row contains the name of the timeseries, and another row representing the timeseries.
    """

    sqlContext = get_spark_sql()
    csv_file = str(csv_file)

    if(not csv_file.startswith("file:") and not csv_file.startswith("hdfs:") and not csv_file.startswith("http")):
        csv_file = "file://" + csv_file
    df = sqlContext.read.option("maxColumns", 200000).csv(csv_file,header=True)#.limit(90)
    #df = df.select(["_c0","0000280463D712D6","00002804622171BA"])

    merged = transpose_spark_dataframe(df)

    merged = combine_timeseries_columns(merged,name)

    return merged


def transpose_spark_dataframe(df):
    """
    Transposes a spark dataframe. Dataframes are somewhat expected to have columns

    _c0,"Field1","Field2"

    Where the _c0 column contains the list of dates.
    :param df:
    :return:
    """

    from pyspark.sql.functions import pandas_udf, PandasUDFType
    from pyspark.sql.types import StructField, StructType, StringType, FloatType

    all_dates = df.select(df._c0).rdd.map(lambda row: row._c0).collect()
    fields = [StructField(field_name, FloatType(), True) for field_name in all_dates]
    fields.insert(0, StructField("_c0", StringType()))
    schema = StructType(fields)

    @pandas_udf(schema, PandasUDFType.GROUPED_MAP)
    def transpose(dataframe):
        # pdf is a pandas.DataFrame
        column_dates = dataframe["_c0"]
        dataframe = dataframe.drop("_c0", axis=1)
        transposed = dataframe.T
        transposed = transposed.apply(pd.to_numeric)
        transposed.columns = column_dates.T

        full = pd.DataFrame(index=transposed.index, columns=all_dates, dtype='float')
        full[column_dates[0]] = transposed
        full.insert(loc=0, column='_c0', value=full.index)

        return full

    transposed = df.groupBy('_c0').apply(transpose)
    merged = transposed.groupBy(transposed._c0).max()
    # get all date columns, __index_level_0__ is the default name used by pandas to store it's index
    columns = merged.columns
    identifier = "_c0"
    if 'fieldID' in columns:
        identifier = 'fieldID'
    columns.remove(identifier)
    from functools import reduce
    merged = reduce(lambda merged, i: merged.withColumnRenamed(columns[i], all_dates[i]),
                    range(len(all_dates)),
                    merged)
    return merged


def _read_s1_timeseries(directory:str,mode:str):
    vh_df = read_parquet_timeseries(os.path.join(directory,'S1_'+mode+'_VH.parquet'),'S1_'+mode+'_VH')
    vv_df = read_parquet_timeseries(os.path.join(directory,'S1_'+mode+'_VV.parquet'),'S1_'+mode+'_VV')
    joined = vh_df.join(vv_df, vh_df.name == vv_df.name,'inner').drop(vv_df.name)
    angle_df = read_parquet_timeseries(os.path.join(directory, 'S1_'+mode+'_incidenceAngle.parquet'), 'S1_'+mode+'_incidenceAngle')
    joined = joined.join(angle_df, joined.name == angle_df.name, 'inner').drop(angle_df.name)
    return joined

def read_full_timeseries_input(directory:str, s2_path = None, S2layername='FAPAR'):
    """
    Reads timeseries data from separate parquet files, and uses joins on field ID to combine them into a single dataframe
    :param directory:
    :return:
    """

    directory = str(directory)
    if s2_path is None:
        s2_path = os.path.join(directory, 'S2_' + S2layername + '.parquet')
    else:
        s2_path = str(s2_path)

    if s2_path.endswith("csv"):
        s2data = read_csv_timeseries(s2_path, 's2_' + S2layername.lower())
    else:
        s2data = read_parquet_timeseries(s2_path, 's2_' + S2layername.lower())

    ascending = _read_s1_timeseries(directory,'ASCENDING')
    descending = _read_s1_timeseries(directory, 'DESCENDING')

    joined = s2data.join(ascending,s2data.name == ascending.name,'inner').drop(ascending.name)
    joined = joined.join(descending,joined.name == descending.name,'inner').drop(descending.name)

    return joined

def row_to_pandas(row:Row) -> pd.Series:
    """
    Convert PySpark row to Pandas timeseries
    :param row:
    :return:
    """
    time_series = pd.Series(row.asDict())
    time_series.index = pd.to_datetime(time_series.index)
    return time_series

def df_row_to_dict(fieldID, single_row, S2layername):
    '''
    Conversion to required dict
    todo: NOTE: while we're currently moving away from split ascending/descending, we need to keep it here for now as this is how the training data is generated
    :param fieldID:
    :param single_row:
    :return:
    '''
    input_data = {
        'S1': {
            'ASCENDING': {
                'VV': {
                    fieldID: row_to_pandas(single_row['S1_ASCENDING_VV'])
                },
                'VH': {
                    fieldID: row_to_pandas(single_row['S1_ASCENDING_VH'])
                },
                'incidenceAngle': {
                    fieldID: row_to_pandas(single_row['S1_ASCENDING_incidenceAngle'])
                }

            },
            'DESCENDING': {
                'VV': {
                    fieldID: row_to_pandas(single_row['S1_DESCENDING_VV'])
                },
                'VH': {
                    fieldID: row_to_pandas(single_row['S1_DESCENDING_VH'])
                },
                'incidenceAngle': {
                    fieldID: row_to_pandas(single_row['S1_DESCENDING_incidenceAngle'])
                }
            }

        },
        'S2': {
            S2layername: {
                fieldID: row_to_pandas(single_row['s2_' + S2layername.lower()])
            }
        }
    }
    return input_data

