#################################################################################################################################################################################################################################################
# AUTHOR: Matthias Maier
# Task: Create matrix for one-way shipping distances between shipping terminals
# Unit: km
#################################################################################################################################################################################################################################################

# IMPORT
from pathlib import Path
import geopandas as gpd
import os
import folium
import webbrowser
from matplotlib import pyplot as plt
import networkx as nx
import pandas as pd
import searoute as sr
from pyproj import Transformer
from shapely import Point
from shapely.ops import transform as shapely_transform
import shapely
import numpy as np
from shapely.ops import nearest_points
pd.options.mode.chained_assignment = None


def create_route_for_A2B(terminal_export_4326, terminal_import_4326, terminal_export_name, terminal_import_name, gdf_coastline, pdir, show_graph=False, assume_complex_coastline=False, save_path=False):
    """
    Create the shipping route from export terminal A to import terminal B
    :param terminal_export_4326: The point for the export terminal in EPSG: 4326
    :param terminal_import_4326: The point for the import terminal in EPSG: 4326
    :param terminal_export_name: The name of the export terminal
    :param terminal_import_name: The name of the import terminal
    :param gdf_coastline: The geodataframe for the coastline. Used to find the exact route along the coast / fjord
    :param pdir: Path to the parent directory (i.e., input_calculations)
    :param show_graph: Boolean stating if the graph for constructing the detailled route along the coast should be shown. An image of the route profile is saved independent of this
    :param assume_complex_coastline: Boolean stating if the coastline is complex. Increases the odds of finding an exact route but requires more time
    :param save_path: If the complete route profile should be saved or only the value for the route distance
    :return: (Shipping distance [km], LineString for route or None)
    """

    # Setup
    transformer_25833_to_4326 = Transformer.from_crs("EPSG:25833", "EPSG:4326")
    transformer_4326_to_25833 = Transformer.from_crs("EPSG:4326", "EPSG:25833", always_xy=True)
    map_i = folium.Map(tiles="cartodb positron")

    # Find approximate route (searoute)
    origin_on_land = terminal_export_4326.tolist()
    destination_germany = terminal_import_4326.tolist()
    route_on_open_sea = sr.searoute(origin_on_land, destination_germany)

    ### Find detailed route from export terminal to connector on searoute ###

    # 1. Find point for export terminal which is in the sea
    origin_on_land_25833 = shapely_transform(transformer_4326_to_25833.transform, Point(origin_on_land))
    index_nearest_seapolygon = gpd.sjoin_nearest(gpd.GeoDataFrame(geometry=[origin_on_land_25833], crs='EPSG:25833'), gdf_coastline, how='left', distance_col='distance_to_coast').loc[0, 'index_right']
    nearest_polygon = gdf_coastline.loc[index_nearest_seapolygon, 'geometry'] # = Nearest sea polygon (the sea is represented by polygons)
    nearest_point_on_nearest_polygon = nearest_points(nearest_polygon, origin_on_land_25833)[0]
    nearest_point_on_nearest_polygon_4326 = tuple(reversed(transformer_25833_to_4326.transform(nearest_point_on_nearest_polygon.x, nearest_point_on_nearest_polygon.y)))

    # 2. Find connector to searoute
    # Usually, the nearest point is best. However, some places (e.g., Steinkjer) have a unique location -> Searoute has to be trimmed

    if terminal_export_name in ['Steinkjer', 'Verdalsøra', 'Tromsø']:
        route_on_open_sea_25833 = shapely.LineString([shapely_transform(transformer_4326_to_25833.transform, Point(el)) for el in route_on_open_sea.geometry.coordinates][1:])
    else:
        route_on_open_sea_25833 = shapely.LineString([shapely_transform(transformer_4326_to_25833.transform, Point(el)) for el in route_on_open_sea.geometry.coordinates])

    # Get point on searoute which is closest to origin
    origin = nearest_point_on_nearest_polygon
    nearest_point_on_searoute = nearest_points(route_on_open_sea_25833, origin)[0]
    connector_on_searoute = nearest_point_on_searoute

    # Remove the tail of the searoute (only consider searoute to connector on searoute)
    route_on_open_sea_25833 = shapely.ops.split(route_on_open_sea_25833, connector_on_searoute.buffer(100)).geoms[-1]
    route_on_open_sea_4326 = [tuple(reversed(transformer_25833_to_4326.transform(Point(el).x, Point(el).y))) for el in route_on_open_sea_25833.coords]

    # 3. Find path from origin to connector on searoute
    # 3.a. Define area
    radius_factor = 1.5 if assume_complex_coastline else 0.6
    middle_line = shapely.LineString([origin, connector_on_searoute])
    middle_point = middle_line.interpolate(middle_line.length / 2)
    polygon = middle_point.buffer(middle_line.length * radius_factor)
    latmin, lonmin, latmax, lonmax = polygon.bounds

    # 3.b. Construct a regular mesh
    points = []
    resolution = 1000 if assume_complex_coastline else 2000
    for lat in np.arange(latmin, latmax, resolution):
        for lon in np.arange(lonmin, lonmax, resolution):
            points.append(Point((round(lat, 4), round(lon, 4))))

    points = [i for i in points if polygon.contains(i)]

    # 3.c Get valid points (i.e., points which are in water)
    gdf_valid_points = gpd.sjoin(gpd.GeoDataFrame(geometry=points, crs='EPSG:25833'), gdf_coastline, how='inner')

    # 3.c.a. Optionally reduce size (Remove points close to each other)
    # if big: Reduce size

    valid_points = gdf_valid_points.geometry.tolist()
    gdf_valid_points = gpd.GeoDataFrame(geometry=valid_points, crs='EPSG:25833')
    gdf_valid_points.loc[-1, 'geometry'] = origin
    gdf_valid_points.loc[-2, 'geometry'] = connector_on_searoute

    # 3.d. Get num closest neighbors for all points ( = edges per node)
    G = nx.Graph()
    num = 3

    for index, row in gdf_valid_points.iterrows():
        point = row['geometry']
        neighbours = gpd.sjoin_nearest(gdf_valid_points[['geometry']], gpd.GeoDataFrame(geometry=[point], crs='EPSG:25833'), how='left', distance_col='distance_to_point')
        neighbours.sort_values(by=['distance_to_point'], ascending=True, inplace=True)

        closest_neighbours = neighbours.iloc[1:1 + num].index.tolist()
        dist_to_neighbours = [float(np.round(el, 2)) for el in neighbours.iloc[1:1 + num]['distance_to_point'].tolist()]

        for i in range(len(closest_neighbours)):
            G.add_edge(index, closest_neighbours[i], weight=dist_to_neighbours[i])

    if show_graph:
        nx.draw(G, with_labels=True)
        plt.show()

    # 3.e Find the shortest route through graph
    error = False

    try:
        path_at_coast = nx.dijkstra_path(G, -1, -2)
        path_at_coast_nodes = gdf_valid_points.loc[path_at_coast, 'geometry'].tolist()
        path_at_coast_length = shapely.LineString(path_at_coast_nodes).length / 1000 # [km]
        line_path_4326 = [tuple(transformer_25833_to_4326.transform(el.x, el.y)) for el in path_at_coast_nodes]
        folium.PolyLine(line_path_4326).add_to(map_i)

    except Exception as e:
        print('An error occured in the route calculation from {} to {}'.format(terminal_export_name, terminal_import_name))
        if assume_complex_coastline == False:
            print('\tERROR: No route could be found. Trying increasing number of points')
        else:
            print('\tERROR: No route could be found despite more points')
            gdf_valid_points['point_4326'] = gdf_valid_points.apply(lambda row: tuple(reversed(transformer_25833_to_4326.transform(row.geometry.x, row.geometry.y))), axis=1)
            _ = gdf_valid_points.apply(lambda row: folium.Marker(location=tuple(reversed(row.point_4326)), icon=folium.Icon(color='black', icon_color='black')).add_to(map_i), axis=1)

        error = True

    connector_on_searoute_4326 = tuple(reversed(transformer_25833_to_4326.transform(connector_on_searoute.x, connector_on_searoute.y)))
    coords_searoute = [tuple(reversed(el)) for el in route_on_open_sea.geometry.coordinates]
    coords_searoute.extend([tuple(reversed(destination_germany))])

    folium.Marker(location=tuple(reversed(nearest_point_on_nearest_polygon_4326)), tooltip='origin').add_to(map_i)
    folium.Marker(location=tuple(reversed(connector_on_searoute_4326)), tooltip='connector on searoute').add_to(map_i)
    folium.PolyLine(locations=[tuple(reversed(el)) for el in route_on_open_sea_4326], color='black').add_to(map_i)

    suffix = '_error' if error else ''
    plot_name = 'route_' + terminal_export_name + suffix
    map_i.save(pdir + 'edges_shipping_transport/plots/' + plot_name + ".html")

    if save_path and not error:
        path_at_senodes = [Point(el) for el in route_on_open_sea_25833.coords]
        destination_25833 = [shapely_transform(transformer_4326_to_25833.transform, Point(destination_germany))]
        complete_route_nodes = path_at_coast_nodes.copy()
        complete_route_nodes.extend(path_at_senodes)
        complete_route_nodes.extend(destination_25833)
        complete_route_25833 = shapely.LineString(complete_route_nodes)

    if error:
        if assume_complex_coastline:
            webbrowser.open_new_tab(pdir + 'edges_shipping_transport/plots/' + plot_name + ".html")
        return None, None
    else:
        route_on_open_sea_length = route_on_open_sea_25833.length / 1000
        distance = route_on_open_sea_length + path_at_coast_length
        print("Route from {} to {}: {:.1f} km || {:.1f} km".format(terminal_export_name, terminal_import_name, path_at_coast_length, route_on_open_sea_length))

        if save_path:
            return distance, complete_route_25833
        else:
            return distance, None


