######################################################################################################
# yieldpotentialmap.py
# ---------------------------
# Purpose
#       Integration of the yield potential map generation process in OpenEO
#
######################################################################################################

# imports
import argparse
import logging
import os

import openeo
from openeo import Connection
from openeo.rest.datacube import DataCube
from openeo.rest.udp import build_process_dict

from nextland_services.constants import *
from nextland_services.helpers import read_input, load_udf, get_extension, write_json, load_markdown

logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)-8s %(message)s')
_log = logging.getLogger("yieldpotentialmap")


def get_terrascope_input_datacube(conn: Connection, date_range: list, features: dict) -> DataCube:
    """
    Get the input data for the yield potential map from the TerraScope collection
    :param conn: Existing connection to an OpenEO backend
    :param date_range: Date range for which to get the timeseries
    :param features: Feature collection to apply as mask on the resulting data
    :return: Datacube containing the FAPAR data selection from Terrascope
    """

    dc = conn \
        .load_collection('TERRASCOPE_S2_FAPAR_V2', bands=[get_fapar_band('terrascope'), 'SCENECLASSIFICATION_20M']) \
        .filter_spatial(features) \
        .filter_temporal(date_range)

    dc = dc.process("mask_scl_dilation", data=dc, scl_band_name="SCENECLASSIFICATION_20M")
    dc = dc.filter_bands(get_fapar_band('terrascope')).mask_polygon(mask=features)
    return dc / 200


def get_shub_input_datacube(conn: Connection, date_range: list, features: dict) -> DataCube:
    """
    Get the input data for the yield potential map from the SentinelHub collection
    :param conn: Existing connection to an OpenEO backend
    :param date_range: Date range for which to get the timeseries
    :param features: Feature collection to apply as mask on the resulting data
    :return: Datacube containing the FAPAR data selection from SentinelHub
    """

    dc = conn \
        .load_collection('SENTINEL2_L2A_SENTINELHUB',
                         bands=["B03", "B04", "B08", "sunAzimuthAngles", "sunZenithAngles", "viewAzimuthMean",
                                "viewZenithMean", 'SCL']) \
        .filter_spatial(features) \
        .filter_temporal(date_range)

    dc = dc.process("mask_scl_dilation", data=dc, scl_band_name="SCL")

    dc = dc.reduce_bands_udf(code=load_udf(Path(__file__).parent / 'shub_fapar_udf.py'),
                             runtime='Python')

    dc = dc.add_dimension(label=get_fapar_band('shub'), name=get_fapar_band('shub'), type='bands')
    return dc.mask_polygon(mask=features)


def extract_timeseries(conn: Connection, source: str, features: dict, date_range: list) -> DataCube:
    """
    Given a datacube, extract the median timeseries for a given field
    :param conn: OpenEO connection
    :param source: Data source to use. This can be terrascope or shub
    :param features: Feature collection representing the fields to process
    :param date_range: Date range for which to get the timeseries
    :return: Datacube containing the data timeseries for the yield potential map
    """

    if source == 'terrascope':
        dc = get_terrascope_input_datacube(conn=conn, date_range=date_range, features=features)
    elif source == 'shub':
        dc = get_shub_input_datacube(conn=conn, date_range=date_range, features=features)
    else:
        raise Exception(f'Source {source} is not supported!')

    return dc


def get_fapar_band(source: str) -> str:
    """
    Return the name of the fapar band based on the selected source
    :param source: Data source to use. This can be terrascope or shub
    :return: Name of the FAPAR band
    """
    if source.lower() == 'terrascope':
        return 'FAPAR_10M'
    elif source.lower() == 'shub':
        return 'FAPAR'
    else:
        raise Exception(f'Source {source} is not supported!')


