#script to preprocess agera5 data and convert this data to a geojson.

from openeo.extra.job_management import MultiBackendJobManager, _format_usage_stat
from cropclass.utils import laea20km_id_to_extent
from openeo_classification.connection import terrascope
from cropclass.openeo.preprocessing import add_meteo
from pathlib import Path
import os
import time
import fire
import json
import logging
import geopandas as gpd
import pandas as pd

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger()



def run(year=2021, parallel_jobs = 20, input_file="croptype2021_terrascope.geojson", status_file = "eu27_2021_meteo.csv", output_dir="."):
    output= os.path.join(output_dir, str(year))

    """
     Script to start and monitor jobs for the EU27 croptype map project in openEO platform CCN. Used specifically for preprocessing_agera5 data
     The script can use multiple backends, to maximize throughput. Jobs are tracked in a CSV file, upon failure, the script can resume
     processing by pointing to the same csv file. Delete that file to start processing from scratch.

     @param year:  The year for which to generate a cropmap
     @param parallel_jobs:
     @param status_file: The local file where status should be tracked.
     @return:
     """

    with Path(input_file).open('r') as f:
        tiles_to_produce = gpd.GeoDataFrame.from_features(json.load(f))

    '''
    #Does not seem to work inline, but when run seperatly it works
    S2_list = ['30SVH','31UFS','32TPT','37TBF','34VFN','33VWG','35VME','29SNC','29TPG','30TUM','30SVH','31TCJ','31UDP','33UYP','33TWN','33TWM','33TXL','33TWJ','35WMM']
    regex= '|'.join(S2_list)
    tiles_to_produce2 = pd.DataFrame(tiles_to_produce).sort_values(by='mgrs_ids', key=lambda col: col.str.contains((regex))*-1)
    '''


    logger.info(f"Found {len(tiles_to_produce)} tiles to process. Year: {year}")


    class CustomJobManager(MultiBackendJobManager):


        def __init__(self, poll_sleep=1):
            super().__init__(poll_sleep)
            #ask why not working?
            #self.required_with_default.append(("error_reason",""))

        def on_job_error(self, job, output_dir):
            logs = job.logs()
            error_logs = [l for l in logs if l.level.lower() == "error"]
            job_metadata = job.describe_job()

            title = job_metadata['title']
            base_dir = Path(output_dir)

            if len(error_logs) > 0:
                (base_dir / f'job_{title}_errors.json').write_text(json.dumps(error_logs, indent=2))

        def _update_statuses(self, df: pd.DataFrame):
            """Update status (and stats) of running jobs (in place)"""
            active = df.loc[(df.status == "created") | (df.status == "queued") | (df.status == "running")]
            for i in active.index:
                job_id = df.loc[i, 'id']
                backend_name = df.loc[i, "backend_name"]
                con = self.backends[backend_name].get_connection()
                the_job = con.job(job_id)
                try:
                    job_metadata = the_job.describe_job()
                except:
                    time.sleep(5)
                    continue
                logger.info(f"Status of job {job_id!r} (on backend {backend_name}) is {job_metadata['status']!r}")

                if job_metadata["status"] == "created":
                    try:
                        the_job.start_job()
                    except:
                        logger.warning("Failed to start the job, will try again after poll")
                if job_metadata["status"] == "finished":
                    the_job.download_result(os.path.join(output,job_metadata['title']))
                df.loc[i, "status"] = job_metadata["status"]
                df.loc[i, "cpu"] = _format_usage_stat(job_metadata, "cpu")
                df.loc[i, "memory"] = _format_usage_stat(job_metadata, "memory")
                df.loc[i, "duration"] = _format_usage_stat(job_metadata, "duration")


    def run(row,connection_provider,connection, provider):
        job_options = {
            "driver-memory": "512M",
            "driver-memoryOverhead": "1G",
            "driver-cores": "1",
            "executor-memory": "512M",
            "executor-memoryOverhead": "512M",
            "executor-cores": "1",
            "max-executors": "32",
            "logging-threshold": "info",
            "mount_tmp":False,
            "goofys":"false",
            "node_caching": True,
            "task-cpus":1

        }

        EXTENT_20KM = laea20km_id_to_extent(row['name'])


        print(f"submitting job to {provider}")

        startdate = 'YYYY-01-01'.replace('YYYY', str(year))
        enddate = 'YYYY-12-31'.replace('YYYY', str(year))

        clf_results = add_meteo(connection=connection, METEO_collection='AGERA5',other_bands = None,
                            bbox=EXTENT_20KM , start=startdate, end=enddate)


        job = clf_results.create_job(
            title=row[f"meteo_2021"].
                replace("2021",str(year)),
            decription="create meteo ts for croptype maps",
            out_format="json",
            job_options=job_options)

        try:
            job.start_job()
        except:
            time.sleep(10)


        return job

    manager = CustomJobManager()
    manager.add_backend("terrascope", connection=terrascope, parallel_jobs=parallel_jobs)


    manager.run_jobs(
        df=tiles_to_produce,
        start_job=run,
        output_file=Path(os.path.join(output_dir),status_file)
    )




if __name__ == '__main__':
    year = 2018
    fire.Fire(run(year=year, parallel_jobs=20,input_file="/home/deroob/git/cropclass/src/cropclass/production/jobsplit_laea20km_2021_all_final.geojson", status_file = f"eu27_{str(year)}_meteo.csv",output_dir="/vitodata/EEA_HRL_VLCC/data/ref/METEO")
              )