from loguru import logger
from pathlib import Path
import geopandas as gpd
import pandas as pd
import xarray as xr
import multiprocessing as mp
import os


_BASE_ATTRIBUTES = [
    'CT',
    'validityTi',
    'CT_fin',
    'tile',
    'easting',
    'northing',
    'epsg',
    'zonenumber',
    'zoneletter',
    'round_lat',
    'round_lon',
    'bounds',
    'year',
    'start_date',
    'end_date',
    'ref_id',
    'source',
    'source_path',
    'geometry'
]

# Define the products that should be in the database
_PRODUCTS = ['S2_L2A_10m',
             'S2_L2A_20m',
             'AgERA5_DAILY_1000m',
             'S1_GRD-ASCENDING_10m',
             'S1_GRD-DESCENDING_10m',
             'OUTPUT']


def parametrized(dec):
    """This decorator can be used to create other
    decorators that accept arguments"""

    def layer(*args, **kwargs):
        def repl(f):
            return dec(f, *args, **kwargs)

        return repl

    return layer


@parametrized
def sigsev_guard(fcn, default_value=None, timeout=None):
    """Used as a decorator with arguments.
    The decorated function will be called with its
    input arguments in another process.

    If the execution lasts longer than *timeout* seconds,
    it will be considered failed.

    If the execution fails, *default_value* will be returned.
    """

    def _fcn_wrapper(*args, **kwargs):
        q = mp.Queue()
        p = mp.Process(target=lambda q: q.put(fcn(*args, **kwargs)), args=(q,))
        p.start()
        p.join(timeout=timeout)
        exit_code = p.exitcode

        if exit_code == 0:
            return q.get()

        logger.warning(('Process did not exit correctly. '
                        f'Exit code: {exit_code}'))
        return default_value

    return _fcn_wrapper


@sigsev_guard(default_value=False, timeout=60)
def open_file_safely(file):
    _ = xr.open_dataset(file,
                        engine='h5netcdf',
                        mask_and_scale=False
                        )
    return True


