import itertools

from loguru import logger

from cropclass.products import run_tile
from worldcereal.utils.spark import get_spark_context


def process_blocks(tasks, config, outputfolder, sc,
                   skip_processed=True):

    logger.info(f'Launching {len(tasks)} tasks on executors ...')
    sc.parallelize(tasks, len(tasks)).foreach(
        lambda x: run_tile(
            x[0],
            config,
            outputfolder=outputfolder,
            blocks=x[1],
            skip_processed=skip_processed,
            debug=False,
            postprocess=False
        )
    )


def postprocess_tiles(tiles, config, outputfolder, sc,
                      skip_processed=True):

    logger.info('Starting postprocessing on executors ...')
    sc.parallelize(tiles, len(tiles)).foreach(
        lambda x: run_tile(
            x,
            config,
            outputfolder=outputfolder,
            skip_processed=skip_processed,
            debug=False,
            process=False,
            postprocess=True
        )
    )


if __name__ == '__main__':

    # Initialize spark
    sc = get_spark_context()

    # ----------------------------------------------------
    # VERSION TO RUN ONE ENTIRE TILE ON SPARK:
    # ----------------------------------------------------

    tile = '31UFS'
    block = None

    run_config = ('/vitodata/CropSAR/cropmap/NEXTLAND/'
                  'worldcereal/runs/config.json')
    outputfolder = ('/vitodata/CropSAR/cropmap/'
                    'NEXTLAND/worldcereal/runs/'
                    'CroptypeBelgiumRNNsummer1_V050_TScoll')

    run_tile(tile,
             run_config,
             outputfolder=outputfolder,
             blocks=block,
             skip_processed=False,
             debug=False,
             sparkcontext=sc
             )

    # # ----------------------------------------------------
    # # VERSION TO RUN MULTIPLE TILES SIMULTANEOUSLY ON SPARK:
    # # ----------------------------------------------------

    # # ---------------------------------------------------------
    # # Set path to CONFIG.
    # # Make sure config file is NOT changed during entire time
    # # that the job is running!!!

    # run_config = ('/vitodata/CropSAR/cropmap/NEXTLAND/'
    #               'worldcereal/runs/config.json')

    # # Set base output folder
    # outputfolder = ('/vitodata/CropSAR/cropmap/'
    #                 'NEXTLAND/worldcereal/runs/CroptypeBelgiumRNNsummer1_V050')

    # # MAKE A LIST OF TILES
    # skip_processed = True
    # tiles = ['31UFS']

    # # MAKE A LIST OF ALL BLOCKS IN A TILE
    # blocks = list(range(121))

    # # THE PRODUCT OF BOTH MAKES UP THE TOTAL TASK LIST
    # tasks = list(itertools.product(tiles, blocks))

    # logger.info('DO BLOCK PROCESSING')
    # process_blocks(tasks, run_config, outputfolder,
    #                sc, skip_processed=skip_processed)

    # logger.info('DO TILE POSTPROCESSING')
    # postprocess_tiles(tiles, run_config, outputfolder,
    #                   sc, skip_processed=skip_processed)

    # logger.success('ALL DONE!')
