Logo ROOT   6.16/01
Reference Guide
Classification Class Reference

Class to perform two class classification.

The first step before any analysis is to preperate the data, to do that you need to create an object of TMVA::DataLoader, in this object you need to configure the variables and the number of events to train/test. The class TMVA::Experimental::Classification needs a TMVA::DataLoader object, optional a TFile object to save the results and some extra options in a string like "V:Color:Transformations=I;D;P;U;G:Silent:DrawProgressBar:ModelPersistence:Jobs=2" where: V = verbose output Color = coloured screen output Silent = batch mode: boolean silent flag inhibiting any output from TMVA Transformations = list of transformations to test. DrawProgressBar = draw progress bar to display training and testing. ModelPersistence = to save the trained model in xml or serialized files. Jobs = number of ml methods to test/train in parallel using MultiProc, requires to call Evaluate method. Basic example.

void classification(UInt_t jobs = 2)
{
TFile *input(0);
TString fname = "./tmva_class_example.root";
if (!gSystem->AccessPathName(fname)) {
input = TFile::Open(fname); // check if file in local directory exists
} else {
input = TFile::Open("http://root.cern.ch/files/tmva_class_example.root", "CACHEREAD");
}
if (!input) {
std::cout << "ERROR: could not open data file" << std::endl;
exit(1);
}
// Register the training and test trees
TTree *signalTree = (TTree *)input->Get("TreeS");
TTree *background = (TTree *)input->Get("TreeB");
dataloader->AddVariable("myvar1 := var1+var2", 'F');
dataloader->AddVariable("myvar2 := var1-var2", "Expression 2", "", 'F');
dataloader->AddVariable("var3", "Variable 3", "units", 'F');
dataloader->AddVariable("var4", "Variable 4", "units", 'F');
dataloader->AddSpectator("spec1 := var1*2", "Spectator 1", "units", 'F');
dataloader->AddSpectator("spec2 := var1*3", "Spectator 2", "units", 'F');
// global event weights per tree (see below for setting event-wise weights)
Double_t signalWeight = 1.0;
Double_t backgroundWeight = 1.0;
dataloader->SetBackgroundWeightExpression("weight");
cl->BookMethod(TMVA::Types::kBDT, "BDTG", "!H:!V:NTrees=2000:MinNodeSize=2.5%:BoostType=Grad:Shrinkage=0.10:"
"UseBaggedBoost:BaggedSampleFraction=0.5:nCuts=20:MaxDepth=2");
cl->BookMethod(TMVA::Types::kSVM, "SVM", "Gamma=0.25:Tol=0.001:VarTransform=Norm");
cl->Evaluate(); // Train and Test all methods
auto &results = cl->GetResults();
TCanvas *c = new TCanvas(Form("ROC"));
c->SetTitle("ROC-Integral Curve");
auto mg = new TMultiGraph();
for (UInt_t i = 0; i < results.size(); i++) {
auto roc = results[i].GetROCGraph();
roc->SetLineColorAlpha(i + 1, 0.1);
mg->Add(roc);
}
mg->Draw("AL");
mg->GetXaxis()->SetTitle(" Signal Efficiency ");
mg->GetYaxis()->SetTitle(" Background Rejection ");
c->BuildLegend(0.15, 0.15, 0.3, 0.3);
c->Draw();
delete cl;
}
#define c(i)
Definition: RSha256.hxx:101
unsigned int UInt_t
Definition: RtypesCore.h:42
double Double_t
Definition: RtypesCore.h:55
char * Form(const char *fmt,...)
R__EXTERN TSystem * gSystem
Definition: TSystem.h:540
The Canvas class.
Definition: TCanvas.h:31
A ROOT file is a suite of consecutive data records (TKey instances) with a well defined format.
Definition: TFile.h:48
static Bool_t SetCacheFileDir(ROOT::Internal::TStringView cacheDir, Bool_t operateDisconnected=kTRUE, Bool_t forceCacheread=kFALSE)
Definition: TFile.h:316
static TFile * Open(const char *name, Option_t *option="", const char *ftitle="", Int_t compress=ROOT::RCompressionSetting::EDefaults::kUseGeneralPurpose, Int_t netopt=0)
Create / open a file.
Definition: TFile.cxx:3975
virtual void BookMethod(TString methodname, TString methodtitle, TString options="")
Method to book the machine learning method to perform the algorithm.
Definition: Envelope.cxx:168
std::vector< ClassificationResult > & GetResults()
return the the vector of TMVA::Experimental::ClassificationResult objects.
virtual void Evaluate()
Method to perform Train/Test over all ml method booked.
static Tools & Instance()
Definition: Tools.cxx:75
@ kBDT
Definition: Types.h:88
@ kSVM
Definition: Types.h:91
A TMultiGraph is a collection of TGraph (or derived) objects.
Definition: TMultiGraph.h:35
Basic string class.
Definition: TString.h:131
virtual Bool_t AccessPathName(const char *path, EAccessMode mode=kFileExists)
Returns FALSE if one can access a file using the specified access mode.
Definition: TSystem.cxx:1286
A TTree object has a header with a name and a title.
Definition: TTree.h:71
void classification(UInt_t jobs=4)
Definition: classification.C:6
static constexpr double mg

#include <TMVA/Classification.h>


The documentation for this class was generated from the following file: