#################################################################################################################################################################################################################################################
# AUTHOR: Matthias Maier
# Task: Useful functions for plotting geodataframes
#################################################################################################################################################################################################################################################

######################################################################################################################################################
# IMPORT
import matplotlib.pyplot as plt
import contextily as cx
import webbrowser
import os
######################################################################################################################################################


######################################################################################################################################################
def plot_selected_geodataframe(df, name_saved_plot, folder='', title=None, interactive=True, df_baselayer=None):
    """Function to create an (interactive) plot of a geodataframe"""

    if title is None:
        title = name_saved_plot

    if interactive:
        map = df.explore(color='k', radius=5, weight=10)

        if df_baselayer is not None:
            map = df_baselayer.explore(color='#88db8c', m=map, opacity=0.7)

        map.save(folder + name_saved_plot + '.html')
        webbrowser.open_new_tab(os.getcwd() + '/' + folder + name_saved_plot + ".html")
    else:
        ax = df.plot(figsize=(20, 20), linewidth = 2, edgecolor = 'k', color='k')
        if df_baselayer is not None:
            ax = df_baselayer.plot(ax=ax, color='#88db8c', opacity=0.7)

        cx.add_basemap(ax, crs=df.crs, source=cx.providers.Esri.WorldTerrain)

        plt.title(title)
        plt.xticks([])
        plt.yticks([])
        plt.xlabel('Longitude')
        plt.ylabel('Latitude')
        plt.savefig(folder + name_saved_plot + '.png', dpi=600)
        plt.show()
######################################################################################################################################################


######################################################################################################################################################
def plot_multiple_geodataframes(df_list, name_saved_plot, folder='', title=None, add_basemap=True, color_list = None):
    """Function to create a plot of one or more geodataframes"""

    fig, ax = plt.subplots(figsize=(20, 20))
    crs = df_list[0].crs
    color_i = 0

    for df in df_list:
        assert df.crs == crs, 'Dataframes have different crs!'

        color = 'k' if color_list is None else color_list[color_i]
        ax = df.plot(ax=ax, linewidth=2, edgecolor=color, color=color)
        color_i += 1

    if add_basemap:
        cx.add_basemap(ax, crs=crs, source=cx.providers.Esri.WorldTerrain)

    if title is None:
        title = name_saved_plot

    plt.title(title)
    plt.xticks([])
    plt.yticks([])
    plt.xlabel('Longitude')
    plt.ylabel('Latitude')
    plt.savefig(folder + name_saved_plot + '.png', dpi=600)
######################################################################################################################################################