Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
MethodKNN.cxx
Go to the documentation of this file.
1// @(#)root/tmva $Id$
2// Author: Rustem Ospanov
3
4/**********************************************************************************
5 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6 * Package: TMVA *
7 * Class : MethodKNN *
8 * *
9 * *
10 * Description: *
11 * Implementation *
12 * *
13 * Author: *
14 * Rustem Ospanov <rustem@fnal.gov> - U. of Texas at Austin, USA *
15 * *
16 * Copyright (c) 2007: *
17 * CERN, Switzerland *
18 * MPI-K Heidelberg, Germany *
19 * U. of Texas at Austin, USA *
20 * *
21 * Redistribution and use in source and binary forms, with or without *
22 * modification, are permitted according to the terms listed in LICENSE *
23 * (see tmva/doc/LICENSE) *
24 **********************************************************************************/
25
26/*! \class TMVA::MethodKNN
27\ingroup TMVA
28
29Analysis of k-nearest neighbor.
30
31*/
32
33#include "TMVA/MethodKNN.h"
34
36#include "TMVA/Configurable.h"
37#include "TMVA/DataSetInfo.h"
38#include "TMVA/Event.h"
39#include "TMVA/LDA.h"
40#include "TMVA/IMethod.h"
41#include "TMVA/MethodBase.h"
42#include "TMVA/MsgLogger.h"
43#include "TMVA/Ranking.h"
44#include "TMVA/Tools.h"
45#include "TMVA/Types.h"
46
47#include "TFile.h"
48#include "TMath.h"
49#include "TTree.h"
50
51#include <cmath>
52#include <string>
53#include <cstdlib>
54
56
57
58////////////////////////////////////////////////////////////////////////////////
59/// standard constructor
60
62 const TString& methodTitle,
65 : TMVA::MethodBase(jobName, Types::kKNN, methodTitle, theData, theOption)
66 , fSumOfWeightsS(0)
67 , fSumOfWeightsB(0)
68 , fModule(0)
69 , fnkNN(0)
70 , fBalanceDepth(0)
71 , fScaleFrac(0)
72 , fSigmaFact(0)
73 , fTrim(kFALSE)
74 , fUseKernel(kFALSE)
75 , fUseWeight(kFALSE)
76 , fUseLDA(kFALSE)
77 , fTreeOptDepth(0)
78{
79}
80
81////////////////////////////////////////////////////////////////////////////////
82/// constructor from weight file
83
87 , fSumOfWeightsS(0)
88 , fSumOfWeightsB(0)
89 , fModule(0)
90 , fnkNN(0)
91 , fBalanceDepth(0)
92 , fScaleFrac(0)
93 , fSigmaFact(0)
94 , fTrim(kFALSE)
95 , fUseKernel(kFALSE)
96 , fUseWeight(kFALSE)
97 , fUseLDA(kFALSE)
98 , fTreeOptDepth(0)
99{
100}
101
102////////////////////////////////////////////////////////////////////////////////
103/// destructor
104
106{
107 if (fModule) delete fModule;
108}
109
110////////////////////////////////////////////////////////////////////////////////
111/// MethodKNN options
112///
113/// - fnkNN = 20; // number of k-nearest neighbors
114/// - fBalanceDepth = 6; // number of binary tree levels used for tree balancing
115/// - fScaleFrac = 0.8; // fraction of events used to compute variable width
116/// - fSigmaFact = 1.0; // scale factor for Gaussian sigma
117/// - fKernel = use polynomial (1-x^3)^3 or Gaussian kernel
118/// - fTrim = false; // use equal number of signal and background events
119/// - fUseKernel = false; // use polynomial kernel weight function
120/// - fUseWeight = true; // count events using weights
121/// - fUseLDA = false
122
124{
125 DeclareOptionRef(fnkNN = 20, "nkNN", "Number of k-nearest neighbors");
126 DeclareOptionRef(fBalanceDepth = 6, "BalanceDepth", "Binary tree balance depth");
127 DeclareOptionRef(fScaleFrac = 0.80, "ScaleFrac", "Fraction of events used to compute variable width");
128 DeclareOptionRef(fSigmaFact = 1.0, "SigmaFact", "Scale factor for sigma in Gaussian kernel");
129 DeclareOptionRef(fKernel = "Gaus", "Kernel", "Use polynomial (=Poln) or Gaussian (=Gaus) kernel");
130 DeclareOptionRef(fTrim = kFALSE, "Trim", "Use equal number of signal and background events");
131 DeclareOptionRef(fUseKernel = kFALSE, "UseKernel", "Use polynomial kernel weight");
132 DeclareOptionRef(fUseWeight = kTRUE, "UseWeight", "Use weight to count kNN events");
133 DeclareOptionRef(fUseLDA = kFALSE, "UseLDA", "Use local linear discriminant - experimental feature");
134}
135
136////////////////////////////////////////////////////////////////////////////////
137/// options that are used ONLY for the READER to ensure backward compatibility
138
141 DeclareOptionRef(fTreeOptDepth = 6, "TreeOptDepth", "Binary tree optimisation depth");
142}
143
144////////////////////////////////////////////////////////////////////////////////
145/// process the options specified by the user
146
148{
149 if (!(fnkNN > 0)) {
150 fnkNN = 10;
151 Log() << kWARNING << "kNN must be a positive integer: set kNN = " << fnkNN << Endl;
152 }
153 if (fScaleFrac < 0.0) {
154 fScaleFrac = 0.0;
155 Log() << kWARNING << "ScaleFrac can not be negative: set ScaleFrac = " << fScaleFrac << Endl;
156 }
157 if (fScaleFrac > 1.0) {
158 fScaleFrac = 1.0;
159 }
160 if (!(fBalanceDepth > 0)) {
161 fBalanceDepth = 6;
162 Log() << kWARNING << "Optimize must be a positive integer: set Optimize = " << fBalanceDepth << Endl;
163 }
164
165 Log() << kVERBOSE
166 << "kNN options: \n"
167 << " kNN = \n" << fnkNN
168 << " UseKernel = \n" << fUseKernel
169 << " SigmaFact = \n" << fSigmaFact
170 << " ScaleFrac = \n" << fScaleFrac
171 << " Kernel = \n" << fKernel
172 << " Trim = \n" << fTrim
173 << " Optimize = " << fBalanceDepth << Endl;
174}
175
176////////////////////////////////////////////////////////////////////////////////
177/// FDA can handle classification with 2 classes and regression with one regression-target
178
185
186////////////////////////////////////////////////////////////////////////////////
187/// Initialization
188
190{
191 // fScaleFrac <= 0.0 then do not scale input variables
192 // fScaleFrac >= 1.0 then use all event coordinates to scale input variables
193
194 fModule = new kNN::ModulekNN();
195 fSumOfWeightsS = 0;
196 fSumOfWeightsB = 0;
197}
198
199////////////////////////////////////////////////////////////////////////////////
200/// create kNN
201
203{
204 if (!fModule) {
205 Log() << kFATAL << "ModulekNN is not created" << Endl;
206 }
207
208 fModule->Clear();
209
210 std::string option;
211 if (fScaleFrac > 0.0) {
212 option += "metric";
213 }
214 if (fTrim) {
215 option += "trim";
216 }
217
218 Log() << kINFO << "Creating kd-tree with " << fEvent.size() << " events" << Endl;
219
220 for (kNN::EventVec::const_iterator event = fEvent.begin(); event != fEvent.end(); ++event) {
221 fModule->Add(*event);
222 }
223
224 // create binary tree
225 fModule->Fill(static_cast<UInt_t>(fBalanceDepth),
226 static_cast<UInt_t>(100.0*fScaleFrac),
227 option);
228}
229
230////////////////////////////////////////////////////////////////////////////////
231/// kNN training
232
234{
235 Log() << kHEADER << "<Train> start..." << Endl;
236
237 if (IsNormalised()) {
238 Log() << kINFO << "Input events are normalized - setting ScaleFrac to 0" << Endl;
239 fScaleFrac = 0.0;
240 }
241
242 if (!fEvent.empty()) {
243 Log() << kINFO << "Erasing " << fEvent.size() << " previously stored events" << Endl;
244 fEvent.clear();
245 }
246 if (GetNVariables() < 1)
247 Log() << kFATAL << "MethodKNN::Train() - mismatched or wrong number of event variables" << Endl;
248
249
250 Log() << kINFO << "Reading " << GetNEvents() << " events" << Endl;
251
252 for (UInt_t ievt = 0; ievt < GetNEvents(); ++ievt) {
253 // read the training event
254 const Event* evt_ = GetEvent(ievt);
255 Double_t weight = evt_->GetWeight();
256
257 // in case event with neg weights are to be ignored
258 if (IgnoreEventsWithNegWeightsInTraining() && weight <= 0) continue;
259
260 kNN::VarVec vvec(GetNVariables(), 0.0);
261 for (UInt_t ivar = 0; ivar < evt_ -> GetNVariables(); ++ivar) vvec[ivar] = evt_->GetValue(ivar);
262
264
265 if (DataInfo().IsSignal(evt_)) { // signal type = 1
266 fSumOfWeightsS += weight;
267 event_type = 1;
268 }
269 else { // background type = 2
270 fSumOfWeightsB += weight;
271 event_type = 2;
272 }
273
274 //
275 // Create event and add classification variables, weight, type and regression variables
276 //
278 event_knn.SetTargets(evt_->GetTargets());
279 fEvent.push_back(event_knn);
280
281 }
282 Log() << kINFO
283 << "Number of signal events " << fSumOfWeightsS << Endl
284 << "Number of background events " << fSumOfWeightsB << Endl;
285
286 // create kd-tree (binary tree) structure
287 MakeKNN();
288
289 ExitFromTraining();
290}
291
292////////////////////////////////////////////////////////////////////////////////
293/// Compute classifier response
294
296{
297 // cannot determine error
298 NoErrorCalc(err, errUpper);
299
300 //
301 // Define local variables
302 //
303 const Event *ev = GetEvent();
304 const Int_t nvar = GetNVariables();
305 const Double_t weight = ev->GetWeight();
306 const UInt_t knn = static_cast<UInt_t>(fnkNN);
307
308 kNN::VarVec vvec(static_cast<UInt_t>(nvar), 0.0);
309
310 for (Int_t ivar = 0; ivar < nvar; ++ivar) {
311 vvec[ivar] = ev->GetValue(ivar);
312 }
313
314 // search for fnkNN+2 nearest neighbors, pad with two
315 // events to avoid Monte-Carlo events with zero distance
316 // most of CPU time is spent in this recursive function
317 const kNN::Event event_knn(vvec, weight, 3);
318 fModule->Find(event_knn, knn + 2);
319
320 const kNN::List &rlist = fModule->GetkNNList();
321 if (rlist.size() != knn + 2) {
322 Log() << kFATAL << "kNN result list is empty" << Endl;
323 return -100.0;
324 }
325
326 if (fUseLDA) return MethodKNN::getLDAValue(rlist, event_knn);
327
328 //
329 // Set flags for kernel option=Gaus, Poln
330 //
331 Bool_t use_gaus = false, use_poln = false;
332 if (fUseKernel) {
333 if (fKernel == "Gaus") use_gaus = true;
334 else if (fKernel == "Poln") use_poln = true;
335 }
336
337 //
338 // Compute radius for polynomial kernel
339 //
340 Double_t kradius = -1.0;
341 if (use_poln) {
343
344 if (!(kradius > 0.0)) {
345 Log() << kFATAL << "kNN radius is not positive" << Endl;
346 return -100.0;
347 }
348
350 }
351
352 //
353 // Compute RMS of variable differences for Gaussian sigma
354 //
355 std::vector<Double_t> rms_vec;
356 if (use_gaus) {
358
359 if (rms_vec.empty() || rms_vec.size() != event_knn.GetNVar()) {
360 Log() << kFATAL << "Failed to compute RMS vector" << Endl;
361 return -100.0;
362 }
363 }
364
365 UInt_t count_all = 0;
367
368 for (kNN::List::const_iterator lit = rlist.begin(); lit != rlist.end(); ++lit) {
369
370 // get reference to current node to make code more readable
371 const kNN::Node<kNN::Event> &node = *(lit->first);
372
373 // Warn about Monte-Carlo event with zero distance
374 // this happens when this query event is also in learning sample
375 if (lit->second < 0.0) {
376 Log() << kFATAL << "A neighbor has negative distance to query event" << Endl;
377 }
378 else if (!(lit->second > 0.0)) {
379 Log() << kVERBOSE << "A neighbor has zero distance to query event" << Endl;
380 }
381
382 // get event weight and scale weight by kernel function
383 Double_t evweight = node.GetWeight();
386
387 if (fUseWeight) weight_all += evweight;
388 else ++weight_all;
389
390 if (node.GetEvent().GetType() == 1) { // signal type = 1
391 if (fUseWeight) weight_sig += evweight;
392 else ++weight_sig;
393 }
394 else if (node.GetEvent().GetType() == 2) { // background type = 2
395 }
396 else {
397 Log() << kFATAL << "Unknown type for training event" << Endl;
398 }
399
400 // use only fnkNN events
401 ++count_all;
402
403 if (count_all >= knn) {
404 break;
405 }
406 }
407
408 // check that total number of events or total weight sum is positive
409 if (!(count_all > 0)) {
410 Log() << kFATAL << "Size kNN result list is not positive" << Endl;
411 return -100.0;
412 }
413
414 // check that number of events matches number of k in knn
415 if (count_all < knn) {
416 Log() << kDEBUG << "count_all and kNN have different size: " << count_all << " < " << knn << Endl;
417 }
418
419 // Check that total weight is positive
420 if (!(weight_all > 0.0)) {
421 Log() << kFATAL << "kNN result total weight is not positive" << Endl;
422 return -100.0;
423 }
424
425 return weight_sig/weight_all;
426}
427
428////////////////////////////////////////////////////////////////////////////////
429/// Return vector of averages for target values of k-nearest neighbors.
430/// Use own copy of the regression vector, I do not like using a pointer to vector.
431
432const std::vector< Float_t >& TMVA::MethodKNN::GetRegressionValues()
433{
434 if( fRegressionReturnVal == 0 )
435 fRegressionReturnVal = new std::vector<Float_t>;
436 else
437 fRegressionReturnVal->clear();
438
439 //
440 // Define local variables
441 //
442 const Event *evt = GetEvent();
443 const Int_t nvar = GetNVariables();
444 const UInt_t knn = static_cast<UInt_t>(fnkNN);
445 std::vector<float> reg_vec;
446
447 kNN::VarVec vvec(static_cast<UInt_t>(nvar), 0.0);
448
449 for (Int_t ivar = 0; ivar < nvar; ++ivar) {
450 vvec[ivar] = evt->GetValue(ivar);
451 }
452
453 // search for fnkNN+2 nearest neighbors, pad with two
454 // events to avoid Monte-Carlo events with zero distance
455 // most of CPU time is spent in this recursive function
456 const kNN::Event event_knn(vvec, evt->GetWeight(), 3);
457 fModule->Find(event_knn, knn + 2);
458
459 const kNN::List &rlist = fModule->GetkNNList();
460 if (rlist.size() != knn + 2) {
461 Log() << kFATAL << "kNN result list is empty" << Endl;
462 return *fRegressionReturnVal;
463 }
464
465 // compute regression values
467 UInt_t count_all = 0;
468
469 for (kNN::List::const_iterator lit = rlist.begin(); lit != rlist.end(); ++lit) {
470
471 // get reference to current node to make code more readable
472 const kNN::Node<kNN::Event> &node = *(lit->first);
473 const kNN::VarVec &tvec = node.GetEvent().GetTargets();
474 const Double_t weight = node.GetEvent().GetWeight();
475
476 if (reg_vec.empty()) {
477 reg_vec= kNN::VarVec(tvec.size(), 0.0);
478 }
479
480 for(UInt_t ivar = 0; ivar < tvec.size(); ++ivar) {
481 if (fUseWeight) reg_vec[ivar] += tvec[ivar]*weight;
482 else reg_vec[ivar] += tvec[ivar];
483 }
484
485 if (fUseWeight) weight_all += weight;
486 else ++weight_all;
487
488 // use only fnkNN events
489 ++count_all;
490
491 if (count_all == knn) {
492 break;
493 }
494 }
495
496 // check that number of events matches number of k in knn
497 if (!(weight_all > 0.0)) {
498 Log() << kFATAL << "Total weight sum is not positive: " << weight_all << Endl;
499 return *fRegressionReturnVal;
500 }
501
502 for (UInt_t ivar = 0; ivar < reg_vec.size(); ++ivar) {
504 }
505
506 // copy result
507 fRegressionReturnVal->insert(fRegressionReturnVal->begin(), reg_vec.begin(), reg_vec.end());
508
509 return *fRegressionReturnVal;
510}
511
512////////////////////////////////////////////////////////////////////////////////
513/// no ranking available
514
516{
517 return 0;
518}
519
520////////////////////////////////////////////////////////////////////////////////
521/// write weights to XML
522
523void TMVA::MethodKNN::AddWeightsXMLTo( void* parent ) const {
524 void* wght = gTools().AddChild(parent, "Weights");
525 gTools().AddAttr(wght,"NEvents",fEvent.size());
526 if (fEvent.size()>0) gTools().AddAttr(wght,"NVar",fEvent.begin()->GetNVar());
527 if (fEvent.size()>0) gTools().AddAttr(wght,"NTgt",fEvent.begin()->GetNTgt());
528
529 for (kNN::EventVec::const_iterator event = fEvent.begin(); event != fEvent.end(); ++event) {
530
531 std::stringstream s("");
532 s.precision( 16 );
533 for (UInt_t ivar = 0; ivar < event->GetNVar(); ++ivar) {
534 if (ivar>0) s << " ";
535 s << std::scientific << event->GetVar(ivar);
536 }
537
538 for (UInt_t itgt = 0; itgt < event->GetNTgt(); ++itgt) {
539 s << " " << std::scientific << event->GetTgt(itgt);
540 }
541
542 void* evt = gTools().AddChild(wght, "Event", s.str().c_str());
543 gTools().AddAttr(evt,"Type", event->GetType());
544 gTools().AddAttr(evt,"Weight", event->GetWeight());
545 }
546}
547
548////////////////////////////////////////////////////////////////////////////////
549
551 void* ch = gTools().GetChild(wghtnode); // first event
552 UInt_t nvar = 0, ntgt = 0;
553 gTools().ReadAttr( wghtnode, "NVar", nvar );
554 gTools().ReadAttr( wghtnode, "NTgt", ntgt );
555
556
557 Short_t evtType(0);
559
560 while (ch) {
561 // build event
562 kNN::VarVec vvec(nvar, 0);
564
565 gTools().ReadAttr( ch, "Type", evtType );
566 gTools().ReadAttr( ch, "Weight", evtWeight );
567 std::stringstream s( gTools().GetContent(ch) );
568
569 for(UInt_t ivar=0; ivar<nvar; ivar++)
570 s >> vvec[ivar];
571
572 for(UInt_t itgt=0; itgt<ntgt; itgt++)
573 s >> tvec[itgt];
574
575 ch = gTools().GetNextChild(ch);
576
578 fEvent.push_back(event_knn);
579 }
580
581 // create kd-tree (binary tree) structure
582 MakeKNN();
583}
584
585////////////////////////////////////////////////////////////////////////////////
586/// read the weights
587
589{
590 Log() << kINFO << "Starting ReadWeightsFromStream(std::istream& is) function..." << Endl;
591
592 if (!fEvent.empty()) {
593 Log() << kINFO << "Erasing " << fEvent.size() << " previously stored events" << Endl;
594 fEvent.clear();
595 }
596
597 UInt_t nvar = 0;
598
599 while (is) {
600 std::string line;
601 std::getline(is, line);
602
603 if (line.empty() || line.find("#") != std::string::npos) {
604 continue;
605 }
606
607 UInt_t count = 0;
608 std::string::size_type pos=0;
609 while( (pos=line.find(',',pos)) != std::string::npos ) { count++; pos++; }
610
611 if (nvar == 0) {
612 nvar = count - 2;
613 }
614 if (count < 3 || nvar != count - 2) {
615 Log() << kFATAL << "Missing comma delimeter(s)" << Endl;
616 }
617
618 // Int_t ievent = -1;
619 Int_t type = -1;
620 Double_t weight = -1.0;
621
622 kNN::VarVec vvec(nvar, 0.0);
623
624 UInt_t vcount = 0;
625 std::string::size_type prev = 0;
626
627 for (std::string::size_type ipos = 0; ipos < line.size(); ++ipos) {
628 if (line[ipos] != ',' && ipos + 1 != line.size()) {
629 continue;
630 }
631
632 if (!(ipos > prev)) {
633 Log() << kFATAL << "Wrong substring limits" << Endl;
634 }
635
636 std::string vstring = line.substr(prev, ipos - prev);
637 if (ipos + 1 == line.size()) {
638 vstring = line.substr(prev, ipos - prev + 1);
639 }
640
641 if (vstring.empty()) {
642 Log() << kFATAL << "Failed to parse string" << Endl;
643 }
644
645 if (vcount == 0) {
646 // ievent = std::atoi(vstring.c_str());
647 }
648 else if (vcount == 1) {
649 type = std::atoi(vstring.c_str());
650 }
651 else if (vcount == 2) {
652 weight = std::atof(vstring.c_str());
653 }
654 else if (vcount - 3 < vvec.size()) {
655 vvec[vcount - 3] = std::atof(vstring.c_str());
656 }
657 else {
658 Log() << kFATAL << "Wrong variable count" << Endl;
659 }
660
661 prev = ipos + 1;
662 ++vcount;
663 }
664
665 fEvent.push_back(kNN::Event(vvec, weight, type));
666 }
667
668 Log() << kINFO << "Read " << fEvent.size() << " events from text file" << Endl;
669
670 // create kd-tree (binary tree) structure
671 MakeKNN();
672}
673
674////////////////////////////////////////////////////////////////////////////////
675/// save weights to ROOT file
676
678{
679 Log() << kINFO << "Starting WriteWeightsToStream(TFile &rf) function..." << Endl;
680
681 if (fEvent.empty()) {
682 Log() << kWARNING << "MethodKNN contains no events " << Endl;
683 return;
684 }
685
686 kNN::Event *event = new kNN::Event();
687 TTree *tree = new TTree("knn", "event tree");
688 tree->SetDirectory(nullptr);
689 tree->Branch("event", "TMVA::kNN::Event", &event);
690
691 Double_t size = 0.0;
692 for (kNN::EventVec::const_iterator it = fEvent.begin(); it != fEvent.end(); ++it) {
693 (*event) = (*it);
694 size += tree->Fill();
695 }
696
697 // !!! hard coded tree name !!!
698 rf.WriteTObject(tree, "knn", "Overwrite");
699
700 // scale to MegaBytes
701 size /= 1048576.0;
702
703 Log() << kINFO << "Wrote " << size << "MB and " << fEvent.size()
704 << " events to ROOT file" << Endl;
705
706 delete tree;
707 delete event;
708}
709
710////////////////////////////////////////////////////////////////////////////////
711/// read weights from ROOT file
712
714{
715 Log() << kINFO << "Starting ReadWeightsFromStream(TFile &rf) function..." << Endl;
716
717 if (!fEvent.empty()) {
718 Log() << kINFO << "Erasing " << fEvent.size() << " previously stored events" << Endl;
719 fEvent.clear();
720 }
721
722 // !!! hard coded tree name !!!
723 TTree *tree = dynamic_cast<TTree *>(rf.Get("knn"));
724 if (!tree) {
725 Log() << kFATAL << "Failed to find knn tree" << Endl;
726 return;
727 }
728
729 kNN::Event *event = new kNN::Event();
730 tree->SetBranchAddress("event", &event);
731
732 const Int_t nevent = tree->GetEntries();
733
734 Double_t size = 0.0;
735 for (Int_t i = 0; i < nevent; ++i) {
736 size += tree->GetEntry(i);
737 fEvent.push_back(*event);
738 }
739
740 // scale to MegaBytes
741 size /= 1048576.0;
742
743 Log() << kINFO << "Read " << size << "MB and " << fEvent.size()
744 << " events from ROOT file" << Endl;
745
746 delete event;
747
748 // create kd-tree (binary tree) structure
749 MakeKNN();
750}
751
752////////////////////////////////////////////////////////////////////////////////
753/// write specific classifier response
754
755void TMVA::MethodKNN::MakeClassSpecific( std::ostream& fout, const TString& className ) const
756{
757 fout << " // not implemented for class: \"" << className << "\"" << std::endl;
758 fout << "};" << std::endl;
759}
760
761////////////////////////////////////////////////////////////////////////////////
762/// get help message text
763///
764/// typical length of text line:
765/// "|--------------------------------------------------------------|"
766
768{
769 Log() << Endl;
770 Log() << gTools().Color("bold") << "--- Short description:" << gTools().Color("reset") << Endl;
771 Log() << Endl;
772 Log() << "The k-nearest neighbor (k-NN) algorithm is a multi-dimensional classification" << Endl
773 << "and regression algorithm. Similarly to other TMVA algorithms, k-NN uses a set of" << Endl
774 << "training events for which a classification category/regression target is known. " << Endl
775 << "The k-NN method compares a test event to all training events using a distance " << Endl
776 << "function, which is an Euclidean distance in a space defined by the input variables. "<< Endl
777 << "The k-NN method, as implemented in TMVA, uses a kd-tree algorithm to perform a" << Endl
778 << "quick search for the k events with shortest distance to the test event. The method" << Endl
779 << "returns a fraction of signal events among the k neighbors. It is recommended" << Endl
780 << "that a histogram which stores the k-NN decision variable is binned with k+1 bins" << Endl
781 << "between 0 and 1." << Endl;
782
783 Log() << Endl;
784 Log() << gTools().Color("bold") << "--- Performance tuning via configuration options: "
785 << gTools().Color("reset") << Endl;
786 Log() << Endl;
787 Log() << "The k-NN method estimates a density of signal and background events in a "<< Endl
788 << "neighborhood around the test event. The method assumes that the density of the " << Endl
789 << "signal and background events is uniform and constant within the neighborhood. " << Endl
790 << "k is an adjustable parameter and it determines an average size of the " << Endl
791 << "neighborhood. Small k values (less than 10) are sensitive to statistical " << Endl
792 << "fluctuations and large (greater than 100) values might not sufficiently capture " << Endl
793 << "local differences between events in the training set. The speed of the k-NN" << Endl
794 << "method also increases with larger values of k. " << Endl;
795 Log() << Endl;
796 Log() << "The k-NN method assigns equal weight to all input variables. Different scales " << Endl
797 << "among the input variables is compensated using ScaleFrac parameter: the input " << Endl
798 << "variables are scaled so that the widths for central ScaleFrac*100% events are " << Endl
799 << "equal among all the input variables." << Endl;
800
801 Log() << Endl;
802 Log() << gTools().Color("bold") << "--- Additional configuration options: "
803 << gTools().Color("reset") << Endl;
804 Log() << Endl;
805 Log() << "The method inclues an option to use a Gaussian kernel to smooth out the k-NN" << Endl
806 << "response. The kernel re-weights events using a distance to the test event." << Endl;
807}
808
809////////////////////////////////////////////////////////////////////////////////
810/// polynomial kernel
811
813{
815
816 if (!(avalue < 1.0)) {
817 return 0.0;
818 }
819
820 const Double_t prod = 1.0 - avalue * avalue * avalue;
821
822 return (prod * prod * prod);
823}
824
825////////////////////////////////////////////////////////////////////////////////
826/// Gaussian kernel
827
829 const kNN::Event &event, const std::vector<Double_t> &svec) const
830{
831 if (event_knn.GetNVar() != event.GetNVar() || event_knn.GetNVar() != svec.size()) {
832 Log() << kFATAL << "Mismatched vectors in Gaussian kernel function" << Endl;
833 return 0.0;
834 }
835
836 //
837 // compute exponent
838 //
839 double sum_exp = 0.0;
840
841 for(unsigned int ivar = 0; ivar < event_knn.GetNVar(); ++ivar) {
842
843 const Double_t diff_ = event.GetVar(ivar) - event_knn.GetVar(ivar);
844 const Double_t sigm_ = svec[ivar];
845 if (!(sigm_ > 0.0)) {
846 Log() << kFATAL << "Bad sigma value = " << sigm_ << Endl;
847 return 0.0;
848 }
849
850 sum_exp += diff_*diff_/(2.0*sigm_*sigm_);
851 }
852
853 //
854 // Return unnormalized(!) Gaussian function, because normalization
855 // cancels for the ratio of weights.
856 //
857
858 return std::exp(-sum_exp);
859}
860
861////////////////////////////////////////////////////////////////////////////////
862///
863/// Get polynomial kernel radius
864///
865
867{
868 Double_t kradius = -1.0;
869 UInt_t kcount = 0;
870 const UInt_t knn = static_cast<UInt_t>(fnkNN);
871
872 for (kNN::List::const_iterator lit = rlist.begin(); lit != rlist.end(); ++lit)
873 {
874 if (!(lit->second > 0.0)) continue;
875
876 if (kradius < lit->second || kradius < 0.0) kradius = lit->second;
877
878 ++kcount;
879 if (kcount >= knn) break;
880 }
881
882 return kradius;
883}
884
885////////////////////////////////////////////////////////////////////////////////
886///
887/// Get polynomial kernel radius
888///
889
890const std::vector<Double_t> TMVA::MethodKNN::getRMS(const kNN::List &rlist, const kNN::Event &event_knn) const
891{
892 std::vector<Double_t> rvec;
893 UInt_t kcount = 0;
894 const UInt_t knn = static_cast<UInt_t>(fnkNN);
895
896 for (kNN::List::const_iterator lit = rlist.begin(); lit != rlist.end(); ++lit)
897 {
898 if (!(lit->second > 0.0)) continue;
899
900 const kNN::Node<kNN::Event> *node_ = lit -> first;
901 const kNN::Event &event_ = node_-> GetEvent();
902
903 if (rvec.empty()) {
904 rvec.insert(rvec.end(), event_.GetNVar(), 0.0);
905 }
906 else if (rvec.size() != event_.GetNVar()) {
907 Log() << kFATAL << "Wrong number of variables, should never happen!" << Endl;
908 rvec.clear();
909 return rvec;
910 }
911
912 for(unsigned int ivar = 0; ivar < event_.GetNVar(); ++ivar) {
913 const Double_t diff_ = event_.GetVar(ivar) - event_knn.GetVar(ivar);
914 rvec[ivar] += diff_*diff_;
915 }
916
917 ++kcount;
918 if (kcount >= knn) break;
919 }
920
921 if (kcount < 1) {
922 Log() << kFATAL << "Bad event kcount = " << kcount << Endl;
923 rvec.clear();
924 return rvec;
925 }
926
927 for(unsigned int ivar = 0; ivar < rvec.size(); ++ivar) {
928 if (!(rvec[ivar] > 0.0)) {
929 Log() << kFATAL << "Bad RMS value = " << rvec[ivar] << Endl;
930 rvec.clear();
931 return rvec;
932 }
933
934 rvec[ivar] = std::abs(fSigmaFact)*std::sqrt(rvec[ivar]/kcount);
935 }
936
937 return rvec;
938}
939
940////////////////////////////////////////////////////////////////////////////////
941
943{
945
946 for (kNN::List::const_iterator lit = rlist.begin(); lit != rlist.end(); ++lit) {
947
948 // get reference to current node to make code more readable
949 const kNN::Node<kNN::Event> &node = *(lit->first);
950 const kNN::VarVec &tvec = node.GetEvent().GetVars();
951
952 if (node.GetEvent().GetType() == 1) { // signal type = 1
953 sig_vec.push_back(tvec);
954 }
955 else if (node.GetEvent().GetType() == 2) { // background type = 2
956 bac_vec.push_back(tvec);
957 }
958 else {
959 Log() << kFATAL << "Unknown type for training event" << Endl;
960 }
961 }
962
963 fLDA.Initialize(sig_vec, bac_vec);
964
965 return fLDA.GetProb(event_knn.GetVars(), 1);
966}
#define REGISTER_METHOD(CLASS)
for example
std::vector< std::vector< Float_t > > LDAEvents
Definition LDA.h:38
size_t size(const MatrixT &matrix)
retrieve the size of a square matrix
short Short_t
Signed Short integer 2 bytes (short)
Definition RtypesCore.h:53
constexpr Bool_t kFALSE
Definition RtypesCore.h:108
constexpr Bool_t kTRUE
Definition RtypesCore.h:107
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
Option_t Option_t option
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void value
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t Atom_t Time_t type
const_iterator begin() const
const_iterator end() const
A ROOT file is an on-disk file, usually with extension .root, that stores objects in a file-system-li...
Definition TFile.h:131
Class that contains all the data information.
Definition DataSetInfo.h:62
Virtual base Class for all MVA method.
Definition MethodBase.h:111
virtual void DeclareCompatibilityOptions()
options that are used ONLY for the READER to ensure backward compatibility they are hence without any...
Analysis of k-nearest neighbor.
Definition MethodKNN.h:54
void MakeKNN(void)
create kNN
virtual ~MethodKNN(void)
destructor
const std::vector< Double_t > getRMS(const kNN::List &rlist, const kNN::Event &event_knn) const
Get polynomial kernel radius.
const Ranking * CreateRanking() override
no ranking available
void DeclareOptions() override
MethodKNN options.
MethodKNN(const TString &jobName, const TString &methodTitle, DataSetInfo &theData, const TString &theOption="KNN")
standard constructor
Definition MethodKNN.cxx:61
Double_t getKernelRadius(const kNN::List &rlist) const
Get polynomial kernel radius.
void Train(void) override
kNN training
double getLDAValue(const kNN::List &rlist, const kNN::Event &event_knn)
void ProcessOptions() override
process the options specified by the user
Double_t PolnKernel(Double_t value) const
polynomial kernel
void DeclareCompatibilityOptions() override
options that are used ONLY for the READER to ensure backward compatibility
void ReadWeightsFromStream(std::istream &istr) override
read the weights
void GetHelpMessage() const override
get help message text
void MakeClassSpecific(std::ostream &, const TString &) const override
write specific classifier response
Double_t GetMvaValue(Double_t *err=nullptr, Double_t *errUpper=nullptr) override
Compute classifier response.
Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets) override
FDA can handle classification with 2 classes and regression with one regression-target.
void Init(void) override
Initialization.
void WriteWeightsToStream(TFile &rf) const
save weights to ROOT file
Double_t GausKernel(const kNN::Event &event_knn, const kNN::Event &event, const std::vector< Double_t > &svec) const
Gaussian kernel.
void ReadWeightsFromXML(void *wghtnode) override
const std::vector< Float_t > & GetRegressionValues() override
Return vector of averages for target values of k-nearest neighbors.
void AddWeightsXMLTo(void *parent) const override
write weights to XML
Ranking for variables in method (implementation)
Definition Ranking.h:48
const TString & Color(const TString &)
human readable color strings
Definition Tools.cxx:828
void ReadAttr(void *node, const char *, T &value)
read attribute from xml
Definition Tools.h:329
void * GetChild(void *parent, const char *childname=nullptr)
get child node
Definition Tools.cxx:1150
void AddAttr(void *node, const char *, const T &value, Int_t precision=16)
add attribute to xml
Definition Tools.h:347
void * AddChild(void *parent, const char *childname, const char *content=nullptr, bool isRootNode=false)
add child node
Definition Tools.cxx:1124
void * GetNextChild(void *prevchild, const char *childname=nullptr)
XML helpers.
Definition Tools.cxx:1162
Singleton class for Global types used by TMVA.
Definition Types.h:71
@ kClassification
Definition Types.h:127
@ kRegression
Definition Types.h:128
UInt_t GetNVar() const
Definition ModulekNN.h:188
This file contains binary tree and global function template that searches tree for k-nearest neigbors...
Definition NodekNN.h:68
Double_t GetWeight() const
Definition NodekNN.h:181
const T & GetEvent() const
Definition NodekNN.h:157
std::list< Elem > List
Definition ModulekNN.h:99
std::vector< VarType > VarVec
Definition ModulekNN.h:57
virtual void Clear(Option_t *="")
Definition TObject.h:125
Basic string class.
Definition TString.h:138
A TTree represents a columnar dataset.
Definition TTree.h:89
TLine * line
create variable transformations
Tools & gTools()
MsgLogger & Endl(MsgLogger &ml)
Definition MsgLogger.h:148
Double_t Sqrt(Double_t x)
Returns the square root of x.
Definition TMath.h:673
Short_t Abs(Short_t d)
Returns the absolute value of parameter Short_t d.
Definition TMathBase.h:124