#################################################################################################################################################################################################################################################
# AUTHOR: Matthias Maier
# Task: Calculate the transport distance for road transport
# Comment: This file is run from main and requires docker to run!
# Structure: The transport matrix contains values for transport from A (index value) to B (colum value)
# Units:
#     Distance: m
#     Driving time: sec
#     Search radius: m
#     Cost: NOK2024/t
#################################################################################################################################################################################################################################################

# IMPORTS
from pathlib import Path
import geopandas as gpd
import time
import pandas as pd
import numpy as np
import warnings
from tqdm import tqdm
from matplotlib import pyplot as plt
from pyproj import Transformer
from shapely import Point, LineString
from shapely.ops import transform as shapely_transform
from input_calculations.X_GLOBAL_SCRIPTS import network_operations_openrouteservice as no

warnings.simplefilter(action='ignore', category=pd.errors.PerformanceWarning) # Allocating distance, driving time to dataframe not time effective, but rests takes much longer so ignore


def run(RID, gdf_nodes_start, gdf_nodes_destination, edge_name, transform_input_points, pdir, debug=False, plot_distributions=False):
    """
    Function to create transport edges (distance, driving time, cost) between nodes

    :param RID: Run-ID. If run locally, RID = 'test'
    :param gdf_nodes_start: Geodataframe for nodes where the transport edge starts. The point must be in the geometry!
    :param gdf_nodes_destination: Geodataframe for nodes where the transport edge ends. The point must be in the geometry!
    :param edge_name: Name of the edge, e.g., 02_timber_to_gasification_plant
    :param transform_input_points: Points should be in EPSG:25833 Point(latitude, longitude) if transform==True else EPSG:4326 tuple(longitude, latitude), e.g., (False, False)
    :param pdir: Path to the parent directory (i.e., input_calculations)
    """

    #######################################################################################################
    # SETUP
    print('### DATA IMPORT ###')
    start_time = time.time()

    transformer_4326_to_25833 = Transformer.from_crs("EPSG:4326", "EPSG:25833", always_xy=True)

    # Import files
    gdf_road_network = gpd.read_feather(pdir + 'X_GLOBAL/data_processed/dataset_Elveg2.feather')

    # Trim data
    gdf_road_network.geometry = gdf_road_network.geometry.force_2d()

    print('Time passed: {:.2f} seconds'.format(time.time() - start_time))
    print('\n')
    #######################################################################################################

    #######################################################################################################
    # CALCULATE DISTANCE - Connection point on closest road to destination point in the road network
    print('### CALCULATING ROUTES ###\n')

    transport_matrix_distance = {}
    transport_matrix_driving_time = {}
    transport_matrix_search_area = {}
    transport_matrix_lin_estimation_share = {}
    transport_dict_route = {}

    infeasible_routes = 0

    for index_destination in tqdm(gdf_nodes_destination.index):
        for index_start in gdf_nodes_start.index:
            print('Calculating route from ' + str(index_start) + ' to ' + str(index_destination))

            start_point_25833 = gdf_nodes_start.loc[index_start, 'geometry']
            end_point_25833 = gdf_nodes_destination.loc[index_destination, 'geometry']

            distance, driving_time, search_radius, lin_estimation_share, route_ors = no.calculate_route_between_two_points(start_point_25833, end_point_25833, transform_input_points=transform_input_points, gdf_road_network=gdf_road_network, debug=debug)

            if distance == np.inf:
                infeasible_routes += 1
            else:
                if route_ors is not None:
                    route_4326 = [list(coord) for coord in route_ors['features'][0]['geometry']['coordinates']]
                    route_25833 = LineString([shapely_transform(transformer_4326_to_25833.transform, Point(point)) for point in route_4326])
                if route_ors is None:
                    route_25833 = LineString()

            transport_matrix_distance[index_start, index_destination] = distance
            transport_matrix_driving_time[index_start, index_destination] = driving_time
            transport_matrix_search_area[index_start, index_destination] = search_radius
            transport_matrix_lin_estimation_share[index_start, index_destination] = lin_estimation_share
            transport_dict_route[(index_start, index_destination)] = route_25833


    print('Infeasible routes: {}'.format(infeasible_routes))
    print('Time passed: {:.2f} seconds'.format(time.time() - start_time))
    print('\nExporting transport matrices')

    df_transport_matrix_distance = pd.Series(transport_matrix_distance).unstack()
    df_transport_matrix_driving_time = pd.Series(transport_matrix_driving_time).unstack()

    df_transport_matrix_distance.to_feather(pdir + 'edges_inland_transport/output_data/{}_{}_distance.feather'.format(edge_name, RID))
    df_transport_matrix_driving_time.to_feather(pdir + 'edges_inland_transport/output_data/{}_{}_driving_time.feather'.format(edge_name, RID))

    df_route = pd.DataFrame.from_dict(transport_dict_route, orient="index", columns=["geometry"])
    df_route.index = pd.MultiIndex.from_tuples(df_route.index, names=["source", "destination"])
    gdf_route = gpd.GeoDataFrame(df_route, crs='EPSG:25833')
    gdf_route.to_feather(pdir + 'edges_inland_transport/output_data/{}_{}_driving_route.feather'.format(edge_name, RID))

    print('Time passed: {:.2f} seconds'.format(time.time() - start_time))

    if plot_distributions:
        print('Plotting distributions')

        # Search area distribution
        data = np.array(list(transport_matrix_search_area.values()))
        data = data[data>=0]

        fig, ax = plt.subplots()
        plt.hist(data, bins=50, edgecolor='black', color='grey', density=True)
        plt.xlim([data.min(), data.max()])
        plt.xlabel('Search area [m]')
        plt.ylabel('Occurrence density')
        plt.title('Occurrence distribution of selected feature')
        plt.savefig(pdir + 'edges_inland_transport/outputs/{}_{}_searcharea_distribution.png'.format(edge_name, RID))

        # Linear estimation share distribution
        data = np.array(list(transport_matrix_lin_estimation_share.values()))
        data = data[data >= 0]

        if data.max()>data.min():
            fig, ax = plt.subplots()
            plt.hist(data, bins=50, edgecolor='black', color='grey', density=True)
            plt.xlim([data.min(), data.max()])
            plt.xlabel('Linear estimation share [%]')
            plt.ylabel('Occurrence density')
            plt.title('Occurrence distribution of selected feature')
            plt.savefig(pdir + 'edges_inland_transport/outputs/{}_{}_linest_distribution.png'.format(edge_name, RID))


