Logo ROOT  
Reference Guide
MulticlassKeras.py
Go to the documentation of this file.
1#!/usr/bin/env python
2## \file
3## \ingroup tutorial_tmva_keras
4## \notebook -nodraw
5## This tutorial shows how to do multiclass classification in TMVA with neural
6## networks trained with keras.
7##
8## \macro_code
9##
10## \date 2017
11## \author TMVA Team
12
13from ROOT import TMVA, TFile, TTree, TCut, gROOT
14from os.path import isfile
15
16from keras.models import Sequential
17from keras.layers.core import Dense, Activation
18from keras.regularizers import l2
19from keras.optimizers import SGD
20
21# Setup TMVA
24
25output = TFile.Open('TMVA.root', 'RECREATE')
26factory = TMVA.Factory('TMVAClassification', output,
27 '!V:!Silent:Color:DrawProgressBar:Transformations=D,G:AnalysisType=multiclass')
28
29# Load data
30if not isfile('tmva_example_multiple_background.root'):
31 createDataMacro = str(gROOT.GetTutorialDir()) + '/tmva/createData.C'
32 print(createDataMacro)
33 gROOT.ProcessLine('.L {}'.format(createDataMacro))
34 gROOT.ProcessLine('create_MultipleBackground(4000)')
35
36data = TFile.Open('tmva_example_multiple_background.root')
37signal = data.Get('TreeS')
38background0 = data.Get('TreeB0')
39background1 = data.Get('TreeB1')
40background2 = data.Get('TreeB2')
41
42dataloader = TMVA.DataLoader('dataset')
43for branch in signal.GetListOfBranches():
44 dataloader.AddVariable(branch.GetName())
45
46dataloader.AddTree(signal, 'Signal')
47dataloader.AddTree(background0, 'Background_0')
48dataloader.AddTree(background1, 'Background_1')
49dataloader.AddTree(background2, 'Background_2')
50dataloader.PrepareTrainingAndTestTree(TCut(''),
51 'SplitMode=Random:NormMode=NumEvents:!V')
52
53# Generate model
54
55# Define model
56model = Sequential()
57model.add(Dense(32, activation='relu', W_regularizer=l2(1e-5), input_dim=4))
58model.add(Dense(4, activation='softmax'))
59
60# Set loss and optimizer
61model.compile(loss='categorical_crossentropy', optimizer=SGD(lr=0.01), metrics=['accuracy',])
62
63# Store model to file
64model.save('model.h5')
65model.summary()
66
67# Book methods
68factory.BookMethod(dataloader, TMVA.Types.kFisher, 'Fisher',
69 '!H:!V:Fisher:VarTransform=D,G')
70factory.BookMethod(dataloader, TMVA.Types.kPyKeras, "PyKeras",
71 'H:!V:VarTransform=D,G:FilenameModel=model.h5:NumEpochs=20:BatchSize=32')
72
73# Run TMVA
74factory.TrainAllMethods()
75factory.TestAllMethods()
76factory.EvaluateAllMethods()
A specialized string object used for TTree selections.
Definition: TCut.h:25
static TFile * Open(const char *name, Option_t *option="", const char *ftitle="", Int_t compress=ROOT::RCompressionSetting::EDefaults::kUseCompiledDefault, Int_t netopt=0)
Create / open a file.
Definition: TFile.cxx:3923
This is the main MVA steering class.
Definition: Factory.h:81
static void PyInitialize()
Initialize Python interpreter.
static Tools & Instance()
Definition: Tools.cxx:75