ROOT  6.07/01
Reference Guide
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Groups Pages
NodekNN.h
Go to the documentation of this file.
1 // @(#)root/tmva $Id$
2 // Author: Rustem Ospanov
3 
4 /**********************************************************************************
5  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6  * Package: TMVA *
7  * Class : Node *
8  * Web : http://tmva.sourceforge.net *
9  * *
10  * Description: *
11  * kd-tree (binary tree) template *
12  * *
13  * Author: *
14  * Rustem Ospanov <rustem@fnal.gov> - U. of Texas at Austin, USA *
15  * *
16  * Copyright (c) 2007: *
17  * CERN, Switzerland *
18  * MPI-K Heidelberg, Germany *
19  * U. of Texas at Austin, USA *
20  * *
21  * Redistribution and use in source and binary forms, with or without *
22  * modification, are permitted according to the terms listed in LICENSE *
23  * (http://tmva.sourceforge.net/LICENSE) *
24  **********************************************************************************/
25 
26 #ifndef ROOT_TMVA_NodekNN
27 #define ROOT_TMVA_NodekNN
28 
29 // C++
30 #include <cassert>
31 #include <list>
32 #include <string>
33 #include <iostream>
34 
35 // ROOT
36 #ifndef ROOT_Rtypes
37 #include "Rtypes.h"
38 #endif
39 
40 //////////////////////////////////////////////////////////////////////////
41 // //
42 // kNN::Node //
43 // //
44 // This file contains binary tree and global function template //
45 // that searches tree for k-nearest neigbors //
46 // //
47 // Node class template parameter T has to provide these functions: //
48 // rtype GetVar(UInt_t) const; //
49 // - rtype is any type convertible to Float_t //
50 // UInt_t GetNVar(void) const; //
51 // rtype GetWeight(void) const; //
52 // - rtype is any type convertible to Double_t //
53 // //
54 // Find function template parameter T has to provide these functions: //
55 // (in addition to above requirements) //
56 // rtype GetDist(Float_t, UInt_t) const; //
57 // - rtype is any type convertible to Float_t //
58 // rtype GetDist(const T &) const; //
59 // - rtype is any type convertible to Float_t //
60 // //
61 // where T::GetDist(Float_t, UInt_t) <= T::GetDist(const T &) //
62 // for any pair of events and any variable number for these events //
63 // //
64 //////////////////////////////////////////////////////////////////////////
65 
66 namespace TMVA
67 {
68  namespace kNN
69  {
70  template <class T>
71  class Node
72  {
73 
74  public:
75 
76  Node(const Node *parent, const T &event, Int_t mod);
77  ~Node();
78 
79  const Node* Add(const T &event, UInt_t depth);
80 
81  void SetNodeL(Node *node);
82  void SetNodeR(Node *node);
83 
84  const T& GetEvent() const;
85 
86  const Node* GetNodeL() const;
87  const Node* GetNodeR() const;
88  const Node* GetNodeP() const;
89 
90  Double_t GetWeight() const;
91 
92  Float_t GetVarDis() const;
93  Float_t GetVarMin() const;
94  Float_t GetVarMax() const;
95 
96  UInt_t GetMod() const;
97 
98  void Print() const;
99  void Print(std::ostream& os, const std::string &offset = "") const;
100 
101  private:
102 
103  // these methods are private and not implemented by design
104  // use provided public constructor for all uses of this template class
105  Node();
106  Node(const Node &);
107  const Node& operator=(const Node &);
108 
109  private:
110 
111  const Node* fNodeP;
112 
115 
116  const T fEvent;
117 
119 
122 
123  const UInt_t fMod;
124  };
125 
126  // recursive search for k-nearest neighbor: k = nfind
127  template<class T>
128  UInt_t Find(std::list<std::pair<const Node<T> *, Float_t> > &nlist,
129  const Node<T> *node, const T &event, UInt_t nfind);
130 
131  // recursive search for k-nearest neighbor
132  // find k events with sum of event weights >= nfind
133  template<class T>
134  UInt_t Find(std::list<std::pair<const Node<T> *, Float_t> > &nlist,
135  const Node<T> *node, const T &event, Double_t nfind, Double_t ncurr);
136 
137  // recursively travel upward until root node is reached
138  template <class T>
139  UInt_t Depth(const Node<T> *node);
140 
141  // prInt_t node content and content of its children
142  //template <class T>
143  //std::ostream& operator<<(std::ostream& os, const Node<T> &node);
144 
145  //
146  // Inlined functions for Node template
147  //
148  template <class T>
149  inline void Node<T>::SetNodeL(Node<T> *node)
150  {
151  fNodeL = node;
152  }
153 
154  template <class T>
155  inline void Node<T>::SetNodeR(Node<T> *node)
156  {
157  fNodeR = node;
158  }
159 
160  template <class T>
161  inline const T& Node<T>::GetEvent() const
162  {
163  return fEvent;
164  }
165 
166  template <class T>
167  inline const Node<T>* Node<T>::GetNodeL() const
168  {
169  return fNodeL;
170  }
171 
172  template <class T>
173  inline const Node<T>* Node<T>::GetNodeR() const
174  {
175  return fNodeR;
176  }
177 
178  template <class T>
179  inline const Node<T>* Node<T>::GetNodeP() const
180  {
181  return fNodeP;
182  }
183 
184  template <class T>
186  {
187  return fEvent.GetWeight();
188  }
189 
190  template <class T>
192  {
193  return fVarDis;
194  }
195 
196  template <class T>
198  {
199  return fVarMin;
200  }
201 
202  template <class T>
204  {
205  return fVarMax;
206  }
207 
208  template <class T>
209  inline UInt_t Node<T>::GetMod() const
210  {
211  return fMod;
212  }
213 
214  //
215  // Inlined global function(s)
216  //
217  template <class T>
218  inline UInt_t Depth(const Node<T> *node)
219  {
220  if (!node) return 0;
221  else return Depth(node->GetNodeP()) + 1;
222  }
223 
224  } // end of kNN namespace
225 } // end of TMVA namespace
226 
227 //-------------------------------------------------------------------------------------------
228 template<class T>
229 TMVA::kNN::Node<T>::Node(const Node<T> *parent, const T &event, const Int_t mod)
230  :fNodeP(parent),
231  fNodeL(0),
232  fNodeR(0),
233  fEvent(event),
234  fVarDis(event.GetVar(mod)),
235  fVarMin(fVarDis),
236  fVarMax(fVarDis),
237  fMod(mod)
238 {}
239 
240 //-------------------------------------------------------------------------------------------
241 template<class T>
243 {
244  if (fNodeL) delete fNodeL;
245  if (fNodeR) delete fNodeR;
246 }
247 
248 //-------------------------------------------------------------------------------------------
249 template<class T>
250 const TMVA::kNN::Node<T>* TMVA::kNN::Node<T>::Add(const T &event, const UInt_t depth)
251 {
252  // This is Node member function that adds a new node to a binary tree.
253  // each node contains maximum and minimum values of splitting variable
254  // left or right nodes are added based on value of splitting variable
255 
256  assert(fMod == depth % event.GetNVar() && "Wrong recursive depth in Node<>::Add");
257 
258  const Float_t value = event.GetVar(fMod);
259 
260  fVarMin = std::min(fVarMin, value);
261  fVarMax = std::max(fVarMax, value);
262 
263  Node<T> *node = 0;
264  if (value < fVarDis) {
265  if (fNodeL)
266  {
267  return fNodeL->Add(event, depth + 1);
268  }
269  else {
270  fNodeL = new Node<T>(this, event, (depth + 1) % event.GetNVar());
271  node = fNodeL;
272  }
273  }
274  else {
275  if (fNodeR) {
276  return fNodeR->Add(event, depth + 1);
277  }
278  else {
279  fNodeR = new Node<T>(this, event, (depth + 1) % event.GetNVar());
280  node = fNodeR;
281  }
282  }
283 
284  return node;
285 }
286 
287 //-------------------------------------------------------------------------------------------
288 template<class T>
290 {
291  Print(std::cout);
292 }
293 
294 //-------------------------------------------------------------------------------------------
295 template<class T>
296 void TMVA::kNN::Node<T>::Print(std::ostream& os, const std::string &offset) const
297 {
298  os << offset << "-----------------------------------------------------------" << std::endl;
299  os << offset << "Node: mod " << fMod
300  << " at " << fVarDis
301  << " with weight: " << GetWeight() << std::endl
302  << offset << fEvent;
303 
304  if (fNodeL) {
305  os << offset << "Has left node " << std::endl;
306  }
307  if (fNodeR) {
308  os << offset << "Has right node" << std::endl;
309  }
310 
311  if (fNodeL) {
312  os << offset << "PrInt_t left node " << std::endl;
313  fNodeL->Print(os, offset + " ");
314  }
315  if (fNodeR) {
316  os << offset << "PrInt_t right node" << std::endl;
317  fNodeR->Print(os, offset + " ");
318  }
319 
320  if (!fNodeL && !fNodeR) {
321  os << std::endl;
322  }
323 }
324 
325 //-------------------------------------------------------------------------------------------
326 template<class T>
327 UInt_t TMVA::kNN::Find(std::list<std::pair<const TMVA::kNN::Node<T> *, Float_t> > &nlist,
328  const TMVA::kNN::Node<T> *node, const T &event, const UInt_t nfind)
329 {
330  // This is a global templated function that searches for k-nearest neighbors.
331  // list contains k or less nodes that are closest to event.
332  // only nodes with positive weights are added to list.
333  // each node contains maximum and minimum values of splitting variable
334  // for all its children - this range is checked to avoid descending into
335  // nodes that are defintely outside current minimum neighbourhood.
336  //
337  // This function should be modified with care.
338  //
339 
340  if (!node || nfind < 1) {
341  return 0;
342  }
343 
344  const Float_t value = event.GetVar(node->GetMod());
345 
346  if (node->GetWeight() > 0.0) {
347 
348  Float_t max_dist = 0.0;
349 
350  if (!nlist.empty()) {
351 
352  max_dist = nlist.back().second;
353 
354  if (nlist.size() == nfind) {
355  if (value > node->GetVarMax() &&
356  event.GetDist(node->GetVarMax(), node->GetMod()) > max_dist) {
357  return 0;
358  }
359  if (value < node->GetVarMin() &&
360  event.GetDist(node->GetVarMin(), node->GetMod()) > max_dist) {
361  return 0;
362  }
363  }
364  }
365 
366  const Float_t distance = event.GetDist(node->GetEvent());
367 
368  Bool_t insert_this = kFALSE;
369  Bool_t remove_back = kFALSE;
370 
371  if (nlist.size() < nfind) {
372  insert_this = kTRUE;
373  }
374  else if (nlist.size() == nfind) {
375  if (distance < max_dist) {
376  insert_this = kTRUE;
377  remove_back = kTRUE;
378  }
379  }
380  else {
381  std::cerr << "TMVA::kNN::Find() - logic error in recursive procedure" << std::endl;
382  return 1;
383  }
384 
385  if (insert_this) {
386  // need typename keyword because qualified dependent names
387  // are not valid types unless preceded by 'typename'.
388  typename std::list<std::pair<const Node<T> *, Float_t> >::iterator lit = nlist.begin();
389 
390  // find a place where current node should be inserted
391  for (; lit != nlist.end(); ++lit) {
392  if (distance < lit->second) {
393  break;
394  }
395  else {
396  continue;
397  }
398  }
399 
400  nlist.insert(lit, std::pair<const Node<T> *, Float_t>(node, distance));
401 
402  if (remove_back) {
403  nlist.pop_back();
404  }
405  }
406  }
407 
408  UInt_t count = 1;
409  if (node->GetNodeL() && node->GetNodeR()) {
410  if (value < node->GetVarDis()) {
411  count += Find(nlist, node->GetNodeL(), event, nfind);
412  count += Find(nlist, node->GetNodeR(), event, nfind);
413  }
414  else {
415  count += Find(nlist, node->GetNodeR(), event, nfind);
416  count += Find(nlist, node->GetNodeL(), event, nfind);
417  }
418  }
419  else {
420  if (node->GetNodeL()) {
421  count += Find(nlist, node->GetNodeL(), event, nfind);
422  }
423  if (node->GetNodeR()) {
424  count += Find(nlist, node->GetNodeR(), event, nfind);
425  }
426  }
427 
428  return count;
429 }
430 
431 
432 //-------------------------------------------------------------------------------------------
433 template<class T>
434 UInt_t TMVA::kNN::Find(std::list<std::pair<const TMVA::kNN::Node<T> *, Float_t> > &nlist,
435  const TMVA::kNN::Node<T> *node, const T &event, const Double_t nfind, Double_t ncurr)
436 {
437  // This is a global templated function that searches for k-nearest neighbors.
438  // list contains all nodes that are closest to event
439  // and have sum of event weights >= nfind.
440  // Only nodes with positive weights are added to list.
441  // Requirement for used classes:
442  // - each node contains maximum and minimum values of splitting variable
443  // for all its children
444  // - min and max range is checked to avoid descending into
445  // nodes that are defintely outside current minimum neighbourhood.
446  //
447  // This function should be modified with care.
448  //
449 
450  if (!node || !(nfind < 0.0)) {
451  return 0;
452  }
453 
454  const Float_t value = event.GetVar(node->GetMod());
455 
456  if (node->GetWeight() > 0.0) {
457 
458  Float_t max_dist = 0.0;
459 
460  if (!nlist.empty()) {
461 
462  max_dist = nlist.back().second;
463 
464  if (!(ncurr < nfind)) {
465  if (value > node->GetVarMax() &&
466  event.GetDist(node->GetVarMax(), node->GetMod()) > max_dist) {
467  return 0;
468  }
469  if (value < node->GetVarMin() &&
470  event.GetDist(node->GetVarMin(), node->GetMod()) > max_dist) {
471  return 0;
472  }
473  }
474  }
475 
476  const Float_t distance = event.GetDist(node->GetEvent());
477 
478  Bool_t insert_this = kFALSE;
479 
480  if (ncurr < nfind) {
481  insert_this = kTRUE;
482  }
483  else if (!nlist.empty()) {
484  if (distance < max_dist) {
485  insert_this = kTRUE;
486  }
487  }
488  else {
489  std::cerr << "TMVA::kNN::Find() - logic error in recursive procedure" << std::endl;
490  return 1;
491  }
492 
493  if (insert_this) {
494  // (re)compute total current weight when inserting a new node
495  ncurr = 0;
496 
497  // need typename keyword because qualified dependent names
498  // are not valid types unless preceded by 'typename'.
499  typename std::list<std::pair<const Node<T> *, Float_t> >::iterator lit = nlist.begin();
500 
501  // find a place where current node should be inserted
502  for (; lit != nlist.end(); ++lit) {
503  if (distance < lit->second) {
504  break;
505  }
506 
507  ncurr += lit -> first -> GetWeight();
508  }
509 
510  lit = nlist.insert(lit, std::pair<const Node<T> *, Float_t>(node, distance));
511 
512  for (; lit != nlist.end(); ++lit) {
513  ncurr += lit -> first -> GetWeight();
514  if (!(ncurr < nfind)) {
515  ++lit;
516  break;
517  }
518  }
519 
520  if(lit != nlist.end())
521  {
522  nlist.erase(lit, nlist.end());
523  }
524  }
525  }
526 
527  UInt_t count = 1;
528  if (node->GetNodeL() && node->GetNodeR()) {
529  if (value < node->GetVarDis()) {
530  count += Find(nlist, node->GetNodeL(), event, nfind, ncurr);
531  count += Find(nlist, node->GetNodeR(), event, nfind, ncurr);
532  }
533  else {
534  count += Find(nlist, node->GetNodeR(), event, nfind, ncurr);
535  count += Find(nlist, node->GetNodeL(), event, nfind, ncurr);
536  }
537  }
538  else {
539  if (node->GetNodeL()) {
540  count += Find(nlist, node->GetNodeL(), event, nfind, ncurr);
541  }
542  if (node->GetNodeR()) {
543  count += Find(nlist, node->GetNodeR(), event, nfind, ncurr);
544  }
545  }
546 
547  return count;
548 }
549 
550 #endif
551 
const Node * GetNodeR() const
Definition: NodekNN.h:173
UInt_t GetMod() const
Definition: NodekNN.h:209
const Node * fNodeP
Definition: NodekNN.h:111
static Vc_ALWAYS_INLINE int_v min(const int_v &x, const int_v &y)
Definition: vector.h:433
float Float_t
Definition: RtypesCore.h:53
tuple offset
Definition: tree.py:93
#define assert(cond)
Definition: unittest.h:542
Node * fNodeR
Definition: NodekNN.h:114
int Int_t
Definition: RtypesCore.h:41
bool Bool_t
Definition: RtypesCore.h:59
const Bool_t kFALSE
Definition: Rtypes.h:92
const Node * Add(const T &event, UInt_t depth)
Definition: NodekNN.h:250
UInt_t Depth(const Node< T > *node)
Definition: NodekNN.h:218
TTree * T
const Node & operator=(const Node &)
Float_t GetVarMin() const
Definition: NodekNN.h:197
const T fEvent
Definition: NodekNN.h:116
void SetNodeL(Node *node)
Definition: NodekNN.h:149
const Float_t fVarDis
Definition: NodekNN.h:118
unsigned int UInt_t
Definition: RtypesCore.h:42
bool first
Definition: line3Dfit.C:48
Double_t GetWeight(Double_t x) const
const UInt_t fMod
Definition: NodekNN.h:123
const Node * GetNodeP() const
Definition: NodekNN.h:179
Float_t fVarMin
Definition: NodekNN.h:120
void Print() const
Definition: NodekNN.h:289
void Print(std::ostream &os, const OptionType &opt)
double Double_t
Definition: RtypesCore.h:55
const Node * GetNodeL() const
Definition: NodekNN.h:167
Float_t GetVarMax() const
Definition: NodekNN.h:203
Double_t GetWeight() const
Definition: NodekNN.h:185
Float_t fVarMax
Definition: NodekNN.h:121
const T & GetEvent() const
Definition: NodekNN.h:161
UInt_t Find(std::list< std::pair< const Node< T > *, Float_t > > &nlist, const Node< T > *node, const T &event, UInt_t nfind)
static Vc_ALWAYS_INLINE int_v max(const int_v &x, const int_v &y)
Definition: vector.h:440
void SetNodeR(Node *node)
Definition: NodekNN.h:155
Node * fNodeL
Definition: NodekNN.h:113
Float_t GetVarDis() const
Definition: NodekNN.h:191
const Bool_t kTRUE
Definition: Rtypes.h:91