Logo ROOT   6.08/07
Reference Guide
MulticlassKeras.py
Go to the documentation of this file.
1 #!/usr/bin/env python
2 
3 from ROOT import TMVA, TFile, TTree, TCut, gROOT
4 from os.path import isfile
5 
6 from keras.models import Sequential
7 from keras.layers.core import Dense, Activation
8 from keras.regularizers import l2
9 from keras import initializations
10 from keras.optimizers import SGD
11 
12 # Setup TMVA
15 
16 output = TFile.Open('TMVA.root', 'RECREATE')
17 factory = TMVA.Factory('TMVAClassification', output,
18  '!V:!Silent:Color:DrawProgressBar:Transformations=D,G:AnalysisType=multiclass')
19 
20 # Load data
21 if not isfile('tmva_example_multiple_background.root'):
22  createDataMacro = gROOT.GetTutorialsDir() + '/tmva/createData.C'
23  print(createDataMacro)
24  gROOT.ProcessLine('.L {}'.format(createDataMacro))
25  gROOT.ProcessLine('create_MultipleBackground(4000)')
26 
27 data = TFile.Open('tmva_example_multiple_background.root')
28 signal = data.Get('TreeS')
29 background0 = data.Get('TreeB0')
30 background1 = data.Get('TreeB1')
31 background2 = data.Get('TreeB2')
32 
33 dataloader = TMVA.DataLoader('dataset')
34 for branch in signal.GetListOfBranches():
35  dataloader.AddVariable(branch.GetName())
36 
37 dataloader.AddTree(signal, 'Signal')
38 dataloader.AddTree(background0, 'Background_0')
39 dataloader.AddTree(background1, 'Background_1')
40 dataloader.AddTree(background2, 'Background_2')
41 dataloader.PrepareTrainingAndTestTree(TCut(''),
42  'SplitMode=Random:NormMode=NumEvents:!V')
43 
44 # Generate model
45 
46 # Define initialization
47 def normal(shape, name=None):
48  return initializations.normal(shape, scale=0.05, name=name)
49 
50 # Define model
51 model = Sequential()
52 model.add(Dense(32, init=normal, activation='relu', W_regularizer=l2(1e-5), input_dim=4))
53 #model.add(Dense(32, init=normal, activation='relu', W_regularizer=l2(1e-5)))
54 model.add(Dense(4, init=normal, activation='softmax'))
55 
56 # Set loss and optimizer
57 model.compile(loss='categorical_crossentropy', optimizer=SGD(lr=0.01), metrics=['accuracy',])
58 
59 # Store model to file
60 model.save('model.h5')
61 model.summary()
62 
63 # Book methods
64 factory.BookMethod(dataloader, TMVA.Types.kFisher, 'Fisher',
65  '!H:!V:Fisher:VarTransform=D,G')
66 factory.BookMethod(dataloader, TMVA.Types.kPyKeras, "PyKeras",
67  'H:!V:VarTransform=D,G:FilenameModel=model.h5:NumEpochs=20:BatchSize=32')
68 
69 # Run TMVA
70 factory.TrainAllMethods()
71 factory.TestAllMethods()
72 factory.EvaluateAllMethods()
static Tools & Instance()
Definition: Tools.cxx:80
static std::string format(double x, double y, int digits, int width)
static void PyInitialize()
static TFile * Open(const char *name, Option_t *option="", const char *ftitle="", Int_t compress=1, Int_t netopt=0)
Create / open a file.
Definition: TFile.cxx:3907
def normal(shape, name=None)