Logo ROOT   6.12/07
Reference Guide
classification.C
Go to the documentation of this file.
1 #include "TMVA/Factory.h"
2 #include "TMVA/DataLoader.h"
3 #include "TMVA/Tools.h"
4 #include "TMVA/Classification.h"
5 
6 void classification(UInt_t jobs = 4)
7 {
9 
10  TFile *input(0);
11  TString fname = "./tmva_class_example.root";
12  if (!gSystem->AccessPathName(fname)) {
13  input = TFile::Open(fname); // check if file in local directory exists
14  } else {
16  input = TFile::Open("http://root.cern.ch/files/tmva_class_example.root", "CACHEREAD");
17  }
18  if (!input) {
19  std::cout << "ERROR: could not open data file" << std::endl;
20  exit(1);
21  }
22 
23  // Register the training and test trees
24 
25  TTree *signalTree = (TTree *)input->Get("TreeS");
26  TTree *background = (TTree *)input->Get("TreeB");
27 
29  // If you wish to modify default settings
30  // (please check "src/Config.h" to see all available global options)
31  //
32  // (TMVA::gConfig().GetVariablePlotting()).fTimesRMS = 8.0;
33  // (TMVA::gConfig().GetIONames()).fWeightFileDir = "myWeightDirectory";
34 
35  // Define the input variables that shall be used for the MVA training
36  // note that you may also use variable expressions, such as: "3*var1/var2*abs(var3)"
37  // [all types of expressions that can also be parsed by TTree::Draw( "expression" )]
38  dataloader->AddVariable("myvar1 := var1+var2", 'F');
39  dataloader->AddVariable("myvar2 := var1-var2", "Expression 2", "", 'F');
40  dataloader->AddVariable("var3", "Variable 3", "units", 'F');
41  dataloader->AddVariable("var4", "Variable 4", "units", 'F');
42 
43  // You can add so-called "Spectator variables", which are not used in the MVA training,
44  // but will appear in the final "TestTree" produced by TMVA. This TestTree will contain the
45  // input variables, the response values of all trained MVAs, and the spectator variables
46 
47  dataloader->AddSpectator("spec1 := var1*2", "Spectator 1", "units", 'F');
48  dataloader->AddSpectator("spec2 := var1*3", "Spectator 2", "units", 'F');
49 
50  // global event weights per tree (see below for setting event-wise weights)
51  Double_t signalWeight = 1.0;
52  Double_t backgroundWeight = 1.0;
53 
54  // You can add an arbitrary number of signal or background trees
55  dataloader->AddSignalTree(signalTree, signalWeight);
56  dataloader->AddBackgroundTree(background, backgroundWeight);
57 
58  // Set individual event weights (the variables must exist in the original TTree)
59  // - for signal : `dataloader->SetSignalWeightExpression ("weight1*weight2");`
60  // - for background: `dataloader->SetBackgroundWeightExpression("weight1*weight2");`
61  dataloader->SetBackgroundWeightExpression("weight");
62 
63  TFile *outputFile = TFile::Open("TMVAClass.root", "RECREATE");
64 
65  TMVA::Experimental::Classification *cl = new TMVA::Experimental::Classification(dataloader, Form("Jobs=%d", jobs));
66 
67  cl->BookMethod(TMVA::Types::kBDT, "BDT", "!H:!V:NTrees=2000:MinNodeSize=2.5%:MaxDepth=3:BoostType=AdaBoost:"
68  "AdaBoostBeta=0.5:UseBaggedBoost:BaggedSampleFraction=0.5:SeparationType="
69  "GiniIndex:nCuts=20");
70  cl->BookMethod(TMVA::Types::kBDT, "BDTG", "!H:!V:NTrees=2000:MinNodeSize=2.5%:BoostType=Grad:Shrinkage=0.10:"
71  "UseBaggedBoost:BaggedSampleFraction=0.5:nCuts=20:MaxDepth=2");
72  cl->BookMethod(TMVA::Types::kSVM, "SVM", "Gamma=0.25:Tol=0.001:VarTransform=Norm");
73 
74  cl->BookMethod(TMVA::Types::kBDT, "BDTB", "!H:!V:NTrees=2000:BoostType=Bagging:SeparationType=GiniIndex:nCuts=20");
75 
76  cl->Evaluate(); // Train and Test all methods
77 
78  auto &results = cl->GetResults();
79 
80  TCanvas *c = new TCanvas(Form("ROC"));
81  c->SetTitle("ROC-Integral Curve");
82 
83  auto mg = new TMultiGraph();
84  for (UInt_t i = 0; i < results.size(); i++) {
85  auto roc = results[i].GetROCGraph();
86  roc->SetLineColorAlpha(i + 1, 0.1);
87  mg->Add(roc);
88  }
89  mg->Draw("AL");
90  mg->GetXaxis()->SetTitle(" Signal Efficiency ");
91  mg->GetYaxis()->SetTitle(" Background Rejection ");
92  c->BuildLegend(0.15, 0.15, 0.3, 0.3);
93  c->Draw();
94 
95  outputFile->Close();
96  delete cl;
97 }
void AddBackgroundTree(TTree *background, Double_t weight=1.0, Types::ETreeType treetype=Types::kMaxTreeType)
number of signal events (used to compute significance)
Definition: DataLoader.cxx:408
virtual Bool_t AccessPathName(const char *path, EAccessMode mode=kFileExists)
Returns FALSE if one can access a file using the specified access mode.
Definition: TSystem.cxx:1276
static Tools & Instance()
Definition: Tools.cxx:75
virtual void BookMethod(TString methodname, TString methodtitle, TString options="")
Method to book the machine learning method to perform the algorithm.
Definition: Envelope.cxx:158
void SetTitle(const char *title="")
Set canvas title.
Definition: TCanvas.cxx:1956
A ROOT file is a suite of consecutive data records (TKey instances) with a well defined format...
Definition: TFile.h:46
A TMultiGraph is a collection of TGraph (or derived) objects.
Definition: TMultiGraph.h:35
static Bool_t SetCacheFileDir(ROOT::Internal::TStringView cacheDir, Bool_t operateDisconnected=kTRUE, Bool_t forceCacheread=kFALSE)
Definition: TFile.h:303
virtual TObject * Get(const char *namecycle)
Return pointer to object identified by namecycle.
Basic string class.
Definition: TString.h:125
void AddVariable(const TString &expression, const TString &title, const TString &unit, char type='F', Double_t min=0, Double_t max=0)
user inserts discriminating variable in data set info
Definition: DataLoader.cxx:491
static constexpr double mg
static TFile * Open(const char *name, Option_t *option="", const char *ftitle="", Int_t compress=1, Int_t netopt=0)
Create / open a file.
Definition: TFile.cxx:3950
R__EXTERN TSystem * gSystem
Definition: TSystem.h:540
unsigned int UInt_t
Definition: RtypesCore.h:42
char * Form(const char *fmt,...)
The Canvas class.
Definition: TCanvas.h:31
double Double_t
Definition: RtypesCore.h:55
void SetBackgroundWeightExpression(const TString &variable)
Definition: DataLoader.cxx:553
virtual void Draw(Option_t *option="")
Draw a canvas.
Definition: TCanvas.cxx:826
void classification(UInt_t jobs=4)
Definition: classification.C:6
virtual TLegend * BuildLegend(Double_t x1=0.3, Double_t y1=0.21, Double_t x2=0.3, Double_t y2=0.21, const char *title="", Option_t *option="")
Build a legend from the graphical objects in the pad.
Definition: TPad.cxx:485
std::vector< ClassificationResult > & GetResults()
A TTree object has a header with a name and a title.
Definition: TTree.h:70
virtual void Evaluate()
Virtual method to be implemented with your algorithm.
void AddSignalTree(TTree *signal, Double_t weight=1.0, Types::ETreeType treetype=Types::kMaxTreeType)
number of signal events (used to compute significance)
Definition: DataLoader.cxx:377
void AddSpectator(const TString &expression, const TString &title="", const TString &unit="", Double_t min=0, Double_t max=0)
user inserts target in data set info
Definition: DataLoader.cxx:521
virtual void Close(Option_t *option="")
Close a file.
Definition: TFile.cxx:916