#script to preprocess agera5 data and convert this data to a geojson.

from openeo.extra.job_management import MultiBackendJobManager, _format_usage_stat
from cropclass.utils import laea20km_id_to_extent
import openeo
from cropclass.openeo.preprocessing import add_meteo
from pathlib import Path
import os
import time
import fire
import json
import logging
import geopandas as gpd
import pandas as pd

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger()

terrascope = openeo.connect("openeo.vito.be").authenticate_oidc()

def run(parallel_jobs = 20, input_file="croptype2021_terrascope.geojson", status_file = "eu27_2021_meteo.csv", output_dir="."):


    """
     Script to start and monitor jobs for the EU27 croptype map project in openEO platform CCN. Used specifically for preprocessing_agera5 data
     The script can use multiple backends, to maximize throughput. Jobs are tracked in a CSV file, upon failure, the script can resume
     processing by pointing to the same csv file. Delete that file to start processing from scratch.

     @param year:  The year for which to generate a cropmap
     @param parallel_jobs:
     @param status_file: The local file where status should be tracked.
     @return:
     """

    with Path(input_file).open('r') as f:
        tiles_to_produce = gpd.GeoDataFrame.from_features(json.load(f))




    logger.info(f"Found {len(tiles_to_produce)} tiles to process.")


    class CustomJobManager(MultiBackendJobManager):


        def __init__(self, poll_sleep=0):
            super().__init__(poll_sleep)
            #ask why not working?
            #self.required_with_default.append(("error_reason",""))

        def on_job_error(self, job, row):
            logs = job.logs()
            error_logs = [l for l in logs if l.level.lower() == "error"]
            job_metadata = job.describe_job()

            title = job_metadata['title']
            base_dir = Path(output_dir) / str(row['year'])

            if len(error_logs) > 0:
                (base_dir / f'job_{title}_errors.json').write_text(json.dumps(error_logs, indent=2))



        def _update_statuses(self, df: pd.DataFrame):
            """Update status (and stats) of running jobs (in place)"""
            active = df.loc[(df.status == "created") | (df.status == "queued") | (df.status == "running")]
            for i in active.index:
                job_id = df.loc[i, 'id']
                backend_name = df.loc[i, "backend_name"]
                con = self.backends[backend_name].get_connection()
                the_job = con.job(job_id)
                try:
                    job_metadata = the_job.describe_job()
                except:
                    time.sleep(5)
                    continue
                logger.info(f"Status of job {job_id!r} (on backend {backend_name}) is {job_metadata['status']!r}")

                if job_metadata["status"] == "created":
                    try:
                        the_job.start_job()
                    except:
                        logger.warning("Failed to start the job, will try again after poll")
                if job_metadata["status"] == "finished":
                    year = str(df.loc[i,'year'])
                    the_job.download_result(os.path.join(output_dir,year,job_metadata['title']))
                df.loc[i, "status"] = job_metadata["status"]
                df.loc[i, "cpu"] = _format_usage_stat(job_metadata, "cpu")
                df.loc[i, "memory"] = _format_usage_stat(job_metadata, "memory")
                df.loc[i, "duration"] = _format_usage_stat(job_metadata, "duration")


    def run(row,connection_provider,connection, provider):
        job_options = {
            "driver-memory": "512M",
            "driver-memoryOverhead": "1G",
            "driver-cores": "1",
            "executor-memory": "512M",
            "executor-memoryOverhead": "512M",
            "executor-cores": "1",
            "max-executors": "32",
            "logging-threshold": "info",
            "mount_tmp":False,
            "goofys":"false",
            "node_caching": True,
            "task-cpus":1

        }
        name = str(row['name'])
        year = str(row['year'])
        EXTENT_20KM = laea20km_id_to_extent(name)

        if os.path.exists(os.path.join(output_dir, str(year),f"METEO-{name}-{year}")):
            return None

        print(f"submitting job to {provider}")

        startdate = 'YYYY-01-01'.replace('YYYY', year)
        enddate = 'YYYY-12-31'.replace('YYYY', year)

        clf_results = add_meteo(connection=connection, METEO_collection='AGERA5',other_bands = None,
                            bbox=EXTENT_20KM , start=startdate, end=enddate)


        job = clf_results.create_job(
            title=f"METEO-{name}-{year}",
            decription="create meteo ts for croptype maps",
            out_format="json",
            job_options=job_options)

        '''
        try:
            job.start_job()
        except:
            time.sleep(10)
        '''

        return job

    manager = CustomJobManager()
    manager.add_backend("terrascope", connection=terrascope, parallel_jobs=parallel_jobs)


    manager.run_jobs(
        df=tiles_to_produce,
        start_job=run,
        output_file=Path(os.path.join(output_dir),status_file)
    )




if __name__ == '__main__':
    fire.Fire(run(parallel_jobs=100,input_file="/vitodata/EEA_HRL_VLCC/data/production/jobsplit_laea20km_all_final.geojson", status_file = f"eu27_all_meteo.csv",output_dir="/vitodata/EEA_HRL_VLCC/data/ref/METEO")
              )