Logo ROOT  
Reference Guide
 
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Modules Pages
Loading...
Searching...
No Matches
MulticlassKeras.py
Go to the documentation of this file.
1#!/usr/bin/env python
2# \file
3# \ingroup tutorial_tmva_keras
4# \notebook -nodraw
5# This tutorial shows how to do multiclass classification in TMVA with neural
6# networks trained with keras.
7#
8# \macro_code
9#
10# \date 2017
11# \author TMVA Team
12
13from ROOT import TMVA, TFile, TCut, gROOT
14from os.path import isfile
15
16from tensorflow.keras.models import Sequential
17from tensorflow.keras.layers import Dense
18from tensorflow.keras.optimizers import SGD
19
20
21def create_model():
22 # Define model
23 model = Sequential()
24 model.add(Dense(32, activation='relu', input_dim=4))
25 model.add(Dense(4, activation='softmax'))
26
27 # Set loss and optimizer
28 model.compile(loss='categorical_crossentropy', optimizer=SGD(
29 learning_rate=0.01), weighted_metrics=['accuracy',])
30
31 # Store model to file
32 model.save('modelMultiClass.h5')
34
35
36def run():
37 with TFile.Open('TMVA.root', 'RECREATE') as output, TFile.Open('tmva_example_multiple_background.root') as data:
38 factory = TMVA.Factory('TMVAClassification', output,
39 '!V:!Silent:Color:DrawProgressBar:Transformations=D,G:AnalysisType=multiclass')
40
41 signal = data.Get('TreeS')
42 background0 = data.Get('TreeB0')
43 background1 = data.Get('TreeB1')
44 background2 = data.Get('TreeB2')
45
46 dataloader = TMVA.DataLoader('dataset')
47 for branch in signal.GetListOfBranches():
49
50 dataloader.AddTree(signal, 'Signal')
51 dataloader.AddTree(background0, 'Background_0')
52 dataloader.AddTree(background1, 'Background_1')
53 dataloader.AddTree(background2, 'Background_2')
55 'SplitMode=Random:NormMode=NumEvents:!V')
56
57 # Book methods
58 factory.BookMethod(dataloader, TMVA.Types.kFisher, 'Fisher',
59 '!H:!V:Fisher:VarTransform=D,G')
60 factory.BookMethod(dataloader, TMVA.Types.kPyKeras, 'PyKeras',
61 'H:!V:VarTransform=D,G:FilenameModel=modelMultiClass.h5:FilenameTrainedModel=trainedModelMultiClass.h5:NumEpochs=20:BatchSize=32')
62
63 # Run TMVA
67
68
69if __name__ == "__main__":
70 # Generate model
72
73 # Setup TMVA
76
77 # Load data
78 if not isfile('tmva_example_multiple_background.root'):
79 createDataMacro = str(gROOT.GetTutorialDir()) + '/tmva/createData.C'
80 print(createDataMacro)
81 gROOT.ProcessLine('.L {}'.format(createDataMacro))
82 gROOT.ProcessLine('create_MultipleBackground(4000)')
83
84 # Run TMVA
85 run()
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t Atom_t Time_t format
A specialized string object used for TTree selections.
Definition TCut.h:25
This is the main MVA steering class.
Definition Factory.h:80