Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
BinarySearchTree.cxx
Go to the documentation of this file.
1// @(#)root/tmva $Id$
2// Author: Andreas Hoecker, Joerg Stelzer, Helge Voss, Kai Voss
3
4/**********************************************************************************
5 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6 * Package: TMVA *
7 * Class : BinarySearchTree *
8 * Web : http://tmva.sourceforge.net *
9 * *
10 * Description: *
11 * Implementation (see header file for description) *
12 * *
13 * Authors (alphabetical): *
14 * Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland *
15 * Joerg Stelzer <stelzer@cern.ch> - DESY, Germany *
16 * Helge Voss <Helge.Voss@cern.ch> - MPI-K Heidelberg, Germany *
17 * Kai Voss <Kai.Voss@cern.ch> - U. of Victoria, Canada *
18 * *
19 * Copyright (c) 2005: *
20 * CERN, Switzerland *
21 * U. of Victoria, Canada *
22 * MPI-K Heidelberg, Germany *
23 * LAPP, Annecy, France *
24 * *
25 * Redistribution and use in source and binary forms, with or without *
26 * modification, are permitted according to the terms listed in LICENSE *
27 * (http://tmva.sourceforge.net/LICENSE) *
28 * *
29 **********************************************************************************/
30
31/*! \class TMVA::BinarySearchTree
32\ingroup TMVA
33
34A simple Binary search tree including a volume search method.
35
36*/
37
38#include <stdexcept>
39#include <cstdlib>
40#include <queue>
41#include <algorithm>
42
43#include "TMath.h"
44
45#include "TMatrixDBase.h"
46#include "TTree.h"
47
48#include "TMVA/MsgLogger.h"
49#include "TMVA/MethodBase.h"
50#include "TMVA/Tools.h"
51#include "TMVA/Event.h"
53
54#include "TMVA/BinaryTree.h"
55#include "TMVA/Types.h"
56#include "TMVA/Node.h"
57
59
60////////////////////////////////////////////////////////////////////////////////
61/// default constructor
62
65 fPeriod ( 1 ),
66 fCurrentDepth( 0 ),
67 fStatisticsIsValid( kFALSE ),
68 fSumOfWeights( 0 ),
69 fCanNormalize( kFALSE )
70{
71 fNEventsW[0]=fNEventsW[1]=0.;
72}
73
74////////////////////////////////////////////////////////////////////////////////
75/// copy constructor that creates a true copy, i.e. a completely independent tree
76
78 : BinaryTree(),
79 fPeriod ( b.fPeriod ),
80 fCurrentDepth( 0 ),
81 fStatisticsIsValid( kFALSE ),
82 fSumOfWeights( b.fSumOfWeights ),
83 fCanNormalize( kFALSE )
84{
85 fNEventsW[0]=fNEventsW[1]=0.;
86 Log() << kFATAL << " Copy constructor not implemented yet " << Endl;
87}
88
89////////////////////////////////////////////////////////////////////////////////
90/// destructor
91
93{
94 for(std::vector< std::pair<Double_t, const TMVA::Event*> >::iterator pIt = fNormalizeTreeTable.begin();
95 pIt != fNormalizeTreeTable.end(); ++pIt) {
96 delete pIt->second;
97 }
98}
99
100////////////////////////////////////////////////////////////////////////////////
101/// re-create a new tree (decision tree or search tree) from XML
102
104 std::string type("");
105 gTools().ReadAttr(node,"type", type);
107 bt->ReadXML( node, tmva_Version_Code );
108 return bt;
109}
110
111////////////////////////////////////////////////////////////////////////////////
112/// insert a new "event" in the binary tree
113
115{
116 fCurrentDepth=0;
117 fStatisticsIsValid = kFALSE;
118
119 if (this->GetRoot() == NULL) { // If the list is empty...
120 this->SetRoot( new BinarySearchTreeNode(event)); //Make the new node the root.
121 // have to use "s" for start as "r" for "root" would be the same as "r" for "right"
122 this->GetRoot()->SetPos('s');
123 this->GetRoot()->SetDepth(0);
124 fNNodes = 1;
125 fSumOfWeights = event->GetWeight();
126 ((BinarySearchTreeNode*)this->GetRoot())->SetSelector((UInt_t)0);
127 this->SetPeriode(event->GetNVariables());
128 }
129 else {
130 // sanity check:
131 if (event->GetNVariables() != (UInt_t)this->GetPeriode()) {
132 Log() << kFATAL << "<Insert> event vector length != Periode specified in Binary Tree" << Endl
133 << "--- event size: " << event->GetNVariables() << " Periode: " << this->GetPeriode() << Endl
134 << "--- and all this when trying filling the "<<fNNodes+1<<"th Node" << Endl;
135 }
136 // insert a new node at the propper position
137 this->Insert(event, this->GetRoot());
138 }
139
140 // normalise the tree to speed up searches
141 if (fCanNormalize) fNormalizeTreeTable.push_back( std::make_pair(0.0,new const Event(*event)) );
142}
143
144////////////////////////////////////////////////////////////////////////////////
145/// private internal function to insert a event (node) at the proper position
146
148 Node *node )
149{
150 fCurrentDepth++;
151 fStatisticsIsValid = kFALSE;
152
153 if (node->GoesLeft(*event)){ // If the adding item is less than the current node's data...
154 if (node->GetLeft() != NULL){ // If there is a left node...
155 // Add the new event to the left node
156 this->Insert(event, node->GetLeft());
157 }
158 else { // If there is not a left node...
159 // Make the new node for the new event
161 fNNodes++;
162 fSumOfWeights += event->GetWeight();
163 current->SetSelector(fCurrentDepth%((Int_t)event->GetNVariables()));
164 current->SetParent(node); // Set the new node's previous node.
165 current->SetPos('l');
166 current->SetDepth( node->GetDepth() + 1 );
167 node->SetLeft(current); // Make it the left node of the current one.
168 }
169 }
170 else if (node->GoesRight(*event)) { // If the adding item is less than or equal to the current node's data...
171 if (node->GetRight() != NULL) { // If there is a right node...
172 // Add the new node to it.
173 this->Insert(event, node->GetRight());
174 }
175 else { // If there is not a right node...
176 // Make the new node.
178 fNNodes++;
179 fSumOfWeights += event->GetWeight();
180 current->SetSelector(fCurrentDepth%((Int_t)event->GetNVariables()));
181 current->SetParent(node); // Set the new node's previous node.
182 current->SetPos('r');
183 current->SetDepth( node->GetDepth() + 1 );
184 node->SetRight(current); // Make it the left node of the current one.
185 }
186 }
187 else Log() << kFATAL << "<Insert> neither left nor right :)" << Endl;
188}
189
190////////////////////////////////////////////////////////////////////////////////
191///search the tree to find the node matching "event"
192
194{
195 return this->Search( event, this->GetRoot() );
196}
197
198////////////////////////////////////////////////////////////////////////////////
199/// Private, recursive, function for searching.
200
202{
203 if (node != NULL) { // If the node is not NULL...
204 // If we have found the node...
205 if (((BinarySearchTreeNode*)(node))->EqualsMe(*event))
206 return (BinarySearchTreeNode*)node; // Return it
207 if (node->GoesLeft(*event)) // If the node's data is greater than the search item...
208 return this->Search(event, node->GetLeft()); //Search the left node.
209 else //If the node's data is less than the search item...
210 return this->Search(event, node->GetRight()); //Search the right node.
211 }
212 else return NULL; //If the node is NULL, return NULL.
213}
214
215////////////////////////////////////////////////////////////////////////////////
216/// return the sum of event (node) weights
217
219{
220 if (fSumOfWeights <= 0) {
221 Log() << kWARNING << "you asked for the SumOfWeights, which is not filled yet"
222 << " I call CalcStatistics which hopefully fixes things"
223 << Endl;
224 }
225 if (fSumOfWeights <= 0) Log() << kFATAL << " Zero events in your Search Tree" <<Endl;
226
227 return fSumOfWeights;
228}
229
230////////////////////////////////////////////////////////////////////////////////
231/// return the sum of event (node) weights
232
234{
235 if (fSumOfWeights <= 0) {
236 Log() << kWARNING << "you asked for the SumOfWeights, which is not filled yet"
237 << " I call CalcStatistics which hopefully fixes things"
238 << Endl;
239 }
240 if (fSumOfWeights <= 0) Log() << kFATAL << " Zero events in your Search Tree" <<Endl;
241
242 return fNEventsW[ ( theType == Types::kSignal) ? 0 : 1 ];
243}
244
245////////////////////////////////////////////////////////////////////////////////
246/// create the search tree from the event collection
247/// using ONLY the variables specified in "theVars"
248
249Double_t TMVA::BinarySearchTree::Fill( const std::vector<Event*>& events, const std::vector<Int_t>& theVars,
250 Int_t theType )
251{
252 fPeriod = theVars.size();
253 return Fill(events, theType);
254}
255
256////////////////////////////////////////////////////////////////////////////////
257/// create the search tree from the events in a TTree
258/// using ALL the variables specified included in the Event
259
260Double_t TMVA::BinarySearchTree::Fill( const std::vector<Event*>& events, Int_t theType )
261{
262 UInt_t n=events.size();
263
264 if (fSumOfWeights != 0) {
265 Log() << kWARNING
266 << "You are filling a search three that is not empty.. "
267 << " do you know what you are doing?"
268 << Endl;
269 }
270 for (UInt_t ievt=0; ievt<n; ievt++) {
271 // insert event into binary tree
272 if (theType == -1 || (Int_t(events[ievt]->GetClass()) == theType) ) {
273 this->Insert( events[ievt] );
274 fSumOfWeights += events[ievt]->GetWeight();
275 }
276 } // end of event loop
277 CalcStatistics(0);
278
279 return fSumOfWeights;
280}
281
282////////////////////////////////////////////////////////////////////////////////
283/// normalises the binary-search tree to reduce the branch length and hence speed up the
284/// search procedure (on average).
285
286void TMVA::BinarySearchTree::NormalizeTree ( std::vector< std::pair<Double_t, const TMVA::Event*> >::iterator leftBound,
287 std::vector< std::pair<Double_t, const TMVA::Event*> >::iterator rightBound,
288 UInt_t actDim )
289{
290
291 if (leftBound == rightBound) return;
292
293 if (actDim == fPeriod) actDim = 0;
294 for (std::vector< std::pair<Double_t, const TMVA::Event*> >::iterator i=leftBound; i!=rightBound; ++i) {
295 i->first = i->second->GetValue( actDim );
296 }
297
298 std::sort( leftBound, rightBound );
299
300 std::vector< std::pair<Double_t, const TMVA::Event*> >::iterator leftTemp = leftBound;
301 std::vector< std::pair<Double_t, const TMVA::Event*> >::iterator rightTemp = rightBound;
302
303 // meet in the middle
304 while (true) {
305 --rightTemp;
306 if (rightTemp == leftTemp ) {
307 break;
308 }
309 ++leftTemp;
310 if (leftTemp == rightTemp) {
311 break;
312 }
313 }
314
315 std::vector< std::pair<Double_t, const TMVA::Event*> >::iterator mid = leftTemp;
316 std::vector< std::pair<Double_t, const TMVA::Event*> >::iterator midTemp = mid;
317
318 if (mid!=leftBound)--midTemp;
319
320 while (mid != leftBound && mid->second->GetValue( actDim ) == midTemp->second->GetValue( actDim )) {
321 --mid;
322 --midTemp;
323 }
324
325 Insert( mid->second );
326
327 // Print(std::cout);
328 // std::cout << std::endl << std::endl;
329
330 NormalizeTree( leftBound, mid, actDim+1 );
331 ++mid;
332 // Print(std::cout);
333 // std::cout << std::endl << std::endl;
334 NormalizeTree( mid, rightBound, actDim+1 );
335
336
337 return;
338}
339
340////////////////////////////////////////////////////////////////////////////////
341/// Normalisation of tree
342
344{
345 SetNormalize( kFALSE );
346 Clear( NULL );
347 this->SetRoot(NULL);
348 NormalizeTree( fNormalizeTreeTable.begin(), fNormalizeTreeTable.end(), 0 );
349}
350
351////////////////////////////////////////////////////////////////////////////////
352/// clear nodes
353
355{
356 BinarySearchTreeNode* currentNode = (BinarySearchTreeNode*)(n == NULL ? this->GetRoot() : n);
357
358 if (currentNode->GetLeft() != 0) Clear( currentNode->GetLeft() );
359 if (currentNode->GetRight() != 0) Clear( currentNode->GetRight() );
360
361 if (n != NULL) delete n;
362
363 return;
364}
365
366////////////////////////////////////////////////////////////////////////////////
367/// search the whole tree and add up all weights of events that
368/// lie within the given volume
369
371 std::vector<const BinarySearchTreeNode*>* events )
372{
373 return SearchVolume( this->GetRoot(), volume, 0, events );
374}
375
376////////////////////////////////////////////////////////////////////////////////
377/// recursively walk through the daughter nodes and add up all weights of events that
378/// lie within the given volume
379
381 std::vector<const BinarySearchTreeNode*>* events )
382{
383 if (t==NULL) return 0; // Are we at an outer leave?
384
386
387 Double_t count = 0.0;
388 if (InVolume( st->GetEventV(), volume )) {
389 count += st->GetWeight();
390 if (NULL != events) events->push_back( st );
391 }
392 if (st->GetLeft()==NULL && st->GetRight()==NULL) {
393
394 return count; // Are we at an outer leave?
395 }
396
397 Bool_t tl, tr;
398 Int_t d = depth%this->GetPeriode();
399 if (d != st->GetSelector()) {
400 Log() << kFATAL << "<SearchVolume> selector in Searchvolume "
401 << d << " != " << "node "<< st->GetSelector() << Endl;
402 }
403 tl = (*(volume->fLower))[d] < st->GetEventV()[d]; // Should we descend left?
404 tr = (*(volume->fUpper))[d] >= st->GetEventV()[d]; // Should we descend right?
405
406 if (tl) count += SearchVolume( st->GetLeft(), volume, (depth+1), events );
407 if (tr) count += SearchVolume( st->GetRight(), volume, (depth+1), events );
408
409 return count;
410}
411
412////////////////////////////////////////////////////////////////////////////////
413/// test if the data points are in the given volume
414
415Bool_t TMVA::BinarySearchTree::InVolume(const std::vector<Float_t>& event, Volume* volume ) const
416{
417
418 Bool_t result = false;
419 for (UInt_t ivar=0; ivar< fPeriod; ivar++) {
420 result = ( (*(volume->fLower))[ivar] < event[ivar] &&
421 (*(volume->fUpper))[ivar] >= event[ivar] );
422 if (!result) break;
423 }
424 return result;
425}
426
427////////////////////////////////////////////////////////////////////////////////
428/// calculate basic statistics (mean, rms for each variable)
429
431{
432 if (fStatisticsIsValid) return;
433
435
436 // default, start at the tree top, then descend recursively
437 if (n == NULL) {
438 fSumOfWeights = 0;
439 for (Int_t sb=0; sb<2; sb++) {
440 fNEventsW[sb] = 0;
441 fMeans[sb] = std::vector<Float_t>(fPeriod);
442 fRMS[sb] = std::vector<Float_t>(fPeriod);
443 fMin[sb] = std::vector<Float_t>(fPeriod);
444 fMax[sb] = std::vector<Float_t>(fPeriod);
445 fSum[sb] = std::vector<Double_t>(fPeriod);
446 fSumSq[sb] = std::vector<Double_t>(fPeriod);
447 for (UInt_t j=0; j<fPeriod; j++) {
448 fMeans[sb][j] = fRMS[sb][j] = fSum[sb][j] = fSumSq[sb][j] = 0;
449 fMin[sb][j] = FLT_MAX;
450 fMax[sb][j] = -FLT_MAX;
451 }
452 }
453 currentNode = (BinarySearchTreeNode*) this->GetRoot();
454 if (currentNode == NULL) return; // no root-node
455 }
456
457 const std::vector<Float_t> & evtVec = currentNode->GetEventV();
458 Double_t weight = currentNode->GetWeight();
459 // Int_t type = currentNode->IsSignal();
460 // Int_t type = currentNode->IsSignal() ? 0 : 1;
461 Int_t type = Int_t(currentNode->GetClass())== Types::kSignal ? 0 : 1;
462
463 fNEventsW[type] += weight;
464 fSumOfWeights += weight;
465
466 for (UInt_t j=0; j<fPeriod; j++) {
467 Float_t val = evtVec[j];
468 fSum[type][j] += val*weight;
469 fSumSq[type][j] += val*val*weight;
470 if (val < fMin[type][j]) fMin[type][j] = val;
471 if (val > fMax[type][j]) fMax[type][j] = val;
472 }
473
474 if ( (currentNode->GetLeft() != NULL) ) CalcStatistics( currentNode->GetLeft() );
475 if ( (currentNode->GetRight() != NULL) ) CalcStatistics( currentNode->GetRight() );
476
477 if (n == NULL) { // i.e. the root node
478 for (Int_t sb=0; sb<2; sb++) {
479 for (UInt_t j=0; j<fPeriod; j++) {
480 if (fNEventsW[sb] == 0) { fMeans[sb][j] = fRMS[sb][j] = 0; continue; }
481 fMeans[sb][j] = fSum[sb][j]/fNEventsW[sb];
482 fRMS[sb][j] = TMath::Sqrt(fSumSq[sb][j]/fNEventsW[sb] - fMeans[sb][j]*fMeans[sb][j]);
483 }
484 }
485 fStatisticsIsValid = kTRUE;
486 }
487
488 return;
489}
490
491////////////////////////////////////////////////////////////////////////////////
492/// recursively walk through the daughter nodes and add up all weights of events that
493/// lie within the given volume a maximum number of events can be given
494
495Int_t TMVA::BinarySearchTree::SearchVolumeWithMaxLimit( Volume *volume, std::vector<const BinarySearchTreeNode*>* events,
496 Int_t max_points )
497{
498 if (this->GetRoot() == NULL) return 0; // Are we at an outer leave?
499
500 std::queue< std::pair< const BinarySearchTreeNode*, Int_t > > queue;
501 std::pair< const BinarySearchTreeNode*, Int_t > st = std::make_pair( (const BinarySearchTreeNode*)this->GetRoot(), 0 );
502 queue.push( st );
503
504 Int_t count = 0;
505
506 while ( !queue.empty() ) {
507 st = queue.front(); queue.pop();
508
509 if (count == max_points)
510 return count;
511
512 if (InVolume( st.first->GetEventV(), volume )) {
513 count++;
514 if (NULL != events) events->push_back( st.first );
515 }
516
517 Bool_t tl, tr;
518 Int_t d = st.second;
519 if ( d == Int_t(this->GetPeriode()) ) d = 0;
520
521 if (d != st.first->GetSelector()) {
522 Log() << kFATAL << "<SearchVolume> selector in Searchvolume "
523 << d << " != " << "node "<< st.first->GetSelector() << Endl;
524 }
525
526 tl = (*(volume->fLower))[d] < st.first->GetEventV()[d] && st.first->GetLeft() != NULL; // Should we descend left?
527 tr = (*(volume->fUpper))[d] >= st.first->GetEventV()[d] && st.first->GetRight() != NULL; // Should we descend right?
528
529 if (tl) queue.push( std::make_pair( (const BinarySearchTreeNode*)st.first->GetLeft(), d+1 ) );
530 if (tr) queue.push( std::make_pair( (const BinarySearchTreeNode*)st.first->GetRight(), d+1 ) );
531 }
532
533 return count;
534}
#define d(i)
Definition RSha256.hxx:102
#define b(i)
Definition RSha256.hxx:100
int Int_t
Definition RtypesCore.h:45
float Float_t
Definition RtypesCore.h:57
constexpr Bool_t kFALSE
Definition RtypesCore.h:101
constexpr Bool_t kTRUE
Definition RtypesCore.h:100
#define ClassImp(name)
Definition Rtypes.h:377
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t result
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t Atom_t Time_t type
Node for the BinarySearch or Decision Trees.
void SetSelector(Short_t i)
set index of variable used for discrimination at this node
const std::vector< Float_t > & GetEventV() const
Short_t GetSelector() const
return index of variable used for discrimination at this node
A simple Binary search tree including a volume search method.
Int_t SearchVolumeWithMaxLimit(TMVA::Volume *, std::vector< const TMVA::BinarySearchTreeNode * > *events=nullptr, Int_t=-1)
recursively walk through the daughter nodes and add up all weights of events that lie within the give...
BinarySearchTreeNode * Search(Event *event) const
search the tree to find the node matching "event"
Double_t SearchVolume(Volume *, std::vector< const TMVA::BinarySearchTreeNode * > *events=nullptr)
search the whole tree and add up all weights of events that lie within the given volume
void CalcStatistics(TMVA::Node *n=nullptr)
calculate basic statistics (mean, rms for each variable)
void Clear(TMVA::Node *n=nullptr)
clear nodes
Bool_t InVolume(const std::vector< Float_t > &, Volume *) const
test if the data points are in the given volume
virtual ~BinarySearchTree(void)
destructor
Double_t Fill(const std::vector< TMVA::Event * > &events, const std::vector< Int_t > &theVars, Int_t theType=-1)
create the search tree from the event collection using ONLY the variables specified in "theVars"
Double_t fNEventsW[2]
Number of events per class, taking into account event weights.
void NormalizeTree()
Normalisation of tree.
Double_t GetSumOfWeights(void) const
return the sum of event (node) weights
BinarySearchTree(void)
default constructor
void Insert(const Event *)
insert a new "event" in the binary tree
static BinarySearchTree * CreateFromXML(void *node, UInt_t tmva_Version_Code=262657)
re-create a new tree (decision tree or search tree) from XML
Base class for BinarySearch and Decision Trees.
Definition BinaryTree.h:62
MsgLogger & Log() const
virtual void ReadXML(void *node, UInt_t tmva_Version_Code=262657)
read attributes from XML
Node for the BinarySearch or Decision Trees.
Definition Node.h:58
virtual Node * GetLeft() const
Definition Node.h:89
virtual void SetRight(Node *r)
Definition Node.h:95
virtual Bool_t GoesLeft(const Event &) const =0
virtual void SetLeft(Node *l)
Definition Node.h:94
void SetPos(char s)
Definition Node.h:119
void SetDepth(UInt_t d)
Definition Node.h:113
virtual void SetParent(Node *p)
Definition Node.h:96
virtual Bool_t GoesRight(const Event &) const =0
UInt_t GetDepth() const
Definition Node.h:116
virtual Node * GetRight() const
Definition Node.h:90
void ReadAttr(void *node, const char *, T &value)
read attribute from xml
Definition Tools.h:329
@ kSignal
Never change this number - it is elsewhere assumed to be zero !
Definition Types.h:135
Volume for BinarySearchTree.
Definition Volume.h:47
std::vector< Double_t > * fLower
vector with lower volume dimensions
Definition Volume.h:75
std::vector< Double_t > * fUpper
vector with upper volume dimensions
Definition Volume.h:76
const Int_t n
Definition legend1.C:16
Tools & gTools()
MsgLogger & Endl(MsgLogger &ml)
Definition MsgLogger.h:148
Double_t Sqrt(Double_t x)
Returns the square root of x.
Definition TMath.h:662