#!/usr/bin/env python

import numpy as np
import pandas as pd
import matplotlib.dates
import matplotlib.pyplot as plt


def _format_ts_coord(x, y):

    # Formats a timeseries x,y coordinate as 'x=yyyy-mm-dd y=val'

    date = matplotlib.dates.num2date(x)
    date = date.replace(tzinfo=None)

    return 'x={} y={:0.4f}'.format(date.strftime('%Y-%m-%d'), y)


def plot_cropsar(ax, cropsar, s2=None, whittaker=None,
                 s2_product=None, margin=False, legend=False):

    # This function plots a CropSAR timeseries, and optionally the
    # cleaned Sentinel-2 data it was based on

    if s2_product is None:
        s2_product = 'Sentinel-2 Data'

    # Plot the main CropSAR timeseries

    ax.plot_date(cropsar.q50.index, cropsar.q50.values, 'C0-',
                 label='CropSAR q50')

    if whittaker is not None:

        # Plot the whittaker smoothed timeseries

        ax.plot_date(whittaker.smooth.index, whittaker.smooth.values, 'C0:',
                     label=s2_product + ' smoothed (whittaker)')

    if s2 is not None and 'clean' in s2.columns:

        # Plot the original Sentinel-2 timeseries with cleaning info

        removed_nar = s2.loc[s2.flag == 21].data
        removed_ext = s2.loc[s2.flag == 22].data
        removed_min = s2.loc[s2.flag == 51].data

        ax.plot_date(s2.clean.index, s2.clean.values, 'C0.',
                     mew=1.0, label=s2_product)

        ax.plot_date(removed_nar.index, removed_nar.values, 'C2+',
                     mew=1.0, label=s2_product + ' removed (narrowed field)')

        ax.plot_date(removed_ext.index, removed_ext.values, 'C3+',
                     mew=1.0, label=s2_product + ' removed (extended field)')

        ax.plot_date(removed_min.index, removed_min.values, 'C4+',
                     mew=1.0, label=s2_product + ' removed (local minima)')

    elif s2 is not None:

        # Plot the original Sentinel-2 timeseries

        ax.plot_date(s2.data.index, s2.data.values, 'C0.',
                     mew=1.0, label=s2_product)

    # Plot the CropSAR uncertainty 'range'

    ax.fill_between(cropsar.q50.index,
                    cropsar.q10.values.flatten(),
                    cropsar.q90.values.flatten(),
                    color='C0', alpha=0.3,
                    label='CropSAR q10-q90')

    # Expect values to be in range [0, 1] but add some margin

    ax.set_ylim(-0.2, 1.2)
    ax.set_yticks([0.0, 0.2, 0.4, 0.6, 0.8, 1.0])

    # Optionally crop the Sentinel-2 timeseries to the CropSAR timeseries
    # date range (as normally it includes some extra margins)

    if not margin:
        if not cropsar.empty:
            ax.set_xlim(cropsar.index[0], cropsar.index[-1])
        elif not whittaker.empty:
            ax.set_xlim(whittaker.index[0], whittaker.index[-1])

    # Show a label only every 3 months, but gridlines every month

    ax.xaxis.set_major_locator(matplotlib.dates.MonthLocator(bymonth=[1, 4, 7, 10]))
    ax.xaxis.set_major_formatter(matplotlib.dates.DateFormatter('%Y-%m'))
    ax.xaxis.set_minor_locator(matplotlib.dates.MonthLocator())
    ax.format_coord = _format_ts_coord
    ax.grid(True, which='minor')
    ax.grid(True)

    # Optionally draw a legend to the right of the plot

    if legend:
        ax.legend(loc='upper left', bbox_to_anchor=(1.0, 1.0))


def plot_vv_vh(ax, s1, s1_product=None, legend=False):

    # This function plots a Sentinel-1 VV and VH timeseries

    if s1_product is None:
        s1_product = 'Sentinel-1 Data'

    # Calculate smooth timeseries using rolling mean

    s1_vv_lin = np.power(10.0, s1.vv / 10.0)
    s1_vh_lin = np.power(10.0, s1.vh / 10.0)

    s1_vv_sm_lin = s1_vv_lin.rolling(7, center=True).mean()
    s1_vh_sm_lin = s1_vh_lin.rolling(7, center=True).mean()

    s1_vv_sm = 10.0 * np.log10(s1_vv_sm_lin)
    s1_vh_sm = 10.0 * np.log10(s1_vh_sm_lin)

    # Plot the original and smoothed timeseries

    ax.plot_date(s1.vv.index, s1.vv.values,
                 'C1.', mew=0.3, label=s1_product + ' vv')
    ax.plot_date(s1_vv_sm.index, s1_vv_sm.values,
                 'C1-', label=s1_product + ' vv (smoothed)')

    ax.plot_date(s1.vh.index, s1.vh.values,
                 'C2.', mew=0.3, label=s1_product + ' vh')
    ax.plot_date(s1_vh_sm.index, s1_vh_sm.values,
                 'C2-', label=s1_product + ' vh (smoothed)')

    # Expect values to be in range [-25, 0] but add some margin

    ax.set_ylim(-30.0, 5.0)
    ax.set_yticks([-25.0, -20.0, -15.0, -10.0, -5.0, 0.0])

    # Show a label only every 3 months, but gridlines every month

    ax.xaxis.set_major_locator(matplotlib.dates.MonthLocator(bymonth=[1, 4, 7, 10]))
    ax.xaxis.set_major_formatter(matplotlib.dates.DateFormatter('%Y-%m'))
    ax.xaxis.set_minor_locator(matplotlib.dates.MonthLocator())
    ax.format_coord = _format_ts_coord
    ax.grid(True, which='minor')
    ax.grid(True)

    # Optionally draw a legend to the right of the plot

    if legend:
        ax.legend(loc='upper left', bbox_to_anchor=(1.0, 1.0))


