#################################################################################################################################################################################################################################################
# AUTHOR: Matthias Maier
# Task: Transport constraint for cH2 shipping
#################################################################################################################################################################################################################################################

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pyomo.environ as pyo
from supply_chain_optimization.functions.initialize_metadata import initialize_metadata

############################################################################################################################################
# CONSTRAINTS

def implement_ch2_shipping_cost_constraint(model, data):
    """
    Constraint Shipping costs = f(edge thickness)
    :param model: Pyomo model
    :param data: The dataset dict containing node and edge data
    :return: None
    """

    export_terminals_indices = model.cH2_shipping_terminal_indices_05
    import_terminals_indices = model.cH2_shipping_terminal_indices_07

    # Number of ships >= required ships (required ships = shipping amount * factor)
    model.ch2_shipping_num_ships_constraint_06 = pyo.Constraint(rule=lambda model: number_of_ships_constraint(model, data))

    # Number of annual trips (int) >= annual shipping amount / shipping amount per trip
    model.ch2_shipping_num_trips_constraint_06 = pyo.Constraint(export_terminals_indices, import_terminals_indices, rule=lambda model, export_terminal_index, import_terminal_index: number_of_trips_constraint(model, export_terminal_index, import_terminal_index, data))

    # Variable shipping cost = number of trips * cost per trip
    model.ch2_shipping_variable_cost_constraint_06 = pyo.Constraint(export_terminals_indices, import_terminals_indices, rule=lambda model, export_terminal_index, import_terminal_index: variable_cost_constraint(model, export_terminal_index, import_terminal_index, data))

    # Fixed shipping costs = number of ships * ship's annual fixed costs
    model.ch2_shipping_fixed_cost_constraint_06 = pyo.Constraint(rule=lambda model: fixed_cost_constraint(model, data))

    # Fleet emissions = sum(number of trips on route * trip emissions)
    model.ch2_shipping_fleet_emissions_constraint_06 = pyo.Constraint(rule=lambda model: model.edges_ch2_shipping_fleet_emissions_06 == get_fleet_emissions(model, data))


def number_of_ships_constraint(model, data):
    # Number of ships >= sum of required ships (required ships = linear function of shipping amount) on route A-B for all routes

    required_number_of_ships = 0 # Minimum number of ships in the fleet [-]

    for export_terminal_index in model.cH2_shipping_terminal_indices_05:
        for import_terminal_index in model.cH2_shipping_terminal_indices_07:
            shipping_amount = model.edges_ch2_shipping_amount_06[export_terminal_index, import_terminal_index] # [kt hydrogen per year]
            number_of_ships_on_route = get_required_number_of_ships_on_route(export_terminal_index, import_terminal_index, shipping_amount, data) # float, number of ships on route A-B [-]
            required_number_of_ships += number_of_ships_on_route

    return model.edges_ch2_shipping_fleet_num_ships_06 >= required_number_of_ships


def number_of_trips_constraint(model, export_terminal_index, import_terminal_index, data):
    # Number of annual trips (int) >= annual shipping amount / shipping amount per trip

    metadata = data['metadata']
    shipping_amount = model.edges_ch2_shipping_amount_06[export_terminal_index, import_terminal_index] # [kt hydrogen per year]
    ship_payload = metadata['cH2 Ship Payload per module [kgH2]'] * metadata['cH2 Ship Num Modules [-]'] # [kg hydrogen per shipment]
    number_of_trips_float = shipping_amount * 1000 * 1000 / ship_payload # [# trips per year]

    return model.edges_ch2_shipping_num_trips_06[export_terminal_index, import_terminal_index] >= number_of_trips_float


def variable_cost_constraint(model, export_terminal_index, import_terminal_index, data):
    # Variable shipping cost = number of trips * cost per trip

    number_of_trips = model.edges_ch2_shipping_num_trips_06[export_terminal_index, import_terminal_index] # [trips/a]
    cost_per_trip = get_cost_per_trip_on_shipping_route(export_terminal_index, import_terminal_index, data) # [tNOK2024/Trip]

    return model.edges_ch2_shipping_cost_06[export_terminal_index, import_terminal_index] == number_of_trips * cost_per_trip


def fixed_cost_constraint(model, data):
    # Fixed shipping costs = number of ships * ship's annual fixed costs

    number_of_ships = model.edges_ch2_shipping_fleet_num_ships_06 # [-]
    fixed_annual_cost_per_ship = get_fixed_annual_cost_per_ship(data['metadata']) # [tNOK2024/a]

    return model.edges_ch2_shipping_fleet_cost_06 == number_of_ships * fixed_annual_cost_per_ship


