26 #ifndef ROOT_TMVA_NodekNN
27 #define ROOT_TMVA_NodekNN
99 void Print(std::ostream& os,
const std::string &
offset =
"")
const;
187 return fEvent.GetWeight();
234 fVarDis(event.GetVar(mod)),
244 if (fNodeL)
delete fNodeL;
245 if (fNodeR)
delete fNodeR;
256 assert(fMod == depth % event.GetNVar() &&
"Wrong recursive depth in Node<>::Add");
258 const Float_t value =
event.GetVar(fMod);
264 if (value < fVarDis) {
267 return fNodeL->
Add(event, depth + 1);
270 fNodeL =
new Node<T>(
this, event, (depth + 1) % event.GetNVar());
276 return fNodeR->
Add(event, depth + 1);
279 fNodeR =
new Node<T>(
this, event, (depth + 1) % event.GetNVar());
298 os << offset <<
"-----------------------------------------------------------" << std::endl;
299 os << offset <<
"Node: mod " << fMod
301 <<
" with weight: " <<
GetWeight() << std::endl
305 os << offset <<
"Has left node " << std::endl;
308 os << offset <<
"Has right node" << std::endl;
312 os << offset <<
"PrInt_t left node " << std::endl;
313 fNodeL->Print(os, offset +
" ");
316 os << offset <<
"PrInt_t right node" << std::endl;
317 fNodeR->Print(os, offset +
" ");
320 if (!fNodeL && !fNodeR) {
340 if (!node || nfind < 1) {
350 if (!nlist.empty()) {
352 max_dist = nlist.back().second;
354 if (nlist.size() == nfind) {
359 if (value < node->GetVarMin() &&
371 if (nlist.size() < nfind) {
374 else if (nlist.size() == nfind) {
375 if (distance < max_dist) {
381 std::cerr <<
"TMVA::kNN::Find() - logic error in recursive procedure" << std::endl;
388 typename std::list<std::pair<const Node<T> *,
Float_t> >::iterator lit = nlist.begin();
391 for (; lit != nlist.end(); ++lit) {
392 if (distance < lit->second) {
400 nlist.insert(lit, std::pair<
const Node<T> *,
Float_t>(node, distance));
410 if (value < node->GetVarDis()) {
450 if (!node || !(nfind < 0.0)) {
460 if (!nlist.empty()) {
462 max_dist = nlist.back().second;
464 if (!(ncurr < nfind)) {
469 if (value < node->GetVarMin() &&
483 else if (!nlist.empty()) {
484 if (distance < max_dist) {
489 std::cerr <<
"TMVA::kNN::Find() - logic error in recursive procedure" << std::endl;
499 typename std::list<std::pair<const Node<T> *,
Float_t> >::iterator lit = nlist.begin();
502 for (; lit != nlist.end(); ++lit) {
503 if (distance < lit->second) {
510 lit = nlist.insert(lit, std::pair<
const Node<T> *,
Float_t>(node, distance));
512 for (; lit != nlist.end(); ++lit) {
514 if (!(ncurr < nfind)) {
520 if(lit != nlist.end())
522 nlist.erase(lit, nlist.end());
529 if (value < node->GetVarDis()) {
530 count +=
Find(nlist, node->
GetNodeL(), event, nfind, ncurr);
531 count +=
Find(nlist, node->
GetNodeR(), event, nfind, ncurr);
534 count +=
Find(nlist, node->
GetNodeR(), event, nfind, ncurr);
535 count +=
Find(nlist, node->
GetNodeL(), event, nfind, ncurr);
540 count +=
Find(nlist, node->
GetNodeL(), event, nfind, ncurr);
543 count +=
Find(nlist, node->
GetNodeR(), event, nfind, ncurr);
const Node * GetNodeR() const
static Vc_ALWAYS_INLINE int_v min(const int_v &x, const int_v &y)
const Node * Add(const T &event, UInt_t depth)
UInt_t Depth(const Node< T > *node)
const Node & operator=(const Node &)
Float_t GetVarMin() const
void SetNodeL(Node *node)
Double_t GetWeight(Double_t x) const
const Node * GetNodeP() const
void Print(std::ostream &os, const OptionType &opt)
const Node * GetNodeL() const
Float_t GetVarMax() const
Double_t GetWeight() const
const T & GetEvent() const
UInt_t Find(std::list< std::pair< const Node< T > *, Float_t > > &nlist, const Node< T > *node, const T &event, UInt_t nfind)
static Vc_ALWAYS_INLINE int_v max(const int_v &x, const int_v &y)
void SetNodeR(Node *node)
Float_t GetVarDis() const