def plot_rvi(ax, s1, s1_product=None, legend=False):

    # This function plots a Sentinel-1 RVI timeseries (based on VV/VH)

    if s1_product is None:
        s1_product ='Sentinel-1 Data'

    # Calculate the RVI timeseries

    s1_vv_lin = np.power(10.0, s1.vv/10.0)
    s1_vh_lin = np.power(10.0, s1.vh/10.0)

    # Calculate smooth RVI timeseries using rolling mean

    s1_rvi = (4.0 * s1_vh_lin) / (s1_vv_lin + s1_vh_lin)
    s1_rvi_sm = s1_rvi.rolling(7, center=True).mean()

    # Plot the original and smoothed RVI timeseries

    ax.plot_date(s1_rvi.index, s1_rvi.values, 'C3.', mew=0.3,
                 label=s1_product + ' rvi')

    ax.plot_date(s1_rvi_sm.index, s1_rvi_sm.values, 'C3-',
                 label=s1_product + ' rvi (smoothed)')

    # Expect values to be in range [0, 1] but add some margin

    ax.set_ylim(0.0, 1.6)
    ax.set_yticks([0.2, 0.4, 0.6, 0.8, 1.0, 1.2, 1.4])

    # Show a label only every 3 months, but gridlines every month

    ax.xaxis.set_major_locator(matplotlib.dates.MonthLocator(bymonth=[1, 4, 7, 10]))
    ax.xaxis.set_major_formatter(matplotlib.dates.DateFormatter('%Y-%m'))
    ax.xaxis.set_minor_locator(matplotlib.dates.MonthLocator())
    ax.format_coord = _format_ts_coord
    ax.grid(True, which='minor')
    ax.grid(True)

    # Optionally draw a legend to the right of the plot

    if legend:
        ax.legend(loc='upper left', bbox_to_anchor=(1.0, 1.0))


def plot(fig, cropsar, s1=None, s2=None, whittaker=None,
         margin=False, s1_product=None, s2_product=None, legend=True):

    # This function plots a full CropSAR analysis:
    #   - a CropSAR timeseries
    #   - optionally the cleaned Sentinel-2 data it was based on
    #   - optionally a Sentinel-1 VV, VV and RVI timeseries

    # The 'cropsar' argument can either be a DataFrame with q10, q50 and q90
    # columns, or a dictionary containing the full CropSAR analysis as
    # produced by 'cropsar.retrieve_cropsar_analysis()'

    if not isinstance(cropsar, pd.DataFrame):
        metadata = cropsar.get('sources', {}).get('.metadata', {})

        if s1_product is None:
            s1_product = metadata.get('s1-product')

        if s2_product is None:
            s2_product = metadata.get('s2-product')

        if s1 is None:
            s1 = cropsar.get('clean', {}).get('s1-data')

        if s2 is None:
            s2 = cropsar.get('clean', {}).get('s2-data')

        if whittaker is None:
            whittaker = cropsar.get('whittaker')

        cropsar = cropsar['cropsar']

    # Plot CropSAR + Sentinel-2

    ax = fig.add_subplot(1 if s1 is None else 3, 1, 1)

    plot_cropsar(ax, cropsar, s2, whittaker, s2_product, margin, legend)

    if s1 is not None:

        # Plot Sentinel-1 VV/VH

        ax = fig.add_subplot(3, 1, 2, sharex=ax)

        plot_vv_vh(ax, s1, s1_product, legend)

        # Plot Sentinel-1 RVI

        ax = fig.add_subplot(3, 1, 3, sharex=ax)

        plot_rvi(ax, s1, s1_product, legend)


def show(cropsar, s1=None, s2=None, whittaker=None,
         margin=False, s1_product=None, s2_product=None,
         legend=True, style='seaborn-darkgrid'):

    # This function shows a full CropSAR analysis:
    #   - a CropSAR timeseries
    #   - optionally the cleaned Sentinel-2 data it was based on
    #   - optionally a Sentinel-1 VV, VV and RVI timeseries

    with plt.style.context(style):

        fig = plt.figure(figsize=(12, 6))
        fig.set_tight_layout(True)

        plot(fig, cropsar, s1, s2, whittaker,
             margin, s1_product, s2_product, legend)

        plt.show()