def get_fleet_emissions(model, data):
    # Fleet emissions = f(number of trips on route, distance of trip on route for route in routes)

    total_emissions = 0 # [tCO2eq/a]
    emission_factors = data['emission_factors']
    metadata = data['metadata']

    for export_terminal_index in model.cH2_shipping_terminal_indices_05:
        for import_terminal_index in model.cH2_shipping_terminal_indices_07:

            number_of_trips_on_route = model.edges_ch2_shipping_num_trips_06[export_terminal_index, import_terminal_index] # [trips/a]
            shipping_distance = data['edge_data_shipping_distance_06'].loc[export_terminal_index, import_terminal_index] * 2 # [km/trip]

            total_sailing_time = shipping_distance / metadata['cH2 Ship Sailing Speed [km/day]'] # [days/trip]
            total_terminal_time = metadata['cH2 Shipping Terminal (export) Filling time [min]'] + metadata['cH2 Shipping Terminal (import) Emptying time [min]'] + 60 # Loading, Unloading, Docking and Undocking
            total_terminal_time = total_terminal_time / 60 / 24 # [days/trip]

            specific_fuel_consumption = calculate_fuel_consumption(metadata) # [tonnes/day]

            fuel_consumption_during_sailing = total_sailing_time * specific_fuel_consumption # [tonnes/trip]
            fuel_consumption_during_terminal = total_terminal_time * specific_fuel_consumption * 0.05 # [tonnes/trip]
            fuel_consumption = fuel_consumption_during_sailing + fuel_consumption_during_terminal # [tonnes/trip]
            fuel_consumption = fuel_consumption * number_of_trips_on_route # [tonnes/a]

            total_emissions += fuel_consumption * emission_factors['VLSFO [kgCO2eq/kg]'] # [tCO2eq/a]

    return total_emissions

############################################################################################################################################


############################################################################################################################################
# HELPER FUNCTIONS

def get_required_number_of_ships_on_route(export_terminal_index, import_terminal_index, shipping_amount, data):
    """
    Calculate the number of ships (float) in the fleet on the given shipping route to ship the stated shipping amount
    The function is dependent on the outer parameters for transfer time at import and export terminal

    :param shipping_amount: cH2 shipping amount [kt hydrogen per year]
    :return: number of ships (float)
    """

    metadata = data['metadata']
    ship_payload = metadata['cH2 Ship Payload per module [kgH2]'] * metadata['cH2 Ship Num Modules [-]'] / 1000 # [tonnes hydrogen per trip]
    number_of_trips_float = shipping_amount * 1000 / ship_payload  # [trips/a for the entire fleet]

    # Outer parameters
    transfer_time_export_terminal = metadata['cH2 Shipping Terminal (export) Filling time [min]'] + 30  # [min/trip]
    transfer_time_import_terminal = metadata['cH2 Shipping Terminal (import) Emptying time [min]'] + 30  # [min/trip]
    trip_terminal_time = (transfer_time_export_terminal + transfer_time_import_terminal) / 60  # Filling time + Emptying time + docking and undocking [h]

    shipping_distance = data['edge_data_shipping_distance_06'].loc[export_terminal_index, import_terminal_index]  # [km]
    trip_time = shipping_distance * 2 / metadata['cH2 Ship Sailing Speed [km/day]'] + trip_terminal_time / 24  # [days/trip]
    annual_trips_per_ship = metadata['Ship Utilization [h/a]'] / (trip_time * 24)  # [trips/a]
    number_of_ships_float = number_of_trips_float / annual_trips_per_ship  # [trips per year and fleet] / [trips per year and ship] = [#ships/fleet]

    return number_of_ships_float


