ROOT  6.06/09
Reference Guide
NodekNN.h
Go to the documentation of this file.
1 // @(#)root/tmva $Id$
2 // Author: Rustem Ospanov
3 
4 /**********************************************************************************
5  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6  * Package: TMVA *
7  * Class : Node *
8  * Web : http://tmva.sourceforge.net *
9  * *
10  * Description: *
11  * kd-tree (binary tree) template *
12  * *
13  * Author: *
14  * Rustem Ospanov <rustem@fnal.gov> - U. of Texas at Austin, USA *
15  * *
16  * Copyright (c) 2007: *
17  * CERN, Switzerland *
18  * MPI-K Heidelberg, Germany *
19  * U. of Texas at Austin, USA *
20  * *
21  * Redistribution and use in source and binary forms, with or without *
22  * modification, are permitted according to the terms listed in LICENSE *
23  * (http://tmva.sourceforge.net/LICENSE) *
24  **********************************************************************************/
25 
26 #ifndef ROOT_TMVA_NodekNN
27 #define ROOT_TMVA_NodekNN
28 
29 // C++
30 #include <list>
31 #include <string>
32 #include <iostream>
33 
34 // ROOT
35 #ifndef ROOT_Rtypes
36 #include "Rtypes.h"
37 #endif
38 
39 //////////////////////////////////////////////////////////////////////////
40 // //
41 // kNN::Node //
42 // //
43 // This file contains binary tree and global function template //
44 // that searches tree for k-nearest neigbors //
45 // //
46 // Node class template parameter T has to provide these functions: //
47 // rtype GetVar(UInt_t) const; //
48 // - rtype is any type convertible to Float_t //
49 // UInt_t GetNVar(void) const; //
50 // rtype GetWeight(void) const; //
51 // - rtype is any type convertible to Double_t //
52 // //
53 // Find function template parameter T has to provide these functions: //
54 // (in addition to above requirements) //
55 // rtype GetDist(Float_t, UInt_t) const; //
56 // - rtype is any type convertible to Float_t //
57 // rtype GetDist(const T &) const; //
58 // - rtype is any type convertible to Float_t //
59 // //
60 // where T::GetDist(Float_t, UInt_t) <= T::GetDist(const T &) //
61 // for any pair of events and any variable number for these events //
62 // //
63 //////////////////////////////////////////////////////////////////////////
64 
65 namespace TMVA
66 {
67  namespace kNN
68  {
69  template <class T>
70  class Node
71  {
72 
73  public:
74 
75  Node(const Node *parent, const T &event, Int_t mod);
76  ~Node();
77 
78  const Node* Add(const T &event, UInt_t depth);
79 
80  void SetNodeL(Node *node);
81  void SetNodeR(Node *node);
82 
83  const T& GetEvent() const;
84 
85  const Node* GetNodeL() const;
86  const Node* GetNodeR() const;
87  const Node* GetNodeP() const;
88 
89  Double_t GetWeight() const;
90 
91  Float_t GetVarDis() const;
92  Float_t GetVarMin() const;
93  Float_t GetVarMax() const;
94 
95  UInt_t GetMod() const;
96 
97  void Print() const;
98  void Print(std::ostream& os, const std::string &offset = "") const;
99 
100  private:
101 
102  // these methods are private and not implemented by design
103  // use provided public constructor for all uses of this template class
104  Node();
105  Node(const Node &);
106  const Node& operator=(const Node &);
107 
108  private:
109 
110  const Node* fNodeP;
111 
114 
115  const T fEvent;
116 
118 
121 
122  const UInt_t fMod;
123  };
124 
125  // recursive search for k-nearest neighbor: k = nfind
126  template<class T>
127  UInt_t Find(std::list<std::pair<const Node<T> *, Float_t> > &nlist,
128  const Node<T> *node, const T &event, UInt_t nfind);
129 
130  // recursive search for k-nearest neighbor
131  // find k events with sum of event weights >= nfind
132  template<class T>
133  UInt_t Find(std::list<std::pair<const Node<T> *, Float_t> > &nlist,
134  const Node<T> *node, const T &event, Double_t nfind, Double_t ncurr);
135 
136  // recursively travel upward until root node is reached
137  template <class T>
138  UInt_t Depth(const Node<T> *node);
139 
140  // prInt_t node content and content of its children
141  //template <class T>
142  //std::ostream& operator<<(std::ostream& os, const Node<T> &node);
143 
144  //
145  // Inlined functions for Node template
146  //
147  template <class T>
148  inline void Node<T>::SetNodeL(Node<T> *node)
149  {
150  fNodeL = node;
151  }
152 
153  template <class T>
154  inline void Node<T>::SetNodeR(Node<T> *node)
155  {
156  fNodeR = node;
157  }
158 
159  template <class T>
160  inline const T& Node<T>::GetEvent() const
161  {
162  return fEvent;
163  }
164 
165  template <class T>
166  inline const Node<T>* Node<T>::GetNodeL() const
167  {
168  return fNodeL;
169  }
170 
171  template <class T>
172  inline const Node<T>* Node<T>::GetNodeR() const
173  {
174  return fNodeR;
175  }
176 
177  template <class T>
178  inline const Node<T>* Node<T>::GetNodeP() const
179  {
180  return fNodeP;
181  }
182 
183  template <class T>
185  {
186  return fEvent.GetWeight();
187  }
188 
189  template <class T>
191  {
192  return fVarDis;
193  }
194 
195  template <class T>
197  {
198  return fVarMin;
199  }
200 
201  template <class T>
203  {
204  return fVarMax;
205  }
206 
207  template <class T>
208  inline UInt_t Node<T>::GetMod() const
209  {
210  return fMod;
211  }
212 
213  //
214  // Inlined global function(s)
215  //
216  template <class T>
217  inline UInt_t Depth(const Node<T> *node)
218  {
219  if (!node) return 0;
220  else return Depth(node->GetNodeP()) + 1;
221  }
222 
223  } // end of kNN namespace
224 } // end of TMVA namespace
225 
226 //-------------------------------------------------------------------------------------------
227 template<class T>
228 TMVA::kNN::Node<T>::Node(const Node<T> *parent, const T &event, const Int_t mod)
229  :fNodeP(parent),
230  fNodeL(0),
231  fNodeR(0),
232  fEvent(event),
233  fVarDis(event.GetVar(mod)),
234  fVarMin(fVarDis),
235  fVarMax(fVarDis),
236  fMod(mod)
237 {}
238 
239 //-------------------------------------------------------------------------------------------
240 template<class T>
242 {
243  if (fNodeL) delete fNodeL;
244  if (fNodeR) delete fNodeR;
245 }
246 
247 //-------------------------------------------------------------------------------------------
248 template<class T>
249 const TMVA::kNN::Node<T>* TMVA::kNN::Node<T>::Add(const T &event, const UInt_t depth)
250 {
251  // This is Node member function that adds a new node to a binary tree.
252  // each node contains maximum and minimum values of splitting variable
253  // left or right nodes are added based on value of splitting variable
254 
255  assert(fMod == depth % event.GetNVar() && "Wrong recursive depth in Node<>::Add");
256 
257  const Float_t value = event.GetVar(fMod);
258 
259  fVarMin = std::min(fVarMin, value);
260  fVarMax = std::max(fVarMax, value);
261 
262  Node<T> *node = 0;
263  if (value < fVarDis) {
264  if (fNodeL)
265  {
266  return fNodeL->Add(event, depth + 1);
267  }
268  else {
269  fNodeL = new Node<T>(this, event, (depth + 1) % event.GetNVar());
270  node = fNodeL;
271  }
272  }
273  else {
274  if (fNodeR) {
275  return fNodeR->Add(event, depth + 1);
276  }
277  else {
278  fNodeR = new Node<T>(this, event, (depth + 1) % event.GetNVar());
279  node = fNodeR;
280  }
281  }
282 
283  return node;
284 }
285 
286 //-------------------------------------------------------------------------------------------
287 template<class T>
289 {
290  Print(std::cout);
291 }
292 
293 //-------------------------------------------------------------------------------------------
294 template<class T>
295 void TMVA::kNN::Node<T>::Print(std::ostream& os, const std::string &offset) const
296 {
297  os << offset << "-----------------------------------------------------------" << std::endl;
298  os << offset << "Node: mod " << fMod
299  << " at " << fVarDis
300  << " with weight: " << GetWeight() << std::endl
301  << offset << fEvent;
302 
303  if (fNodeL) {
304  os << offset << "Has left node " << std::endl;
305  }
306  if (fNodeR) {
307  os << offset << "Has right node" << std::endl;
308  }
309 
310  if (fNodeL) {
311  os << offset << "PrInt_t left node " << std::endl;
312  fNodeL->Print(os, offset + " ");
313  }
314  if (fNodeR) {
315  os << offset << "PrInt_t right node" << std::endl;
316  fNodeR->Print(os, offset + " ");
317  }
318 
319  if (!fNodeL && !fNodeR) {
320  os << std::endl;
321  }
322 }
323 
324 //-------------------------------------------------------------------------------------------
325 template<class T>
326 UInt_t TMVA::kNN::Find(std::list<std::pair<const TMVA::kNN::Node<T> *, Float_t> > &nlist,
327  const TMVA::kNN::Node<T> *node, const T &event, const UInt_t nfind)
328 {
329  // This is a global templated function that searches for k-nearest neighbors.
330  // list contains k or less nodes that are closest to event.
331  // only nodes with positive weights are added to list.
332  // each node contains maximum and minimum values of splitting variable
333  // for all its children - this range is checked to avoid descending into
334  // nodes that are defintely outside current minimum neighbourhood.
335  //
336  // This function should be modified with care.
337  //
338 
339  if (!node || nfind < 1) {
340  return 0;
341  }
342 
343  const Float_t value = event.GetVar(node->GetMod());
344 
345  if (node->GetWeight() > 0.0) {
346 
347  Float_t max_dist = 0.0;
348 
349  if (!nlist.empty()) {
350 
351  max_dist = nlist.back().second;
352 
353  if (nlist.size() == nfind) {
354  if (value > node->GetVarMax() &&
355  event.GetDist(node->GetVarMax(), node->GetMod()) > max_dist) {
356  return 0;
357  }
358  if (value < node->GetVarMin() &&
359  event.GetDist(node->GetVarMin(), node->GetMod()) > max_dist) {
360  return 0;
361  }
362  }
363  }
364 
365  const Float_t distance = event.GetDist(node->GetEvent());
366 
367  Bool_t insert_this = kFALSE;
368  Bool_t remove_back = kFALSE;
369 
370  if (nlist.size() < nfind) {
371  insert_this = kTRUE;
372  }
373  else if (nlist.size() == nfind) {
374  if (distance < max_dist) {
375  insert_this = kTRUE;
376  remove_back = kTRUE;
377  }
378  }
379  else {
380  std::cerr << "TMVA::kNN::Find() - logic error in recursive procedure" << std::endl;
381  return 1;
382  }
383 
384  if (insert_this) {
385  // need typename keyword because qualified dependent names
386  // are not valid types unless preceded by 'typename'.
387  typename std::list<std::pair<const Node<T> *, Float_t> >::iterator lit = nlist.begin();
388 
389  // find a place where current node should be inserted
390  for (; lit != nlist.end(); ++lit) {
391  if (distance < lit->second) {
392  break;
393  }
394  else {
395  continue;
396  }
397  }
398 
399  nlist.insert(lit, std::pair<const Node<T> *, Float_t>(node, distance));
400 
401  if (remove_back) {
402  nlist.pop_back();
403  }
404  }
405  }
406 
407  UInt_t count = 1;
408  if (node->GetNodeL() && node->GetNodeR()) {
409  if (value < node->GetVarDis()) {
410  count += Find(nlist, node->GetNodeL(), event, nfind);
411  count += Find(nlist, node->GetNodeR(), event, nfind);
412  }
413  else {
414  count += Find(nlist, node->GetNodeR(), event, nfind);
415  count += Find(nlist, node->GetNodeL(), event, nfind);
416  }
417  }
418  else {
419  if (node->GetNodeL()) {
420  count += Find(nlist, node->GetNodeL(), event, nfind);
421  }
422  if (node->GetNodeR()) {
423  count += Find(nlist, node->GetNodeR(), event, nfind);
424  }
425  }
426 
427  return count;
428 }
429 
430 
431 //-------------------------------------------------------------------------------------------
432 template<class T>
433 UInt_t TMVA::kNN::Find(std::list<std::pair<const TMVA::kNN::Node<T> *, Float_t> > &nlist,
434  const TMVA::kNN::Node<T> *node, const T &event, const Double_t nfind, Double_t ncurr)
435 {
436  // This is a global templated function that searches for k-nearest neighbors.
437  // list contains all nodes that are closest to event
438  // and have sum of event weights >= nfind.
439  // Only nodes with positive weights are added to list.
440  // Requirement for used classes:
441  // - each node contains maximum and minimum values of splitting variable
442  // for all its children
443  // - min and max range is checked to avoid descending into
444  // nodes that are defintely outside current minimum neighbourhood.
445  //
446  // This function should be modified with care.
447  //
448 
449  if (!node || !(nfind < 0.0)) {
450  return 0;
451  }
452 
453  const Float_t value = event.GetVar(node->GetMod());
454 
455  if (node->GetWeight() > 0.0) {
456 
457  Float_t max_dist = 0.0;
458 
459  if (!nlist.empty()) {
460 
461  max_dist = nlist.back().second;
462 
463  if (!(ncurr < nfind)) {
464  if (value > node->GetVarMax() &&
465  event.GetDist(node->GetVarMax(), node->GetMod()) > max_dist) {
466  return 0;
467  }
468  if (value < node->GetVarMin() &&
469  event.GetDist(node->GetVarMin(), node->GetMod()) > max_dist) {
470  return 0;
471  }
472  }
473  }
474 
475  const Float_t distance = event.GetDist(node->GetEvent());
476 
477  Bool_t insert_this = kFALSE;
478 
479  if (ncurr < nfind) {
480  insert_this = kTRUE;
481  }
482  else if (!nlist.empty()) {
483  if (distance < max_dist) {
484  insert_this = kTRUE;
485  }
486  }
487  else {
488  std::cerr << "TMVA::kNN::Find() - logic error in recursive procedure" << std::endl;
489  return 1;
490  }
491 
492  if (insert_this) {
493  // (re)compute total current weight when inserting a new node
494  ncurr = 0;
495 
496  // need typename keyword because qualified dependent names
497  // are not valid types unless preceded by 'typename'.
498  typename std::list<std::pair<const Node<T> *, Float_t> >::iterator lit = nlist.begin();
499 
500  // find a place where current node should be inserted
501  for (; lit != nlist.end(); ++lit) {
502  if (distance < lit->second) {
503  break;
504  }
505 
506  ncurr += lit -> first -> GetWeight();
507  }
508 
509  lit = nlist.insert(lit, std::pair<const Node<T> *, Float_t>(node, distance));
510 
511  for (; lit != nlist.end(); ++lit) {
512  ncurr += lit -> first -> GetWeight();
513  if (!(ncurr < nfind)) {
514  ++lit;
515  break;
516  }
517  }
518 
519  if(lit != nlist.end())
520  {
521  nlist.erase(lit, nlist.end());
522  }
523  }
524  }
525 
526  UInt_t count = 1;
527  if (node->GetNodeL() && node->GetNodeR()) {
528  if (value < node->GetVarDis()) {
529  count += Find(nlist, node->GetNodeL(), event, nfind, ncurr);
530  count += Find(nlist, node->GetNodeR(), event, nfind, ncurr);
531  }
532  else {
533  count += Find(nlist, node->GetNodeR(), event, nfind, ncurr);
534  count += Find(nlist, node->GetNodeL(), event, nfind, ncurr);
535  }
536  }
537  else {
538  if (node->GetNodeL()) {
539  count += Find(nlist, node->GetNodeL(), event, nfind, ncurr);
540  }
541  if (node->GetNodeR()) {
542  count += Find(nlist, node->GetNodeR(), event, nfind, ncurr);
543  }
544  }
545 
546  return count;
547 }
548 
549 #endif
550 
const Node * GetNodeR() const
Definition: NodekNN.h:172
UInt_t GetMod() const
Definition: NodekNN.h:208
const Node * fNodeP
Definition: NodekNN.h:110
static Vc_ALWAYS_INLINE int_v min(const int_v &x, const int_v &y)
Definition: vector.h:433
float Float_t
Definition: RtypesCore.h:53
double T(double x)
Definition: ChebyshevPol.h:34
#define assert(cond)
Definition: unittest.h:542
Node * fNodeR
Definition: NodekNN.h:113
int Int_t
Definition: RtypesCore.h:41
bool Bool_t
Definition: RtypesCore.h:59
const Bool_t kFALSE
Definition: Rtypes.h:92
const Node * Add(const T &event, UInt_t depth)
Definition: NodekNN.h:249
UInt_t Depth(const Node< T > *node)
Definition: NodekNN.h:217
const Node & operator=(const Node &)
Float_t GetVarMin() const
Definition: NodekNN.h:196
const T fEvent
Definition: NodekNN.h:115
void SetNodeL(Node *node)
Definition: NodekNN.h:148
const Float_t fVarDis
Definition: NodekNN.h:117
unsigned int UInt_t
Definition: RtypesCore.h:42
Double_t GetWeight(Double_t x) const
const UInt_t fMod
Definition: NodekNN.h:122
const Node * GetNodeP() const
Definition: NodekNN.h:178
Float_t fVarMin
Definition: NodekNN.h:119
void Print() const
Definition: NodekNN.h:288
void Print(std::ostream &os, const OptionType &opt)
double Double_t
Definition: RtypesCore.h:55
const Node * GetNodeL() const
Definition: NodekNN.h:166
Float_t GetVarMax() const
Definition: NodekNN.h:202
Double_t GetWeight() const
Definition: NodekNN.h:184
Float_t fVarMax
Definition: NodekNN.h:120
const T & GetEvent() const
Definition: NodekNN.h:160
UInt_t Find(std::list< std::pair< const Node< T > *, Float_t > > &nlist, const Node< T > *node, const T &event, UInt_t nfind)
static Vc_ALWAYS_INLINE int_v max(const int_v &x, const int_v &y)
Definition: vector.h:440
Abstract ClassifierFactory template that handles arbitrary types.
void SetNodeR(Node *node)
Definition: NodekNN.h:154
Node * fNodeL
Definition: NodekNN.h:112
Float_t GetVarDis() const
Definition: NodekNN.h:190
const Bool_t kTRUE
Definition: Rtypes.h:91