#!/usr/bin/env python

# ----------------------------------------------------------------------------------- #
#
#  Root macro for constructing a Fisher disciminant from two correlated parameters
#
#  References:
#    Glen Cowan, SDA, pages 51-57
#    http://en.wikipedia.org/wiki/Iris_flower_data_set
#    http://en.wikipedia.org/wiki/Linear_discriminant_analysis
#
#  Author: Lars Egholm Pedersen 
#  Email:  egholm@nbi.dk
#  Date:   5th of October 2013
#
#  Create som artificial random data for two different types of processes
#  And construct a Fisher discriminant from there
#
# ----------------------------------------------------------------------------------- #

# ----------------------------------------------------------------------------------- #
# Imports
# ----------------------------------------------------------------------------------- #
from ROOT import *

# ----------------------------------------------------------------------------------- #
# Functions
# ----------------------------------------------------------------------------------- #
def sqr( a ) : 
    return a*a

#Function for generating a set of correlated random gaussian numbers...
def get_corr( mu1, sig1, mu2, sig2, rho12 ) : 
    r = TRandom3(0) # Just gonna initialize random generator in function for simplicity
                    # Be carefull ALWAYS to use random seed that is changing faster
                    # than you are calling the function!!

    theta = 0.5 * atan( 2.0 * rho12 * sig1 * sig2 / ( sqr(sig1) - sqr(sig2) ) )
    sigu = sqrt( fabs( (sqr(sig1*cos(theta)) - sqr(sig2*sin(theta)) ) / ( sqr(cos(theta)) - sqr(sin(theta))) ) )
    sigv = sqrt( fabs( (sqr(sig2*cos(theta)) - sqr(sig1*sin(theta)) ) / ( sqr(cos(theta)) - sqr(sin(theta))) ) )

    u = r.Gaus( 0.0, sigu )
    v = r.Gaus( 0.0, sigv )

    x = mu1 + cos(theta)*u - sin(theta)*v
    y = mu2 + sin(theta)*u + cos(theta)*v

    return [x,y]

# ----------------------------------------------------------------------------------- #
# Fisher discriminant script.
# ----------------------------------------------------------------------------------- #
gROOT.Reset()

# Setting of general plotting style:
gStyle.SetCanvasColor(0)
gStyle.SetFillColor(0)

# Setting what to be shown in statistics box (using two different methods):
gStyle.SetOptStat("emr")
gStyle.SetOptFit(1111)

# ----------------------------------------------------------------------------------- #
# Define parameters
# ----------------------------------------------------------------------------------- #

nspec   = 2                   # Number of 'species' : signal / background

mean_par0  = [ 15.0 , 12.0  ] # Process type 1,2 mean in x direction
width_par0 = [  2.0 ,  3.0  ] # Process type 1,2 width in x direction
mean_par1  = [ 50.0 , 55.0  ] # ... y
width_par1 = [  6.0 ,  7.0  ] # ... y
corr       = [  0.80,  0.90 ] # Coefficient of correlation

ndata    = 2000          # Amount of data you want to create

color_index  = [ 2,  4]  # Make process type 1 red and two blue...
marker_index = [24, 25]  # Marker style for correlation plot

par_name = ["Parameter_a", "Parameter_b"] #Parameter names
draw_opt = ["", "same"]                   #Drawing option ... 

# ----------------------------------------------------------------------------------- #
# Define all your graphics here
# ----------------------------------------------------------------------------------- #

hist_par0 = [ TH1D("hist_par0_spec0", "hist_par0_spec0", 50,  0.0, 25.0) , #Histograms for "par0"
              TH1D("hist_par0_spec1", "hist_par0_spec1", 50,  0.0, 25.0) ] #Projection

hist_par1 = [ TH1D("hist_par1_spec0", "hist_par1_spec0", 50, 20.0, 80.0) , #Histograms for "par1"
              TH1D("hist_par1_spec1", "hist_par1_spec1", 50, 20.0, 80.0) ] #Projection

hist_corr   = [ TH2D("corr_par0_par1_spec0", "corr_par0_par1_spec0", 50, 0.0, 25.0, 50, 20.0, 80.0) , #Correlation 
                TH2D("corr_par0_par1_spec1", "corr_par0_par1_spec1", 50, 0.0, 25.0, 50, 20.0, 80.0) ] #histograms

hist_fisher = [ TH1D("hist_fisher_spec0", "hist_fisher_spec0", 200, -2.0, 2.0) , #Histograms for final
                TH1D("hist_fisher_spec1", "hist_fisher_spec1", 200, -2.0, 2.0) ] #fisher projection

# ----------------------------------------------------------------------------------- #
# Generate data and fill initial histograms
# ----------------------------------------------------------------------------------- #

for ispec in range( nspec ) : 
    for iexp in range( ndata ) : 
        #Get liniarly correlated random numbers...
        values = get_corr( mean_par0[ispec], width_par0[ispec], mean_par1[ispec], width_par1[ispec], corr[ispec] )

        hist_par0[ispec].Fill( values[0] )
        hist_par1[ispec].Fill( values[1] ) 
        hist_corr[ispec].Fill( values[0], values[1] )