def execute_yieldpotential_udf(dc: DataCube, source: str, features: dict, threshold: float,
                               exclude_months: list, exclude_croptypes: list, raw: bool, year_check: bool,
                               polygonize: bool, publish: bool = False):
    """
    Execute the UDF that will generate the yield potential map
    :param dc: Datacube that contains the input timeseries required to calculate the yield potential map
    :param source: Data source to use. This can be terrascope or shub
    :param features: Feature collection representing the fields to process
    :param threshold: Threshold for the mean value for a given date
    :param exclude_months: Months that should be removed from the yield map generation procedure
    :param exclude_croptypes: Croptypes for which we do not calculate a yield potential map
    :param raw: Flag indicating if the raw values should be returned (True) or if the map should contain categories (False)
    :param year_check: Flag indicating if the year should be checked on croptypes, being split, ...
    :param polygonize: Flag indicating if the resulting map should be polygonized
    :param publish: Flag indicating if the we are publishing the graph and not executing!
    :return:
    """
    mask_value = 999.0
    udf_process = lambda data: data.run_udf(udf=load_udf(Path(__file__).parent / 'yieldpotentialmaps_udf.py'),
                                            runtime='Python',
                                            context={
                                                'band': get_fapar_band(source),
                                                'threshold': threshold if not publish else {
                                                    "from_parameter": "threshold"},
                                                'exclude_months': exclude_months if not publish else {
                                                    "from_parameter": "exclude_months"},
                                                'exclude_croptypes': exclude_croptypes if not publish else {
                                                    "from_parameter": "exclude_croptypes"},
                                                'mask_value': mask_value,
                                                'raw': raw if not publish else {"from_parameter": "raw"},
                                                'year_check': year_check if not publish else {
                                                    "from_parameter": "check"},
                                                'polygonize': polygonize
                                            })
    dc = dc \
        .chunk_polygon(chunks=features, process=udf_process, mask_value=mask_value)
    return dc


def generate_yield_potential_map(conn: Connection, source: str, date_range: list, features: dict,
                                 threshold: float = 0.5,
                                 exclude_months: list = [1, 2, 3, 4, 9, 10, 11, 12],
                                 exclude_croptypes: list = ['Wintertarwe', 'Wintergerst', '6', '743', '321', '311'],
                                 raw: bool = False,
                                 year_check: bool = True,
                                 polygonize: bool = False,
                                 publish: bool = False) -> DataCube:
    """
    Generate a yield potential map using an existing connection to an OpenEO backend
    :param conn: Connection to an OpenEO supported backend
    :param source: Data source to use. This can be terrascope or shub
    :param date_range: Date range for which to calculate the yield potential map
    :param features: Feature collection for which to generate the yield potential map
    :param threshold: Threshold for the mean value for a given date
    :param exclude_months: Months that should be removed from the yield potential map generation procedure
    :param exclude_croptypes: Croptypes for which we do not calculate a yield potential map
    :param raw: Flag indicating if the raw values should be returned (True) or if the map should contain categories (False)
    :param year_check: Flag indicating if the year should be checked on croptypes, being split, ...
    :param polygonize: Flag indicating if the resulting map should be polygonized
    :param publish: Flag indicating if we are publishing the graph and not executing it!
    :return: Datacube containing the yield potential map
    """

    _log.info(f'Starting the generation of the yield potential map')

    _log.debug(f'Retrieving the timeseries')
    dc = extract_timeseries(conn=conn, source=source, features=features, date_range=date_range)

    _log.debug(f'Executing the yield potential map UDF')
    return execute_yieldpotential_udf(dc=dc, source=source, features=features,
                                      threshold=threshold,
                                      exclude_months=exclude_months,
                                      exclude_croptypes=exclude_croptypes,
                                      raw=raw, year_check=year_check, polygonize=polygonize, publish=publish)


