#!/usr/bin/env python # ----------------------------------------------------------------------------------- # # Python/pyROOT macro for training classifiers on a dataset. # Specify input, training variables, etc. in top of script before running in shell. # Weights are stored in ./weights[title]: # Validation data stored in: ./[outpath].root # TMVA Manual http://tmva.sourceforge.net/docu/TMVAUsersGuide.pdf ''' # # Authors: Lars Egholm Pedersen & Troels Petersen (NBI) # Date: 7th of December 2016 # # ----------------------------------------------------------------------------------- # import ROOT # ------------------------------------------------------------ # Setup input and training parameters here # Use Ascii or .ROOT input. # The variables defined in varlist should also be checked to match the input files headers read_ascii = True title = "TMVAClassification" outpath = "TMVAClassification.root" # Path to signal and background files (may be the same s.t. difference is specified by cut) # sigpath = "./SimpleDataset.root" # bkgpath = "./SimpleDataset.root" # If reading from ROOT file, specify names of signal and background trees (may be different) # sigtree_name = "ntuple" # bkgtree_name = "ntuple" sigpath = "./SimpleDataset.txt" bkgpath = "./SimpleDataset.txt" # Define cut/requirements specifying whether a training point is signal or background sigcut = "isSignal > 0.5" bkgcut = "isSignal < 0.5" # Define list of variables that will be trained on ( , "f") is there to tell TMVA that it is a # floating point number. Use I for integers (descrete variables) varlist = [ ("x", "f"), ("y", "f") ] # Define list of methods that you want TMVA to test the performance of. # (see Ch. 8 of manual for full options, BDT is around p.110) # Examples here include: # - Fisher discriminant # - Boosted Decision Tree (BDT) # Note that ":".join(["A","B","C"]) is equivalent to "A:B:C" methodlist = [ (ROOT.TMVA.Types.kFisher, "Fisher", "") , # Fisher is Fisher, meaning not so many options (ROOT.TMVA.Types.kBDT, "BDTA", ":".join([ "NTrees=100", # Number of trees "MinNodeSize=4", # Percentage of events required to be in each final leaf node "MaxDepth=3", # Maximal allowed depth of tree "BoostType=AdaBoost", # Adaptive Boosting: Subsequent trees are trained with higher weights on events where predecessor failed "AdaBoostBeta=0.5", # Penalty factor used in adaptive boosing "SeparationType=GiniIndex", # How is separation defined, used to figure out when nodes should be split "nCuts=20", # Number of cuts that are trying along a variable to figure out where it should be split "PruneMethod=NoPruning"])) ] # Method for removing statistically insignificant branches # It is recommended to not use this but instead use trees with low complexity # ------------------------------------------------------------ # Section for setting up TMVA # ROOT file that will store TMVA validation data tmvaoutput = ROOT.TFile( outpath, "RECREATE" ) factory = ROOT.TMVA.Factory( title, tmvaoutput, ":".join([ "!V", # Dont print everything "!Silent", # Print something "Transformations=I", # Perform training on dataset without transforming "AnalysisType=Classification"] ) ) # Analysis is of 'A vs B' separation type # Connect TMVA factory to input: if read_ascii : factory.SetInputTrees( sigpath, bkgpath ) # Standard ROOT format (Ntuple) input: else : sigfile = ROOT.TFile( sigpath, 'READ' ) bkgfile = ROOT.TFile( bkgpath, 'READ' ) sigtree = sigfile.Get( sigtree_name ) bkgtree = bkgfile.Get( bkgtree_name ) factory.AddSignalTree( sigtree ) factory.AddBackgroundTree( bkgtree ) # Add variables: for ivar in varlist : # Arguments : name title data-type factory.AddVariable( ivar[0], ivar[0], ivar[1] ); factory.PrepareTrainingAndTestTree( ROOT.TCut( sigcut ), ROOT.TCut( bkgcut ), ":".join([ "nTrain_Signal=0", # Number of signal events used, 0 = ALL "nTrain_Background=0", # Number of background events, 0 = ALL "SplitMode=Random", # How are events chosen to be used for either training or testing "NormMode=NumEvents", # Integral of datasets is given by number of events # (could e.g. also be sum of weights or simply defined to be 1) "!V" # Don't print everything (i.e. not verbose) ])) # Tell TMVA which types of classifiers it should try out (i.e. the ones specified in methodlist) for imeth in methodlist : # Arguments type name options factory.BookMethod( imeth[0], imeth[1], imeth[2] ) # Perform training factory.TrainAllMethods() factory.TestAllMethods() factory.EvaluateAllMethods() # Close opened files tmvaoutput.Close() if not read_ascii : sigfile.Close() bkgfile.Close() # This will open the Gui where you can examine the classifier behaviour # this can also be done later by : # $ root -l # > TMVA::TMVAGui( "TMVAClassification.root" ) # ROOT.TMVA.TMVAGui( outpath ) raw_input( ' Press any key to exit ' )