Logo ROOT   6.14/05
Reference Guide
MulticlassKeras.py
Go to the documentation of this file.
1 #!/usr/bin/env python
2 
3 from ROOT import TMVA, TFile, TTree, TCut, gROOT
4 from os.path import isfile
5 
6 from keras.models import Sequential
7 from keras.layers.core import Dense, Activation
8 from keras.regularizers import l2
9 from keras.optimizers import SGD
10 
11 # Setup TMVA
14 
15 output = TFile.Open('TMVA.root', 'RECREATE')
16 factory = TMVA.Factory('TMVAClassification', output,
17  '!V:!Silent:Color:DrawProgressBar:Transformations=D,G:AnalysisType=multiclass')
18 
19 # Load data
20 if not isfile('tmva_example_multiple_background.root'):
21  createDataMacro = str(gROOT.GetTutorialDir()) + '/tmva/createData.C'
22  print(createDataMacro)
23  gROOT.ProcessLine('.L {}'.format(createDataMacro))
24  gROOT.ProcessLine('create_MultipleBackground(4000)')
25 
26 data = TFile.Open('tmva_example_multiple_background.root')
27 signal = data.Get('TreeS')
28 background0 = data.Get('TreeB0')
29 background1 = data.Get('TreeB1')
30 background2 = data.Get('TreeB2')
31 
32 dataloader = TMVA.DataLoader('dataset')
33 for branch in signal.GetListOfBranches():
34  dataloader.AddVariable(branch.GetName())
35 
36 dataloader.AddTree(signal, 'Signal')
37 dataloader.AddTree(background0, 'Background_0')
38 dataloader.AddTree(background1, 'Background_1')
39 dataloader.AddTree(background2, 'Background_2')
40 dataloader.PrepareTrainingAndTestTree(TCut(''),
41  'SplitMode=Random:NormMode=NumEvents:!V')
42 
43 # Generate model
44 
45 # Define model
46 model = Sequential()
47 model.add(Dense(32, activation='relu', W_regularizer=l2(1e-5), input_dim=4))
48 model.add(Dense(4, activation='softmax'))
49 
50 # Set loss and optimizer
51 model.compile(loss='categorical_crossentropy', optimizer=SGD(lr=0.01), metrics=['accuracy',])
52 
53 # Store model to file
54 model.save('model.h5')
55 model.summary()
56 
57 # Book methods
58 factory.BookMethod(dataloader, TMVA.Types.kFisher, 'Fisher',
59  '!H:!V:Fisher:VarTransform=D,G')
60 factory.BookMethod(dataloader, TMVA.Types.kPyKeras, "PyKeras",
61  'H:!V:VarTransform=D,G:FilenameModel=model.h5:NumEpochs=20:BatchSize=32')
62 
63 # Run TMVA
64 factory.TrainAllMethods()
65 factory.TestAllMethods()
66 factory.EvaluateAllMethods()
static Tools & Instance()
Definition: Tools.cxx:75
static void PyInitialize()
Initialize Python interpreter.
static TFile * Open(const char *name, Option_t *option="", const char *ftitle="", Int_t compress=1, Int_t netopt=0)
Create / open a file.
Definition: TFile.cxx:3976
This is the main MVA steering class.
Definition: Factory.h:81