Logo ROOT  
Reference Guide
Loading...
Searching...
No Matches
MulticlassKeras.py
Go to the documentation of this file.
1## \file
2## \ingroup tutorial_tmva_keras
3## \notebook -nodraw
4## This tutorial shows how to do multiclass classification in TMVA with neural
5## networks trained with keras.
6##
7## \macro_code
8##
9## \date 2017
10## \author TMVA Team
11
12from ROOT import TMVA, TFile, TCut, gROOT
13from os.path import isfile
14
15from tensorflow.keras.models import Sequential
16from tensorflow.keras.layers import Dense
17from tensorflow.keras.optimizers import SGD
18
19
20def create_model():
21 # Define model
22 model = Sequential()
23 model.add(Dense(32, activation='relu', input_dim=4))
24 model.add(Dense(4, activation='softmax'))
25
26 # Set loss and optimizer
27 model.compile(loss='categorical_crossentropy', optimizer=SGD(
28 learning_rate=0.01), weighted_metrics=['accuracy',])
29
30 # Store model to file
31 model.save('modelMultiClass.keras')
32 model.summary()
33
34
35def run():
36 with TFile.Open('TMVA.root', 'RECREATE') as output, TFile.Open('tmva_example_multiple_background.root') as data:
37 factory = TMVA.Factory('TMVAClassification', output,
38 '!V:!Silent:Color:DrawProgressBar:Transformations=D,G:AnalysisType=multiclass')
39
40 signal = data.Get('TreeS')
41 background0 = data.Get('TreeB0')
42 background1 = data.Get('TreeB1')
43 background2 = data.Get('TreeB2')
44
45 dataloader = TMVA.DataLoader('dataset')
46 for branch in signal.GetListOfBranches():
47 dataloader.AddVariable(branch.GetName())
48
49 dataloader.AddTree(signal, 'Signal')
50 dataloader.AddTree(background0, 'Background_0')
51 dataloader.AddTree(background1, 'Background_1')
52 dataloader.AddTree(background2, 'Background_2')
53 dataloader.PrepareTrainingAndTestTree(TCut(''),
54 'SplitMode=Random:NormMode=NumEvents:!V')
55
56 # Book methods
57 factory.BookMethod(dataloader, TMVA.Types.kFisher, 'Fisher',
58 '!H:!V:Fisher:VarTransform=D,G')
59 factory.BookMethod(dataloader, TMVA.Types.kPyKeras, 'PyKeras',
60 'H:!V:VarTransform=D,G:FilenameModel=modelMultiClass.keras:FilenameTrainedModel=trainedModelMultiClass.keras:NumEpochs=20:BatchSize=32')
61
62 # Run TMVA
63 factory.TrainAllMethods()
64 factory.TestAllMethods()
65 factory.EvaluateAllMethods()
66
67
68if __name__ == "__main__":
69 # Generate model
70 create_model()
71
72 # Setup TMVA
75
76 # Load data
77 if not isfile('tmva_example_multiple_background.root'):
78 createDataMacro = str(gROOT.GetTutorialDir()) + '/machine_learning/createData.C'
79 print(createDataMacro)
80 gROOT.ProcessLine('.L {}'.format(createDataMacro))
81 gROOT.ProcessLine('create_MultipleBackground(4000)')
82
83 # Run TMVA
84 run()
A specialized string object used for TTree selections.
Definition TCut.h:25
static TFile * Open(const char *name, Option_t *option="", const char *ftitle="", Int_t compress=ROOT::RCompressionSetting::EDefaults::kUseCompiledDefault, Int_t netopt=0)
Create / open a file.
Definition TFile.cxx:3787
This is the main MVA steering class.
Definition Factory.h:80
static void PyInitialize()
Initialize Python interpreter.
static Tools & Instance()
Definition Tools.cxx:72