from typing import Dict, List, TypeVar
import numpy as np
import xarray as xr
import pandas as pd
from pathlib import Path


from satio.collections import (AgERA5Collection,
                               SIGMA0TrainingCollection)
from satio.utils.errors import EmptyCollection

from cropclass.geoloader import GDALWarpLoader


class AgERA5YearlyCollection(AgERA5Collection):

    sensor = 'AgERA5'

    def __init__(self, *args, **kwargs):
        """
        collection csv/dataframe should have the paths
        in column 'path'
        e.g.
        date      |    path
        20200101   .../ewoc-agera5-yearly/2020
        """
        super().__init__(*args, **kwargs)
        self._loader = GDALWarpLoader()

    def filter_dates(self, start_date, end_date):
        start_year = int(start_date[:4])
        end_year = int(end_date[:4]) + 1

        df = self.df[(self.df.date >= f'{start_year}0101')
                     & (self.df.date < f'{end_year}0101')]
        return self._clone(df=df, start_date=start_date, end_date=end_date)

    def get_band_filenames(self, band, resolution=None):
        filenames = self.df.apply(
            lambda x:
            f"{x.path}/AgERA5_{band}_"
            f'{x.date.year}.tif',
            axis=1)
        return filenames.values.tolist()

    def load_timeseries(self,
                        *bands,
                        resolution=100,
                        resampling='cubic',
                        **kwargs):
        return super().load_timeseries(*bands,
                                       resolution=resolution,
                                       resampling=resampling,
                                       **kwargs)


OTC = TypeVar('OTC', bound='OpenEOTrainingCollection')


class OpenEOTrainingCollection:

    sensor = 'OpenEO'

    def __init__(self,
                 df: pd.DataFrame):

        self.df = df
        self._location_ids = None

    def _clone(self, df):
        return self.__class__(df)

    def save(self, filename: str):
        self.df.to_csv(filename, index=False)

    @property
    def location_ids(self) -> List:
        return self.df.location_id.values.tolist()

    def filter_location(self, *location_ids) -> OTC:
        new_df = self.df[self.df.location_id.isin(location_ids)]
        if new_df.shape[0] == 0:
            raise EmptyCollection("No products in collection for given "
                                  f"location_ids: '{location_ids}'")

        return self._clone(new_df)

    def load(self,
             bands: List = None,
             location_id: int = None,
             force_dtype: type = np.uint16) -> Dict:

        if location_id is None:
            # get first location available
            location_id = self.location_ids[0]

        row = self.df[self.df.location_id == location_id].iloc[0]

        data_path = row['data_path']

        if not Path(data_path).is_file():
            raise FileNotFoundError(f"{data_path} not found.")

        ds = xr.open_dataset(data_path, engine='h5netcdf')
        ds = ds.where(ds.sampleID == location_id, drop=True)
        avail_bands = list(ds.keys())

        if bands is None:
            bands = avail_bands.copy()

        for b in bands:
            if b not in avail_bands:
                raise ValueError(f"Given band {b} is not available."
                                 f"Available bands: {avail_bands}")

        xarrs = {b: ds[b] for b in bands}

        for b in xarrs.keys():
            if xarrs[b].values.dtype == np.float64:
                # Cast to float32
                xarrs[b].values = xarrs[b].values.astype(np.float32)
            # rename t to timestamp
            xarrs[b] = xarrs[b].rename({'t': 'timestamp'})
            # expand dimensions to match what is expected for satio Timeseries
            xarrs[b] = xarrs[b].expand_dims(dim={'x': 1,
                                                 'y': 1},
                                            axis=[2, 3])
            # drop feature dimension
            xarrs[b] = xr.DataArray(data=xarrs[b].values[0, ...],
                                    dims=['timestamp', 'x', 'y'],
                                    coords={'timestamp':
                                            xarrs[b].timestamp.values})

        # monkey patch for points in int16 instead of uint16
        if force_dtype:
            for b in xarrs.keys():
                if not b == 'SCL':
                    xarrs[b] = xarrs[b].astype(force_dtype)

        return xarrs

    def load_timeseries(self,
                        *bands,
                        mask_and_scale=False,
                        **kwargs):

        from satio.timeseries import load_timeseries

        return load_timeseries(self, *bands, mask_and_scale=mask_and_scale,
                               **kwargs)


