Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
Classification.h
Go to the documentation of this file.
1// @(#)root/tmva $Id$ 2017
2// Authors: Omar Zapata, Andreas Hoecker, Peter Speckmayer, Joerg Stelzer, Helge Voss, Kai Voss, Eckhard von Toerne,
3// Jan Therhaag
4
5#ifndef ROOT_TMVA_Classification
6#define ROOT_TMVA_Classification
7
8#include <TString.h>
9#include <TMultiGraph.h>
10#include <vector>
11#include <map>
12
13#include <TMVA/IMethod.h>
14#include <TMVA/MethodBase.h>
15#include <TMVA/Configurable.h>
16#include <TMVA/Types.h>
17#include <TMVA/DataSet.h>
18#include <TMVA/Event.h>
19#include <TMVA/Results.h>
22#include <TMVA/Factory.h>
23#include <TMVA/DataLoader.h>
24#include <TMVA/OptionMap.h>
25#include <TMVA/Envelope.h>
26
27/*! \class TMVA::ClassificationResult
28 * Class to save the results of the classifier.
29 * Every machine learning method booked have an object for the results
30 * in the classification process, in this class is stored the mvas,
31 * data loader name and ml method name and title.
32 * You can to display the results calling the method Show, get the ROC-integral with the
33 * method GetROCIntegral or get the TMVA::ROCCurve object calling GetROC.
34\ingroup TMVA
35*/
36
37/*! \class TMVA::Classification
38 * Class to perform two class classification.
39 * The first step before any analysis is to prepare the data,
40 * to do that you need to create an object of TMVA::DataLoader,
41 * in this object you need to configure the variables and the number of events
42 * to train/test.
43 * The class TMVA::Experimental::Classification needs a TMVA::DataLoader object,
44 * optional a TFile object to save the results and some extra options in a string
45 * like "V:Color:Transformations=I;D;P;U;G:Silent:DrawProgressBar:ModelPersistence:Jobs=2" where:
46 * V = verbose output
47 * Color = coloured screen output
48 * Silent = batch mode: boolean silent flag inhibiting any output from TMVA
49 * Transformations = list of transformations to test.
50 * DrawProgressBar = draw progress bar to display training and testing.
51 * ModelPersistence = to save the trained model in xml or serialized files.
52 * Jobs = number of ml methods to test/train in parallel using MultiProc, requires to call Evaluate method.
53 * Basic example.
54 * \code
55void classification(UInt_t jobs = 2)
56{
57 TMVA::Tools::Instance();
58
59 TFile *input(0);
60 TString fname = "./tmva_class_example.root";
61 if (!gSystem->AccessPathName(fname)) {
62 input = TFile::Open(fname); // check if file in local directory exists
63 } else {
64 TFile::SetCacheFileDir(".");
65 input = TFile::Open("http://root.cern/files/tmva_class_example.root", "CACHEREAD");
66 }
67 if (!input) {
68 std::cout << "ERROR: could not open data file" << std::endl;
69 exit(1);
70 }
71
72 // Register the training and test trees
73
74 TTree *signalTree = (TTree *)input->Get("TreeS");
75 TTree *background = (TTree *)input->Get("TreeB");
76
77 TMVA::DataLoader *dataloader = new TMVA::DataLoader("dataset");
78
79 dataloader->AddVariable("myvar1 := var1+var2", 'F');
80 dataloader->AddVariable("myvar2 := var1-var2", "Expression 2", "", 'F');
81 dataloader->AddVariable("var3", "Variable 3", "units", 'F');
82 dataloader->AddVariable("var4", "Variable 4", "units", 'F');
83
84 dataloader->AddSpectator("spec1 := var1*2", "Spectator 1", "units", 'F');
85 dataloader->AddSpectator("spec2 := var1*3", "Spectator 2", "units", 'F');
86
87 // global event weights per tree (see below for setting event-wise weights)
88 Double_t signalWeight = 1.0;
89 Double_t backgroundWeight = 1.0;
90
91 dataloader->SetBackgroundWeightExpression("weight");
92
93 TMVA::Experimental::Classification *cl = new TMVA::Experimental::Classification(dataloader, Form("Jobs=%d", jobs));
94
95 cl->BookMethod(TMVA::Types::kBDT, "BDTG", "!H:!V:NTrees=2000:MinNodeSize=2.5%:BoostType=Grad:Shrinkage=0.10:"
96 "UseBaggedBoost:BaggedSampleFraction=0.5:nCuts=20:MaxDepth=2");
97 cl->BookMethod(TMVA::Types::kSVM, "SVM", "Gamma=0.25:Tol=0.001:VarTransform=Norm");
98
99 cl->Evaluate(); // Train and Test all methods
100
101 auto &results = cl->GetResults();
102
103 TCanvas *c = new TCanvas(Form("ROC"));
104 c->SetTitle("ROC-Integral Curve");
105
106 auto mg = new TMultiGraph();
107 for (UInt_t i = 0; i < results.size(); i++) {
108 auto roc = results[i].GetROCGraph();
109 roc->SetLineColorAlpha(i + 1, 0.1);
110 mg->Add(roc);
111 }
112 mg->Draw("AL");
113 mg->GetXaxis()->SetTitle(" Signal Efficiency ");
114 mg->GetYaxis()->SetTitle(" Background Rejection ");
115 c->BuildLegend(0.15, 0.15, 0.3, 0.3);
116 c->Draw();
117
118 delete cl;
119}
120 * \endcode
121 *
122\ingroup TMVA
123*/
124
125namespace TMVA {
126class ResultsClassification;
127namespace Experimental {
129 friend class Classification;
130
131private:
134 std::map<UInt_t, std::vector<std::tuple<Float_t, Float_t, Bool_t>>> fMvaTrain; ///< Mvas for two-class classification
135 std::map<UInt_t, std::vector<std::tuple<Float_t, Float_t, Bool_t>>> fMvaTest; ///< Mvas for two-class and multiclass classification
136 std::vector<TString> fClassNames;
137
139 Bool_t fIsCuts; ///< if it is a method cuts need special output
141
142public:
146
147 const TString GetMethodName() const { return fMethod.GetValue<TString>("MethodName"); }
148 const TString GetMethodTitle() const { return fMethod.GetValue<TString>("MethodTitle"); }
153
154 void Show();
155
158
160};
161
162class Classification : public Envelope {
163 std::vector<ClassificationResult> fResults; ///<!
164 std::vector<IMethod *> fIMethods; ///<! vector of objects with booked methods
168public:
169 explicit Classification(DataLoader *loader, TFile *file, TString options);
170 explicit Classification(DataLoader *loader, TString options);
172
173 virtual void Train();
176
177 virtual void Test();
180
181 virtual void Evaluate();
182
183 std::vector<ClassificationResult> &GetResults();
184
186
187protected:
195
197
199 void CopyFrom(TDirectory *src, TFile *file);
200 void MergeFiles();
201
203};
204} // namespace Experimental
205} // namespace TMVA
206
207#endif // ROOT_TMVA_Classification
#define ClassDef(name, id)
Definition Rtypes.h:342
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t r
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t index
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t src
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t Atom_t Time_t type
Describe directory structure in memory.
Definition TDirectory.h:45
A ROOT file is an on-disk file, usually with extension .root, that stores objects in a file-system-li...
Definition TFile.h:53
A TGraph is an object made of two arrays X and Y with npoints each.
Definition TGraph.h:41
Abstract base class for all high level ml algorithms, you can book ml methods like BDT,...
Definition Envelope.h:44
Double_t GetROCIntegral(UInt_t iClass=0, TMVA::Types::ETreeType type=TMVA::Types::kTesting)
Method to get ROC-Integral value from mvas.
TGraph * GetROCGraph(UInt_t iClass=0, TMVA::Types::ETreeType type=TMVA::Types::kTesting)
Method to get TGraph object with the ROC curve.
void Show()
Method to print the results in stdout.
Bool_t IsMethod(TString methodname, TString methodtitle)
Method to check if method was booked.
std::map< UInt_t, std::vector< std::tuple< Float_t, Float_t, Bool_t > > > fMvaTest
Mvas for two-class and multiclass classification.
ROCCurve * GetROC(UInt_t iClass=0, TMVA::Types::ETreeType type=TMVA::Types::kTesting)
Method to get TMVA::ROCCurve Object.
Bool_t fIsCuts
if it is a method cuts need special output
ClassificationResult & operator=(const ClassificationResult &r)
std::map< UInt_t, std::vector< std::tuple< Float_t, Float_t, Bool_t > > > fMvaTrain
Mvas for two-class classification.
std::vector< ClassificationResult > fResults
!
Classification(DataLoader *loader, TFile *file, TString options)
Contructor to create a two class classifier.
Double_t GetROCIntegral(TString methodname, TString methodtitle, UInt_t iClass=0)
Method to get ROC-Integral value from mvas.
virtual void Test()
Perform test evaluation in all booked methods.
TString GetMethodOptions(TString methodname, TString methodtitle)
return the options for the booked method.
MethodBase * GetMethod(TString methodname, TString methodtitle)
Return a TMVA::MethodBase object.
virtual void TrainMethod(TString methodname, TString methodtitle)
Lets train an specific ml method.
Bool_t HasMethodObject(TString methodname, TString methodtitle, Int_t &index)
Allows to check if the TMVA::MethodBase was created and return the index in the vector.
std::vector< ClassificationResult > & GetResults()
Return the vector of TMVA::Experimental::ClassificationResult objects.
std::vector< IMethod * > fIMethods
! vector of objects with booked methods
virtual void Train()
Method to train all booked ml methods.
virtual void Evaluate()
Method to perform Train/Test over all ml method booked.
Types::EAnalysisType fAnalysisType
!
TMVA::ROCCurve * GetROC(TMVA::MethodBase *method, UInt_t iClass=0, TMVA::Types::ETreeType type=TMVA::Types::kTesting)
Method to get TMVA::ROCCurve Object.
Bool_t IsCutsMethod(TMVA::MethodBase *method)
Allows to check if the ml method is a Cuts method.
void CopyFrom(TDirectory *src, TFile *file)
virtual void TestMethod(TString methodname, TString methodtitle)
Lets perform test an specific ml method.
Virtual base Class for all MVA method.
Definition MethodBase.h:111
class to storage options for the differents methods
Definition OptionMap.h:34
T GetValue(const TString &key)
Definition OptionMap.h:133
Mother of all ROOT objects.
Definition TObject.h:41
Basic string class.
Definition TString.h:139
create variable transformations