3 from ROOT
import TMVA, TFile, TTree, TCut
4 from subprocess
import call
5 from os.path
import isfile
7 from keras.models
import Sequential
8 from keras.layers
import Dense, Activation
9 from keras.regularizers
import l2
10 from keras.optimizers
import SGD
18 '!V:!Silent:Color:DrawProgressBar:Transformations=D,G:AnalysisType=Classification')
21 if not isfile(
'tmva_class_example.root'):
22 call([
'curl',
'-O',
'http://root.cern.ch/files/tmva_class_example.root'])
25 signal = data.Get(
'TreeS')
26 background = data.Get(
'TreeB')
29 for branch
in signal.GetListOfBranches():
30 dataloader.AddVariable(branch.GetName())
32 dataloader.AddSignalTree(signal, 1.0)
33 dataloader.AddBackgroundTree(background, 1.0)
34 dataloader.PrepareTrainingAndTestTree(TCut(
''),
35 'nTrain_Signal=4000:nTrain_Background=4000:SplitMode=Random:NormMode=NumEvents:!V')
41 model.add(Dense(64, activation=
'relu', W_regularizer=l2(1e-5), input_dim=4))
42 model.add(Dense(2, activation=
'softmax'))
45 model.compile(loss=
'categorical_crossentropy',
46 optimizer=SGD(lr=0.01), metrics=[
'accuracy', ])
49 model.save(
'model.h5')
53 factory.BookMethod(dataloader, TMVA.Types.kFisher,
'Fisher',
54 '!H:!V:Fisher:VarTransform=D,G')
55 factory.BookMethod(dataloader, TMVA.Types.kPyKeras,
'PyKeras',
56 'H:!V:VarTransform=D,G:FilenameModel=model.h5:NumEpochs=20:BatchSize=32')
59 factory.TrainAllMethods()
60 factory.TestAllMethods()
61 factory.EvaluateAllMethods()
static void PyInitialize()
Initialize Python interpreter.
static TFile * Open(const char *name, Option_t *option="", const char *ftitle="", Int_t compress=1, Int_t netopt=0)
Create / open a file.
This is the main MVA steering class.