import sys
sys.path.insert(0, 'rasterio')
from biopar.bioparnnw import BioParNNW
import rasterio
#from rasterio.windows import Window
import numpy as np
from biopar.filereading import *
from pathlib import Path

profile = None
basepath = "/data/cel_vol1/HRL_PHENO/S2/S2_TOC/30TXP/S2A_MSIL2A_20180101T105441_N0206_R051_T30TXP_20180101T124911.SAFE/GRANULE/L2A_T30TXP_A013204_20180101T105435/IMG_DATA/R10m/T30TXP_20180101T105441_B03_10m.jp2"

class S2Biopar10MProcessor():

    def __init__(self,basepath,output_dir=".",parameter='ALBEDO'):
        self.basepath = basepath
        self.output_dir = output_dir
        self.parameter = parameter

    def get_basepath_10m(self):
        return self.basepath

    def get_path(self,band_id,resolution = '10m'):
        if resolution == '10m':
            return self.get_basepath_10m().replace("B03", band_id)
        elif resolution == '20m':
            return self.get_basepath_20m().replace("B03", band_id)
        else:
            raise ValueError('Invalid resolution: ' + resolution)

    def get_output_path(self):
        return str(Path(self.output_dir) / Path(self.basepath).name.replace("B03",self.parameter))

    def compute_biopar_S2(self):
        """
        Computes biopars for a Sentinel-2 product.
        Uses 20m bands.

        :return:
        """


        bands_10m = ['B02','B03','B04','B08']


        with rasterio.Env(CPL_DEBUG=0, GDAL_CACHEMAX=1024,GTIFF_IMPLICIT_JPEG_OVR=False):
            basepath = self.get_basepath_10m()
            dataset_10m = list(map(lambda band_id:  rasterio.open(self.get_path( band_id,'10m')),bands_10m))

            basedataset = dataset_10m[0]


            profile = basedataset.profile

            profile['dtype'] = rasterio.uint8
            profile['driver'] = 'GTIFF'
            profile['compress'] = 'DEFLATE'
            nodata_value = 255
            profile['nodata'] = nodata_value
            import tqdm

            with rasterio.open(self.get_output_path(), 'w', **profile) as dst:
                for ji, win in tqdm.tqdm(basedataset.block_windows()):
                    #objgraph.show_growth(limit=3)
                    if rasterio.__version__ == "0.36.0":
                        tile_width = win[0][1] - win[0][0]
                        tile_height = win[1][1] - win[1][0]
                    else:
                        tile_width = win.width
                        tile_height = win.height
                    win_width = tile_width


                    tiles_10m = list(map(lambda ds:ds.read(1, out=np.empty(shape=(1, int(round(win_width)), int(round(
                        tile_height))), dtype=ds.profile['dtype']), window=win), dataset_10m))
                    tile0 = tiles_10m[0]

                    masks = list(map(lambda tile: (tile < 0) | (tile > 10000),tiles_10m))

                    mask  = np.zeros(tile0.shape,dtype=np.bool)
                    for tile_mask in masks:
                        mask = mask | tile_mask

                    if np.all(mask):
                        dst.write(np.full((tile_width, tile_height), nodata_value,dtype=rasterio.uint8),1,window=win)
                        continue


                    all_reflectance_tiles = tiles_10m
                    all_reflectance_tiles = list(map(lambda tile: tile[~mask].astype(np.float) * 0.0001,all_reflectance_tiles))


                    flat = list(map(lambda arr:arr.flatten(), all_reflectance_tiles))
                    bands = np.array(flat)

                    image = self.compute_parameter(bands,200)
                    #rounding rule, as requested by FRAME project (Herman Eerens)
                    image = np.floor(image + 0.5)
                    as_image = np.zeros((tile_width, tile_height), dtype=rasterio.uint8)
                    as_image[~mask] = image.flatten()
                    as_image[mask] = nodata_value
                    as_image = as_image.reshape((tile_width, tile_height))


                    dst.write(as_image, 1,window=win)

    def compute_parameter(self,array,output_scale = 200):
        """
        Computes a parameter based on provided bands.

        :param array: Numpy ndarray containing an entry for each pixel.
        :return:
        """
        """
        1 0.307540 B02-Blue
        2 0.288688 B03-Green
        3 0.239615 B04-Red
        4 0.164157 B08-NIR
        """
        #np.savetxt('pixels10MBands.csv', np.random.permutation(array.T)[:10])
        weights = [0.307540, 0.288688, 0.239615, 0.164157]
        result = np.dot(weights, array)
        return output_scale * result


if __name__ == '__main__':
    processor = S2Biopar10MProcessor("/data/cel_vol1/HRL_PHENO/S2/S2_TOC/30TXP/S2A_MSIL2A_20180101T105441_N0206_R051_T30TXP_20180101T124911.SAFE/GRANULE/L2A_T30TXP_A013204_20180101T105435/IMG_DATA/R10m/T30TXP_20180101T105441_B03_10m.jp2")
    processor.compute_biopar_S2()
