try:
    import gdal
except ImportError:
    from osgeo import gdal
import pandas as pd
import numpy as np
import os
import glob
import rasterio
import xarray as xr
from loguru import logger
import rasterio.mask
from rasterio.warp import reproject, Resampling
from pathlib import Path




_GLOBAL_NAMESPACE = dict([(key, getattr(np, key))
                          for key in dir(np) if not key.startswith('__')])

class Vrt:
    '''
    A class that deals with tile-specifc VRTs that are needed for
    various data sources
    '''

    def __init__(self, sourcedir, vrtdir, startdate, enddate, inpattern,
                 infiles, outpattern, dtype=np.float32,
                 operation='1*x', nodata=None, dropna=False, mask=None):

        self._sourcedir = sourcedir
        self._startdate = pd.to_datetime(startdate)
        self._enddate = pd.to_datetime(enddate)
        self.dtype = dtype
        self.operation = operation
        self.nodata = nodata
        self.dropna = dropna
        self.mask = mask

        self.file, self.dates, self.sourcefiles = self._create(vrtdir,
                                                               inpattern,
                                                               infiles,
                                                               outpattern)

    def _create(self, vrtdir, inpattern, infiles, outpattern):

        outvrt = os.path.join(vrtdir,
                              outpattern + '_'.join(
                                  [self._startdate.strftime('%Y%m%d'),
                                   self._enddate.strftime('%Y%m%d')]) + '.vrt')

        self.name = outpattern + '_'.join(
            [self._startdate.strftime('%Y%m%d'),
             self._enddate.strftime('%Y%m%d')])

        if os.path.exists(outvrt):
            os.remove(outvrt)
        logger.info('Creating: {}'.format(outvrt))

        if infiles is None:
            img_files = glob.glob(os.path.join(self._sourcedir, inpattern))
        else:
            img_files = infiles

        if len(img_files) == 0:
            logger.info('No img files found under: {}'.format(
                os.path.join(self._sourcedir, inpattern)))

        indates = pd.to_datetime(
            np.array([os.path.basename(x).split('_')[1][0:8] for
                      x in img_files]))
        mask = (indates >= self._startdate) & (indates <= self._enddate)

        img_files = np.array(img_files)[mask]
        indates = indates[mask]
        idx_sort = np.argsort(indates)
        img_files = list(img_files[idx_sort])

        # Create a file that will hold the individual files
        filename = os.path.join(vrtdir,
                                '_'.join(['filelist', outpattern,
                                          self._startdate.strftime(
                                              '%Y%m%d'),
                                          self._enddate.strftime(
                                              '%Y%m%d')]) + '.txt')

        if os.path.exists(filename):
            os.remove(filename)
        fileList = open(filename, 'w')

        for file in img_files:
            fileList.write(file + '\n')
        fileList.close()

        vrtopts = gdal.BuildVRTOptions(separate=True)
        gdal.BuildVRT(outvrt, img_files, options=vrtopts)

        return outvrt, indates[idx_sort], img_files

    def get_size(self):
        with rasterio.open(self.file, 'r') as src:
            return src.height, src.width

    def get_crs(self):
        with rasterio.open(self.file, 'r') as src:
            return src.crs

    def get_transform(self):
        with rasterio.open(self.file, 'r') as src:
            return src.transform

    def get_pixel_position(self, coord_x, coord_y):
        with rasterio.open(self.file, 'r') as src:
            return src.index(coord_x, coord_y)


