#!/usr/bin/env python

# ----------------------------------------------------------------------------------- #
#
#  Python macro for selecting b-jets in Aleph Z->qqbar MC:
#
#  Author: Troels C. Petersen (NBI)
#  Email:  petersen@nbi.dk
#  Date:   20th of April 2020
#
# ----------------------------------------------------------------------------------- #

from __future__ import print_function, division   # Ensures Python3 printing & division standard
from matplotlib import pyplot as plt
from matplotlib import colors
from matplotlib.colors import LogNorm
import numpy as np
import csv

# Possible other packages to consider:
# cornerplot, seaplot, sklearn.decomposition(PCA)

r = np.random
r.seed(42)

SavePlots = False          # For now, don't save plots (once you trust your code, switch on)
plt.close('all')           # To close all open figures




# ----------------------------------------------------------------------------------- #
# Evaluate (made into a function, as this is called many times):
# ----------------------------------------------------------------------------------- #
def evaluate(bquark) :
    N = [[0,0], [0,0]]   # Make a list of lists (i.e. matrix) for counting successes/failures.
    for i in np.arange(len(isb)):
        if (bquark[i] == 0 and isb[i] == 0) : N[0][0] += 1
        if (bquark[i] == 0 and isb[i] == 1) : N[0][1] += 1
        if (bquark[i] == 1 and isb[i] == 0) : N[1][0] += 1
        if (bquark[i] == 1 and isb[i] == 1) : N[1][1] += 1
    fracWrong = float(N[0][1]+N[1][0])/float(len(isb))
    return N, fracWrong




# ----------------------------------------------------------------------------------- #
# Main program start:
# ----------------------------------------------------------------------------------- #

# Get data:
# ----------------------------------------------------------------------------------- #
data = np.genfromtxt('AlephBtag_MC_small_v2.csv', names=True)

energy = data['energy']
cTheta = data['cTheta']
phi    = data['phi']
prob_b = data['prob_b']
spheri = data['spheri']
pt2rel = data['pt2rel']
multip = data['multip']
bqvjet = data['bqvjet']
ptlrel = data['ptlrel']
nnbjet = data['nnbjet']
isb    = data['isb']


# Produce 1D figures:
# ----------------------------------------------------------------------------------- #

# Define the histogram range and binning (important - MatPlotLib is NOT good at this):
Nbins = 100
xmin = 0.0
xmax = 1.0

# Make new lists selected based on what the jets really are (b-quark jet or light-quark jet):
prob_b_bjets = []
prob_b_ljets = []
bqvjet_bjets = []
bqvjet_ljets = []
for i in np.arange(len(isb)) :
    if (isb[i] == 1) :
        prob_b_bjets.append(prob_b[i])
        bqvjet_bjets.append(bqvjet[i])
    else             :
        prob_b_ljets.append(prob_b[i])
        bqvjet_ljets.append(bqvjet[i])

# Produce the actual figure, here with two histograms in it:
fig, ax = plt.subplots(figsize=(12, 6))      # Create just a single figure and axes (figsize is in inches!)
hist_prob_b_bjets = ax.hist(prob_b_bjets, bins=Nbins, range=(xmin, xmax), histtype='step', linewidth=2, label='prob_b_bjets', color='blue')
hist_prob_b_ljets = ax.hist(prob_b_ljets, bins=Nbins, range=(xmin, xmax), histtype='step', linewidth=2, label='prob_b_ljets', color='red')
ax.set_xlabel("Probability of b-quark based on track impact parameters")     # Label of x-axis
ax.set_ylabel("Frequency / 0.01")                                            # Label of y-axis
ax.set_title("Distribution of prob_b")                                       # Title of plot
ax.legend(loc='best')                                                        # Legend. Could also be 'upper right'
ax.grid(axis='y')

fig.tight_layout()
fig.show()

if SavePlots :
    fig.savefig('Hist_prob_b_and_bqvjet.pdf', dpi=600)



# Produce 2D figures:
# ----------------------------------------------------------------------------------- #

# First we try a scatter plot, to see how the individual events distribute themselves:
fig2, ax2 = plt.subplots(figsize=(12, 6))
scat2_prob_b_vs_bqvjet_bjets = ax2.scatter(prob_b_bjets, bqvjet_bjets, label='b-jets', color='blue')
scat2_prob_b_vs_bqvjet_ljets = ax2.scatter(prob_b_ljets, bqvjet_ljets, label='l-jets', color='red')
ax2.legend(loc='best')
fig2.tight_layout()
fig2.show()

if SavePlots :
    fig2.savefig('Scatter_prob_b_vs_bqvjet.pdf', dpi=600)


# However, as can be seen in the figure, the overlap between b-jets and light-jets is large,
# and one covers much of the other in a scatter plot, which also does not show the amount of
# statistics in the dense regions. Therefore, we try two separate 2D histograms (zoomed):
fig3, ax3 = plt.subplots(1, 2, figsize=(12, 6))
hist2_prob_b_vs_bqvjet_bjets = ax3[0].hist2d(prob_b_bjets, bqvjet_bjets, bins=[40,40], range=[[0.0, 0.4], [0.0, 0.4]])
hist2_prob_b_vs_bqvjet_ljets = ax3[1].hist2d(prob_b_ljets, bqvjet_ljets, bins=[40,40], range=[[0.0, 0.4], [0.0, 0.4]])
ax3[0].set_title("b-jets")
ax3[1].set_title("light-jets")

fig3.tight_layout()
fig3.show()

if SavePlots :
    fig3.savefig('Hist2D_prob_b_vs_bqvjet.pdf', dpi=600)



# Selection:
# ----------------------------------------------------------------------------------- #

# I give the selection cuts names, so that they only need to be changed in ONE place (also ensures consistency!):
loose_propb = 0.04
tight_propb = 0.38
loose_bqvjet = 0.02
tight_bqvjet = 0.89

# If either of the variable clearly indicate b-quark, or of both loosely do so, call it a b-quark, otherwise not!
bquark=[]
for i in np.arange(len(prob_b)):
    if   (prob_b[i] > tight_propb)  :
        bquark.append(1)
    elif (bqvjet[i] > tight_bqvjet) :
        bquark.append(1)
    elif ((prob_b[i] > loose_propb) and (bqvjet[i] > loose_bqvjet)) :
        bquark.append(1)
    else :
        bquark.append(0)

N, fracWrong = evaluate(bquark)
print("\nRESULT OF HUMAN ATTEMPT AT A GOOD SELECTION:")
print("  First number is my estimate, second is the MC truth:")
print("  True-Negative (0,0)  = ", N[0][0])
print("  False-Negative (0,1) = ", N[0][1])
print("  False-Positive (1,0) = ", N[1][0])
print("  True-Positive (1,1)  = ", N[1][1])
print("    Fraction wrong = ( (0,1) + (1,0) ) / sum = ", fracWrong)



# If you want to execute your scripts in Spyder and normally through the terminal,
# you can use the snippet below to pause scripts running in the terminal,
# so your figures don't immediately close.
# If you are only going to use Spyder, just remove or ignore the part below

try:
    __IPYTHON__
except:
    raw_input('Press Enter to exit')

