import json
import logging
import shutil
import tempfile
from pathlib import Path
from typing import List

import numpy
import pandas
import pandas as pd
from openeo.metadata import CollectionMetadata
from openeo.rest.datacube import DataCube
from shapely.geometry import MultiPolygon, GeometryCollection

from cropsar import prepare_date_range
from cropsar.preprocessing.cloud_mask_openeo import evaluate


class InProcessSession:
    """Not a real session but straight function calls."""

    def __init__(self):
        self.imagecollection = self.load_collection

    def load_collection(self, collection_id, bands=None):
        return DataCube.load_collection(collection_id, connection=self, bands=bands)

    def collection_metadata(self, collection_id):
        import openeogeotrellis.layercatalog
        catalog = openeogeotrellis.layercatalog.get_layer_catalog()
        return CollectionMetadata(catalog.get_collection_metadata(collection_id))

    def list_output_formats(self) -> dict:
        return {
            "JSON": {
              "gis_data_types": [
                "raster"
              ],
              "parameters": {}
            },
            "NetCDF": {
              "gis_data_types": [
                "other",
                "raster"
              ],
              "parameters": {},
              "title": "Network Common Data Form"
            }
          }


session = InProcessSession()

#available collections: http://openeo.vgt.vito.be/openeo/0.4.0/collections
# Relevant for CropSAR:
# S1_GRD_SIGMA0_ASCENDING
# S1_GRD_SIGMA0_DESCENDING
# S2_FAPAR_V102_WEBMERCATOR2
# S2_FAPAR_SCENECLASSIFICATION_V102_PYRAMID


def preprocess_file(in_file, out_file,buffer=-15,size_threshold=1000,skip_overlaps=True):
    import geopandas as gpd
    #openeohelpers == https://git.vito.be/projects/BIGGEO/repos/openeo-python-helpers/browse


    def buffer_geometry(df, distance, intermediate_crs='epsg:32631', **kwargs):
        # We use an intermediate CRS to be able to express distance in more
        # convenient units.  By default this CRS is UTM zone 31N, so units
        # are expressed in meter.
        #
        # Optional kwargs are passed to shapely's object.buffer() function.
        #
        # For more information, see:
        #
        #  https://shapely.readthedocs.io/en/latest/manual.html#object.buffer

        intermediate_df = df.to_crs(intermediate_crs)
        intermediate_df.geometry = intermediate_df.buffer(distance, **kwargs)

        return intermediate_df.to_crs(df.crs)

    fields = gpd.read_file(in_file)

    buffered_parcels = buffer_geometry(fields, buffer)

    buffered_parcels = buffered_parcels[buffered_parcels.area > size_threshold]
    buffered_parcels = buffered_parcels.to_crs('epsg:4326')

    if not skip_overlaps:
        from openeohelpers import filter_overlaps
        filtered = filter_overlaps(buffered_parcels)
    else:
        filtered = buffered_parcels

    filtered.to_file(out_file, driver="GeoJSON")
    minx, miny, maxx, maxy = filtered.geometry.total_bounds
    return (minx, miny, maxx, maxy)



def retrieve_timeseries(collection,parcels_file, start, end, output_file: Path = Path("timeseries.json"),env=None):
    datacube = session.load_collection(collection)

    time_series = datacube\
        .filter_temporal(start, end)\
        .polygonal_mean_timeseries(parcels_file)\
        .save_result(format="json")

    evaluate(time_series, output_file,env=env)


