Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
TMVA_CNN_Classification.py
Go to the documentation of this file.
1## \file
2## \ingroup tutorial_tmva
3## \notebook
4## TMVA Classification Example Using a Convolutional Neural Network
5##
6## This is an example of using a CNN in TMVA. We do classification using a toy image data set
7## that is generated when running the example macro
8##
9## \macro_image
10## \macro_output
11## \macro_code
12##
13## \author Harshal Shende
14
15
16# TMVA Classification Example Using a Convolutional Neural Network
17
18
19## Helper function to create input images data
20## we create a signal and background 2D histograms from 2d gaussians
21## with a location (means in X and Y) different for each event
22## The difference between signal and background is in the gaussian width.
23## The width for the background gaussian is slightly larger than the signal width by few % values
24
25import os
26import importlib.util
27
28opt = [1, 1, 1, 1, 1]
29useTMVACNN = opt[0] if len(opt) > 0 else False
30useKerasCNN = opt[1] if len(opt) > 1 else False
31useTMVADNN = opt[2] if len(opt) > 2 else False
32useTMVABDT = opt[3] if len(opt) > 3 else False
33usePyTorchCNN = opt[4] if len(opt) > 4 else False
34
35tf_spec = importlib.util.find_spec("tensorflow")
36if tf_spec is None:
37 useKerasCNN = False
38 print("TMVA_CNN_Classificaton","Skip using Keras since tensorflow is not installed")
39else:
40 import tensorflow
41
42# PyTorch has to be imported before ROOT to avoid crashes because of clashing
43# std::regexp symbols that are exported by cppyy.
44# See also: https://github.com/wlav/cppyy/issues/227
45torch_spec = importlib.util.find_spec("torch")
46if torch_spec is None:
47 usePyTorchCNN = False
48 print("TMVA_CNN_Classificaton","Skip using PyTorch since torch is not installed")
49else:
50 import torch
51
52
53import ROOT
54
55#switch off MT in OpenMP (BLAS)
56ROOT.gSystem.Setenv("OMP_NUM_THREADS", "1")
57
58TMVA = ROOT.TMVA
59TFile = ROOT.TFile
60
62
63def MakeImagesTree(n, nh, nw):
64 # image size (nh x nw)
65 ntot = nh * nw
66 fileOutName = "images_data_16x16.root"
67 nRndmEvts = 10000 # number of events we use to fill each image
68 delta_sigma = 0.1 # 5% difference in the sigma
69 pixelNoise = 5
70
71 sX1 = 3
72 sY1 = 3
73 sX2 = sX1 + delta_sigma
74 sY2 = sY1 - delta_sigma
75 h1 = ROOT.TH2D("h1", "h1", nh, 0, 10, nw, 0, 10)
76 h2 = ROOT.TH2D("h2", "h2", nh, 0, 10, nw, 0, 10)
77 f1 = ROOT.TF2("f1", "xygaus")
78 f2 = ROOT.TF2("f2", "xygaus")
79 sgn = ROOT.TTree("sig_tree", "signal_tree")
80 bkg = ROOT.TTree("bkg_tree", "background_tree")
81
82 f = TFile(fileOutName, "RECREATE")
83 x1 = ROOT.std.vector["float"](ntot)
84 x2 = ROOT.std.vector["float"](ntot)
85
86 # create signal and background trees with a single branch
87 # an std::vector<float> of size nh x nw containing the image data
88 bkg.Branch("vars", "std::vector<float>", x1)
89 sgn.Branch("vars", "std::vector<float>", x2)
90
91 sgn.SetDirectory(f)
92 bkg.SetDirectory(f)
93
94 f1.SetParameters(1, 5, sX1, 5, sY1)
95 f2.SetParameters(1, 5, sX2, 5, sY2)
96 ROOT.gRandom.SetSeed(0)
97 ROOT.Info("TMVA_CNN_Classification", "Filling ROOT tree \n")
98 for i in range(n):
99 if i % 1000 == 0:
100 print("Generating image event ...", i)
101
102 h1.Reset()
103 h2.Reset()
104 # generate random means in range [3,7] to be not too much on the border
105 f1.SetParameter(1, ROOT.gRandom.Uniform(3, 7))
106 f1.SetParameter(3, ROOT.gRandom.Uniform(3, 7))
107 f2.SetParameter(1, ROOT.gRandom.Uniform(3, 7))
108 f2.SetParameter(3, ROOT.gRandom.Uniform(3, 7))
109
110 h1.FillRandom("f1", nRndmEvts)
111 h2.FillRandom("f2", nRndmEvts)
112
113 for k in range(nh):
114 for l in range(nw):
115 m = k * nw + l
116 # add some noise in each bin
117 x1[m] = h1.GetBinContent(k + 1, l + 1) + ROOT.gRandom.Gaus(0, pixelNoise)
118 x2[m] = h2.GetBinContent(k + 1, l + 1) + ROOT.gRandom.Gaus(0, pixelNoise)
119
120 sgn.Fill()
121 bkg.Fill()
122
123 sgn.Write()
124 bkg.Write()
125
126 print("Signal and background tree with images data written to the file %s", f.GetName())
127 sgn.Print()
128 bkg.Print()
129 f.Close()
130
131hasGPU = ROOT.gSystem.GetFromPipe("root-config --has-tmva-gpu") == "yes"
132hasCPU = ROOT.gSystem.GetFromPipe("root-config --has-tmva-cpu") == "yes"
133
134nevt = 1000 # use a larger value to get better results
135
136if (not hasCPU and not hasGPU) :
137 ROOT.Warning("TMVA_CNN_Classificaton","ROOT is not supporting tmva-cpu and tmva-gpu skip using TMVA-DNN and TMVA-CNN")
138 useTMVACNN = False
139 useTMVADNN = False
140
141if ROOT.gSystem.GetFromPipe("root-config --has-tmva-pymva") != "yes":
142 useKerasCNN = False
143 usePyTorchCNN = False
144else:
146
147if not useTMVACNN:
148 ROOT.Warning(
149 "TMVA_CNN_Classificaton",
150 "TMVA is not build with GPU or CPU multi-thread support. Cannot use TMVA Deep Learning for CNN",
151 )
152
153writeOutputFile = True
154
155num_threads = 4 # use default threads
156max_epochs = 10 # maximum number of epochs used for training
157
158
159# do enable MT running
160if num_threads >= 0:
161 ROOT.EnableImplicitMT(num_threads)
162
163print("Running with nthreads = ", ROOT.GetThreadPoolSize())
164
165
166
167outputFile = None
168if writeOutputFile:
169 outputFile = TFile.Open("TMVA_CNN_ClassificationOutput.root", "RECREATE")
170
171
172## Create TMVA Factory
173
174# Create the Factory class. Later you can choose the methods
175# whose performance you'd like to investigate.
176
177# The factory is the major TMVA object you have to interact with. Here is the list of parameters you need to pass
178
179# - The first argument is the base of the name of all the output
180# weight files in the directory weight/ that will be created with the
181# method parameters
182
183# - The second argument is the output file for the training results
184
185# - The third argument is a string option defining some general configuration for the TMVA session.
186# For example all TMVA output can be suppressed by removing the "!" (not) in front of the "Silent" argument in the
187# option string
188
189# - note that we disable any pre-transformation of the input variables and we avoid computing correlations between
190# input variables
191
192
193factory = TMVA.Factory(
194 "TMVA_CNN_Classification",
195 outputFile,
196 V=False,
197 ROC=True,
198 Silent=False,
199 Color=True,
200 AnalysisType="Classification",
201 Transformations=None,
202 Correlations=False,
203)
204
205
206## Declare DataLoader(s)
207
208# The next step is to declare the DataLoader class that deals with input variables
209
210# Define the input variables that shall be used for the MVA training
211# note that you may also use variable expressions, which can be parsed by TTree::Draw( "expression" )]
212
213# In this case the input data consists of an image of 16x16 pixels. Each single pixel is a branch in a ROOT TTree
214
215loader = TMVA.DataLoader("dataset")
216
217
218## Setup Dataset(s)
219
220# Define input data file and signal and background trees
221
222
223imgSize = 16 * 16
224inputFileName = "images_data_16x16.root"
225
226# if the input file does not exist create it
227if ROOT.gSystem.AccessPathName(inputFileName):
228 MakeImagesTree(nevt, 16, 16)
229
230inputFile = TFile.Open(inputFileName)
231if inputFile is None:
232 ROOT.Warning("TMVA_CNN_Classification", "Error opening input file %s - exit", inputFileName.Data())
233
234
235# inputFileName = "tmva_class_example.root"
236
237
238# --- Register the training and test trees
239
240signalTree = inputFile.Get("sig_tree")
241backgroundTree = inputFile.Get("bkg_tree")
242
243nEventsSig = signalTree.GetEntries()
244nEventsBkg = backgroundTree.GetEntries()
245
246# global event weights per tree (see below for setting event-wise weights)
247signalWeight = 1.0
248backgroundWeight = 1.0
249
250# You can add an arbitrary number of signal or background trees
251loader.AddSignalTree(signalTree, signalWeight)
252loader.AddBackgroundTree(backgroundTree, backgroundWeight)
253
254## add event variables (image)
255## use new method (from ROOT 6.20 to add a variable array for all image data)
256loader.AddVariablesArray("vars", imgSize)
257
258# Set individual event weights (the variables must exist in the original TTree)
259# for signal : factory->SetSignalWeightExpression ("weight1*weight2");
260# for background: factory->SetBackgroundWeightExpression("weight1*weight2");
261# loader->SetBackgroundWeightExpression( "weight" );
262
263# Apply additional cuts on the signal and background samples (can be different)
264mycuts = "" # for example: TCut mycuts = "abs(var1)<0.5 && abs(var2-0.5)<1";
265mycutb = "" # for example: TCut mycutb = "abs(var1)<0.5";
266
267# Tell the factory how to use the training and testing events
268# If no numbers of events are given, half of the events in the tree are used
269# for training, and the other half for testing:
270# loader.PrepareTrainingAndTestTree( mycut, "SplitMode=random:!V" );
271# It is possible also to specify the number of training and testing events,
272# note we disable the computation of the correlation matrix of the input variables
273
274nTrainSig = 0.8 * nEventsSig
275nTrainBkg = 0.8 * nEventsBkg
276
277# build the string options for DataLoader::PrepareTrainingAndTestTree
278
279loader.PrepareTrainingAndTestTree(
280 mycuts,
281 mycutb,
282 nTrain_Signal=nTrainSig,
283 nTrain_Background=nTrainBkg,
284 SplitMode="Random",
285 SplitSeed=100,
286 NormMode="NumEvents",
287 V=False,
288 CalcCorrelations=False,
289)
290
291
292# DataSetInfo : [dataset] : Added class "Signal"
293# : Add Tree sig_tree of type Signal with 10000 events
294# DataSetInfo : [dataset] : Added class "Background"
295# : Add Tree bkg_tree of type Background with 10000 events
296
297# signalTree.Print();
298
299# Booking Methods
300
301# Here we book the TMVA methods. We book a Boosted Decision Tree method (BDT)
302
303
304# Boosted Decision Trees
305if useTMVABDT:
306 factory.BookMethod(
307 loader,
308 TMVA.Types.kBDT,
309 "BDT",
310 V=False,
311 NTrees=400,
312 MinNodeSize="2.5%",
313 MaxDepth=2,
314 BoostType="AdaBoost",
315 AdaBoostBeta=0.5,
316 UseBaggedBoost=True,
317 BaggedSampleFraction=0.5,
318 SeparationType="GiniIndex",
319 nCuts=20,
320 )
321
322
323#### Booking Deep Neural Network
324
325# Here we book the DNN of TMVA. See the example TMVA_Higgs_Classification.C for a detailed description of the
326# options
327
328if useTMVADNN:
329 layoutString = ROOT.TString(
330 "DENSE|100|RELU,BNORM,DENSE|100|RELU,BNORM,DENSE|100|RELU,BNORM,DENSE|100|RELU,DENSE|1|LINEAR"
331 )
332
333 # Training strategies
334 # one can catenate several training strings with different parameters (e.g. learning rates or regularizations
335 # parameters) The training string must be concatenated with the `|` delimiter
336 trainingString1 = ROOT.TString(
337 "LearningRate=1e-3,Momentum=0.9,Repetitions=1,"
338 "ConvergenceSteps=5,BatchSize=100,TestRepetitions=1,"
339 "WeightDecay=1e-4,Regularization=None,"
340 "Optimizer=ADAM,DropConfig=0.0+0.0+0.0+0."
341 ) # + "|" + trainingString2 + ...
342 trainingString1 += ",MaxEpochs=" + str(max_epochs)
343
344 # Build now the full DNN Option string
345 dnnMethodName = "TMVA_DNN_CPU"
346
347 # use GPU if available
348 dnnOptions = "CPU"
349 if hasGPU :
350 dnnOptions = "GPU"
351 dnnMethodName = "TMVA_DNN_GPU"
352
353 factory.BookMethod(
354 loader,
355 TMVA.Types.kDL,
356 dnnMethodName,
357 H=False,
358 V=True,
359 ErrorStrategy="CROSSENTROPY",
360 VarTransform=None,
361 WeightInitialization="XAVIER",
362 Layout=layoutString,
363 TrainingStrategy=trainingString1,
364 Architecture=dnnOptions
365 )
366
367
368### Book Convolutional Neural Network in TMVA
369
370# For building a CNN one needs to define
371
372# - Input Layout : number of channels (in this case = 1) | image height | image width
373# - Batch Layout : batch size | number of channels | image size = (height*width)
374
375# Then one add Convolutional layers and MaxPool layers.
376
377# - For Convolutional layer the option string has to be:
378# - CONV | number of units | filter height | filter width | stride height | stride width | padding height | paddig
379# width | activation function
380
381# - note in this case we are using a filer 3x3 and padding=1 and stride=1 so we get the output dimension of the
382# conv layer equal to the input
383
384# - note we use after the first convolutional layer a batch normalization layer. This seems to help significantly the
385# convergence
386
387# - For the MaxPool layer:
388# - MAXPOOL | pool height | pool width | stride height | stride width
389
390# The RESHAPE layer is needed to flatten the output before the Dense layer
391
392# Note that to run the CNN is required to have CPU or GPU support
393
394
395if useTMVACNN:
396 # Training strategies.
397 trainingString1 = ROOT.TString(
398 "LearningRate=1e-3,Momentum=0.9,Repetitions=1,"
399 "ConvergenceSteps=5,BatchSize=100,TestRepetitions=1,"
400 "WeightDecay=1e-4,Regularization=None,"
401 "Optimizer=ADAM,DropConfig=0.0+0.0+0.0+0.0"
402 )
403 trainingString1 += ",MaxEpochs=" + str(max_epochs)
404
405 ## New DL (CNN)
406 cnnMethodName = "TMVA_CNN_CPU"
407 cnnOptions = "CPU"
408 # use GPU if available
409 if hasGPU:
410 cnnOptions = "GPU"
411 cnnMethodName = "TMVA_CNN_GPU"
412
413 factory.BookMethod(
414 loader,
415 TMVA.Types.kDL,
416 cnnMethodName,
417 H=False,
418 V=True,
419 ErrorStrategy="CROSSENTROPY",
420 VarTransform=None,
421 WeightInitialization="XAVIER",
422 InputLayout="1|16|16",
423 Layout="CONV|10|3|3|1|1|1|1|RELU,BNORM,CONV|10|3|3|1|1|1|1|RELU,MAXPOOL|2|2|1|1,RESHAPE|FLAT,DENSE|100|RELU,DENSE|1|LINEAR",
424 TrainingStrategy=trainingString1,
425 Architecture=cnnOptions,
426 )
427
428
429### Book Convolutional Neural Network in Keras using a generated model
430
431
432if usePyTorchCNN:
433 ROOT.Info("TMVA_CNN_Classification", "Using Convolutional PyTorch Model")
434 pyTorchFileName = str(ROOT.gROOT.GetTutorialDir())
435 pyTorchFileName += "/tmva/PyTorch_Generate_CNN_Model.py"
436 # check that pytorch can be imported and file defining the model exists
437 torch_spec = importlib.util.find_spec("torch")
438 if torch_spec is not None and os.path.exists(pyTorchFileName):
439 #cmd = str(ROOT.TMVA.Python_Executable()) + " " + pyTorchFileName
440 #os.system(cmd)
441 #import PyTorch_Generate_CNN_Model
442 ROOT.Info("TMVA_CNN_Classification", "Booking PyTorch CNN model")
443 factory.BookMethod(
444 loader,
445 TMVA.Types.kPyTorch,
446 "PyTorch",
447 H=True,
448 V=False,
449 VarTransform=None,
450 FilenameModel="PyTorchModelCNN.pt",
451 FilenameTrainedModel="PyTorchTrainedModelCNN.pt",
452 NumEpochs=max_epochs,
453 BatchSize=100,
454 UserCode=str(pyTorchFileName)
455 )
456 else:
457 ROOT.Warning(
458 "TMVA_CNN_Classification",
459 "PyTorch is not installed or model building file is not existing - skip using PyTorch",
460 )
461
462if useKerasCNN:
463 ROOT.Info("TMVA_CNN_Classification", "Building convolutional keras model")
464 # create python script which can be executed
465 # create 2 conv2d layer + maxpool + dense
466 import tensorflow
467 from tensorflow.keras.models import Sequential
468 from tensorflow.keras.optimizers import Adam
469
470 # from keras.initializers import TruncatedNormal
471 # from keras import initializations
472 from tensorflow.keras.layers import Input, Dense, Dropout, Flatten, Conv2D, MaxPooling2D, Reshape
473
474 # from keras.callbacks import ReduceLROnPlateau
475 model = Sequential()
476 model.add(Reshape((16, 16, 1), input_shape=(256,)))
477 model.add(Conv2D(10, kernel_size=(3, 3), kernel_initializer="TruncatedNormal", activation="relu", padding="same"))
478 model.add(Conv2D(10, kernel_size=(3, 3), kernel_initializer="TruncatedNormal", activation="relu", padding="same"))
479 # stride for maxpool is equal to pool size
480 model.add(MaxPooling2D(pool_size=(2, 2)))
481 model.add(Flatten())
482 model.add(Dense(64, activation="tanh"))
483 # model.add(Dropout(0.2))
484 model.add(Dense(2, activation="sigmoid"))
485 model.compile(loss="binary_crossentropy", optimizer=Adam(learning_rate=0.001), weighted_metrics=["accuracy"])
486 model.save("model_cnn.h5")
487 model.summary()
488
489 if not os.path.exists("model_cnn.h5"):
490 raise FileNotFoundError("Error creating Keras model file - skip using Keras")
491 else:
492 # book PyKeras method only if Keras model could be created
493 ROOT.Info("TMVA_CNN_Classification", "Booking convolutional keras model")
494 factory.BookMethod(
495 loader,
496 TMVA.Types.kPyKeras,
497 "PyKeras",
498 H=True,
499 V=False,
500 VarTransform=None,
501 FilenameModel="model_cnn.h5",
502 FilenameTrainedModel="trained_model_cnn.h5",
503 NumEpochs=max_epochs,
504 BatchSize=100,
505 GpuOptions="allow_growth=True",
506 ) # needed for RTX NVidia card and to avoid TF allocates all GPU memory
507
508
509
510## Train Methods
511
512factory.TrainAllMethods()
513
514## Test and Evaluate Methods
515
516factory.TestAllMethods()
517
518factory.EvaluateAllMethods()
519
520## Plot ROC Curve
521
522c1 = factory.GetROCCurve(loader)
523c1.Draw()
524
525# close outputfile to save output file
526outputFile.Close()
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t Atom_t Time_t UChar_t len
A ROOT file is composed of a header, followed by consecutive data records (TKey instances) with a wel...
Definition TFile.h:53
static TFile * Open(const char *name, Option_t *option="", const char *ftitle="", Int_t compress=ROOT::RCompressionSetting::EDefaults::kUseCompiledDefault, Int_t netopt=0)
Create / open a file.
Definition TFile.cxx:4075
This is the main MVA steering class.
Definition Factory.h:80
static void PyInitialize()
Initialize Python interpreter.
static Tools & Instance()
Definition Tools.cxx:71
void EnableImplicitMT(UInt_t numthreads=0)
Enable ROOT's implicit multi-threading for all objects and methods that provide an internal paralleli...
Definition TROOT.cxx:539
UInt_t GetThreadPoolSize()
Returns the size of ROOT's thread pool.
Definition TROOT.cxx:577