def get_fixed_annual_cost_per_ship(metadata):
    """
    Annual ship's cost = fixed value for all ships
    :param metadata: The metadata dataset
    :return: Annual fixed ship's cost [tNOK2024/a]
    """

    # GENERAL
    ship_deadweight_tonnage = metadata['cH2 Ship Deadweight Tonnage [tonnes] (internal)']

    # CAPEX
    capex_cargo_ship = -380 + 2.6 * 149.3 + 1.8055 * ship_deadweight_tonnage / 1000 - 0.01009 * np.power(ship_deadweight_tonnage / 1000, 2) + 0.0000189 * np.power(ship_deadweight_tonnage / 1000, 3)  # Mulligan 2007 [M$2007]
    capex_cargo_ship = capex_cargo_ship / 186.2 * 287.6 * 10.75  # Adjusted to 2024 using Producer Price Index by Industry: Ship and Boat Building (PCU3366133661) [MNOK2024]
    capex_modules = metadata['cH2 Ship TASC module [tNOK2024/module]'] * metadata['cH2 Ship Num Modules [-]'] / 1000 # [MNOK2024]

    capex = capex_cargo_ship + capex_modules # [MNOK2024]
    annual_capex = capex * crf(metadata['cH2 Ship Lifetime [a]'], metadata['WACC [%]'] / 100)  # [MNOK2024/a]

    # Fixed OPEX
    fixed_opex = metadata['cH2 Ship Fixed OPEX [MNOK2024/a]']  # Salary costs + Insurance
    fixed_opex += capex * 0.03  # Maintenance

    return (annual_capex + fixed_opex) * 1000 # [tNOK2024/a]


def get_cost_per_trip_on_shipping_route(export_terminal_index, import_terminal_index, data):
    """
    Calculate the cost per trip (i.e., fuel costs) on the given shipping route excluding terminal costs
    1 Trip = Loading at export terminal, sailing to import terminal, unloading at import terminal and sailing back to export terminal

    :return: Cost per trip [tNOK2024/Trip]
    """

    metadata = data['metadata']
    trip_cost = 0 # [tNOK2024/Trip]

    shipping_distance = data['edge_data_shipping_distance_06'].loc[export_terminal_index, import_terminal_index]  # [km]

    # Outer parameters
    transfer_time_export_terminal = metadata['cH2 Shipping Terminal (export) Filling time [min]'] + 30  # [min/trip]
    transfer_time_import_terminal = metadata['cH2 Shipping Terminal (import) Emptying time [min]'] + 30  # [min/trip]
    trip_terminal_time = (transfer_time_export_terminal + transfer_time_import_terminal) / 60  # Filling time + Emptying time + docking and undocking [h/trip]

    trip_time = shipping_distance * 2 / metadata['cH2 Ship Sailing Speed [km/day]'] + trip_terminal_time / 24  # [days/trip]

    # Fuel costs during sailing
    fuel_consumption_during_sailing = calculate_fuel_consumption(metadata) # [tonnes/day]

    fuel_costs_during_sailing = fuel_consumption_during_sailing * metadata['Ship Fuel price (VLSFO) [tNOK2024/tonne]'] # [tNOK2024/day]
    sailing_costs_per_trip = fuel_costs_during_sailing * (trip_time - trip_terminal_time / 24) # [tNOK2024/Trip)
    trip_cost += sailing_costs_per_trip

    # Fuel costs at terminals (export + import)
    fuel_costs_during_terminal = fuel_costs_during_sailing * 0.05 / 24 * trip_terminal_time  # [tNOK2024/Trip]
    trip_cost += fuel_costs_during_terminal

    return trip_cost # [tNOK2024/Trip]


def calculate_fuel_consumption(metadata):
    """
    Calculates the fuel consumption according to Cepowski et al. 2007
    :param metadata: The metadata dataset
    :return: Fuel consumption [tonnes/day]
    """

    ship_deadweight_tonnage = metadata['cH2 Ship Deadweight Tonnage [tonnes] (internal)']
    ship_speed = metadata['cH2 Ship Sailing Speed [km/day]'] / 44.45 # [knots]

    fuel_consumption = 87.108 * ((ship_deadweight_tonnage * 2.58 / 1000 / 1000 - 0.04178) * 0.90469 + (ship_speed * 0.294118-3.52941) * 0.130693 + 0.14126)

    return fuel_consumption


def crf(lifetime, i):
    return (i*np.power(1+i,lifetime)) / (np.power(1+i,lifetime)-1)