from cropsar.preprocessing.cloud_mask_openeo import retrieve_timeseries as retrieve_clean_fapar
def retrieve_all_timeseries(parcels_file, start, end, output_dir: Path,use_gamma=True,params={},env=None):
    """
    Retrieve all input timeseries.
    :param parcels_file:
    :param start:
    :param end:
    :return:
    """

    print("Start preprocessing of cropsar geometries!")
    prepared_geometry_list = _prepare_cropsar_geometries(parcels_file)
    print("Done preprocessing geometries.")

    start_date = pd.to_datetime(start).date()
    end_date = pd.to_datetime(end).date()

    current_date = start_date
    while current_date < end_date:
        next_date = current_date + pd.DateOffset(months=5)
        if next_date > end_date:
            next_date = end_date
            upper = next_date
        else:
            upper = next_date - pd.DateOffset(days=1)

        print(current_date)
        print(upper)

        postfix = "_%s_%s.json" % (current_date.strftime('%Y%m%d'), upper.strftime('%Y%m%d'))

        if use_gamma:
            asc_output_file = output_dir / ("S1_GAMMA0" + postfix)
            print("retrieving Gamma0 timeseries for: %s to %s " % (str(current_date),str(upper)))
            retrieve_timeseries("TERRASCOPE_S1_GAMMA0_V1", prepared_geometry_list, current_date, upper,
                                output_file=asc_output_file,env=env)
        else:
            asc_output_file = output_dir / ("S1_ASCENDING" + postfix)
            desc_output_file = output_dir / ("S1_DESCENDING" + postfix)
            retrieve_timeseries("S1_GRD_SIGMA0_ASCENDING", prepared_geometry_list, current_date, upper, output_file=asc_output_file,env=env)
            retrieve_timeseries("S1_GRD_SIGMA0_DESCENDING", prepared_geometry_list, current_date, upper, output_file=desc_output_file,env=env)

        fapar_output_file = output_dir / ("FAPAR_CLEAN" + postfix)
        print("retrieving FAPAR timeseries for: %s to %s " % (str(current_date), str(upper)))
        retrieve_clean_fapar(prepared_geometry_list, current_date, upper, session, output_file=fapar_output_file,params=params,env=env)

        current_date = next_date


def _prepare_cropsar_geometries(parcels_file) -> GeometryCollection:
    from openeo_driver.delayed_vector import DelayedVector

    from cropsar._crs import auto_utm_crs_for_geometry
    from cropsar._generate import _get_param
    import shapely.geometry
    parcel_reader = DelayedVector(parcels_file)
    gdf = parcel_reader.as_geodataframe()
    bbox = shapely.geometry.box(*gdf.total_bounds)
    crs = gdf.crs
    print("Found this CRS in the input data: %s." % str(crs))
    #https://geopandas.org/projections.html#upgrading-to-geopandas-0-7-with-pyproj-2-2-and-proj-6
    if not isinstance(crs,dict):
        crs = crs.to_dict()
    crs_utm = auto_utm_crs_for_geometry(bbox, crs)

    border = -_get_param({}, 's2_clean', 'inarrowfieldbordersinmeter')
    buffer_args = {'cap_style': 1, 'join_style': 2, 'resolution': 4}
    if crs == crs_utm:
        gdf_buffer = gdf.buffer(border,**buffer_args)
        print("Total number of polygons: %s with area: %s" % (str(gdf.length), str(gdf.area.sum())))
    else:
        gdf_utm = gdf.to_crs(crs_utm)
        gdf_utm = gdf_utm.buffer(border,**buffer_args)
        print("Total number of polygons: %s with area: %s" % (str(gdf_utm.length), str(gdf_utm.area.sum())))
        gdf_buffer = gdf_utm.to_crs(gdf.crs)

    empty_geoms = gdf.geometry.loc[gdf_buffer.geometry.is_empty]
    if len(empty_geoms)>0:
        gdf_buffer.geometry.loc[gdf_buffer.geometry.is_empty] = empty_geoms

    return GeometryCollection(gdf_buffer.geometry.array)


