Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
tmva100_DataPreparation.py
Go to the documentation of this file.
1## \file
2## \ingroup tutorial_tmva
3## \notebook -nodraw
4## This tutorial illustrates how to prepare ROOT datasets to be nicely readable
5## by most machine learning methods. This requires filtering the inital complex
6## datasets and writing the data in a flat format.
7##
8## \macro_code
9## \macro_output
10##
11## \date August 2019
12## \author Stefan Wunsch
13
14import ROOT
15
16
17def filter_events(df):
18 """
19 Reduce initial dataset to only events which shall be used for training
20 """
21 return df.Filter("nElectron>=2 && nMuon>=2", "At least two electrons and two muons")
22
23
24def define_variables(df):
25 """
26 Define the variables which shall be used for training
27 """
28 return df.Define("Muon_pt_1", "Muon_pt[0]")\
29 .Define("Muon_pt_2", "Muon_pt[1]")\
30 .Define("Electron_pt_1", "Electron_pt[0]")\
31 .Define("Electron_pt_2", "Electron_pt[1]")
32
33
34variables = ["Muon_pt_1", "Muon_pt_2", "Electron_pt_1", "Electron_pt_2"]
35
36
37if __name__ == "__main__":
38 for filename, label in [["SMHiggsToZZTo4L.root", "signal"], ["ZZTo2e2mu.root", "background"]]:
39 print(">>> Extract the training and testing events for {} from the {} dataset.".format(
40 label, filename))
41
42 # Load dataset, filter the required events and define the training variables
43 filepath = "root://eospublic.cern.ch//eos/root-eos/cms_opendata_2012_nanoaod/" + filename
44 df = ROOT.RDataFrame("Events", filepath)
45 df = filter_events(df)
46 df = define_variables(df)
47
48 # Book cutflow report
49 report = df.Report()
50
51 # Split dataset by event number for training and testing
52 columns = ROOT.std.vector["string"](variables)
53 df.Filter("event % 2 == 0", "Select events with even event number for training")\
54 .Snapshot("Events", "train_" + label + ".root", columns)
55 df.Filter("event % 2 == 1", "Select events with odd event number for training")\
56 .Snapshot("Events", "test_" + label + ".root", columns)
57
58 # Print cutflow report
59 report.Print()
ROOT's RDataFrame offers a high level interface for analyses of data stored in TTrees,...