Logo ROOT  
Reference Guide
NodekNN.h
Go to the documentation of this file.
1// @(#)root/tmva $Id$
2// Author: Rustem Ospanov
3
4/**********************************************************************************
5 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6 * Package: TMVA *
7 * Class : Node *
8 * Web : http://tmva.sourceforge.net *
9 * *
10 * Description: *
11 * kd-tree (binary tree) template *
12 * *
13 * Author: *
14 * Rustem Ospanov <rustem@fnal.gov> - U. of Texas at Austin, USA *
15 * *
16 * Copyright (c) 2007: *
17 * CERN, Switzerland *
18 * MPI-K Heidelberg, Germany *
19 * U. of Texas at Austin, USA *
20 * *
21 * Redistribution and use in source and binary forms, with or without *
22 * modification, are permitted according to the terms listed in LICENSE *
23 * (http://tmva.sourceforge.net/LICENSE) *
24 **********************************************************************************/
25
26#ifndef ROOT_TMVA_NodekNN
27#define ROOT_TMVA_NodekNN
28
29// C++
30#include <cassert>
31#include <list>
32#include <string>
33#include <iostream>
34
35// ROOT
36#include "Rtypes.h"
37
38/*! \class TMVA::kNN::Node
39\ingroup TMVA
40This file contains binary tree and global function template
41that searches tree for k-nearest neigbors
42
43Node class template parameter T has to provide these functions:
44 rtype GetVar(UInt_t) const;
45 - rtype is any type convertible to Float_t
46 UInt_t GetNVar(void) const;
47 rtype GetWeight(void) const;
48 - rtype is any type convertible to Double_t
49
50Find function template parameter T has to provide these functions:
51(in addition to above requirements)
52 rtype GetDist(Float_t, UInt_t) const;
53 - rtype is any type convertible to Float_t
54 rtype GetDist(const T &) const;
55 - rtype is any type convertible to Float_t
56
57 where T::GetDist(Float_t, UInt_t) <= T::GetDist(const T &)
58 for any pair of events and any variable number for these events
59*/
60
61namespace TMVA
62{
63 namespace kNN
64 {
65 template <class T>
66 class Node
67 {
68
69 public:
70
71 Node(const Node *parent, const T &event, Int_t mod);
73
74 const Node* Add(const T &event, UInt_t depth);
75
76 void SetNodeL(Node *node);
77 void SetNodeR(Node *node);
78
79 const T& GetEvent() const;
80
81 const Node* GetNodeL() const;
82 const Node* GetNodeR() const;
83 const Node* GetNodeP() const;
84
86
90
91 UInt_t GetMod() const;
92
93 void Print() const;
94 void Print(std::ostream& os, const std::string &offset = "") const;
95
96 private:
97
98 // these methods are private and not implemented by design
99 // use provided public constructor for all uses of this template class
101 Node(const Node &);
102 const Node& operator=(const Node &);
103
104 private:
105
106 const Node* fNodeP;
107
110
111 const T fEvent;
112
114
117
119 };
120
121 // recursive search for k-nearest neighbor: k = nfind
122 template<class T>
123 UInt_t Find(std::list<std::pair<const Node<T> *, Float_t> > &nlist,
124 const Node<T> *node, const T &event, UInt_t nfind);
125
126 // recursive search for k-nearest neighbor
127 // find k events with sum of event weights >= nfind
128 template<class T>
129 UInt_t Find(std::list<std::pair<const Node<T> *, Float_t> > &nlist,
130 const Node<T> *node, const T &event, Double_t nfind, Double_t ncurr);
131
132 // recursively travel upward until root node is reached
133 template <class T>
134 UInt_t Depth(const Node<T> *node);
135
136 // prInt_t node content and content of its children
137 //template <class T>
138 //std::ostream& operator<<(std::ostream& os, const Node<T> &node);
139
140 //
141 // Inlined functions for Node template
142 //
143 template <class T>
144 inline void Node<T>::SetNodeL(Node<T> *node)
145 {
146 fNodeL = node;
147 }
148
149 template <class T>
150 inline void Node<T>::SetNodeR(Node<T> *node)
151 {
152 fNodeR = node;
153 }
154
155 template <class T>
156 inline const T& Node<T>::GetEvent() const
157 {
158 return fEvent;
159 }
160
161 template <class T>
162 inline const Node<T>* Node<T>::GetNodeL() const
163 {
164 return fNodeL;
165 }
166
167 template <class T>
168 inline const Node<T>* Node<T>::GetNodeR() const
169 {
170 return fNodeR;
171 }
172
173 template <class T>
174 inline const Node<T>* Node<T>::GetNodeP() const
175 {
176 return fNodeP;
177 }
178
179 template <class T>
181 {
182 return fEvent.GetWeight();
183 }
184
185 template <class T>
187 {
188 return fVarDis;
189 }
190
191 template <class T>
193 {
194 return fVarMin;
195 }
196
197 template <class T>
199 {
200 return fVarMax;
201 }
202
203 template <class T>
204 inline UInt_t Node<T>::GetMod() const
205 {
206 return fMod;
207 }
208
209 //
210 // Inlined global function(s)
211 //
212 template <class T>
213 inline UInt_t Depth(const Node<T> *node)
214 {
215 if (!node) return 0;
216 else return Depth(node->GetNodeP()) + 1;
217 }
218
219 } // end of kNN namespace
220} // end of TMVA namespace
221
222////////////////////////////////////////////////////////////////////////////////
223template<class T>
224TMVA::kNN::Node<T>::Node(const Node<T> *parent, const T &event, const Int_t mod)
225:fNodeP(parent),
226 fNodeL(0),
227 fNodeR(0),
228 fEvent(event),
229 fVarDis(event.GetVar(mod)),
230 fVarMin(fVarDis),
231 fVarMax(fVarDis),
232 fMod(mod)
233{}
234
235////////////////////////////////////////////////////////////////////////////////
236template<class T>
238{
239 if (fNodeL) delete fNodeL;
240 if (fNodeR) delete fNodeR;
241}
242
243////////////////////////////////////////////////////////////////////////////////
244/// This is Node member function that adds a new node to a binary tree.
245/// each node contains maximum and minimum values of splitting variable
246/// left or right nodes are added based on value of splitting variable
247
248template<class T>
249const TMVA::kNN::Node<T>* TMVA::kNN::Node<T>::Add(const T &event, const UInt_t depth)
250{
251
252 assert(fMod == depth % event.GetNVar() && "Wrong recursive depth in Node<>::Add");
253
254 const Float_t value = event.GetVar(fMod);
255
256 fVarMin = std::min(fVarMin, value);
257 fVarMax = std::max(fVarMax, value);
258
259 Node<T> *node = 0;
260 if (value < fVarDis) {
261 if (fNodeL)
262 {
263 return fNodeL->Add(event, depth + 1);
264 }
265 else {
266 fNodeL = new Node<T>(this, event, (depth + 1) % event.GetNVar());
267 node = fNodeL;
268 }
269 }
270 else {
271 if (fNodeR) {
272 return fNodeR->Add(event, depth + 1);
273 }
274 else {
275 fNodeR = new Node<T>(this, event, (depth + 1) % event.GetNVar());
276 node = fNodeR;
277 }
278 }
279
280 return node;
281}
282
283////////////////////////////////////////////////////////////////////////////////
284template<class T>
286{
287 Print(std::cout);
288}
289
290////////////////////////////////////////////////////////////////////////////////
291template<class T>
292void TMVA::kNN::Node<T>::Print(std::ostream& os, const std::string &offset) const
293{
294 os << offset << "-----------------------------------------------------------" << std::endl;
295 os << offset << "Node: mod " << fMod
296 << " at " << fVarDis
297 << " with weight: " << GetWeight() << std::endl
298 << offset << fEvent;
299
300 if (fNodeL) {
301 os << offset << "Has left node " << std::endl;
302 }
303 if (fNodeR) {
304 os << offset << "Has right node" << std::endl;
305 }
306
307 if (fNodeL) {
308 os << offset << "PrInt_t left node " << std::endl;
309 fNodeL->Print(os, offset + " ");
310 }
311 if (fNodeR) {
312 os << offset << "PrInt_t right node" << std::endl;
313 fNodeR->Print(os, offset + " ");
314 }
315
316 if (!fNodeL && !fNodeR) {
317 os << std::endl;
318 }
319}
320
321////////////////////////////////////////////////////////////////////////////////
322/// This is a global templated function that searches for k-nearest neighbors.
323/// list contains k or less nodes that are closest to event.
324/// only nodes with positive weights are added to list.
325/// each node contains maximum and minimum values of splitting variable
326/// for all its children - this range is checked to avoid descending into
327/// nodes that are definitely outside current minimum neighbourhood.
328///
329/// This function should be modified with care.
330
331template<class T>
332UInt_t TMVA::kNN::Find(std::list<std::pair<const TMVA::kNN::Node<T> *, Float_t> > &nlist,
333 const TMVA::kNN::Node<T> *node, const T &event, const UInt_t nfind)
334{
335 if (!node || nfind < 1) {
336 return 0;
337 }
338
339 const Float_t value = event.GetVar(node->GetMod());
340
341 if (node->GetWeight() > 0.0) {
342
343 Float_t max_dist = 0.0;
344
345 if (!nlist.empty()) {
346
347 max_dist = nlist.back().second;
348
349 if (nlist.size() == nfind) {
350 if (value > node->GetVarMax() &&
351 event.GetDist(node->GetVarMax(), node->GetMod()) > max_dist) {
352 return 0;
353 }
354 if (value < node->GetVarMin() &&
355 event.GetDist(node->GetVarMin(), node->GetMod()) > max_dist) {
356 return 0;
357 }
358 }
359 }
360
361 const Float_t distance = event.GetDist(node->GetEvent());
362
363 Bool_t insert_this = kFALSE;
364 Bool_t remove_back = kFALSE;
365
366 if (nlist.size() < nfind) {
367 insert_this = kTRUE;
368 }
369 else if (nlist.size() == nfind) {
370 if (distance < max_dist) {
371 insert_this = kTRUE;
372 remove_back = kTRUE;
373 }
374 }
375 else {
376 std::cerr << "TMVA::kNN::Find() - logic error in recursive procedure" << std::endl;
377 return 1;
378 }
379
380 if (insert_this) {
381 // need typename keyword because qualified dependent names
382 // are not valid types unless preceded by 'typename'.
383 typename std::list<std::pair<const Node<T> *, Float_t> >::iterator lit = nlist.begin();
384
385 // find a place where current node should be inserted
386 for (; lit != nlist.end(); ++lit) {
387 if (distance < lit->second) {
388 break;
389 }
390 else {
391 continue;
392 }
393 }
394
395 nlist.insert(lit, std::pair<const Node<T> *, Float_t>(node, distance));
396
397 if (remove_back) {
398 nlist.pop_back();
399 }
400 }
401 }
402
403 UInt_t count = 1;
404 if (node->GetNodeL() && node->GetNodeR()) {
405 if (value < node->GetVarDis()) {
406 count += Find(nlist, node->GetNodeL(), event, nfind);
407 count += Find(nlist, node->GetNodeR(), event, nfind);
408 }
409 else {
410 count += Find(nlist, node->GetNodeR(), event, nfind);
411 count += Find(nlist, node->GetNodeL(), event, nfind);
412 }
413 }
414 else {
415 if (node->GetNodeL()) {
416 count += Find(nlist, node->GetNodeL(), event, nfind);
417 }
418 if (node->GetNodeR()) {
419 count += Find(nlist, node->GetNodeR(), event, nfind);
420 }
421 }
422
423 return count;
424}
425
426////////////////////////////////////////////////////////////////////////////////
427/// This is a global templated function that searches for k-nearest neighbors.
428/// list contains all nodes that are closest to event
429/// and have sum of event weights >= nfind.
430/// Only nodes with positive weights are added to list.
431/// Requirement for used classes:
432/// - each node contains maximum and minimum values of splitting variable
433/// for all its children
434/// - min and max range is checked to avoid descending into
435/// nodes that are definitely outside current minimum neighbourhood.
436///
437/// This function should be modified with care.
438
439template<class T>
440UInt_t TMVA::kNN::Find(std::list<std::pair<const TMVA::kNN::Node<T> *, Float_t> > &nlist,
441 const TMVA::kNN::Node<T> *node, const T &event, const Double_t nfind, Double_t ncurr)
442{
443
444 if (!node || !(nfind < 0.0)) {
445 return 0;
446 }
447
448 const Float_t value = event.GetVar(node->GetMod());
449
450 if (node->GetWeight() > 0.0) {
451
452 Float_t max_dist = 0.0;
453
454 if (!nlist.empty()) {
455
456 max_dist = nlist.back().second;
457
458 if (!(ncurr < nfind)) {
459 if (value > node->GetVarMax() &&
460 event.GetDist(node->GetVarMax(), node->GetMod()) > max_dist) {
461 return 0;
462 }
463 if (value < node->GetVarMin() &&
464 event.GetDist(node->GetVarMin(), node->GetMod()) > max_dist) {
465 return 0;
466 }
467 }
468 }
469
470 const Float_t distance = event.GetDist(node->GetEvent());
471
472 Bool_t insert_this = kFALSE;
473
474 if (ncurr < nfind) {
475 insert_this = kTRUE;
476 }
477 else if (!nlist.empty()) {
478 if (distance < max_dist) {
479 insert_this = kTRUE;
480 }
481 }
482 else {
483 std::cerr << "TMVA::kNN::Find() - logic error in recursive procedure" << std::endl;
484 return 1;
485 }
486
487 if (insert_this) {
488 // (re)compute total current weight when inserting a new node
489 ncurr = 0;
490
491 // need typename keyword because qualified dependent names
492 // are not valid types unless preceded by 'typename'.
493 typename std::list<std::pair<const Node<T> *, Float_t> >::iterator lit = nlist.begin();
494
495 // find a place where current node should be inserted
496 for (; lit != nlist.end(); ++lit) {
497 if (distance < lit->second) {
498 break;
499 }
500
501 ncurr += lit -> first -> GetWeight();
502 }
503
504 lit = nlist.insert(lit, std::pair<const Node<T> *, Float_t>(node, distance));
505
506 for (; lit != nlist.end(); ++lit) {
507 ncurr += lit -> first -> GetWeight();
508 if (!(ncurr < nfind)) {
509 ++lit;
510 break;
511 }
512 }
513
514 if(lit != nlist.end())
515 {
516 nlist.erase(lit, nlist.end());
517 }
518 }
519 }
520
521 UInt_t count = 1;
522 if (node->GetNodeL() && node->GetNodeR()) {
523 if (value < node->GetVarDis()) {
524 count += Find(nlist, node->GetNodeL(), event, nfind, ncurr);
525 count += Find(nlist, node->GetNodeR(), event, nfind, ncurr);
526 }
527 else {
528 count += Find(nlist, node->GetNodeR(), event, nfind, ncurr);
529 count += Find(nlist, node->GetNodeL(), event, nfind, ncurr);
530 }
531 }
532 else {
533 if (node->GetNodeL()) {
534 count += Find(nlist, node->GetNodeL(), event, nfind, ncurr);
535 }
536 if (node->GetNodeR()) {
537 count += Find(nlist, node->GetNodeR(), event, nfind, ncurr);
538 }
539 }
540
541 return count;
542}
543
544#endif
545
int Int_t
Definition: RtypesCore.h:41
unsigned int UInt_t
Definition: RtypesCore.h:42
const Bool_t kFALSE
Definition: RtypesCore.h:88
bool Bool_t
Definition: RtypesCore.h:59
double Double_t
Definition: RtypesCore.h:55
float Float_t
Definition: RtypesCore.h:53
const Bool_t kTRUE
Definition: RtypesCore.h:87
This file contains binary tree and global function template that searches tree for k-nearest neigbors...
Definition: NodekNN.h:67
Float_t GetVarMin() const
Definition: NodekNN.h:192
void SetNodeL(Node *node)
Definition: NodekNN.h:144
const Node * fNodeP
Definition: NodekNN.h:106
void SetNodeR(Node *node)
Definition: NodekNN.h:150
Double_t GetWeight() const
Definition: NodekNN.h:180
Float_t GetVarDis() const
Definition: NodekNN.h:186
Float_t GetVarMax() const
Definition: NodekNN.h:198
void Print(std::ostream &os, const std::string &offset="") const
Definition: NodekNN.h:292
const Node * GetNodeL() const
Definition: NodekNN.h:162
Node(const Node *parent, const T &event, Int_t mod)
Definition: NodekNN.h:224
Float_t fVarMin
Definition: NodekNN.h:115
const Node * GetNodeP() const
Definition: NodekNN.h:174
const Float_t fVarDis
Definition: NodekNN.h:113
const UInt_t fMod
Definition: NodekNN.h:118
void Print() const
Definition: NodekNN.h:285
Node * fNodeR
Definition: NodekNN.h:109
const Node & operator=(const Node &)
Node(const Node &)
const Node * GetNodeR() const
Definition: NodekNN.h:168
UInt_t GetMod() const
Definition: NodekNN.h:204
const T fEvent
Definition: NodekNN.h:111
const Node * Add(const T &event, UInt_t depth)
This is Node member function that adds a new node to a binary tree.
Definition: NodekNN.h:249
Node * fNodeL
Definition: NodekNN.h:108
Float_t fVarMax
Definition: NodekNN.h:116
const T & GetEvent() const
Definition: NodekNN.h:156
double T(double x)
Definition: ChebyshevPol.h:34
void Print(std::ostream &os, const OptionType &opt)
static constexpr double second
UInt_t Find(std::list< std::pair< const Node< T > *, Float_t > > &nlist, const Node< T > *node, const T &event, UInt_t nfind)
UInt_t Depth(const Node< T > *node)
Definition: NodekNN.h:213
create variable transformations
Definition: first.py:1