class Tile:
    '''
    A class for each S2 tile containing references to where
    the tile imagery can be found
    Respective VRTs are created automatically each time a new datasource is
    added to this class
    '''

    def __init__(self, tile, vrt_dir, startdate, enddate, overwrite=False):

        self.id = tile
        self.datasources = {}
        self._startdate = pd.to_datetime(startdate)
        self._enddate = pd.to_datetime(enddate)
        self._overwrite = overwrite
        self._vrtdir = vrt_dir

        if not os.path.exists(vrt_dir):
            os.makedirs(self._vrtdir, exist_ok=True)

        self.datetimeindex = self._get_datetimeindex()

    def add_datasource(self, name, imgdir='', inpattern='', infiles=None,
                       outpattern='', dtype=np.float32,
                       operation='1*x', nodata=None, dropna=False, mask=None):
        '''
        Method to add a new datasource to the tile object
        :param name: the name of the datasource (eg "s1_asc_vv")
        :param imgdir: the directory where the imagery is located
        :param inpattern: a datasource specific pattern based on which the
        list of files will be retrieved
        :param infiles: optional list of input files: cannot be specified
        together with imgdir and inpattern
        :param outpattern: basename of the output vrt to be created
        :param dtype: datatype of the imagery
        :param operation: optional numpy-like operation to be performed on the
        data when a window is requested
        :param nodata: input nodata value to be considered
        :param dropna: whether or not no data acquisitions should be dropped
        from the input stack
        :param mask: if not None, 'mask' refers to an existing datasource in
        the tile that serves as the mask for the new datasource
        :return:
        '''
        if infiles is not None:
            assert inpattern == '', ("A list of infiles cannot be provided"
                                     " together with an input pattern!")
        if infiles is not None:
            assert type(infiles) == list, "Provided infiles should be a list!"
            assert len(
                infiles) != 0, "Infiles should contain at least one image!"
        if mask is not None and mask not in self.datasources:
            raise ValueError(
                ('Provided mask "{}" does not exist in Tile object:'
                 ' first add the required mask!').format(mask))


        self.datasources[name] = Vrt(imgdir, self._vrtdir, self._startdate,
                                     self._enddate, inpattern, infiles,
                                     outpattern + '_' + self.id + '_',
                                     dtype=dtype, operation=operation,
                                     nodata=nodata, dropna=dropna, mask=mask)

        # Make sure the mask has the same acquisition dates
        if mask is not None:
            assert np.array_equal(
                self.datasources[name].dates, self.datasources[mask].dates)

    def get_datasource(self, name):
        if name not in self.datasources:
            logger.error(
                'Datasource "{}" not found in this Tile object!'.format(name))
            raise ValueError(
                'Datasource "{}" not found in this Tile object!'.format(name))
        return self.datasources[name]

    def list_datasources(self):
        return list(self.datasources.keys())

    def _get_datetimeindex(self):
        return pd.DatetimeIndex(pd.date_range(start=self._startdate,
                                              end=self._enddate, freq='5D'))


class Window:
    '''
    Class that deals with individual windows (data cubes)
    that are input to the neural nets
    '''

    def __init__(self, window, tile):
        '''

        :param window: tuple of ((row_start, row_stop), (col_start, col_stop))
        :param tile: corresponding Tile object to get data from
        '''

        if not isinstance(tile, Tile):
            raise TypeError('"tile" argument should be instance of Tile')

        self.window = window
        self.dim = window[0][1] - window[0][0]
        self.tile = tile
        self.data = {}
        self._hasdata = False
        self._datastacks = None
        self._datastacks_array = None

    def _read_window_data(self):
        '''
        method that initiates the loading of all input data from the available
        datasources for this window
        the result is stored in the object as xarrays
        :return:
        '''

        logger.info('Loading data for window {} ...'.format(self.window))

        datasources = self.tile.list_datasources()

        for datasource in datasources:

            datasourceObj = self.tile.get_datasource(datasource)
            operation = datasourceObj.operation

            logger.info('Reading datasource: {}'.format(datasource))

            with rasterio.open(datasourceObj.file, 'r') as src:

                bands = src.count
                data = np.empty((self.dim, self.dim, bands),
                                dtype=datasourceObj.dtype)
                local_namespace = {'x': data}

                for band in range(bands):
                    data[:, :, band] = src.read(band + 1, window=self.window)

                if datasourceObj.nodata is not None:
                    data[data == datasourceObj.nodata] = np.nan


                self.data[datasource] = xr.DataArray(
                    eval(operation, _GLOBAL_NAMESPACE, local_namespace),
                    coords=[np.arange(self.dim),
                            np.arange(self.dim), datasourceObj.dates],
                    dims=['x', 'y', 't'])
                if datasourceObj.dropna:
                    self.data[datasource] = self.data[datasource].dropna('t')




        # Now another HACK: sometimes, a pixel in S2 radiometry is NaN, while
        # the mask is 0 or 1. If at a later stage we forward impute NaN values,
        # then S2 radiometry and its mask become out of synch for those images!
        # So let's put all mask pixels to NaN where radiometry is NaN as well
        if 's2_b02' in self.data:
            self.data['s2_mask'].values[np.isnan(self.data['s2_b02'])] = np.nan
        if 's2_b03' in self.data:
            self.data['s2_mask'].values[np.isnan(self.data['s2_b03'])] = np.nan
        if 's2_b04' in self.data:
            self.data['s2_mask'].values[np.isnan(self.data['s2_b04'])] = np.nan
        if 's2_b08' in self.data:
            self.data['s2_mask'].values[np.isnan(self.data['s2_b08'])] = np.nan


        # Yet another HACK: if a pixel is NaN in one of the S2 bands
        # , we should put it to NaN in the other one as well

        for s2_band in list(self.data.keys()):
            if not s2_band.startswith('s2_b'):
                continue
            for s2_band_nan in list(self.data.keys()):
                if not s2_band_nan.startswith('s2_b'):
                    continue
                self.data[s2_band].values[np.isnan(self.data[s2_band_nan])] = np.nan


        self._hasdata = True

    def get_band_data(self, datasource):
        """
        Method to return a particular band as xarray
        :return:
        """
        if not self._hasdata:
            self._read_window_data()
        if datasource not in self.data:
            raise ValueError(
                "Datasource {} not found in window!".format(datasource))
        logger.info('Getting data as raw arrarys ...')
        return self.data[datasource]


