"""
Wrapper function designed to obtain the delineation results based on the trained model of Radix
"""

import functools
from pathlib import Path
from openeo.udf.debug import inspect

# def _setup_logging():
#     global logger
#     from loguru import logger


@functools.lru_cache(maxsize=25)
def load_delineation_model(modeldir: str, modeltag: str, cache_dir: str):
    import os
    from vito_lot_delineation.models import load_model
    from fielddelineation.utils import download_and_unpack

    if cache_dir is None:
        cache_dir = os.getcwd()

    if modeldir.startswith('http'):
        modeldir, modeltag = download_and_unpack(modeldir, modeltag, cache_dir,
                                                 format=os.path.splitext(modeltag)[1].replace('.',  ''))

    # load a model from its tag
    if 'MultiHead' in modeltag:
        model = load_model(Path(modeldir) / modeltag.replace('.zip', ''))
    else:
        # We only import this if we are running the older version that does not use MulitHead.
        # Otherwise you may be blocked by an ImportError if you can't access the correct
        # version of vito_lot_delineation.
        # TODO: This is a stopgap solution. Would be better that the correct version of vito_lot_delineation is available in Artifactory, or alike.
        from vito_lot_delineation.models.EnanchedResUnet3D.main import SemanticModel
        model = SemanticModel.load(Path(modeldir) / modeltag.replace('.zip', ''))

    return model


def _apply_delineation(features, modeldir, modeltag, cache_dir, return_model=False):
    from vito_lot_delineation import run_prediction
    from vito_lot_delineation.inference import preprocess_raw
    inspect(
        message=f"Checking if model directory exists: {modeldir}",
        code='load_delineation_model',
        level='debug'
    )
    inspect(message='******* Start actual delineation ********',
            code='_apply_delineation', level='debug')
    # load the pre-trained model
    model = load_delineation_model(modeldir, modeltag, cache_dir)
    inspect(message=f'Delineation model loaded: {modeltag}',
            code='_apply_delineation', level='debug')

    # preprocess satellite dataset before fetching into a batch
    inspect(message='Preprocessing of features and transforming it to Torch sensor',
            code='_apply_delineation', level='debug')
    sample = preprocess_raw(sample=features, # input satellite features
                            cfg=model.cfg["input"]) # configuration directory
    # create batch for feeding into model
    batch = sample.unsqueeze(0).repeat(1, 1, 1, 1, 1)
    inspect(message=f'Created batch sample of shape: {batch.shape}',
            code='_apply_delineation', level='debug')

    inspect(message='Apply model prediction',
            code='_apply_delineation', level='debug')
    # run now actually the prediction
    preds = run_prediction(
        model=model, # model to use for prediction
        batch=batch # batch of images over which want to predict
    )
    inspect(message='Model prediction succeeded', 
            code='_apply_delineation', level='debug')
    # keep only the semantic segmentation
    preds_sem = preds['semantic'].numpy()

    # for the moment reduce outcome from 3D to 2D
    if len(preds_sem.shape) > 2:
        inspect('Model output is 3D --> convert to 2D')
        preds_sem = preds_sem[0, :, :]

    if return_model:
        return preds_sem, model

    return preds_sem

