Logo ROOT  
Reference Guide
Loading...
Searching...
No Matches
ClassificationKeras.py
Go to the documentation of this file.
1## \file
2## \ingroup tutorial_tmva_keras
3## \notebook -nodraw
4## This tutorial shows how to do classification in TMVA with neural networks
5## trained with keras.
6##
7## \macro_code
8##
9## \date 2017
10## \author TMVA Team
11
12from ROOT import TMVA, TFile, TCut, gROOT
13from subprocess import call
14from os.path import isfile
15
16from tensorflow.keras.models import Sequential
17from tensorflow.keras.layers import Dense
18from tensorflow.keras.optimizers import SGD
19
20
21def create_model():
22 # Generate model
23
24 # Define model
25 model = Sequential()
26 model.add(Dense(64, activation='relu', input_dim=4))
27 model.add(Dense(2, activation='softmax'))
28
29 # Set loss and optimizer
30 model.compile(loss='categorical_crossentropy',
31 optimizer=SGD(learning_rate=0.01), weighted_metrics=['accuracy', ])
32
33 # Store model to file
34 model.save('modelClassification.keras')
35 model.summary()
36
37
38def run():
39 with TFile.Open('TMVA_Classification_Keras.root', 'RECREATE') as output, TFile.Open(str(gROOT.GetTutorialDir()) + '/machine_learning/data/tmva_class_example.root') as data:
40 factory = TMVA.Factory('TMVAClassification', output,
41 '!V:!Silent:Color:DrawProgressBar:Transformations=D,G:AnalysisType=Classification')
42
43 signal = data.Get('TreeS')
44 background = data.Get('TreeB')
45
46 dataloader = TMVA.DataLoader('dataset')
47 for branch in signal.GetListOfBranches():
48 dataloader.AddVariable(branch.GetName())
49
50 dataloader.AddSignalTree(signal, 1.0)
51 dataloader.AddBackgroundTree(background, 1.0)
52 dataloader.PrepareTrainingAndTestTree(TCut(''),
53 'nTrain_Signal=4000:nTrain_Background=4000:SplitMode=Random:NormMode=NumEvents:!V')
54
55 # Book methods
56 factory.BookMethod(dataloader, TMVA.Types.kFisher, 'Fisher',
57 '!H:!V:Fisher:VarTransform=D,G')
58 factory.BookMethod(dataloader, TMVA.Types.kPyKeras, 'PyKeras',
59 'H:!V:VarTransform=D,G:FilenameModel=modelClassification.keras:FilenameTrainedModel=trainedModelClassification.keras:NumEpochs=20:BatchSize=32:LearningRateSchedule=10,0.01;20,0.005')
60
61 # Run training, test and evaluation
62 factory.TrainAllMethods()
63 factory.TestAllMethods()
64 factory.EvaluateAllMethods()
65
66
67if __name__ == "__main__":
68 # Setup TMVA
71
72 # Create and store the ML model
73 create_model()
74
75 # Run TMVA
76 run()
A specialized string object used for TTree selections.
Definition TCut.h:25
static TFile * Open(const char *name, Option_t *option="", const char *ftitle="", Int_t compress=ROOT::RCompressionSetting::EDefaults::kUseCompiledDefault, Int_t netopt=0)
Create / open a file.
Definition TFile.cxx:3787
This is the main MVA steering class.
Definition Factory.h:80
static void PyInitialize()
Initialize Python interpreter.
static Tools & Instance()
Definition Tools.cxx:72