from typing import List

from loguru import logger
from pathlib import Path
import copy
import json

from satio.grid import get_blocks_gdf
from satio import layers
from satio.collections import (AgERA5Collection)
from satio.utils.logs import (exitlogs, proclogs,
                              memlogs, ExitLogsMonitor)

from worldcereal.seasons import get_processing_dates, NoSeasonError
from worldcereal.utils.aez import get_matching_aez_id, group_from_id
from worldcereal.utils import get_best_model, needs_s3_loader
from worldcereal.fp import (TSL2AFeaturesProcessor,
                            TSSIGMA0FeaturesProcessor,
                            TSSIGMA0TiledFeaturesProcessor,
                            WorldCerealAgERA5FeaturesProcessor)
from cropclass.processors import CropClassProcessor
from satio.collections import (TerrascopeV200Collection,
                               TerrascopeSigma0Collection,
                               create_terrascope_products_df)
from worldcereal.collections import TerrascopeSigma0TiledCollection
from worldcereal.collections import WorldCerealDEMCollection
from worldcereal.processors import CropTypeProcessor
from worldcereal.worldcereal_products import get_processing_blocks
from cropclass.features.settings_old import (get_croptype_catboost_parameters,
                                             get_croptype_tsteps_parameters)
from worldcereal.postprocess import PostProcessor
from worldcereal.collections import (get_S2_products,
                                     create_products_df)
from worldcereal.collections import get_S1_products as get_S1_products_TS

S2_GRID = layers.load('s2grid')


def get_S1_products(tile):

    basedir = Path('/data/worldcereal/data/s1_sigma0_tiled')
    currentdir = basedir / tile
    s1_products = list(currentdir.iterdir())
    s1_products = [str(x) for x in s1_products]

    return s1_products


def get_collections(tile, start_date, end_date, epsg):

    # ----------------------------
    # Look for S1 and S2 products
    # ----------------------------
    logger.info('Creating collections ...')
    epsg = S2_GRID[S2_GRID['tile'] == tile].epsg.values[0]
    S2_GRID[S2_GRID['tile'] == tile]
    s2products = get_S2_products(tile, start_date, end_date)
    s2_products_df = create_products_df(s2products, tile, epsg)
    # s1products = get_S1_products(tile)
    # s1_products_df = create_products_df(s1products, tile, epsg,
    #                                     sensor='s1')
    bbox = ','.join(
        [str(x) for x in S2_GRID[
            S2_GRID['tile'] == tile].geometry.values[0].bounds])
    s1products = get_S1_products_TS(bbox, 'ASCENDING', start_date, end_date)
    s1_products_df = create_terrascope_products_df(
        s1products, tile, epsg, sensor='s1')

    # ----------------------------
    # Create the collections
    # ----------------------------
    S2coll = TerrascopeV200Collection(s2_products_df)
    S1coll = TerrascopeSigma0Collection(s1_products_df)
    demcoll = WorldCerealDEMCollection(folder='/data/MEP/DEM/'
                                       'COP-DEM_GLO-30_DTED/S2grid_20m')
    AgERA5coll = AgERA5Collection.from_path('/data/MTDA/AgERA5')
    collections = {
        'OPTICAL': S2coll,
        'SAR': S1coll,
        'METEO': AgERA5coll,
        'DEM': demcoll
    }

    fps = {
        'OPTICAL': TSL2AFeaturesProcessor,
        'SAR': TSSIGMA0FeaturesProcessor,
        'METEO': WorldCerealAgERA5FeaturesProcessor,
    }

    return collections, fps


def _run_block(output_folder, processing_tuple):
    '''Main WorldCereal products pipeline for 1 processing block

    '''
    block_id, block = processing_tuple

    # Get the processing parameters
    parameters = block.parameters
    tile = block.tile
    aez_id = parameters['aez']
    season = parameters['season']
    class_model = parameters['class_model']
    mask_model = parameters.get('mask_model', None)
    start_date = parameters['start_date']
    end_date = parameters['end_date']
    use_local_models = parameters.get('localmodels', True)
    featuresettings = parameters['featuresettings']
    filtersettings = parameters['filtersettings']
    segment = parameters.get('segment', False)
    segment_feat = parameters.get('segment_feat', None)
    save_confidence = parameters.get('save_confidence', True)
    save_meta = parameters.get('save_meta', True)
    save_features = parameters.get('save_features', False)
    class_encoder = parameters.get('class_encoder', None)
    mask_encoder = parameters.get('mask_encoder', None)

    logger.info('-'*50)
    logger.info(f'Starting processing block: {block_id}')
    logger.info('-'*50)
    logger.info('PARAMETERS:')
    for parameter, value in parameters.items():
        logger.info(f'{parameter}: {value}')

    # Get collections
    collections, fps = get_collections(tile, start_date, end_date, block.epsg)

    # Get best model
    if use_local_models:
        class_model = get_best_model(class_model, aez_id)
        if mask_model is not None:
            mask_model = get_best_model(mask_model, aez_id)
    else:
        logger.info(f'Using parent model for the class/mask models!')

    # Initialize processing chain
    chain = CropClassProcessor(
        output_folder,
        class_model=class_model,
        mask_model=mask_model,
        class_encoder=class_encoder,
        mask_encoder=mask_encoder,
        season=season,
        aez=aez_id,
        collections=dict(
            OPTICAL=collections['OPTICAL'],
            SAR=collections['SAR'],
            METEO=collections['METEO'],
            DEM=collections['DEM'],
        ),
        settings=featuresettings['settings'],
        features_meta=featuresettings['features_meta'],
        rsi_meta=featuresettings['rsi_meta'],
        ignore_def_feat=featuresettings['ignore_def_feat'],
        gdd_normalization=featuresettings['gddnormalization'],
        fps=fps,
        start_date=parameters['start_date'].strip('-'),
        end_date=parameters['end_date'].strip('-'),
        save_features=save_features,
        save_meta=save_meta,
        avg_segm=segment,
        segm_feat=segment_feat,
        save_confidence=save_confidence,
        filtersettings=filtersettings)

    # Run pipeline
    chain.process(block.tile, block.bounds,
                  block.epsg, block.block_id)


