Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
RuleCut.cxx
Go to the documentation of this file.
1// @(#)root/tmva $Id$
2// Author: Andreas Hoecker, Joerg Stelzer, Fredrik Tegenfeldt, Helge Voss
3
4/**********************************************************************************
5 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6 * Package: TMVA *
7 * Class : Rule *
8 * *
9 * Description: *
10 * A class describing a 'rule cut' *
11 * *
12 * *
13 * Authors (alphabetical): *
14 * Fredrik Tegenfeldt <Fredrik.Tegenfeldt@cern.ch> - Iowa State U., USA *
15 * *
16 * Copyright (c) 2005: *
17 * CERN, Switzerland *
18 * Iowa State U. *
19 * *
20 * Redistribution and use in source and binary forms, with or without *
21 * modification, are permitted according to the terms listed in LICENSE *
22 * (see tmva/doc/LICENSE) *
23 **********************************************************************************/
24
25/*! \class TMVA::RuleCut
26\ingroup TMVA
27A class describing a 'rule cut'
28*/
29
30#include <algorithm>
31#include <list>
32
33#include "TMVA/RuleCut.h"
34#include "TMVA/DecisionTree.h"
35#include "TMVA/MsgLogger.h"
36
37////////////////////////////////////////////////////////////////////////////////
38/// main constructor
39
40TMVA::RuleCut::RuleCut( const std::vector<const Node*> & nodes )
41 : fCutNeve(0),
42 fPurity(0),
43 fLogger(new MsgLogger("RuleFit"))
44{
45 MakeCuts( nodes );
46}
47
48////////////////////////////////////////////////////////////////////////////////
49/// empty constructor
50
52 : fCutNeve(0),
53 fPurity(0),
54 fLogger(new MsgLogger("RuleFit"))
55{
56}
57
58////////////////////////////////////////////////////////////////////////////////
59/// destructor
60
62 delete fLogger;
63}
64
65
66////////////////////////////////////////////////////////////////////////////////
67/// Construct the cuts from the given array of nodes
68
69void TMVA::RuleCut::MakeCuts( const std::vector<const Node*> & nodes )
70{
71 // At least 2 nodes are required
72 UInt_t nnodes = nodes.size();
73 if (nnodes<2) {
74 Log() << kWARNING << "<MakeCuts()> Empty cut created." << Endl;
75 return;
76 }
77
78 // Set number of events and S/S+B in last node
79 const DecisionTreeNode* dtn = dynamic_cast<const DecisionTreeNode*>(nodes.back());
80 if(!dtn) return;
81 fCutNeve = dtn->GetNEvents();
82 fPurity = dtn->GetPurity();
83
84 // some local typedefs
85 typedef std::pair<Double_t,Int_t> CutDir_t; // first is cut value, second is direction
86 typedef std::pair<Int_t,CutDir_t> SelCut_t;
87
88 // Clear vectors
89 fSelector.clear();
90 fCutMin.clear();
91 fCutMax.clear();
92 fCutDoMin.clear();
93 fCutDoMax.clear();
94
95 // Count the number of variables in cut
96 // Exclude last node since that does not lead to a cut
97 std::list<SelCut_t> allsel;
98 Int_t sel;
99 Double_t val;
100 Int_t dir;
101 const Node *nextNode;
102 for ( UInt_t i=0; i<nnodes-1; i++) {
103 nextNode = nodes[i+1];
104 const DecisionTreeNode* dtn_ = dynamic_cast<const DecisionTreeNode*>(nodes[i]);
105 if(!dtn_) return;
106 sel = dtn_->GetSelector();
107 val = dtn_->GetCutValue();
108 if (nodes[i]->GetRight() == nextNode) { // val>cut
109 dir = 1;
110 }
111 else if (nodes[i]->GetLeft() == nextNode) { // val<cut
112 dir = -1;
113 }
114 else {
115 Log() << kFATAL << "<MakeTheRule> BUG! Should not be here - an end-node before the end!" << Endl;
116 dir = 0;
117 }
118 allsel.push_back(SelCut_t(sel,CutDir_t(val,dir)));
119 }
120 // sort after the selector (first element of CutDir_t)
121 allsel.sort();
122 Int_t prevSel=-1;
123 Int_t nsel=0;
124 Bool_t firstMin=kTRUE;
125 Bool_t firstMax=kTRUE;
126 for ( std::list<SelCut_t>::const_iterator it = allsel.begin(); it!=allsel.end(); ++it ) {
127 sel = (*it).first;
128 val = (*it).second.first;
129 dir = (*it).second.second;
130 if (sel!=prevSel) { // a new selector!
131 firstMin = kTRUE;
132 firstMax = kTRUE;
133 nsel++;
134 fSelector.push_back(sel);
135 fCutMin.resize( fSelector.size(),0);
136 fCutMax.resize( fSelector.size(),0);
137 fCutDoMin.resize( fSelector.size(), kFALSE);
138 fCutDoMax.resize( fSelector.size(), kFALSE);
139 }
140 switch ( dir ) {
141 case 1:
142 if ((val<fCutMin[nsel-1]) || firstMin) {
143 fCutMin[nsel-1] = val;
144 fCutDoMin[nsel-1] = kTRUE;
145 firstMin = kFALSE;
146 }
147 break;
148 case -1:
149 if ((val>fCutMax[nsel-1]) || firstMax) {
150 fCutMax[nsel-1] = val;
151 fCutDoMax[nsel-1] = kTRUE;
152 firstMax = kFALSE;
153 }
154 default:
155 break;
156 }
157 prevSel = sel;
158 }
159}
160
161////////////////////////////////////////////////////////////////////////////////
162/// get number of cuts
163
165{
166 UInt_t rval=0;
167 for (UInt_t i=0; i<fSelector.size(); i++) {
168 if (fCutDoMin[i]) rval += 1;
169 if (fCutDoMax[i]) rval += 1;
170 }
171 return rval;
172}
173////////////////////////////////////////////////////////////////////////////////
174/// get cut range for a given selector
175
177{
178 dormin=kFALSE;
179 dormax=kFALSE;
180 Bool_t done=kFALSE;
181 Bool_t foundIt=kFALSE;
182 UInt_t ind=0;
183 while (!done) {
184 foundIt = (Int_t(fSelector[ind])==sel);
185 ind++;
186 done = (foundIt || (ind==fSelector.size()));
187 }
188 if (!foundIt) return kFALSE;
189 rmin = fCutMin[ind-1];
190 rmax = fCutMax[ind-1];
191 dormin = fCutDoMin[ind-1];
192 dormax = fCutDoMax[ind-1];
193 return kTRUE;
194}
int Int_t
Definition RtypesCore.h:45
constexpr Bool_t kFALSE
Definition RtypesCore.h:94
constexpr Bool_t kTRUE
Definition RtypesCore.h:93
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t sel
Short_t GetSelector() const
return index of variable used for discrimination at this node
Float_t GetCutValue(void) const
return the cut value applied at this node
Float_t GetNEvents(void) const
return the number of events that entered the node (during training), or -1 if traininfo undefined
Float_t GetPurity(void) const
return S/(S+B) (purity) at this node (from training)
ostringstream derivative to redirect and format output
Definition MsgLogger.h:57
Node for the BinarySearch or Decision Trees.
Definition Node.h:58
RuleCut()
empty constructor
Definition RuleCut.cxx:51
virtual ~RuleCut()
destructor
Definition RuleCut.cxx:61
UInt_t GetNcuts() const
get number of cuts
Definition RuleCut.cxx:164
void MakeCuts(const std::vector< const TMVA::Node * > &nodes)
Construct the cuts from the given array of nodes.
Definition RuleCut.cxx:69
Bool_t GetCutRange(Int_t sel, Double_t &rmin, Double_t &rmax, Bool_t &dormin, Bool_t &dormax) const
get cut range for a given selector
Definition RuleCut.cxx:176
MsgLogger & Endl(MsgLogger &ml)
Definition MsgLogger.h:148