Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
TMVARegressionApplication.C
Go to the documentation of this file.
1/// \file
2/// \ingroup tutorial_tmva
3/// \notebook -nodraw
4/// This macro provides a simple example on how to use the trained regression MVAs
5/// within an analysis module
6///
7/// - Project : TMVA - a Root-integrated toolkit for multivariate data analysis
8/// - Package : TMVA
9/// - Executable: TMVARegressionApplication
10///
11/// \macro_output
12/// \macro_code
13/// \author Andreas Hoecker
14
15#include <cstdlib>
16#include <vector>
17#include <iostream>
18#include <map>
19#include <string>
20
21#include "TFile.h"
22#include "TTree.h"
23#include "TString.h"
24#include "TSystem.h"
25#include "TROOT.h"
26#include "TStopwatch.h"
27
28#include "TMVA/Tools.h"
29#include "TMVA/Reader.h"
30
31using namespace TMVA;
32
33void TMVARegressionApplication( TString myMethodList = "" )
34{
35 //---------------------------------------------------------------
36 // This loads the library
38
39 // Default MVA methods to be trained + tested
40 std::map<std::string,int> Use;
41
42 // --- Mutidimensional likelihood and Nearest-Neighbour methods
43 Use["PDERS"] = 0;
44 Use["PDEFoam"] = 1;
45 Use["KNN"] = 1;
46 //
47 // --- Linear Discriminant Analysis
48 Use["LD"] = 1;
49 //
50 // --- Function Discriminant analysis
51 Use["FDA_GA"] = 0;
52 Use["FDA_MC"] = 0;
53 Use["FDA_MT"] = 0;
54 Use["FDA_GAMT"] = 0;
55 //
56 // --- Neural Network
57 Use["MLP"] = 0;
58 // Deep neural network
59#ifdef R__HAS_TMVAGPU
60 Use["DNN_GPU"] = 1;
61 Use["DNN_CPU"] = 0;
62#else
63 Use["DNN_GPU"] = 0;
64#ifdef R__HAS_TMVACPU
65 Use["DNN_CPU"] = 1;
66#else
67 Use["DNN_CPU"] = 0;
68#endif
69#endif
70 //
71 // --- Support Vector Machine
72 Use["SVM"] = 0;
73 //
74 // --- Boosted Decision Trees
75 Use["BDT"] = 0;
76 Use["BDTG"] = 1;
77 // ---------------------------------------------------------------
78
79 std::cout << std::endl;
80 std::cout << "==> Start TMVARegressionApplication" << std::endl;
81
82 // Select methods (don't look at this code - not of interest)
83 if (myMethodList != "") {
84 for (std::map<std::string,int>::iterator it = Use.begin(); it != Use.end(); it++) it->second = 0;
85
86 std::vector<TString> mlist = gTools().SplitString( myMethodList, ',' );
87 for (UInt_t i=0; i<mlist.size(); i++) {
88 std::string regMethod(mlist[i]);
89
90 if (Use.find(regMethod) == Use.end()) {
91 std::cout << "Method \"" << regMethod << "\" not known in TMVA under this name. Choose among the following:" << std::endl;
92 for (std::map<std::string,int>::iterator it = Use.begin(); it != Use.end(); it++) std::cout << it->first << " ";
93 std::cout << std::endl;
94 return;
95 }
96 Use[regMethod] = 1;
97 }
98 }
99
100 // --------------------------------------------------------------------------------------------------
101
102 // --- Create the Reader object
103
104 TMVA::Reader *reader = new TMVA::Reader( "!Color:!Silent" );
105
106 // Create a set of variables and declare them to the reader
107 // - the variable names MUST corresponds in name and type to those given in the weight file(s) used
108 Float_t var1, var2;
109 reader->AddVariable( "var1", &var1 );
110 reader->AddVariable( "var2", &var2 );
111
112 // Spectator variables declared in the training have to be added to the reader, too
113 Float_t spec1,spec2;
114 reader->AddSpectator( "spec1:=var1*2", &spec1 );
115 reader->AddSpectator( "spec2:=var1*3", &spec2 );
116
117 // --- Book the MVA methods
118
119 TString dir = "datasetreg/weights/";
120 TString prefix = "TMVARegression";
121
122 // Book method(s)
123 for (std::map<std::string,int>::iterator it = Use.begin(); it != Use.end(); it++) {
124 if (it->second) {
125 TString methodName = it->first + " method";
126 TString weightfile = dir + prefix + "_" + TString(it->first) + ".weights.xml";
127 reader->BookMVA( methodName, weightfile );
128 }
129 }
130
131 // Book output histograms
132 TH1* hists[100];
133 Int_t nhists = -1;
134 for (std::map<std::string,int>::iterator it = Use.begin(); it != Use.end(); it++) {
135 TH1* h = new TH1F( it->first.c_str(), TString(it->first) + " method", 100, -100, 600 );
136 if (it->second) hists[++nhists] = h;
137 }
138 nhists++;
139
140 // Prepare input tree (this must be replaced by your data source)
141 // in this example, there is a toy tree with signal and one with background events
142 // we'll later on use only the "signal" events for the test in this example.
143 //
144 TFile *input(0);
145 TString fname = "./tmva_reg_example.root";
146 if (!gSystem->AccessPathName( fname )) {
147 input = TFile::Open( fname ); // check if file in local directory exists
148 }
149 else {
151 input = TFile::Open("http://root.cern/files/tmva_reg_example.root", "CACHEREAD"); // if not: download from ROOT server
152 }
153 if (!input) {
154 std::cout << "ERROR: could not open data file" << std::endl;
155 exit(1);
156 }
157 std::cout << "--- TMVARegressionApp : Using input file: " << input->GetName() << std::endl;
158
159 // --- Event loop
160
161 // Prepare the tree
162 // - here the variable names have to corresponds to your tree
163 // - you can use the same variables as above which is slightly faster,
164 // but of course you can use different ones and copy the values inside the event loop
165 //
166 TTree* theTree = (TTree*)input->Get("TreeR");
167 std::cout << "--- Select signal sample" << std::endl;
168 theTree->SetBranchAddress( "var1", &var1 );
169 theTree->SetBranchAddress( "var2", &var2 );
170
171 std::cout << "--- Processing: " << theTree->GetEntries() << " events" << std::endl;
172 TStopwatch sw;
173 sw.Start();
174 for (Long64_t ievt=0; ievt<theTree->GetEntries();ievt++) {
175
176 if (ievt%1000 == 0) {
177 std::cout << "--- ... Processing event: " << ievt << std::endl;
178 }
179
180 theTree->GetEntry(ievt);
181
182 // Retrieve the MVA target values (regression outputs) and fill into histograms
183 // NOTE: EvaluateRegression(..) returns a vector for multi-target regression
184
185 for (Int_t ih=0; ih<nhists; ih++) {
186 TString title = hists[ih]->GetTitle();
187 Float_t val = (reader->EvaluateRegression( title ))[0];
188 hists[ih]->Fill( val );
189 }
190 }
191 sw.Stop();
192 std::cout << "--- End of event loop: "; sw.Print();
193
194 // --- Write histograms
195
196 TFile *target = new TFile( "TMVARegApp.root","RECREATE" );
197 for (Int_t ih=0; ih<nhists; ih++) hists[ih]->Write();
198 target->Close();
199
200 std::cout << "--- Created root file: \"" << target->GetName()
201 << "\" containing the MVA output histograms" << std::endl;
202
203 delete reader;
204
205 std::cout << "==> TMVARegressionApplication is done!" << std::endl << std::endl;
206}
207
208int main( int argc, char** argv )
209{
210 // Select methods (don't look at this code - not of interest)
211 TString methodList;
212 for (int i=1; i<argc; i++) {
213 TString regMethod(argv[i]);
214 if(regMethod=="-b" || regMethod=="--batch") continue;
215 if (!methodList.IsNull()) methodList += TString(",");
216 methodList += regMethod;
217 }
218 TMVARegressionApplication(methodList);
219 return 0;
220}
int main()
Definition Prototype.cxx:12
#define h(i)
Definition RSha256.hxx:106
int Int_t
Definition RtypesCore.h:45
unsigned int UInt_t
Definition RtypesCore.h:46
float Float_t
Definition RtypesCore.h:57
long long Long64_t
Definition RtypesCore.h:69
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void input
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t target
R__EXTERN TSystem * gSystem
Definition TSystem.h:561
A ROOT file is an on-disk file, usually with extension .root, that stores objects in a file-system-li...
Definition TFile.h:53
static TFile * Open(const char *name, Option_t *option="", const char *ftitle="", Int_t compress=ROOT::RCompressionSetting::EDefaults::kUseCompiledDefault, Int_t netopt=0)
Create / open a file.
Definition TFile.cxx:4089
static Bool_t SetCacheFileDir(std::string_view cacheDir, Bool_t operateDisconnected=kTRUE, Bool_t forceCacheread=kFALSE)
Sets the directory where to locally stage/cache remote files.
Definition TFile.cxx:4626
1-D histogram with a float per channel (see TH1 documentation)
Definition TH1.h:622
TH1 is the base class of all histogram classes in ROOT.
Definition TH1.h:59
virtual Int_t Fill(Double_t x)
Increment bin with abscissa X by 1.
Definition TH1.cxx:3344
The Reader class serves to use the MVAs in a specific analysis context.
Definition Reader.h:64
const std::vector< Float_t > & EvaluateRegression(const TString &methodTag, Double_t aux=0)
evaluates MVA for given set of input variables
Definition Reader.cxx:565
IMethod * BookMVA(const TString &methodTag, const TString &weightfile)
read method name from weight file
Definition Reader.cxx:368
void AddSpectator(const TString &expression, Float_t *)
Add a float spectator or expression to the reader.
Definition Reader.cxx:321
void AddVariable(const TString &expression, Float_t *)
Add a float variable or expression to the reader.
Definition Reader.cxx:303
static Tools & Instance()
Definition Tools.cxx:71
std::vector< TString > SplitString(const TString &theOpt, const char separator) const
splits the option string at 'separator' and fills the list 'splitV' with the primitive strings
Definition Tools.cxx:1199
const char * GetTitle() const override
Returns title of object.
Definition TNamed.h:48
Stopwatch class.
Definition TStopwatch.h:28
void Start(Bool_t reset=kTRUE)
Start the stopwatch.
void Stop()
Stop the stopwatch.
void Print(Option_t *option="") const override
Print the real and cpu time passed between the start and stop events.
Basic string class.
Definition TString.h:139
Bool_t IsNull() const
Definition TString.h:414
virtual Bool_t AccessPathName(const char *path, EAccessMode mode=kFileExists)
Returns FALSE if one can access a file using the specified access mode.
Definition TSystem.cxx:1296
A TTree represents a columnar dataset.
Definition TTree.h:79
virtual Int_t GetEntry(Long64_t entry, Int_t getall=0)
Read all branches of entry and return total number of bytes read.
Definition TTree.cxx:5638
virtual Int_t SetBranchAddress(const char *bname, void *add, TBranch **ptr=nullptr)
Change branch address, dealing with clone trees properly.
Definition TTree.cxx:8385
virtual Long64_t GetEntries() const
Definition TTree.h:463
create variable transformations
Tools & gTools()