def publish_processing_graph(conn: Connection, source: str, threshold: float,
                             exclude_months: list,
                             exclude_croptypes: list,
                             raw: bool,
                             year_check: bool):
    """
    Publish the yield potential map generation as a user defined process
    :param conn: Connection to an OpenEO supported backend
    :param source: Data source to use. This can be terrascope or shub
    :param threshold: Default threshold value
    :param exclude_months: Default months that should be removed from the yield potential map generation procedure
    :param exclude_croptypes: Default croptypes for which we do not calculate a yield potential map
    :param raw: Default flag value indicating if the raw values should be returned (True) or if the map should contain categories (False)
    :param year_check: Default flag value indicating if the year should be checked
    """

    # Setup of parameters
    date_range_param = date_parameter()

    features_param = polygon_param(
        description="Feature collection of fields for which to generate the yield potential map"
    )
    threshold_param = Parameter.number(
        name='threshold',
        description='Average FAPAR threshold',
        default=threshold
    )
    exclude_months_param = Parameter.array(
        name='exclude_months',
        description='Months that should not be included in the calculation of the yield potential map',
        default=exclude_months
    )
    exclude_croptypes_param = Parameter.array(
        name='exclude_croptypes',
        description='No yield potential map is calculated if the field has one of these croptypes. This is only applicable if the check is set.',
        default=exclude_croptypes
    )
    raw_param = Parameter.boolean(
        name='raw',
        description='Flag indicating if the yield map contains the raw differences or the result is categorized',
        default=raw
    )
    check_param = Parameter.boolean(
        name='check',
        description='Flag indicating if addtional checks needs to be executed to see if the  year is valid for the field (only working for fields in Belgium!). This check will validate that the field was not split '
                    'during the given year and validate if the croptype is not included in the exclude_croptypes.',
        default=year_check
    )

    # Create processing graph
    process_id = f'yieldpotentialmap_{source}'
    description = load_markdown('yieldpotentialmap.md')
    process = generate_yield_potential_map(conn=conn, source=source, date_range=date_range_param,
                                           features=features_param,
                                           threshold=threshold_param, exclude_months=exclude_months_param,
                                           exclude_croptypes=exclude_croptypes_param, raw=raw_param,
                                           year_check=check_param, publish=True)

    # Publish as a local service for the current user
    eoconn.save_user_defined_process(
        process_id,
        process.graph,
        description=description,
        parameters=[date_range_param, features_param, threshold_param, exclude_months_param, exclude_croptypes_param,
                    raw_param, check_param],
        public=True)

    process_graph = build_process_dict(
        process_id=process_id,
        description=description,
        process_graph=process,
        parameters=[date_range_param, features_param, threshold_param, exclude_months_param, exclude_croptypes_param,
                    raw_param, check_param]
    )
    return process_graph


def execute_datacube(dc: DataCube, output_file: Path, format: str, title: str, batch: bool):
    """
    Execute an OpenEO datacube and download the results
    :param dc: OpenEO DataCube to process
    :param output_file: Path the output file
    :param format: Format of the result to be created
    :param title: Title of the batch job (if enabled)
    :param batch: Flag indicating if the job needs to be executed in batch mode
    :return:
    """
    if batch:
        _log.info(f'Downloading the result to {output_file.parent}')
        _log.debug(f'Executing batch job')
        job = dc.send_job(out_format="GTiff", title=title, sample_by_feature=True)
        job.start_and_wait().download_results(output_file.parent)
    else:
        _log.info(f'Downloading the result to {output_file}')
        _log.debug(f'Starting download')
        dc.download(output_file, format=format)


