#################################################################################################################################################################################################################################################
# AUTHOR: Matthias Maier
# Task: Analyse the price vs. supply of timber in Norway
#################################################################################################################################################################################################################################################

from pathlib import Path
import numpy as np
import pandas as pd
import geopandas as gpd
from matplotlib import pyplot as plt
import os
from create_biomass_nodes import run as create_biomass_nodes

def supply_price_analysis(RID, pdir, plot_supply_curve = False):
    if not os.path.exists('output_data/node_data_01_biomass_{}.feather'.format(RID)):
        print('Creating biomass nodes for RID = {}'.format(RID))
        create_biomass_nodes(RID, 1000, pdir, plot_nodes=False)

    gdf_reduced_centers = gpd.read_feather(pdir + 'nodes_primary_biomass_resources/output_data/node_data_01_biomass_' + RID + '.feather')
    gdf_fylkedata = gpd.read_feather(pdir + 'X_GLOBAL/data_processed/dataset_fylke.feather')
    df_timber_price = pd.read_excel(pdir + 'X_GLOBAL/data_processed/dataset_landsdeler.xlsx', index_col=0, sheet_name='skog_data').loc[:,'Timber price [NOK2023/m3]']

    # Assign each center to its landsdel using "within"
    gdf_reduced_centers = gpd.sjoin(gdf_reduced_centers, gdf_fylkedata[['geometry', 'landsdel']], how='left', predicate='within')
    gdf_reduced_centers = gdf_reduced_centers[~gdf_reduced_centers['landsdel'].isna()]
    _ = gdf_reduced_centers[['landsdel']].map(lambda el: df_timber_price[el]).rename(columns={'landsdel': 'Timber price [NOK2023/m3]'})
    gdf_reduced_centers = pd.concat([gdf_reduced_centers, _], axis='columns')
    gdf_reduced_centers['price_gap'] = gdf_reduced_centers['Timber price [NOK2023/m3]'] - gdf_reduced_centers['Production costs [NOK/m3]']
    gdf_reduced_centers = gdf_reduced_centers[gdf_reduced_centers['price_gap'] > 0]

    gdf_reduced_centers.sort_values(by='Production costs [NOK/m3]', inplace=True)
    gdf_reduced_centers['Cumulative Potential [kt/a]'] = gdf_reduced_centers['Potential [kt/a]'].cumsum()

    print('Economic timber potential: {:.2f} [kt/a]'.format(gdf_reduced_centers['Cumulative Potential [kt/a]'].iloc[-1]))

    if plot_supply_curve:
        fig, ax = plt.subplots(figsize=(9*0.7,6*0.7))
        plt.scatter(gdf_reduced_centers['Cumulative Potential [kt/a]'], gdf_reduced_centers['Production costs [NOK/t]'], s=1, color='#a02b92')
        plt.xlabel('Cumulative Potential [kt/a]', fontname='Inter')
        plt.ylabel('Production costs [NOK/t]', fontname='Inter')
        plt.title('Supply price analysis for the case = {}'.format(RID), fontname='Inter', fontsize=8)
        plt.tight_layout()
        plt.show()
        plt.savefig(pdir + 'nodes_primary_biomass_resources/plots/supply_price_analysis_' + RID + '.png', dpi=600)

    return gdf_reduced_centers

if __name__ == '__main__':

    pdir = str(Path(__file__).resolve().parent.parent) + '\\'

    # supply_price_analysis(RID = 'BC', plot_supply_curve = True)

    ### Create estimate for different runs ###
    # Use a GP to model the production cost over cumulative potential distribution

    X_train = [] # Cumulative potential [kt/a]
    y_train = [] # Production costs at potential [NOK/t]

    n = 10
    for run_i in range(n):
        RUN_ID = 'SPA_sample_{}'.format(run_i)
        gdf_reduced_centers = supply_price_analysis(RUN_ID, plot_supply_curve=False)

        X_train_i = gdf_reduced_centers.loc[:,'Cumulative Potential [kt/a]'].values.tolist()
        y_train_i = gdf_reduced_centers.loc[:,'Production costs [NOK/t]'].values.tolist()

        X_train.extend(X_train_i)
        y_train.extend(y_train_i)

    df_train = pd.DataFrame(data={'Cumulative Potential [kt/a]':X_train, 'Production costs [NOK/t]':y_train})
    df_train.sort_values(by='Cumulative Potential [kt/a]', inplace=True)

    # Create the bin edges from 0 to 2000 in steps of 10
    bin_edges = np.arange(0, 2001, 10)

    # Assign each row to a bin
    df_train['bin'] = pd.cut(df_train['Cumulative Potential [kt/a]'], bins=bin_edges, right=False)

    # Calculate mean and std per bin
    df_stats = df_train.groupby('bin', observed=False)['Production costs [NOK/t]'].agg(['mean', 'std']).reset_index()

    # Create df2 with midpoints of bins as first column
    df2 = pd.DataFrame({
        'Cumulative Potential [kt/a]': [interval.left + (interval.right - interval.left) / 2 for interval in df_stats['bin']],
        'Mean Production costs [NOK/t]': df_stats['mean'],
        'Std Production costs [NOK/t]': df_stats['std']
    })

    # Plot data
    fig, ax = plt.subplots(figsize=(9 * 0.7, 6 * 0.7))
    plt.scatter(X_train, y_train, s=1, color='black', label="Observations", alpha=0.1)

    plt.title('Supply price analysis', fontname='Inter', fontsize=8)
    plt.xlabel('Cumulative Potential [kt/a]', fontname='Inter')
    plt.ylabel('Production costs [NOK/t]', fontname='Inter')
    plt.show()

    # Plot mean and sigma
    mean_vals = df2['Mean Production costs [NOK/t]']
    std_vals = df2['Std Production costs [NOK/t]']
    x_vals = df2['Cumulative Potential [kt/a]']

    upper = mean_vals + 1.96 * std_vals
    lower = mean_vals - 1.96 * std_vals

    fig, ax = plt.subplots(figsize=(9 * 0.7, 6 * 0.7))
    plt.plot(x_vals, mean_vals, color='black', label=r'$\mu$')
    plt.fill_between(x_vals, lower, upper, color='#a02b92', alpha=0.2, label=r'$\mu \pm 1.96\sigma$')

    plt.title('Supply price analysis', fontname='Inter', fontsize=8)
    plt.xlabel('Cumulative Potential [kt/a]', fontname='Inter')
    plt.ylabel('Production costs [NOK/t]', fontname='Inter')
    plt.legend(loc='upper left')
    plt.tight_layout()
    plt.show()

    plt.savefig('plots/supply_price_analysis_cumulative.png', dpi=600)