import sys
sys.path.insert(0, 'rasterio')
from biopar.bioparnnw import BioParNNW
import rasterio
#from rasterio.windows import Window
import numpy as np
from biopar.filereading import *
from pathlib import Path

profile = None
basepath = "/data/cel_vol1/HRL_PHENO/S2/S2_TOC/30TXP/S2A_MSIL2A_20180101T105441_N0206_R051_T30TXP_20180101T124911.SAFE/GRANULE/L2A_T30TXP_A013204_20180101T105435/IMG_DATA/R10m/T30TXP_20180101T105441_B03_10m.jp2"

class S2BioparProcessor():

    def __init__(self,basepath,output_dir=".",parameter='FAPAR',use10m =False):
        self.basepath = basepath
        self.output_dir = output_dir
        self.parameter = parameter
        self.use10m = use10m

    def get_basepath_10m(self):
        return self.basepath

    def get_basepath_20m(self):
        return str(self.basepath).replace('10m','20m')


    def get_angle_xml(self):
        return Path(self.get_basepath_20m()).parent / ".." / ".." / "MTD_TL.xml"

    def get_path(self,band_id,resolution = '10m'):
        if resolution == '10m':
            return self.get_basepath_10m().replace("B03", band_id)
        elif resolution == '20m':
            return self.get_basepath_20m().replace("B03", band_id)
        else:
            raise ValueError('Invalid resolution: ' + resolution)

    def get_output_path(self):
        return str(Path(self.output_dir) / Path(self.basepath).name)

    def compute_biopar_S2(self):
        """
        Computes biopars for a Sentinel-2 product.
        Uses 20m bands.

        :return:
        """


        bands_10m = ['B03','B04']
        bands_20m = [
            "B05",
            "B06",
            "B07",
            "B8A",
            "B11",
            "B12"]


        with rasterio.Env(CPL_DEBUG=0, GDAL_CACHEMAX=1024,GTIFF_IMPLICIT_JPEG_OVR=False):
            basepath = self.get_basepath_10m()
            if self.use10m:
                dataset_10m = list(map(lambda band_id:  rasterio.open(self.get_path( band_id,'10m')),bands_10m))
            else:
                bands_20m = bands_10m + bands_20m
                basepath = self.get_basepath_20m()

            print('Processing: ' + basepath)
            dataset_20m = list (map(lambda band_id:  rasterio.open(self.get_path( band_id,'20m')),bands_20m))

            win_scale_factor = 2.0
            if self.use10m:
                basedataset = dataset_10m[0]
            else:
                basedataset = dataset_20m[0]
                win_scale_factor = 1.0

            profile = basedataset.profile

            profile['dtype'] = rasterio.uint8
            profile['driver'] = 'GTIFF'
            profile['compress'] = 'DEFLATE'
            profile['nodata'] = 0

            saa, sza, vaa, vza = self.compute_angles()

            nnw = BioParNNW(version='snap',parameter=self.parameter)
            import tqdm

            with rasterio.open(self.get_output_path(), 'w', **profile) as dst:
                for ji, win in tqdm.tqdm(basedataset.block_windows()):
                    #objgraph.show_growth(limit=3)
                    if rasterio.__version__ == "0.36.0":
                        tile_width = win[0][1] - win[0][0]
                        tile_height = win[1][1] - win[1][0]
                    else:
                        tile_width = win.width
                        tile_height = win.height
                    win_width = tile_width

                    if self.use10m:
                        tiles_10m = list(map(lambda ds:ds.read(1, out=np.empty(shape=(1, int(round(win_width)), int(round(
                            tile_height))), dtype=ds.profile['dtype']), window=win), dataset_10m))
                        tile0 = tiles_10m[0]
                    else:
                        tiles_10m = []
                        tile0 = basedataset.read(1, out=np.empty(shape=(1, int(round(tile_width)), int(round(tile_height))), dtype=basedataset.profile['dtype']), window=win)

                    if len(tile0[tile0 != 0]) == 0:
                        continue

                    #            win20m = rasterio.windows.union([((win.col_off/2, win.width/2), (win.row_off/2, win.height/2))])
                    if rasterio.__version__ == "0.36.0":
                        #TODO compute reduced window
                        win20m = win
                    else:
                        from rasterio.windows import Window
                        win20m = Window(col_off=win.col_off/win_scale_factor, row_off= win.row_off/win_scale_factor, width=tile_width / win_scale_factor, height=win.height / win_scale_factor)
                    tiles_20m = list(map(lambda ds: ds.read(1,out=np.empty(shape=(1, int(round(tile_width)), int(round(
                        tile_height))),dtype=ds.profile['dtype']), window=win20m), dataset_20m))



                    all_reflectance_tiles = tiles_10m + tiles_20m
                    all_reflectance_tiles = list(map(lambda tile: tile[tile0 != 0].astype(np.float) * 0.0001,all_reflectance_tiles))
                    #print(np.histogram(all_reflectance_tiles))
                    constant_angles = False
                    if constant_angles:
                        from math import cos, radians

                        sza_const = radians(30)
                        saa_const = radians(156.5)
                        vza_const = radians(7.)  # ranges from 2 to 12
                        vaa_const = radians(105.)  # largest range!
                        g1 = np.full((len(tile0[tile0 != 0]),1), cos(vza_const))
                        g2 = np.full((len(tile0[tile0 != 0]),1), cos(sza_const))
                        g3 = np.full((len(tile0[tile0 != 0]),1), cos(saa_const - vaa_const))
                    else:
                        from numpy import cos, radians
                        if rasterio.__version__ == "0.36.0":
                            s = tuple(slice(*rng) for rng in win) 
                        else:
                            s = tuple(reversed(win.toslices()))
                        g1 = cos(radians(vza[s][tile0 != 0]))
                        g2 = cos(radians(sza[s][tile0 != 0]))
                        g3 = cos(radians(saa[s][tile0 != 0]) - radians(vaa[s][tile0 != 0]))

                    flat = list(map(lambda arr:arr.flatten(), all_reflectance_tiles + [g1, g2, g3]))
                    bands = np.array(flat)
                    image = nnw.run(bands, output_scale=200.)

                    as_image = np.zeros((tile_width, tile_height), dtype=rasterio.uint8)
                    as_image[tile0!=0] = image.flatten()
                    as_image = as_image.reshape((tile_width, tile_height))


                    dst.write(as_image, 1,window=win)

    def compute_angles(self):
        angle_file = self.get_angle_xml()
        (saa, sza, vaa, vza) = read_angles(str(angle_file))
        if self.use10m:
            zoom_factor = 5000 / 10
        else:
            zoom_factor = 5000 / 20
        import scipy
        saa = scipy.ndimage.zoom(saa, zoom_factor)
        sza = scipy.ndimage.zoom(sza, zoom_factor)
        vaa = scipy.ndimage.zoom(vaa, zoom_factor)
        vza = scipy.ndimage.zoom(vza, zoom_factor)
        return saa, sza, vaa, vza


if __name__ == '__main__':
    processor = S2BioparProcessor("/data/cel_vol1/HRL_PHENO/S2/S2_TOC/30TXP/S2A_MSIL2A_20180101T105441_N0206_R051_T30TXP_20180101T124911.SAFE/GRANULE/L2A_T30TXP_A013204_20180101T105435/IMG_DATA/R10m/T30TXP_20180101T105441_B03_10m.jp2")
    processor.compute_biopar_S2()
