from typing import Dict
from openeo.rest.connection import Connection
from cropclass.utils.seasons import get_processing_dates
from cropclass import classification
from cropclass.openeo.preprocessing import cropclass_preprocessed_inputs
from cropclass.config import get_processing_options, get_collection_options


def croptype_map(extent,
                 connection: Connection,
                 provider: str,
                 processing_options: Dict = None):
    """Main OpenEO entry point for croptype map generation

    Args:
        extent (Dict): extent of the AOI
        connection (Connection): _description_
        collections_options (Dict): dictionary with collection definitions
        processing_options (Dict, optional): dictionay with additional
            or non-default processing options.

    Returns:
        OpenEO processing graph that can be submitted as a batch job
    """

    # Get the appropriate collections
    collections_options = get_collection_options(provider)

    # Get default processing options
    default_processing_options = get_processing_options(provider)

    # Update processing options if required
    if processing_options is not None:
        default_processing_options.update(processing_options)

    # Get processing dates
    start_date, end_date = get_processing_dates(
        default_processing_options['start_month'],
        default_processing_options['end_month'],
        default_processing_options['year']
    )

    # Get preprocessed inputs from OpenEO
    inputs = cropclass_preprocessed_inputs(
        connection,
        extent,
        start_date,
        end_date,
        **collections_options,
        **default_processing_options)

    # Load the UDF
    cropclass_udf = load_cropclass_udf()

    # Define the classification step
    clf_results = inputs.apply_neighborhood(
        lambda data: data.run_udf(
            udf=cropclass_udf, runtime='Python',
            context=default_processing_options),
        size=[
            {'dimension': 'x', 'value': 128, 'unit': 'px'},
            {'dimension': 'y', 'value': 128, 'unit': 'px'}
        ],
        overlap=[]
    ).linear_scale_range(0, 10000, 0, 10000).drop_dimension('t').rename_labels(
        "bands", ["croptype", "probability"])

    return clf_results


def load_cropclass_udf() -> str:
    import os
    with open(os.path.realpath(classification.__file__), 'r+') as f:
        return f.read()
