Logo ROOT   6.12/07
Reference Guide
MethodKNN.cxx
Go to the documentation of this file.
1 // @(#)root/tmva $Id$
2 // Author: Rustem Ospanov
3 
4 /**********************************************************************************
5  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6  * Package: TMVA *
7  * Class : MethodKNN *
8  * Web : http://tmva.sourceforge.net *
9  * *
10  * Description: *
11  * Implementation *
12  * *
13  * Author: *
14  * Rustem Ospanov <rustem@fnal.gov> - U. of Texas at Austin, USA *
15  * *
16  * Copyright (c) 2007: *
17  * CERN, Switzerland *
18  * MPI-K Heidelberg, Germany *
19  * U. of Texas at Austin, USA *
20  * *
21  * Redistribution and use in source and binary forms, with or without *
22  * modification, are permitted according to the terms listed in LICENSE *
23  * (http://tmva.sourceforge.net/LICENSE) *
24  **********************************************************************************/
25 
26 /*! \class TMVA::MethodKNN
27 \ingroup TMVA
28 
29 Analysis of k-nearest neighbor.
30 
31 */
32 
33 #include "TMVA/MethodKNN.h"
34 
35 #include "TMVA/ClassifierFactory.h"
36 #include "TMVA/Configurable.h"
37 #include "TMVA/DataSetInfo.h"
38 #include "TMVA/Event.h"
39 #include "TMVA/LDA.h"
40 #include "TMVA/IMethod.h"
41 #include "TMVA/MethodBase.h"
42 #include "TMVA/MsgLogger.h"
43 #include "TMVA/Ranking.h"
44 #include "TMVA/Tools.h"
45 #include "TMVA/Types.h"
46 
47 #include "TFile.h"
48 #include "TMath.h"
49 #include "TTree.h"
50 
51 #include <cmath>
52 #include <string>
53 #include <cstdlib>
54 
55 REGISTER_METHOD(KNN)
56 
58 
59 ////////////////////////////////////////////////////////////////////////////////
60 /// standard constructor
61 
63  const TString& methodTitle,
64  DataSetInfo& theData,
65  const TString& theOption )
66  : TMVA::MethodBase(jobName, Types::kKNN, methodTitle, theData, theOption)
67  , fSumOfWeightsS(0)
68  , fSumOfWeightsB(0)
69  , fModule(0)
70  , fnkNN(0)
71  , fBalanceDepth(0)
72  , fScaleFrac(0)
73  , fSigmaFact(0)
74  , fTrim(kFALSE)
75  , fUseKernel(kFALSE)
76  , fUseWeight(kFALSE)
77  , fUseLDA(kFALSE)
78  , fTreeOptDepth(0)
79 {
80 }
81 
82 ////////////////////////////////////////////////////////////////////////////////
83 /// constructor from weight file
84 
86  const TString& theWeightFile)
87  : TMVA::MethodBase( Types::kKNN, theData, theWeightFile)
88  , fSumOfWeightsS(0)
89  , fSumOfWeightsB(0)
90  , fModule(0)
91  , fnkNN(0)
92  , fBalanceDepth(0)
93  , fScaleFrac(0)
94  , fSigmaFact(0)
95  , fTrim(kFALSE)
98  , fUseLDA(kFALSE)
99  , fTreeOptDepth(0)
100 {
101 }
102 
103 ////////////////////////////////////////////////////////////////////////////////
104 /// destructor
105 
107 {
108  if (fModule) delete fModule;
109 }
110 
111 ////////////////////////////////////////////////////////////////////////////////
112 /// MethodKNN options
113 ///
114 /// - fnkNN = 20; // number of k-nearest neighbors
115 /// - fBalanceDepth = 6; // number of binary tree levels used for tree balancing
116 /// - fScaleFrac = 0.8; // fraction of events used to compute variable width
117 /// - fSigmaFact = 1.0; // scale factor for Gaussian sigma
118 /// - fKernel = use polynomial (1-x^3)^3 or Gaussian kernel
119 /// - fTrim = false; // use equal number of signal and background events
120 /// - fUseKernel = false; // use polynomial kernel weight function
121 /// - fUseWeight = true; // count events using weights
122 /// - fUseLDA = false
123 
125 {
126  DeclareOptionRef(fnkNN = 20, "nkNN", "Number of k-nearest neighbors");
127  DeclareOptionRef(fBalanceDepth = 6, "BalanceDepth", "Binary tree balance depth");
128  DeclareOptionRef(fScaleFrac = 0.80, "ScaleFrac", "Fraction of events used to compute variable width");
129  DeclareOptionRef(fSigmaFact = 1.0, "SigmaFact", "Scale factor for sigma in Gaussian kernel");
130  DeclareOptionRef(fKernel = "Gaus", "Kernel", "Use polynomial (=Poln) or Gaussian (=Gaus) kernel");
131  DeclareOptionRef(fTrim = kFALSE, "Trim", "Use equal number of signal and background events");
132  DeclareOptionRef(fUseKernel = kFALSE, "UseKernel", "Use polynomial kernel weight");
133  DeclareOptionRef(fUseWeight = kTRUE, "UseWeight", "Use weight to count kNN events");
134  DeclareOptionRef(fUseLDA = kFALSE, "UseLDA", "Use local linear discriminant - experimental feature");
135 }
136 
137 ////////////////////////////////////////////////////////////////////////////////
138 /// options that are used ONLY for the READER to ensure backward compatibility
139 
142  DeclareOptionRef(fTreeOptDepth = 6, "TreeOptDepth", "Binary tree optimisation depth");
143 }
144 
145 ////////////////////////////////////////////////////////////////////////////////
146 /// process the options specified by the user
147 
149 {
150  if (!(fnkNN > 0)) {
151  fnkNN = 10;
152  Log() << kWARNING << "kNN must be a positive integer: set kNN = " << fnkNN << Endl;
153  }
154  if (fScaleFrac < 0.0) {
155  fScaleFrac = 0.0;
156  Log() << kWARNING << "ScaleFrac can not be negative: set ScaleFrac = " << fScaleFrac << Endl;
157  }
158  if (fScaleFrac > 1.0) {
159  fScaleFrac = 1.0;
160  }
161  if (!(fBalanceDepth > 0)) {
162  fBalanceDepth = 6;
163  Log() << kWARNING << "Optimize must be a positive integer: set Optimize = " << fBalanceDepth << Endl;
164  }
165 
166  Log() << kVERBOSE
167  << "kNN options: \n"
168  << " kNN = \n" << fnkNN
169  << " UseKernel = \n" << fUseKernel
170  << " SigmaFact = \n" << fSigmaFact
171  << " ScaleFrac = \n" << fScaleFrac
172  << " Kernel = \n" << fKernel
173  << " Trim = \n" << fTrim
174  << " Optimize = " << fBalanceDepth << Endl;
175 }
176 
177 ////////////////////////////////////////////////////////////////////////////////
178 /// FDA can handle classification with 2 classes and regression with one regression-target
179 
181 {
182  if (type == Types::kClassification && numberClasses == 2) return kTRUE;
183  if (type == Types::kRegression) return kTRUE;
184  return kFALSE;
185 }
186 
187 ////////////////////////////////////////////////////////////////////////////////
188 /// Initialization
189 
191 {
192  // fScaleFrac <= 0.0 then do not scale input variables
193  // fScaleFrac >= 1.0 then use all event coordinates to scale input variables
194 
195  fModule = new kNN::ModulekNN();
196  fSumOfWeightsS = 0;
197  fSumOfWeightsB = 0;
198 }
199 
200 ////////////////////////////////////////////////////////////////////////////////
201 /// create kNN
202 
204 {
205  if (!fModule) {
206  Log() << kFATAL << "ModulekNN is not created" << Endl;
207  }
208 
209  fModule->Clear();
210 
211  std::string option;
212  if (fScaleFrac > 0.0) {
213  option += "metric";
214  }
215  if (fTrim) {
216  option += "trim";
217  }
218 
219  Log() << kINFO << "Creating kd-tree with " << fEvent.size() << " events" << Endl;
220 
221  for (kNN::EventVec::const_iterator event = fEvent.begin(); event != fEvent.end(); ++event) {
222  fModule->Add(*event);
223  }
224 
225  // create binary tree
226  fModule->Fill(static_cast<UInt_t>(fBalanceDepth),
227  static_cast<UInt_t>(100.0*fScaleFrac),
228  option);
229 }
230 
231 ////////////////////////////////////////////////////////////////////////////////
232 /// kNN training
233 
235 {
236  Log() << kHEADER << "<Train> start..." << Endl;
237 
238  if (IsNormalised()) {
239  Log() << kINFO << "Input events are normalized - setting ScaleFrac to 0" << Endl;
240  fScaleFrac = 0.0;
241  }
242 
243  if (!fEvent.empty()) {
244  Log() << kINFO << "Erasing " << fEvent.size() << " previously stored events" << Endl;
245  fEvent.clear();
246  }
247  if (GetNVariables() < 1)
248  Log() << kFATAL << "MethodKNN::Train() - mismatched or wrong number of event variables" << Endl;
249 
250 
251  Log() << kINFO << "Reading " << GetNEvents() << " events" << Endl;
252 
253  for (UInt_t ievt = 0; ievt < GetNEvents(); ++ievt) {
254  // read the training event
255  const Event* evt_ = GetEvent(ievt);
256  Double_t weight = evt_->GetWeight();
257 
258  // in case event with neg weights are to be ignored
259  if (IgnoreEventsWithNegWeightsInTraining() && weight <= 0) continue;
260 
261  kNN::VarVec vvec(GetNVariables(), 0.0);
262  for (UInt_t ivar = 0; ivar < evt_ -> GetNVariables(); ++ivar) vvec[ivar] = evt_->GetValue(ivar);
263 
264  Short_t event_type = 0;
265 
266  if (DataInfo().IsSignal(evt_)) { // signal type = 1
267  fSumOfWeightsS += weight;
268  event_type = 1;
269  }
270  else { // background type = 2
271  fSumOfWeightsB += weight;
272  event_type = 2;
273  }
274 
275  //
276  // Create event and add classification variables, weight, type and regression variables
277  //
278  kNN::Event event_knn(vvec, weight, event_type);
279  event_knn.SetTargets(evt_->GetTargets());
280  fEvent.push_back(event_knn);
281 
282  }
283  Log() << kINFO
284  << "Number of signal events " << fSumOfWeightsS << Endl
285  << "Number of background events " << fSumOfWeightsB << Endl;
286 
287  // create kd-tree (binary tree) structure
288  MakeKNN();
289 
291 }
292 
293 ////////////////////////////////////////////////////////////////////////////////
294 /// Compute classifier response
295 
297 {
298  // cannot determine error
299  NoErrorCalc(err, errUpper);
300 
301  //
302  // Define local variables
303  //
304  const Event *ev = GetEvent();
305  const Int_t nvar = GetNVariables();
306  const Double_t weight = ev->GetWeight();
307  const UInt_t knn = static_cast<UInt_t>(fnkNN);
308 
309  kNN::VarVec vvec(static_cast<UInt_t>(nvar), 0.0);
310 
311  for (Int_t ivar = 0; ivar < nvar; ++ivar) {
312  vvec[ivar] = ev->GetValue(ivar);
313  }
314 
315  // search for fnkNN+2 nearest neighbors, pad with two
316  // events to avoid Monte-Carlo events with zero distance
317  // most of CPU time is spent in this recursive function
318  const kNN::Event event_knn(vvec, weight, 3);
319  fModule->Find(event_knn, knn + 2);
320 
321  const kNN::List &rlist = fModule->GetkNNList();
322  if (rlist.size() != knn + 2) {
323  Log() << kFATAL << "kNN result list is empty" << Endl;
324  return -100.0;
325  }
326 
327  if (fUseLDA) return MethodKNN::getLDAValue(rlist, event_knn);
328 
329  //
330  // Set flags for kernel option=Gaus, Poln
331  //
332  Bool_t use_gaus = false, use_poln = false;
333  if (fUseKernel) {
334  if (fKernel == "Gaus") use_gaus = true;
335  else if (fKernel == "Poln") use_poln = true;
336  }
337 
338  //
339  // Compute radius for polynomial kernel
340  //
341  Double_t kradius = -1.0;
342  if (use_poln) {
343  kradius = MethodKNN::getKernelRadius(rlist);
344 
345  if (!(kradius > 0.0)) {
346  Log() << kFATAL << "kNN radius is not positive" << Endl;
347  return -100.0;
348  }
349 
350  kradius = 1.0/TMath::Sqrt(kradius);
351  }
352 
353  //
354  // Compute RMS of variable differences for Gaussian sigma
355  //
356  std::vector<Double_t> rms_vec;
357  if (use_gaus) {
358  rms_vec = TMVA::MethodKNN::getRMS(rlist, event_knn);
359 
360  if (rms_vec.empty() || rms_vec.size() != event_knn.GetNVar()) {
361  Log() << kFATAL << "Failed to compute RMS vector" << Endl;
362  return -100.0;
363  }
364  }
365 
366  UInt_t count_all = 0;
367  Double_t weight_all = 0, weight_sig = 0, weight_bac = 0;
368 
369  for (kNN::List::const_iterator lit = rlist.begin(); lit != rlist.end(); ++lit) {
370 
371  // get reference to current node to make code more readable
372  const kNN::Node<kNN::Event> &node = *(lit->first);
373 
374  // Warn about Monte-Carlo event with zero distance
375  // this happens when this query event is also in learning sample
376  if (lit->second < 0.0) {
377  Log() << kFATAL << "A neighbor has negative distance to query event" << Endl;
378  }
379  else if (!(lit->second > 0.0)) {
380  Log() << kVERBOSE << "A neighbor has zero distance to query event" << Endl;
381  }
382 
383  // get event weight and scale weight by kernel function
384  Double_t evweight = node.GetWeight();
385  if (use_gaus) evweight *= MethodKNN::GausKernel(event_knn, node.GetEvent(), rms_vec);
386  else if (use_poln) evweight *= MethodKNN::PolnKernel(TMath::Sqrt(lit->second)*kradius);
387 
388  if (fUseWeight) weight_all += evweight;
389  else ++weight_all;
390 
391  if (node.GetEvent().GetType() == 1) { // signal type = 1
392  if (fUseWeight) weight_sig += evweight;
393  else ++weight_sig;
394  }
395  else if (node.GetEvent().GetType() == 2) { // background type = 2
396  if (fUseWeight) weight_bac += evweight;
397  else ++weight_bac;
398  }
399  else {
400  Log() << kFATAL << "Unknown type for training event" << Endl;
401  }
402 
403  // use only fnkNN events
404  ++count_all;
405 
406  if (count_all >= knn) {
407  break;
408  }
409  }
410 
411  // check that total number of events or total weight sum is positive
412  if (!(count_all > 0)) {
413  Log() << kFATAL << "Size kNN result list is not positive" << Endl;
414  return -100.0;
415  }
416 
417  // check that number of events matches number of k in knn
418  if (count_all < knn) {
419  Log() << kDEBUG << "count_all and kNN have different size: " << count_all << " < " << knn << Endl;
420  }
421 
422  // Check that total weight is positive
423  if (!(weight_all > 0.0)) {
424  Log() << kFATAL << "kNN result total weight is not positive" << Endl;
425  return -100.0;
426  }
427 
428  return weight_sig/weight_all;
429 }
430 
431 ////////////////////////////////////////////////////////////////////////////////
432 /// Return vector of averages for target values of k-nearest neighbors.
433 /// Use own copy of the regression vector, I do not like using a pointer to vector.
434 
435 const std::vector< Float_t >& TMVA::MethodKNN::GetRegressionValues()
436 {
437  if( fRegressionReturnVal == 0 )
438  fRegressionReturnVal = new std::vector<Float_t>;
439  else
440  fRegressionReturnVal->clear();
441 
442  //
443  // Define local variables
444  //
445  const Event *evt = GetEvent();
446  const Int_t nvar = GetNVariables();
447  const UInt_t knn = static_cast<UInt_t>(fnkNN);
448  std::vector<float> reg_vec;
449 
450  kNN::VarVec vvec(static_cast<UInt_t>(nvar), 0.0);
451 
452  for (Int_t ivar = 0; ivar < nvar; ++ivar) {
453  vvec[ivar] = evt->GetValue(ivar);
454  }
455 
456  // search for fnkNN+2 nearest neighbors, pad with two
457  // events to avoid Monte-Carlo events with zero distance
458  // most of CPU time is spent in this recursive function
459  const kNN::Event event_knn(vvec, evt->GetWeight(), 3);
460  fModule->Find(event_knn, knn + 2);
461 
462  const kNN::List &rlist = fModule->GetkNNList();
463  if (rlist.size() != knn + 2) {
464  Log() << kFATAL << "kNN result list is empty" << Endl;
465  return *fRegressionReturnVal;
466  }
467 
468  // compute regression values
469  Double_t weight_all = 0;
470  UInt_t count_all = 0;
471 
472  for (kNN::List::const_iterator lit = rlist.begin(); lit != rlist.end(); ++lit) {
473 
474  // get reference to current node to make code more readable
475  const kNN::Node<kNN::Event> &node = *(lit->first);
476  const kNN::VarVec &tvec = node.GetEvent().GetTargets();
477  const Double_t weight = node.GetEvent().GetWeight();
478 
479  if (reg_vec.empty()) {
480  reg_vec= kNN::VarVec(tvec.size(), 0.0);
481  }
482 
483  for(UInt_t ivar = 0; ivar < tvec.size(); ++ivar) {
484  if (fUseWeight) reg_vec[ivar] += tvec[ivar]*weight;
485  else reg_vec[ivar] += tvec[ivar];
486  }
487 
488  if (fUseWeight) weight_all += weight;
489  else ++weight_all;
490 
491  // use only fnkNN events
492  ++count_all;
493 
494  if (count_all == knn) {
495  break;
496  }
497  }
498 
499  // check that number of events matches number of k in knn
500  if (!(weight_all > 0.0)) {
501  Log() << kFATAL << "Total weight sum is not positive: " << weight_all << Endl;
502  return *fRegressionReturnVal;
503  }
504 
505  for (UInt_t ivar = 0; ivar < reg_vec.size(); ++ivar) {
506  reg_vec[ivar] /= weight_all;
507  }
508 
509  // copy result
510  fRegressionReturnVal->insert(fRegressionReturnVal->begin(), reg_vec.begin(), reg_vec.end());
511 
512  return *fRegressionReturnVal;
513 }
514 
515 ////////////////////////////////////////////////////////////////////////////////
516 /// no ranking available
517 
519 {
520  return 0;
521 }
522 
523 ////////////////////////////////////////////////////////////////////////////////
524 /// write weights to XML
525 
526 void TMVA::MethodKNN::AddWeightsXMLTo( void* parent ) const {
527  void* wght = gTools().AddChild(parent, "Weights");
528  gTools().AddAttr(wght,"NEvents",fEvent.size());
529  if (fEvent.size()>0) gTools().AddAttr(wght,"NVar",fEvent.begin()->GetNVar());
530  if (fEvent.size()>0) gTools().AddAttr(wght,"NTgt",fEvent.begin()->GetNTgt());
531 
532  for (kNN::EventVec::const_iterator event = fEvent.begin(); event != fEvent.end(); ++event) {
533 
534  std::stringstream s("");
535  s.precision( 16 );
536  for (UInt_t ivar = 0; ivar < event->GetNVar(); ++ivar) {
537  if (ivar>0) s << " ";
538  s << std::scientific << event->GetVar(ivar);
539  }
540 
541  for (UInt_t itgt = 0; itgt < event->GetNTgt(); ++itgt) {
542  s << " " << std::scientific << event->GetTgt(itgt);
543  }
544 
545  void* evt = gTools().AddChild(wght, "Event", s.str().c_str());
546  gTools().AddAttr(evt,"Type", event->GetType());
547  gTools().AddAttr(evt,"Weight", event->GetWeight());
548  }
549 }
550 
551 ////////////////////////////////////////////////////////////////////////////////
552 
553 void TMVA::MethodKNN::ReadWeightsFromXML( void* wghtnode ) {
554  void* ch = gTools().GetChild(wghtnode); // first event
555  UInt_t nvar = 0, ntgt = 0;
556  gTools().ReadAttr( wghtnode, "NVar", nvar );
557  gTools().ReadAttr( wghtnode, "NTgt", ntgt );
558 
559 
560  Short_t evtType(0);
561  Double_t evtWeight(0);
562 
563  while (ch) {
564  // build event
565  kNN::VarVec vvec(nvar, 0);
566  kNN::VarVec tvec(ntgt, 0);
567 
568  gTools().ReadAttr( ch, "Type", evtType );
569  gTools().ReadAttr( ch, "Weight", evtWeight );
570  std::stringstream s( gTools().GetContent(ch) );
571 
572  for(UInt_t ivar=0; ivar<nvar; ivar++)
573  s >> vvec[ivar];
574 
575  for(UInt_t itgt=0; itgt<ntgt; itgt++)
576  s >> tvec[itgt];
577 
578  ch = gTools().GetNextChild(ch);
579 
580  kNN::Event event_knn(vvec, evtWeight, evtType, tvec);
581  fEvent.push_back(event_knn);
582  }
583 
584  // create kd-tree (binary tree) structure
585  MakeKNN();
586 }
587 
588 ////////////////////////////////////////////////////////////////////////////////
589 /// read the weights
590 
592 {
593  Log() << kINFO << "Starting ReadWeightsFromStream(std::istream& is) function..." << Endl;
594 
595  if (!fEvent.empty()) {
596  Log() << kINFO << "Erasing " << fEvent.size() << " previously stored events" << Endl;
597  fEvent.clear();
598  }
599 
600  UInt_t nvar = 0;
601 
602  while (!is.eof()) {
603  std::string line;
604  std::getline(is, line);
605 
606  if (line.empty() || line.find("#") != std::string::npos) {
607  continue;
608  }
609 
610  UInt_t count = 0;
611  std::string::size_type pos=0;
612  while( (pos=line.find(',',pos)) != std::string::npos ) { count++; pos++; }
613 
614  if (nvar == 0) {
615  nvar = count - 2;
616  }
617  if (count < 3 || nvar != count - 2) {
618  Log() << kFATAL << "Missing comma delimeter(s)" << Endl;
619  }
620 
621  // Int_t ievent = -1;
622  Int_t type = -1;
623  Double_t weight = -1.0;
624 
625  kNN::VarVec vvec(nvar, 0.0);
626 
627  UInt_t vcount = 0;
628  std::string::size_type prev = 0;
629 
630  for (std::string::size_type ipos = 0; ipos < line.size(); ++ipos) {
631  if (line[ipos] != ',' && ipos + 1 != line.size()) {
632  continue;
633  }
634 
635  if (!(ipos > prev)) {
636  Log() << kFATAL << "Wrong substring limits" << Endl;
637  }
638 
639  std::string vstring = line.substr(prev, ipos - prev);
640  if (ipos + 1 == line.size()) {
641  vstring = line.substr(prev, ipos - prev + 1);
642  }
643 
644  if (vstring.empty()) {
645  Log() << kFATAL << "Failed to parse string" << Endl;
646  }
647 
648  if (vcount == 0) {
649  // ievent = std::atoi(vstring.c_str());
650  }
651  else if (vcount == 1) {
652  type = std::atoi(vstring.c_str());
653  }
654  else if (vcount == 2) {
655  weight = std::atof(vstring.c_str());
656  }
657  else if (vcount - 3 < vvec.size()) {
658  vvec[vcount - 3] = std::atof(vstring.c_str());
659  }
660  else {
661  Log() << kFATAL << "Wrong variable count" << Endl;
662  }
663 
664  prev = ipos + 1;
665  ++vcount;
666  }
667 
668  fEvent.push_back(kNN::Event(vvec, weight, type));
669  }
670 
671  Log() << kINFO << "Read " << fEvent.size() << " events from text file" << Endl;
672 
673  // create kd-tree (binary tree) structure
674  MakeKNN();
675 }
676 
677 ////////////////////////////////////////////////////////////////////////////////
678 /// save weights to ROOT file
679 
681 {
682  Log() << kINFO << "Starting WriteWeightsToStream(TFile &rf) function..." << Endl;
683 
684  if (fEvent.empty()) {
685  Log() << kWARNING << "MethodKNN contains no events " << Endl;
686  return;
687  }
688 
689  kNN::Event *event = new kNN::Event();
690  TTree *tree = new TTree("knn", "event tree");
691  tree->SetDirectory(0);
692  tree->Branch("event", "TMVA::kNN::Event", &event);
693 
694  Double_t size = 0.0;
695  for (kNN::EventVec::const_iterator it = fEvent.begin(); it != fEvent.end(); ++it) {
696  (*event) = (*it);
697  size += tree->Fill();
698  }
699 
700  // !!! hard coded tree name !!!
701  rf.WriteTObject(tree, "knn", "Overwrite");
702 
703  // scale to MegaBytes
704  size /= 1048576.0;
705 
706  Log() << kINFO << "Wrote " << size << "MB and " << fEvent.size()
707  << " events to ROOT file" << Endl;
708 
709  delete tree;
710  delete event;
711 }
712 
713 ////////////////////////////////////////////////////////////////////////////////
714 /// read weights from ROOT file
715 
717 {
718  Log() << kINFO << "Starting ReadWeightsFromStream(TFile &rf) function..." << Endl;
719 
720  if (!fEvent.empty()) {
721  Log() << kINFO << "Erasing " << fEvent.size() << " previously stored events" << Endl;
722  fEvent.clear();
723  }
724 
725  // !!! hard coded tree name !!!
726  TTree *tree = dynamic_cast<TTree *>(rf.Get("knn"));
727  if (!tree) {
728  Log() << kFATAL << "Failed to find knn tree" << Endl;
729  return;
730  }
731 
732  kNN::Event *event = new kNN::Event();
733  tree->SetBranchAddress("event", &event);
734 
735  const Int_t nevent = tree->GetEntries();
736 
737  Double_t size = 0.0;
738  for (Int_t i = 0; i < nevent; ++i) {
739  size += tree->GetEntry(i);
740  fEvent.push_back(*event);
741  }
742 
743  // scale to MegaBytes
744  size /= 1048576.0;
745 
746  Log() << kINFO << "Read " << size << "MB and " << fEvent.size()
747  << " events from ROOT file" << Endl;
748 
749  delete event;
750 
751  // create kd-tree (binary tree) structure
752  MakeKNN();
753 }
754 
755 ////////////////////////////////////////////////////////////////////////////////
756 /// write specific classifier response
757 
758 void TMVA::MethodKNN::MakeClassSpecific( std::ostream& fout, const TString& className ) const
759 {
760  fout << " // not implemented for class: \"" << className << "\"" << std::endl;
761  fout << "};" << std::endl;
762 }
763 
764 ////////////////////////////////////////////////////////////////////////////////
765 /// get help message text
766 ///
767 /// typical length of text line:
768 /// "|--------------------------------------------------------------|"
769 
771 {
772  Log() << Endl;
773  Log() << gTools().Color("bold") << "--- Short description:" << gTools().Color("reset") << Endl;
774  Log() << Endl;
775  Log() << "The k-nearest neighbor (k-NN) algorithm is a multi-dimensional classification" << Endl
776  << "and regression algorithm. Similarly to other TMVA algorithms, k-NN uses a set of" << Endl
777  << "training events for which a classification category/regression target is known. " << Endl
778  << "The k-NN method compares a test event to all training events using a distance " << Endl
779  << "function, which is an Euclidean distance in a space defined by the input variables. "<< Endl
780  << "The k-NN method, as implemented in TMVA, uses a kd-tree algorithm to perform a" << Endl
781  << "quick search for the k events with shortest distance to the test event. The method" << Endl
782  << "returns a fraction of signal events among the k neighbors. It is recommended" << Endl
783  << "that a histogram which stores the k-NN decision variable is binned with k+1 bins" << Endl
784  << "between 0 and 1." << Endl;
785 
786  Log() << Endl;
787  Log() << gTools().Color("bold") << "--- Performance tuning via configuration options: "
788  << gTools().Color("reset") << Endl;
789  Log() << Endl;
790  Log() << "The k-NN method estimates a density of signal and background events in a "<< Endl
791  << "neighborhood around the test event. The method assumes that the density of the " << Endl
792  << "signal and background events is uniform and constant within the neighborhood. " << Endl
793  << "k is an adjustable parameter and it determines an average size of the " << Endl
794  << "neighborhood. Small k values (less than 10) are sensitive to statistical " << Endl
795  << "fluctuations and large (greater than 100) values might not sufficiently capture " << Endl
796  << "local differences between events in the training set. The speed of the k-NN" << Endl
797  << "method also increases with larger values of k. " << Endl;
798  Log() << Endl;
799  Log() << "The k-NN method assigns equal weight to all input variables. Different scales " << Endl
800  << "among the input variables is compensated using ScaleFrac parameter: the input " << Endl
801  << "variables are scaled so that the widths for central ScaleFrac*100% events are " << Endl
802  << "equal among all the input variables." << Endl;
803 
804  Log() << Endl;
805  Log() << gTools().Color("bold") << "--- Additional configuration options: "
806  << gTools().Color("reset") << Endl;
807  Log() << Endl;
808  Log() << "The method inclues an option to use a Gaussian kernel to smooth out the k-NN" << Endl
809  << "response. The kernel re-weights events using a distance to the test event." << Endl;
810 }
811 
812 ////////////////////////////////////////////////////////////////////////////////
813 /// polynomial kernel
814 
816 {
817  const Double_t avalue = TMath::Abs(value);
818 
819  if (!(avalue < 1.0)) {
820  return 0.0;
821  }
822 
823  const Double_t prod = 1.0 - avalue * avalue * avalue;
824 
825  return (prod * prod * prod);
826 }
827 
828 ////////////////////////////////////////////////////////////////////////////////
829 /// Gaussian kernel
830 
832  const kNN::Event &event, const std::vector<Double_t> &svec) const
833 {
834  if (event_knn.GetNVar() != event.GetNVar() || event_knn.GetNVar() != svec.size()) {
835  Log() << kFATAL << "Mismatched vectors in Gaussian kernel function" << Endl;
836  return 0.0;
837  }
838 
839  //
840  // compute exponent
841  //
842  double sum_exp = 0.0;
843 
844  for(unsigned int ivar = 0; ivar < event_knn.GetNVar(); ++ivar) {
845 
846  const Double_t diff_ = event.GetVar(ivar) - event_knn.GetVar(ivar);
847  const Double_t sigm_ = svec[ivar];
848  if (!(sigm_ > 0.0)) {
849  Log() << kFATAL << "Bad sigma value = " << sigm_ << Endl;
850  return 0.0;
851  }
852 
853  sum_exp += diff_*diff_/(2.0*sigm_*sigm_);
854  }
855 
856  //
857  // Return unnormalized(!) Gaussian function, because normalization
858  // cancels for the ratio of weights.
859  //
860 
861  return std::exp(-sum_exp);
862 }
863 
864 ////////////////////////////////////////////////////////////////////////////////
865 ///
866 /// Get polynomial kernel radius
867 ///
868 
869 Double_t TMVA::MethodKNN::getKernelRadius(const kNN::List &rlist) const
870 {
871  Double_t kradius = -1.0;
872  UInt_t kcount = 0;
873  const UInt_t knn = static_cast<UInt_t>(fnkNN);
874 
875  for (kNN::List::const_iterator lit = rlist.begin(); lit != rlist.end(); ++lit)
876  {
877  if (!(lit->second > 0.0)) continue;
878 
879  if (kradius < lit->second || kradius < 0.0) kradius = lit->second;
880 
881  ++kcount;
882  if (kcount >= knn) break;
883  }
884 
885  return kradius;
886 }
887 
888 ////////////////////////////////////////////////////////////////////////////////
889 ///
890 /// Get polynomial kernel radius
891 ///
892 
893 const std::vector<Double_t> TMVA::MethodKNN::getRMS(const kNN::List &rlist, const kNN::Event &event_knn) const
894 {
895  std::vector<Double_t> rvec;
896  UInt_t kcount = 0;
897  const UInt_t knn = static_cast<UInt_t>(fnkNN);
898 
899  for (kNN::List::const_iterator lit = rlist.begin(); lit != rlist.end(); ++lit)
900  {
901  if (!(lit->second > 0.0)) continue;
902 
903  const kNN::Node<kNN::Event> *node_ = lit -> first;
904  const kNN::Event &event_ = node_-> GetEvent();
905 
906  if (rvec.empty()) {
907  rvec.insert(rvec.end(), event_.GetNVar(), 0.0);
908  }
909  else if (rvec.size() != event_.GetNVar()) {
910  Log() << kFATAL << "Wrong number of variables, should never happen!" << Endl;
911  rvec.clear();
912  return rvec;
913  }
914 
915  for(unsigned int ivar = 0; ivar < event_.GetNVar(); ++ivar) {
916  const Double_t diff_ = event_.GetVar(ivar) - event_knn.GetVar(ivar);
917  rvec[ivar] += diff_*diff_;
918  }
919 
920  ++kcount;
921  if (kcount >= knn) break;
922  }
923 
924  if (kcount < 1) {
925  Log() << kFATAL << "Bad event kcount = " << kcount << Endl;
926  rvec.clear();
927  return rvec;
928  }
929 
930  for(unsigned int ivar = 0; ivar < rvec.size(); ++ivar) {
931  if (!(rvec[ivar] > 0.0)) {
932  Log() << kFATAL << "Bad RMS value = " << rvec[ivar] << Endl;
933  rvec.clear();
934  return rvec;
935  }
936 
937  rvec[ivar] = std::abs(fSigmaFact)*std::sqrt(rvec[ivar]/kcount);
938  }
939 
940  return rvec;
941 }
942 
943 ////////////////////////////////////////////////////////////////////////////////
944 
945 Double_t TMVA::MethodKNN::getLDAValue(const kNN::List &rlist, const kNN::Event &event_knn)
946 {
947  LDAEvents sig_vec, bac_vec;
948 
949  for (kNN::List::const_iterator lit = rlist.begin(); lit != rlist.end(); ++lit) {
950 
951  // get reference to current node to make code more readable
952  const kNN::Node<kNN::Event> &node = *(lit->first);
953  const kNN::VarVec &tvec = node.GetEvent().GetVars();
954 
955  if (node.GetEvent().GetType() == 1) { // signal type = 1
956  sig_vec.push_back(tvec);
957  }
958  else if (node.GetEvent().GetType() == 2) { // background type = 2
959  bac_vec.push_back(tvec);
960  }
961  else {
962  Log() << kFATAL << "Unknown type for training event" << Endl;
963  }
964  }
965 
966  fLDA.Initialize(sig_vec, bac_vec);
967 
968  return fLDA.GetProb(event_knn.GetVars(), 1);
969 }
void ProcessOptions()
process the options specified by the user
Definition: MethodKNN.cxx:148
void AddWeightsXMLTo(void *parent) const
write weights to XML
Definition: MethodKNN.cxx:526
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:158
Singleton class for Global types used by TMVA.
Definition: Types.h:73
Bool_t fUseLDA
Definition: MethodKNN.h:135
void MakeClassSpecific(std::ostream &, const TString &) const
write specific classifier response
Definition: MethodKNN.cxx:758
TLine * line
void DeclareOptions()
MethodKNN options.
Definition: MethodKNN.cxx:124
virtual Int_t Fill()
Fill all branches.
Definition: TTree.cxx:4364
const List & GetkNNList() const
Definition: ModulekNN.h:204
MsgLogger & Log() const
Definition: Configurable.h:122
OptionBase * DeclareOptionRef(T &ref, const TString &name, const TString &desc="")
A ROOT file is a suite of consecutive data records (TKey instances) with a well defined format...
Definition: TFile.h:46
EAnalysisType
Definition: Types.h:125
void Train(void)
kNN training
Definition: MethodKNN.cxx:234
Virtual base Class for all MVA method.
Definition: MethodBase.h:109
virtual TObject * Get(const char *namecycle)
Return pointer to object identified by namecycle.
This file contains binary tree and global function template that searches tree for k-nearest neigbors...
Definition: NodekNN.h:66
void MakeKNN(void)
create kNN
Definition: MethodKNN.cxx:203
virtual Int_t GetEntry(Long64_t entry=0, Int_t getall=0)
Read all branches of entry and return total number of bytes read.
Definition: TTree.cxx:5330
Basic string class.
Definition: TString.h:125
Ranking for variables in method (implementation)
Definition: Ranking.h:48
int Int_t
Definition: RtypesCore.h:41
bool Bool_t
Definition: RtypesCore.h:59
Bool_t Find(Event event, UInt_t nfind=100, const std::string &option="count") const
find in tree if tree has been filled then search for nfind closest events if metic (fVarScale map) is...
Definition: ModulekNN.cxx:348
VarType GetVar(UInt_t i) const
Definition: ModulekNN.h:179
void AddAttr(void *node, const char *, const T &value, Int_t precision=16)
add attribute to xml
Definition: Tools.h:308
void * AddChild(void *parent, const char *childname, const char *content=0, bool isRootNode=false)
add child node
Definition: Tools.cxx:1135
Short_t Abs(Short_t d)
Definition: TMathBase.h:108
Bool_t fUseKernel
Definition: MethodKNN.h:133
void DeclareCompatibilityOptions()
options that are used ONLY for the READER to ensure backward compatibility
Definition: MethodKNN.cxx:140
virtual Int_t WriteTObject(const TObject *obj, const char *name=0, Option_t *option="", Int_t bufsize=0)
Write object obj to this directory.
void ReadWeightsFromStream(std::istream &istr)
read the weights
Definition: MethodKNN.cxx:591
const std::vector< Double_t > getRMS(const kNN::List &rlist, const kNN::Event &event_knn) const
Get polynomial kernel radius.
Definition: MethodKNN.cxx:893
double sqrt(double)
void Init(void)
Initialization.
Definition: MethodKNN.cxx:190
virtual Int_t SetBranchAddress(const char *bname, void *add, TBranch **ptr=0)
Change branch address, dealing with clone trees properly.
Definition: TTree.cxx:7898
TString fKernel
Definition: MethodKNN.h:130
Double_t GetWeight() const
Definition: NodekNN.h:180
Bool_t fUseWeight
Definition: MethodKNN.h:134
kNN::EventVec fEvent
Definition: MethodKNN.h:137
const Event * GetEvent() const
Definition: MethodBase.h:738
Float_t GetProb(const std::vector< Float_t > &x, Int_t k)
Signal probability with Gaussian approximation.
Definition: LDA.cxx:239
void * GetChild(void *parent, const char *childname=0)
get child node
Definition: Tools.cxx:1161
Double_t GausKernel(const kNN::Event &event_knn, const kNN::Event &event, const std::vector< Double_t > &svec) const
Gaussian kernel.
Definition: MethodKNN.cxx:831
DataSetInfo & DataInfo() const
Definition: MethodBase.h:399
Class that contains all the data information.
Definition: DataSetInfo.h:60
LDA fLDA
(untouched) events used for learning
Definition: MethodKNN.h:139
void SetTargets(const VarVec &tvec)
Definition: ModulekNN.cxx:107
Double_t GetWeight() const
return the event weight - depending on whether the flag IgnoreNegWeightsInTraining is or not...
Definition: Event.cxx:382
Bool_t Fill(const UShort_t odepth, UInt_t ifrac, const std::string &option="")
fill the tree
Definition: ModulekNN.cxx:245
static constexpr double second
Double_t PolnKernel(Double_t value) const
polynomial kernel
Definition: MethodKNN.cxx:815
void WriteWeightsToStream(TFile &rf) const
save weights to ROOT file
Definition: MethodKNN.cxx:680
std::vector< Float_t > & GetTargets()
Definition: Event.h:98
UInt_t GetNEvents() const
temporary event when testing on a different DataSet than the own one
Definition: MethodBase.h:406
Double_t GetMvaValue(Double_t *err=0, Double_t *errUpper=0)
Compute classifier response.
Definition: MethodKNN.cxx:296
Float_t fSigmaFact
Definition: MethodKNN.h:128
UInt_t GetNVar() const
Definition: ModulekNN.h:188
virtual ~MethodKNN(void)
destructor
Definition: MethodKNN.cxx:106
Int_t fTreeOptDepth
Experimental feature for local knn analysis.
Definition: MethodKNN.h:142
unsigned int UInt_t
Definition: RtypesCore.h:42
Double_t fSumOfWeightsS
Definition: MethodKNN.h:119
short Short_t
Definition: RtypesCore.h:35
void ReadAttr(void *node, const char *, T &value)
read attribute from xml
Definition: Tools.h:290
void Initialize(const LDAEvents &inputSignal, const LDAEvents &inputBackground)
Create LDA matrix using local events found by knn method.
Definition: LDA.cxx:68
Tools & gTools()
void Add(const Event &event)
add an event to tree
Definition: ModulekNN.cxx:212
const Ranking * CreateRanking()
no ranking available
Definition: MethodKNN.cxx:518
void GetHelpMessage() const
get help message text
Definition: MethodKNN.cxx:770
UInt_t GetNVariables() const
Definition: MethodBase.h:334
const Bool_t kFALSE
Definition: RtypesCore.h:88
Float_t GetValue(UInt_t ivar) const
return value of i&#39;th variable
Definition: Event.cxx:237
Bool_t IgnoreEventsWithNegWeightsInTraining() const
Definition: MethodBase.h:673
void ReadWeightsFromXML(void *wghtnode)
Definition: MethodKNN.cxx:553
virtual void SetDirectory(TDirectory *dir)
Change the tree&#39;s directory.
Definition: TTree.cxx:8464
#define ClassImp(name)
Definition: Rtypes.h:359
Int_t fBalanceDepth
Definition: MethodKNN.h:125
double Double_t
Definition: RtypesCore.h:55
Analysis of k-nearest neighbor.
Definition: MethodKNN.h:54
Double_t fSumOfWeightsB
Definition: MethodKNN.h:120
Bool_t IsNormalised() const
Definition: MethodBase.h:483
int type
Definition: TGX11.cxx:120
void * GetNextChild(void *prevchild, const char *childname=0)
XML helpers.
Definition: Tools.cxx:1173
static constexpr double s
void ExitFromTraining()
Definition: MethodBase.h:451
virtual Long64_t GetEntries() const
Definition: TTree.h:382
virtual Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets)
FDA can handle classification with 2 classes and regression with one regression-target.
Definition: MethodKNN.cxx:180
const TString & Color(const TString &)
human readable color strings
Definition: Tools.cxx:839
const VarVec & GetVars() const
Definition: ModulekNN.cxx:121
virtual Int_t Branch(TCollection *list, Int_t bufsize=32000, Int_t splitlevel=99, const char *name="")
Create one branch for each element in the collection.
Definition: TTree.cxx:1701
#define REGISTER_METHOD(CLASS)
for example
Abstract ClassifierFactory template that handles arbitrary types.
virtual void DeclareCompatibilityOptions()
options that are used ONLY for the READER to ensure backward compatibility they are hence without any...
Definition: MethodBase.cxx:601
void Clear()
clean up
Definition: ModulekNN.cxx:194
std::vector< std::vector< Float_t > > LDAEvents
Definition: LDA.h:38
Double_t getKernelRadius(const kNN::List &rlist) const
Get polynomial kernel radius.
Definition: MethodKNN.cxx:869
Definition: tree.py:1
Float_t fScaleFrac
Definition: MethodKNN.h:127
std::vector< Float_t > * fRegressionReturnVal
Definition: MethodBase.h:584
A TTree object has a header with a name and a title.
Definition: TTree.h:70
kNN::ModulekNN * fModule
Definition: MethodKNN.h:122
Definition: first.py:1
Double_t Sqrt(Double_t x)
Definition: TMath.h:590
double exp(double)
const std::vector< Float_t > & GetRegressionValues()
Return vector of averages for target values of k-nearest neighbors.
Definition: MethodKNN.cxx:435
const Bool_t kTRUE
Definition: RtypesCore.h:87
MethodKNN(const TString &jobName, const TString &methodTitle, DataSetInfo &theData, const TString &theOption="KNN")
standard constructor
Definition: MethodKNN.cxx:62
double getLDAValue(const kNN::List &rlist, const kNN::Event &event_knn)
Definition: MethodKNN.cxx:945
Int_t fnkNN
module where all work is done
Definition: MethodKNN.h:124
void NoErrorCalc(Double_t *const err, Double_t *const errUpper)
Definition: MethodBase.cxx:829
const T & GetEvent() const
Definition: NodekNN.h:156