#!/usr/bin/env python

# ----------------------------------------------------------------------------------- #
#
#  Root macro for training different types of TMVA's discriminants
#
#  Author: Lars Egholm Pedersen 
#  Email:  egholm@nbi.dk
#  Date:   5th of October 2013
#
#
# ----------------------------------------------------------------------------------- #

from ROOT import *

#Main call here
if __name__ == '__main__':

    #Everything will be given as [signal, background] below
    #Input files, trees and cut

    # ----------------------------------------------------------------------------------- #
    #Define variables that go into training here: Fisher Iris data
    infile = [ TFile("iris_data_set.root", "READ"), 
               TFile("iris_data_set.root", "READ") ]

    intree = [infile[0].Get("tree") , infile[1].Get("tree") ]

    vardict = { "sepal_length" : 'f' , 
                "sepal_width"  : 'f' , 
                "petal_length" : 'f' , 
                "petal_width"  : 'f' }

    cut    = ["flower_type == 1", "flower_type == 2"] #Signal / background criteria


#    # ----------------------------------------------------------------------------------- #
#    #Define variables that go into training here: For heart desease data
#    infile = [ TFile("south_african_heart_disease.root", "READ"), 
#               TFile("south_african_heart_disease.root", "READ") ]
#
#    intree = [infile[0].Get("tree") , infile[1].Get("tree") ]
#
#    vardict = { "sbp"       : 'f' ,
#                "tobacco"   : 'f' ,
#                "ldl"       : 'f' ,
#                "adiposity" : 'f' ,
#                "famhist"   : 'f' ,
#                "typea"     : 'f' ,
#                "obesity"   : 'f' ,
#                "alcohol"   : 'f' ,
#                "age"       : 'f' } #Define variables and type of variable
#
#    cut    = ["chd == 0", "chd == 1"]
#

#    # ----------------------------------------------------------------------------------- #
#    #Define variables that go into training here: For heart desease data
#    infile = [ TFile("Higgs14TeV.root", "READ"), 
#               TFile("ZZ14TeV.root", "READ") ]
#
#    intree = [infile[0].Get("Default_M125") , infile[1].Get("Default") ]
#
#    vardict = { "cts"  : 'f' ,
#                "phi1" : 'f' ,
#                "ct1"  : 'f' ,
#                "ct2"  : 'f' ,
#                "phi"  : 'f' ,
#                "mZ1"  : 'f' ,
#                "mZ2"  : 'f' } #Define variables and type of variable
#
#    cut    = ["", ""] #No need to cut here : Two seperate files

    # ----------------------------------------------------------------------------------- #
    #TMVA is called from here...

    #Define factory options
    factoryOption = "!V:!Silent:Transformations=I;P:AnalysisType=Classification"

    #TMVA title : 
    tmvatitle = "TMVAClassifier"

    #Start training here:
    output  = TFile("./tmva." + tmvatitle + ".root", "RECREATE" )
    factory = TMVA.Factory( tmvatitle, output, factoryOption )

    factory.AddSignalTree(     intree[0], 1.0 ) # Tell tmva where it should find signal
    factory.AddBackgroundTree( intree[1], 1.0 ) # and background trees


    #Add variables to your factory
    for ivar in vardict : 
        factory.AddVariable( ivar, ivar, "", vardict[ivar] )

    #Tell factory how many event it should train on etc
    factory.PrepareTrainingAndTestTree( TCut(cut[0]), TCut(cut[1]), 
                                        "nTrain_Signal=0:nTrain_Background=0:SplitMode=Random:NormMode=NumEvents:!V" )


    #Define different types of classifiers
    factory.BookMethod( TMVA.Types.kFisher   , "Fisher"  , "!H:!V:Fisher:VarTransform=None")
#    factory.BookMethod( TMVA.Types.kBDT      , "BDT"     , "!H:!V:NTrees=500:NNodesMax=10" )
#    factory.BookMethod( TMVA.Types.kCFMlpANN , "CFMlpANN", "!H:!V:NCycles=200:HiddenLayers=N+1,N"  );


    factory.TrainAllMethods();
    factory.TestAllMethods();
    factory.EvaluateAllMethods();


# --------------------------------------------------------------------------------
# Questions
# --------------------------------------------------------------------------------
# 
# The script has initially been set up to construct a Fisher discriminant on the
# Iris dataset from last time
#
# Try to run it by ./train_discriminant.py
#
# TMVA will also tell you the separational power of the individual variables : 
#
# --- IdTransformation         : Ranking result (top variable is best ranked)
# --- IdTransformation         : -------------------------------------
# --- IdTransformation         : Rank : Variable     : Separation
# --- IdTransformation         : -------------------------------------
# --- IdTransformation         :    1 : petal_length : ...
#
#
# And what the found Fisher coeeficients are : 
#
# --- Fisher                   : Results for Fisher coefficients:
# --- Fisher                   : ---------------------------
# --- Fisher                   :     Variable:  Coefficient:
# --- Fisher                   : ---------------------------
# --- Fisher                   :  sepal_width:       ...
#
# Finally it will also print out the ROC integral for you:
#
# --- Factory                  : --------------------------------------------------------------------------------
# --- Factory                  : MVA              Signal efficiency at bkg eff.(error):       | Sepa-    Signifi- 
# --- Factory                  : Method:          @B=0.01    @B=0.10    @B=0.30    ROC-integ. | ration:  cance:   
# --- Factory                  : --------------------------------------------------------------------------------
#
# You can study the output by typing in your terminal : 
# 
# root -l $ROOTSYS/tmva/test/TMVAGui.C\(\"tmva.TMVAClassifier.root\"\)
#
# This will open a GUI that can plot all input parameters, correlations, resulting discriminat
# (Classifier Output Distributions ... ) and ROC curve in an automated way
#
# --------------------------------------------------------------------------------
#
# Change the script s.t. it will start training on the heart desease data in the top 
# of the script.
# Also add the two lines : 
#
#    factory.BookMethod( TMVA.Types.kBDT      , "BDT"     , "!H:!V:NTrees=500:NNodesMax=10" )
#    factory.BookMethod( TMVA.Types.kCFMlpANN , "CFMlpANN", "!H:!V:NCycles=500:HiddenLayers=N+1,N"  );
#
# back into the code. This will train a Boosted Decision Tree (top line) and a Neural Network (Bottom)
# This will initially be horribly overtrained. But do it anyway and use the GUI to look at distributions
# and linear correlation coeeficients. 
# Looking at the distributions, can you get an idea of which variables will give a good separation?
# Does this correspond to what TMVA writes out when running?
# 
# Try to reduce the overtraining of the BDT by modifying the how many trees and how many nodes it uses
# Can you make the Kolmogorov Smirnoff probability more reasonable?
# 
# Try to remove the 3 least significant parameters to see if that helps?
# Does this change the ROC integral?
#
# Which of the methods seem to be most sensitive to the amount of statistics?
# 
# ----------------------------------------------------------------------------------
# Bonus if you want to play with something that has much more statistics and is more
# related to physics : 
# ----------------------------------------------------------------------------------
# 
# Use the Higgs / ZZ data to train a BDT and/or Neural network?
# 
# See how far complex you can make the classifier (i.e. how many trees / nodes...)
# without significant overtraining?
# 
# Can you remove any parameters without loss of separation? (hint look at variable ranking)
# 
# ----------------------------------------------------------------------------------
