Logo ROOT  
Reference Guide
 
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Modules Pages
Loading...
Searching...
No Matches
ClassificationKeras.py
Go to the documentation of this file.
1#!/usr/bin/env python
2# \file
3# \ingroup tutorial_tmva_keras
4# \notebook -nodraw
5# This tutorial shows how to do classification in TMVA with neural networks
6# trained with keras.
7#
8# \macro_code
9#
10# \date 2017
11# \author TMVA Team
12
13from ROOT import TMVA, TFile, TCut, gROOT
14from subprocess import call
15from os.path import isfile
16
17from tensorflow.keras.models import Sequential
18from tensorflow.keras.layers import Dense
19from tensorflow.keras.optimizers import SGD
20
21
22def create_model():
23 # Generate model
24
25 # Define model
26 model = Sequential()
27 model.add(Dense(64, activation='relu', input_dim=4))
28 model.add(Dense(2, activation='softmax'))
29
30 # Set loss and optimizer
31 model.compile(loss='categorical_crossentropy',
32 optimizer=SGD(learning_rate=0.01), weighted_metrics=['accuracy', ])
33
34 # Store model to file
35 model.save('modelClassification.h5')
37
38
39def run():
40 with TFile.Open('TMVA_Classification_Keras.root', 'RECREATE') as output, TFile.Open(str(gROOT.GetTutorialDir()) + '/tmva/data/tmva_class_example.root') as data:
41 factory = TMVA.Factory('TMVAClassification', output,
42 '!V:!Silent:Color:DrawProgressBar:Transformations=D,G:AnalysisType=Classification')
43
44 signal = data.Get('TreeS')
45 background = data.Get('TreeB')
46
47 dataloader = TMVA.DataLoader('dataset')
48 for branch in signal.GetListOfBranches():
50
51 dataloader.AddSignalTree(signal, 1.0)
52 dataloader.AddBackgroundTree(background, 1.0)
54 'nTrain_Signal=4000:nTrain_Background=4000:SplitMode=Random:NormMode=NumEvents:!V')
55
56 # Book methods
57 factory.BookMethod(dataloader, TMVA.Types.kFisher, 'Fisher',
58 '!H:!V:Fisher:VarTransform=D,G')
59 factory.BookMethod(dataloader, TMVA.Types.kPyKeras, 'PyKeras',
60 'H:!V:VarTransform=D,G:FilenameModel=modelClassification.h5:FilenameTrainedModel=trainedModelClassification.h5:NumEpochs=20:BatchSize=32')
61
62 # Run training, test and evaluation
66
67
68if __name__ == "__main__":
69 # Setup TMVA
72
73 # Create and store the ML model
75
76 # Run TMVA
77 run()
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
A specialized string object used for TTree selections.
Definition TCut.h:25
This is the main MVA steering class.
Definition Factory.h:80