import numpy as np

from osgeo import gdal, osr, ogr

"""
Classes and functions related to geometrical information and transformations.
"""

# -------------------------------------------------------------------------------------------------
# CLASSES
# -------------------------------------------------------------------------------------------------

class CoordinateHelper(object):
    #    #adfGeoTransform[0] /* top left x */
    #    #adfGeoTransform[1] /* w-e pixel resolution */
    #    #adfGeoTransform[2] /* 0 */
    #    #adfGeoTransform[3] /* top left y */
    #    #adfGeoTransform[4] /* 0 */
    #    #adfGeoTransform[5] /* n-s pixel resolution (negative value) */
    #    self.topLeftX = GeoTransform[0]
    #    self.xRes     = GeoTransform[1]
    #    self.topLeftY = GeoTransform[3]
    #    self.yRes     = GeoTransform[5]

    def __init__(self, topLeftX, xRes, topLeftY, yRes, xSize, ySize):
        self.topLeftX = topLeftX
        self.xRes = xRes
        self.topLeftY = topLeftY
        self.yRes = yRes
        self.xSize = xSize
        self.ySize = ySize

    @property
    def geoTransform(self):
        return self.topLeftX, self.xRes, 0.0, self.topLeftY, 0.0, self.yRes

    @property
    def topLeftCoords(self):
        return self.topLeftX, self.topLeftY

    @property
    def topRightCoords(self):
        return self.topLeftX + self.xRes * self.xSize, self.topLeftY

    @property
    def bottomLeftCoords(self):
        return self.topLeftX, self.topLeftY + self.yRes * self.ySize

    @property
    def bottomRightCoords(self):
        return self.topLeftX + self.xRes * self.xSize, self.topLeftY + self.yRes * self.ySize

#    @property
#    def centerCoords(self):
#        return (self.topLeftX + self.xRes * self.xSize) / 2, \
#               (self.topLeftY + self.yRes * self.ySize) / 2
#
    @property
    def bbox(self):
        xtl = self.topLeftCoords[0]
        ytl = self.topLeftCoords[1]
        xbr = self.bottomRightCoords[0]
        ybr = self.bottomRightCoords[1]

        xMin = 0
        yMin = 0
        xMax = 0
        yMax = 0

        if xtl <=  xbr:
            xMin = xtl
            xMax = xbr
        else:
            xMin = xbr
            xMax = xtl

        if ytl <= ybr:
            yMin = ytl
            yMax = ybr
        else:
            yMin = ybr
            yMax = ytl

        return xMin, yMin, xMax, yMax

    @staticmethod
    def CreateFromGeoTransForm(GeoTransform, xSize, ySize):
        return CoordinateHelper(GeoTransform[0], GeoTransform[1], GeoTransform[3], GeoTransform[5],
                                xSize, ySize)

class GeomGrid(object):
    def __init__(self, angleSubElement, ulx, uly):
        self._angleSubElement = angleSubElement
        self.ulx = ulx
        self.uly = uly
        self.zenithStep = None
        self.azimuthStep = None
        self._zenithGrid = None
        self._azimuthGrid = None
        self._parse()

    def _parse(self):
        pass

    @property
    def zenithGeoTransform(self):
        # YRes must be a negative number!
        coordinateHelper = \
          CoordinateHelper(self.ulx, self.zenithStep[0], self.uly, 0-self.zenithStep[1], 0, 0)
        return coordinateHelper.geoTransform

    @property
    def azimuthGeoTransform(self):
        # YRes must be a negative number!
        coordinateHelper = \
          CoordinateHelper(self.ulx, self.azimuthStep[0], self.uly, 0-self.azimuthStep[1], 0, 0)
        return coordinateHelper.geoTransform

    @property
    def zenithGrid(self):
        return self._zenithGrid

    @property
    def azimuthGrid(self):
        return self._azimuthGrid

class AngleGrid(GeomGrid):
    def __init__(self, angleSubElement, ulx, uly):
        super(AngleGrid, self).__init__(angleSubElement, ulx, uly)

    def _parse(self):
        try:
            # Find the zenith and azimuth steps.
            # Zenith step.
            colStepZenith = int(self._angleSubElement.find('Zenith').find('COL_STEP').text)
            rowStepZenith = int(self._angleSubElement.find('Zenith').find('ROW_STEP').text)
            self.zenithStep = (colStepZenith,rowStepZenith)

            # Azimuth step.
            colStepAzimuth = int(self._angleSubElement.find('Azimuth').find('COL_STEP').text)
            rowStepAzimuth = int(self._angleSubElement.find('Azimuth').find('ROW_STEP').text)
            self.azimuthStep = (colStepAzimuth,rowStepAzimuth)

            # Read the values for zenith.
            zenithValues = \
              self._angleSubElement.find('Zenith').find('Values_List').findall('VALUES')

            # Get the number of elements in the first row to check whether the value count is
            # consistent in the grid.
            colCountZenith = len(zenithValues[0].text.split(' '))

            zenithList = []
            for zenithRow in zenithValues:
                zenithRowFloat = [float(value) for value in zenithRow.text.split(' ')]
                if len(zenithRowFloat) != colCountZenith:
                    raise Exception('%d elements expected in a row. got %d elements. (Row Data: %s)'
                                    % (colCountZenith, len(zenithRowFloat), zenithRow.text))
                zenithList.append(zenithRowFloat)
            rowCountZenith = len(zenithList)

            # Convert to NumPy array.
            self._zenithGrid = np.array(zenithList, dtype=np.float32)
            zenithList = None

            # Read the values for azimuth.
            azimuthValues = \
                self._angleSubElement.find('Azimuth').find('Values_List').findall('VALUES')
            # Get the number of elements in the first row to check whether the value count is
            # consistent in the grid.
            colCountAzimuth = len(azimuthValues[0].text.split(' '))

            azimuthList = []
            for azimuthRow in azimuthValues:
                azimuthRowFloat = [float(value) for value in azimuthRow.text.split(' ')]
                if len(azimuthRowFloat) != colCountAzimuth:
                    raise Exception('%d elements expected in a row. got %d elements. (Row Data: %s)'
                                    % (colCountZenith, len(zenithRowFloat), zenithRow.text))
                azimuthList.append(azimuthRowFloat)
            rowCountAzimuth = len(azimuthList)

            # Convert to NumPy array.
            self._azimuthGrid = np.array(azimuthList, dtype=np.float32)
            azimuthList = None

        except Exception as e:
            raise Exception('Something went wrong while parsing the angleSubElement (%s): %s' %
                            (self.__class__.__name__,str(e)))