class SIGMA0HRLTrainingCollection(SIGMA0TrainingCollection):
    '''Override parent class to support simultaneous
    ascending/descending file loading into one collection.
    Using the orbits parameter, you can control which
    orbits get loaded (either both or one of the two).
    In case only one orbit is available for a given sample,
    only that orbit will be loaded, irrespective of your 
    preference as indicated by the orbits parameter.

    NOTE: no way to retrieve original orbit direction so once
    you combine both orbits, there is no turning back!
    '''

    def __init__(self,
                 df: pd.DataFrame,
                 dataformat: str = 'ewoco',
                 orbits=None):
        super().__init__(df, dataformat)

        default_s1_orbits = ['ASCENDING', 'DESCENDING']
        if orbits is None:
            orbits = default_s1_orbits
        elif not isinstance(orbits, list):
            raise ValueError('Specified S1 orbits should be a list, '
                             f'got {type(orbits)}')
        else:
            unexpected = [o for o in orbits
                          if o not in default_s1_orbits]
            if len(unexpected) > 0:
                raise ValueError('Unexpected value for S1 orbits: '
                                 f'{unexpected}. Allowed values: '
                                 f'{default_s1_orbits}')
        self.orbits = orbits

    def _clone(self, df):
        return self.__class__(df, dataformat=self._dataformat,
                              orbits=self.orbits)

    def load(self,
             bands: List = None,
             location_id: int = None,
             mask_and_scale: bool = False) -> Dict:
        '''
        Override of parent class method to support native
        Int32 datatype of GEE-processed sigma0 and allow
        choice of S1 orbit. Default is to load both
        orbits at once and combine them.
        '''

        if location_id is None:
            # get first location available
            location_id = self.location_ids[0]

        row = self.df[self.df.location_id == location_id].iloc[0]
        path = row['path']

        if not os.path.isdir(path):
            raise FileNotFoundError(f"{path} not found.")

        # check for available data:
        available = {}
        nc_files_asc = glob.glob(
            os.path.join(
                row['path'],
                f'*{self.sensor}*ASCENDING*{self.processing_level}*nc'))
        if len(nc_files_asc) > 0:
            available['ASCENDING'] = nc_files_asc
        nc_files_desc = glob.glob(
            os.path.join(
                row['path'],
                f'*{self.sensor}*DESCENDING*{self.processing_level}*nc'))
        if len(nc_files_desc) > 0:
            available['DESCENDING'] = nc_files_desc
        if len(available) == 0:
            raise FileNotFoundError(f'No S1 data available in {path}')

        # select the preferred orbits from the available ones
        selected_orbits = [o for o in self.orbits if o in available.keys()]
        # if no orbits found at this stage, select the available one
        if len(selected_orbits) == 0:
            selected_orbits = available.keys()
        logger.info(f'S1 orbits selected: {list(selected_orbits)}')
        # select the files to load
        nc_files_orbits = [v for k, v in available.items()
                           if k in selected_orbits]

        # start loading the data
        xarrs = {}

        for nc_files in nc_files_orbits:
            if len(nc_files) == 0:
                continue

            avail_bands = dict()

            for nc_file in nc_files:
                ds = xr.open_dataset(nc_file, engine='h5netcdf',
                                     mask_and_scale=mask_and_scale)
                avail_bands.update(dict.fromkeys(ds.keys(), nc_file))
                if 'spatial_ref' in avail_bands.keys():
                    avail_bands.pop('spatial_ref')

            if bands is None:
                bands = list(avail_bands.keys())

            for b in bands:
                if b not in avail_bands:
                    raise ValueError("Given band is not available."
                                     f"Available bands: {avail_bands}")

            for b in bands:
                new_xarr = xr.open_dataset(
                    avail_bands[b], engine='h5netcdf',
                    mask_and_scale=mask_and_scale)[b]
                if b not in xarrs.keys():
                    xarrs[b] = new_xarr
                else:
                    xarrs[b] = xr.concat([xarrs[b], new_xarr], dim='timestamp')

        for b in xarrs.keys():
            if xarrs[b].values.dtype == np.float64:
                # Cast to float32
                xarrs[b].values = xarrs[b].values.astype(np.float32)
            xarrs[b] = xarrs[b].sortby('timestamp')

        return xarrs