def retrieve_cropsar(parcels_file, start:str, end:str,use_gamma=False, working_directory=None, params={},env=None) -> List[Path]:
    """
    Retrieve cropsar curves for all polygons in a given file. This method is invoked by the openEO cropsar process.

    :param parcels_file: shp or geojson file with polygons
    :param start: start date string, formatted as '%Y-%m-%d'
    :param end: start date string, formatted as '%Y-%m-%d'
    :return: list of paths containing cropsar results
    """
    if working_directory is None:
        input_timeseries_dir = Path(tempfile.mkdtemp(prefix="in_", dir=Path.cwd()))
    else:
        input_timeseries_dir = Path(working_directory)

    print("Running CropSAR process on: %s, start date: %s, end date: %s " % (str(parcels_file),str(start),str(end)))
    print("Working directory: %s" % str(input_timeseries_dir))

    start,end = prepare_date_range(start,end)
    
    retrieve_all_timeseries(parcels_file, start, end, input_timeseries_dir,use_gamma,params=params,env=env)

    if use_gamma:
        gamma_input_glob = str(input_timeseries_dir / "S1_GAMMA0*.json")
        asc_input_glob = desc_input_glob = None
    else:
        asc_input_glob = str(input_timeseries_dir / "S1_ASCENDING_*.json")
        desc_input_glob = str(input_timeseries_dir / "S1_DESCENDING_*.json")
        gamma_input_glob = None

    fapar_input_glob = str(input_timeseries_dir / "FAPAR_CLEAN_*.json")

    if working_directory is None:
        output_cropsar_dir = Path(tempfile.mkdtemp(prefix="out_", dir=Path.cwd()))
    else:
        output_cropsar_dir = Path(working_directory)

    try:
        return run_cropsar_on_files(asc_input_glob, desc_input_glob, fapar_input_glob, output_cropsar_dir,gamma0_glob=gamma_input_glob,include_inputs=params.get('include_inputs',False))
    finally:
        if working_directory is None:
            shutil.rmtree(input_timeseries_dir, ignore_errors=True)


def run_cropsar_on_files(s1_asc_glob: str, s1_desc_glob: str, S2_FAPAR_clean_glob: str, output_dir: Path = Path.cwd(), gamma0_glob = None, include_inputs = False) -> List[Path]:
    from openeo.rest.conversions import timeseries_json_to_pandas

    #import cropsar.service.api as api
    import pandas as pd
    import glob

    def load_ts_json(glob_pattern):

        filenames = glob.glob(glob_pattern)

        df = pd.DataFrame()

        for file in filenames:
            with open(file, "r") as f:
                ts_json = json.load(f)
            if len(ts_json) > 0:
                ts_dataframe = timeseries_json_to_pandas(ts_json, auto_collapse=False)
                ts_dataframe.index = pd.to_datetime(ts_dataframe.index).date
                df = df.append(ts_dataframe)
            else:
                print("Skipping file with empty json: %s" % file)
        return df.sort_index().T

    if gamma0_glob is None:
        df_S1_asc = load_ts_json(s1_asc_glob)
        df_S1_desc = load_ts_json(s1_desc_glob)
    else:
        df_gamma0 = load_ts_json(gamma0_glob)
        df_S1_asc = df_gamma0
        df_S1_desc = None

    df_S2_clean = load_ts_json(S2_FAPAR_clean_glob)

    #scaling for S2 fapar
    df_S2_clean *= 0.005

    cropsar_result = run_cropsar_dataframes(df_S2_clean, df_S1_asc, df_S1_desc, output_dir)
    if include_inputs:
        cropsar_result.extend(_write(df_S2_clean,"S2_FAPAR",output_dir))
        if df_S1_desc is None:
            cropsar_result.extend(_write(df_S1_asc,"S1_GAMMA0",output_dir))
        else:
            cropsar_result.extend(_write(df_S1_desc, "S1_SIGMA0_DESCENDING", output_dir))
            cropsar_result.extend(_write(df_S1_asc, "S1_SIGMA0_ASCENDING", output_dir))

    return cropsar_result

def _write(df,name, output_dir:Path) -> List:
    out_file = output_dir / ("%s.csv"%name)
    df.to_csv(out_file, date_format='%Y-%m-%d')
    return [out_file]

