#################################################################################################################################################################################################################################################
# AUTHOR: Matthias Maier
# Task: Calculate the transport distance between points
# Comment: These functions work but scale bad and are thus not used. Instead, OpenRouteService is used
#################################################################################################################################################################################################################################################


######################################################################################################################################################
# IMPORT

import geopandas as gpd
import networkx as nx
from shapely.ops import unary_union, nearest_points
from shapely.geometry import Point, LineString
import numpy as np
import pandas as pd
from input_calculations.X_GLOBAL_SCRIPTS import helper_functions as hf
######################################################################################################################################################


######################################################################################################################################################
def _extract_geometry_data(df_exploded_row_geometry):
    """Function to return start and end coordinates of a line as comma separated strings, and line length as a float"""

    coords = df_exploded_row_geometry.coords
    start = coords[0]  # (2.0, 0.0)
    end = coords[-1]  # (3.0, 3.0)

    length = df_exploded_row_geometry.length  # 4.0
    start = ','.join([str(c) for c in start])  # '2.0,0.0'
    end = ','.join([str(c) for c in end])  # '3.0,3.0'
    return [start, end, length]
######################################################################################################################################################


######################################################################################################################################################
def calculate_Elveg2_route(return_dict, start_point, end_point, df_transport_network, plot_name = None):

    """Function to calculate distance between two points
    :param return_dict: Empty multiprocessing.manager.dict if function is called from multiprocessing, else None
    :param start_point: Start point (must be in dataset) as a shapely point in CRS 25833
    :param end_point: End point (must be in dataset) as a shapely point in CRS 25833
    :param df_transport_network: Transport network (Linestrings for roads)
    :param plot_name: If specified, plot transport network and the start/end point (no file ending)
    :return: distance, route
    """

    G = nx.Graph()  # Create empty graph

    # Create nodes from the start and end point for the networkx graph
    node_start = f"{start_point.x},{start_point.y}"
    node_end = f"{end_point.x},{end_point.y}"

    all_nodes_in_network = [Point(el) for el in list(set.union(*[set(el.coords) for el in df_transport_network.geometry]))]

    assert start_point in all_nodes_in_network, 'Start point not contained in network!'
    assert end_point in all_nodes_in_network,  'End point not contained in network!'
    assert start_point != end_point, 'Start and end must be different points'

    # Index of the new split linestring
    idx_split = 'split_I'

    # Cut linestring into 2 pieces where start_point/end_point is if start_point/end_point is in the middle of a linestring
    for point in [start_point, end_point]:
        df_transport_network = pd.concat([df_transport_network, df_transport_network.distance(point).rename('distance_to_point')], axis='columns')
        df_transport_network = df_transport_network.sort_values('distance_to_point')
        df_transport_network.drop(['distance_to_point'], axis='columns', inplace=True)

        row_in_question = df_transport_network.iloc[0,:]
        linestring_in_question = row_in_question['geometry']
        assert linestring_in_question.distance(point) < 10

        if (point.x, point.y) == linestring_in_question.coords[0] or (point.x, point.y) == linestring_in_question.coords[-1]:
            pass # If point is in the start or end of the linestring, then everything is fine

        else:
            # Remove linestring and add the split up linestring (2 pieces) instead
            df_transport_network.drop(index=row_in_question.name, inplace=True)

            # Linestring has to be split as point is in the middle of it
            linestring_in_question_as_point_list = [Point(el) for el in list(linestring_in_question.coords)]
            index_of_point = linestring_in_question_as_point_list.index(point)

            linestring_in_question_before = LineString(linestring_in_question_as_point_list[:index_of_point+1])
            linestring_in_question_after = LineString(linestring_in_question_as_point_list[index_of_point:])

            for ls in [linestring_in_question_after, linestring_in_question_before]:
                new_row = pd.DataFrame({'fartsgrenseVerdi':[row_in_question['fartsgrenseVerdi']], 'geometry':[ls]}, index=[idx_split])
                df_transport_network = pd.concat([df_transport_network, new_row], axis='index')
                idx_split = idx_split + 'I'

    # Explode the graph to ensure all linestrings start and end at intersections
    union = unary_union(df_transport_network.geometry)  # Union the df to insert cross-connections at every intersection
    df_exploded = gpd.GeoDataFrame(geometry=[union])  # A df with one row with a multilinestring, union of the entire network
    df_exploded = df_exploded.explode()  # Explode so each row is a line segment from and to an intersection

    df_exploded[["start", "end", "length"]] = df_exploded.apply(lambda x: _extract_geometry_data(x.geometry), axis=1, result_type="expand")  # Extract start/end point as string (nodes in graph) and length for all exploded lines
    _ = df_exploded.apply(lambda x: G.add_edge(x.start, x.end, length=x.length), axis=1)  # Fill the graph with edges and automatically create nodes for start and end

    # Calculate all possible paths between all nodes in the graph
    all_paths = dict(nx.all_pairs_dijkstra_path(G=G, weight="length"))

    # Calculate the shortest path (or only possible path)
    if node_end in all_paths[node_start]:
        paths = all_paths[node_start][node_end]

        if type(paths[0]) == list: # Multiple paths from node_start to node_end exist. Return the path with the least length
            shortest_path_nodes = sorted(paths, key=lambda x: nx.path_weight(G=G, path=x, weight="length"))[0]

        else: # Only of path from node_start to node_end exists
            shortest_path_nodes = paths

        distance = nx.path_weight(G=G, path=shortest_path_nodes, weight='length')
        shortest_path = LineString(map(lambda el: Point(el.split(',')[0], el.split(',')[1]), shortest_path_nodes))

        # INFO: The shortest path is a linestring going through all the junctions but does not follow the road segments. However, the distance is equal to the combined distance of all road segments, since it comes from the graph!

        # Plot the transport network (black) and the chosen path (red)
        if plot_name is not None:
            hf.plot_multiple_geodataframes([df_transport_network, gpd.GeoDataFrame(crs=df_transport_network.crs, geometry=[start_point, end_point] + [shortest_path])], add_basemap=False, name_saved_plot=plot_name, color_list=['k', 'r'])

    else:
        if plot_name is not None:
            hf.plot_multiple_geodataframes([df_transport_network, gpd.GeoDataFrame(crs=df_transport_network.crs, geometry=[start_point, end_point])], add_basemap=False, name_saved_plot=plot_name, color_list=['k', 'r'])

        print('\t\tElveg2: No feasible path found')
        distance = np.inf
        shortest_path = None

    if return_dict is not None:
        return_dict['distance'] = distance
        return_dict['route'] = shortest_path

    return distance, shortest_path
######################################################################################################################################################


######################################################################################################################################################
def get_road_subset(df_road, center_point, distance_treshold):
    """Function to calculate a subset of the whole road transport network given a center and a maximum radius around"""

    index_list = []

    # Iterate over each LINESTRING Z object in the GeoDataFrame
    for idx, row in df_road.iterrows():
        linestring = row['geometry']

        # Use Shapely's nearest_points to find the nearest point
        nearest = nearest_points(linestring, center_point)[0]  # [0] gives the point on the LINESTRING

        # Calculate the distance to the center point
        distance = nearest.distance(center_point)

        # Include index if distance is below threshold
        if distance < distance_treshold:
            index_list.append(idx)

    return df_road.loc[index_list,:]
######################################################################################################################################################