Logo ROOT   6.16/01
Reference Guide
BinarySearchTree.cxx
Go to the documentation of this file.
1// @(#)root/tmva $Id$
2// Author: Andreas Hoecker, Joerg Stelzer, Helge Voss, Kai Voss
3
4/**********************************************************************************
5 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6 * Package: TMVA *
7 * Class : BinarySearchTree *
8 * Web : http://tmva.sourceforge.net *
9 * *
10 * Description: *
11 * Implementation (see header file for description) *
12 * *
13 * Authors (alphabetical): *
14 * Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland *
15 * Joerg Stelzer <stelzer@cern.ch> - DESY, Germany *
16 * Helge Voss <Helge.Voss@cern.ch> - MPI-K Heidelberg, Germany *
17 * Kai Voss <Kai.Voss@cern.ch> - U. of Victoria, Canada *
18 * *
19 * Copyright (c) 2005: *
20 * CERN, Switzerland *
21 * U. of Victoria, Canada *
22 * MPI-K Heidelberg, Germany *
23 * LAPP, Annecy, France *
24 * *
25 * Redistribution and use in source and binary forms, with or without *
26 * modification, are permitted according to the terms listed in LICENSE *
27 * (http://tmva.sourceforge.net/LICENSE) *
28 * *
29 **********************************************************************************/
30
31/*! \class TMVA::BinarySearchTree
32\ingroup TMVA
33
34A simple Binary search tree including a volume search method.
35
36*/
37
38#include <stdexcept>
39#include <cstdlib>
40#include <queue>
41#include <algorithm>
42
43// #if ROOT_VERSION_CODE >= 364802
44// #ifndef ROOT_TMathBase
45// #include "TMathBase.h"
46// #endif
47// #else
48// #ifndef ROOT_TMath
49#include "TMath.h"
50// #endif
51// #endif
52
53#include "TMatrixDBase.h"
54#include "TObjString.h"
55#include "TTree.h"
56
57#include "TMVA/MsgLogger.h"
58#include "TMVA/MethodBase.h"
59#include "TMVA/Tools.h"
60#include "TMVA/Event.h"
62
63#include "TMVA/BinaryTree.h"
64#include "TMVA/Types.h"
65#include "TMVA/Node.h"
66
68
69////////////////////////////////////////////////////////////////////////////////
70/// default constructor
71
74 fPeriod ( 1 ),
75 fCurrentDepth( 0 ),
76 fStatisticsIsValid( kFALSE ),
77 fSumOfWeights( 0 ),
78 fCanNormalize( kFALSE )
79{
80 fNEventsW[0]=fNEventsW[1]=0.;
81}
82
83////////////////////////////////////////////////////////////////////////////////
84/// copy constructor that creates a true copy, i.e. a completely independent tree
85
87 : BinaryTree(),
88 fPeriod ( b.fPeriod ),
89 fCurrentDepth( 0 ),
90 fStatisticsIsValid( kFALSE ),
91 fSumOfWeights( b.fSumOfWeights ),
92 fCanNormalize( kFALSE )
93{
94 fNEventsW[0]=fNEventsW[1]=0.;
95 Log() << kFATAL << " Copy constructor not implemented yet " << Endl;
96}
97
98////////////////////////////////////////////////////////////////////////////////
99/// destructor
100
102{
103 for(std::vector< std::pair<Double_t, const TMVA::Event*> >::iterator pIt = fNormalizeTreeTable.begin();
104 pIt != fNormalizeTreeTable.end(); ++pIt) {
105 delete pIt->second;
106 }
107}
108
109////////////////////////////////////////////////////////////////////////////////
110/// re-create a new tree (decision tree or search tree) from XML
111
113 std::string type("");
114 gTools().ReadAttr(node,"type", type);
116 bt->ReadXML( node, tmva_Version_Code );
117 return bt;
118}
119
120////////////////////////////////////////////////////////////////////////////////
121/// insert a new "event" in the binary tree
122
124{
125 fCurrentDepth=0;
126 fStatisticsIsValid = kFALSE;
127
128 if (this->GetRoot() == NULL) { // If the list is empty...
129 this->SetRoot( new BinarySearchTreeNode(event)); //Make the new node the root.
130 // have to use "s" for start as "r" for "root" would be the same as "r" for "right"
131 this->GetRoot()->SetPos('s');
132 this->GetRoot()->SetDepth(0);
133 fNNodes = 1;
134 fSumOfWeights = event->GetWeight();
135 ((BinarySearchTreeNode*)this->GetRoot())->SetSelector((UInt_t)0);
136 this->SetPeriode(event->GetNVariables());
137 }
138 else {
139 // sanity check:
140 if (event->GetNVariables() != (UInt_t)this->GetPeriode()) {
141 Log() << kFATAL << "<Insert> event vector length != Periode specified in Binary Tree" << Endl
142 << "--- event size: " << event->GetNVariables() << " Periode: " << this->GetPeriode() << Endl
143 << "--- and all this when trying filling the "<<fNNodes+1<<"th Node" << Endl;
144 }
145 // insert a new node at the propper position
146 this->Insert(event, this->GetRoot());
147 }
148
149 // normalise the tree to speed up searches
150 if (fCanNormalize) fNormalizeTreeTable.push_back( std::make_pair(0.0,new const Event(*event)) );
151}
152
153////////////////////////////////////////////////////////////////////////////////
154/// private internal function to insert a event (node) at the proper position
155
157 Node *node )
158{
159 fCurrentDepth++;
160 fStatisticsIsValid = kFALSE;
161
162 if (node->GoesLeft(*event)){ // If the adding item is less than the current node's data...
163 if (node->GetLeft() != NULL){ // If there is a left node...
164 // Add the new event to the left node
165 this->Insert(event, node->GetLeft());
166 }
167 else { // If there is not a left node...
168 // Make the new node for the new event
169 BinarySearchTreeNode* current = new BinarySearchTreeNode(event);
170 fNNodes++;
171 fSumOfWeights += event->GetWeight();
172 current->SetSelector(fCurrentDepth%((Int_t)event->GetNVariables()));
173 current->SetParent(node); // Set the new node's previous node.
174 current->SetPos('l');
175 current->SetDepth( node->GetDepth() + 1 );
176 node->SetLeft(current); // Make it the left node of the current one.
177 }
178 }
179 else if (node->GoesRight(*event)) { // If the adding item is less than or equal to the current node's data...
180 if (node->GetRight() != NULL) { // If there is a right node...
181 // Add the new node to it.
182 this->Insert(event, node->GetRight());
183 }
184 else { // If there is not a right node...
185 // Make the new node.
186 BinarySearchTreeNode* current = new BinarySearchTreeNode(event);
187 fNNodes++;
188 fSumOfWeights += event->GetWeight();
189 current->SetSelector(fCurrentDepth%((Int_t)event->GetNVariables()));
190 current->SetParent(node); // Set the new node's previous node.
191 current->SetPos('r');
192 current->SetDepth( node->GetDepth() + 1 );
193 node->SetRight(current); // Make it the left node of the current one.
194 }
195 }
196 else Log() << kFATAL << "<Insert> neither left nor right :)" << Endl;
197}
198
199////////////////////////////////////////////////////////////////////////////////
200///search the tree to find the node matching "event"
201
203{
204 return this->Search( event, this->GetRoot() );
205}
206
207////////////////////////////////////////////////////////////////////////////////
208/// Private, recursive, function for searching.
209
211{
212 if (node != NULL) { // If the node is not NULL...
213 // If we have found the node...
214 if (((BinarySearchTreeNode*)(node))->EqualsMe(*event))
215 return (BinarySearchTreeNode*)node; // Return it
216 if (node->GoesLeft(*event)) // If the node's data is greater than the search item...
217 return this->Search(event, node->GetLeft()); //Search the left node.
218 else //If the node's data is less than the search item...
219 return this->Search(event, node->GetRight()); //Search the right node.
220 }
221 else return NULL; //If the node is NULL, return NULL.
222}
223
224////////////////////////////////////////////////////////////////////////////////
225/// return the sum of event (node) weights
226
228{
229 if (fSumOfWeights <= 0) {
230 Log() << kWARNING << "you asked for the SumOfWeights, which is not filled yet"
231 << " I call CalcStatistics which hopefully fixes things"
232 << Endl;
233 }
234 if (fSumOfWeights <= 0) Log() << kFATAL << " Zero events in your Search Tree" <<Endl;
235
236 return fSumOfWeights;
237}
238
239////////////////////////////////////////////////////////////////////////////////
240/// return the sum of event (node) weights
241
243{
244 if (fSumOfWeights <= 0) {
245 Log() << kWARNING << "you asked for the SumOfWeights, which is not filled yet"
246 << " I call CalcStatistics which hopefully fixes things"
247 << Endl;
248 }
249 if (fSumOfWeights <= 0) Log() << kFATAL << " Zero events in your Search Tree" <<Endl;
250
251 return fNEventsW[ ( theType == Types::kSignal) ? 0 : 1 ];
252}
253
254////////////////////////////////////////////////////////////////////////////////
255/// create the search tree from the event collection
256/// using ONLY the variables specified in "theVars"
257
258Double_t TMVA::BinarySearchTree::Fill( const std::vector<Event*>& events, const std::vector<Int_t>& theVars,
259 Int_t theType )
260{
261 fPeriod = theVars.size();
262 return Fill(events, theType);
263}
264
265////////////////////////////////////////////////////////////////////////////////
266/// create the search tree from the events in a TTree
267/// using ALL the variables specified included in the Event
268
269Double_t TMVA::BinarySearchTree::Fill( const std::vector<Event*>& events, Int_t theType )
270{
271 UInt_t n=events.size();
272
273 UInt_t nevents = 0;
274 if (fSumOfWeights != 0) {
275 Log() << kWARNING
276 << "You are filling a search three that is not empty.. "
277 << " do you know what you are doing?"
278 << Endl;
279 }
280 for (UInt_t ievt=0; ievt<n; ievt++) {
281 // insert event into binary tree
282 if (theType == -1 || (Int_t(events[ievt]->GetClass()) == theType) ) {
283 this->Insert( events[ievt] );
284 nevents++;
285 fSumOfWeights += events[ievt]->GetWeight();
286 }
287 } // end of event loop
288 CalcStatistics(0);
289
290 return fSumOfWeights;
291}
292
293////////////////////////////////////////////////////////////////////////////////
294/// normalises the binary-search tree to reduce the branch length and hence speed up the
295/// search procedure (on average).
296
297void TMVA::BinarySearchTree::NormalizeTree ( std::vector< std::pair<Double_t, const TMVA::Event*> >::iterator leftBound,
298 std::vector< std::pair<Double_t, const TMVA::Event*> >::iterator rightBound,
299 UInt_t actDim )
300{
301
302 if (leftBound == rightBound) return;
303
304 if (actDim == fPeriod) actDim = 0;
305 for (std::vector< std::pair<Double_t, const TMVA::Event*> >::iterator i=leftBound; i!=rightBound; ++i) {
306 i->first = i->second->GetValue( actDim );
307 }
308
309 std::sort( leftBound, rightBound );
310
311 std::vector< std::pair<Double_t, const TMVA::Event*> >::iterator leftTemp = leftBound;
312 std::vector< std::pair<Double_t, const TMVA::Event*> >::iterator rightTemp = rightBound;
313
314 // meet in the middle
315 while (true) {
316 --rightTemp;
317 if (rightTemp == leftTemp ) {
318 break;
319 }
320 ++leftTemp;
321 if (leftTemp == rightTemp) {
322 break;
323 }
324 }
325
326 std::vector< std::pair<Double_t, const TMVA::Event*> >::iterator mid = leftTemp;
327 std::vector< std::pair<Double_t, const TMVA::Event*> >::iterator midTemp = mid;
328
329 if (mid!=leftBound)--midTemp;
330
331 while (mid != leftBound && mid->second->GetValue( actDim ) == midTemp->second->GetValue( actDim )) {
332 --mid;
333 --midTemp;
334 }
335
336 Insert( mid->second );
337
338 // Print(std::cout);
339 // std::cout << std::endl << std::endl;
340
341 NormalizeTree( leftBound, mid, actDim+1 );
342 ++mid;
343 // Print(std::cout);
344 // std::cout << std::endl << std::endl;
345 NormalizeTree( mid, rightBound, actDim+1 );
346
347
348 return;
349}
350
351////////////////////////////////////////////////////////////////////////////////
352/// Normalisation of tree
353
355{
356 SetNormalize( kFALSE );
357 Clear( NULL );
358 this->SetRoot(NULL);
359 NormalizeTree( fNormalizeTreeTable.begin(), fNormalizeTreeTable.end(), 0 );
360}
361
362////////////////////////////////////////////////////////////////////////////////
363/// clear nodes
364
366{
367 BinarySearchTreeNode* currentNode = (BinarySearchTreeNode*)(n == NULL ? this->GetRoot() : n);
368
369 if (currentNode->GetLeft() != 0) Clear( currentNode->GetLeft() );
370 if (currentNode->GetRight() != 0) Clear( currentNode->GetRight() );
371
372 if (n != NULL) delete n;
373
374 return;
375}
376
377////////////////////////////////////////////////////////////////////////////////
378/// search the whole tree and add up all weights of events that
379/// lie within the given volume
380
382 std::vector<const BinarySearchTreeNode*>* events )
383{
384 return SearchVolume( this->GetRoot(), volume, 0, events );
385}
386
387////////////////////////////////////////////////////////////////////////////////
388/// recursively walk through the daughter nodes and add up all weights of events that
389/// lie within the given volume
390
392 std::vector<const BinarySearchTreeNode*>* events )
393{
394 if (t==NULL) return 0; // Are we at an outer leave?
395
397
398 Double_t count = 0.0;
399 if (InVolume( st->GetEventV(), volume )) {
400 count += st->GetWeight();
401 if (NULL != events) events->push_back( st );
402 }
403 if (st->GetLeft()==NULL && st->GetRight()==NULL) {
404
405 return count; // Are we at an outer leave?
406 }
407
408 Bool_t tl, tr;
409 Int_t d = depth%this->GetPeriode();
410 if (d != st->GetSelector()) {
411 Log() << kFATAL << "<SearchVolume> selector in Searchvolume "
412 << d << " != " << "node "<< st->GetSelector() << Endl;
413 }
414 tl = (*(volume->fLower))[d] < st->GetEventV()[d]; // Should we descend left?
415 tr = (*(volume->fUpper))[d] >= st->GetEventV()[d]; // Should we descend right?
416
417 if (tl) count += SearchVolume( st->GetLeft(), volume, (depth+1), events );
418 if (tr) count += SearchVolume( st->GetRight(), volume, (depth+1), events );
419
420 return count;
421}
422
423////////////////////////////////////////////////////////////////////////////////
424/// test if the data points are in the given volume
425
426Bool_t TMVA::BinarySearchTree::InVolume(const std::vector<Float_t>& event, Volume* volume ) const
427{
428
429 Bool_t result = false;
430 for (UInt_t ivar=0; ivar< fPeriod; ivar++) {
431 result = ( (*(volume->fLower))[ivar] < event[ivar] &&
432 (*(volume->fUpper))[ivar] >= event[ivar] );
433 if (!result) break;
434 }
435 return result;
436}
437
438////////////////////////////////////////////////////////////////////////////////
439/// calculate basic statistics (mean, rms for each variable)
440
442{
443 if (fStatisticsIsValid) return;
444
446
447 // default, start at the tree top, then descend recursively
448 if (n == NULL) {
449 fSumOfWeights = 0;
450 for (Int_t sb=0; sb<2; sb++) {
451 fNEventsW[sb] = 0;
452 fMeans[sb] = std::vector<Float_t>(fPeriod);
453 fRMS[sb] = std::vector<Float_t>(fPeriod);
454 fMin[sb] = std::vector<Float_t>(fPeriod);
455 fMax[sb] = std::vector<Float_t>(fPeriod);
456 fSum[sb] = std::vector<Double_t>(fPeriod);
457 fSumSq[sb] = std::vector<Double_t>(fPeriod);
458 for (UInt_t j=0; j<fPeriod; j++) {
459 fMeans[sb][j] = fRMS[sb][j] = fSum[sb][j] = fSumSq[sb][j] = 0;
460 fMin[sb][j] = FLT_MAX;
461 fMax[sb][j] = -FLT_MAX;
462 }
463 }
464 currentNode = (BinarySearchTreeNode*) this->GetRoot();
465 if (currentNode == NULL) return; // no root-node
466 }
467
468 const std::vector<Float_t> & evtVec = currentNode->GetEventV();
469 Double_t weight = currentNode->GetWeight();
470 // Int_t type = currentNode->IsSignal();
471 // Int_t type = currentNode->IsSignal() ? 0 : 1;
472 Int_t type = Int_t(currentNode->GetClass())== Types::kSignal ? 0 : 1;
473
474 fNEventsW[type] += weight;
475 fSumOfWeights += weight;
476
477 for (UInt_t j=0; j<fPeriod; j++) {
478 Float_t val = evtVec[j];
479 fSum[type][j] += val*weight;
480 fSumSq[type][j] += val*val*weight;
481 if (val < fMin[type][j]) fMin[type][j] = val;
482 if (val > fMax[type][j]) fMax[type][j] = val;
483 }
484
485 if ( (currentNode->GetLeft() != NULL) ) CalcStatistics( currentNode->GetLeft() );
486 if ( (currentNode->GetRight() != NULL) ) CalcStatistics( currentNode->GetRight() );
487
488 if (n == NULL) { // i.e. the root node
489 for (Int_t sb=0; sb<2; sb++) {
490 for (UInt_t j=0; j<fPeriod; j++) {
491 if (fNEventsW[sb] == 0) { fMeans[sb][j] = fRMS[sb][j] = 0; continue; }
492 fMeans[sb][j] = fSum[sb][j]/fNEventsW[sb];
493 fRMS[sb][j] = TMath::Sqrt(fSumSq[sb][j]/fNEventsW[sb] - fMeans[sb][j]*fMeans[sb][j]);
494 }
495 }
496 fStatisticsIsValid = kTRUE;
497 }
498
499 return;
500}
501
502////////////////////////////////////////////////////////////////////////////////
503/// recursively walk through the daughter nodes and add up all weights of events that
504/// lie within the given volume a maximum number of events can be given
505
506Int_t TMVA::BinarySearchTree::SearchVolumeWithMaxLimit( Volume *volume, std::vector<const BinarySearchTreeNode*>* events,
507 Int_t max_points )
508{
509 if (this->GetRoot() == NULL) return 0; // Are we at an outer leave?
510
511 std::queue< std::pair< const BinarySearchTreeNode*, Int_t > > queue;
512 std::pair< const BinarySearchTreeNode*, Int_t > st = std::make_pair( (const BinarySearchTreeNode*)this->GetRoot(), 0 );
513 queue.push( st );
514
515 Int_t count = 0;
516
517 while ( !queue.empty() ) {
518 st = queue.front(); queue.pop();
519
520 if (count == max_points)
521 return count;
522
523 if (InVolume( st.first->GetEventV(), volume )) {
524 count++;
525 if (NULL != events) events->push_back( st.first );
526 }
527
528 Bool_t tl, tr;
529 Int_t d = st.second;
530 if ( d == Int_t(this->GetPeriode()) ) d = 0;
531
532 if (d != st.first->GetSelector()) {
533 Log() << kFATAL << "<SearchVolume> selector in Searchvolume "
534 << d << " != " << "node "<< st.first->GetSelector() << Endl;
535 }
536
537 tl = (*(volume->fLower))[d] < st.first->GetEventV()[d] && st.first->GetLeft() != NULL; // Should we descend left?
538 tr = (*(volume->fUpper))[d] >= st.first->GetEventV()[d] && st.first->GetRight() != NULL; // Should we descend right?
539
540 if (tl) queue.push( std::make_pair( (const BinarySearchTreeNode*)st.first->GetLeft(), d+1 ) );
541 if (tr) queue.push( std::make_pair( (const BinarySearchTreeNode*)st.first->GetRight(), d+1 ) );
542 }
543
544 return count;
545}
#define d(i)
Definition: RSha256.hxx:102
#define b(i)
Definition: RSha256.hxx:100
int Int_t
Definition: RtypesCore.h:41
unsigned int UInt_t
Definition: RtypesCore.h:42
const Bool_t kFALSE
Definition: RtypesCore.h:88
bool Bool_t
Definition: RtypesCore.h:59
double Double_t
Definition: RtypesCore.h:55
float Float_t
Definition: RtypesCore.h:53
const Bool_t kTRUE
Definition: RtypesCore.h:87
#define ClassImp(name)
Definition: Rtypes.h:363
int type
Definition: TGX11.cxx:120
Node for the BinarySearch or Decision Trees.
const std::vector< Float_t > & GetEventV() const
A simple Binary search tree including a volume search method.
BinarySearchTreeNode * Search(Event *event) const
search the tree to find the node matching "event"
Bool_t InVolume(const std::vector< Float_t > &, Volume *) const
test if the data points are in the given volume
virtual ~BinarySearchTree(void)
destructor
Double_t Fill(const std::vector< TMVA::Event * > &events, const std::vector< Int_t > &theVars, Int_t theType=-1)
create the search tree from the event collection using ONLY the variables specified in "theVars"
void NormalizeTree()
Normalisation of tree.
Double_t GetSumOfWeights(void) const
return the sum of event (node) weights
static BinarySearchTree * CreateFromXML(void *node, UInt_t tmva_Version_Code=TMVA_VERSION_CODE)
re-create a new tree (decision tree or search tree) from XML
void Clear(TMVA::Node *n=0)
clear nodes
BinarySearchTree(void)
default constructor
void CalcStatistics(TMVA::Node *n=0)
calculate basic statistics (mean, rms for each variable)
Double_t SearchVolume(Volume *, std::vector< const TMVA::BinarySearchTreeNode * > *events=0)
search the whole tree and add up all weights of events that lie within the given volume
void Insert(const Event *)
insert a new "event" in the binary tree
Int_t SearchVolumeWithMaxLimit(TMVA::Volume *, std::vector< const TMVA::BinarySearchTreeNode * > *events=0, Int_t=-1)
recursively walk through the daughter nodes and add up all weights of events that lie within the give...
Base class for BinarySearch and Decision Trees.
Definition: BinaryTree.h:62
virtual void ReadXML(void *node, UInt_t tmva_Version_Code=TMVA_VERSION_CODE)
read attributes from XML
Definition: BinaryTree.cxx:144
MsgLogger & Log() const
Definition: BinaryTree.cxx:235
UInt_t GetNVariables() const
accessor to the number of variables
Definition: Event.cxx:309
Node for the BinarySearch or Decision Trees.
Definition: Node.h:56
virtual Node * GetLeft() const
Definition: Node.h:87
virtual void SetRight(Node *r)
Definition: Node.h:93
virtual Bool_t GoesLeft(const Event &) const =0
virtual void SetLeft(Node *l)
Definition: Node.h:92
void SetPos(char s)
Definition: Node.h:117
void SetDepth(UInt_t d)
Definition: Node.h:111
virtual void SetParent(Node *p)
Definition: Node.h:94
virtual Bool_t GoesRight(const Event &) const =0
UInt_t GetDepth() const
Definition: Node.h:114
virtual Node * GetRight() const
Definition: Node.h:88
void ReadAttr(void *node, const char *, T &value)
read attribute from xml
Definition: Tools.h:337
@ kSignal
Definition: Types.h:136
Volume for BinarySearchTree.
Definition: Volume.h:48
std::vector< Double_t > * fLower
Definition: Volume.h:73
std::vector< Double_t > * fUpper
Definition: Volume.h:74
const Int_t n
Definition: legend1.C:16
TClass * GetClass(T *)
Definition: TClass.h:582
Tools & gTools()
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:158
Double_t Log(Double_t x)
Definition: TMath.h:748
Double_t Sqrt(Double_t x)
Definition: TMath.h:679