Logo ROOT   6.18/05
Reference Guide
ClassificationKeras.py
Go to the documentation of this file.
1#!/usr/bin/env python
2## \file
3## \ingroup tutorial_tmva_keras
4## \notebook -nodraw
5## This tutorial shows how to do classification in TMVA with neural networks
6## trained with keras.
7##
8## \macro_code
9##
10## \date 2017
11## \author TMVA Team
12
13from ROOT import TMVA, TFile, TTree, TCut
14from subprocess import call
15from os.path import isfile
16
17from keras.models import Sequential
18from keras.layers import Dense, Activation
19from keras.regularizers import l2
20from keras.optimizers import SGD
21
22# Setup TMVA
25
26output = TFile.Open('TMVA.root', 'RECREATE')
27factory = TMVA.Factory('TMVAClassification', output,
28 '!V:!Silent:Color:DrawProgressBar:Transformations=D,G:AnalysisType=Classification')
29
30# Load data
31if not isfile('tmva_class_example.root'):
32 call(['curl', '-O', 'http://root.cern.ch/files/tmva_class_example.root'])
33
34data = TFile.Open('tmva_class_example.root')
35signal = data.Get('TreeS')
36background = data.Get('TreeB')
37
38dataloader = TMVA.DataLoader('dataset')
39for branch in signal.GetListOfBranches():
40 dataloader.AddVariable(branch.GetName())
41
42dataloader.AddSignalTree(signal, 1.0)
43dataloader.AddBackgroundTree(background, 1.0)
44dataloader.PrepareTrainingAndTestTree(TCut(''),
45 'nTrain_Signal=4000:nTrain_Background=4000:SplitMode=Random:NormMode=NumEvents:!V')
46
47# Generate model
48
49# Define model
50model = Sequential()
51model.add(Dense(64, activation='relu', W_regularizer=l2(1e-5), input_dim=4))
52model.add(Dense(2, activation='softmax'))
53
54# Set loss and optimizer
55model.compile(loss='categorical_crossentropy',
56 optimizer=SGD(lr=0.01), metrics=['accuracy', ])
57
58# Store model to file
59model.save('model.h5')
60model.summary()
61
62# Book methods
63factory.BookMethod(dataloader, TMVA.Types.kFisher, 'Fisher',
64 '!H:!V:Fisher:VarTransform=D,G')
65factory.BookMethod(dataloader, TMVA.Types.kPyKeras, 'PyKeras',
66 'H:!V:VarTransform=D,G:FilenameModel=model.h5:NumEpochs=20:BatchSize=32')
67
68# Run training, test and evaluation
69factory.TrainAllMethods()
70factory.TestAllMethods()
71factory.EvaluateAllMethods()
A specialized string object used for TTree selections.
Definition: TCut.h:25
static TFile * Open(const char *name, Option_t *option="", const char *ftitle="", Int_t compress=ROOT::RCompressionSetting::EDefaults::kUseGeneralPurpose, Int_t netopt=0)
Create / open a file.
Definition: TFile.cxx:3980
This is the main MVA steering class.
Definition: Factory.h:81
static void PyInitialize()
Initialize Python interpreter.
static Tools & Instance()
Definition: Tools.cxx:75