if __name__ == '__main__':

    # Test function
    pdir = str(Path(__file__).resolve().parent.parent) + '\\'

    gdf_biomass_resources = gpd.read_feather(pdir + 'nodes_primary_biomass_resources/output_data/node_data_01_biomass_BC.feather')
    gdf_biomass_resources['geometry'] = gdf_biomass_resources['point_on_closest_road']
    gdf_biomass_resources = gdf_biomass_resources.iloc[0:5, :]

    gdf_gasifiers = gpd.read_feather(pdir + 'nodes_secondary_biomass_gasification/output_data/node_data_03_biomass_gasification.feather')
    gdf_gasifiers = gdf_gasifiers.iloc[0:5, :]

    run('test', gdf_biomass_resources, gdf_gasifiers, 'edge_data_02_transport_timber_from_forest_to_gasification_hub', pdir,(True, True), debug=True)

    gdf_cH2_shipping_terminals = gpd.read_feather(pdir + 'nodes_tertiary_cH2_shipping_terminal_export/output_data/node_data_05_cH2_shipping_terminals_export.feather')
    gdf_cH2_shipping_terminals = gdf_cH2_shipping_terminals.iloc[0:5, :]

    run('test', gdf_gasifiers, gdf_cH2_shipping_terminals, 'edge_data_04_transport_cH2_from_biomass_gasifier_to_cH2_shipping_terminal', pdir, (True, True), debug=True)

    gdf_timber_shipping_terminals = gpd.read_feather(pdir + 'nodes_tertiary_wc_shipping_terminal_export/output_data/node_data_05_wc_shipping_terminals_export.feather')
    gdf_timber_shipping_terminals = gdf_timber_shipping_terminals.iloc[0:5, :]

    run('test', gdf_gasifiers, gdf_cH2_shipping_terminals, 'edge_data_04_transport_timber_from_forest_to_wc_shipping_terminal', pdir, (True, True), debug=True)
#################################################################################################################################################################################################################################################