Logo ROOT  
Reference Guide
Rule.cxx
Go to the documentation of this file.
1// @(#)root/tmva $Id$
2// Author: Andreas Hoecker, Joerg Stelzer, Fredrik Tegenfeldt, Helge Voss
3
4/**********************************************************************************
5 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6 * Package: TMVA *
7 * Class : Rule *
8 * Web : http://tmva.sourceforge.net *
9 * *
10 * Description: *
11 * A class describing a 'rule' *
12 * Each internal node of a tree defines a rule from all the parental nodes. *
13 * A rule consists of at least 2 nodes. *
14 * Input: a decision tree (in the constructor) *
15 * *
16 * Authors (alphabetical): *
17 * Fredrik Tegenfeldt <Fredrik.Tegenfeldt@cern.ch> - Iowa State U., USA *
18 * Helge Voss <Helge.Voss@cern.ch> - MPI-KP Heidelberg, Ger. *
19 * *
20 * Copyright (c) 2005: *
21 * CERN, Switzerland *
22 * Iowa State U. *
23 * MPI-K Heidelberg, Germany *
24 * *
25 * Redistribution and use in source and binary forms, with or without *
26 * modification, are permitted according to the terms listed in LICENSE *
27 * (http://tmva.sourceforge.net/LICENSE) *
28 **********************************************************************************/
29
30/*! \class TMVA::Rule
31\ingroup TMVA
32
33Implementation of a rule.
34
35A rule is simply a branch or a part of a branch in a tree.
36It fulfills the following:
37
38 - First node is the root node of the originating tree
39 - Consists of a minimum of 2 nodes
40 - A rule returns for a given event:
41 - 0 : if the event fails at any node
42 - 1 : otherwise
43 - If the rule contains <2 nodes, it returns 0 SHOULD NOT HAPPEN!
44
45The coefficient is found by either brute force or some sort of
46intelligent fitting. See the RuleEnsemble class for more info.
47*/
48
49#include "TMVA/Rule.h"
50
51#include "TMVA/Event.h"
52#include "TMVA/MethodBase.h"
53#include "TMVA/MethodRuleFit.h"
54#include "TMVA/MsgLogger.h"
55#include "TMVA/RuleCut.h"
56#include "TMVA/RuleFit.h"
57#include "TMVA/RuleEnsemble.h"
58#include "TMVA/Tools.h"
59#include "TMVA/Types.h"
60
61#include <iomanip>
62
63////////////////////////////////////////////////////////////////////////////////
64/// the main constructor for a Rule
65
67 const std::vector< const Node * >& nodes )
68 : fCut ( 0 )
69 , fNorm ( 1.0 )
70 , fSupport ( 0.0 )
71 , fSigma ( 0.0 )
72 , fCoefficient ( 0.0 )
73 , fImportance ( 0.0 )
74 , fImportanceRef ( 1.0 )
75 , fRuleEnsemble ( re )
76 , fSSB ( 0 )
77 , fSSBNeve ( 0 )
78 , fLogger( new MsgLogger("RuleFit") )
79{
80 //
81 // input:
82 // nodes - a vector of Node; from these all possible rules will be created
83 //
84 //
85
86 fCut = new RuleCut( nodes );
87 fSSB = fCut->GetPurity();
89}
90
91////////////////////////////////////////////////////////////////////////////////
92/// the simple constructor
93
95 : fCut ( 0 )
96 , fNorm ( 1.0 )
97 , fSupport ( 0.0 )
98 , fSigma ( 0.0 )
99 , fCoefficient ( 0.0 )
100 , fImportance ( 0.0 )
101 , fImportanceRef ( 1.0 )
102 , fRuleEnsemble ( re )
103 , fSSB ( 0 )
104 , fSSBNeve ( 0 )
105 , fLogger( new MsgLogger("RuleFit") )
106{
107}
108
109////////////////////////////////////////////////////////////////////////////////
110/// the simple constructor
111
113 : fCut ( 0 )
114 , fNorm ( 1.0 )
115 , fSupport ( 0.0 )
116 , fSigma ( 0.0 )
117 , fCoefficient ( 0.0 )
118 , fImportance ( 0.0 )
119 , fImportanceRef ( 1.0 )
120 , fRuleEnsemble ( 0 )
121 , fSSB ( 0 )
122 , fSSBNeve ( 0 )
123 , fLogger( new MsgLogger("RuleFit") )
124{
125}
126
127////////////////////////////////////////////////////////////////////////////////
128/// destructor
129
131{
132 delete fCut;
133 delete fLogger;
134}
135
136////////////////////////////////////////////////////////////////////////////////
137/// check if variable in node
138
140{
141 Bool_t found = kFALSE;
142 Bool_t doneLoop = kFALSE;
143 UInt_t nvars = fCut->GetNvars();
144 UInt_t i = 0;
145 //
146 while (!doneLoop) {
147 found = (fCut->GetSelector(i) == iv);
148 i++;
149 doneLoop = (found || (i==nvars));
150 }
151 return found;
152}
153
154////////////////////////////////////////////////////////////////////////////////
155
157{
158 fLogger->SetMinType(t);
159}
160
161
162////////////////////////////////////////////////////////////////////////////////
163/// Compare two rules.
164///
165/// - useCutValue:
166/// - true -> calculate a distance between the two rules based on the cut values
167/// if the rule cuts are not equal, the distance is < 0 (-1.0)
168/// return true if d<mindist
169/// - false-> ignore mindist, return true if rules are equal, ignoring cut values
170/// - mindist: min distance allowed between rules; if < 0 => set useCutValue=false;
171
172Bool_t TMVA::Rule::Equal( const Rule& other, Bool_t useCutValue, Double_t mindist ) const
173{
174 Bool_t rval=kFALSE;
175 if (mindist<0) useCutValue=kFALSE;
176 Double_t d = RuleDist( other, useCutValue );
177 // cut value used - return true if 0<=d<mindist
178 if (useCutValue) rval = ( (!(d<0)) && (d<mindist) );
179 else rval = (!(d<0));
180 // cut value not used, return true if <> -1
181 return rval;
182}
183
184////////////////////////////////////////////////////////////////////////////////
185/// Returns:
186///
187/// * -1.0 : rules are NOT equal, i.e, variables and/or cut directions are wrong
188/// * >=0: rules are equal apart from the cutvalue, returns \f$ d = \sqrt{\sum(c1-c2)^2} \f$
189///
190/// If not useCutValue, the distance is exactly zero if they are equal
191
192Double_t TMVA::Rule::RuleDist( const Rule& other, Bool_t useCutValue ) const
193{
194 if (fCut->GetNvars()!=other.GetRuleCut()->GetNvars()) return -1.0; // check number of cuts
195 //
196 const UInt_t nvars = fCut->GetNvars();
197 //
198 Int_t sel; // cut variable
199 Double_t rms; // rms of cut variable
200 Double_t smin; // distance between the lower range
201 Double_t smax; // distance between the upper range
202 Double_t vminA,vmaxA; // min,max range of cut A (cut from this Rule)
203 Double_t vminB,vmaxB; // idem from other Rule
204 //
205 // compare nodes
206 // A 'distance' is assigned if the two rules has exactly the same set of cuts but with
207 // different cut values.
208 // The distance is given in number of sigmas
209 //
210 UInt_t in = 0; // cut index
211 Double_t sumdc2 = 0; // sum of 'distances'
212 Bool_t equal = true; // flag if cut are equal
213 //
214 const RuleCut *otherCut = other.GetRuleCut();
215 while ((equal) && (in<nvars)) {
216 // check equality in cut topology
217 equal = ( (fCut->GetSelector(in) == (otherCut->GetSelector(in))) &&
218 (fCut->GetCutDoMin(in) == (otherCut->GetCutDoMin(in))) &&
219 (fCut->GetCutDoMax(in) == (otherCut->GetCutDoMax(in))) );
220 // if equal topology, check cut values
221 if (equal) {
222 if (useCutValue) {
223 sel = fCut->GetSelector(in);
224 vminA = fCut->GetCutMin(in);
225 vmaxA = fCut->GetCutMax(in);
226 vminB = other.GetRuleCut()->GetCutMin(in);
227 vmaxB = other.GetRuleCut()->GetCutMax(in);
228 // messy - but ok...
229 rms = fRuleEnsemble->GetRuleFit()->GetMethodBase()->GetRMS(sel);
230 smin=0;
231 smax=0;
232 if (fCut->GetCutDoMin(in))
233 smin = ( rms>0 ? (vminA-vminB)/rms : 0 );
234 if (fCut->GetCutDoMax(in))
235 smax = ( rms>0 ? (vmaxA-vmaxB)/rms : 0 );
236 sumdc2 += smin*smin + smax*smax;
237 // sumw += 1.0/(rms*rms); // TODO: probably not needed
238 }
239 }
240 in++;
241 }
242 if (!useCutValue) sumdc2 = (equal ? 0.0:-1.0); // ignore cut values
243 else sumdc2 = (equal ? sqrt(sumdc2) : -1.0);
244
245 return sumdc2;
246}
247
248////////////////////////////////////////////////////////////////////////////////
249/// comparison operator ==
250
251Bool_t TMVA::Rule::operator==( const Rule& other ) const
252{
253 return this->Equal( other, kTRUE, 1e-3 );
254}
255
256////////////////////////////////////////////////////////////////////////////////
257/// comparison operator <
258
259Bool_t TMVA::Rule::operator<( const Rule& other ) const
260{
261 return (fImportance < other.GetImportance());
262}
263
264////////////////////////////////////////////////////////////////////////////////
265/// std::ostream operator
266
267std::ostream& TMVA::operator<< ( std::ostream& os, const Rule& rule )
268{
269 rule.Print( os );
270 return os;
271}
272
273////////////////////////////////////////////////////////////////////////////////
274/// returns the name of a rule
275
277{
278 return fRuleEnsemble->GetMethodBase()->GetInputLabel(i);
279}
280
281////////////////////////////////////////////////////////////////////////////////
282/// copy function
283
284void TMVA::Rule::Copy( const Rule& other )
285{
286 if(this != &other) {
287 SetRuleEnsemble( other.GetRuleEnsemble() );
288 fCut = new RuleCut( *(other.GetRuleCut()) );
289 fSSB = other.GetSSB();
290 fSSBNeve = other.GetSSBNeve();
291 SetCoefficient(other.GetCoefficient());
292 SetSupport( other.GetSupport() );
293 SetSigma( other.GetSigma() );
294 SetNorm( other.GetNorm() );
295 CalcImportance();
296 SetImportanceRef( other.GetImportanceRef() );
297 }
298}
299
300////////////////////////////////////////////////////////////////////////////////
301/// print function
302
303void TMVA::Rule::Print( std::ostream& os ) const
304{
305 const UInt_t nvars = fCut->GetNvars();
306 if (nvars<1) os << " *** WARNING - <EMPTY RULE> ***" << std::endl; // TODO: Fix this, use fLogger
307 //
308 Int_t sel;
309 Double_t valmin, valmax;
310 //
311 os << " Importance = " << Form("%1.4f", fImportance/fImportanceRef) << std::endl;
312 os << " Coefficient = " << Form("%1.4f", fCoefficient) << std::endl;
313 os << " Support = " << Form("%1.4f", fSupport) << std::endl;
314 os << " S/(S+B) = " << Form("%1.4f", fSSB) << std::endl;
315
316 for ( UInt_t i=0; i<nvars; i++) {
317 os << " ";
318 sel = fCut->GetSelector(i);
319 valmin = fCut->GetCutMin(i);
320 valmax = fCut->GetCutMax(i);
321 //
322 os << Form("* Cut %2d",i+1) << " : " << std::flush;
323 if (fCut->GetCutDoMin(i)) os << Form("%10.3g",valmin) << " < " << std::flush;
324 else os << " " << std::flush;
325 os << GetVarName(sel) << std::flush;
326 if (fCut->GetCutDoMax(i)) os << " < " << Form("%10.3g",valmax) << std::flush;
327 else os << " " << std::flush;
328 os << std::endl;
329 }
330}
331
332////////////////////////////////////////////////////////////////////////////////
333/// print function
334
335void TMVA::Rule::PrintLogger(const char *title) const
336{
337 const UInt_t nvars = fCut->GetNvars();
338 if (nvars<1) Log() << kWARNING << "BUG TRAP: EMPTY RULE!!!" << Endl;
339 //
340 Int_t sel;
341 Double_t valmin, valmax;
342 //
343 if (title) Log() << kINFO << title;
344 Log() << kINFO
345 << "Importance = " << Form("%1.4f", fImportance/fImportanceRef) << Endl;
346
347 for ( UInt_t i=0; i<nvars; i++) {
348
349 Log() << kINFO << " ";
350 sel = fCut->GetSelector(i);
351 valmin = fCut->GetCutMin(i);
352 valmax = fCut->GetCutMax(i);
353 //
354 Log() << kINFO << Form("Cut %2d",i+1) << " : ";
355 if (fCut->GetCutDoMin(i)) Log() << kINFO << Form("%10.3g",valmin) << " < ";
356 else Log() << kINFO << " ";
357 Log() << kINFO << GetVarName(sel);
358 if (fCut->GetCutDoMax(i)) Log() << kINFO << " < " << Form("%10.3g",valmax);
359 else Log() << kINFO << " ";
360 Log() << Endl;
361 }
362}
363
364////////////////////////////////////////////////////////////////////////////////
365/// extensive print function used to print info for the weight file
366
367void TMVA::Rule::PrintRaw( std::ostream& os ) const
368{
369 Int_t dp = os.precision();
370 const UInt_t nvars = fCut->GetNvars();
371 os << "Parameters: "
372 << std::setprecision(10)
373 << fImportance << " "
374 << fImportanceRef << " "
375 << fCoefficient << " "
376 << fSupport << " "
377 << fSigma << " "
378 << fNorm << " "
379 << fSSB << " "
380 << fSSBNeve << " "
381 << std::endl; \
382 os << "N(cuts): " << nvars << std::endl; // mark end of nodes
383 for ( UInt_t i=0; i<nvars; i++) {
384 os << "Cut " << i << " : " << std::flush;
385 os << fCut->GetSelector(i)
386 << std::setprecision(10)
387 << " " << fCut->GetCutMin(i)
388 << " " << fCut->GetCutMax(i)
389 << " " << (fCut->GetCutDoMin(i) ? "T":"F")
390 << " " << (fCut->GetCutDoMax(i) ? "T":"F")
391 << std::endl;
392 }
393 os << std::setprecision(dp);
394}
395
396////////////////////////////////////////////////////////////////////////////////
397
398void* TMVA::Rule::AddXMLTo( void* parent ) const
399{
400 void* rule = gTools().AddChild( parent, "Rule" );
401 const UInt_t nvars = fCut->GetNvars();
402
403 gTools().AddAttr( rule, "Importance", fImportance );
404 gTools().AddAttr( rule, "Ref", fImportanceRef );
405 gTools().AddAttr( rule, "Coeff", fCoefficient );
406 gTools().AddAttr( rule, "Support", fSupport );
407 gTools().AddAttr( rule, "Sigma", fSigma );
408 gTools().AddAttr( rule, "Norm", fNorm );
409 gTools().AddAttr( rule, "SSB", fSSB );
410 gTools().AddAttr( rule, "SSBNeve", fSSBNeve );
411 gTools().AddAttr( rule, "Nvars", nvars );
412
413 for (UInt_t i=0; i<nvars; i++) {
414 void* cut = gTools().AddChild( rule, "Cut" );
415 gTools().AddAttr( cut, "Selector", fCut->GetSelector(i) );
416 gTools().AddAttr( cut, "Min", fCut->GetCutMin(i) );
417 gTools().AddAttr( cut, "Max", fCut->GetCutMax(i) );
418 gTools().AddAttr( cut, "DoMin", (fCut->GetCutDoMin(i) ? "T":"F") );
419 gTools().AddAttr( cut, "DoMax", (fCut->GetCutDoMax(i) ? "T":"F") );
420 }
421
422 return rule;
423}
424
425////////////////////////////////////////////////////////////////////////////////
426/// read rule from XML
427
428void TMVA::Rule::ReadFromXML( void* wghtnode )
429{
430 TString nodeName = TString( gTools().GetName(wghtnode) );
431 if (nodeName != "Rule") Log() << kFATAL << "<ReadFromXML> Unexpected node name: " << nodeName << Endl;
432
433 gTools().ReadAttr( wghtnode, "Importance", fImportance );
434 gTools().ReadAttr( wghtnode, "Ref", fImportanceRef );
435 gTools().ReadAttr( wghtnode, "Coeff", fCoefficient );
436 gTools().ReadAttr( wghtnode, "Support", fSupport );
437 gTools().ReadAttr( wghtnode, "Sigma", fSigma );
438 gTools().ReadAttr( wghtnode, "Norm", fNorm );
439 gTools().ReadAttr( wghtnode, "SSB", fSSB );
440 gTools().ReadAttr( wghtnode, "SSBNeve", fSSBNeve );
441
442 UInt_t nvars;
443 gTools().ReadAttr( wghtnode, "Nvars", nvars );
444 if (fCut) delete fCut;
445 fCut = new RuleCut();
446 fCut->SetNvars( nvars );
447
448 // read Cut
449 void* ch = gTools().GetChild( wghtnode );
450 UInt_t i = 0;
451 UInt_t ui;
452 Double_t d;
453 Char_t c;
454 while (ch) {
455 gTools().ReadAttr( ch, "Selector", ui );
456 fCut->SetSelector( i, ui );
457 gTools().ReadAttr( ch, "Min", d );
458 fCut->SetCutMin ( i, d );
459 gTools().ReadAttr( ch, "Max", d );
460 fCut->SetCutMax ( i, d );
461 gTools().ReadAttr( ch, "DoMin", c );
462 fCut->SetCutDoMin( i, (c == 'T' ? kTRUE : kFALSE ) );
463 gTools().ReadAttr( ch, "DoMax", c );
464 fCut->SetCutDoMax( i, (c == 'T' ? kTRUE : kFALSE ) );
465
466 i++;
467 ch = gTools().GetNextChild(ch);
468 }
469
470 // sanity check
471 if (i != nvars) Log() << kFATAL << "<ReadFromXML> Mismatch in number of cuts: " << i << " != " << nvars << Endl;
472}
473
474////////////////////////////////////////////////////////////////////////////////
475/// read function (format is the same as written by PrintRaw)
476
477void TMVA::Rule::ReadRaw( std::istream& istr )
478{
479 TString dummy;
480 UInt_t nvars;
481 istr >> dummy
482 >> fImportance
483 >> fImportanceRef
484 >> fCoefficient
485 >> fSupport
486 >> fSigma
487 >> fNorm
488 >> fSSB
489 >> fSSBNeve;
490 // coverity[tainted_data_argument]
491 istr >> dummy >> nvars;
492 Double_t cutmin,cutmax;
493 UInt_t sel,idum;
494 Char_t bA, bB;
495 //
496 if (fCut) delete fCut;
497 fCut = new RuleCut();
498 fCut->SetNvars( nvars );
499 for ( UInt_t i=0; i<nvars; i++) {
500 istr >> dummy >> idum; // get 'Node' and index
501 istr >> dummy; // get ':'
502 istr >> sel >> cutmin >> cutmax >> bA >> bB;
503 fCut->SetSelector(i,sel);
504 fCut->SetCutMin(i,cutmin);
505 fCut->SetCutMax(i,cutmax);
506 fCut->SetCutDoMin(i,(bA=='T' ? kTRUE:kFALSE));
507 fCut->SetCutDoMax(i,(bB=='T' ? kTRUE:kFALSE));
508 }
509}
#define d(i)
Definition: RSha256.hxx:102
#define c(i)
Definition: RSha256.hxx:101
#define e(i)
Definition: RSha256.hxx:103
int Int_t
Definition: RtypesCore.h:45
char Char_t
Definition: RtypesCore.h:33
unsigned int UInt_t
Definition: RtypesCore.h:46
const Bool_t kFALSE
Definition: RtypesCore.h:101
bool Bool_t
Definition: RtypesCore.h:63
double Double_t
Definition: RtypesCore.h:59
const Bool_t kTRUE
Definition: RtypesCore.h:100
double sqrt(double)
char * Form(const char *fmt,...)
ostringstream derivative to redirect and format output
Definition: MsgLogger.h:59
A class describing a 'rule cut'.
Definition: RuleCut.h:36
Double_t GetPurity() const
Definition: RuleCut.h:79
UInt_t GetNvars() const
Definition: RuleCut.h:72
Double_t GetCutMin(Int_t is) const
Definition: RuleCut.h:74
UInt_t GetSelector(Int_t is) const
Definition: RuleCut.h:73
Double_t GetCutNeve() const
Definition: RuleCut.h:78
Char_t GetCutDoMin(Int_t is) const
Definition: RuleCut.h:76
Char_t GetCutDoMax(Int_t is) const
Definition: RuleCut.h:77
Double_t GetCutMax(Int_t is) const
Definition: RuleCut.h:75
Implementation of a rule.
Definition: Rule.h:50
void SetMsgType(EMsgType t)
Definition: Rule.cxx:156
void Copy(const Rule &other)
copy function
Definition: Rule.cxx:284
Double_t GetImportanceRef() const
Definition: Rule.h:146
Double_t GetSSBNeve() const
Definition: Rule.h:118
Double_t GetSupport() const
Definition: Rule.h:142
Bool_t Equal(const Rule &other, Bool_t useCutValue, Double_t maxdist) const
Compare two rules.
Definition: Rule.cxx:172
void ReadRaw(std::istream &os)
read function (format is the same as written by PrintRaw)
Definition: Rule.cxx:477
void * AddXMLTo(void *parent) const
Definition: Rule.cxx:398
void PrintLogger(const char *title=0) const
print function
Definition: Rule.cxx:335
Bool_t operator==(const Rule &other) const
comparison operator ==
Definition: Rule.cxx:251
const RuleCut * GetRuleCut() const
Definition: Rule.h:139
Double_t GetCoefficient() const
Definition: Rule.h:141
void Print(std::ostream &os) const
print function
Definition: Rule.cxx:303
const RuleEnsemble * GetRuleEnsemble() const
Definition: Rule.h:140
virtual ~Rule()
destructor
Definition: Rule.cxx:130
Double_t GetSSB() const
Definition: Rule.h:117
Double_t GetNorm() const
Definition: Rule.h:144
void ReadFromXML(void *wghtnode)
read rule from XML
Definition: Rule.cxx:428
Double_t GetImportance() const
Definition: Rule.h:145
void PrintRaw(std::ostream &os) const
extensive print function used to print info for the weight file
Definition: Rule.cxx:367
Double_t GetSigma() const
Definition: Rule.h:143
const TString & GetVarName(Int_t i) const
returns the name of a rule
Definition: Rule.cxx:276
Bool_t operator<(const Rule &other) const
comparison operator <
Definition: Rule.cxx:259
Double_t RuleDist(const Rule &other, Bool_t useCutValue) const
Returns:
Definition: Rule.cxx:192
Double_t fSSB
Definition: Rule.h:180
RuleCut * fCut
Definition: Rule.h:172
Bool_t ContainsVariable(UInt_t iv) const
check if variable in node
Definition: Rule.cxx:139
Rule()
the simple constructor
Definition: Rule.cxx:112
Double_t fSSBNeve
Definition: Rule.h:181
void * GetNextChild(void *prevchild, const char *childname=0)
XML helpers.
Definition: Tools.cxx:1174
void * AddChild(void *parent, const char *childname, const char *content=0, bool isRootNode=false)
add child node
Definition: Tools.cxx:1136
void * GetChild(void *parent, const char *childname=0)
get child node
Definition: Tools.cxx:1162
void ReadAttr(void *node, const char *, T &value)
read attribute from xml
Definition: Tools.h:335
void AddAttr(void *node, const char *, const T &value, Int_t precision=16)
add attribute to xml
Definition: Tools.h:353
EMsgType
Definition: Types.h:57
@ kINFO
Definition: Types.h:60
@ kWARNING
Definition: Types.h:61
@ kFATAL
Definition: Types.h:63
Basic string class.
Definition: TString.h:136
Tools & gTools()
std::ostream & operator<<(std::ostream &os, const BinaryTree &tree)
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:158
Double_t Log(Double_t x)
Definition: TMath.h:760