Logo ROOT  
Reference Guide
Rule.cxx
Go to the documentation of this file.
1// @(#)root/tmva $Id$
2// Author: Andreas Hoecker, Joerg Stelzer, Fredrik Tegenfeldt, Helge Voss
3
4/**********************************************************************************
5 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6 * Package: TMVA *
7 * Class : Rule *
8 * Web : http://tmva.sourceforge.net *
9 * *
10 * Description: *
11 * A class describing a 'rule' *
12 * Each internal node of a tree defines a rule from all the parental nodes. *
13 * A rule consists of at least 2 nodes. *
14 * Input: a decision tree (in the constructor) *
15 * *
16 * Authors (alphabetical): *
17 * Fredrik Tegenfeldt <Fredrik.Tegenfeldt@cern.ch> - Iowa State U., USA *
18 * Helge Voss <Helge.Voss@cern.ch> - MPI-KP Heidelberg, Ger. *
19 * *
20 * Copyright (c) 2005: *
21 * CERN, Switzerland *
22 * Iowa State U. *
23 * MPI-K Heidelberg, Germany *
24 * *
25 * Redistribution and use in source and binary forms, with or without *
26 * modification, are permitted according to the terms listed in LICENSE *
27 * (http://tmva.sourceforge.net/LICENSE) *
28 **********************************************************************************/
29
30/*! \class TMVA::Rule
31\ingroup TMVA
32
33Implementation of a rule.
34
35A rule is simply a branch or a part of a branch in a tree.
36It fulfills the following:
37
38 - First node is the root node of the originating tree
39 - Consists of a minimum of 2 nodes
40 - A rule returns for a given event:
41 - 0 : if the event fails at any node
42 - 1 : otherwise
43 - If the rule contains <2 nodes, it returns 0 SHOULD NOT HAPPEN!
44
45The coefficient is found by either brute force or some sort of
46intelligent fitting. See the RuleEnsemble class for more info.
47*/
48
49#include "TMVA/Rule.h"
50
51#include "TMVA/Event.h"
52#include "TMVA/MethodBase.h"
53#include "TMVA/MethodRuleFit.h"
54#include "TMVA/MsgLogger.h"
55#include "TMVA/RuleCut.h"
56#include "TMVA/RuleFit.h"
57#include "TMVA/RuleEnsemble.h"
58#include "TMVA/Tools.h"
59#include "TMVA/Types.h"
60
61////////////////////////////////////////////////////////////////////////////////
62/// the main constructor for a Rule
63
65 const std::vector< const Node * >& nodes )
66 : fCut ( 0 )
67 , fNorm ( 1.0 )
68 , fSupport ( 0.0 )
69 , fSigma ( 0.0 )
70 , fCoefficient ( 0.0 )
71 , fImportance ( 0.0 )
72 , fImportanceRef ( 1.0 )
73 , fRuleEnsemble ( re )
74 , fSSB ( 0 )
75 , fSSBNeve ( 0 )
76 , fLogger( new MsgLogger("RuleFit") )
77{
78 //
79 // input:
80 // nodes - a vector of Node; from these all possible rules will be created
81 //
82 //
83
84 fCut = new RuleCut( nodes );
85 fSSB = fCut->GetPurity();
87}
88
89////////////////////////////////////////////////////////////////////////////////
90/// the simple constructor
91
93 : fCut ( 0 )
94 , fNorm ( 1.0 )
95 , fSupport ( 0.0 )
96 , fSigma ( 0.0 )
97 , fCoefficient ( 0.0 )
98 , fImportance ( 0.0 )
99 , fImportanceRef ( 1.0 )
100 , fRuleEnsemble ( re )
101 , fSSB ( 0 )
102 , fSSBNeve ( 0 )
103 , fLogger( new MsgLogger("RuleFit") )
104{
105}
106
107////////////////////////////////////////////////////////////////////////////////
108/// the simple constructor
109
111 : fCut ( 0 )
112 , fNorm ( 1.0 )
113 , fSupport ( 0.0 )
114 , fSigma ( 0.0 )
115 , fCoefficient ( 0.0 )
116 , fImportance ( 0.0 )
117 , fImportanceRef ( 1.0 )
118 , fRuleEnsemble ( 0 )
119 , fSSB ( 0 )
120 , fSSBNeve ( 0 )
121 , fLogger( new MsgLogger("RuleFit") )
122{
123}
124
125////////////////////////////////////////////////////////////////////////////////
126/// destructor
127
129{
130 delete fCut;
131 delete fLogger;
132}
133
134////////////////////////////////////////////////////////////////////////////////
135/// check if variable in node
136
138{
139 Bool_t found = kFALSE;
140 Bool_t doneLoop = kFALSE;
141 UInt_t nvars = fCut->GetNvars();
142 UInt_t i = 0;
143 //
144 while (!doneLoop) {
145 found = (fCut->GetSelector(i) == iv);
146 i++;
147 doneLoop = (found || (i==nvars));
148 }
149 return found;
150}
151
152////////////////////////////////////////////////////////////////////////////////
153
154void TMVA::Rule::SetMsgType( EMsgType t )
155{
156 fLogger->SetMinType(t);
157}
158
159
160////////////////////////////////////////////////////////////////////////////////
161/// Compare two rules.
162///
163/// - useCutValue:
164/// - true -> calculate a distance between the two rules based on the cut values
165/// if the rule cuts are not equal, the distance is < 0 (-1.0)
166/// return true if d<mindist
167/// - false-> ignore mindist, return true if rules are equal, ignoring cut values
168/// - mindist: min distance allowed between rules; if < 0 => set useCutValue=false;
169
170Bool_t TMVA::Rule::Equal( const Rule& other, Bool_t useCutValue, Double_t mindist ) const
171{
172 Bool_t rval=kFALSE;
173 if (mindist<0) useCutValue=kFALSE;
174 Double_t d = RuleDist( other, useCutValue );
175 // cut value used - return true if 0<=d<mindist
176 if (useCutValue) rval = ( (!(d<0)) && (d<mindist) );
177 else rval = (!(d<0));
178 // cut value not used, return true if <> -1
179 return rval;
180}
181
182////////////////////////////////////////////////////////////////////////////////
183/// Returns:
184///
185/// * -1.0 : rules are NOT equal, i.e, variables and/or cut directions are wrong
186/// * >=0: rules are equal apart from the cutvalue, returns \f$ d = \sqrt{\sum(c1-c2)^2} \f$
187///
188/// If not useCutValue, the distance is exactly zero if they are equal
189
190Double_t TMVA::Rule::RuleDist( const Rule& other, Bool_t useCutValue ) const
191{
192 if (fCut->GetNvars()!=other.GetRuleCut()->GetNvars()) return -1.0; // check number of cuts
193 //
194 const UInt_t nvars = fCut->GetNvars();
195 //
196 Int_t sel; // cut variable
197 Double_t rms; // rms of cut variable
198 Double_t smin; // distance between the lower range
199 Double_t smax; // distance between the upper range
200 Double_t vminA,vmaxA; // min,max range of cut A (cut from this Rule)
201 Double_t vminB,vmaxB; // idem from other Rule
202 //
203 // compare nodes
204 // A 'distance' is assigned if the two rules has exactly the same set of cuts but with
205 // different cut values.
206 // The distance is given in number of sigmas
207 //
208 UInt_t in = 0; // cut index
209 Double_t sumdc2 = 0; // sum of 'distances'
210 Bool_t equal = true; // flag if cut are equal
211 //
212 const RuleCut *otherCut = other.GetRuleCut();
213 while ((equal) && (in<nvars)) {
214 // check equality in cut topology
215 equal = ( (fCut->GetSelector(in) == (otherCut->GetSelector(in))) &&
216 (fCut->GetCutDoMin(in) == (otherCut->GetCutDoMin(in))) &&
217 (fCut->GetCutDoMax(in) == (otherCut->GetCutDoMax(in))) );
218 // if equal topology, check cut values
219 if (equal) {
220 if (useCutValue) {
221 sel = fCut->GetSelector(in);
222 vminA = fCut->GetCutMin(in);
223 vmaxA = fCut->GetCutMax(in);
224 vminB = other.GetRuleCut()->GetCutMin(in);
225 vmaxB = other.GetRuleCut()->GetCutMax(in);
226 // messy - but ok...
227 rms = fRuleEnsemble->GetRuleFit()->GetMethodBase()->GetRMS(sel);
228 smin=0;
229 smax=0;
230 if (fCut->GetCutDoMin(in))
231 smin = ( rms>0 ? (vminA-vminB)/rms : 0 );
232 if (fCut->GetCutDoMax(in))
233 smax = ( rms>0 ? (vmaxA-vmaxB)/rms : 0 );
234 sumdc2 += smin*smin + smax*smax;
235 // sumw += 1.0/(rms*rms); // TODO: probably not needed
236 }
237 }
238 in++;
239 }
240 if (!useCutValue) sumdc2 = (equal ? 0.0:-1.0); // ignore cut values
241 else sumdc2 = (equal ? sqrt(sumdc2) : -1.0);
242
243 return sumdc2;
244}
245
246////////////////////////////////////////////////////////////////////////////////
247/// comparison operator ==
248
249Bool_t TMVA::Rule::operator==( const Rule& other ) const
250{
251 return this->Equal( other, kTRUE, 1e-3 );
252}
253
254////////////////////////////////////////////////////////////////////////////////
255/// comparison operator <
256
257Bool_t TMVA::Rule::operator<( const Rule& other ) const
258{
259 return (fImportance < other.GetImportance());
260}
261
262////////////////////////////////////////////////////////////////////////////////
263/// std::ostream operator
264
265std::ostream& TMVA::operator<< ( std::ostream& os, const Rule& rule )
266{
267 rule.Print( os );
268 return os;
269}
270
271////////////////////////////////////////////////////////////////////////////////
272/// returns the name of a rule
273
275{
276 return fRuleEnsemble->GetMethodBase()->GetInputLabel(i);
277}
278
279////////////////////////////////////////////////////////////////////////////////
280/// copy function
281
282void TMVA::Rule::Copy( const Rule& other )
283{
284 if(this != &other) {
285 SetRuleEnsemble( other.GetRuleEnsemble() );
286 fCut = new RuleCut( *(other.GetRuleCut()) );
287 fSSB = other.GetSSB();
288 fSSBNeve = other.GetSSBNeve();
289 SetCoefficient(other.GetCoefficient());
290 SetSupport( other.GetSupport() );
291 SetSigma( other.GetSigma() );
292 SetNorm( other.GetNorm() );
293 CalcImportance();
294 SetImportanceRef( other.GetImportanceRef() );
295 }
296}
297
298////////////////////////////////////////////////////////////////////////////////
299/// print function
300
301void TMVA::Rule::Print( std::ostream& os ) const
302{
303 const UInt_t nvars = fCut->GetNvars();
304 if (nvars<1) os << " *** WARNING - <EMPTY RULE> ***" << std::endl; // TODO: Fix this, use fLogger
305 //
306 Int_t sel;
307 Double_t valmin, valmax;
308 //
309 os << " Importance = " << Form("%1.4f", fImportance/fImportanceRef) << std::endl;
310 os << " Coefficient = " << Form("%1.4f", fCoefficient) << std::endl;
311 os << " Support = " << Form("%1.4f", fSupport) << std::endl;
312 os << " S/(S+B) = " << Form("%1.4f", fSSB) << std::endl;
313
314 for ( UInt_t i=0; i<nvars; i++) {
315 os << " ";
316 sel = fCut->GetSelector(i);
317 valmin = fCut->GetCutMin(i);
318 valmax = fCut->GetCutMax(i);
319 //
320 os << Form("* Cut %2d",i+1) << " : " << std::flush;
321 if (fCut->GetCutDoMin(i)) os << Form("%10.3g",valmin) << " < " << std::flush;
322 else os << " " << std::flush;
323 os << GetVarName(sel) << std::flush;
324 if (fCut->GetCutDoMax(i)) os << " < " << Form("%10.3g",valmax) << std::flush;
325 else os << " " << std::flush;
326 os << std::endl;
327 }
328}
329
330////////////////////////////////////////////////////////////////////////////////
331/// print function
332
333void TMVA::Rule::PrintLogger(const char *title) const
334{
335 const UInt_t nvars = fCut->GetNvars();
336 if (nvars<1) Log() << kWARNING << "BUG TRAP: EMPTY RULE!!!" << Endl;
337 //
338 Int_t sel;
339 Double_t valmin, valmax;
340 //
341 if (title) Log() << kINFO << title;
342 Log() << kINFO
343 << "Importance = " << Form("%1.4f", fImportance/fImportanceRef) << Endl;
344
345 for ( UInt_t i=0; i<nvars; i++) {
346
347 Log() << kINFO << " ";
348 sel = fCut->GetSelector(i);
349 valmin = fCut->GetCutMin(i);
350 valmax = fCut->GetCutMax(i);
351 //
352 Log() << kINFO << Form("Cut %2d",i+1) << " : ";
353 if (fCut->GetCutDoMin(i)) Log() << kINFO << Form("%10.3g",valmin) << " < ";
354 else Log() << kINFO << " ";
355 Log() << kINFO << GetVarName(sel);
356 if (fCut->GetCutDoMax(i)) Log() << kINFO << " < " << Form("%10.3g",valmax);
357 else Log() << kINFO << " ";
358 Log() << Endl;
359 }
360}
361
362////////////////////////////////////////////////////////////////////////////////
363/// extensive print function used to print info for the weight file
364
365void TMVA::Rule::PrintRaw( std::ostream& os ) const
366{
367 Int_t dp = os.precision();
368 const UInt_t nvars = fCut->GetNvars();
369 os << "Parameters: "
370 << std::setprecision(10)
371 << fImportance << " "
372 << fImportanceRef << " "
373 << fCoefficient << " "
374 << fSupport << " "
375 << fSigma << " "
376 << fNorm << " "
377 << fSSB << " "
378 << fSSBNeve << " "
379 << std::endl; \
380 os << "N(cuts): " << nvars << std::endl; // mark end of nodes
381 for ( UInt_t i=0; i<nvars; i++) {
382 os << "Cut " << i << " : " << std::flush;
383 os << fCut->GetSelector(i)
384 << std::setprecision(10)
385 << " " << fCut->GetCutMin(i)
386 << " " << fCut->GetCutMax(i)
387 << " " << (fCut->GetCutDoMin(i) ? "T":"F")
388 << " " << (fCut->GetCutDoMax(i) ? "T":"F")
389 << std::endl;
390 }
391 os << std::setprecision(dp);
392}
393
394////////////////////////////////////////////////////////////////////////////////
395
396void* TMVA::Rule::AddXMLTo( void* parent ) const
397{
398 void* rule = gTools().AddChild( parent, "Rule" );
399 const UInt_t nvars = fCut->GetNvars();
400
401 gTools().AddAttr( rule, "Importance", fImportance );
402 gTools().AddAttr( rule, "Ref", fImportanceRef );
403 gTools().AddAttr( rule, "Coeff", fCoefficient );
404 gTools().AddAttr( rule, "Support", fSupport );
405 gTools().AddAttr( rule, "Sigma", fSigma );
406 gTools().AddAttr( rule, "Norm", fNorm );
407 gTools().AddAttr( rule, "SSB", fSSB );
408 gTools().AddAttr( rule, "SSBNeve", fSSBNeve );
409 gTools().AddAttr( rule, "Nvars", nvars );
410
411 for (UInt_t i=0; i<nvars; i++) {
412 void* cut = gTools().AddChild( rule, "Cut" );
413 gTools().AddAttr( cut, "Selector", fCut->GetSelector(i) );
414 gTools().AddAttr( cut, "Min", fCut->GetCutMin(i) );
415 gTools().AddAttr( cut, "Max", fCut->GetCutMax(i) );
416 gTools().AddAttr( cut, "DoMin", (fCut->GetCutDoMin(i) ? "T":"F") );
417 gTools().AddAttr( cut, "DoMax", (fCut->GetCutDoMax(i) ? "T":"F") );
418 }
419
420 return rule;
421}
422
423////////////////////////////////////////////////////////////////////////////////
424/// read rule from XML
425
426void TMVA::Rule::ReadFromXML( void* wghtnode )
427{
428 TString nodeName = TString( gTools().GetName(wghtnode) );
429 if (nodeName != "Rule") Log() << kFATAL << "<ReadFromXML> Unexpected node name: " << nodeName << Endl;
430
431 gTools().ReadAttr( wghtnode, "Importance", fImportance );
432 gTools().ReadAttr( wghtnode, "Ref", fImportanceRef );
433 gTools().ReadAttr( wghtnode, "Coeff", fCoefficient );
434 gTools().ReadAttr( wghtnode, "Support", fSupport );
435 gTools().ReadAttr( wghtnode, "Sigma", fSigma );
436 gTools().ReadAttr( wghtnode, "Norm", fNorm );
437 gTools().ReadAttr( wghtnode, "SSB", fSSB );
438 gTools().ReadAttr( wghtnode, "SSBNeve", fSSBNeve );
439
440 UInt_t nvars;
441 gTools().ReadAttr( wghtnode, "Nvars", nvars );
442 if (fCut) delete fCut;
443 fCut = new RuleCut();
444 fCut->SetNvars( nvars );
445
446 // read Cut
447 void* ch = gTools().GetChild( wghtnode );
448 UInt_t i = 0;
449 UInt_t ui;
450 Double_t d;
451 Char_t c;
452 while (ch) {
453 gTools().ReadAttr( ch, "Selector", ui );
454 fCut->SetSelector( i, ui );
455 gTools().ReadAttr( ch, "Min", d );
456 fCut->SetCutMin ( i, d );
457 gTools().ReadAttr( ch, "Max", d );
458 fCut->SetCutMax ( i, d );
459 gTools().ReadAttr( ch, "DoMin", c );
460 fCut->SetCutDoMin( i, (c == 'T' ? kTRUE : kFALSE ) );
461 gTools().ReadAttr( ch, "DoMax", c );
462 fCut->SetCutDoMax( i, (c == 'T' ? kTRUE : kFALSE ) );
463
464 i++;
465 ch = gTools().GetNextChild(ch);
466 }
467
468 // sanity check
469 if (i != nvars) Log() << kFATAL << "<ReadFromXML> Mismatch in number of cuts: " << i << " != " << nvars << Endl;
470}
471
472////////////////////////////////////////////////////////////////////////////////
473/// read function (format is the same as written by PrintRaw)
474
475void TMVA::Rule::ReadRaw( std::istream& istr )
476{
478 UInt_t nvars;
479 istr >> dummy
480 >> fImportance
481 >> fImportanceRef
482 >> fCoefficient
483 >> fSupport
484 >> fSigma
485 >> fNorm
486 >> fSSB
487 >> fSSBNeve;
488 // coverity[tainted_data_argument]
489 istr >> dummy >> nvars;
490 Double_t cutmin,cutmax;
491 UInt_t sel,idum;
492 Char_t bA, bB;
493 //
494 if (fCut) delete fCut;
495 fCut = new RuleCut();
496 fCut->SetNvars( nvars );
497 for ( UInt_t i=0; i<nvars; i++) {
498 istr >> dummy >> idum; // get 'Node' and index
499 istr >> dummy; // get ':'
500 istr >> sel >> cutmin >> cutmax >> bA >> bB;
501 fCut->SetSelector(i,sel);
502 fCut->SetCutMin(i,cutmin);
503 fCut->SetCutMax(i,cutmax);
504 fCut->SetCutDoMin(i,(bA=='T' ? kTRUE:kFALSE));
505 fCut->SetCutDoMax(i,(bB=='T' ? kTRUE:kFALSE));
506 }
507}
#define d(i)
Definition: RSha256.hxx:102
#define c(i)
Definition: RSha256.hxx:101
#define e(i)
Definition: RSha256.hxx:103
static RooMathCoreReg dummy
int Int_t
Definition: RtypesCore.h:41
char Char_t
Definition: RtypesCore.h:29
unsigned int UInt_t
Definition: RtypesCore.h:42
const Bool_t kFALSE
Definition: RtypesCore.h:88
bool Bool_t
Definition: RtypesCore.h:59
double Double_t
Definition: RtypesCore.h:55
const Bool_t kTRUE
Definition: RtypesCore.h:87
double sqrt(double)
char * Form(const char *fmt,...)
ostringstream derivative to redirect and format output
Definition: MsgLogger.h:59
A class describing a 'rule cut'.
Definition: RuleCut.h:34
Double_t GetPurity() const
Definition: RuleCut.h:77
UInt_t GetNvars() const
Definition: RuleCut.h:70
Double_t GetCutMin(Int_t is) const
Definition: RuleCut.h:72
UInt_t GetSelector(Int_t is) const
Definition: RuleCut.h:71
Double_t GetCutNeve() const
Definition: RuleCut.h:76
Char_t GetCutDoMin(Int_t is) const
Definition: RuleCut.h:74
Char_t GetCutDoMax(Int_t is) const
Definition: RuleCut.h:75
Double_t GetCutMax(Int_t is) const
Definition: RuleCut.h:73
Implementation of a rule.
Definition: Rule.h:48
void SetMsgType(EMsgType t)
Definition: Rule.cxx:154
void Copy(const Rule &other)
copy function
Definition: Rule.cxx:282
Double_t GetImportanceRef() const
Definition: Rule.h:144
Double_t GetSSBNeve() const
Definition: Rule.h:116
Double_t GetSupport() const
Definition: Rule.h:140
Bool_t Equal(const Rule &other, Bool_t useCutValue, Double_t maxdist) const
Compare two rules.
Definition: Rule.cxx:170
void ReadRaw(std::istream &os)
read function (format is the same as written by PrintRaw)
Definition: Rule.cxx:475
void * AddXMLTo(void *parent) const
Definition: Rule.cxx:396
void PrintLogger(const char *title=0) const
print function
Definition: Rule.cxx:333
Bool_t operator==(const Rule &other) const
comparison operator ==
Definition: Rule.cxx:249
const RuleCut * GetRuleCut() const
Definition: Rule.h:137
Double_t GetCoefficient() const
Definition: Rule.h:139
void Print(std::ostream &os) const
print function
Definition: Rule.cxx:301
const RuleEnsemble * GetRuleEnsemble() const
Definition: Rule.h:138
virtual ~Rule()
destructor
Definition: Rule.cxx:128
Double_t GetSSB() const
Definition: Rule.h:115
Double_t GetNorm() const
Definition: Rule.h:142
void ReadFromXML(void *wghtnode)
read rule from XML
Definition: Rule.cxx:426
Double_t GetImportance() const
Definition: Rule.h:143
void PrintRaw(std::ostream &os) const
extensive print function used to print info for the weight file
Definition: Rule.cxx:365
Double_t GetSigma() const
Definition: Rule.h:141
const TString & GetVarName(Int_t i) const
returns the name of a rule
Definition: Rule.cxx:274
Bool_t operator<(const Rule &other) const
comparison operator <
Definition: Rule.cxx:257
Double_t RuleDist(const Rule &other, Bool_t useCutValue) const
Returns:
Definition: Rule.cxx:190
Double_t fSSB
Definition: Rule.h:178
RuleCut * fCut
Definition: Rule.h:170
Bool_t ContainsVariable(UInt_t iv) const
check if variable in node
Definition: Rule.cxx:137
Rule()
the simple constructor
Definition: Rule.cxx:110
Double_t fSSBNeve
Definition: Rule.h:179
void * GetNextChild(void *prevchild, const char *childname=0)
XML helpers.
Definition: Tools.cxx:1174
void * AddChild(void *parent, const char *childname, const char *content=0, bool isRootNode=false)
add child node
Definition: Tools.cxx:1136
void * GetChild(void *parent, const char *childname=0)
get child node
Definition: Tools.cxx:1162
void ReadAttr(void *node, const char *, T &value)
read attribute from xml
Definition: Tools.h:337
void AddAttr(void *node, const char *, const T &value, Int_t precision=16)
add attribute to xml
Definition: Tools.h:355
Basic string class.
Definition: TString.h:131
std::string GetName(const std::string &scope_name)
Definition: Cppyy.cxx:150
Tools & gTools()
std::ostream & operator<<(std::ostream &os, const BinaryTree &tree)
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:158
Double_t Log(Double_t x)
Definition: TMath.h:750