def reproject_raster(in_raster, ref_raster):
    # first get the coordinate specifications of the reference raster
    crs_to = rasterio.open(ref_raster).crs
    transform_to = rasterio.open(ref_raster).transform
    width = rasterio.open(ref_raster).width
    height = rasterio.open(ref_raster).height

    ## define the outname of the reprojected raster
    outname = Path(in_raster).stem + '_warped.tif'

    with rasterio.open(in_raster) as src:
        kwargs = src.meta.copy()
        kwargs.update({
            'crs': crs_to,
            'transform': transform_to,
            'width': width,
            'height': height
        })

        with rasterio.open(os.path.join(Path(in_raster).parent, outname), 'w', **kwargs) as dst:
            for i in range(1, src.count + 1):
                reproject(
                    source=rasterio.band(src, i),
                    destination=rasterio.band(dst, i),
                    src_transform=src.transform,
                    src_crs = src.crs,
                    dst_transform=transform_to,
                    dst_crs=crs_to,
                    resampling = Resampling.nearest
                )

    src.close()
    dst.close()

    return os.path.join(Path(in_raster).parent, outname)



def read_window_raster(raster_dir, bounds, dimension
                       , dtype=np.uint32):
    """
    Function that will read a raster from a certain defined window
    :param raster_dir: directory of the raster that should be read
    :param bounds: boundary that will be used to read the raster (format: minx, miny, maxx, maxy)
    :param dimension: shape of window that should be read (e.g., 128 x 128)
    :param dtype: datatype at which the raster should be read
    :return: loaded raster for the desired window
    """
    import math

    with rasterio.open(raster_dir, 'r') as src:
        x_origin = src.transform.xoff
        y_origin = src.transform.yoff

        x_off = math.floor(int((bounds[0] - x_origin)/10))
        y_off = math.floor(int((y_origin - bounds[3])/10))

        if x_off < 0 or y_off < 0:
            return np.array([])
        if (x_off + dimension) > src.shape[1] or (y_off + dimension) > src.shape[0]:
            return np.array([])

        bands = src.count
        data = np.empty((dimension, dimension),
                        dtype=dtype)
        for band in range(bands):
            data[:, :] = src.read(band + 1)[y_off:y_off+dimension, x_off:x_off+dimension]

    return data

def raster_to_xr(raster_dir, windowsize, dimension
                 , dtype=np.uint32):
    """
    Function that will convert a specific raster to an xarray with specified dimensions
    """

    _GLOBAL_NAMESPACE = dict([(key, getattr(np, key))
                              for key in dir(np) if not key.startswith('__')])
    with rasterio.open(raster_dir, 'r') as src:

        bands = src.count
        data = np.empty((dimension, dimension),
                        dtype=dtype)
        local_namespace = {'x': data}

        for band in range(bands):
            data[:, :] = src.read(band + 1, window=windowsize)

            xr_raster = xr.DataArray(
                eval('1*x', _GLOBAL_NAMESPACE, local_namespace),
                coords=[np.arange(dimension),
                        np.arange(dimension)],
                dims=['x', 'y'])

    return xr_raster