Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
TMVA_RNN_Classification.py
Go to the documentation of this file.
1## \file
2## \ingroup tutorial_tmva
3## \notebook
4## TMVA Classification Example Using a Recurrent Neural Network
5##
6## This is an example of using a RNN in TMVA. We do classification using a toy time dependent data set
7## that is generated when running this example macro
8##
9## \macro_image
10## \macro_output
11## \macro_code
12##
13## \author Harshal Shende
14
15
16# TMVA Classification Example Using a Recurrent Neural Network
17
18# This is an example of using a RNN in TMVA.
19# We do the classification using a toy data set containing a time series of data sample ntimes
20# and with dimension ndim that is generated when running the provided function `MakeTimeData (nevents, ntime, ndim)`
21
22
23import ROOT
24
25num_threads = 4 # use max 4 threads
26# do enable MT running
28 ROOT.EnableImplicitMT(num_threads)
29 # switch off MT in OpenBLAS to avoid conflict with tbb
30 ROOT.gSystem.Setenv("OMP_NUM_THREADS", "1")
31 print("Running with nthreads = {}".format(ROOT.GetThreadPoolSize()))
32else:
33 print("Running in serial mode since ROOT does not support MT")
34
35
36TMVA = ROOT.TMVA
37TFile = ROOT.TFile
38
39import os
40import importlib
41
42
45
46
47## Helper function to generate the time data set
48## make some time data but not of fixed length.
49## use a poisson with mu = 5 and truncated at 10
50
51
52def MakeTimeData(n, ntime, ndim):
53 # ntime = 10;
54 # ndim = 30; // number of dim/time
55
56 fname = "time_data_t" + str(ntime) + "_d" + str(ndim) + ".root"
57 v1 = []
58 v2 = []
59
60 for i in range(ntime):
61 v1.append(ROOT.TH1D("h1_" + str(i), "h1", ndim, 0, 10))
62 v2.append(ROOT.TH1D("h2_" + str(i), "h2", ndim, 0, 10))
63
64 f1 = ROOT.TF1("f1", "gaus")
65 f2 = ROOT.TF1("f2", "gaus")
66
67 sgn = ROOT.TTree("sgn", "sgn")
68 bkg = ROOT.TTree("bkg", "bkg")
69 f = TFile(fname, "RECREATE")
70
71 x1 = []
72 x2 = []
73
74 for i in range(ntime):
75 x1.append(ROOT.std.vector["float"](ndim))
76 x2.append(ROOT.std.vector["float"](ndim))
77
78 for i in range(ntime):
79 bkg.Branch("vars_time" + str(i), "std::vector<float>", x1[i])
80 sgn.Branch("vars_time" + str(i), "std::vector<float>", x2[i])
81
85
86 mean1 = ROOT.std.vector["double"](ntime)
87 mean2 = ROOT.std.vector["double"](ntime)
88 sigma1 = ROOT.std.vector["double"](ntime)
89 sigma2 = ROOT.std.vector["double"](ntime)
90
91 for j in range(ntime):
92 mean1[j] = 5.0 + 0.2 * ROOT.TMath.Sin(ROOT.TMath.Pi() * j / float(ntime))
93 mean2[j] = 5.0 + 0.2 * ROOT.TMath.Cos(ROOT.TMath.Pi() * j / float(ntime))
94 sigma1[j] = 4 + 0.3 * ROOT.TMath.Sin(ROOT.TMath.Pi() * j / float(ntime))
95 sigma2[j] = 4 + 0.3 * ROOT.TMath.Cos(ROOT.TMath.Pi() * j / float(ntime))
96
97 for i in range(n):
98 if i % 1000 == 0:
99 print("Generating event ... %d", i)
100
101 for j in range(ntime):
102 h1 = v1[j]
103 h2 = v2[j]
104 h1.Reset()
105 h2.Reset()
106
107 f1.SetParameters(1, mean1[j], sigma1[j])
108 f2.SetParameters(1, mean2[j], sigma2[j])
109
110 h1.FillRandom("f1", 1000)
111 h2.FillRandom("f2", 1000)
112
113 for k in range(ntime):
114 # std::cout << j*10+k << " ";
115 x1[j][k] = h1.GetBinContent(k + 1) + ROOT.gRandom.Gaus(0, 10)
116 x2[j][k] = h2.GetBinContent(k + 1) + ROOT.gRandom.Gaus(0, 10)
117
118 sgn.Fill()
119 bkg.Fill()
120
121 if n == 1:
122 c1 = ROOT.TCanvas()
123 c1.Divide(ntime, 2)
124 for j in range(ntime):
125 c1.cd(j + 1)
126 v1[j].Draw()
127 for j in range(ntime):
128 c1.cd(ntime + j + 1)
129 v2[j].Draw()
130
132
133 if n > 1:
134 sgn.Write()
135 bkg.Write()
136 sgn.Print()
137 bkg.Print()
138 f.Close()
139
140
141## macro for performing a classification using a Recurrent Neural Network
142## @param use_type
143## use_type = 0 use Simple RNN network
144## use_type = 1 use LSTM network
145## use_type = 2 use GRU
146## use_type = 3 build 3 different networks with RNN, LSTM and GRU
147
148
149use_type = 1
150ninput = 30
151ntime = 10
152batchSize = 100
153maxepochs = 10
154
155nTotEvts = 2000 # total events to be generated for signal or background
156
157useKeras = False
158
159useTMVA_RNN = True
160useTMVA_DNN = True
161useTMVA_BDT = False
162
163if ROOT.gSystem.GetFromPipe("root-config --has-tmva-pymva") == "yes":
164 useKeras = True
165
166
167rnn_types = ["RNN", "LSTM", "GRU"]
168use_rnn_type = [1, 1, 1]
169
170if 0 <= use_type < 3:
171 use_rnn_type = [0, 0, 0]
172 use_rnn_type[use_type] = 1
173
174useGPU = True # use GPU for TMVA if available
175
176useGPU = "tmva-gpu" in ROOT.gROOT.GetConfigFeatures()
177useTMVA_RNN = ("tmva-cpu" in ROOT.gROOT.GetConfigFeatures()) or useGPU
178
179if useTMVA_RNN:
181 "TMVA_RNN_Classification",
182 "TMVA is not build with GPU or CPU multi-thread support. Cannot use TMVA Deep Learning for RNN",
183 )
184
185archString = "GPU" if useGPU else "CPU"
186
187writeOutputFile = True
188
189rnn_type = "RNN"
190
191if "tmva-pymva" in ROOT.gROOT.GetConfigFeatures():
193else:
194 useKeras = False
195
196
197
198inputFileName = "time_data_t10_d30.root"
199
200fileDoesNotExist = ROOT.gSystem.AccessPathName(inputFileName)
201
202# if file does not exists create it
203if fileDoesNotExist:
204 MakeTimeData(nTotEvts, ntime, ninput)
205
206
207inputFile = TFile.Open(inputFileName)
208if inputFile is None:
209 raise ROOT.Error("Error opening input file %s - exit", inputFileName.Data())
210
211
212print("--- RNNClassification : Using input file: {}".format(inputFile.GetName()))
213
214# Create a ROOT output file where TMVA will store ntuples, histograms, etc.
215outfileName = "data_RNN_" + archString + ".root"
216outputFile = None
217
218
219if writeOutputFile:
220 outputFile = TFile.Open(outfileName, "RECREATE")
221
222
223## Declare Factory
224
225# Create the Factory class. Later you can choose the methods
226# whose performance you'd like to investigate.
227
228# The factory is the major TMVA object you have to interact with. Here is the list of parameters you need to
229# pass
230
231# - The first argument is the base of the name of all the output
232# weightfiles in the directory weight/ that will be created with the
233# method parameters
234
235# - The second argument is the output file for the training results
236#
237# - The third argument is a string option defining some general configuration for the TMVA session.
238# For example all TMVA output can be suppressed by removing the "!" (not) in front of the "Silent" argument in
239# the option string
240
241
242# // Creating the factory object
243factory = TMVA.Factory(
244 "TMVAClassification",
245 outputFile,
246 V=False,
247 Silent=False,
248 Color=True,
249 DrawProgressBar=True,
250 Transformations=None,
251 Correlations=False,
252 AnalysisType="Classification",
253 ModelPersistence=True,
254)
255dataloader = TMVA.DataLoader("dataset")
256
257signalTree = inputFile.Get("sgn")
258background = inputFile.Get("bkg")
259
260nvar = ninput * ntime
261
262## add variables - use new AddVariablesArray function
263for i in range(ntime):
264 dataloader.AddVariablesArray("vars_time" + str(i), ninput)
265
266
267dataloader.AddSignalTree(signalTree, 1.0)
268dataloader.AddBackgroundTree(background, 1.0)
269
270# check given input
271datainfo = dataloader.GetDataSetInfo()
273print("number of variables is {}".format(vars.size()))
274
275
276for v in vars:
277 print(v)
278
279nTrainSig = 0.8 * nTotEvts
280nTrainBkg = 0.8 * nTotEvts
281
282# Apply additional cuts on the signal and background samples (can be different)
283mycuts = "" # for example: TCut mycuts = "abs(var1)<0.5 && abs(var2-0.5)<1";
284mycutb = ""
285
286# build the string options for DataLoader::PrepareTrainingAndTestTree
288 mycuts,
289 mycutb,
290 nTrain_Signal=nTrainSig,
291 nTrain_Background=nTrainBkg,
292 SplitMode="Random",
293 SplitSeed=100,
294 NormMode="NumEvents",
295 V=False,
296 CalcCorrelations=False,
297)
298
299print("prepared DATA LOADER ")
300
301
302## Book TMVA recurrent models
303
304# Book the different types of recurrent models in TMVA (SimpleRNN, LSTM or GRU)
305
306
307if useTMVA_RNN:
308 for i in range(3):
309 if not use_rnn_type[i]:
310 continue
311
312 rnn_type = rnn_types[i]
313
314 ## Define RNN layer layout
315 ## it should be LayerType (RNN or LSTM or GRU) | number of units | number of inputs | time steps | remember output (typically no=0 | return full sequence
316 rnnLayout = str(rnn_type) + "|10|" + str(ninput) + "|" + str(ntime) + "|0|1,RESHAPE|FLAT,DENSE|64|TANH,LINEAR"
317
318 ## Defining Training strategies. Different training strings can be concatenate. Use however only one
319 trainingString1 = "LearningRate=1e-3,Momentum=0.0,Repetitions=1,ConvergenceSteps=5,BatchSize=" + str(batchSize)
320 trainingString1 += ",TestRepetitions=1,WeightDecay=1e-2,Regularization=None,MaxEpochs=" + str(maxepochs)
321 trainingString1 += "Optimizer=ADAM,DropConfig=0.0+0.+0.+0."
322
323 ## define the inputlayout string for RNN
324 ## the input data should be organize as following:
325 ##/ input layout for RNN: time x ndim
326 ## add after RNN a reshape layer (needed top flatten the output) and a dense layer with 64 units and a last one
327 ## Note the last layer is linear because when using Crossentropy a Sigmoid is applied already
328 ## Define the full RNN Noption string adding the final options for all network
329 rnnName = "TMVA_" + str(rnn_type)
331 dataloader,
333 rnnName,
334 H=False,
335 V=True,
336 ErrorStrategy="CROSSENTROPY",
337 VarTransform=None,
338 WeightInitialization="XAVIERUNIFORM",
339 ValidationSize=0.2,
340 RandomSeed=1234,
341 InputLayout=str(ntime) + "|" + str(ninput),
342 Layout=rnnLayout,
343 TrainingStrategy=trainingString1,
344 Architecture=archString
345 )
346
347
348## Book TMVA fully connected dense layer models
349if useTMVA_DNN:
350 # Method DL with Dense Layer
351 # Training strategies.
352 trainingString1 = ROOT.TString(
353 "LearningRate=1e-3,Momentum=0.0,Repetitions=1,"
354 "ConvergenceSteps=10,BatchSize=256,TestRepetitions=1,"
355 "WeightDecay=1e-4,Regularization=None,MaxEpochs=20"
356 "DropConfig=0.0+0.+0.+0.,Optimizer=ADAM:"
357 ) # + "|" + trainingString2
358 # General Options.
359 trainingString1.Append(archString)
360 dnnName = "TMVA_DNN"
362 dataloader,
364 dnnName,
365 H=False,
366 V=True,
367 ErrorStrategy="CROSSENTROPY",
368 VarTransform=None,
369 WeightInitialization="XAVIER",
370 RandomSeed=0,
371 InputLayout="1|1|" + str(ntime * ninput),
372 Layout="DENSE|64|TANH,DENSE|TANH|64,DENSE|TANH|64,LINEAR",
373 TrainingStrategy=trainingString1
374 )
375
376
377## Book Keras recurrent models
378
379# Book the different types of recurrent models in Keras (SimpleRNN, LSTM or GRU)
380
381
382if useKeras:
383 for i in range(3):
384 if use_rnn_type[i]:
385 modelName = "model_" + rnn_types[i] + ".h5"
386 trainedModelName = "trained_" + modelName
387 print("Building recurrent keras model using a", rnn_types[i], "layer")
388 # create python script which can be executed
389 # create 2 conv2d layer + maxpool + dense
390 from tensorflow.keras.models import Sequential
391 from tensorflow.keras.optimizers import Adam
392
393 # from keras.initializers import TruncatedNormal
394 # from keras import initializations
395 from tensorflow.keras.layers import Input, Dense, Dropout, Flatten, SimpleRNN, GRU, LSTM, Reshape, BatchNormalization
396
397 model = Sequential()
398 model.add(Reshape((10, 30), input_shape=(10 * 30,)))
399 # add recurrent neural network depending on type / Use option to return the full output
400 if rnn_types[i] == "LSTM":
401 model.add(LSTM(units=10, return_sequences=True))
402 elif rnn_types[i] == "GRU":
403 model.add(GRU(units=10, return_sequences=True))
404 else:
405 model.add(SimpleRNN(units=10, return_sequences=True))
406 # m.AddLine("model.add(BatchNormalization())");
407 model.add(Flatten()) # needed if returning the full time output sequence
408 model.add(Dense(64, activation="tanh"))
409 model.add(Dense(2, activation="sigmoid"))
410 model.compile(loss="binary_crossentropy", optimizer=Adam(learning_rate=0.001), weighted_metrics=["accuracy"])
411 model.save(modelName)
413 print("saved recurrent model", modelName)
414
415 if not os.path.exists(modelName):
416 useKeras = False
417 print("Error creating Keras recurrent model file - Skip using Keras")
418 else:
419 # book PyKeras method only if Keras model could be created
420 print("Booking Keras model ", rnn_types[i])
422 dataloader,
424 "PyKeras_" + rnn_types[i],
425 H=True,
426 V=False,
427 VarTransform=None,
428 FilenameModel=modelName,
429 FilenameTrainedModel="trained_" + modelName,
430 NumEpochs=maxepochs,
431 BatchSize=batchSize,
432 GpuOptions="allow_growth=True",
433 )
434
435
436# use BDT in case not using Keras or TMVA DL
437if not useKeras or not useTMVA_BDT:
438 useTMVA_BDT = True
439
440
441## Book TMVA BDT
442
443
444if useTMVA_BDT:
446 dataloader,
448 "BDTG",
449 H=True,
450 V=False,
451 NTrees=100,
452 MinNodeSize="2.5%",
453 BoostType="Grad",
454 Shrinkage=0.10,
455 UseBaggedBoost=True,
456 BaggedSampleFraction=0.5,
457 nCuts=20,
458 MaxDepth=2,
459 )
460
461
462## Train all methods
464
465print("nthreads = {}".format(ROOT.GetThreadPoolSize()))
466
467# ---- Evaluate all MVAs using the set of test events
469
470# ----- Evaluate and compare performance of all configured MVAs
472
473# check method
474
475# plot ROC curve
476c1 = factory.GetROCCurve(dataloader)
477c1.Draw()
478
479if outputFile:
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t Atom_t Time_t format
A ROOT file is an on-disk file, usually with extension .root, that stores objects in a file-system-li...
Definition TFile.h:53
This is the main MVA steering class.
Definition Factory.h:80
th1 Draw()