#################################################################################################################################################################################################################################################
# AUTHOR: Matthias Maier
# Task: Solve the supply chain graph using pyomo
#################################################################################################################################################################################################################################################


#################################################################################################################################################################################################################################################
# IMPORT

import pyomo.environ as pyo
import logging
import numpy as np

from nodes_cost_functions.timber_production_01 import implement_timber_production_cost_function
from nodes_cost_functions.biomass_gasification_03 import implement_biomass_gasification_cost_function as implement_biomass_gasification_cost_function_03
from nodes_cost_functions.wc_shipping_terminal_export import implement_wood_chip_shipping_terminal_export_cost_function
from nodes_cost_functions.ch2_shipping_terminal_export import implement_ch2_shipping_terminal_export_cost_function
from nodes_cost_functions.ch2_shipping_terminal_import import implement_ch2_shipping_terminal_import_cost_function
from nodes_cost_functions.wc_shipping_terminal_import import implement_wood_chip_shipping_terminal_import_cost_function
from nodes_cost_functions.biomass_gasification_07 import implement_biomass_gasification_cost_function as implement_biomass_gasification_cost_function_07
from nodes_constraints.timber_production import implement_timber_production_size_constraints
from nodes_constraints.biomass_gasification_03 import implement_biomass_gasification_size_constraints as implement_biomass_gasification_size_constraints_03
from nodes_constraints.wc_shipping_terminal_export import implement_wood_chip_shipping_terminal_export_size_constraints
from nodes_constraints.ch2_shipping_terminal_export import implement_ch2_export_shipping_terminal_size_constraints
from nodes_constraints.ch2_shipping_terminal_import import implement_ch2_import_shipping_terminal_size_constraints
from nodes_constraints.biomass_gasification_07 import implement_biomass_gasification_size_constraints as implement_biomass_gasification_size_constraints_07
from nodes_constraints.demand_constraint import implement_demand_constraint
from edges_constraints.cH2_tube_trailer import implement_ch2_tube_trailer_transport_constraint
from edges_constraints.cH2_pipeline import implement_ch2_pipeline_transport_constraint
from edges_constraints.cH2_shipping import implement_ch2_shipping_cost_constraint
from edges_constraints.wood_chip_shipping import implement_wood_chip_shipping_cost_constraint
from edges_constraints.timber_truck import implement_timber_truck_transport_constraint

logging.getLogger('pyomo.core').setLevel(logging.ERROR)
#################################################################################################################################################################################################################################################


#################################################################################################################################################################################################################################################

