Logo ROOT   6.16/01
Reference Guide
MulticlassKeras.py
Go to the documentation of this file.
1#!/usr/bin/env python
2
3from ROOT import TMVA, TFile, TTree, TCut, gROOT
4from os.path import isfile
5
6from keras.models import Sequential
7from keras.layers.core import Dense, Activation
8from keras.regularizers import l2
9from keras.optimizers import SGD
10
11# Setup TMVA
14
15output = TFile.Open('TMVA.root', 'RECREATE')
16factory = TMVA.Factory('TMVAClassification', output,
17 '!V:!Silent:Color:DrawProgressBar:Transformations=D,G:AnalysisType=multiclass')
18
19# Load data
20if not isfile('tmva_example_multiple_background.root'):
21 createDataMacro = str(gROOT.GetTutorialDir()) + '/tmva/createData.C'
22 print(createDataMacro)
23 gROOT.ProcessLine('.L {}'.format(createDataMacro))
24 gROOT.ProcessLine('create_MultipleBackground(4000)')
25
26data = TFile.Open('tmva_example_multiple_background.root')
27signal = data.Get('TreeS')
28background0 = data.Get('TreeB0')
29background1 = data.Get('TreeB1')
30background2 = data.Get('TreeB2')
31
32dataloader = TMVA.DataLoader('dataset')
33for branch in signal.GetListOfBranches():
34 dataloader.AddVariable(branch.GetName())
35
36dataloader.AddTree(signal, 'Signal')
37dataloader.AddTree(background0, 'Background_0')
38dataloader.AddTree(background1, 'Background_1')
39dataloader.AddTree(background2, 'Background_2')
40dataloader.PrepareTrainingAndTestTree(TCut(''),
41 'SplitMode=Random:NormMode=NumEvents:!V')
42
43# Generate model
44
45# Define model
46model = Sequential()
47model.add(Dense(32, activation='relu', W_regularizer=l2(1e-5), input_dim=4))
48model.add(Dense(4, activation='softmax'))
49
50# Set loss and optimizer
51model.compile(loss='categorical_crossentropy', optimizer=SGD(lr=0.01), metrics=['accuracy',])
52
53# Store model to file
54model.save('model.h5')
55model.summary()
56
57# Book methods
58factory.BookMethod(dataloader, TMVA.Types.kFisher, 'Fisher',
59 '!H:!V:Fisher:VarTransform=D,G')
60factory.BookMethod(dataloader, TMVA.Types.kPyKeras, "PyKeras",
61 'H:!V:VarTransform=D,G:FilenameModel=model.h5:NumEpochs=20:BatchSize=32')
62
63# Run TMVA
64factory.TrainAllMethods()
65factory.TestAllMethods()
66factory.EvaluateAllMethods()
static TFile * Open(const char *name, Option_t *option="", const char *ftitle="", Int_t compress=ROOT::RCompressionSetting::EDefaults::kUseGeneralPurpose, Int_t netopt=0)
Create / open a file.
Definition: TFile.cxx:3975
This is the main MVA steering class.
Definition: Factory.h:81
static void PyInitialize()
Initialize Python interpreter.
static Tools & Instance()
Definition: Tools.cxx:75