Logo ROOT   6.10/09
Reference Guide
RuleCut.cxx
Go to the documentation of this file.
1 // @(#)root/tmva $Id$
2 // Author: Andreas Hoecker, Joerg Stelzer, Fredrik Tegenfeldt, Helge Voss
3 
4 /**********************************************************************************
5  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6  * Package: TMVA *
7  * Class : Rule *
8  * *
9  * Description: *
10  * A class describing a 'rule cut' *
11  * *
12  * *
13  * Authors (alphabetical): *
14  * Fredrik Tegenfeldt <Fredrik.Tegenfeldt@cern.ch> - Iowa State U., USA *
15  * *
16  * Copyright (c) 2005: *
17  * CERN, Switzerland *
18  * Iowa State U. *
19  * *
20  * Redistribution and use in source and binary forms, with or without *
21  * modification, are permitted according to the terms listed in LICENSE *
22  * (http://tmva.sourceforge.net/LICENSE) *
23  **********************************************************************************/
24 
25 /*! \class TMVA::RuleCut
26 \ingroup TMVA
27 A class describing a 'rule cut'
28 */
29 
30 #include <algorithm>
31 #include <list>
32 
33 #include "TMVA/RuleCut.h"
34 #include "TMVA/DecisionTree.h"
35 #include "TMVA/MsgLogger.h"
36 
37 ////////////////////////////////////////////////////////////////////////////////
38 /// main constructor
39 
40 TMVA::RuleCut::RuleCut( const std::vector<const Node*> & nodes )
41  : fCutNeve(0),
42  fPurity(0),
43  fLogger(new MsgLogger("RuleFit"))
44 {
45  MakeCuts( nodes );
46 }
47 
48 ////////////////////////////////////////////////////////////////////////////////
49 /// empty constructor
50 
52  : fCutNeve(0),
53  fPurity(0),
54  fLogger(new MsgLogger("RuleFit"))
55 {
56 }
57 
58 ////////////////////////////////////////////////////////////////////////////////
59 /// destructor
60 
62  delete fLogger;
63 }
64 
65 
66 ////////////////////////////////////////////////////////////////////////////////
67 /// Construct the cuts from the given array of nodes
68 
69 void TMVA::RuleCut::MakeCuts( const std::vector<const Node*> & nodes )
70 {
71  // At least 2 nodes are required
72  UInt_t nnodes = nodes.size();
73  if (nnodes<2) {
74  Log() << kWARNING << "<MakeCuts()> Empty cut created." << Endl;
75  return;
76  }
77 
78  // Set number of events and S/S+B in last node
79  const DecisionTreeNode* dtn = dynamic_cast<const DecisionTreeNode*>(nodes.back());
80  if(!dtn) return;
81  fCutNeve = dtn->GetNEvents();
82  fPurity = dtn->GetPurity();
83 
84  // some local typedefs
85  typedef std::pair<Double_t,Int_t> CutDir_t; // first is cut value, second is direction
86  typedef std::pair<Int_t,CutDir_t> SelCut_t;
87 
88  // Clear vectors
89  fSelector.clear();
90  fCutMin.clear();
91  fCutMax.clear();
92  fCutDoMin.clear();
93  fCutDoMax.clear();
94 
95  // Count the number of variables in cut
96  // Exclude last node since that does not lead to a cut
97  std::list<SelCut_t> allsel;
98  Int_t sel;
99  Double_t val;
100  Int_t dir;
101  const Node *nextNode;
102  for ( UInt_t i=0; i<nnodes-1; i++) {
103  nextNode = nodes[i+1];
104  const DecisionTreeNode* dtn_ = dynamic_cast<const DecisionTreeNode*>(nodes[i]);
105  if(!dtn_) return;
106  sel = dtn_->GetSelector();
107  val = dtn_->GetCutValue();
108  if (nodes[i]->GetRight() == nextNode) { // val>cut
109  dir = 1;
110  }
111  else if (nodes[i]->GetLeft() == nextNode) { // val<cut
112  dir = -1;
113  }
114  else {
115  Log() << kFATAL << "<MakeTheRule> BUG! Should not be here - an end-node before the end!" << Endl;
116  dir = 0;
117  }
118  allsel.push_back(SelCut_t(sel,CutDir_t(val,dir)));
119  }
120  // sort after the selector (first element of CutDir_t)
121  allsel.sort();
122  Int_t prevSel=-1;
123  Int_t nsel=0;
124  Bool_t firstMin=kTRUE;
125  Bool_t firstMax=kTRUE;
126  for ( std::list<SelCut_t>::const_iterator it = allsel.begin(); it!=allsel.end(); it++ ) {
127  sel = (*it).first;
128  val = (*it).second.first;
129  dir = (*it).second.second;
130  if (sel!=prevSel) { // a new selector!
131  firstMin = kTRUE;
132  firstMax = kTRUE;
133  nsel++;
134  fSelector.push_back(sel);
135  fCutMin.resize( fSelector.size(),0);
136  fCutMax.resize( fSelector.size(),0);
137  fCutDoMin.resize( fSelector.size(), kFALSE);
138  fCutDoMax.resize( fSelector.size(), kFALSE);
139  }
140  switch ( dir ) {
141  case 1:
142  if ((val<fCutMin[nsel-1]) || firstMin) {
143  fCutMin[nsel-1] = val;
144  fCutDoMin[nsel-1] = kTRUE;
145  firstMin = kFALSE;
146  }
147  break;
148  case -1:
149  if ((val>fCutMax[nsel-1]) || firstMax) {
150  fCutMax[nsel-1] = val;
151  fCutDoMax[nsel-1] = kTRUE;
152  firstMax = kFALSE;
153  }
154  default:
155  break;
156  }
157  prevSel = sel;
158  }
159 }
160 
161 ////////////////////////////////////////////////////////////////////////////////
162 /// get number of cuts
163 
165 {
166  UInt_t rval=0;
167  for (UInt_t i=0; i<fSelector.size(); i++) {
168  if (fCutDoMin[i]) rval += 1;
169  if (fCutDoMax[i]) rval += 1;
170  }
171  return rval;
172 }
173 ////////////////////////////////////////////////////////////////////////////////
174 /// get cut range for a given selector
175 
176 Bool_t TMVA::RuleCut::GetCutRange(Int_t sel,Double_t &rmin, Double_t &rmax, Bool_t &dormin, Bool_t &dormax) const
177 {
178  dormin=kFALSE;
179  dormax=kFALSE;
180  Bool_t done=kFALSE;
181  Bool_t foundIt=kFALSE;
182  UInt_t ind=0;
183  while (!done) {
184  foundIt = (Int_t(fSelector[ind])==sel);
185  ind++;
186  done = (foundIt || (ind==fSelector.size()));
187  }
188  if (!foundIt) return kFALSE;
189  rmin = fCutMin[ind-1];
190  rmax = fCutMax[ind-1];
191  dormin = fCutDoMin[ind-1];
192  dormax = fCutDoMax[ind-1];
193  return kTRUE;
194 }
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:158
Float_t GetNEvents(void) const
Double_t fPurity
Definition: RuleCut.h:92
Bool_t GetCutRange(Int_t sel, Double_t &rmin, Double_t &rmax, Bool_t &dormin, Bool_t &dormax) const
get cut range for a given selector
Definition: RuleCut.cxx:176
MsgLogger * fLogger
Definition: RuleCut.h:95
int Int_t
Definition: RtypesCore.h:41
bool Bool_t
Definition: RtypesCore.h:59
std::vector< Double_t > fCutMax
Definition: RuleCut.h:88
std::vector< Char_t > fCutDoMin
Definition: RuleCut.h:89
MsgLogger & Log() const
Definition: RuleCut.h:96
Float_t GetCutValue(void) const
MsgLogger * fLogger
Definition: Rule.h:181
RuleCut()
empty constructor
Definition: RuleCut.cxx:51
std::vector< Char_t > fCutDoMax
Definition: RuleCut.h:90
std::vector< UInt_t > fSelector
Definition: RuleCut.h:86
UInt_t GetNcuts() const
get number of cuts
Definition: RuleCut.cxx:164
unsigned int UInt_t
Definition: RtypesCore.h:42
const Bool_t kFALSE
Definition: RtypesCore.h:92
Float_t GetPurity(void) const
double Double_t
Definition: RtypesCore.h:55
ostringstream derivative to redirect and format output
Definition: MsgLogger.h:59
void MakeCuts(const std::vector< const TMVA::Node * > &nodes)
Construct the cuts from the given array of nodes.
Definition: RuleCut.cxx:69
Node for the BinarySearch or Decision Trees.
Definition: Node.h:56
std::vector< Double_t > fCutMin
Definition: RuleCut.h:87
virtual ~RuleCut()
destructor
Definition: RuleCut.cxx:61
Short_t GetSelector() const
const Bool_t kTRUE
Definition: RtypesCore.h:91
Double_t fCutNeve
Definition: RuleCut.h:91