import numpy as np


def read_bands(basepath, tilesize):
    import rasterio
    read_window = rasterio.windows.union([((0, tilesize), (0, tilesize))])

    def read_band_10m(band_id):
        global band5, dataset
        band5 = np.empty(shape=(1, int(round(tilesize)), int(round(tilesize))), dtype=np.int16)
        with rasterio.open(basepath.replace("B03", band_id)) as dataset:
            band = dataset.read(1, out=band5, window=read_window,masked=True)
            band = band.astype(np.float) * 0.0001
            return band

    band3 = read_band_10m("B03")
    band4 = read_band_10m("B04")
    band8 = read_band_10m("B08")

    from math import cos, radians
    sza_const = radians(30)
    saa_const = radians(156.5)
    vza_const = radians(7.)  # ranges from 2 to 12
    vaa_const = radians(105.)  # largest range!
    g1 = np.full(band3[~band3.mask].shape, cos(vza_const))
    g2 = np.full(band3[~band3.mask].shape, cos(sza_const))
    g3 = np.full(band3[~band3.mask].shape, cos(saa_const - vaa_const))

    bands = np.asarray(([band3[~band3.mask].flatten(), band4[~band3.mask].flatten(), band8[~band3.mask].flatten(), g1.flatten(), g2.flatten(), g3.flatten()]))
    return (bands,band3.mask)





def _get_namespace_from_xml_tree(xml_tree):
    ns_raw = xml_tree.tag
    idx_start = ns_raw.index('{') + 1
    idx_stop = ns_raw.index('}')
    return ns_raw[idx_start: idx_stop]


def read_angles(angle_xml):
    """
    Read angle from S2 metadata, reusing code from S2 processing chain.

    :param angle_xml:
    :return:
    """
    from biopar.geometry import AngleGrid, GeomGrid
    from xml.etree.ElementTree import parse
    params = {}

    xml_tree = parse(angle_xml).getroot()
    ns = _get_namespace_from_xml_tree(xml_tree)
    geom_info = xml_tree.find('{%s}Geometric_Info' % ns)
    projection_epsg = geom_info.find('Tile_Geocoding').find('HORIZONTAL_CS_CODE').text

    # Get the geoposition.
    geo_positions = geom_info.find('Tile_Geocoding').findall('Geoposition')
    upper_left_x = None
    upper_left_y = None
    for geo_position in geo_positions:
        resolution = int(geo_position.attrib['resolution'])
        # Use the geopostion at 10m.
        if resolution == 10:
            upper_left_x = int(geo_position.find('ULX').text)
            upper_left_y = int(geo_position.find('ULY').text)

    # Generate the Solar Zenith Angle (SZA) and Solar Azimuth Angle (SAA) grids and write to
    # file.
    sunAnglesGrid = AngleGrid(geom_info.find('Tile_Angles').find('Sun_Angles_Grid'),
                              upper_left_x, upper_left_y)

    mean_viewing_incidence_angles = geom_info.find('Tile_Angles') \
        .find('Mean_Viewing_Incidence_Angle_List') \
        .findall('Mean_Viewing_Incidence_Angle')
    summed_view_zenith = 0.0
    summed_view_azimuth = 0.0
    mvia_count = 0.0
    for mvia in mean_viewing_incidence_angles:
        summed_view_zenith += float(mvia.find("ZENITH_ANGLE").text)
        summed_view_azimuth += float(mvia.find("AZIMUTH_ANGLE").text)
        mvia_count += 1
    mean_view_zenith = summed_view_zenith / mvia_count
    mean_view_azimuth = summed_view_azimuth / mvia_count

    viewAnglesGrid = GeomGrid(None, upper_left_x, upper_left_y)
    viewAnglesGrid.zenithStep = sunAnglesGrid.zenithStep
    viewAnglesGrid.azimuthStep = sunAnglesGrid.azimuthStep

    vza_grid = np.ones_like(sunAnglesGrid.zenithGrid, dtype=np.float32) * mean_view_zenith
    vaa_grid = np.ones_like(sunAnglesGrid.azimuthGrid, dtype=np.float32) * mean_view_azimuth

    # (saa,sza,vaa,vza)
    return (sunAnglesGrid.azimuthGrid, sunAnglesGrid.zenithGrid, vaa_grid, vza_grid)