import json
from pathlib import Path

import numpy as np
import openeo
import scipy.signal
from openeo.rest.datacube import DataCube
import gc


def create_mask(start=None, end=None, session=None, scl_layer_band="TERRASCOPE_S2_TOC_V2:SCENECLASSIFICATION_20M"):
    """
    Creates a mask from Sen2Cor sceneclassification using a dilation strategy.
    @param start: deprecated, not needed
    @param end:  deprecated, not needed
    @param session: an openEO connection
    @param scl_layer_band:  The SCL layer and band to use.
    @return: A mask datacube
    """

    if(session == None):
        raise ValueError("openEO session needs to be specified")

    if scl_layer_band==None:
        scl_layer_band = "TERRASCOPE_S2_TOC_V2:SCENECLASSIFICATION_20M"
    layer_band = scl_layer_band.split(':')
    s2_sceneclassification = session.imagecollection( layer_band[0], bands=[layer_band[1]])

    classification = s2_sceneclassification.band(layer_band[1])

    def makekernel(iwindowsize):
        kernel_vect = scipy.signal.windows.gaussian(iwindowsize, std=iwindowsize / 6.0, sym=True)
        kernel = np.outer(kernel_vect, kernel_vect)
        kernel = kernel / kernel.sum()
        return kernel

    #in openEO, 1 means mask (remove pixel) 0 means keep pixel

    #keep useful pixels, so set to 1 (remove) if smaller than threshold
    first_mask = ~ ((classification == 2) | (classification == 4) | (classification == 5) | (classification == 6) | (classification == 7))
    first_mask = first_mask.apply_kernel(makekernel(17))
    #remove pixels smaller than threshold, so pixels with a lot of neighbouring good pixels are retained?
    first_mask = first_mask > 0.057

    #remove cloud pixels so set to 1 (remove) if larger than threshold
    second_mask = (classification == 3) | (classification == 8) | (classification == 9) | (classification == 10) | (classification == 11)
    second_mask = second_mask.apply_kernel(makekernel(201))
    second_mask = second_mask > 0.025

    return first_mask | second_mask



def download_mask(session,date = "2018-08-17"):


    mask = create_mask(date,date,session)\
        .filter_bbox(west=minx,east=maxx,north=maxy,south=miny,crs="EPSG:4326")\

    mask.download("mask20m.tiff", format='GTIFF')
        #.execute_batch("mask.tiff",out_format='GTIFF',parameters={"catalog":True})

def download_rgb(session,date = "2018-05-06"):


    mask = create_mask(date,date,session)\
        .filter_bbox(west=minx,east=maxx,north=maxy,south=miny,crs="EPSG:4326")

    s2_radiometry = session.imagecollection("CGS_SENTINEL2_RADIOMETRY_V102_001",bands=["2","3","4"]) \
        .filter_bbox(west=minx, east=maxx, north=maxy, south=miny, crs="EPSG:4326").filter_temporal(date,date)\
        .mask(mask)

    #s2_radiometry.execute_batch("masked.tiff",out_format='GTIFF',parameters={"catalog":True})
    s2_radiometry.download("masked_3.tiff", format='GTIFF')


def retrieve_timeseries(parcels, start, end, session, output_file=Path("timeseries.json"),params={},env=None):
    time_series = create_fapar_process(start, end, parcels, session, params)
    correlation_id = params.get("correlation_id", '')
    evaluate(time_series, output_file, correlation_id)


def create_fapar_process(start, end, parcels, session,params):
    fapar_bands = params.get("fapar_bands", ["FAPAR_10M", "SCENECLASSIFICATION_20M"])
    datacube = session.load_collection(params.get("fapar_layer", "TERRASCOPE_S2_FAPAR_V2"),
                                       bands=fapar_bands)

    masked = None
    if params.get("mask_strategy", "mask_scl_dilation") == "mask_scl_dilation":
        masked = datacube.process("mask_scl_dilation", data=datacube,
                                  scl_band_name=params.get('scl_layer_band', 'SCENECLASSIFICATION_20M'))
    else:
        mask = create_mask(start, end, session, scl_layer_band=params.get('scl_layer_band', None))
        masked = datacube.mask(mask)
    time_series = masked.filter_temporal(start, end).filter_bands(fapar_bands[0]).polygonal_mean_timeseries(parcels) \
        .save_result(format="json")
    return time_series


def evaluate(image_collection: DataCube, output_file, correlation_id: str = '',env=None):
    """
    Evaluates an ImageCollection in-process and writes its result to a file.

    Note: only supports JSON results.
    """
    #imports are here because they create a dependency on openEO backend code that we want to keep optional
    from openeo_driver import ProcessGraphDeserializer
    from openeo_driver.save_result import JSONResult
    from openeo_driver.utils import EvalEnv

    process_graph = image_collection.flat_graph()

    if env == None:
        env = EvalEnv()
        env = env.push({'correlation_id': correlation_id})
        env = env.push({
            "version": "1.0.0",
            "pyramid_levels": "highest"
        })

    print(env)

    try:
        result = ProcessGraphDeserializer.evaluate(process_graph, env,do_dry_run=True)
        assert isinstance(result, JSONResult), type(result)

        with open(output_file, 'w') as f:
            json.dump(result.prepare_for_json(), f)
    finally:
        gc.collect()



if __name__ == '__main__':
    session = openeo.connect("http://openeo.vgt.vito.be/openeo/0.4.0")

    # this dummy login is needed for now, will be replaced with actual credentials!
    session.authenticate_basic("driesj", "driesj123")

    minx, miny, maxx, maxy = (3.057030657924054, 50.99958367677388, 3.058236553549667, 51.00226308446294)
    # enlarge bounds, to also have some data outside of our parcel
    minx -= 0.05
    miny -= 0.05
    maxx += 0.05
    maxy += 0.05

    #RETIE
    minx, miny, maxx, maxy = (4.996033, 51.238922, 5.121603, 51.282696)
    date = "2018-08-14"
    #minx,miny,maxx,maxy = (5.611086, 51.018870, 5.614706, 51.022064)
    download_rgb(session,date)
    #download_mask(session,date)
    #retrieve_timeseries("2018-04-01","2018-11-01","/data/users/Public/driesj/some_polygons.shp",session)
