#################################################################################################################################################################################################################################################
# AUTHOR: Matthias Maier
# Task: Piecewise linear regression
#################################################################################################################################################################################################################################################

import pwlf
import numpy as np
import matplotlib.pyplot as plt


def estimate(x, y, n_pieces, x_label, y_label, title='', constraints = None, print_output=True, xy=None):
    """
    :param x: The list of x values
    :param y: The list of y values
    :param n_pieces: The number of pieces to fit
    :param x_label: The x label
    :param y_label: The y label
    :param title: The title of the plot
    :param constraints: A tuple (x,y) specifying a point where the fit has to go through
    :param print_output: Print the piecewise prediction (pieces, breakpoints and score)
    :param xy: Place for the annotation in the plot
    :return:
    """

    # Estimation
    myPWLF = pwlf.PiecewiseLinFit(x, y)

    if constraints is None:
        breakpoints = myPWLF.fit(n_pieces)
    else:
        assert type(constraints) == tuple

        x_constraint = constraints[0]
        y_constraint = constraints[1]
        breakpoints = myPWLF.fit(n_pieces, x_constraint, y_constraint)

    slopes = myPWLF.calc_slopes()

    # Predict for the determined points
    xHat = np.linspace(min(x),max(x),50)
    yHat = myPWLF.predict(xHat)
    score = myPWLF.r_squared()

    if print_output:
        # Plot
        fig, ax = plt.subplots(figsize=(9*0.7,6*0.7))
        plt.scatter(x,y,s=1, color='black', label='Nonlinear function')
        plt.plot(xHat,yHat,color='#a02b92', label='Piecewise linear estimate')
        plt.title(title, fontname='Inter', fontsize=8)
        plt.xlabel(x_label, fontname='Inter')
        plt.ylabel(y_label, fontname='Inter')
        ax.annotate('Piecewise linear prediction: R2 = {:.4f}'.format(score), xy=(0.95, 0.05) if xy is None else xy, xycoords='axes fraction', fontsize=9, horizontalalignment='right', verticalalignment='bottom', fontname='Inter')
        plt.legend(loc='upper left')
        plt.tight_layout()
        plt.show()

        print('Piecewise linear prediction: y=ax+b ({} pieces)'.format(n_pieces))
        print('\t{} = a * {} + b'.format(y_label, x_label))
        print('\ta Values: {}'.format([float('{:.4f}'.format(x)) for x in slopes.tolist()]))
        print('\tb Values: {}'.format([float('{:.4f}'.format(x)) for x in myPWLF.intercepts.tolist()]))
        print('\tBreakpoints: {}'.format([float('{:.4f}'.format(x)) for x in breakpoints.tolist()]))
        print('\tR2 = {:.4f}'.format(score))

    return breakpoints.tolist(), slopes.tolist(), myPWLF.intercepts.tolist(), score

if __name__ == '__main__':
    x = np.linspace(0.75,1.25,20)  # S/S0
    y = np.power(x,0.67) # (S/S0)^0.67

    x_constraint = [1]
    y_constraint = [1]
    constraints = (x_constraint, y_constraint)

    estimate(x,y,2, 'Independent variable','Dependent variable', constraints=constraints)