from cropclass.openeo.classification import load_cropclass_udf
from cropclass.openeo.preprocessing import cropclass_preprocessed_inputs


def croptype_map(extent, start_date, end_date, connection, collections_options, processing_options):

    inputs = cropclass_preprocessed_inputs(
        connection,
        extent,
        start_date,
        end_date, **collections_options, **processing_options)


    context = {
        'startdate': start_date,
        'enddate': end_date,
        'worldcovermask': collections_options.get("WORLDCOVER_collection", None) is not None,
        'modeldir': processing_options.get('modeldir', 'tmp/model/'),
        'modeltag': processing_options.get('modeltag', '20221102T134330-transformer_optical_sar_dem'),
        'custom_dependency_path': processing_options.get('custom_dependency_path', '')

    }
    context = {**processing_options, **context}

    cropclass_udf = load_cropclass_udf()

    # Define the classification step (note: we don't actually start it here)
    clf_results = inputs.apply_neighborhood(
        lambda data: data.run_udf(udf=cropclass_udf, runtime='Python', context=context),
        size=[
            {'dimension': 'x', 'value': 256, 'unit': 'px'},
            {'dimension': 'y', 'value': 256, 'unit': 'px'}
        ],
        overlap=[]
    ).linear_scale_range(0, 10000, 0, 10000).rename_labels("bands", ["croptype", "probability"])

    return clf_results