class DataBase(object):

    def __init__(self,
                 rootdir,
                 samplefiles: list,
                 sc=None):

        self.rootdir = rootdir
        self.source_files = samplefiles
        self._gdf = None
        self._sc = sc

    def _read_file(self, file: str):
        if not Path(file).is_file():
            raise ValueError(f'File {file} does not exist!')

        logger.debug(f'Opening {file} ...')
        gdf = gpd.read_file(file)

        return gdf

    def append(self, dbfile):
        db_old = gpd.read_file(dbfile)
        db_old = db_old.set_index('sampleID')

        self._gdf = pd.concat([self.database, db_old], axis=0)

    def to_json(self, outfile):
        if self.database is None:
            raise ValueError('First populate the database!')
        if len(self.database) > 0:
            logger.info((f'Writing {self.database.shape[0]} samples '
                         f'to {outfile} ...'))
            self.database.to_file(outfile, driver='GeoJSON')
        else:
            logger.warning(
                'No valid samples remaining, nothing written to file!')

    def populate(self):

        logger.info('Populating database ...')
        gdf = gpd.GeoDataFrame()
        for file in self.source_files:
            gdf = gdf.append(self._read_file(file))

        logger.info(f'Added a total of {len(gdf)} samples to database!')

        # Set the location_id attribute as the index
        gdf = gdf.set_index('sampleID')

        # Make sure there's no duplicates
        nr_duplicates = gdf.index.duplicated().sum()
        if nr_duplicates > 0:
            logger.warning((f'Will drop {nr_duplicates} samples '
                            'with duplicate location_id!'))
            gdf = gdf[~gdf.index.duplicated()]

        # Only keep the required attributes
        gdf = gdf[_BASE_ATTRIBUTES]

        self._gdf = gdf

    def add_root(self):
        """
        Function to add "path" column to df
        referring to data location on disk
        """
        if self.database is None:
            raise ValueError('First populate the database!')

        def _add_root(row):
            rootpath = (Path(self.rootdir) / 'extractions' / 'data' /
                        str(row['epsg']) / row['tile'] / str(row.name))
            return str(rootpath)

        logger.info('Adding root paths ...')

        self.database['path'] = self.database.apply(_add_root,
                                                    axis=1)

    def add_products(self, products=_PRODUCTS, remove_errors=False):
        if self.database is None:
            raise ValueError('First populate the database!')

        def _check_dates(ds, startdate, enddate, file):
            logger.debug(f'Checking start and end of: {file}')
            start = pd.to_datetime(ds.timestamp.values[0])
            end = pd.to_datetime(ds.timestamp.values[-1])
            if ((start - pd.to_datetime(startdate))
                    > pd.Timedelta(days=60)):
                logger.error((f'"{file}" only starts at '
                              f'{start} while stack should start '
                              f' at {startdate}!'))
                return False
            elif ((pd.to_datetime(enddate) - end)
                  > pd.Timedelta(days=60)):
                logger.error((f'"{file}" ends at '
                              f'{end} while stack should end '
                              f' at {enddate}!'))
                return False
            else:
                logger.info('Date check OK!')
                return True

        def _find_products(row, products=_PRODUCTS,
                           remove_errors=False):

            def _check_file(file, remove_errors=False):
                # We check if file can be opened
                # and if we have expected time range
                # in the product
                logger.info(f'Attempt opening file: {file}')

                if open_file_safely(file):
                    ds = xr.open_dataset(file,
                                         engine='h5netcdf',
                                         mask_and_scale=False)
                    logger.info('Successfully openened ...')

                    if 'OUTPUT' not in pattern:
                        is_ok = _check_dates(
                            ds,
                            row['start_date'],
                            row['end_date'],
                            file)
                    else:
                        is_ok = True
                else:
                    is_ok = False
                if not is_ok and remove_errors:
                    logger.warning(f'Removing: {file}')
                    os.remove(file)
                return is_ok

            if 'path' not in row:
                raise ValueError('Rooth data path not found, '
                                 'first run the "add_root" functionality!')
            datapath = Path(row['path'])

            productpaths = {'sampleID': row.name}
            for pattern in products:
                if pattern != 'OUTPUT':
                    productpath = (datapath / '_'.join([
                        pattern, str(row.name), str(row.epsg),
                        row.start_date, row.end_date + '.nc'
                    ]))
                else:
                    labeltype = row.ref_id.split('_')[-2]
                    productpath = (datapath / '_'.join([
                        pattern,
                        labeltype,
                        '10m', str(row.name), str(row.epsg),
                        row.start_date, row.end_date + '.nc'
                    ]))
                try:
                    is_ok = _check_file(productpath,
                                        remove_errors=remove_errors)
                    productpaths[pattern] = str(productpath) if is_ok else None

                except Exception as e:
                    # Fails when after retries product
                    # is not found or cannot be opened
                    logger.error(f'Got an error for file : {productpath}')
                    logger.error(e)
                    productpaths[pattern] = None

            return productpaths

        if self._sc is None:
            # Run locally
            logger.info(f'Looking for products locally...')
            products = list(self.database.apply(
                lambda row: _find_products(
                    row,
                    products=products,
                    remove_errors=remove_errors),
                axis=1).values)
        else:
            logger.info(f'Looking for products using spark ...')
            products = self._sc.parallelize(
                self.database.iterrows(),
                len(self.database.index.tolist())).map(
                lambda row: _find_products(
                    row[1],
                    products=products,
                    remove_errors=remove_errors)).collect()

        productpaths = pd.DataFrame.from_dict(
            products).set_index('sampleID')
        self._gdf = pd.concat([self.database,
                               productpaths], axis=1)

    def check_s2_resolutionmatch(self):
        if self.database is None:
            raise ValueError('First populate the database!')
        if 'S2_L2A_10m' not in self.database.columns:
            raise Exception('"S2_L2A_10m" not in paths')
        if 'S2_L2A_20m' not in self.database.columns:
            raise Exception('"S2_L2A_20m" not in paths')

        def _is_matching(row):
            logger.debug(f'Checking sample {row.name} ...')

        logger.info('Checking S2 resolution match ...')
        result = self.database.apply(_is_matching, axis=1)
        result = result[result]
        return result

    def check_issues(self, products=_PRODUCTS):

        def _check_sample(row):

            checks = dict()
            checks['sampleID'] = row.name
            logger.debug(f'Checking sample: {row.name}')

            # Check if the files can be opened
            for product in products:
                if product not in row:
                    logger.warning(f'"{product}" not in paths')
                    checks[product] = False
                elif row[product] is None:
                    logger.warning(f'"{product}" does not exist or is corrupt')
                    checks[product] = False
                else:
                    # File is OK
                    checks[product] = True

            # For S2: check if timestamps of both
            # resolutions are identical
            if (checks['S2_L2A_10m']
                    and checks['S2_L2A_20m']):

                ds_10m = xr.open_dataset(row['S2_L2A_10m'],
                                         engine='h5netcdf',
                                         mask_and_scale=False)
                ds_20m = xr.open_dataset(row['S2_L2A_20m'],
                                         engine='h5netcdf',
                                         mask_and_scale=False)
                if ((ds_10m.timestamp.size ==
                     ds_20m.timestamp.size) and all(
                         ds_10m.timestamp == ds_20m.timestamp)):
                    checks['s2_identicaltimestamps'] = True
                else:
                    logger.warning((f'Sample {row.name} has '
                                    ' a different number of '
                                    'timestamps!'))
                    checks['s2_identicaltimestamps'] = False

            else:
                checks['s2_identicaltimestamps'] = False

            return checks

        if self._sc is None:
            logger.info('Running checks locally ...')
            checks = self.database.apply(
                lambda row: _check_sample(row), axis=1).apply(
                    pd.Series
            ).drop(columns=['sampleID'])
        else:
            logger.info('Running checks on spark ...')

            results = self._sc.parallelize(
                self.database.iterrows(),
                len(self.database.index.tolist())).map(
                lambda row: _check_sample(row[1])).collect()

            checks = pd.DataFrame.from_dict(results)
            checks = checks.set_index('sampleID')

        # only retain samples with issues
        issuesinv = ~checks
        issues = checks[issuesinv.sum(axis=1) > 0]
        logger.info(f'Found issues with {len(issues)} files!')

        return issues

    def drop(self, index: list):
        '''
        Method to drop certain samples from database
        based on list of indexes
        '''
        logger.info('Dropping indexes from database.')
        self._gdf = self.database.drop(index=index)

    @ property
    def database(self):
        if self._gdf is None:
            logger.warning('Database not yet populated!')
            return None
        else:
            return self._gdf