def run_optimization(data, final_hydrogen_demand, print_details = False, test_mode=False, show_regression=False, enforce_cH2_export=False):
    """
    :param data: Dict with optimization data
    :param final_hydrogen_demand: Final hydrogen demand in kt/a
    :param print_details: Print solver output
    :param test_mode: Test optimizer through modifying constraints
    :param show_regression: Show cost regression plots for nodes (i.e., biomass gasification)
    :param enforce_cH2_export: Enforce the cH2 supply chain via deactivating the cH2 production option in Germany
    :return: model
    """

    if print_details:
        print('### BUILDING MODEL ###')

    # Extract data
    metadata = data['metadata']
    nodes_data_timber_resources_01 = data['node_data_timber_resources_01'] # Index = Biomass nodes, Columns = Potential [1000m3], Production costs [NOK/t]
    nodes_data_gasification_hubs_03 = data['node_data_gasification_hubs_03'] # Index = Gasification Hubs, Columns = size constraints (min/max), cost data
    nodes_data_cH2_shipping_terminals_05 = data['node_data_cH2_shipping_terminals_05'] # Index = cH2 Shipping Terminals (export)
    nodes_data_wc_shipping_terminals_05 = data['node_data_wood_chip_shipping_terminals_05'] # Index = Wood chip Shipping Terminals (export)
    nodes_data_cH2_shipping_terminals_07 = data['node_data_cH2_shipping_terminals_07'] # Index = cH2 Shipping Terminals (import)
    nodes_data_wc_shipping_terminals_07 = data['node_data_wood_chip_shipping_terminals_07']  # Index = Wood chip Shipping Terminals (import)

    #################################################################################################################################################################################################
    # SET UP MODEL

    # Create an instance of the model
    model = pyo.ConcreteModel()
    model.dual = pyo.Suffix(direction=pyo.Suffix.IMPORT)
    
    # Define index sets (= nodes indices)
    model.timber_resources_indices_01 = pyo.Set(initialize=nodes_data_timber_resources_01.index.values.tolist())
    model.gasification_hubs_indices_03 = pyo.Set(initialize=nodes_data_gasification_hubs_03.index.values.tolist())
    model.cH2_shipping_terminal_indices_05 = pyo.Set(initialize=nodes_data_cH2_shipping_terminals_05.index.values.tolist())
    model.wc_shipping_terminal_indices_05 = pyo.Set(initialize=nodes_data_wc_shipping_terminals_05.index.values.tolist())
    model.cH2_shipping_terminal_indices_07 = pyo.Set(initialize=nodes_data_cH2_shipping_terminals_07.index.values.tolist())
    model.wc_shipping_terminal_indices_07 = pyo.Set(initialize=nodes_data_wc_shipping_terminals_07.index.values.tolist())
    #################################################################################################################################################################################################


    #################################################################################################################################################################################################
    # VARIABLE DECLARATION

    # VARIABLES (EDGES)
    # Timber truck
    model.edges_timber_truck_amount_02 = pyo.Var(model.timber_resources_indices_01, model.gasification_hubs_indices_03, domain=pyo.NonNegativeReals)  # Transport [kt timber per year]
    model.edges_timber_truck_amount_04 = pyo.Var(model.timber_resources_indices_01, model.wc_shipping_terminal_indices_05, domain=pyo.NonNegativeReals)  # Transport [kt timber per year]
    model.edges_timber_truck_cost = pyo.Var(domain=pyo.NonNegativeReals) # Transport cost for the entire timber truck fleet [tNOK2024/a]
    model.edges_timber_truck_num_trucks = pyo.Var(domain=pyo.NonNegativeIntegers) # Number of truck in the entire timber truck fleet [-]
    model.edges_timber_truck_emissions = pyo.Var(domain=pyo.NonNegativeReals) # CO2 emissions of the entire timber truck fleet [tCO2eq/a]

    # cH2 Tube Trailer & Pipeline
    model.edges_ch2_pipeline_amount_04 = pyo.Var(model.gasification_hubs_indices_03, model.cH2_shipping_terminal_indices_05, domain=pyo.NonNegativeReals)  # Transport for on-site hydrogen [kt hydrogen per year]
    model.edges_ch2_tube_trailer_amount_04 = pyo.Var(model.gasification_hubs_indices_03, model.cH2_shipping_terminal_indices_05, domain=pyo.NonNegativeReals)  # Transport for A-B [kt hydrogen per year]
    model.edges_ch2_tube_trailer_cost_04 = pyo.Var(domain=pyo.NonNegativeReals) # Transport cost for the entire cH2 Tube Trailer fleet [tNOK2024/a]
    model.edges_ch2_tube_trailer_num_trucks_04 = pyo.Var(domain=pyo.NonNegativeIntegers) # Number of truck in the entire cH2 Tube Trailer fleet [-]
    model.edges_ch2_tube_trailer_emissions_04 = pyo.Var(domain=pyo.NonNegativeReals) # CO2 emissions of the entire tube trailer fleet [tCO2eq/a]

    # cH2 Shipping
    model.edges_ch2_shipping_amount_06 = pyo.Var(model.cH2_shipping_terminal_indices_05, model.cH2_shipping_terminal_indices_07, domain=pyo.NonNegativeReals)  # Transport amount for A-B [kt hydrogen per year]
    model.edges_ch2_shipping_num_trips_06 = pyo.Var(model.cH2_shipping_terminal_indices_05, model.cH2_shipping_terminal_indices_07, domain=pyo.NonNegativeIntegers) # Amount of trips done by the fleet on route A-B [-]
    model.edges_ch2_shipping_cost_06 = pyo.Var(model.cH2_shipping_terminal_indices_05, model.cH2_shipping_terminal_indices_07, domain=pyo.NonNegativeReals) # Variable transport cost for A-B [tNOK2024/a]

    model.edges_ch2_shipping_fleet_num_ships_06 = pyo.Var(domain=pyo.NonNegativeIntegers) # Number of ships in the fleet [-]
    model.edges_ch2_shipping_fleet_cost_06 = pyo.Var(domain=pyo.NonNegativeReals) # Fixed fleet costs [tNOK2024/a]
    model.edges_ch2_shipping_fleet_emissions_06 = pyo.Var(domain=pyo.NonNegativeReals) # CO2 emissions of the entire fleet [tCO2eq/a]

    # Wood chip shipping
    model.edges_wood_chip_shipping_amount_06 = pyo.Var(model.wc_shipping_terminal_indices_05, model.wc_shipping_terminal_indices_07, domain=pyo.NonNegativeReals)  # Transport amount for A-B [kt wood chips per year]
    model.edges_wood_chip_shipping_num_trips_06 = pyo.Var(model.wc_shipping_terminal_indices_05, model.wc_shipping_terminal_indices_07, domain=pyo.NonNegativeIntegers) # Amount of trips done by the fleet on route A-B [-]
    model.edges_wood_chip_shipping_cost_06 = pyo.Var(model.wc_shipping_terminal_indices_05, model.wc_shipping_terminal_indices_07, domain=pyo.NonNegativeReals) # Variable transport cost for A-B [tNOK2024/a]

    model.edges_wood_chip_shipping_fleet_num_ships_06 = pyo.Var(domain=pyo.NonNegativeIntegers) # Number of ships in the fleet [-]
    model.edges_wood_chip_shipping_fleet_cost_06 = pyo.Var(domain=pyo.NonNegativeReals) # Fixed fleet costs [tNOK2024/a]
    model.edges_wood_chip_shipping_fleet_emissions_06 = pyo.Var(domain=pyo.NonNegativeReals)  # CO2 emissions of the entire fleet [tCO2eq/a]

    # VARIABLES (NODES)
    # Timber production sites
    model.nodes_biomass_production_site_expansion_size_01 = pyo.Var(model.timber_resources_indices_01, domain=pyo.NonNegativeReals) # Timber production amount [kt timber per year]
    model.nodes_biomass_production_site_cost_01 = pyo.Var(model.timber_resources_indices_01, domain=pyo.NonNegativeReals) # Auxiliary variable for cost [tNOK2024/a]
    model.nodes_biomass_production_site_emission_01 = pyo.Var(domain=pyo.NonNegativeReals) # Auxiliary variable for total CO2 emissions at timber production [tCO2eq/a]

    # Biomass gasification hubs (NO)
    model.nodes_biomass_gasification_hub_expansion_decision_03 = pyo.Var(model.gasification_hubs_indices_03, domain=pyo.Binary) # Boolean stating if hub is built
    model.nodes_biomass_gasification_hub_expansion_size_03 = pyo.Var(model.gasification_hubs_indices_03, domain=pyo.NonNegativeReals, bounds=lambda model, index: (0, metadata['Biomass Gasification SIZE - max [MW]']/ 100 * 26.2975)) # Gasification hub size [kt hydrogen per year]
    model.nodes_biomass_gasification_hub_ch2_tube_trailer_expansion_decision_03 = pyo.Var(model.gasification_hubs_indices_03, domain=pyo.Binary)  # Boolean stating if tube trailer terminal is built
    model.nodes_biomass_gasification_hub_ch2_tube_trailer_expansion_size_03 = pyo.Var(model.gasification_hubs_indices_03, domain=pyo.NonNegativeReals)  # Size of the tube trailer terminal at the gasification plant [kt hydrogen per year]
    model.nodes_biomass_gasification_hub_cost_03 = pyo.Var(model.gasification_hubs_indices_03, domain=pyo.NonNegativeReals) # Auxiliary variable to represent total costs (Constraint c = f(size) [tNOK2024/a]
    model.nodes_biomass_gasification_hub_base_cost_03 = pyo.Var(model.gasification_hubs_indices_03, domain=pyo.NonNegativeReals)  # Auxiliary variable to represent piecewise linear base costs (Constraint c = f(size) [tNOK2024/a]
    model.nodes_biomass_gasification_hub_emission_03 = pyo.Var(domain=pyo.NonNegativeReals) # Auxiliary variable for total CO2 emissions at H2 production [tCO2eq/a]

    # cH2 shipping terminal (export)
    model.nodes_ch2_shipping_terminal_expansion_decision_05 = pyo.Var(model.cH2_shipping_terminal_indices_05, domain=pyo.Binary) # Boolean stating if terminal is built
    model.nodes_ch2_shipping_terminal_expansion_size_05 = pyo.Var(model.cH2_shipping_terminal_indices_05, domain=pyo.NonNegativeReals) # Terminal size [kt hydrogen per year]
    model.nodes_ch2_shipping_terminal_ch2_tube_trailer_expansion_size_05 = pyo.Var(model.cH2_shipping_terminal_indices_05, domain=pyo.NonNegativeReals)  # Size of the tube trailer terminal at the shipping terminal [kt hydrogen per year]
    model.nodes_ch2_shipping_terminal_ch2_tube_trailer_expansion_decision_05 = pyo.Var(model.cH2_shipping_terminal_indices_05, domain=pyo.Binary) # Boolean stating if tube trailer terminal is built
    model.nodes_ch2_shipping_terminal_cost_05 = pyo.Var(model.cH2_shipping_terminal_indices_05, domain=pyo.NonNegativeReals) # Auxiliary variable to represent costs (Constraint c = f(size) [tNOK2024/a]
    model.nodes_ch2_shipping_terminal_emission_05 = pyo.Var(domain=pyo.NonNegativeReals) # Auxiliary variable for total CO2 emissions at cH2 terminals [tCO2eq/a]

    # Wood chip shipping terminal (export)
    model.nodes_wood_chip_shipping_terminal_expansion_size_05 = pyo.Var(model.wc_shipping_terminal_indices_05, domain=pyo.NonNegativeReals) # Terminal size [kt timber/a]
    model.nodes_wood_chip_shipping_terminal_cost_05 = pyo.Var(model.wc_shipping_terminal_indices_05, domain=pyo.NonNegativeReals) # Auxiliary variable to represent terminal costs (timber supply + wood chip production) [tNOK2024/a]
    model.nodes_wood_chip_shipping_terminal_emission_05 = pyo.Var(domain=pyo.NonNegativeReals) # Auxiliary variable for total CO2 emissions at timber terminal [tCO2eq/a]

    # cH2 shipping terminal (import)
    model.nodes_ch2_shipping_terminal_expansion_size_07 = pyo.Var(model.cH2_shipping_terminal_indices_07, domain=pyo.NonNegativeReals, bounds=lambda model, index: (0, 1314.875))  # Terminal size [kt hydrogen per year] (Valid between 0 and 5000 MW)
    model.nodes_ch2_shipping_terminal_expansion_decision_07 = pyo.Var(model.cH2_shipping_terminal_indices_07, domain=pyo.Binary)  # Boolean stating if terminal is built
    model.nodes_ch2_shipping_terminal_cost_07 = pyo.Var(model.cH2_shipping_terminal_indices_07, domain=pyo.NonNegativeReals)  # Auxiliary variable to represent expansion costs (Constraint c = f(size) [tNOK2024/a]
    model.nodes_ch2_shipping_terminal_emission_07 = pyo.Var(domain=pyo.NonNegativeReals)  # Auxiliary variable for total CO2 emissions at cH2 terminal [tCO2eq/a]

    # Wood chip shipping terminal (import)
    model.nodes_wood_chip_shipping_terminal_cost_07 = pyo.Var(model.wc_shipping_terminal_indices_07, domain=pyo.NonNegativeReals)  # Auxiliary variable to represent terminal costs [tNOK2024/a]

    # Biomass gasification hubs (DE) at the import terminals in Germany
    model.nodes_biomass_gasification_hub_expansion_size_07 = pyo.Var(model.wc_shipping_terminal_indices_07, domain=pyo.NonNegativeReals) # Gasification hub size [kt hydrogen per year]
    model.nodes_biomass_gasification_hub_expansion_decision_07 = pyo.Var(model.wc_shipping_terminal_indices_07, domain=pyo.Binary)  # Boolean stating if terminal is built
    model.nodes_biomass_gasification_hub_cost_07 = pyo.Var(model.wc_shipping_terminal_indices_07, domain=pyo.NonNegativeReals)  # Auxiliary variable to represent costs for gasification plant [tNOK2024/a]
    model.nodes_biomass_gasification_hub_emission_07 = pyo.Var(domain=pyo.NonNegativeReals) # Auxiliary variable for total CO2 emissions at H2 production [tCO2eq/a]

    #################################################################################################################################################################################################


    #################################################################################################################################################################################################
    # OBJECTIVE

    @model.Objective(sense=pyo.minimize)
    def total_cost(model):
        # Minimize all supply chain costs. All cost values in [tNOK2024/a]

        production_cost_timber_01 = sum(model.nodes_biomass_production_site_cost_01[biomass_site] for biomass_site in model.timber_resources_indices_01)

        transport_costs_timber_truck = model.edges_timber_truck_cost

        production_cost_gasification_hubs_03 = sum(model.nodes_biomass_gasification_hub_cost_03[biomass_gasification_hub] for biomass_gasification_hub in model.gasification_hubs_indices_03)

        transport_costs_ch2_tube_trailer_04 = model.edges_ch2_tube_trailer_cost_04

        terminal_costs_wood_chip_terminal_05 = sum(model.nodes_wood_chip_shipping_terminal_cost_05[timber_shipping_terminal] for timber_shipping_terminal in model.wc_shipping_terminal_indices_05)
        terminal_costs_ch2_terminal_05 = sum(model.nodes_ch2_shipping_terminal_cost_05[ch2_shipping_terminal] for ch2_shipping_terminal in model.cH2_shipping_terminal_indices_05)

        transport_costs_ch2_shipping_06 = sum(model.edges_ch2_shipping_cost_06[export_terminal, import_terminal] for export_terminal in model.cH2_shipping_terminal_indices_05 for import_terminal in model.cH2_shipping_terminal_indices_07)
        transport_costs_ch2_shipping_06 += model.edges_ch2_shipping_fleet_cost_06
        transport_costs_wood_chip_shipping_06 = sum(model.edges_wood_chip_shipping_cost_06[export_terminal, import_terminal] for export_terminal in model.wc_shipping_terminal_indices_05 for import_terminal in model.wc_shipping_terminal_indices_07)
        transport_costs_wood_chip_shipping_06 += model.edges_wood_chip_shipping_fleet_cost_06

        terminal_costs_ch2_terminal_07 = sum(model.nodes_ch2_shipping_terminal_cost_07[ch2_shipping_terminal]*model.nodes_ch2_shipping_terminal_expansion_decision_07[ch2_shipping_terminal]
                                             for ch2_shipping_terminal in model.cH2_shipping_terminal_indices_07)
        terminal_costs_wood_chip_terminal_07 = sum(model.nodes_wood_chip_shipping_terminal_cost_07[timber_shipping_terminal] for timber_shipping_terminal in model.wc_shipping_terminal_indices_07)
        production_cost_gasification_hubs_07 = sum(model.nodes_biomass_gasification_hub_cost_07[biomass_gasification_hub] for biomass_gasification_hub in model.wc_shipping_terminal_indices_07)

        return (production_cost_timber_01 + transport_costs_timber_truck + production_cost_gasification_hubs_03 + transport_costs_ch2_tube_trailer_04 + terminal_costs_wood_chip_terminal_05 +
                terminal_costs_ch2_terminal_05 + transport_costs_ch2_shipping_06 + transport_costs_wood_chip_shipping_06 + terminal_costs_ch2_terminal_07 + terminal_costs_wood_chip_terminal_07 + production_cost_gasification_hubs_07)

    #################################################################################################################################################################################################


    #################################################################################################################################################################################################
    # CONSTRAINTS

    # Timber production Size & Cost
    implement_timber_production_size_constraints(model, data)
    implement_timber_production_cost_function(model, data)

    # Timber Truck (02, 04) Fleet Size & Cost
    implement_timber_truck_transport_constraint(model, data)

    # Biomass Gasification (NO) Size & Cost
    implement_biomass_gasification_size_constraints_03(model, data)
    implement_biomass_gasification_cost_function_03(model, data, show_regression=show_regression)

    # cH2 Tube Trailer Size & Cost
    implement_ch2_tube_trailer_transport_constraint(model, data)
    implement_ch2_pipeline_transport_constraint(model, data)

    # Wood chip shipping terminal (export) Size & Cost
    implement_wood_chip_shipping_terminal_export_size_constraints(model, data)
    implement_wood_chip_shipping_terminal_export_cost_function(model, data)

    # cH2 shipping Terminal (export) Size & Cost
    implement_ch2_export_shipping_terminal_size_constraints(model, data)
    implement_ch2_shipping_terminal_export_cost_function(model, data)

    # cH2 Shipping Cost (Cost = f(edge thickness)
    implement_ch2_shipping_cost_constraint(model, data)

    # Wood chip Shipping Cost (Cost = f(edge thickness)
    implement_wood_chip_shipping_cost_constraint(model, data)

    # cH2 shipping Terminal (import) Size & Cost
    implement_ch2_import_shipping_terminal_size_constraints(model, data)
    implement_ch2_shipping_terminal_import_cost_function(model, metadata)

    # Wood chip shipping terminal (import) Cost
    implement_wood_chip_shipping_terminal_import_cost_function(model, data)

    # Biomass Gasification (DE) Size & Cost
    implement_biomass_gasification_size_constraints_07(model, data)
    implement_biomass_gasification_cost_function_07(model, data, show_regression=show_regression)

    # Demand constraint
    implement_demand_constraint(model, final_hydrogen_demand, data)

    # Enforce cH2 supply chain
    if enforce_cH2_export:
        model.custom_constraint_ch2 = pyo.Constraint(rule=lambda model: model.nodes_biomass_gasification_hub_expansion_decision_07['Wilhelmshaven'] == 0)

    # Add custom constraints for testing
    if test_mode:
        pass
        # model.custom_constraint_1 = pyo.Constraint(rule=lambda model: model.nodes_biomass_gasification_hub_expansion_decision_07['Wilhelmshaven'] == 0)
        # model.custom_constraint_2 = pyo.Constraint(rule=lambda model: model.nodes_ch2_shipping_terminal_expansion_decision_05['ch2_shipping_terminal_9']==1)
        # model.custom_constraint_3 = pyo.Constraint(rule=lambda model: model.nodes_biomass_gasification_hub_ch2_tube_trailer_expansion_size_03['biomass_gasification_plant_378'] >= 10)
        # model.custom_constraint_4 = pyo.Constraint(rule=lambda model: model.nodes_ch2_shipping_terminal_expansion_decision_05['ch2_shipping_terminal_10'] == 1) # Grimstad

    #################################################################################################################################################################################################


    #################################################################################################################################################################################################
    # SOLVE MODEL

    if print_details:
        print('### MODEL DETAILS ###')
        print('Number of binary variables: {}'.format(sum(1 for v in model.component_data_objects(pyo.Var, active=True) if v.is_binary())))
        print('Number of continuous variables: {}'.format(sum(1 for v in model.component_data_objects(pyo.Var, active=True) if v.is_continuous())))
        print('Number of integer variables: {}'.format(sum(1 for v in model.component_data_objects(pyo.Var, active=True) if v.is_integer())))

        print('### SOLVING MODEL ###')

    results = pyo.SolverFactory('gurobi').solve(model)

    if str(results.Solver.status) == 'ok':
        if print_details:
            results.write()
    else:
        print("No Valid Solution Found")
        return None

    # Sanity check
    assert sum(pyo.value(model.nodes_biomass_gasification_hub_ch2_tube_trailer_expansion_size_03[hub]) for hub in model.gasification_hubs_indices_03) - sum(pyo.value(model.nodes_ch2_shipping_terminal_ch2_tube_trailer_expansion_size_05[terminal]) for terminal in model.cH2_shipping_terminal_indices_05) < 0.001

    for gasification_hub in model.gasification_hubs_indices_03:
        if np.round(pyo.value(model.nodes_biomass_gasification_hub_ch2_tube_trailer_expansion_decision_03[gasification_hub]), decimals=2) == 1:
            assert np.round(pyo.value(model.nodes_biomass_gasification_hub_expansion_decision_03[gasification_hub]), decimals=2) == 1

    for terminal in model.cH2_shipping_terminal_indices_05:
        if np.round(pyo.value(model.nodes_ch2_shipping_terminal_ch2_tube_trailer_expansion_decision_05[terminal]), decimals=2) == 1:
            assert np.round(pyo.value(model.nodes_ch2_shipping_terminal_expansion_decision_05[terminal]), decimals=2) == 1

    return model