ROOT  6.07/01
Reference Guide
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Groups Pages
BinarySearchTree.cxx
Go to the documentation of this file.
1 // @(#)root/tmva $Id$
2 // Author: Andreas Hoecker, Joerg Stelzer, Helge Voss, Kai Voss
3 
4 /**********************************************************************************
5  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6  * Package: TMVA *
7  * Class : BinarySearchTree *
8  * Web : http://tmva.sourceforge.net *
9  * *
10  * Description: *
11  * Implementation (see header file for description) *
12  * *
13  * Authors (alphabetical): *
14  * Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland *
15  * Joerg Stelzer <stelzer@cern.ch> - DESY, Germany *
16  * Helge Voss <Helge.Voss@cern.ch> - MPI-K Heidelberg, Germany *
17  * Kai Voss <Kai.Voss@cern.ch> - U. of Victoria, Canada *
18  * *
19  * Copyright (c) 2005: *
20  * CERN, Switzerland *
21  * U. of Victoria, Canada *
22  * MPI-K Heidelberg, Germany *
23  * LAPP, Annecy, France *
24  * *
25  * Redistribution and use in source and binary forms, with or without *
26  * modification, are permitted according to the terms listed in LICENSE *
27  * (http://tmva.sourceforge.net/LICENSE) *
28  * *
29  **********************************************************************************/
30 
31 //////////////////////////////////////////////////////////////////////////
32 // //
33 // BinarySearchTree //
34 // //
35 // A simple Binary search tree including a volume search method //
36 // //
37 //////////////////////////////////////////////////////////////////////////
38 
39 #include <stdexcept>
40 #include <cstdlib>
41 #include <queue>
42 #include <algorithm>
43 
44 // #if ROOT_VERSION_CODE >= 364802
45 // #ifndef ROOT_TMathBase
46 // #include "TMathBase.h"
47 // #endif
48 // #else
49 // #ifndef ROOT_TMath
50 #include "TMath.h"
51 // #endif
52 // #endif
53 
54 #include "TMatrixDBase.h"
55 #include "TObjString.h"
56 #include "TTree.h"
57 
58 #ifndef ROOT_TMVA_MsgLogger
59 #include "TMVA/MsgLogger.h"
60 #endif
61 #ifndef ROOT_TMVA_MethodBase
62 #include "TMVA/MethodBase.h"
63 #endif
64 #ifndef ROOT_TMVA_Tools
65 #include "TMVA/Tools.h"
66 #endif
67 #ifndef ROOT_TMVA_Event
68 #include "TMVA/Event.h"
69 #endif
70 #ifndef ROOT_TMVA_BinarySearchTree
71 #include "TMVA/BinarySearchTree.h"
72 #endif
73 
74 #include "TMVA/Types.h"
75 #include "TMVA/Node.h"
76 
78 
79 ////////////////////////////////////////////////////////////////////////////////
80 /// default constructor
81 
82 TMVA::BinarySearchTree::BinarySearchTree( void ) :
83  BinaryTree(),
84  fPeriod ( 1 ),
85  fCurrentDepth( 0 ),
86  fStatisticsIsValid( kFALSE ),
87  fSumOfWeights( 0 ),
88  fCanNormalize( kFALSE )
89 {
90  fNEventsW[0]=fNEventsW[1]=0.;
91 }
92 
93 ////////////////////////////////////////////////////////////////////////////////
94 /// copy constructor that creates a true copy, i.e. a completely independent tree
95 
97  : BinaryTree(),
98  fPeriod ( b.fPeriod ),
99  fCurrentDepth( 0 ),
100  fStatisticsIsValid( kFALSE ),
101  fSumOfWeights( b.fSumOfWeights ),
102  fCanNormalize( kFALSE )
103 {
104  fNEventsW[0]=fNEventsW[1]=0.;
105  Log() << kFATAL << " Copy constructor not implemented yet " << Endl;
106 }
107 
108 ////////////////////////////////////////////////////////////////////////////////
109 /// destructor
110 
112 {
113  for(std::vector< std::pair<Double_t, const TMVA::Event*> >::iterator pIt = fNormalizeTreeTable.begin();
114  pIt != fNormalizeTreeTable.end(); pIt++) {
115  delete pIt->second;
116  }
117 }
118 
119 ////////////////////////////////////////////////////////////////////////////////
120 /// re-create a new tree (decision tree or search tree) from XML
121 
123  std::string type("");
124  gTools().ReadAttr(node,"type", type);
126  bt->ReadXML( node, tmva_Version_Code );
127  return bt;
128 }
129 
130 ////////////////////////////////////////////////////////////////////////////////
131 /// insert a new "event" in the binary tree
132 
134 {
135  fCurrentDepth=0;
136  fStatisticsIsValid = kFALSE;
137 
138  if (this->GetRoot() == NULL) { // If the list is empty...
139  this->SetRoot( new BinarySearchTreeNode(event)); //Make the new node the root.
140  // have to use "s" for start as "r" for "root" would be the same as "r" for "right"
141  this->GetRoot()->SetPos('s');
142  this->GetRoot()->SetDepth(0);
143  fNNodes = 1;
144  fSumOfWeights = event->GetWeight();
145  ((BinarySearchTreeNode*)this->GetRoot())->SetSelector((UInt_t)0);
146  this->SetPeriode(event->GetNVariables());
147  }
148  else {
149  // sanity check:
150  if (event->GetNVariables() != (UInt_t)this->GetPeriode()) {
151  Log() << kFATAL << "<Insert> event vector length != Periode specified in Binary Tree" << Endl
152  << "--- event size: " << event->GetNVariables() << " Periode: " << this->GetPeriode() << Endl
153  << "--- and all this when trying filling the "<<fNNodes+1<<"th Node" << Endl;
154  }
155  // insert a new node at the propper position
156  this->Insert(event, this->GetRoot());
157  }
158 
159  // normalise the tree to speed up searches
160  if (fCanNormalize) fNormalizeTreeTable.push_back( std::make_pair(0.0,new const Event(*event)) );
161 }
162 
163 ////////////////////////////////////////////////////////////////////////////////
164 /// private internal function to insert a event (node) at the proper position
165 
167  Node *node )
168 {
169  fCurrentDepth++;
170  fStatisticsIsValid = kFALSE;
171 
172  if (node->GoesLeft(*event)){ // If the adding item is less than the current node's data...
173  if (node->GetLeft() != NULL){ // If there is a left node...
174  // Add the new event to the left node
175  this->Insert(event, node->GetLeft());
176  }
177  else { // If there is not a left node...
178  // Make the new node for the new event
180  fNNodes++;
181  fSumOfWeights += event->GetWeight();
182  current->SetSelector(fCurrentDepth%((Int_t)event->GetNVariables()));
183  current->SetParent(node); // Set the new node's previous node.
184  current->SetPos('l');
185  current->SetDepth( node->GetDepth() + 1 );
186  node->SetLeft(current); // Make it the left node of the current one.
187  }
188  }
189  else if (node->GoesRight(*event)) { // If the adding item is less than or equal to the current node's data...
190  if (node->GetRight() != NULL) { // If there is a right node...
191  // Add the new node to it.
192  this->Insert(event, node->GetRight());
193  }
194  else { // If there is not a right node...
195  // Make the new node.
197  fNNodes++;
198  fSumOfWeights += event->GetWeight();
199  current->SetSelector(fCurrentDepth%((Int_t)event->GetNVariables()));
200  current->SetParent(node); // Set the new node's previous node.
201  current->SetPos('r');
202  current->SetDepth( node->GetDepth() + 1 );
203  node->SetRight(current); // Make it the left node of the current one.
204  }
205  }
206  else Log() << kFATAL << "<Insert> neither left nor right :)" << Endl;
207 }
208 
209 ////////////////////////////////////////////////////////////////////////////////
210 ///search the tree to find the node matching "event"
211 
213 {
214  return this->Search( event, this->GetRoot() );
215 }
216 
217 ////////////////////////////////////////////////////////////////////////////////
218 /// Private, recursive, function for searching.
219 
221 {
222  if (node != NULL) { // If the node is not NULL...
223  // If we have found the node...
224  if (((BinarySearchTreeNode*)(node))->EqualsMe(*event))
225  return (BinarySearchTreeNode*)node; // Return it
226  if (node->GoesLeft(*event)) // If the node's data is greater than the search item...
227  return this->Search(event, node->GetLeft()); //Search the left node.
228  else //If the node's data is less than the search item...
229  return this->Search(event, node->GetRight()); //Search the right node.
230  }
231  else return NULL; //If the node is NULL, return NULL.
232 }
233 
234 ////////////////////////////////////////////////////////////////////////////////
235 ///return the sum of event (node) weights
236 
238 {
239  if (fSumOfWeights <= 0) {
240  Log() << kWARNING << "you asked for the SumOfWeights, which is not filled yet"
241  << " I call CalcStatistics which hopefully fixes things"
242  << Endl;
243  }
244  if (fSumOfWeights <= 0) Log() << kFATAL << " Zero events in your Search Tree" <<Endl;
245 
246  return fSumOfWeights;
247 }
248 
249 ////////////////////////////////////////////////////////////////////////////////
250 ///return the sum of event (node) weights
251 
253 {
254  if (fSumOfWeights <= 0) {
255  Log() << kWARNING << "you asked for the SumOfWeights, which is not filled yet"
256  << " I call CalcStatistics which hopefully fixes things"
257  << Endl;
258  }
259  if (fSumOfWeights <= 0) Log() << kFATAL << " Zero events in your Search Tree" <<Endl;
260 
261  return fNEventsW[ ( theType == Types::kSignal) ? 0 : 1 ];
262 }
263 
264 ////////////////////////////////////////////////////////////////////////////////
265 /// create the search tree from the event collection
266 /// using ONLY the variables specified in "theVars"
267 
268 Double_t TMVA::BinarySearchTree::Fill( const std::vector<Event*>& events, const std::vector<Int_t>& theVars,
269  Int_t theType )
270 {
271  fPeriod = theVars.size();
272  return Fill(events, theType);
273 }
274 
275 ////////////////////////////////////////////////////////////////////////////////
276 /// create the search tree from the events in a TTree
277 /// using ALL the variables specified included in the Event
278 
279 Double_t TMVA::BinarySearchTree::Fill( const std::vector<Event*>& events, Int_t theType )
280 {
281  UInt_t n=events.size();
282 
283  UInt_t nevents = 0;
284  if (fSumOfWeights != 0) {
285  Log() << kWARNING
286  << "You are filling a search three that is not empty.. "
287  << " do you know what you are doing?"
288  << Endl;
289  }
290  for (UInt_t ievt=0; ievt<n; ievt++) {
291  // insert event into binary tree
292  if (theType == -1 || (Int_t(events[ievt]->GetClass()) == theType) ) {
293  this->Insert( events[ievt] );
294  nevents++;
295  fSumOfWeights += events[ievt]->GetWeight();
296  }
297  } // end of event loop
298  CalcStatistics(0);
299 
300  return fSumOfWeights;
301 }
302 
303 ////////////////////////////////////////////////////////////////////////////////
304 
305 void TMVA::BinarySearchTree::NormalizeTree ( std::vector< std::pair<Double_t, const TMVA::Event*> >::iterator leftBound,
306  std::vector< std::pair<Double_t, const TMVA::Event*> >::iterator rightBound,
307  UInt_t actDim )
308 {
309  // normalises the binary-search tree to reduce the branch length and hence speed up the
310  // search procedure (on average)
311  if (leftBound == rightBound) return;
312 
313  if (actDim == fPeriod) actDim = 0;
314  for (std::vector< std::pair<Double_t, const TMVA::Event*> >::iterator i=leftBound; i!=rightBound; i++) {
315  i->first = i->second->GetValue( actDim );
316  }
317 
318  std::sort( leftBound, rightBound );
319 
320  std::vector< std::pair<Double_t, const TMVA::Event*> >::iterator leftTemp = leftBound;
321  std::vector< std::pair<Double_t, const TMVA::Event*> >::iterator rightTemp = rightBound;
322 
323  // meet in the middle
324  while (true) {
325  rightTemp--;
326  if (rightTemp == leftTemp ) {
327  break;
328  }
329  leftTemp++;
330  if (leftTemp == rightTemp) {
331  break;
332  }
333  }
334 
335  std::vector< std::pair<Double_t, const TMVA::Event*> >::iterator mid = leftTemp;
336  std::vector< std::pair<Double_t, const TMVA::Event*> >::iterator midTemp = mid;
337 
338  if (mid!=leftBound) midTemp--;
339 
340  while (mid != leftBound && mid->second->GetValue( actDim ) == midTemp->second->GetValue( actDim )) {
341  mid--;
342  midTemp--;
343  }
344 
345  Insert( mid->second );
346 
347  // Print(std::cout);
348  // std::cout << std::endl << std::endl;
349 
350  NormalizeTree( leftBound, mid, actDim+1 );
351  mid++;
352  // Print(std::cout);
353  // std::cout << std::endl << std::endl;
354  NormalizeTree( mid, rightBound, actDim+1 );
355 
356 
357  return;
358 }
359 
360 ////////////////////////////////////////////////////////////////////////////////
361 /// Normalisation of tree
362 
364 {
365  SetNormalize( kFALSE );
366  Clear( NULL );
367  this->SetRoot(NULL);
368  NormalizeTree( fNormalizeTreeTable.begin(), fNormalizeTreeTable.end(), 0 );
369 }
370 
371 ////////////////////////////////////////////////////////////////////////////////
372 /// clear nodes
373 
375 {
376  BinarySearchTreeNode* currentNode = (BinarySearchTreeNode*)(n == NULL ? this->GetRoot() : n);
377 
378  if (currentNode->GetLeft() != 0) Clear( currentNode->GetLeft() );
379  if (currentNode->GetRight() != 0) Clear( currentNode->GetRight() );
380 
381  if (n != NULL) delete n;
382 
383  return;
384 }
385 
386 ////////////////////////////////////////////////////////////////////////////////
387 /// search the whole tree and add up all weigths of events that
388 /// lie within the given voluem
389 
391  std::vector<const BinarySearchTreeNode*>* events )
392 {
393  return SearchVolume( this->GetRoot(), volume, 0, events );
394 }
395 
396 ////////////////////////////////////////////////////////////////////////////////
397 /// recursively walk through the daughter nodes and add up all weigths of events that
398 /// lie within the given volume
399 
401  std::vector<const BinarySearchTreeNode*>* events )
402 {
403  if (t==NULL) return 0; // Are we at an outer leave?
404 
406 
407  Double_t count = 0.0;
408  if (InVolume( st->GetEventV(), volume )) {
409  count += st->GetWeight();
410  if (NULL != events) events->push_back( st );
411  }
412  if (st->GetLeft()==NULL && st->GetRight()==NULL) {
413 
414  return count; // Are we at an outer leave?
415  }
416 
417  Bool_t tl, tr;
418  Int_t d = depth%this->GetPeriode();
419  if (d != st->GetSelector()) {
420  Log() << kFATAL << "<SearchVolume> selector in Searchvolume "
421  << d << " != " << "node "<< st->GetSelector() << Endl;
422  }
423  tl = (*(volume->fLower))[d] < st->GetEventV()[d]; // Should we descend left?
424  tr = (*(volume->fUpper))[d] >= st->GetEventV()[d]; // Should we descend right?
425 
426  if (tl) count += SearchVolume( st->GetLeft(), volume, (depth+1), events );
427  if (tr) count += SearchVolume( st->GetRight(), volume, (depth+1), events );
428 
429  return count;
430 }
431 
432 Bool_t TMVA::BinarySearchTree::InVolume(const std::vector<Float_t>& event, Volume* volume ) const
433 {
434  // test if the data points are in the given volume
435 
436  Bool_t result = false;
437  for (UInt_t ivar=0; ivar< fPeriod; ivar++) {
438  result = ( (*(volume->fLower))[ivar] < event[ivar] &&
439  (*(volume->fUpper))[ivar] >= event[ivar] );
440  if (!result) break;
441  }
442  return result;
443 }
444 
445 ////////////////////////////////////////////////////////////////////////////////
446 /// calculate basic statistics (mean, rms for each variable)
447 
449 {
450  if (fStatisticsIsValid) return;
451 
452  BinarySearchTreeNode * currentNode = (BinarySearchTreeNode*)n;
453 
454  // default, start at the tree top, then descend recursively
455  if (n == NULL) {
456  fSumOfWeights = 0;
457  for (Int_t sb=0; sb<2; sb++) {
458  fNEventsW[sb] = 0;
459  fMeans[sb] = std::vector<Float_t>(fPeriod);
460  fRMS[sb] = std::vector<Float_t>(fPeriod);
461  fMin[sb] = std::vector<Float_t>(fPeriod);
462  fMax[sb] = std::vector<Float_t>(fPeriod);
463  fSum[sb] = std::vector<Double_t>(fPeriod);
464  fSumSq[sb] = std::vector<Double_t>(fPeriod);
465  for (UInt_t j=0; j<fPeriod; j++) {
466  fMeans[sb][j] = fRMS[sb][j] = fSum[sb][j] = fSumSq[sb][j] = 0;
467  fMin[sb][j] = FLT_MAX;
468  fMax[sb][j] = -FLT_MAX;
469  }
470  }
471  currentNode = (BinarySearchTreeNode*) this->GetRoot();
472  if (currentNode == NULL) return; // no root-node
473  }
474 
475  const std::vector<Float_t> & evtVec = currentNode->GetEventV();
476  Double_t weight = currentNode->GetWeight();
477 // Int_t type = currentNode->IsSignal();
478 // Int_t type = currentNode->IsSignal() ? 0 : 1;
479  Int_t type = Int_t(currentNode->GetClass())== Types::kSignal ? 0 : 1;
480 
481  fNEventsW[type] += weight;
482  fSumOfWeights += weight;
483 
484  for (UInt_t j=0; j<fPeriod; j++) {
485  Float_t val = evtVec[j];
486  fSum[type][j] += val*weight;
487  fSumSq[type][j] += val*val*weight;
488  if (val < fMin[type][j]) fMin[type][j] = val;
489  if (val > fMax[type][j]) fMax[type][j] = val;
490  }
491 
492  if ( (currentNode->GetLeft() != NULL) ) CalcStatistics( currentNode->GetLeft() );
493  if ( (currentNode->GetRight() != NULL) ) CalcStatistics( currentNode->GetRight() );
494 
495  if (n == NULL) { // i.e. the root node
496  for (Int_t sb=0; sb<2; sb++) {
497  for (UInt_t j=0; j<fPeriod; j++) {
498  if (fNEventsW[sb] == 0) { fMeans[sb][j] = fRMS[sb][j] = 0; continue; }
499  fMeans[sb][j] = fSum[sb][j]/fNEventsW[sb];
500  fRMS[sb][j] = TMath::Sqrt(fSumSq[sb][j]/fNEventsW[sb] - fMeans[sb][j]*fMeans[sb][j]);
501  }
502  }
503  fStatisticsIsValid = kTRUE;
504  }
505 
506  return;
507 }
508 
509 Int_t TMVA::BinarySearchTree::SearchVolumeWithMaxLimit( Volume *volume, std::vector<const BinarySearchTreeNode*>* events,
510  Int_t max_points )
511 {
512  // recursively walk through the daughter nodes and add up all weigths of events that
513  // lie within the given volume a maximum number of events can be given
514  if (this->GetRoot() == NULL) return 0; // Are we at an outer leave?
515 
516  std::queue< std::pair< const BinarySearchTreeNode*, Int_t > > queue;
517  std::pair< const BinarySearchTreeNode*, Int_t > st = std::make_pair( (const BinarySearchTreeNode*)this->GetRoot(), 0 );
518  queue.push( st );
519 
520  Int_t count = 0;
521 
522  while ( !queue.empty() ) {
523  st = queue.front(); queue.pop();
524 
525  if (count == max_points)
526  return count;
527 
528  if (InVolume( st.first->GetEventV(), volume )) {
529  count++;
530  if (NULL != events) events->push_back( st.first );
531  }
532 
533  Bool_t tl, tr;
534  Int_t d = st.second;
535  if ( d == Int_t(this->GetPeriode()) ) d = 0;
536 
537  if (d != st.first->GetSelector()) {
538  Log() << kFATAL << "<SearchVolume> selector in Searchvolume "
539  << d << " != " << "node "<< st.first->GetSelector() << Endl;
540  }
541 
542  tl = (*(volume->fLower))[d] < st.first->GetEventV()[d] && st.first->GetLeft() != NULL; // Should we descend left?
543  tr = (*(volume->fUpper))[d] >= st.first->GetEventV()[d] && st.first->GetRight() != NULL; // Should we descend right?
544 
545  if (tl) queue.push( std::make_pair( (const BinarySearchTreeNode*)st.first->GetLeft(), d+1 ) );
546  if (tr) queue.push( std::make_pair( (const BinarySearchTreeNode*)st.first->GetRight(), d+1 ) );
547  }
548 
549  return count;
550 }
std::vector< Double_t > * fLower
Definition: Volume.h:75
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:162
virtual ~BinarySearchTree(void)
destructor
Bool_t InVolume(const std::vector< Float_t > &, Volume *) const
float Float_t
Definition: RtypesCore.h:53
const char * current
Definition: demos.C:12
std::vector< Double_t > * fUpper
Definition: Volume.h:76
Int_t SearchVolumeWithMaxLimit(TMVA::Volume *, std::vector< const TMVA::BinarySearchTreeNode * > *events=0, Int_t=-1)
int Int_t
Definition: RtypesCore.h:41
bool Bool_t
Definition: RtypesCore.h:59
virtual void SetRight(Node *r)
Definition: Node.h:97
const Bool_t kFALSE
Definition: Rtypes.h:92
void NormalizeTree()
Normalisation of tree.
void SetDepth(UInt_t d)
Definition: Node.h:115
TClass * GetClass(T *)
Definition: TClass.h:554
Tools & gTools()
Definition: Tools.cxx:79
UInt_t GetDepth() const
Definition: Node.h:118
virtual Bool_t GoesLeft(const Event &) const =0
virtual void SetLeft(Node *l)
Definition: Node.h:96
int d
Definition: tornado.py:11
void Clear(TMVA::Node *n=0)
clear nodes
ClassImp(TMVA::BinarySearchTree) TMVA
default constructor
Double_t SearchVolume(Volume *, std::vector< const TMVA::BinarySearchTreeNode * > *events=0)
search the whole tree and add up all weigths of events that lie within the given voluem ...
UInt_t GetNVariables() const
accessor to the number of variables
Definition: Event.cxx:303
Double_t GetSumOfWeights(void) const
return the sum of event (node) weights
void CalcStatistics(TMVA::Node *n=0)
calculate basic statistics (mean, rms for each variable)
static BinarySearchTree * CreateFromXML(void *node, UInt_t tmva_Version_Code=TMVA_VERSION_CODE)
re-create a new tree (decision tree or search tree) from XML
TThread * t[5]
Definition: threadsh1.C:13
unsigned int UInt_t
Definition: RtypesCore.h:42
void ReadAttr(void *node, const char *, T &value)
Definition: Tools.h:295
void Insert(const Event *)
insert a new "event" in the binary tree
virtual void ReadXML(void *node, UInt_t tmva_Version_Code=TMVA_VERSION_CODE)
read attributes from XML
Definition: BinaryTree.cxx:143
virtual void SetParent(Node *p)
Definition: Node.h:98
double Double_t
Definition: RtypesCore.h:55
int type
Definition: TGX11.cxx:120
virtual Node * GetRight() const
Definition: Node.h:92
virtual Bool_t GoesRight(const Event &) const =0
Double_t Fill(const std::vector< TMVA::Event * > &events, const std::vector< Int_t > &theVars, Int_t theType=-1)
create the search tree from the event collection using ONLY the variables specified in "theVars" ...
const std::vector< Float_t > & GetEventV() const
#define NULL
Definition: Rtypes.h:82
double result[121]
Double_t Sqrt(Double_t x)
Definition: TMath.h:464
MsgLogger & Log() const
Definition: BinaryTree.cxx:234
const Bool_t kTRUE
Definition: Rtypes.h:91
BinarySearchTreeNode * Search(Event *event) const
search the tree to find the node matching "event"
const Int_t n
Definition: legend1.C:16
Definition: math.cpp:60
void SetPos(char s)
Definition: Node.h:121
virtual Node * GetLeft() const
Definition: Node.h:91