def get_script_args():
    parser = argparse.ArgumentParser()

    # General actions of the script
    parser.add_argument("-p", "--publish", action='store_true',
                        help="Publish the yield potential map service")
    parser.add_argument("-e", "--execute", action='store_true',
                        help="Execute the yield potential map service")

    # Execution params
    parser.add_argument("-v", "--verbose", action='store_true',
                        help="Verbose output logging")
    parser.add_argument("-o", "--output", type=str, default='.',
                        help="Output directory to store results")
    parser.add_argument("-b", "--batch", action='store_true',
                        help="Calculate the yield potential map in batch mode")
    parser.add_argument("-s", "--split", action='store_true',
                        help="Split all features of the input file into separate executions")
    parser.add_argument("-f", "--format", type=str, default='gtiff',
                        help="Output format of the result (supported by OpenEO)")

    # Yield potential map params
    parser.add_argument("input_file", type=str,
                        help="Path to the geojson file that contains the featurecollection to process")
    parser.add_argument("year", type=str,
                        help="Year for which to calculate the yield potential map")

    # Optional yield potential map params
    parser.add_argument("--source", type=str, default='terrascope',
                        choices=['terrascope', 'shub'],
                        help="Datasource to use (terrascope or shub)")
    parser.add_argument("-t", "--threshold", type=float, default=0.5,
                        help="Threshold to apply for selecting the fapar images")
    parser.add_argument("--exclude_months", type=int, nargs='*', default=[1, 2, 3, 4, 9, 10, 11, 12],
                        help="Months to exclude in the selection of the fapar images")
    parser.add_argument("--exclude_croptypes", type=str, nargs='*', default=[],
                        help="No yield potential maps will be created if the field contained one of the specified croptypes for the given year.")
    parser.add_argument("-r", "--raw", action='store_true', default=False,
                        help="Flag indicating if the raw values should be returned or categories (default)")
    parser.add_argument("-c", "--check", action='store_true', default=False,
                        help="Flag indicating if additional checks should be applied to the field before calculating the yield potential map.")

    return parser.parse_args()


if __name__ == '__main__':

    args = get_script_args()

    if args.verbose:
        _log.setLevel(logging.DEBUG)

    # Read the feature collection from the input file
    features = read_input(args.input_file)

    # Setting up connection with OpenEO
    _log.debug(f'Connecting to OpenEO')
    openeo_user = os.environ.get('OPENEO_USER', os.environ['USER'])
    openeo_pass = os.environ.get('OPENEO_PASS', os.environ['USER'] + '123')
    eoconn = openeo.connect("http://openeo.vito.be").authenticate_basic(openeo_user, openeo_pass)

    if args.execute:
        # Execute the processing graph
        jobs = [features] if not args.split else list(
            map(lambda x: {'type': 'FeatureCollection', 'features': [x]}, features['features']))
        cnt = 1
        for job in jobs:
            result = generate_yield_potential_map(conn=eoconn, source=args.source,
                                                  date_range=[f'{args.year}-01-01', f'{args.year}-12-31'],
                                                  features=job, threshold=args.threshold,
                                                  exclude_months=args.exclude_months,
                                                  exclude_croptypes=args.exclude_croptypes, raw=args.raw,
                                                  year_check=args.check,
                                                  polygonize=False)

            # Additional step for vector data
            if args.format.lower() == 'geojson':
                result = result.raster_to_vector()

            # Execute the processing graph
            base_name = Path(args.input_file).name.split(".")[0]
            output_file = Path(
                args.output) / f'{base_name}_{args.source}_{args.year}_{cnt}.{get_extension(args.format)}'
            execute_datacube(result, output_file=output_file, format=args.format,
                             title=f'Yield potential map - {args.year} - {base_name} - {args.format}', batch=args.batch)
            _log.info(f'Created yield potential map: {output_file}')
            cnt += 1

    if args.publish:
        for source in ['terrascope', 'shub']:
            graph = publish_processing_graph(conn=eoconn, source=source,
                                             threshold=args.threshold,
                                             exclude_months=args.exclude_months,
                                             exclude_croptypes=args.exclude_croptypes, raw=args.raw,
                                             year_check=args.check)

            # Write the processing graph to the resulting directory
            output_file = Path(
                __file__).parent / 'resources' / 'process_graphs' / f'yieldpotentialmap_{source}_graph.json'
            write_json(path=output_file, data=graph)
            _log.info(f'Published processing graph at {output_file}')
