from shapely.geometry import box
import openeo.processes
from cropclass.openeo.preprocessing import cropclass_inputs, cropclass_raw_inputs
from cropclass.openeo.udf import load_cropclass_udf
from worldcereal.seasons import get_processing_dates
import satio.layers
from loguru import logger


JOB_OPTIONS = {
    "driver-memory": "2G",
    "driver-memoryOverhead": "6G",
    "driver-cores": "2",
    "executor-memory": "2G",
    "executor-memoryOverhead": "2G",
    "executor-cores": "1",
    "max-executors": "300",
    "driver-maxResultSize": "0"
}


def main(connection, extent, aez_id, year, outdir, outnamepattern):

    logger.info('-' * 40)
    logger.info('STARTING CROPCLASS WORKFLOW - OPENEO')
    logger.info('-' * 40)
    logger.info('Parameters:')
    logger.info(f'AEZ_ID: {aez_id}')
    logger.info(f'YEAR: {year}')
    logger.info(f'EXTENT: {extent}')
    logger.info(f'OUTDIR: {outdir}')
    logger.info(f'OUTNAME: {outnamepattern}')
    logger.info('-' * 40)

    # Load the UDF we need to run
    logger.info('Loading UDF ...')
    cropclass_udf = load_cropclass_udf()

    # Infer start and end date
    startdate, enddate = get_processing_dates('summer1', aez_id, year)

    # Assemble inputs
    logger.info('Loading job [openeo]...')
    inputs = cropclass_inputs(connection, extent, startdate, enddate)

    # # Compute features and run classification
    # clf_results = inputs.apply_neighborhood(
    #     lambda data: data.run_udf(cropclass_udf,
    #                               runtime='Python'),
    #     size=[
    #         {'dimension': 'x', 'value': 256, 'unit': 'px'},
    #         {'dimension': 'y', 'value': 256, 'unit': 'px'}
    #     ],
    #     overlap=[
    #         {'dimension': 'x', 'value': 0, 'unit': 'px'},
    #         {'dimension': 'y', 'value': 0, 'unit': 'px'}
    #     ]
    # )

    clf_results = inputs.reduce_dimension(
        dimension='t',
        reducer=lambda data: data.run_udf(
            cropclass_udf,
            runtime='Python',
            context={
                'startdate': startdate,
                'enddate': enddate
            })
    ).linear_scale_range(0, 10000, 0, 10000)

    # Submit the workflow job
    logger.info('Submitting job ...')
    job = clf_results.execute_batch(
        title="Cropclass-Classification-Workflow",
        out_format="GTiff",
        job_options=JOB_OPTIONS)

    # Get the results
    results = job.get_results()

    # Loop over the resulting assets and download
    for asset in results.get_assets():
        if asset.metadata["type"].startswith("image/tiff"):
            asset.download(f"{outdir}/{outnamepattern}" + asset.name)

    logger.success('Process done!')


if __name__ == '__main__':

    # Specifications of the requested product
    aez_id = 46172  # includes Belgium
    year = 2020
    tile = '31UFS'
    outnamepattern = f'{year}_{tile}_cropclass-REPROJ-OPENEO-'
    outdir = '/data/users/Public/kristofvt/NEXTLAND/results'

    # Get the extent of the bbox
    s2grid = satio.layers.load('s2grid')
    s2grid = s2grid[s2grid.tile == tile]
    s2grid_buffered = s2grid.copy()
    # s2grid_buffered['geometry'] = s2grid_buffered.to_crs(
    #     'EPSG:32631').buffer(-53000).to_crs('EPSG:4326')
    bounds = s2grid_buffered.set_index('tile').loc[tile].geometry.bounds
    extent = dict(zip(["west", "south", "east", "north"], bounds))
    extent['crs'] = 'EPSG:4326'

    # Setup backend connection
    conn = openeo.connect("https://openeo-dev.vito.be").authenticate_basic(
        "kristofvt", "kristofvt123")

    # Run the pipeline
    main(conn, extent, aez_id, year, outdir, outnamepattern)
