26#ifndef ROOT_TMVA_NodekNN
27#define ROOT_TMVA_NodekNN
94 void Print(std::ostream& os,
const std::string &offset =
"")
const;
182 return fEvent.GetWeight();
229 fVarDis(event.GetVar(mod)),
239 if (fNodeL)
delete fNodeL;
240 if (fNodeR)
delete fNodeR;
252 assert(fMod == depth % event.GetNVar() &&
"Wrong recursive depth in Node<>::Add");
254 const Float_t value =
event.GetVar(fMod);
256 fVarMin = std::min(fVarMin, value);
257 fVarMax = std::max(fVarMax, value);
260 if (value < fVarDis) {
263 return fNodeL->
Add(event, depth + 1);
266 fNodeL =
new Node<T>(
this, event, (depth + 1) % event.GetNVar());
272 return fNodeR->
Add(event, depth + 1);
275 fNodeR =
new Node<T>(
this, event, (depth + 1) % event.GetNVar());
294 os << offset <<
"-----------------------------------------------------------" << std::endl;
295 os << offset <<
"Node: mod " << fMod
297 <<
" with weight: " << GetWeight() << std::endl
301 os << offset <<
"Has left node " << std::endl;
304 os << offset <<
"Has right node" << std::endl;
308 os << offset <<
"PrInt_t left node " << std::endl;
309 fNodeL->Print(os, offset +
" ");
312 os << offset <<
"PrInt_t right node" << std::endl;
313 fNodeR->Print(os, offset +
" ");
316 if (!fNodeL && !fNodeR) {
335 if (!node || nfind < 1) {
345 if (!nlist.empty()) {
347 max_dist = nlist.back().second;
349 if (nlist.size() == nfind) {
354 if (value < node->GetVarMin() &&
366 if (nlist.size() < nfind) {
369 else if (nlist.size() == nfind) {
370 if (distance < max_dist) {
376 std::cerr <<
"TMVA::kNN::Find() - logic error in recursive procedure" << std::endl;
383 typename std::list<std::pair<const Node<T> *,
Float_t> >::iterator lit = nlist.begin();
386 for (; lit != nlist.end(); ++lit) {
387 if (distance < lit->
second) {
395 nlist.insert(lit, std::pair<
const Node<T> *,
Float_t>(node, distance));
405 if (value < node->GetVarDis()) {
444 if (!node || !(nfind < 0.0)) {
454 if (!nlist.empty()) {
456 max_dist = nlist.back().second;
458 if (!(ncurr < nfind)) {
463 if (value < node->GetVarMin() &&
477 else if (!nlist.empty()) {
478 if (distance < max_dist) {
483 std::cerr <<
"TMVA::kNN::Find() - logic error in recursive procedure" << std::endl;
493 typename std::list<std::pair<const Node<T> *,
Float_t> >::iterator lit = nlist.begin();
496 for (; lit != nlist.end(); ++lit) {
497 if (distance < lit->
second) {
501 ncurr += lit ->
first -> GetWeight();
504 lit = nlist.insert(lit, std::pair<
const Node<T> *,
Float_t>(node, distance));
506 for (; lit != nlist.end(); ++lit) {
507 ncurr += lit ->
first -> GetWeight();
508 if (!(ncurr < nfind)) {
514 if(lit != nlist.end())
516 nlist.erase(lit, nlist.end());
523 if (value < node->GetVarDis()) {
524 count +=
Find(nlist, node->
GetNodeL(), event, nfind, ncurr);
525 count +=
Find(nlist, node->
GetNodeR(), event, nfind, ncurr);
528 count +=
Find(nlist, node->
GetNodeR(), event, nfind, ncurr);
529 count +=
Find(nlist, node->
GetNodeL(), event, nfind, ncurr);
534 count +=
Find(nlist, node->
GetNodeL(), event, nfind, ncurr);
537 count +=
Find(nlist, node->
GetNodeR(), event, nfind, ncurr);
This file contains binary tree and global function template that searches tree for k-nearest neigbors...
Float_t GetVarMin() const
void SetNodeL(Node *node)
void SetNodeR(Node *node)
Double_t GetWeight() const
Float_t GetVarDis() const
Float_t GetVarMax() const
const Node * GetNodeL() const
const Node * GetNodeP() const
const Node & operator=(const Node &)
const Node * GetNodeR() const
const Node * Add(const T &event, UInt_t depth)
This is Node member function that adds a new node to a binary tree.
const T & GetEvent() const
void Print(std::ostream &os, const OptionType &opt)
static constexpr double second
UInt_t Find(std::list< std::pair< const Node< T > *, Float_t > > &nlist, const Node< T > *node, const T &event, UInt_t nfind)
UInt_t Depth(const Node< T > *node)
Abstract ClassifierFactory template that handles arbitrary types.