#!/usr/bin/env python

# ----------------------------------------------------------------------------------- #
#
#  Python macro for selecting b-jets in Aleph Z->qqbar MC - using tree based ML algorithms.
#
#  Author: Troels C. Petersen (NBI)
#  Email:  petersen@nbi.dk
#  Date:   22nd of April 2020
#
# ----------------------------------------------------------------------------------- #

from __future__ import print_function, division   # Ensures Python3 printing & division standard
from matplotlib import pyplot as plt
from matplotlib import colors
from matplotlib.colors import LogNorm
import numpy as np
import csv

from sklearn.ensemble import AdaBoostClassifier
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report

# Possible other packages to consider:
# cornerplot, seaplot, sklearn.decomposition(PCA)

r = np.random
r.seed(42)

SavePlots = False          # For now, don't save plots (once you trust your code, switch on)
plt.close('all')           # To close all open figures


# ----------------------------------------------------------------------------------- #
# Evaluate (made into a function, as this is called many times):
# ----------------------------------------------------------------------------------- #
def evaluate(bquark) :
    N = [[0,0], [0,0]]   # Make a list of lists (i.e. matrix) for counting successes/failures.
    for i in np.arange(len(isb)):
        if (bquark[i] == 0 and isb[i] == 0) : N[0][0] += 1
        if (bquark[i] == 0 and isb[i] == 1) : N[0][1] += 1
        if (bquark[i] == 1 and isb[i] == 0) : N[1][0] += 1
        if (bquark[i] == 1 and isb[i] == 1) : N[1][1] += 1
    fracWrong = float(N[0][1]+N[1][0])/float(len(isb))
    return N, fracWrong



# ----------------------------------------------------------------------------------- #
# Main program start:
# ----------------------------------------------------------------------------------- #

# Get data:
# ----------------------------------------------------------------------------------- #
# data = np.genfromtxt('AlephBtag_MC_small_v2.csv', names=True)
data = np.loadtxt('AlephBtag_MC_small_v2.csv',skiprows=1)        # Better for ML, see below

# Feature variables:
energy = data[:,0]
cTheta = data[:,1]
phi    = data[:,2]
prob_b = data[:,3]
spheri = data[:,4]
pt2rel = data[:,5]
multip = data[:,6]
bqvjet = data[:,7]
ptlrel = data[:,8]

# Competitor (Aleph NN with two hidden layers of 10 neurons each from mid 90'ies) variable:
nnbjet = data[:,9]

# Targe variables:
isb    = data[:,10]


# Define and train a simple model (AdaBoostClassifier from SciKit-Learn):
# ----------------------------------------------------------------------------------- #
# First, define the data:
X = data[:,:-2]    # Input features are all the variables in the data, except the last two (the "competition" and the truth/target).
Y = data[:,-1]     # Target values, i.e. truth about the origin of the jets (0: background light jets, 1: signal b-jets)
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.2)    # Training size is typically 50-80%
print("  Size of training sample: ", len(X_train))
print("  Size of testing sample:  ", len(X_test))

# Now we define our model:
model_ABC = AdaBoostClassifier(n_estimators=50)      # Number of trees - note that greater is not necessarily good, cf. overtraining!
model_ABC = model_ABC.fit(X_train, Y_train, )        # This is where the magic happens, and the algorithm trains on the data!

# Let us plot the ML output distributions:
fig, ax = plt.subplots(figsize=(12, 6))      # Create just a single figure and axes (figsize is in inches!)
hist_bjets = ax.hist(model_ABC.predict_proba(X_test[Y_test==1])[:, 1], bins=48, range=(0.44,0.56), histtype='step', linewidth=2, label='True b-jets', color='red')
hist_ljets = ax.hist(model_ABC.predict_proba(X_test[Y_test==0])[:, 1], bins=48, range=(0.44,0.56), histtype='step', linewidth=2, label='True l-jets', color='blue')
ax.set_xlabel("Prediction scaore from ML model")         # Label of x-axis
ax.set_ylabel("Frequency / 0.0025")                       # Label of y-axis
ax.set_title("Distribution of ML model scores")          # Title of plot
ax.legend(loc=(0.75, 0.75), fontsize=20)                 # Legend

point1 = [0.50,   0.0]
point2 = [0.50, 500.0]
X_values = [point1[0], point2[0]]
Y_values = [point1[1], point2[1]]
plt.plot(X_values, Y_values, color="gray", linestyle="dashed")

plt.text(0.51, 400.0, "Threshold", size=20, ha='center', va='center', color="gray")

fig.tight_layout()
fig.show()
if SavePlots :
    fig.savefig('MLmodelPredictions_AdaBoostClassifier.pdf', dpi=600)


# Evaluate the model:
# ----------------------------------------------------------------------------------- #

# SciKit-Learn has build in functions for the scoring (choosing 0.5 as the threshold):
Accuracy = model_ABC.score(X_test, Y_test)
print('\n  Accuracy: ', Accuracy, "      Fraction Wrong: ", 1.0-Accuracy)
print('\n  Confusion matrix:\n', confusion_matrix(Y_test, model_ABC.predict(X_test)))

# To get a "full report", there is also a function:
print(classification_report(Y_test, model_ABC.predict(X_test), digits=3))




# ----------------------------------------------------------------------------------- #
# Compare with NN-approach from 1990'ies (note the different threshold):
# ----------------------------------------------------------------------------------- #

bquark=[]
for i in np.arange(len(nnbjet)):
    if   (nnbjet[i] > 0.82) : bquark.append(1)
    else                    : bquark.append(0)

N, fracWrong = evaluate(bquark)
print("\nALEPH BJET TAG:")
print("  First number is my estimate, second is the MC truth:")
print("  True-Negative (0,0)  = ", N[0][0])
print("  False-Negative (0,1) = ", N[0][1])
print("  False-Positive (1,0) = ", N[1][0])
print("  True-Positive (1,1)  = ", N[1][1])
print("    Fraction wrong = ( (0,1) + (1,0) ) / sum = ", fracWrong)


# If you want to execute your scripts in Spyder and normally through the terminal,
# you can use the snippet below to pause scripts running in the terminal,
# so your figures don't immediately close.
# If you are only going to use Spyder, just remove or ignore the part below

try:
    __IPYTHON__
except:
    raw_input('Press Enter to exit')

