Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
MethodKNN.cxx
Go to the documentation of this file.
1// @(#)root/tmva $Id$
2// Author: Rustem Ospanov
3
4/**********************************************************************************
5 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6 * Package: TMVA *
7 * Class : MethodKNN *
8 * *
9 * *
10 * Description: *
11 * Implementation *
12 * *
13 * Author: *
14 * Rustem Ospanov <rustem@fnal.gov> - U. of Texas at Austin, USA *
15 * *
16 * Copyright (c) 2007: *
17 * CERN, Switzerland *
18 * MPI-K Heidelberg, Germany *
19 * U. of Texas at Austin, USA *
20 * *
21 * Redistribution and use in source and binary forms, with or without *
22 * modification, are permitted according to the terms listed in LICENSE *
23 * (see tmva/doc/LICENSE) *
24 **********************************************************************************/
25
26/*! \class TMVA::MethodKNN
27\ingroup TMVA
28
29Analysis of k-nearest neighbor.
30
31*/
32
33#include "TMVA/MethodKNN.h"
34
36#include "TMVA/Configurable.h"
37#include "TMVA/DataSetInfo.h"
38#include "TMVA/Event.h"
39#include "TMVA/LDA.h"
40#include "TMVA/IMethod.h"
41#include "TMVA/MethodBase.h"
42#include "TMVA/MsgLogger.h"
43#include "TMVA/Ranking.h"
44#include "TMVA/Tools.h"
45#include "TMVA/Types.h"
46
47#include "TFile.h"
48#include "TMath.h"
49#include "TTree.h"
50
51#include <cmath>
52#include <string>
53#include <cstdlib>
54
56
58
59////////////////////////////////////////////////////////////////////////////////
60/// standard constructor
61
63 const TString& methodTitle,
64 DataSetInfo& theData,
65 const TString& theOption )
66 : TMVA::MethodBase(jobName, Types::kKNN, methodTitle, theData, theOption)
67 , fSumOfWeightsS(0)
68 , fSumOfWeightsB(0)
69 , fModule(0)
70 , fnkNN(0)
71 , fBalanceDepth(0)
72 , fScaleFrac(0)
73 , fSigmaFact(0)
74 , fTrim(kFALSE)
75 , fUseKernel(kFALSE)
76 , fUseWeight(kFALSE)
77 , fUseLDA(kFALSE)
78 , fTreeOptDepth(0)
79{
80}
81
82////////////////////////////////////////////////////////////////////////////////
83/// constructor from weight file
84
86 const TString& theWeightFile)
87 : TMVA::MethodBase( Types::kKNN, theData, theWeightFile)
88 , fSumOfWeightsS(0)
89 , fSumOfWeightsB(0)
90 , fModule(0)
91 , fnkNN(0)
92 , fBalanceDepth(0)
93 , fScaleFrac(0)
94 , fSigmaFact(0)
95 , fTrim(kFALSE)
96 , fUseKernel(kFALSE)
97 , fUseWeight(kFALSE)
98 , fUseLDA(kFALSE)
99 , fTreeOptDepth(0)
100{
101}
102
103////////////////////////////////////////////////////////////////////////////////
104/// destructor
105
107{
108 if (fModule) delete fModule;
109}
110
111////////////////////////////////////////////////////////////////////////////////
112/// MethodKNN options
113///
114/// - fnkNN = 20; // number of k-nearest neighbors
115/// - fBalanceDepth = 6; // number of binary tree levels used for tree balancing
116/// - fScaleFrac = 0.8; // fraction of events used to compute variable width
117/// - fSigmaFact = 1.0; // scale factor for Gaussian sigma
118/// - fKernel = use polynomial (1-x^3)^3 or Gaussian kernel
119/// - fTrim = false; // use equal number of signal and background events
120/// - fUseKernel = false; // use polynomial kernel weight function
121/// - fUseWeight = true; // count events using weights
122/// - fUseLDA = false
123
125{
126 DeclareOptionRef(fnkNN = 20, "nkNN", "Number of k-nearest neighbors");
127 DeclareOptionRef(fBalanceDepth = 6, "BalanceDepth", "Binary tree balance depth");
128 DeclareOptionRef(fScaleFrac = 0.80, "ScaleFrac", "Fraction of events used to compute variable width");
129 DeclareOptionRef(fSigmaFact = 1.0, "SigmaFact", "Scale factor for sigma in Gaussian kernel");
130 DeclareOptionRef(fKernel = "Gaus", "Kernel", "Use polynomial (=Poln) or Gaussian (=Gaus) kernel");
131 DeclareOptionRef(fTrim = kFALSE, "Trim", "Use equal number of signal and background events");
132 DeclareOptionRef(fUseKernel = kFALSE, "UseKernel", "Use polynomial kernel weight");
133 DeclareOptionRef(fUseWeight = kTRUE, "UseWeight", "Use weight to count kNN events");
134 DeclareOptionRef(fUseLDA = kFALSE, "UseLDA", "Use local linear discriminant - experimental feature");
135}
136
137////////////////////////////////////////////////////////////////////////////////
138/// options that are used ONLY for the READER to ensure backward compatibility
139
142 DeclareOptionRef(fTreeOptDepth = 6, "TreeOptDepth", "Binary tree optimisation depth");
143}
144
145////////////////////////////////////////////////////////////////////////////////
146/// process the options specified by the user
147
149{
150 if (!(fnkNN > 0)) {
151 fnkNN = 10;
152 Log() << kWARNING << "kNN must be a positive integer: set kNN = " << fnkNN << Endl;
153 }
154 if (fScaleFrac < 0.0) {
155 fScaleFrac = 0.0;
156 Log() << kWARNING << "ScaleFrac can not be negative: set ScaleFrac = " << fScaleFrac << Endl;
157 }
158 if (fScaleFrac > 1.0) {
159 fScaleFrac = 1.0;
160 }
161 if (!(fBalanceDepth > 0)) {
162 fBalanceDepth = 6;
163 Log() << kWARNING << "Optimize must be a positive integer: set Optimize = " << fBalanceDepth << Endl;
164 }
165
166 Log() << kVERBOSE
167 << "kNN options: \n"
168 << " kNN = \n" << fnkNN
169 << " UseKernel = \n" << fUseKernel
170 << " SigmaFact = \n" << fSigmaFact
171 << " ScaleFrac = \n" << fScaleFrac
172 << " Kernel = \n" << fKernel
173 << " Trim = \n" << fTrim
174 << " Optimize = " << fBalanceDepth << Endl;
175}
176
177////////////////////////////////////////////////////////////////////////////////
178/// FDA can handle classification with 2 classes and regression with one regression-target
179
181{
182 if (type == Types::kClassification && numberClasses == 2) return kTRUE;
183 if (type == Types::kRegression) return kTRUE;
184 return kFALSE;
185}
186
187////////////////////////////////////////////////////////////////////////////////
188/// Initialization
189
191{
192 // fScaleFrac <= 0.0 then do not scale input variables
193 // fScaleFrac >= 1.0 then use all event coordinates to scale input variables
194
195 fModule = new kNN::ModulekNN();
196 fSumOfWeightsS = 0;
197 fSumOfWeightsB = 0;
198}
199
200////////////////////////////////////////////////////////////////////////////////
201/// create kNN
202
204{
205 if (!fModule) {
206 Log() << kFATAL << "ModulekNN is not created" << Endl;
207 }
208
209 fModule->Clear();
210
211 std::string option;
212 if (fScaleFrac > 0.0) {
213 option += "metric";
214 }
215 if (fTrim) {
216 option += "trim";
217 }
218
219 Log() << kINFO << "Creating kd-tree with " << fEvent.size() << " events" << Endl;
220
221 for (kNN::EventVec::const_iterator event = fEvent.begin(); event != fEvent.end(); ++event) {
222 fModule->Add(*event);
223 }
224
225 // create binary tree
226 fModule->Fill(static_cast<UInt_t>(fBalanceDepth),
227 static_cast<UInt_t>(100.0*fScaleFrac),
228 option);
229}
230
231////////////////////////////////////////////////////////////////////////////////
232/// kNN training
233
235{
236 Log() << kHEADER << "<Train> start..." << Endl;
237
238 if (IsNormalised()) {
239 Log() << kINFO << "Input events are normalized - setting ScaleFrac to 0" << Endl;
240 fScaleFrac = 0.0;
241 }
242
243 if (!fEvent.empty()) {
244 Log() << kINFO << "Erasing " << fEvent.size() << " previously stored events" << Endl;
245 fEvent.clear();
246 }
247 if (GetNVariables() < 1)
248 Log() << kFATAL << "MethodKNN::Train() - mismatched or wrong number of event variables" << Endl;
249
250
251 Log() << kINFO << "Reading " << GetNEvents() << " events" << Endl;
252
253 for (UInt_t ievt = 0; ievt < GetNEvents(); ++ievt) {
254 // read the training event
255 const Event* evt_ = GetEvent(ievt);
256 Double_t weight = evt_->GetWeight();
257
258 // in case event with neg weights are to be ignored
259 if (IgnoreEventsWithNegWeightsInTraining() && weight <= 0) continue;
260
261 kNN::VarVec vvec(GetNVariables(), 0.0);
262 for (UInt_t ivar = 0; ivar < evt_ -> GetNVariables(); ++ivar) vvec[ivar] = evt_->GetValue(ivar);
263
264 Short_t event_type = 0;
265
266 if (DataInfo().IsSignal(evt_)) { // signal type = 1
267 fSumOfWeightsS += weight;
268 event_type = 1;
269 }
270 else { // background type = 2
271 fSumOfWeightsB += weight;
272 event_type = 2;
273 }
274
275 //
276 // Create event and add classification variables, weight, type and regression variables
277 //
278 kNN::Event event_knn(vvec, weight, event_type);
279 event_knn.SetTargets(evt_->GetTargets());
280 fEvent.push_back(event_knn);
281
282 }
283 Log() << kINFO
284 << "Number of signal events " << fSumOfWeightsS << Endl
285 << "Number of background events " << fSumOfWeightsB << Endl;
286
287 // create kd-tree (binary tree) structure
288 MakeKNN();
289
290 ExitFromTraining();
291}
292
293////////////////////////////////////////////////////////////////////////////////
294/// Compute classifier response
295
297{
298 // cannot determine error
299 NoErrorCalc(err, errUpper);
300
301 //
302 // Define local variables
303 //
304 const Event *ev = GetEvent();
305 const Int_t nvar = GetNVariables();
306 const Double_t weight = ev->GetWeight();
307 const UInt_t knn = static_cast<UInt_t>(fnkNN);
308
309 kNN::VarVec vvec(static_cast<UInt_t>(nvar), 0.0);
310
311 for (Int_t ivar = 0; ivar < nvar; ++ivar) {
312 vvec[ivar] = ev->GetValue(ivar);
313 }
314
315 // search for fnkNN+2 nearest neighbors, pad with two
316 // events to avoid Monte-Carlo events with zero distance
317 // most of CPU time is spent in this recursive function
318 const kNN::Event event_knn(vvec, weight, 3);
319 fModule->Find(event_knn, knn + 2);
320
321 const kNN::List &rlist = fModule->GetkNNList();
322 if (rlist.size() != knn + 2) {
323 Log() << kFATAL << "kNN result list is empty" << Endl;
324 return -100.0;
325 }
326
327 if (fUseLDA) return MethodKNN::getLDAValue(rlist, event_knn);
328
329 //
330 // Set flags for kernel option=Gaus, Poln
331 //
332 Bool_t use_gaus = false, use_poln = false;
333 if (fUseKernel) {
334 if (fKernel == "Gaus") use_gaus = true;
335 else if (fKernel == "Poln") use_poln = true;
336 }
337
338 //
339 // Compute radius for polynomial kernel
340 //
341 Double_t kradius = -1.0;
342 if (use_poln) {
343 kradius = MethodKNN::getKernelRadius(rlist);
344
345 if (!(kradius > 0.0)) {
346 Log() << kFATAL << "kNN radius is not positive" << Endl;
347 return -100.0;
348 }
349
350 kradius = 1.0/TMath::Sqrt(kradius);
351 }
352
353 //
354 // Compute RMS of variable differences for Gaussian sigma
355 //
356 std::vector<Double_t> rms_vec;
357 if (use_gaus) {
358 rms_vec = TMVA::MethodKNN::getRMS(rlist, event_knn);
359
360 if (rms_vec.empty() || rms_vec.size() != event_knn.GetNVar()) {
361 Log() << kFATAL << "Failed to compute RMS vector" << Endl;
362 return -100.0;
363 }
364 }
365
366 UInt_t count_all = 0;
367 Double_t weight_all = 0, weight_sig = 0;
368
369 for (kNN::List::const_iterator lit = rlist.begin(); lit != rlist.end(); ++lit) {
370
371 // get reference to current node to make code more readable
372 const kNN::Node<kNN::Event> &node = *(lit->first);
373
374 // Warn about Monte-Carlo event with zero distance
375 // this happens when this query event is also in learning sample
376 if (lit->second < 0.0) {
377 Log() << kFATAL << "A neighbor has negative distance to query event" << Endl;
378 }
379 else if (!(lit->second > 0.0)) {
380 Log() << kVERBOSE << "A neighbor has zero distance to query event" << Endl;
381 }
382
383 // get event weight and scale weight by kernel function
384 Double_t evweight = node.GetWeight();
385 if (use_gaus) evweight *= MethodKNN::GausKernel(event_knn, node.GetEvent(), rms_vec);
386 else if (use_poln) evweight *= MethodKNN::PolnKernel(TMath::Sqrt(lit->second)*kradius);
387
388 if (fUseWeight) weight_all += evweight;
389 else ++weight_all;
390
391 if (node.GetEvent().GetType() == 1) { // signal type = 1
392 if (fUseWeight) weight_sig += evweight;
393 else ++weight_sig;
394 }
395 else if (node.GetEvent().GetType() == 2) { // background type = 2
396 }
397 else {
398 Log() << kFATAL << "Unknown type for training event" << Endl;
399 }
400
401 // use only fnkNN events
402 ++count_all;
403
404 if (count_all >= knn) {
405 break;
406 }
407 }
408
409 // check that total number of events or total weight sum is positive
410 if (!(count_all > 0)) {
411 Log() << kFATAL << "Size kNN result list is not positive" << Endl;
412 return -100.0;
413 }
414
415 // check that number of events matches number of k in knn
416 if (count_all < knn) {
417 Log() << kDEBUG << "count_all and kNN have different size: " << count_all << " < " << knn << Endl;
418 }
419
420 // Check that total weight is positive
421 if (!(weight_all > 0.0)) {
422 Log() << kFATAL << "kNN result total weight is not positive" << Endl;
423 return -100.0;
424 }
425
426 return weight_sig/weight_all;
427}
428
429////////////////////////////////////////////////////////////////////////////////
430/// Return vector of averages for target values of k-nearest neighbors.
431/// Use own copy of the regression vector, I do not like using a pointer to vector.
432
433const std::vector< Float_t >& TMVA::MethodKNN::GetRegressionValues()
434{
435 if( fRegressionReturnVal == 0 )
436 fRegressionReturnVal = new std::vector<Float_t>;
437 else
438 fRegressionReturnVal->clear();
439
440 //
441 // Define local variables
442 //
443 const Event *evt = GetEvent();
444 const Int_t nvar = GetNVariables();
445 const UInt_t knn = static_cast<UInt_t>(fnkNN);
446 std::vector<float> reg_vec;
447
448 kNN::VarVec vvec(static_cast<UInt_t>(nvar), 0.0);
449
450 for (Int_t ivar = 0; ivar < nvar; ++ivar) {
451 vvec[ivar] = evt->GetValue(ivar);
452 }
453
454 // search for fnkNN+2 nearest neighbors, pad with two
455 // events to avoid Monte-Carlo events with zero distance
456 // most of CPU time is spent in this recursive function
457 const kNN::Event event_knn(vvec, evt->GetWeight(), 3);
458 fModule->Find(event_knn, knn + 2);
459
460 const kNN::List &rlist = fModule->GetkNNList();
461 if (rlist.size() != knn + 2) {
462 Log() << kFATAL << "kNN result list is empty" << Endl;
463 return *fRegressionReturnVal;
464 }
465
466 // compute regression values
467 Double_t weight_all = 0;
468 UInt_t count_all = 0;
469
470 for (kNN::List::const_iterator lit = rlist.begin(); lit != rlist.end(); ++lit) {
471
472 // get reference to current node to make code more readable
473 const kNN::Node<kNN::Event> &node = *(lit->first);
474 const kNN::VarVec &tvec = node.GetEvent().GetTargets();
475 const Double_t weight = node.GetEvent().GetWeight();
476
477 if (reg_vec.empty()) {
478 reg_vec= kNN::VarVec(tvec.size(), 0.0);
479 }
480
481 for(UInt_t ivar = 0; ivar < tvec.size(); ++ivar) {
482 if (fUseWeight) reg_vec[ivar] += tvec[ivar]*weight;
483 else reg_vec[ivar] += tvec[ivar];
484 }
485
486 if (fUseWeight) weight_all += weight;
487 else ++weight_all;
488
489 // use only fnkNN events
490 ++count_all;
491
492 if (count_all == knn) {
493 break;
494 }
495 }
496
497 // check that number of events matches number of k in knn
498 if (!(weight_all > 0.0)) {
499 Log() << kFATAL << "Total weight sum is not positive: " << weight_all << Endl;
500 return *fRegressionReturnVal;
501 }
502
503 for (UInt_t ivar = 0; ivar < reg_vec.size(); ++ivar) {
504 reg_vec[ivar] /= weight_all;
505 }
506
507 // copy result
508 fRegressionReturnVal->insert(fRegressionReturnVal->begin(), reg_vec.begin(), reg_vec.end());
509
510 return *fRegressionReturnVal;
511}
512
513////////////////////////////////////////////////////////////////////////////////
514/// no ranking available
515
517{
518 return 0;
519}
520
521////////////////////////////////////////////////////////////////////////////////
522/// write weights to XML
523
524void TMVA::MethodKNN::AddWeightsXMLTo( void* parent ) const {
525 void* wght = gTools().AddChild(parent, "Weights");
526 gTools().AddAttr(wght,"NEvents",fEvent.size());
527 if (fEvent.size()>0) gTools().AddAttr(wght,"NVar",fEvent.begin()->GetNVar());
528 if (fEvent.size()>0) gTools().AddAttr(wght,"NTgt",fEvent.begin()->GetNTgt());
529
530 for (kNN::EventVec::const_iterator event = fEvent.begin(); event != fEvent.end(); ++event) {
531
532 std::stringstream s("");
533 s.precision( 16 );
534 for (UInt_t ivar = 0; ivar < event->GetNVar(); ++ivar) {
535 if (ivar>0) s << " ";
536 s << std::scientific << event->GetVar(ivar);
537 }
538
539 for (UInt_t itgt = 0; itgt < event->GetNTgt(); ++itgt) {
540 s << " " << std::scientific << event->GetTgt(itgt);
541 }
542
543 void* evt = gTools().AddChild(wght, "Event", s.str().c_str());
544 gTools().AddAttr(evt,"Type", event->GetType());
545 gTools().AddAttr(evt,"Weight", event->GetWeight());
546 }
547}
548
549////////////////////////////////////////////////////////////////////////////////
550
552 void* ch = gTools().GetChild(wghtnode); // first event
553 UInt_t nvar = 0, ntgt = 0;
554 gTools().ReadAttr( wghtnode, "NVar", nvar );
555 gTools().ReadAttr( wghtnode, "NTgt", ntgt );
556
557
558 Short_t evtType(0);
559 Double_t evtWeight(0);
560
561 while (ch) {
562 // build event
563 kNN::VarVec vvec(nvar, 0);
564 kNN::VarVec tvec(ntgt, 0);
565
566 gTools().ReadAttr( ch, "Type", evtType );
567 gTools().ReadAttr( ch, "Weight", evtWeight );
568 std::stringstream s( gTools().GetContent(ch) );
569
570 for(UInt_t ivar=0; ivar<nvar; ivar++)
571 s >> vvec[ivar];
572
573 for(UInt_t itgt=0; itgt<ntgt; itgt++)
574 s >> tvec[itgt];
575
576 ch = gTools().GetNextChild(ch);
577
578 kNN::Event event_knn(vvec, evtWeight, evtType, tvec);
579 fEvent.push_back(event_knn);
580 }
581
582 // create kd-tree (binary tree) structure
583 MakeKNN();
584}
585
586////////////////////////////////////////////////////////////////////////////////
587/// read the weights
588
590{
591 Log() << kINFO << "Starting ReadWeightsFromStream(std::istream& is) function..." << Endl;
592
593 if (!fEvent.empty()) {
594 Log() << kINFO << "Erasing " << fEvent.size() << " previously stored events" << Endl;
595 fEvent.clear();
596 }
597
598 UInt_t nvar = 0;
599
600 while (is) {
601 std::string line;
602 std::getline(is, line);
603
604 if (line.empty() || line.find("#") != std::string::npos) {
605 continue;
606 }
607
608 UInt_t count = 0;
609 std::string::size_type pos=0;
610 while( (pos=line.find(',',pos)) != std::string::npos ) { count++; pos++; }
611
612 if (nvar == 0) {
613 nvar = count - 2;
614 }
615 if (count < 3 || nvar != count - 2) {
616 Log() << kFATAL << "Missing comma delimeter(s)" << Endl;
617 }
618
619 // Int_t ievent = -1;
620 Int_t type = -1;
621 Double_t weight = -1.0;
622
623 kNN::VarVec vvec(nvar, 0.0);
624
625 UInt_t vcount = 0;
626 std::string::size_type prev = 0;
627
628 for (std::string::size_type ipos = 0; ipos < line.size(); ++ipos) {
629 if (line[ipos] != ',' && ipos + 1 != line.size()) {
630 continue;
631 }
632
633 if (!(ipos > prev)) {
634 Log() << kFATAL << "Wrong substring limits" << Endl;
635 }
636
637 std::string vstring = line.substr(prev, ipos - prev);
638 if (ipos + 1 == line.size()) {
639 vstring = line.substr(prev, ipos - prev + 1);
640 }
641
642 if (vstring.empty()) {
643 Log() << kFATAL << "Failed to parse string" << Endl;
644 }
645
646 if (vcount == 0) {
647 // ievent = std::atoi(vstring.c_str());
648 }
649 else if (vcount == 1) {
650 type = std::atoi(vstring.c_str());
651 }
652 else if (vcount == 2) {
653 weight = std::atof(vstring.c_str());
654 }
655 else if (vcount - 3 < vvec.size()) {
656 vvec[vcount - 3] = std::atof(vstring.c_str());
657 }
658 else {
659 Log() << kFATAL << "Wrong variable count" << Endl;
660 }
661
662 prev = ipos + 1;
663 ++vcount;
664 }
665
666 fEvent.push_back(kNN::Event(vvec, weight, type));
667 }
668
669 Log() << kINFO << "Read " << fEvent.size() << " events from text file" << Endl;
670
671 // create kd-tree (binary tree) structure
672 MakeKNN();
673}
674
675////////////////////////////////////////////////////////////////////////////////
676/// save weights to ROOT file
677
679{
680 Log() << kINFO << "Starting WriteWeightsToStream(TFile &rf) function..." << Endl;
681
682 if (fEvent.empty()) {
683 Log() << kWARNING << "MethodKNN contains no events " << Endl;
684 return;
685 }
686
687 kNN::Event *event = new kNN::Event();
688 TTree *tree = new TTree("knn", "event tree");
689 tree->SetDirectory(nullptr);
690 tree->Branch("event", "TMVA::kNN::Event", &event);
691
692 Double_t size = 0.0;
693 for (kNN::EventVec::const_iterator it = fEvent.begin(); it != fEvent.end(); ++it) {
694 (*event) = (*it);
695 size += tree->Fill();
696 }
697
698 // !!! hard coded tree name !!!
699 rf.WriteTObject(tree, "knn", "Overwrite");
700
701 // scale to MegaBytes
702 size /= 1048576.0;
703
704 Log() << kINFO << "Wrote " << size << "MB and " << fEvent.size()
705 << " events to ROOT file" << Endl;
706
707 delete tree;
708 delete event;
709}
710
711////////////////////////////////////////////////////////////////////////////////
712/// read weights from ROOT file
713
715{
716 Log() << kINFO << "Starting ReadWeightsFromStream(TFile &rf) function..." << Endl;
717
718 if (!fEvent.empty()) {
719 Log() << kINFO << "Erasing " << fEvent.size() << " previously stored events" << Endl;
720 fEvent.clear();
721 }
722
723 // !!! hard coded tree name !!!
724 TTree *tree = dynamic_cast<TTree *>(rf.Get("knn"));
725 if (!tree) {
726 Log() << kFATAL << "Failed to find knn tree" << Endl;
727 return;
728 }
729
730 kNN::Event *event = new kNN::Event();
731 tree->SetBranchAddress("event", &event);
732
733 const Int_t nevent = tree->GetEntries();
734
735 Double_t size = 0.0;
736 for (Int_t i = 0; i < nevent; ++i) {
737 size += tree->GetEntry(i);
738 fEvent.push_back(*event);
739 }
740
741 // scale to MegaBytes
742 size /= 1048576.0;
743
744 Log() << kINFO << "Read " << size << "MB and " << fEvent.size()
745 << " events from ROOT file" << Endl;
746
747 delete event;
748
749 // create kd-tree (binary tree) structure
750 MakeKNN();
751}
752
753////////////////////////////////////////////////////////////////////////////////
754/// write specific classifier response
755
756void TMVA::MethodKNN::MakeClassSpecific( std::ostream& fout, const TString& className ) const
757{
758 fout << " // not implemented for class: \"" << className << "\"" << std::endl;
759 fout << "};" << std::endl;
760}
761
762////////////////////////////////////////////////////////////////////////////////
763/// get help message text
764///
765/// typical length of text line:
766/// "|--------------------------------------------------------------|"
767
769{
770 Log() << Endl;
771 Log() << gTools().Color("bold") << "--- Short description:" << gTools().Color("reset") << Endl;
772 Log() << Endl;
773 Log() << "The k-nearest neighbor (k-NN) algorithm is a multi-dimensional classification" << Endl
774 << "and regression algorithm. Similarly to other TMVA algorithms, k-NN uses a set of" << Endl
775 << "training events for which a classification category/regression target is known. " << Endl
776 << "The k-NN method compares a test event to all training events using a distance " << Endl
777 << "function, which is an Euclidean distance in a space defined by the input variables. "<< Endl
778 << "The k-NN method, as implemented in TMVA, uses a kd-tree algorithm to perform a" << Endl
779 << "quick search for the k events with shortest distance to the test event. The method" << Endl
780 << "returns a fraction of signal events among the k neighbors. It is recommended" << Endl
781 << "that a histogram which stores the k-NN decision variable is binned with k+1 bins" << Endl
782 << "between 0 and 1." << Endl;
783
784 Log() << Endl;
785 Log() << gTools().Color("bold") << "--- Performance tuning via configuration options: "
786 << gTools().Color("reset") << Endl;
787 Log() << Endl;
788 Log() << "The k-NN method estimates a density of signal and background events in a "<< Endl
789 << "neighborhood around the test event. The method assumes that the density of the " << Endl
790 << "signal and background events is uniform and constant within the neighborhood. " << Endl
791 << "k is an adjustable parameter and it determines an average size of the " << Endl
792 << "neighborhood. Small k values (less than 10) are sensitive to statistical " << Endl
793 << "fluctuations and large (greater than 100) values might not sufficiently capture " << Endl
794 << "local differences between events in the training set. The speed of the k-NN" << Endl
795 << "method also increases with larger values of k. " << Endl;
796 Log() << Endl;
797 Log() << "The k-NN method assigns equal weight to all input variables. Different scales " << Endl
798 << "among the input variables is compensated using ScaleFrac parameter: the input " << Endl
799 << "variables are scaled so that the widths for central ScaleFrac*100% events are " << Endl
800 << "equal among all the input variables." << Endl;
801
802 Log() << Endl;
803 Log() << gTools().Color("bold") << "--- Additional configuration options: "
804 << gTools().Color("reset") << Endl;
805 Log() << Endl;
806 Log() << "The method inclues an option to use a Gaussian kernel to smooth out the k-NN" << Endl
807 << "response. The kernel re-weights events using a distance to the test event." << Endl;
808}
809
810////////////////////////////////////////////////////////////////////////////////
811/// polynomial kernel
812
814{
815 const Double_t avalue = TMath::Abs(value);
816
817 if (!(avalue < 1.0)) {
818 return 0.0;
819 }
820
821 const Double_t prod = 1.0 - avalue * avalue * avalue;
822
823 return (prod * prod * prod);
824}
825
826////////////////////////////////////////////////////////////////////////////////
827/// Gaussian kernel
828
830 const kNN::Event &event, const std::vector<Double_t> &svec) const
831{
832 if (event_knn.GetNVar() != event.GetNVar() || event_knn.GetNVar() != svec.size()) {
833 Log() << kFATAL << "Mismatched vectors in Gaussian kernel function" << Endl;
834 return 0.0;
835 }
836
837 //
838 // compute exponent
839 //
840 double sum_exp = 0.0;
841
842 for(unsigned int ivar = 0; ivar < event_knn.GetNVar(); ++ivar) {
843
844 const Double_t diff_ = event.GetVar(ivar) - event_knn.GetVar(ivar);
845 const Double_t sigm_ = svec[ivar];
846 if (!(sigm_ > 0.0)) {
847 Log() << kFATAL << "Bad sigma value = " << sigm_ << Endl;
848 return 0.0;
849 }
850
851 sum_exp += diff_*diff_/(2.0*sigm_*sigm_);
852 }
853
854 //
855 // Return unnormalized(!) Gaussian function, because normalization
856 // cancels for the ratio of weights.
857 //
858
859 return std::exp(-sum_exp);
860}
861
862////////////////////////////////////////////////////////////////////////////////
863///
864/// Get polynomial kernel radius
865///
866
868{
869 Double_t kradius = -1.0;
870 UInt_t kcount = 0;
871 const UInt_t knn = static_cast<UInt_t>(fnkNN);
872
873 for (kNN::List::const_iterator lit = rlist.begin(); lit != rlist.end(); ++lit)
874 {
875 if (!(lit->second > 0.0)) continue;
876
877 if (kradius < lit->second || kradius < 0.0) kradius = lit->second;
878
879 ++kcount;
880 if (kcount >= knn) break;
881 }
882
883 return kradius;
884}
885
886////////////////////////////////////////////////////////////////////////////////
887///
888/// Get polynomial kernel radius
889///
890
891const std::vector<Double_t> TMVA::MethodKNN::getRMS(const kNN::List &rlist, const kNN::Event &event_knn) const
892{
893 std::vector<Double_t> rvec;
894 UInt_t kcount = 0;
895 const UInt_t knn = static_cast<UInt_t>(fnkNN);
896
897 for (kNN::List::const_iterator lit = rlist.begin(); lit != rlist.end(); ++lit)
898 {
899 if (!(lit->second > 0.0)) continue;
900
901 const kNN::Node<kNN::Event> *node_ = lit -> first;
902 const kNN::Event &event_ = node_-> GetEvent();
903
904 if (rvec.empty()) {
905 rvec.insert(rvec.end(), event_.GetNVar(), 0.0);
906 }
907 else if (rvec.size() != event_.GetNVar()) {
908 Log() << kFATAL << "Wrong number of variables, should never happen!" << Endl;
909 rvec.clear();
910 return rvec;
911 }
912
913 for(unsigned int ivar = 0; ivar < event_.GetNVar(); ++ivar) {
914 const Double_t diff_ = event_.GetVar(ivar) - event_knn.GetVar(ivar);
915 rvec[ivar] += diff_*diff_;
916 }
917
918 ++kcount;
919 if (kcount >= knn) break;
920 }
921
922 if (kcount < 1) {
923 Log() << kFATAL << "Bad event kcount = " << kcount << Endl;
924 rvec.clear();
925 return rvec;
926 }
927
928 for(unsigned int ivar = 0; ivar < rvec.size(); ++ivar) {
929 if (!(rvec[ivar] > 0.0)) {
930 Log() << kFATAL << "Bad RMS value = " << rvec[ivar] << Endl;
931 rvec.clear();
932 return rvec;
933 }
934
935 rvec[ivar] = std::abs(fSigmaFact)*std::sqrt(rvec[ivar]/kcount);
936 }
937
938 return rvec;
939}
940
941////////////////////////////////////////////////////////////////////////////////
942
944{
945 LDAEvents sig_vec, bac_vec;
946
947 for (kNN::List::const_iterator lit = rlist.begin(); lit != rlist.end(); ++lit) {
948
949 // get reference to current node to make code more readable
950 const kNN::Node<kNN::Event> &node = *(lit->first);
951 const kNN::VarVec &tvec = node.GetEvent().GetVars();
952
953 if (node.GetEvent().GetType() == 1) { // signal type = 1
954 sig_vec.push_back(tvec);
955 }
956 else if (node.GetEvent().GetType() == 2) { // background type = 2
957 bac_vec.push_back(tvec);
958 }
959 else {
960 Log() << kFATAL << "Unknown type for training event" << Endl;
961 }
962 }
963
964 fLDA.Initialize(sig_vec, bac_vec);
965
966 return fLDA.GetProb(event_knn.GetVars(), 1);
967}
#define REGISTER_METHOD(CLASS)
for example
std::vector< std::vector< Float_t > > LDAEvents
Definition LDA.h:38
size_t size(const MatrixT &matrix)
retrieve the size of a square matrix
short Short_t
Definition RtypesCore.h:39
constexpr Bool_t kFALSE
Definition RtypesCore.h:94
constexpr Bool_t kTRUE
Definition RtypesCore.h:93
#define ClassImp(name)
Definition Rtypes.h:382
Option_t Option_t option
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void value
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t Atom_t Time_t type
Int_t WriteTObject(const TObject *obj, const char *name=nullptr, Option_t *option="", Int_t bufsize=0) override
Write object obj to this directory.
TObject * Get(const char *namecycle) override
Return pointer to object identified by namecycle.
A ROOT file is an on-disk file, usually with extension .root, that stores objects in a file-system-li...
Definition TFile.h:53
Class that contains all the data information.
Definition DataSetInfo.h:62
Float_t GetValue(UInt_t ivar) const
return value of i'th variable
Definition Event.cxx:236
std::vector< Float_t > & GetTargets()
Definition Event.h:103
Double_t GetWeight() const
return the event weight - depending on whether the flag IgnoreNegWeightsInTraining is or not.
Definition Event.cxx:389
Virtual base Class for all MVA method.
Definition MethodBase.h:111
virtual void DeclareCompatibilityOptions()
options that are used ONLY for the READER to ensure backward compatibility they are hence without any...
Analysis of k-nearest neighbor.
Definition MethodKNN.h:54
void Init(void)
Initialization.
void MakeKNN(void)
create kNN
virtual ~MethodKNN(void)
destructor
virtual Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets)
FDA can handle classification with 2 classes and regression with one regression-target.
const std::vector< Double_t > getRMS(const kNN::List &rlist, const kNN::Event &event_knn) const
Get polynomial kernel radius.
const Ranking * CreateRanking()
no ranking available
void MakeClassSpecific(std::ostream &, const TString &) const
write specific classifier response
MethodKNN(const TString &jobName, const TString &methodTitle, DataSetInfo &theData, const TString &theOption="KNN")
standard constructor
Definition MethodKNN.cxx:62
void DeclareCompatibilityOptions()
options that are used ONLY for the READER to ensure backward compatibility
Double_t getKernelRadius(const kNN::List &rlist) const
Get polynomial kernel radius.
void GetHelpMessage() const
get help message text
const std::vector< Float_t > & GetRegressionValues()
Return vector of averages for target values of k-nearest neighbors.
void Train(void)
kNN training
void ReadWeightsFromStream(std::istream &istr)
read the weights
double getLDAValue(const kNN::List &rlist, const kNN::Event &event_knn)
Double_t PolnKernel(Double_t value) const
polynomial kernel
void ProcessOptions()
process the options specified by the user
void ReadWeightsFromXML(void *wghtnode)
void AddWeightsXMLTo(void *parent) const
write weights to XML
Double_t GetMvaValue(Double_t *err=nullptr, Double_t *errUpper=nullptr)
Compute classifier response.
void DeclareOptions()
MethodKNN options.
void WriteWeightsToStream(TFile &rf) const
save weights to ROOT file
Double_t GausKernel(const kNN::Event &event_knn, const kNN::Event &event, const std::vector< Double_t > &svec) const
Gaussian kernel.
Ranking for variables in method (implementation)
Definition Ranking.h:48
const TString & Color(const TString &)
human readable color strings
Definition Tools.cxx:828
void ReadAttr(void *node, const char *, T &value)
read attribute from xml
Definition Tools.h:329
void * GetChild(void *parent, const char *childname=nullptr)
get child node
Definition Tools.cxx:1150
void AddAttr(void *node, const char *, const T &value, Int_t precision=16)
add attribute to xml
Definition Tools.h:347
void * AddChild(void *parent, const char *childname, const char *content=nullptr, bool isRootNode=false)
add child node
Definition Tools.cxx:1124
void * GetNextChild(void *prevchild, const char *childname=nullptr)
XML helpers.
Definition Tools.cxx:1162
Singleton class for Global types used by TMVA.
Definition Types.h:71
@ kClassification
Definition Types.h:127
@ kRegression
Definition Types.h:128
void SetTargets(const VarVec &tvec)
UInt_t GetNVar() const
Definition ModulekNN.h:188
VarType GetVar(UInt_t i) const
Definition ModulekNN.h:179
const VarVec & GetVars() const
This file contains binary tree and global function template that searches tree for k-nearest neigbors...
Definition NodekNN.h:68
Double_t GetWeight() const
Definition NodekNN.h:181
const T & GetEvent() const
Definition NodekNN.h:157
std::list< Elem > List
Definition ModulekNN.h:99
std::vector< VarType > VarVec
Definition ModulekNN.h:57
virtual void Clear(Option_t *="")
Definition TObject.h:119
Basic string class.
Definition TString.h:139
A TTree represents a columnar dataset.
Definition TTree.h:79
TLine * line
create variable transformations
Tools & gTools()
MsgLogger & Endl(MsgLogger &ml)
Definition MsgLogger.h:148
Double_t Sqrt(Double_t x)
Returns the square root of x.
Definition TMath.h:666
Short_t Abs(Short_t d)
Returns the absolute value of parameter Short_t d.
Definition TMathBase.h:123