# -------------------------------------------------------------------------------------------------
# FUNCTIONS
# -------------------------------------------------------------------------------------------------

def convert_WKT_to_EPSG(WKT):
    """
    Converts a WKT (usually obtained from GDAL.GetProjection()) to an EPSG code in the form of
    EPSG:XXXX, EPSG:XXXXX, EPSG:XXXXXX.
    """
    sr = osr.SpatialReference()
    sr.ImportFromWkt(WKT)
    code = sr.GetAttrValue('PROJCS|AUTHORITY', 1)
    if not code:
        code = sr.GetAttrValue('GEOGCS|AUTHORITY',1)
    if not code:
        return None
    return 'EPSG:' + str(code)

def convert_EPSG_to_WKT(EPSG):
    """
    Converts an EPSG to a WKT.
    """
    sr = osr.SpatialReference()
    EPSG_int = int(EPSG.replace('EPSG:', ''))
    sr.ImportFromEPSG(EPSG_int)

    return sr.ExportToWkt()

def transform_coordinates(in_EPSG, out_EPSG, in_x, in_y):
    """
    Transform coordinates from one spatial reference system to another one.
    """
    try:
        in_EPSG_int = int(in_EPSG.replace('EPSG:', ''))
    except:
        raise Exception("Could not determine the EPSG code of '%s'" % in_EPSG)

    try:
        out_EPSG_int = int(out_EPSG.replace('EPSG:', ''))
    except:
        raise Exception("Could not determine the EPSG code of '%s'" % out_EPSG)

    point = ogr.Geometry(ogr.wkbPoint)
    point.AddPoint(in_x, in_y)

    in_spatial_ref = osr.SpatialReference()
    in_spatial_ref.ImportFromEPSG(in_EPSG_int)

    out_spatial_ref = osr.SpatialReference()
    out_spatial_ref.ImportFromEPSG(out_EPSG_int)

    coord_transform = osr.CoordinateTransformation(in_spatial_ref, out_spatial_ref)

    point.Transform(coord_transform)

    return point.GetX(), point.GetY()

def reference_warp(input_file, reference_file, output_file, shell, resampling_method='near',
                   deflate=True):
    """
    Reference warp the input file to the output file. The geometrical information is extracted from
    the reference file.
    """
    # Open the input dataset.
    refDataset = gdal.Open(reference_file, gdal.GA_ReadOnly)
    if not refDataset:
        raise Exception("Cannot open input file '%s'." % input_file)

    # Get some information from the dataset.
    refGeoTransform = refDataset.GetGeoTransform()
    refProjection = refDataset.GetProjection()
    refXSize = refDataset.RasterXSize
    refYSize = refDataset.RasterYSize
    refCoorHelper = CoordinateHelper.CreateFromGeoTransForm(refGeoTransform, refXSize, refYSize)

    # Get the EPSG from the reference file.
    if refProjection:
        epsg = convert_WKT_to_EPSG(refProjection)
    else:
        raise Exception('The reference image does not have a spatial reference set.')

    xMin, yMin, xMax, yMax = refCoorHelper.bbox

    # Build the command.
    cmd = 'gdalwarp -overwrite'
    # Cut the output to the same size as the input file.
    cmd += ' -te %.15f %.15f %.15f %.15f' % (xMin, yMin, xMax, yMax)
    # Warp the output to the same size as the reference file.
    cmd += ' -tr %.15f %.15f ' % (np.fabs(refCoorHelper.xRes), np.fabs(refCoorHelper.yRes))
    # Set the resampling method.
    cmd += ' -r %s' % resampling_method
    # Set the spatial reference set.
    cmd += ' -t_srs %s' % epsg
    # Set compression.
    if deflate:
        cmd += ' -co COMPRESS=DEFLATE'
    # Add input and output file.
    cmd += ' %s %s' % (input_file, output_file)

    # Run the command.
    shell.run(cmd)

    # Close the dataset.
    refDataset = None