def run_cropsar_dataframes(df_S2_clean, df_S1_asc, df_S1_desc, output_dir=None, scale=0.0005, offset=29):
    import cropsar

    start = df_S1_asc.columns.min()
    end = df_S1_asc.columns.max()
    # s2_data_layer = api._map_ts_product("S2_FAPAR")  # For CropSAR RNN model scalers
    s2_data_layer = "S2_FAPAR"
    cropsar_df = pd.DataFrame()
    cropsar_df_q10 = pd.DataFrame()
    cropsar_df_q90 = pd.DataFrame()
    spark = True

    def outlier_filter(cleanfielddatatimeseries, start_date, end_date):

        from s2_clean import smooth
        from cropsar._generate import _get_param

        params = None
        localminimamaxdip = _get_param(params, 's2_clean', 'localminimamaxdip')
        localminimamaxdif = _get_param(params, 's2_clean', 'localminimamaxdif')
        localminimamaxgap = _get_param(params, 's2_clean', 'localminimamaxgap')
        localminimamaxpas = _get_param(params, 's2_clean', 'localminimamaxpas')

        cleanfieldfulldatatimeseries = pandas.Series(
            index=pandas.date_range(start=start_date, end=end_date), dtype=float)
        cleanfielddatatimeseries.index = pd.to_datetime(cleanfielddatatimeseries.index)
        ts_intersection = cleanfielddatatimeseries[start_date:end_date]
        cleanfieldfulldatatimeseries.loc[ts_intersection.index] = ts_intersection
        smooth.flaglocalminima(cleanfieldfulldatatimeseries.values, maxdipvalueornumpycube=localminimamaxdip,
                               maxdifvalueornumpycube=localminimamaxdif,
                               maxgap=localminimamaxgap, maxpasses=localminimamaxpas)
        dippedindices = cleanfieldfulldatatimeseries.loc[cleanfieldfulldatatimeseries.isnull()].index.intersection(
            cleanfielddatatimeseries.loc[cleanfielddatatimeseries.notnull()].index)
        cleanfielddatatimeseries.loc[dippedindices] = numpy.nan
        cleanfielddatatimeseries.dropna(inplace=True)
        return cleanfielddatatimeseries

    if spark:
        # os.environ['PYSPARK_PYTHON'] = '/usr/bin/python3.6'
        from pyspark import SparkContext

        spark = SparkContext.getOrCreate()
        my_list = [{'fieldid': field_index,
                    's2_data': pd.DataFrame(df_S2_clean.loc[field_index].T),
                    's1_asc': df_S1_asc.loc[field_index].T,
                    's1_desc': df_S1_desc.loc[field_index].T if df_S1_desc is not None else None,

                    'date_start': start.strftime('%Y-%m-%d'),
                    'date_end': end.strftime('%Y-%m-%d')
                    } for field_index in df_S1_asc.index.get_level_values(0).unique()]

        def do_cropsar(item):
            ascending_row = item['s1_asc']
            descending_row = item['s1_desc']
            s2_row = item['s2_data']

            logging.basicConfig(level=logging.DEBUG)

            combined_orbits_row = _create_combined_sigma0(ascending_row, descending_row, scale, offset)

            s2_row.columns = ['clean']
            s2_row['clean'] = outlier_filter(s2_row['clean'], item['date_start'], item['date_end'])
            result = cropsar.generate_timeseries(combined_orbits_row, s2_row, None, s2_data_layer,
                                                 item['date_start'], item['date_end'],
                                                 s1_var='sigma' if descending_row is not None else 'gamma')
            return (item['fieldid'], result[0])

        cropsar_ts = spark.parallelize(my_list, max(len(my_list) // 10,2)).map(do_cropsar).collectAsMap()

        for field_index, result in cropsar_ts.items():
            cropsar_df[field_index] = result.q50
            cropsar_df_q10[field_index] = result.q10
            cropsar_df_q90[field_index] = result.q90

    else:

        for field_index in df_S1_asc.index.get_level_values(0).unique():
            s2_row = pd.DataFrame(df_S2_clean.loc[field_index].T)
            combined_orbits_row = _create_combined_sigma0(df_S1_asc.loc[field_index].T, df_S1_desc.loc[
                field_index].T if df_S1_desc is not None else None)

            s2_row.columns = ['clean']
            s2_row['clean'] = outlier_filter(s2_row['clean'], start, end)

            result = cropsar.generate_timeseries(combined_orbits_row, s2_row, None, s2_data_layer,
                                                 start.strftime('%Y-%m-%d'), end.strftime('%Y-%m-%d'),
                                                 s1_var='sigma' if df_S1_desc is not None else 'gamma')

            cropsar_df[field_index] = result[0].q50
            cropsar_df_q10[field_index] = result[0].q10
            cropsar_df_q90[field_index] = result[0].q90
            
    if output_dir == None:
        return cropsar_df, cropsar_df_q10, cropsar_df_q90
    else:

        q50_json_output_file = output_dir / "cropsar.json"
        q10_json_output_file = output_dir / "cropsar_q10.json"
        q90_json_output_file = output_dir / "cropsar_q90.json"

        cropsar_df.to_json(q50_json_output_file, date_format='%Y-%m-%d')
        cropsar_df_q10.to_json(q10_json_output_file, date_format='%Y-%m-%d')
        cropsar_df_q90.to_json(q90_json_output_file, date_format='%Y-%m-%d')

        q50_csv_output_file = output_dir / "cropsar.csv"
        q10_csv_output_file = output_dir / "cropsar_q10.csv"
        q90_csv_output_file = output_dir / "cropsar_q90.csv"

        cropsar_df.to_csv(q50_csv_output_file, date_format='%Y-%m-%d')
        cropsar_df_q10.to_csv(q10_csv_output_file, date_format='%Y-%m-%d')
        cropsar_df_q90.to_csv(q90_csv_output_file, date_format='%Y-%m-%d')

        return [
            q50_csv_output_file,
            q10_csv_output_file,
            q90_csv_output_file,
            q50_json_output_file,
            q10_json_output_file,
            q90_json_output_file
        ]


def _create_combined_sigma0(ascending_row, descending_row, scale=0.0005, offset=29):
    """

    @param ascending_row:  ascending data or already combined gamma0 data
    @param descending_row: descending data or None if data is already combined
    @return:
    """
    if descending_row is None:

        ascending_row.columns = ['vh','vv']
        ascending_row['angle'] = 0.0
        return ascending_row
    else:
        ascending_row.columns = ['vh', 'vv', 'angle']
        descending_row.columns = ['vh', 'vv', 'angle']
        ascending_row['angle'] = scale * ascending_row['angle'] + offset
        descending_row['angle'] = scale * descending_row['angle'] + offset
        combined_orbits_row = pd.concat((ascending_row, descending_row)).groupby(level=0).mean()
        return combined_orbits_row


def get_result(job_id:str,session):
    from openeo.rest.job import RESTJob
    job=RESTJob(id,session)
    job.download_results("out.json")


if __name__ == '__main__':
    import fire

    fire.Fire({"cropsar_on_files": run_cropsar_on_files,"retrieve_all_timeseries": retrieve_all_timeseries})

#    start = '2019-02-01'
#    end = '2019-12-31'

#    parcels_file = "/data/CapSat/DLV/EPR/DLV2018_maize.gpkg"
#    parcels_file_processed = "/data/users/Public/driesj/hpg_flanders_v3_filtered.geojson"

    # IMPORTANT: make sure parcels files do not overlap
    #minx, miny, maxx, maxy = preprocess_file(parcels_file, parcels_file_processed)

#    retrieve_all_timeseries(parcels_file_processed,start,end)