# ----------------------------------------------------------------------------------- #
# Specify color etc of your histograms
# ----------------------------------------------------------------------------------- #

for ispec in range( nspec ) : 

    hist_par0[ispec].SetLineColor(   color_index[ispec] )
    hist_par1[ispec].SetLineColor(   color_index[ispec] )
    hist_fisher[ispec].SetLineColor( color_index[ispec] )

    hist_par0[ispec].GetXaxis().SetTitle(   par_name[0] )
    hist_par1[ispec].GetXaxis().SetTitle(   par_name[1] )
    hist_fisher[ispec].GetXaxis().SetTitle( "Fisher Discriminant" )

    hist_corr[ispec].SetMarkerColor( color_index[ispec] ) 
    hist_corr[ispec].SetMarkerStyle(marker_index[ispec] ) 

    hist_corr[ispec].GetXaxis().SetTitle( par_name[0] )
    hist_corr[ispec].GetYaxis().SetTitle( par_name[1] )

# ----------------------------------------------------------------------------------- #
# Plot your generated data
# ----------------------------------------------------------------------------------- #

# x and y projections
canvas_1D = TCanvas("canvas_1D", "canvas_1D", 1200, 900)
canvas_1D.Divide(2)

for ispec in range(nspec) : 
    canvas_1D.cd( 1 ) 
    hist_par0[ispec].Draw( draw_opt[ispec] )
    canvas_1D.cd( 2 ) 
    hist_par1[ispec].Draw( draw_opt[ispec] )

# Correlation plot
canvas_2D = TCanvas("canvas_2D", "canvas_2D", 1200, 900)

for ispec in range(nspec) : 
    hist_corr[ispec].Draw( draw_opt[ispec] )



# ----------------------------------------------------------------------------------- #
# start calculating discriminant here...
# ----------------------------------------------------------------------------------- #



# # ----------------------------------------------------------------------------------- #
# # Calculate means and widths of individual parameters
# # ----------------------------------------------------------------------------------- #
# 
# mu_par0  = [0.0, 0.0]
# mu_par1  = [0.0, 0.0]
# 
# # ----------------------------------------------------------------------------------- #
# # Now we need to get the sum of the covarance matrices
# # ----------------------------------------------------------------------------------- #
# 
# covmat_sum = TMatrixD(2,2) # This will need to be generalized to more parameters 
#                            # The diagonal terms are the variances of the 1d histograms
#                            # While the off diagonal terms are the covariances of the
#                            # 2D histogram
# 
# covmat_sum.Invert() #Invert matrix
# 
# # ----------------------------------------------------------------------------------- #
# # Ready to calculate the fisher weights
# # ----------------------------------------------------------------------------------- #
# 
# #multiply the inverted covairance matrix by the (mu_Null-mu_alt) difference
# wf = [???]
# 
# # ----------------------------------------------------------------------------------- #
# # Generate some independent data (but with same specs) as before and apply fisher
# # ----------------------------------------------------------------------------------- #
# 
# for ispec in range( nspec ) : 
#     for iexp in range( ndata ) : 
#         values = get_corr( mean_par0[ispec], width_par0[ispec], mean_par1[ispec], width_par1[ispec], corr[ispec] )
#         hist_fisher[ispec].Fill( ??? )
# 
# # ----------------------------------------------------------------------------------- #
# # Finally go draw the resulting distribution
# # ----------------------------------------------------------------------------------- #
# 
# # x and y projections
# canvas_fisher = TCanvas("canvas_fisher", "canvas_fisher", 1200, 900)
# 
# for ispec in range(nspec) : 
#     hist_fisher[ispec].Draw( draw_opt[ispec] )
# 
# canvas_fisher.Update()
# canvas_fisher.Draw()

raw_input(" ... ")

# ----------------------------------------------------------------------------------- #
# Questions
# ----------------------------------------------------------------------------------- #
#
# 
# Looking at the 2D plot, you can by eye see some separation. 
# But how would you quantify this? 
# In the following we will be working towards a Fisher linear discriminant, 
# which does exactly this!
#
#
# As a measure of how good the separation obtained is, we consider the
# "distance" between the two distributions as a measure of goodness:
#   separation = (mu_NULL - mu_ALT)**2 / (sigma_NULL**2 + sigma_ALT**2)
#
# What separation do you get from the two 1D histograms of par0 and par1 using the above formula?
#
#
#
# Looking at the correlation plot. Can you give a guess of what the Fisher
# Weights should turn out to be? (roughly)
# ... or continue, on the questions below, which will lead you to the answer
#
#
# Outcomment the steps that go into the calculating of the discriminant and 
# fill in the blanks
#
# For references on how to do this, look e.g. at todays lectures or the
# wikipedia article that has been linked under week6
#
# 
# What is the separation that you get from the discriminant as compared to the 
# one dimensional histograms?
#
# ---------------------------------------------------------------------------
# Optional below. Don't use more than 10-15 minutes on this
# ---------------------------------------------------------------------------
#
# Try to change the mean, widths and correlations of the generated samples.
# When does the discrimiant seem to perform optimally?
# Can you construct cases where it does not work?
# 
# Change the number of data points to 100 and see what happens.
# Are the weights you calculate constant?
#
# ----------------------------------------------------------------------------------- #
