Loading [MathJax]/extensions/tex2jax.js
Logo ROOT  
Reference Guide
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Modules Pages
TMVAMulticlassApplication.C File Reference

Detailed Description

View in nbviewer Open in SWAN This macro provides a simple example on how to use the trained multiclass classifiers within an analysis module

  • Project : TMVA - a Root-integrated toolkit for multivariate data analysis
  • Package : TMVA
  • Root Macro: TMVAMulticlassApplication
==> Start TMVAMulticlassApp
: Booking "BDTG method" of type "BDT" from dataset/weights/TMVAMulticlass_BDTG.weights.xml.
: Reading weight file: dataset/weights/TMVAMulticlass_BDTG.weights.xml
<HEADER> DataSetInfo : [Default] : Added class "Signal"
<HEADER> DataSetInfo : [Default] : Added class "bg0"
<HEADER> DataSetInfo : [Default] : Added class "bg1"
<HEADER> DataSetInfo : [Default] : Added class "bg2"
: Booked classifier "BDTG" of type: "BDT"
: Booking "DL_CPU method" of type "DL" from dataset/weights/TMVAMulticlass_DL_CPU.weights.xml.
: Reading weight file: dataset/weights/TMVAMulticlass_DL_CPU.weights.xml
: Booked classifier "DL_CPU" of type: "DL"
TMVAMultiClassApplication: Skip DL_GPU method since it has not been trained !
TMVAMultiClassApplication: Skip FDA_GA method since it has not been trained !
: Booking "MLP method" of type "MLP" from dataset/weights/TMVAMulticlass_MLP.weights.xml.
: Reading weight file: dataset/weights/TMVAMulticlass_MLP.weights.xml
<HEADER> MLP : Building Network.
: Initializing weights
: Booked classifier "MLP" of type: "MLP"
: Booking "PDEFoam method" of type "PDEFoam" from dataset/weights/TMVAMulticlass_PDEFoam.weights.xml.
: Reading weight file: dataset/weights/TMVAMulticlass_PDEFoam.weights.xml
: Read foams from file: dataset/weights/TMVAMulticlass_PDEFoam.weights_foams.root
: Booked classifier "PDEFoam" of type: "PDEFoam"
--- TMVAMulticlassApp : Using input file: ./files/tmva_multiclass_example.root
--- Select signal sample
--- End of event loop: Real time 0:00:00, CP time 0.780
--- Created root file: "TMVMulticlassApp.root" containing the MVA output histograms
==> TMVAMulticlassApp is done!
#include <cstdlib>
#include <iostream>
#include <map>
#include <string>
#include <vector>
#include "TFile.h"
#include "TTree.h"
#include "TString.h"
#include "TSystem.h"
#include "TROOT.h"
#include "TStopwatch.h"
#include "TH1F.h"
#include "TMVA/Tools.h"
#include "TMVA/Reader.h"
using namespace TMVA;
void TMVAMulticlassApplication( TString myMethodList = "" )
{
//---------------------------------------------------------------
// Default MVA methods to be trained + tested
std::map<std::string,int> Use;
Use["MLP"] = 1;
Use["BDTG"] = 1;
Use["DL_CPU"] = 1;
Use["DL_GPU"] = 1;
Use["FDA_GA"] = 1;
Use["PDEFoam"] = 1;
//---------------------------------------------------------------
std::cout << std::endl;
std::cout << "==> Start TMVAMulticlassApp" << std::endl;
if (myMethodList != "") {
for (std::map<std::string,int>::iterator it = Use.begin(); it != Use.end(); it++) it->second = 0;
std::vector<TString> mlist = gTools().SplitString( myMethodList, ',' );
for (UInt_t i=0; i<mlist.size(); i++) {
std::string regMethod(mlist[i]);
if (Use.find(regMethod) == Use.end()) {
std::cout << "Method \"" << regMethod << "\" not known in TMVA under this name. Choose among the following:" << std::endl;
for (std::map<std::string,int>::iterator it = Use.begin(); it != Use.end(); it++) std::cout << it->first << " " << std::endl;
std::cout << std::endl;
return;
}
Use[regMethod] = 1;
}
}
// create the Reader object
TMVA::Reader *reader = new TMVA::Reader( "!Color:!Silent" );
// create a set of variables and declare them to the reader
// - the variable names must corresponds in name and type to
// those given in the weight file(s) that you use
Float_t var1, var2, var3, var4;
reader->AddVariable( "var1", &var1 );
reader->AddVariable( "var2", &var2 );
reader->AddVariable( "var3", &var3 );
reader->AddVariable( "var4", &var4 );
// book the MVA methods
TString dir = "dataset/weights/";
TString prefix = "TMVAMulticlass";
for (std::map<std::string,int>::iterator it = Use.begin(); it != Use.end(); it++) {
if (it->second) {
TString methodName = TString(it->first) + TString(" method");
TString weightfile = dir + prefix + TString("_") + TString(it->first) + TString(".weights.xml");
// check if file existing (i.e. method has been trained)
if (!gSystem->AccessPathName( weightfile ))
// file exists
reader->BookMVA( methodName, weightfile );
else {
std::cout << "TMVAMultiClassApplication: Skip " << methodName << " since it has not been trained !" << std::endl;
it->second = 0;
}
}
}
// book output histograms
UInt_t nbin = 100;
TH1F *histMLP_signal(0), *histBDTG_signal(0), *histFDAGA_signal(0), *histPDEFoam_signal(0);
TH1F *histDLCPU_signal(0), *histDLGPU_signal(0);
if (Use["MLP"])
histMLP_signal = new TH1F( "MVA_MLP_signal", "MVA_MLP_signal", nbin, 0., 1.1 );
if (Use["BDTG"])
histBDTG_signal = new TH1F( "MVA_BDTG_signal", "MVA_BDTG_signal", nbin, 0., 1.1 );
if (Use["DL_CPU"])
histDLCPU_signal = new TH1F("MVA_DLCPU_signal", "MVA_DLCPU_signal", nbin, 0., 1.1);
if (Use["DL_GPU"])
histDLGPU_signal = new TH1F("MVA_DLGPU_signal", "MVA_DLGPU_signal", nbin, 0., 1.1);
if (Use["FDA_GA"])
histFDAGA_signal = new TH1F( "MVA_FDA_GA_signal", "MVA_FDA_GA_signal", nbin, 0., 1.1 );
if (Use["PDEFoam"])
histPDEFoam_signal = new TH1F( "MVA_PDEFoam_signal", "MVA_PDEFoam_signal", nbin, 0., 1.1 );
TFile *input(0);
TString fname = "./tmva_example_multiclass.root";
if (!gSystem->AccessPathName( fname )) {
input = TFile::Open( fname ); // check if file in local directory exists
}
else {
input = TFile::Open("http://root.cern.ch/files/tmva_multiclass_example.root", "CACHEREAD");
}
if (!input) {
std::cout << "ERROR: could not open data file" << std::endl;
exit(1);
}
std::cout << "--- TMVAMulticlassApp : Using input file: " << input->GetName() << std::endl;
// prepare the tree
// - here the variable names have to corresponds to your tree
// - you can use the same variables as above which is slightly faster,
// but of course you can use different ones and copy the values inside the event loop
TTree* theTree = (TTree*)input->Get("TreeS");
std::cout << "--- Select signal sample" << std::endl;
theTree->SetBranchAddress( "var1", &var1 );
theTree->SetBranchAddress( "var2", &var2 );
theTree->SetBranchAddress( "var3", &var3 );
theTree->SetBranchAddress( "var4", &var4 );
std::cout << "--- Processing: " << theTree->GetEntries() << " events" << std::endl;
sw.Start();
for (Long64_t ievt=0; ievt<theTree->GetEntries();ievt++) {
if (ievt%1000 == 0){
std::cout << "--- ... Processing event: " << ievt << std::endl;
}
theTree->GetEntry(ievt);
if (Use["MLP"])
histMLP_signal->Fill((reader->EvaluateMulticlass( "MLP method" ))[0]);
if (Use["BDTG"])
histBDTG_signal->Fill((reader->EvaluateMulticlass( "BDTG method" ))[0]);
if (Use["DL_CPU"])
histDLCPU_signal->Fill((reader->EvaluateMulticlass("DL_CPU method"))[0]);
if (Use["DL_GPU"])
histDLGPU_signal->Fill((reader->EvaluateMulticlass("DL_GPU method"))[0]);
if (Use["FDA_GA"])
histFDAGA_signal->Fill((reader->EvaluateMulticlass( "FDA_GA method" ))[0]);
if (Use["PDEFoam"])
histPDEFoam_signal->Fill((reader->EvaluateMulticlass( "PDEFoam method" ))[0]);
}
// get elapsed time
sw.Stop();
std::cout << "--- End of event loop: "; sw.Print();
TFile *target = new TFile( "TMVAMulticlassApp.root","RECREATE" );
if (Use["MLP"])
histMLP_signal->Write();
if (Use["BDTG"])
histBDTG_signal->Write();
if (Use["DL_CPU"])
histDLCPU_signal->Write();
if (Use["DL_GPU"])
histDLGPU_signal->Write();
if (Use["FDA_GA"])
histFDAGA_signal->Write();
if (Use["PDEFoam"])
histPDEFoam_signal->Write();
target->Close();
std::cout << "--- Created root file: \"TMVMulticlassApp.root\" containing the MVA output histograms" << std::endl;
delete reader;
std::cout << "==> TMVAMulticlassApp is done!" << std::endl << std::endl;
}
int main( int argc, char** argv )
{
// Select methods (don't look at this code - not of interest)
TString methodList;
for (int i=1; i<argc; i++) {
TString regMethod(argv[i]);
if(regMethod=="-b" || regMethod=="--batch") continue;
if (!methodList.IsNull()) methodList += TString(",");
methodList += regMethod;
}
TMVAMulticlassApplication(methodList);
return 0;
}
unsigned int UInt_t
Definition: RtypesCore.h:44
long long Long64_t
Definition: RtypesCore.h:71
float Float_t
Definition: RtypesCore.h:55
R__EXTERN TSystem * gSystem
Definition: TSystem.h:556
A ROOT file is a suite of consecutive data records (TKey instances) with a well defined format.
Definition: TFile.h:53
static Bool_t SetCacheFileDir(ROOT::Internal::TStringView cacheDir, Bool_t operateDisconnected=kTRUE, Bool_t forceCacheread=kFALSE)
Definition: TFile.h:323
static TFile * Open(const char *name, Option_t *option="", const char *ftitle="", Int_t compress=ROOT::RCompressionSetting::EDefaults::kUseCompiledDefault, Int_t netopt=0)
Create / open a file.
Definition: TFile.cxx:3942
void Close(Option_t *option="") override
Close a file.
Definition: TFile.cxx:873
1-D histogram with a float per channel (see TH1 documentation)}
Definition: TH1.h:571
The Reader class serves to use the MVAs in a specific analysis context.
Definition: Reader.h:63
IMethod * BookMVA(const TString &methodTag, const TString &weightfile)
read method name from weight file
Definition: Reader.cxx:373
const std::vector< Float_t > & EvaluateMulticlass(const TString &methodTag, Double_t aux=0)
evaluates MVA for given set of input variables
Definition: Reader.cxx:635
void AddVariable(const TString &expression, Float_t *)
Add a float variable or expression to the reader.
Definition: Reader.cxx:308
static Tools & Instance()
Definition: Tools.cxx:74
std::vector< TString > SplitString(const TString &theOpt, const char separator) const
splits the option string at 'separator' and fills the list 'splitV' with the primitive strings
Definition: Tools.cxx:1210
Stopwatch class.
Definition: TStopwatch.h:28
void Start(Bool_t reset=kTRUE)
Start the stopwatch.
Definition: TStopwatch.cxx:58
void Stop()
Stop the stopwatch.
Definition: TStopwatch.cxx:77
void Print(Option_t *option="") const
Print the real and cpu time passed between the start and stop events.
Definition: TStopwatch.cxx:219
Basic string class.
Definition: TString.h:131
Bool_t IsNull() const
Definition: TString.h:402
virtual Bool_t AccessPathName(const char *path, EAccessMode mode=kFileExists)
Returns FALSE if one can access a file using the specified access mode.
Definition: TSystem.cxx:1291
A TTree represents a columnar dataset.
Definition: TTree.h:78
virtual Int_t SetBranchAddress(const char *bname, void *add, TBranch **ptr=0)
Change branch address, dealing with clone trees properly.
Definition: TTree.cxx:8237
virtual Long64_t GetEntries() const
Definition: TTree.h:457
virtual Int_t GetEntry(Long64_t entry=0, Int_t getall=0)
Read all branches of entry and return total number of bytes read.
Definition: TTree.cxx:5542
int main(int argc, char **argv)
create variable transformations
Tools & gTools()
Author
Andreas Hoecker

Definition in file TMVAMulticlassApplication.C.