13from ROOT
import TMVA, TFile, TTree, TCut
14from subprocess
import call
15from os.path
import isfile
17from keras.models
import Sequential
18from keras.layers
import Dense, Activation
19from keras.regularizers
import l2
20from keras.optimizers
import SGD
28 '!V:!Silent:Color:DrawProgressBar:Transformations=D,G:AnalysisType=Classification')
31if not isfile(
'tmva_class_example.root'):
32 call([
'curl',
'-O',
'http://root.cern.ch/files/tmva_class_example.root'])
35signal = data.Get(
'TreeS')
36background = data.Get(
'TreeB')
39for branch
in signal.GetListOfBranches():
40 dataloader.AddVariable(branch.GetName())
42dataloader.AddSignalTree(signal, 1.0)
43dataloader.AddBackgroundTree(background, 1.0)
44dataloader.PrepareTrainingAndTestTree(
TCut(
''),
45 'nTrain_Signal=4000:nTrain_Background=4000:SplitMode=Random:NormMode=NumEvents:!V')
51model.add(Dense(64, activation=
'relu', W_regularizer=l2(1e-5), input_dim=4))
52model.add(Dense(2, activation=
'softmax'))
55model.compile(loss=
'categorical_crossentropy',
56 optimizer=SGD(lr=0.01), metrics=[
'accuracy', ])
63factory.BookMethod(dataloader, TMVA.Types.kFisher,
'Fisher',
64 '!H:!V:Fisher:VarTransform=D,G')
65factory.BookMethod(dataloader, TMVA.Types.kPyKeras,
'PyKeras',
66 'H:!V:VarTransform=D,G:FilenameModel=model.h5:NumEpochs=20:BatchSize=32')
69factory.TrainAllMethods()
70factory.TestAllMethods()
71factory.EvaluateAllMethods()
A specialized string object used for TTree selections.
static TFile * Open(const char *name, Option_t *option="", const char *ftitle="", Int_t compress=ROOT::RCompressionSetting::EDefaults::kUseGeneralPurpose, Int_t netopt=0)
Create / open a file.
This is the main MVA steering class.
static void PyInitialize()
Initialize Python interpreter.