Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
TMVA_RNN_Classification.py
Go to the documentation of this file.
1## \file
2## \ingroup tutorial_tmva
3## \notebook
4## TMVA Classification Example Using a Recurrent Neural Network
5##
6## This is an example of using a RNN in TMVA. We do classification using a toy time dependent data set
7## that is generated when running this example macro
8##
9## \macro_image
10## \macro_output
11## \macro_code
12##
13## \author Harshal Shende
14
15
16# TMVA Classification Example Using a Recurrent Neural Network
17
18# This is an example of using a RNN in TMVA.
19# We do the classification using a toy data set containing a time series of data sample ntimes
20# and with dimension ndim that is generated when running the provided function `MakeTimeData (nevents, ntime, ndim)`
21
22
23import ROOT
24
25num_threads = 4 # use max 4 threads
26# do enable MT running
27if "imt" in ROOT.gROOT.GetConfigFeatures():
28 ROOT.EnableImplicitMT(num_threads)
29 ROOT.gSystem.Setenv("OMP_NUM_THREADS", "1") # switch OFF MT in OpenBLAS
30 print("Running with nthreads = {}".format(ROOT.GetThreadPoolSize()))
31else:
32 print("Running in serial mode since ROOT does not support MT")
33
34
35TMVA = ROOT.TMVA
36TFile = ROOT.TFile
37
38import os
39import importlib
40
41
44
45
46## Helper function to generate the time data set
47## make some time data but not of fixed length.
48## use a poisson with mu = 5 and truncated at 10
49
50
51def MakeTimeData(n, ntime, ndim):
52 # ntime = 10;
53 # ndim = 30; // number of dim/time
54
55 fname = "time_data_t" + str(ntime) + "_d" + str(ndim) + ".root"
56 v1 = []
57 v2 = []
58
59 for i in range(ntime):
60 v1.append(ROOT.TH1D("h1_" + str(i), "h1", ndim, 0, 10))
61 v2.append(ROOT.TH1D("h2_" + str(i), "h2", ndim, 0, 10))
62
63 f1 = ROOT.TF1("f1", "gaus")
64 f2 = ROOT.TF1("f2", "gaus")
65
66 sgn = ROOT.TTree("sgn", "sgn")
67 bkg = ROOT.TTree("bkg", "bkg")
68 f = TFile(fname, "RECREATE")
69
70 x1 = []
71 x2 = []
72
73 for i in range(ntime):
74 x1.append(ROOT.std.vector["float"](ndim))
75 x2.append(ROOT.std.vector["float"](ndim))
76
77 for i in range(ntime):
78 bkg.Branch("vars_time" + str(i), "std::vector<float>", x1[i])
79 sgn.Branch("vars_time" + str(i), "std::vector<float>", x2[i])
80
81 sgn.SetDirectory(f)
82 bkg.SetDirectory(f)
83 ROOT.gRandom.SetSeed(0)
84
85 mean1 = ROOT.std.vector["double"](ntime)
86 mean2 = ROOT.std.vector["double"](ntime)
87 sigma1 = ROOT.std.vector["double"](ntime)
88 sigma2 = ROOT.std.vector["double"](ntime)
89
90 for j in range(ntime):
91 mean1[j] = 5.0 + 0.2 * ROOT.TMath.Sin(ROOT.TMath.Pi() * j / float(ntime))
92 mean2[j] = 5.0 + 0.2 * ROOT.TMath.Cos(ROOT.TMath.Pi() * j / float(ntime))
93 sigma1[j] = 4 + 0.3 * ROOT.TMath.Sin(ROOT.TMath.Pi() * j / float(ntime))
94 sigma2[j] = 4 + 0.3 * ROOT.TMath.Cos(ROOT.TMath.Pi() * j / float(ntime))
95
96 for i in range(n):
97 if i % 1000 == 0:
98 print("Generating event ... %d", i)
99
100 for j in range(ntime):
101 h1 = v1[j]
102 h2 = v2[j]
103 h1.Reset()
104 h2.Reset()
105
106 f1.SetParameters(1, mean1[j], sigma1[j])
107 f2.SetParameters(1, mean2[j], sigma2[j])
108
109 h1.FillRandom("f1", 1000)
110 h2.FillRandom("f2", 1000)
111
112 for k in range(ntime):
113 # std::cout << j*10+k << " ";
114 x1[j][k] = h1.GetBinContent(k + 1) + ROOT.gRandom.Gaus(0, 10)
115 x2[j][k] = h2.GetBinContent(k + 1) + ROOT.gRandom.Gaus(0, 10)
116
117 sgn.Fill()
118 bkg.Fill()
119
120 if n == 1:
121 c1 = ROOT.TCanvas()
122 c1.Divide(ntime, 2)
123 for j in range(ntime):
124 c1.cd(j + 1)
125 v1[j].Draw()
126 for j in range(ntime):
127 c1.cd(ntime + j + 1)
128 v2[j].Draw()
129
130 ROOT.gPad.Update()
131
132 if n > 1:
133 sgn.Write()
134 bkg.Write()
135 sgn.Print()
136 bkg.Print()
137 f.Close()
138
139
140## macro for performing a classification using a Recurrent Neural Network
141## @param use_type
142## use_type = 0 use Simple RNN network
143## use_type = 1 use LSTM network
144## use_type = 2 use GRU
145## use_type = 3 build 3 different networks with RNN, LSTM and GRU
146
147
148use_type = 1
149ninput = 30
150ntime = 10
151batchSize = 100
152maxepochs = 10
153
154nTotEvts = 2000 # total events to be generated for signal or background
155
156useKeras = True
157
158useTMVA_RNN = True
159useTMVA_DNN = True
160useTMVA_BDT = False
161
162tf_spec = importlib.util.find_spec("tensorflow")
163if tf_spec is None:
164 useKeras = False
165 ROOT.Warning("TMVA_RNN_Classificaton","Skip using Keras since tensorflow is not installed")
166
167
168rnn_types = ["RNN", "LSTM", "GRU"]
169use_rnn_type = [1, 1, 1]
170
171if 0 <= use_type < 3:
172 use_rnn_type = [0, 0, 0]
173 use_rnn_type[use_type] = 1
174
175useGPU = True # use GPU for TMVA if available
176
177useGPU = "tmva-gpu" in ROOT.gROOT.GetConfigFeatures()
178useTMVA_RNN = ("tmva-cpu" in ROOT.gROOT.GetConfigFeatures()) or useGPU
179
180if useTMVA_RNN:
181 ROOT.Warning(
182 "TMVA_RNN_Classification",
183 "TMVA is not build with GPU or CPU multi-thread support. Cannot use TMVA Deep Learning for RNN",
184 )
185
186archString = "GPU" if useGPU else "CPU"
187
188writeOutputFile = True
189
190rnn_type = "RNN"
191
192if "tmva-pymva" in ROOT.gROOT.GetConfigFeatures():
194else:
195 useKeras = False
196
197
198
199inputFileName = "time_data_t10_d30.root"
200
201fileDoesNotExist = ROOT.gSystem.AccessPathName(inputFileName)
202
203# if file does not exists create it
204if fileDoesNotExist:
205 MakeTimeData(nTotEvts, ntime, ninput)
206
207
208inputFile = TFile.Open(inputFileName)
209if inputFile is None:
210 raise ROOT.Error("Error opening input file %s - exit", inputFileName.Data())
211
212
213print("--- RNNClassification : Using input file: {}".format(inputFile.GetName()))
214
215# Create a ROOT output file where TMVA will store ntuples, histograms, etc.
216outfileName = "data_RNN_" + archString + ".root"
217outputFile = None
218
219
220if writeOutputFile:
221 outputFile = TFile.Open(outfileName, "RECREATE")
222
223
224## Declare Factory
225
226# Create the Factory class. Later you can choose the methods
227# whose performance you'd like to investigate.
228
229# The factory is the major TMVA object you have to interact with. Here is the list of parameters you need to
230# pass
231
232# - The first argument is the base of the name of all the output
233# weightfiles in the directory weight/ that will be created with the
234# method parameters
235
236# - The second argument is the output file for the training results
237#
238# - The third argument is a string option defining some general configuration for the TMVA session.
239# For example all TMVA output can be suppressed by removing the "!" (not) in front of the "Silent" argument in
240# the option string
241
242
243# // Creating the factory object
244factory = TMVA.Factory(
245 "TMVAClassification",
246 outputFile,
247 V=False,
248 Silent=False,
249 Color=True,
250 DrawProgressBar=True,
251 Transformations=None,
252 Correlations=False,
253 AnalysisType="Classification",
254 ModelPersistence=True,
255)
256dataloader = TMVA.DataLoader("dataset")
257
258signalTree = inputFile.Get("sgn")
259background = inputFile.Get("bkg")
260
261nvar = ninput * ntime
262
263## add variables - use new AddVariablesArray function
264for i in range(ntime):
265 dataloader.AddVariablesArray("vars_time" + str(i), ninput)
266
267
268dataloader.AddSignalTree(signalTree, 1.0)
269dataloader.AddBackgroundTree(background, 1.0)
270
271# check given input
272datainfo = dataloader.GetDataSetInfo()
273vars = datainfo.GetListOfVariables()
274print("number of variables is {}".format(vars.size()))
275
276
277for v in vars:
278 print(v)
279
280nTrainSig = 0.8 * nTotEvts
281nTrainBkg = 0.8 * nTotEvts
282
283# Apply additional cuts on the signal and background samples (can be different)
284mycuts = "" # for example: TCut mycuts = "abs(var1)<0.5 && abs(var2-0.5)<1";
285mycutb = ""
286
287# build the string options for DataLoader::PrepareTrainingAndTestTree
288dataloader.PrepareTrainingAndTestTree(
289 mycuts,
290 mycutb,
291 nTrain_Signal=nTrainSig,
292 nTrain_Background=nTrainBkg,
293 SplitMode="Random",
294 SplitSeed=100,
295 NormMode="NumEvents",
296 V=False,
297 CalcCorrelations=False,
298)
299
300print("prepared DATA LOADER ")
301
302
303## Book TMVA recurrent models
304
305# Book the different types of recurrent models in TMVA (SimpleRNN, LSTM or GRU)
306
307
308if useTMVA_RNN:
309 for i in range(3):
310 if not use_rnn_type[i]:
311 continue
312
313 rnn_type = rnn_types[i]
314
315 ## Define RNN layer layout
316 ## it should be LayerType (RNN or LSTM or GRU) | number of units | number of inputs | time steps | remember output (typically no=0 | return full sequence
317 rnnLayout = str(rnn_type) + "|10|" + str(ninput) + "|" + str(ntime) + "|0|1,RESHAPE|FLAT,DENSE|64|TANH,LINEAR"
318
319 ## Defining Training strategies. Different training strings can be concatenate. Use however only one
320 trainingString1 = "LearningRate=1e-3,Momentum=0.0,Repetitions=1,ConvergenceSteps=5,BatchSize=" + str(batchSize)
321 trainingString1 += ",TestRepetitions=1,WeightDecay=1e-2,Regularization=None,MaxEpochs=" + str(maxepochs)
322 trainingString1 += "Optimizer=ADAM,DropConfig=0.0+0.+0.+0."
323
324 ## define the inputlayout string for RNN
325 ## the input data should be organize as following:
326 ##/ input layout for RNN: time x ndim
327 ## add after RNN a reshape layer (needed top flatten the output) and a dense layer with 64 units and a last one
328 ## Note the last layer is linear because when using Crossentropy a Sigmoid is applied already
329 ## Define the full RNN Noption string adding the final options for all network
330 rnnName = "TMVA_" + str(rnn_type)
331 factory.BookMethod(
332 dataloader,
333 TMVA.Types.kDL,
334 rnnName,
335 H=False,
336 V=True,
337 ErrorStrategy="CROSSENTROPY",
338 VarTransform=None,
339 WeightInitialization="XAVIERUNIFORM",
340 ValidationSize=0.2,
341 RandomSeed=1234,
342 InputLayout=str(ntime) + "|" + str(ninput),
343 Layout=rnnLayout,
344 TrainingStrategy=trainingString1,
345 Architecture=archString
346 )
347
348
349## Book TMVA fully connected dense layer models
350if useTMVA_DNN:
351 # Method DL with Dense Layer
352 # Training strategies.
353 trainingString1 = ROOT.TString(
354 "LearningRate=1e-3,Momentum=0.0,Repetitions=1,"
355 "ConvergenceSteps=10,BatchSize=256,TestRepetitions=1,"
356 "WeightDecay=1e-4,Regularization=None,MaxEpochs=20"
357 "DropConfig=0.0+0.+0.+0.,Optimizer=ADAM:"
358 ) # + "|" + trainingString2
359 # General Options.
360 trainingString1.Append(archString)
361 dnnName = "TMVA_DNN"
362 factory.BookMethod(
363 dataloader,
364 TMVA.Types.kDL,
365 dnnName,
366 H=False,
367 V=True,
368 ErrorStrategy="CROSSENTROPY",
369 VarTransform=None,
370 WeightInitialization="XAVIER",
371 RandomSeed=0,
372 InputLayout="1|1|" + str(ntime * ninput),
373 Layout="DENSE|64|TANH,DENSE|TANH|64,DENSE|TANH|64,LINEAR",
374 TrainingStrategy=trainingString1
375 )
376
377
378## Book Keras recurrent models
379
380# Book the different types of recurrent models in Keras (SimpleRNN, LSTM or GRU)
381
382
383if useKeras:
384 for i in range(3):
385 if use_rnn_type[i]:
386 modelName = "model_" + rnn_types[i] + ".h5"
387 trainedModelName = "trained_" + modelName
388 print("Building recurrent keras model using a", rnn_types[i], "layer")
389 # create python script which can be executed
390 # create 2 conv2d layer + maxpool + dense
391 from tensorflow.keras.models import Sequential
392 from tensorflow.keras.optimizers import Adam
393
394 # from keras.initializers import TruncatedNormal
395 # from keras import initializations
396 from tensorflow.keras.layers import Input, Dense, Dropout, Flatten, SimpleRNN, GRU, LSTM, Reshape, BatchNormalization
397
398 model = Sequential()
399 model.add(Reshape((10, 30), input_shape=(10 * 30,)))
400 # add recurrent neural network depending on type / Use option to return the full output
401 if rnn_types[i] == "LSTM":
402 model.add(LSTM(units=10, return_sequences=True))
403 elif rnn_types[i] == "GRU":
404 model.add(GRU(units=10, return_sequences=True))
405 else:
406 model.add(SimpleRNN(units=10, return_sequences=True))
407 # m.AddLine("model.add(BatchNormalization())");
408 model.add(Flatten()) # needed if returning the full time output sequence
409 model.add(Dense(64, activation="tanh"))
410 model.add(Dense(2, activation="sigmoid"))
411 model.compile(loss="binary_crossentropy", optimizer=Adam(learning_rate=0.001), weighted_metrics=["accuracy"])
412 model.save(modelName)
413 model.summary()
414 print("saved recurrent model", modelName)
415
416 if not os.path.exists(modelName):
417 useKeras = False
418 print("Error creating Keras recurrent model file - Skip using Keras")
419 else:
420 # book PyKeras method only if Keras model could be created
421 print("Booking Keras model ", rnn_types[i])
422 factory.BookMethod(
423 dataloader,
424 TMVA.Types.kPyKeras,
425 "PyKeras_" + rnn_types[i],
426 H=True,
427 V=False,
428 VarTransform=None,
429 FilenameModel=modelName,
430 FilenameTrainedModel="trained_" + modelName,
431 NumEpochs=maxepochs,
432 BatchSize=batchSize,
433 GpuOptions="allow_growth=True",
434 )
435
436
437# use BDT in case not using Keras or TMVA DL
438if not useKeras or not useTMVA_BDT:
439 useTMVA_BDT = True
440
441
442## Book TMVA BDT
443
444
445if useTMVA_BDT:
446 factory.BookMethod(
447 dataloader,
448 TMVA.Types.kBDT,
449 "BDTG",
450 H=True,
451 V=False,
452 NTrees=100,
453 MinNodeSize="2.5%",
454 BoostType="Grad",
455 Shrinkage=0.10,
456 UseBaggedBoost=True,
457 BaggedSampleFraction=0.5,
458 nCuts=20,
459 MaxDepth=2,
460 )
461
462
463## Train all methods
464factory.TrainAllMethods()
465
466print("nthreads = {}".format(ROOT.GetThreadPoolSize()))
467
468# ---- Evaluate all MVAs using the set of test events
469factory.TestAllMethods()
470
471# ----- Evaluate and compare performance of all configured MVAs
472factory.EvaluateAllMethods()
473
474# check method
475
476# plot ROC curve
477c1 = factory.GetROCCurve(dataloader)
478c1.Draw()
479
480if outputFile:
481 outputFile.Close()
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t Atom_t Time_t format
A ROOT file is a suite of consecutive data records (TKey instances) with a well defined format.
Definition TFile.h:51
static TFile * Open(const char *name, Option_t *option="", const char *ftitle="", Int_t compress=ROOT::RCompressionSetting::EDefaults::kUseCompiledDefault, Int_t netopt=0)
Create / open a file.
Definition TFile.cxx:4053
static Config & Instance()
static function: returns TMVA instance
Definition Config.cxx:98
This is the main MVA steering class.
Definition Factory.h:80
static void PyInitialize()
Initialize Python interpreter.
static Tools & Instance()
Definition Tools.cxx:71
void EnableImplicitMT(UInt_t numthreads=0)
Enable ROOT's implicit multi-threading for all objects and methods that provide an internal paralleli...
Definition TROOT.cxx:527
UInt_t GetThreadPoolSize()
Returns the size of ROOT's thread pool.
Definition TROOT.cxx:565
th1 Draw()