def run_tile(tile: str,
             configfile: str,
             outputfolder: str,
             blocks: List = None,
             skip_processed: bool = True,
             debug: bool = False,
             process: bool = True,
             postprocess: bool = True,
             sparkcontext=None):
    """Generates WorldCereal products.

    Args:
        tile (str): MGRS tile ID to process. Example: '31UFS'
        configfile (str): path to config.json containing processing settings
        outputfolder (str): path to use for saving products and logs
        blocks (List, optional): Block ids of the blocks to process
                from the given tile. Should be a sequence of integers
                between 0 and 120 included. If not provided, all blocks
                will be processed.
        skip_processed (bool, optional): Skip already processed blocks
                by checking the existlogs folder. Defaults to True.
        debug (bool, optional): Run in debug mode, processing only
                one part of one block. Defaults to False.
        process (bool, optional): If False, skip block processing
        postprocess (bool, optional): If False, skip post-processing to COG
        sparkcontext (optional): Optional sparkcontext to parallellize
                block processing using spark.
    """

    # Load the config file
    if not Path(configfile).is_file():
        raise FileNotFoundError(
            'Required config file '
            f'`{configfile}` not found. Cannot continue.')
    config = json.load(open(configfile, 'r'))

    # Get processing parameters
    parameters = config['parameters']

    # Get the right feature settings
    # parameters['featuresettings'] = get_croptype_catboost_parameters()
    parameters['featuresettings'] = get_croptype_tsteps_parameters()

    # Add models
    parameters['class_model'] = config['models']['class_model']
    parameters['mask_model'] = config['models']['mask_model']

    # Get year and season to process
    year = parameters['year']
    season = parameters['season']
    start_date = parameters.get('start_date', None)
    end_date = parameters.get('end_date', None)

    # Determine AEZ ID and add as parameter
    aez_id = get_matching_aez_id(
        S2_GRID[S2_GRID.tile == tile].geometry.values[0])
    parameters['aez'] = aez_id

    # Get processing dates
    if start_date is None and end_date is None:
        try:
            start_date, end_date = get_processing_dates(season, aez_id, year)
            parameters['start_date'] = start_date
            parameters['end_date'] = end_date
        except NoSeasonError:
            logger.warning(f'No valid `{season}` season found for this tile.')
            return

    if process:

        # ----------------------------
        # Get processing blocks
        # ----------------------------
        blocks = get_processing_blocks([tile], parameters, debug,
                                       blocks=blocks)

        # ----------------------------
        # Create processing tuples
        # ----------------------------
        logger.info('Getting processing tuples ...')
        processing_tuples = [(f'{row.tile}_{row.block_id:03d}', row)
                             for row in blocks.itertuples()]

        # Setup logging folders
        outputfolder = Path(outputfolder)
        exitlogs_folder = outputfolder / 'exitlogs' / f'{year}_{season}'
        proclogs_folder = outputfolder / 'proclogs' / f'{year}_{season}'
        memlogs_folder = outputfolder / 'memlogs' / f'{year}_{season}'

        @proclogs(proclogs_folder, level='DEBUG')
        @exitlogs(exitlogs_folder, skip_processed=skip_processed)
        # @memlogs(memlogs_folder)
        def _log_run_block(processing_tuple):
            _run_block(outputfolder, processing_tuple)

        exitlogs_monitor = ExitLogsMonitor(exitlogs_folder,
                                           interval=300)
        exitlogs_monitor.start()

        if sparkcontext is not None:
            logger.info('Starting spark parallellization ...')
            sparkcontext.parallelize(
                processing_tuples,
                len(processing_tuples)).foreach(_log_run_block)

        else:
            logger.info('Running in serial ...')
            # TODO: allow processing in a few threads
            for tup in processing_tuples:
                _log_run_block(tup)

        exitlogs_monitor.task()
        exitlogs_monitor.stop()

    # ----------------------------
    # Do post-processing
    # ----------------------------

    if postprocess:

        logger.info('Start post-processing to COGs...')
        cog_folder = Path(outputfolder) / 'cogs'

        # Cropland/Croptype product(s)
        product = 'croptype'
        logger.info(f'Working on product: {product}')

        postprocessor = PostProcessor(outputfolder, cog_folder,
                                      product, tile, parameters)

        postprocessor.run(improve_consistency=False,
                          generate_metadata=False,
                          skip_processed=skip_processed,
                          in_memory=False,)

    logger.success('Finished!')


if __name__ == '__main__':

    # -----------------
    tile = '31UFS'
    blocks = [1]
    # configfile = (Path(__file__).parent / 'resources' /
    #               'configs' /
    #               'config.json')
    configfile = ('/vitodata/CropSAR/cropmap/NEXTLAND/'
                  'worldcereal/runs/config.json')
    outputfolder = '/vitodata/CropSAR/cropmap/NEXTLAND/worldcereal/runs/test'

    debug = True
    spark = False
    skip_processed = True
    # ------------------

    # Initialize SC
    if spark:
        from worldcereal.utils.spark import get_spark_context
        sc = get_spark_context()
    else:
        sc = None

    run_tile(tile, configfile, outputfolder,
             blocks, skip_processed, debug,
             sparkcontext=sc)