def test_cost_function(data, shipping_distance, hydrogen_demand):
    """
    Calculate specific costs of a fully utilized ship on a given shipping route
    :param data: The dataset containing node and edge data
    :param shipping_distance: One way shipping distance [km]
    :param hydrogen_demand: Hydrogen demand [MW]
    :return: Specific shipping costs [NOK2024/kgH2]
    """

    data['metadata'] = initialize_metadata(data['metadata'], hydrogen_demand, shipping_distance)
    metadata = data['metadata']

    fixed_annual_cost_per_ship = get_fixed_annual_cost_per_ship(metadata)  # [tNOK2024/a]
    ship_payload = metadata['cH2 Ship Payload per module [kgH2]'] * metadata['cH2 Ship Num Modules [-]'] / 1000  # [tonnes hydrogen per trip]

    transfer_time_export_terminal = metadata['cH2 Shipping Terminal (export) Filling time [min]'] + 30 # [min/trip]
    transfer_time_import_terminal = metadata['cH2 Shipping Terminal (import) Emptying time [min]'] + 30 # [min/trip]
    trip_terminal_time = (transfer_time_export_terminal + transfer_time_import_terminal) / 60  # Filling time + Emptying time + docking and undocking [h/trip]

    trip_time = shipping_distance * 2 / metadata['cH2 Ship Sailing Speed [km/day]'] + trip_terminal_time / 24  # [days/trip]
    annual_trips_per_ship = metadata['Ship Utilization [h/a]'] / (trip_time * 24)  # [trips/a]

    ### Cost per trip ###

    cost_per_trip = 0  # [tNOK2024/Trip]

    # Fuel costs during sailing
    fuel_consumption_during_sailing = calculate_fuel_consumption(metadata) # [tonnes/day]

    fuel_costs_during_sailing = fuel_consumption_during_sailing * metadata['Ship Fuel price (VLSFO) [tNOK2024/tonne]']  # [tNOK2024/day]
    sailing_costs_per_trip = fuel_costs_during_sailing * (trip_time - trip_terminal_time / 24)  # [tNOK2024/Trip)
    cost_per_trip += sailing_costs_per_trip

    # Fuel costs at terminals (export + import)
    fuel_costs_during_terminal = fuel_costs_during_sailing * 0.05 / 24 * trip_terminal_time  # [tNOK2024/Trip]
    cost_per_trip += fuel_costs_during_terminal

    ### Specific costs (ship) ###

    cost_absolute_ship = fixed_annual_cost_per_ship + cost_per_trip * annual_trips_per_ship # [tNOK2024/a]
    cost_specific_ship = cost_absolute_ship * 1000 / (annual_trips_per_ship * ship_payload * 1000) # [NOK2024/kgH2]

    return cost_specific_ship

############################################################################################################################################
# TESTING

if __name__ == '__main__':

    # TEST FUNCTIONS AND CONDUCT SENSITIVITY ANALYSIS

    # Set up test environment
    data_dir = '../data/'

    data = {}
    metadata = pd.read_excel(data_dir + 'general/metadata.xlsx', index_col=0)['Value']
    data['metadata'] = metadata # Un-initialized metadata

    data['edge_data_shipping_distance_06'] = pd.read_feather(data_dir + 'edge_data_06_shipping_distance.feather')  # Index = Shipping terminal names (export), Columns = Shipping terminal names (import), Values = Distance [km]
    data['node_data_cH2_shipping_terminals_05'] = pd.read_feather(data_dir + 'node_data_05_cH2_shipping_terminals_export.feather')  # Index = cH2 shipping terminals (export)

    ### Shipping costs, 500km one-way distance, 300 MW ###
    specific_cost_ship = test_cost_function(data, 500, 300)
    print('Specific costs (fully utilized ship, 1000km one-way distance): {:.2f} [NOK/kg H2]'.format(specific_cost_ship))

    ### Sensitivity of specific ship's cost on shipping distance, 300 MW ###

    shipping_distance_list = np.linspace(250, 20000, 20) # One-way distance [km]
    specific_cost_ship_list = [] # [NOK/kg H2]

    for shipping_distance_i in shipping_distance_list:
        specific_cost_ship = test_cost_function(data, shipping_distance_i, 300)
        specific_cost_ship_list.append(specific_cost_ship)

    plt.plot(shipping_distance_list, specific_cost_ship_list, 'k-*')
    plt.title('Sensitivity of specific shipping cost', fontsize=8)
    plt.xlabel('One-way shipping distance [km]')
    plt.ylabel('Specific cost (ship) [NOK2024/kg H2]')
    plt.show()

    print('Factor cost increase with shipping distance: {:.2f} [NOK/kg per 100 km]'.format((specific_cost_ship_list[-1]-specific_cost_ship_list[0])/(20000/100-250/100)))
    print('Shipping costs for {} km: {:.2f} NOK/kg H2'.format(shipping_distance_list[-1], specific_cost_ship_list[-1]))