import configparser
import os
from pathlib import Path
import rasterio
import numpy as np
from loguru import logger

from satio.utils.retry import retry


NR_THREADS = int(os.environ.get('SATIO_MAXTHREADS', 5))
RETRIES = int(os.environ.get('SATIO_RETRIES', 50))
DELAY = int(os.environ.get('SATIO_DELAY', 5))
BACKOFF = int(os.environ.get('SATIO_BACKOFF', 1))
TIMEOUT = int(os.environ.get('SATIO_TIMEOUT', 30))


class GDALWarpLoader:

    def __init__(self):

        self._gdal_options = {
            'GDAL_DISABLE_READDIR_ON_OPEN': 'FALSE',
            'CPL_VSIL_CURL_ALLOWED_EXTENSIONS': '.tif',
            'VSI_CACHE': 'FALSE'
        }
    @retry(exceptions=Exception, tries=RETRIES, delay=DELAY,
           backoff=BACKOFF, logger=logger)
    def _warp(self, fn, bounds, epsg, resolution,
              resampling='cubic'):
        import tempfile

        logger.debug(f'Start loading of: {fn}')

        with tempfile.NamedTemporaryFile() as tmp:
            self.gdal_warp([fn],
                           tmp.name,
                           bounds,
                           dst_epsg=epsg,
                           resolution=resolution,
                           resampling=resampling)
            with rasterio.open(tmp.name) as src:
                arr = src.read()

        logger.debug(f'Loading of {fn} completed.')

        return arr

    def gdal_warp(self,
                  src_fnames,
                  dst_fname,
                  bounds,
                  dst_epsg,
                  resolution=100,
                  center_long=0,
                  gdal_cachemax=2000,
                  resampling='cubic'):

        import sys
        import subprocess

        # First try system gdal
        bin = '/bin/gdalwarp'
        if not Path(bin).is_file():
            # Try environment gdal
            py = sys.executable.split('/')[-1]
            bin = sys.executable.replace(py, 'gdalwarp')
            if not Path(bin).is_file():
                raise FileNotFoundError(
                    'Could not find a GDAL installation.')

        if isinstance(src_fnames, str):
            src_fnames = [src_fnames]

        fns = " ".join(list(map(str, src_fnames)))
        str_bounds = " ".join(list(map(str, bounds)))

        env_vars_str = " ".join([f'{k}={v}' for
                                 k, v in self._gdal_options.items()])

        cmd = (
            f"{env_vars_str} "
            f"{bin} -of GTiff "
            f"-t_srs EPSG:{dst_epsg} "
            f"-te {str_bounds} "
            f"-tr {resolution} {resolution} -multi "
            f"-r {resampling} "
            f"--config CENTER_LONG {center_long} "
            f"--config GDAL_CACHEMAX {gdal_cachemax} "
            f"-co COMPRESS=DEFLATE "
            f"{fns} "
            f"{dst_fname}"
        )

        p = subprocess.run(cmd, shell=True, timeout=TIMEOUT)
        if p.returncode != 0:
            raise IOError("GDAL warping failed")
        else:
            logger.debug(f'Warping of {fns} completed.')

    def _dates_interval(self, start_date, end_date):
        import datetime
        days = [start_date + datetime.timedelta(days=x)
                for x in range((end_date - start_date).days)]
        return days

    def _yearly_dates(self, year):
        import datetime

        d1 = datetime.datetime(year, 1, 1)
        d2 = datetime.datetime(year + 1, 1, 1)

        return self._dates_interval(d1, d2)

    def load(self, collection, bands, resolution,
             src_nodata=None, dst_nodata=None,
             resampling='cubic'):
        from satio.timeseries import Timeseries
        from dateutil.parser import parse

        if len(bands) > 1:
            # can load only 1 band per time
            raise NotImplementedError

        if not isinstance(bands, (list, tuple)):
            raise TypeError("'bands' should be a list/tuple of bands. "
                            f"Its type is: {type(bands)}")

        band = bands[0]
        dst_bounds = list(collection.bounds)
        dst_epsg = collection.epsg
        # these are the bounds and epsg requested for the final data
        # we need to check the epsgs of source data and get the filenames

        filenames = collection.get_band_filenames(band)
        start_date = parse(collection.start_date)
        end_date = parse(collection.end_date)

        arr = None
        timestamps = None

        for fn in filenames:
            arr_tmp = self._warp(fn,
                                 dst_bounds,
                                 dst_epsg,
                                 resolution,
                                 resampling=resampling)

            year = int(Path(fn).name.split('.')[0].split('_')[-1])
            ts_tmp = self._yearly_dates(year)

            if arr is None:
                arr = arr_tmp
                timestamps = ts_tmp
            else:
                arr = np.concatenate([arr, arr_tmp], axis=0)
                timestamps += ts_tmp

        # filter dates
        filtered_timestamps = self._dates_interval(start_date,
                                                   end_date)
        ts_arr = np.array(timestamps)
        time_flag = (ts_arr >= start_date) & (ts_arr < end_date)
        arr = arr[time_flag, ...]
        arr = np.expand_dims(arr, axis=0)

        ts = Timeseries(arr, filtered_timestamps, bands)
        ts.attrs['sensor'] = collection.sensor

        return ts
