import logging
import os
import json
from shapely.geometry.base import geom_from_wkt
import openeo
from openeo_processes.utils.mask import create_advanced_mask
from openeo_processes.utils.udfstring import UDFString

#openeo_url='http://localhost:8080/openeo/1.0.0/'
openeo_url='http://openeo-dev.vgt.vito.be/openeo/1.0.0/'
#openeo_url='http://openeo.vgt.vito.be/openeo/1.0.0/'

openeo_user=os.environ.get('OPENEO_USER',os.environ['USER'])
openeo_pass=os.environ.get('OPENEO_PASS',os.environ['USER']+'123')

job_options={
    'driver-memory':'8G',
    'executor-memory':'16G'
}

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)

def getImageCollection(eoconn,layerid,startdate,enddate,bands,bbox,epsg):
    coll=eoconn.load_collection(layerid)\
    .filter_temporal(startdate,enddate)\
    .filter_bbox(crs=epsg, **dict(zip(["west", "south", "east", "north"], bbox)))
    if bands is not None: coll=coll.filter_bands(bands)
    return coll

if __name__ == '__main__':
        
    # connection
    eoconn=openeo.connect(openeo_url)
    eoconn.authenticate_basic(openeo_user,openeo_pass)
    
#   |Tile |Year|Country|
#   |30STF|2018|Spain (Andalusia)|
#   |31TCG|2019|Spain (Catalunya)|
#   |31UES|2018|Belgium|
#   |33TWN|2018|Austria|
#   |35VMC|2019|Latvia|
#   |31TFK|2019|France (SUD)|
#   |31UEQ|2019|France (Forested + agro)|    
    
    tileID='30STF'
    year=2018
    startdate=str(year-1)+'-10-01'
    enddate=str(year)+'-08-31'
    nblocks=1
    overlap=32
    windowsize=128
    
    # find bbox and epsg 
    with open("resources/S2A_OPER_GIP_TILPAR_MPC__20151209T095117_V20150622T000000_21000101T000000_B00.json") as f:
        tileinfo=json.load(f)[tileID]
        bbox= geom_from_wkt(tileinfo[2]).bounds
        epsg="EPSG:"+tileinfo[0]

    # check info
    print(tileID)
    print(epsg)
    print(bbox)
    print(startdate+"..."+enddate)

    # ensure  common sets between mask and data
    bands=getImageCollection(eoconn,"TERRASCOPE_S2_TOC_V2",startdate,enddate,["B04","B08","SCL"],bbox,epsg)
    
    # compute the mask
    maskband=create_advanced_mask(bands.filter_bands(["SCL"]).band("SCL"),startdate,enddate)

    # compute ndvi
    ndviband=bands.filter_bands(["B04","B08"]).ndvi(red="B04",nir="B08")

    # set NaN where mask is active
    ndviband=ndviband.mask(maskband)

    # run segmentation 
    sampler_udf_code = UDFString()\
        .load('udf/reduce_images.py')\
        .load('udf/sample_save.py')\
        .prepare()  
    ndviband = ndviband.apply_neighborhood(openeo.UDF(
        code=sampler_udf_code, 
        runtime="Python",
        data={'from_parameter': 'data'},
        context={
#         .replace_option('dump_location','\'/data/users/Private/'+openeo_user+'\'')
#         .replace_option('maxlayers','15')
#         .replace_option('maxnodata','0.05')
#         .replace_option('year',str(year))
#         .replace_option('tileID','\''+tileID+'\'')
#         .replace_option('cropBounds','('+','.join(map(lambda i: str(i),bbox))+')')
          'maxlayers': 15,
          'dump_location': '/data/users/Private/'+openeo_user,
          'tileID': tileID,
          'year': year,
          'maxnodata': 0.05,
          'cropBounds': bbox
        }
        
    ), size=[
        {'dimension': 'x', 'value': (windowsize-2*overlap)*nblocks, 'unit': 'px'},
        {'dimension': 'y', 'value': (windowsize-2*overlap)*nblocks, 'unit': 'px'}
    ], overlap=[
        {'dimension': 'x', 'value': overlap, 'unit': 'px'},
        {'dimension': 'y', 'value': overlap, 'unit': 'px'}
    ])

    # run it
    #ndviband.download("/data/users/Private/banyait/2018/full.json", format='json')
    #ndviband.download("/data/users/Private/banyait/2018/full.nc", format='netcdf')
    ndviband.execute_batch("test.json", out_format='json', job_options=job_options)

#     cube=datacube_from_file(    '/data/users/Private/banyait/2018/sample_665280_5637320.json','json')
#     #toTecplot(cube.get_array().squeeze(dim=cube.get_array().dims[:-2]), '/data/users/Private/banyait/2018/sample_665280_5637320.plt')
#     datacube_plot(cube, limits=(-1.,1.))
#     cube=datacube_from_file(    '/data/users/Private/banyait/2018/full.json','json')
#     #toTecplot(cube.get_array().squeeze(dim=cube.get_array().dims[:-2]), '/data/users/Private/banyait/2018/full.plt')
#     datacube_plot(cube)

    
    logger.info('FINISHED')