def create_shipping_matrix(gdf_shipping_terminals_export, gdf_shipping_terminals_import, edge_name, gdf_coastline, pdir, append=True, save_path=False):
    """
    :param gdf_shipping_terminals_export: Geodataframe for the export shipping terminals
    :param gdf_shipping_terminals_import: Geodataframe for the import shipping terminals
    :param edge_name: Name of the shipping edge (e.g., edge_data_06_shipping)
    :param gdf_coastline: The geodataframe for the coastline. Used to find the exact route along the coast / fjord
    :param pdir: Path to the parent directory (i.e., input_calculations)
    :param append: Boolean if the matrix should be appended to an existing matrix
    :param save_path: If the complete route profile should be saved or only the value for the route distance
    :return: None
    """

    # Set up calculation
    if append == True and not os.path.exists(pdir + 'edges_shipping_transport/output_data/{}_distance.feather'.format(edge_name)):
        print('Could not find a file to append')
        append = False

    if append:
        print('Reading existing file to append')
        df_matrix_distance = pd.read_feather(pdir + 'edges_shipping_transport/output_data/{}_distance.feather'.format(edge_name))
        gdf_matrix_path = gpd.read_feather(pdir + 'edges_shipping_transport/output_data/{}_path.feather'.format(edge_name))

    matrix_distance = {}  # [km]
    dict_path = {}
    success = True

    # Iterate over all indices
    for terminal_import_name in gdf_shipping_terminals_import.index:
        for terminal_export_name in gdf_shipping_terminals_export.index:

            terminal_export_4326 = gdf_shipping_terminals_export.loc[terminal_export_name, 'geometry_4326']
            terminal_import_4326 = gdf_shipping_terminals_import.loc[terminal_import_name, 'geometry_4326']

            if append:
                if terminal_export_name in df_matrix_distance.index and terminal_import_name in df_matrix_distance.columns:
                    # This element exists in the distance matrix and should be skipped, except if the path needs to be calculated

                    if save_path and (terminal_export_name, terminal_import_name) not in list(gdf_matrix_path.index.values):
                        pass  # This element exists in the distance matrix but not in the path dataset
                    else:
                        continue  # This element exists in the distance matrix and in the path dataset

            distance, path = create_route_for_A2B(terminal_export_4326, terminal_import_4326, terminal_export_name, terminal_import_name, gdf_coastline, pdir, save_path=save_path)

            # Try again with more points (time intensive)
            if distance is None:
                distance, path = create_route_for_A2B(terminal_export_4326, terminal_import_4326, terminal_export_name, terminal_import_name, gdf_coastline, pdir, assume_complex_coastline=True, save_path=save_path)

            if distance is None:
                success = False

            matrix_distance[terminal_export_name, terminal_import_name] = distance

            if save_path:
                dict_path[(terminal_export_name, terminal_import_name)] = path

    if success == False:
        print('ERROR: Calculation finished, but some routes are not feasible')

    assert success is True

    if save_path:
        df_path = pd.DataFrame.from_dict(dict_path, orient="index", columns=["geometry"])
        df_path.index = pd.MultiIndex.from_tuples(df_path.index, names=["source", "destination"])
        gdf_matrix_path_new = gpd.GeoDataFrame(df_path, crs='EPSG:25833')

    if append:
        for key in matrix_distance.keys():
            distance = matrix_distance[key]
            export_terminal = key[0]
            import_terminal = key[1]

            df_matrix_distance.loc[export_terminal, import_terminal] = distance
            df_matrix_distance.to_feather(pdir + 'edges_shipping_transport/output_data/{}_distance.feather'.format(edge_name))

        if save_path:
            gdf_matrix_path = pd.concat([gdf_matrix_path, gdf_matrix_path_new])
            gdf_matrix_path.to_feather(pdir + 'edges_shipping_transport/output_data/{}_path.feather'.format(edge_name))

    else:
        df_matrix_distance = pd.Series(matrix_distance).unstack()
        df_matrix_distance.to_feather(pdir + 'edges_shipping_transport/output_data/{}_distance.feather'.format(edge_name))

        if save_path:
            gdf_matrix_path_new.to_feather(pdir + 'edges_shipping_transport/output_data/{}_path.feather'.format(edge_name))


if __name__ == '__main__':
    pdir = str(Path(__file__).resolve().parent.parent) + '\\'

    gdf_coastline = gpd.read_feather('../X_GLOBAL/data_processed/dataset_coastline.feather')
    gdf_cH2_shipping_terminals_export = gpd.read_feather('../nodes_tertiary_cH2_shipping_terminal_export/output_data/node_data_05_cH2_shipping_terminals_export.feather')
    gdf_cH2_shipping_terminals_import = gpd.read_feather('../nodes_tertiary_cH2_shipping_terminal_import/output_data/node_data_07_cH2_shipping_terminals_import.feather')

    gdf_cH2_shipping_terminals_export = gdf_cH2_shipping_terminals_export.iloc[0:5,:]
    gdf_cH2_shipping_terminals_import = gdf_cH2_shipping_terminals_import.iloc[0:5,:]

    create_shipping_matrix(gdf_cH2_shipping_terminals_export, gdf_cH2_shipping_terminals_import, 'test', gdf_coastline, pdir, append=False, save_path=False)