Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
Classification Class Reference

Class to perform two class classification.

The first step before any analysis is to prepare the data, to do that you need to create an object of TMVA::DataLoader, in this object you need to configure the variables and the number of events to train/test. The class TMVA::Experimental::Classification needs a TMVA::DataLoader object, optional a TFile object to save the results and some extra options in a string like "V:Color:Transformations=I;D;P;U;G:Silent:DrawProgressBar:ModelPersistence:Jobs=2" where: V = verbose output Color = coloured screen output Silent = batch mode: boolean silent flag inhibiting any output from TMVA Transformations = list of transformations to test. DrawProgressBar = draw progress bar to display training and testing. ModelPersistence = to save the trained model in xml or serialized files. Jobs = number of ml methods to test/train in parallel using MultiProc, requires to call Evaluate method. Basic example.

void classification(UInt_t jobs = 2)
{
TFile *input(0);
TString fname = "./tmva_class_example.root";
if (!gSystem->AccessPathName(fname)) {
input = TFile::Open(fname); // check if file in local directory exists
} else {
input = TFile::Open("http://root.cern/files/tmva_class_example.root", "CACHEREAD");
}
if (!input) {
std::cout << "ERROR: could not open data file" << std::endl;
exit(1);
}
// Register the training and test trees
TTree *signalTree = (TTree *)input->Get("TreeS");
TTree *background = (TTree *)input->Get("TreeB");
TMVA::DataLoader *dataloader = new TMVA::DataLoader("dataset");
dataloader->AddVariable("myvar1 := var1+var2", 'F');
dataloader->AddVariable("myvar2 := var1-var2", "Expression 2", "", 'F');
dataloader->AddVariable("var3", "Variable 3", "units", 'F');
dataloader->AddVariable("var4", "Variable 4", "units", 'F');
dataloader->AddSpectator("spec1 := var1*2", "Spectator 1", "units", 'F');
dataloader->AddSpectator("spec2 := var1*3", "Spectator 2", "units", 'F');
// global event weights per tree (see below for setting event-wise weights)
Double_t signalWeight = 1.0;
Double_t backgroundWeight = 1.0;
dataloader->SetBackgroundWeightExpression("weight");
cl->BookMethod(TMVA::Types::kBDT, "BDTG", "!H:!V:NTrees=2000:MinNodeSize=2.5%:BoostType=Grad:Shrinkage=0.10:"
"UseBaggedBoost:BaggedSampleFraction=0.5:nCuts=20:MaxDepth=2");
cl->BookMethod(TMVA::Types::kSVM, "SVM", "Gamma=0.25:Tol=0.001:VarTransform=Norm");
cl->Evaluate(); // Train and Test all methods
auto &results = cl->GetResults();
TCanvas *c = new TCanvas(Form("ROC"));
c->SetTitle("ROC-Integral Curve");
auto mg = new TMultiGraph();
for (UInt_t i = 0; i < results.size(); i++) {
auto roc = results[i].GetROCGraph();
roc->SetLineColorAlpha(i + 1, 0.1);
mg->Add(roc);
}
mg->Draw("AL");
mg->GetXaxis()->SetTitle(" Signal Efficiency ");
mg->GetYaxis()->SetTitle(" Background Rejection ");
c->BuildLegend(0.15, 0.15, 0.3, 0.3);
c->Draw();
delete cl;
}
#define c(i)
Definition RSha256.hxx:101
unsigned int UInt_t
Definition RtypesCore.h:46
double Double_t
Definition RtypesCore.h:59
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void input
char * Form(const char *fmt,...)
Formats a string in a circular formatting buffer.
Definition TString.cxx:2489
R__EXTERN TSystem * gSystem
Definition TSystem.h:561
The Canvas class.
Definition TCanvas.h:23
A ROOT file is an on-disk file, usually with extension .root, that stores objects in a file-system-li...
Definition TFile.h:53
static TFile * Open(const char *name, Option_t *option="", const char *ftitle="", Int_t compress=ROOT::RCompressionSetting::EDefaults::kUseCompiledDefault, Int_t netopt=0)
Create / open a file.
Definition TFile.cxx:4086
static Bool_t SetCacheFileDir(std::string_view cacheDir, Bool_t operateDisconnected=kTRUE, Bool_t forceCacheread=kFALSE)
Sets the directory where to locally stage/cache remote files.
Definition TFile.cxx:4623
void AddSpectator(const TString &expression, const TString &title="", const TString &unit="", Double_t min=0, Double_t max=0)
user inserts target in data set info
void SetBackgroundWeightExpression(const TString &variable)
void AddVariable(const TString &expression, const TString &title, const TString &unit, char type='F', Double_t min=0, Double_t max=0)
user inserts discriminating variable in data set info
virtual void BookMethod(TString methodname, TString methodtitle, TString options="")
Method to book the machine learning method to perform the algorithm.
Definition Envelope.cxx:163
std::vector< ClassificationResult > & GetResults()
Return the vector of TMVA::Experimental::ClassificationResult objects.
virtual void Evaluate()
Method to perform Train/Test over all ml method booked.
static Tools & Instance()
Definition Tools.cxx:71
A TMultiGraph is a collection of TGraph (or derived) objects.
Definition TMultiGraph.h:34
Basic string class.
Definition TString.h:139
virtual Bool_t AccessPathName(const char *path, EAccessMode mode=kFileExists)
Returns FALSE if one can access a file using the specified access mode.
Definition TSystem.cxx:1296
A TTree represents a columnar dataset.
Definition TTree.h:79
void classification(UInt_t jobs=4)

#include <TMVA/Classification.h>


The documentation for this class was generated from the following file: