Logo ROOT   6.08/07
Reference Guide
MethodKNN.cxx
Go to the documentation of this file.
1 // @(#)root/tmva $Id$
2 // Author: Rustem Ospanov
3 
4 /**********************************************************************************
5  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6  * Package: TMVA *
7  * Class : MethodKNN *
8  * Web : http://tmva.sourceforge.net *
9  * *
10  * Description: *
11  * Implementation *
12  * *
13  * Author: *
14  * Rustem Ospanov <rustem@fnal.gov> - U. of Texas at Austin, USA *
15  * *
16  * Copyright (c) 2007: *
17  * CERN, Switzerland *
18  * MPI-K Heidelberg, Germany *
19  * U. of Texas at Austin, USA *
20  * *
21  * Redistribution and use in source and binary forms, with or without *
22  * modification, are permitted according to the terms listed in LICENSE *
23  * (http://tmva.sourceforge.net/LICENSE) *
24  **********************************************************************************/
25 
26 //////////////////////////////////////////////////////////////////////////
27 // //
28 // MethodKNN //
29 // //
30 // Analysis of k-nearest neighbor //
31 // //
32 //////////////////////////////////////////////////////////////////////////
33 
34 #include "TMVA/MethodKNN.h"
35 
36 #include "TMVA/ClassifierFactory.h"
37 #include "TMVA/Configurable.h"
38 #include "TMVA/DataSetInfo.h"
39 #include "TMVA/Event.h"
40 #include "TMVA/LDA.h"
41 #include "TMVA/IMethod.h"
42 #include "TMVA/MethodBase.h"
43 #include "TMVA/MsgLogger.h"
44 #include "TMVA/Ranking.h"
45 #include "TMVA/Tools.h"
46 #include "TMVA/Types.h"
47 
48 #include "TFile.h"
49 #include "TMath.h"
50 #include "TTree.h"
51 
52 #include <cmath>
53 #include <string>
54 #include <cstdlib>
55 
56 REGISTER_METHOD(KNN)
57 
59 
60 ////////////////////////////////////////////////////////////////////////////////
61 /// standard constructor
62 
63  TMVA::MethodKNN::MethodKNN( const TString& jobName,
64  const TString& methodTitle,
65  DataSetInfo& theData,
66  const TString& theOption )
67  : TMVA::MethodBase(jobName, Types::kKNN, methodTitle, theData, theOption)
68  , fSumOfWeightsS(0)
69  , fSumOfWeightsB(0)
70  , fModule(0)
71  , fnkNN(0)
72  , fBalanceDepth(0)
73  , fScaleFrac(0)
74  , fSigmaFact(0)
75  , fTrim(kFALSE)
76  , fUseKernel(kFALSE)
77  , fUseWeight(kFALSE)
78  , fUseLDA(kFALSE)
79  , fTreeOptDepth(0)
80 {
81 }
82 
83 ////////////////////////////////////////////////////////////////////////////////
84 /// constructor from weight file
85 
87  const TString& theWeightFile)
88  : TMVA::MethodBase( Types::kKNN, theData, theWeightFile)
89  , fSumOfWeightsS(0)
90  , fSumOfWeightsB(0)
91  , fModule(0)
92  , fnkNN(0)
93  , fBalanceDepth(0)
94  , fScaleFrac(0)
95  , fSigmaFact(0)
96  , fTrim(kFALSE)
97  , fUseKernel(kFALSE)
98  , fUseWeight(kFALSE)
99  , fUseLDA(kFALSE)
100  , fTreeOptDepth(0)
101 {
102 }
103 
104 ////////////////////////////////////////////////////////////////////////////////
105 /// destructor
106 
108 {
109  if (fModule) delete fModule;
110 }
111 
112 ////////////////////////////////////////////////////////////////////////////////
113 /// MethodKNN options
114 
116 {
117  // fnkNN = 20; // number of k-nearest neighbors
118  // fBalanceDepth = 6; // number of binary tree levels used for tree balancing
119  // fScaleFrac = 0.8; // fraction of events used to compute variable width
120  // fSigmaFact = 1.0; // scale factor for Gaussian sigma
121  // fKernel = use polynomial (1-x^3)^3 or Gaussian kernel
122  // fTrim = false; // use equal number of signal and background events
123  // fUseKernel = false; // use polynomial kernel weight function
124  // fUseWeight = true; // count events using weights
125  // fUseLDA = false
126 
127  DeclareOptionRef(fnkNN = 20, "nkNN", "Number of k-nearest neighbors");
128  DeclareOptionRef(fBalanceDepth = 6, "BalanceDepth", "Binary tree balance depth");
129  DeclareOptionRef(fScaleFrac = 0.80, "ScaleFrac", "Fraction of events used to compute variable width");
130  DeclareOptionRef(fSigmaFact = 1.0, "SigmaFact", "Scale factor for sigma in Gaussian kernel");
131  DeclareOptionRef(fKernel = "Gaus", "Kernel", "Use polynomial (=Poln) or Gaussian (=Gaus) kernel");
132  DeclareOptionRef(fTrim = kFALSE, "Trim", "Use equal number of signal and background events");
133  DeclareOptionRef(fUseKernel = kFALSE, "UseKernel", "Use polynomial kernel weight");
134  DeclareOptionRef(fUseWeight = kTRUE, "UseWeight", "Use weight to count kNN events");
135  DeclareOptionRef(fUseLDA = kFALSE, "UseLDA", "Use local linear discriminant - experimental feature");
136 }
137 
138 ////////////////////////////////////////////////////////////////////////////////
139 /// options that are used ONLY for the READER to ensure backward compatibility
140 
143  DeclareOptionRef(fTreeOptDepth = 6, "TreeOptDepth", "Binary tree optimisation depth");
144 }
145 
146 ////////////////////////////////////////////////////////////////////////////////
147 /// process the options specified by the user
148 
150 {
151  if (!(fnkNN > 0)) {
152  fnkNN = 10;
153  Log() << kWARNING << "kNN must be a positive integer: set kNN = " << fnkNN << Endl;
154  }
155  if (fScaleFrac < 0.0) {
156  fScaleFrac = 0.0;
157  Log() << kWARNING << "ScaleFrac can not be negative: set ScaleFrac = " << fScaleFrac << Endl;
158  }
159  if (fScaleFrac > 1.0) {
160  fScaleFrac = 1.0;
161  }
162  if (!(fBalanceDepth > 0)) {
163  fBalanceDepth = 6;
164  Log() << kWARNING << "Optimize must be a positive integer: set Optimize = " << fBalanceDepth << Endl;
165  }
166 
167  Log() << kVERBOSE
168  << "kNN options: \n"
169  << " kNN = \n" << fnkNN
170  << " UseKernel = \n" << fUseKernel
171  << " SigmaFact = \n" << fSigmaFact
172  << " ScaleFrac = \n" << fScaleFrac
173  << " Kernel = \n" << fKernel
174  << " Trim = \n" << fTrim
175  << " Optimize = " << fBalanceDepth << Endl;
176 }
177 
178 ////////////////////////////////////////////////////////////////////////////////
179 /// FDA can handle classification with 2 classes and regression with one regression-target
180 
182 {
183  if (type == Types::kClassification && numberClasses == 2) return kTRUE;
184  if (type == Types::kRegression) return kTRUE;
185  return kFALSE;
186 }
187 
188 ////////////////////////////////////////////////////////////////////////////////
189 /// Initialization
190 
192 {
193  // fScaleFrac <= 0.0 then do not scale input variables
194  // fScaleFrac >= 1.0 then use all event coordinates to scale input variables
195 
196  fModule = new kNN::ModulekNN();
197  fSumOfWeightsS = 0;
198  fSumOfWeightsB = 0;
199 }
200 
201 ////////////////////////////////////////////////////////////////////////////////
202 /// create kNN
203 
205 {
206  if (!fModule) {
207  Log() << kFATAL << "ModulekNN is not created" << Endl;
208  }
209 
210  fModule->Clear();
211 
212  std::string option;
213  if (fScaleFrac > 0.0) {
214  option += "metric";
215  }
216  if (fTrim) {
217  option += "trim";
218  }
219 
220  Log() << kINFO << "Creating kd-tree with " << fEvent.size() << " events" << Endl;
221 
222  for (kNN::EventVec::const_iterator event = fEvent.begin(); event != fEvent.end(); ++event) {
223  fModule->Add(*event);
224  }
225 
226  // create binary tree
227  fModule->Fill(static_cast<UInt_t>(fBalanceDepth),
228  static_cast<UInt_t>(100.0*fScaleFrac),
229  option);
230 }
231 
232 ////////////////////////////////////////////////////////////////////////////////
233 /// kNN training
234 
236 {
237  Log() << kHEADER << "<Train> start..." << Endl;
238 
239  if (IsNormalised()) {
240  Log() << kINFO << "Input events are normalized - setting ScaleFrac to 0" << Endl;
241  fScaleFrac = 0.0;
242  }
243 
244  if (!fEvent.empty()) {
245  Log() << kINFO << "Erasing " << fEvent.size() << " previously stored events" << Endl;
246  fEvent.clear();
247  }
248  if (GetNVariables() < 1)
249  Log() << kFATAL << "MethodKNN::Train() - mismatched or wrong number of event variables" << Endl;
250 
251 
252  Log() << kINFO << "Reading " << GetNEvents() << " events" << Endl;
253 
254  for (UInt_t ievt = 0; ievt < GetNEvents(); ++ievt) {
255  // read the training event
256  const Event* evt_ = GetEvent(ievt);
257  Double_t weight = evt_->GetWeight();
258 
259  // in case event with neg weights are to be ignored
260  if (IgnoreEventsWithNegWeightsInTraining() && weight <= 0) continue;
261 
262  kNN::VarVec vvec(GetNVariables(), 0.0);
263  for (UInt_t ivar = 0; ivar < evt_ -> GetNVariables(); ++ivar) vvec[ivar] = evt_->GetValue(ivar);
264 
265  Short_t event_type = 0;
266 
267  if (DataInfo().IsSignal(evt_)) { // signal type = 1
268  fSumOfWeightsS += weight;
269  event_type = 1;
270  }
271  else { // background type = 2
272  fSumOfWeightsB += weight;
273  event_type = 2;
274  }
275 
276  //
277  // Create event and add classification variables, weight, type and regression variables
278  //
279  kNN::Event event_knn(vvec, weight, event_type);
280  event_knn.SetTargets(evt_->GetTargets());
281  fEvent.push_back(event_knn);
282 
283  }
284  Log() << kINFO
285  << "Number of signal events " << fSumOfWeightsS << Endl
286  << "Number of background events " << fSumOfWeightsB << Endl;
287 
288  // create kd-tree (binary tree) structure
289  MakeKNN();
290 
292 }
293 
294 ////////////////////////////////////////////////////////////////////////////////
295 /// Compute classifier response
296 
298 {
299  // cannot determine error
300  NoErrorCalc(err, errUpper);
301 
302  //
303  // Define local variables
304  //
305  const Event *ev = GetEvent();
306  const Int_t nvar = GetNVariables();
307  const Double_t weight = ev->GetWeight();
308  const UInt_t knn = static_cast<UInt_t>(fnkNN);
309 
310  kNN::VarVec vvec(static_cast<UInt_t>(nvar), 0.0);
311 
312  for (Int_t ivar = 0; ivar < nvar; ++ivar) {
313  vvec[ivar] = ev->GetValue(ivar);
314  }
315 
316  // search for fnkNN+2 nearest neighbors, pad with two
317  // events to avoid Monte-Carlo events with zero distance
318  // most of CPU time is spent in this recursive function
319  const kNN::Event event_knn(vvec, weight, 3);
320  fModule->Find(event_knn, knn + 2);
321 
322  const kNN::List &rlist = fModule->GetkNNList();
323  if (rlist.size() != knn + 2) {
324  Log() << kFATAL << "kNN result list is empty" << Endl;
325  return -100.0;
326  }
327 
328  if (fUseLDA) return MethodKNN::getLDAValue(rlist, event_knn);
329 
330  //
331  // Set flags for kernel option=Gaus, Poln
332  //
333  Bool_t use_gaus = false, use_poln = false;
334  if (fUseKernel) {
335  if (fKernel == "Gaus") use_gaus = true;
336  else if (fKernel == "Poln") use_poln = true;
337  }
338 
339  //
340  // Compute radius for polynomial kernel
341  //
342  Double_t kradius = -1.0;
343  if (use_poln) {
344  kradius = MethodKNN::getKernelRadius(rlist);
345 
346  if (!(kradius > 0.0)) {
347  Log() << kFATAL << "kNN radius is not positive" << Endl;
348  return -100.0;
349  }
350 
351  kradius = 1.0/TMath::Sqrt(kradius);
352  }
353 
354  //
355  // Compute RMS of variable differences for Gaussian sigma
356  //
357  std::vector<Double_t> rms_vec;
358  if (use_gaus) {
359  rms_vec = TMVA::MethodKNN::getRMS(rlist, event_knn);
360 
361  if (rms_vec.empty() || rms_vec.size() != event_knn.GetNVar()) {
362  Log() << kFATAL << "Failed to compute RMS vector" << Endl;
363  return -100.0;
364  }
365  }
366 
367  UInt_t count_all = 0;
368  Double_t weight_all = 0, weight_sig = 0, weight_bac = 0;
369 
370  for (kNN::List::const_iterator lit = rlist.begin(); lit != rlist.end(); ++lit) {
371 
372  // get reference to current node to make code more readable
373  const kNN::Node<kNN::Event> &node = *(lit->first);
374 
375  // Warn about Monte-Carlo event with zero distance
376  // this happens when this query event is also in learning sample
377  if (lit->second < 0.0) {
378  Log() << kFATAL << "A neighbor has negative distance to query event" << Endl;
379  }
380  else if (!(lit->second > 0.0)) {
381  Log() << kVERBOSE << "A neighbor has zero distance to query event" << Endl;
382  }
383 
384  // get event weight and scale weight by kernel function
385  Double_t evweight = node.GetWeight();
386  if (use_gaus) evweight *= MethodKNN::GausKernel(event_knn, node.GetEvent(), rms_vec);
387  else if (use_poln) evweight *= MethodKNN::PolnKernel(TMath::Sqrt(lit->second)*kradius);
388 
389  if (fUseWeight) weight_all += evweight;
390  else ++weight_all;
391 
392  if (node.GetEvent().GetType() == 1) { // signal type = 1
393  if (fUseWeight) weight_sig += evweight;
394  else ++weight_sig;
395  }
396  else if (node.GetEvent().GetType() == 2) { // background type = 2
397  if (fUseWeight) weight_bac += evweight;
398  else ++weight_bac;
399  }
400  else {
401  Log() << kFATAL << "Unknown type for training event" << Endl;
402  }
403 
404  // use only fnkNN events
405  ++count_all;
406 
407  if (count_all >= knn) {
408  break;
409  }
410  }
411 
412  // check that total number of events or total weight sum is positive
413  if (!(count_all > 0)) {
414  Log() << kFATAL << "Size kNN result list is not positive" << Endl;
415  return -100.0;
416  }
417 
418  // check that number of events matches number of k in knn
419  if (count_all < knn) {
420  Log() << kDEBUG << "count_all and kNN have different size: " << count_all << " < " << knn << Endl;
421  }
422 
423  // Check that total weight is positive
424  if (!(weight_all > 0.0)) {
425  Log() << kFATAL << "kNN result total weight is not positive" << Endl;
426  return -100.0;
427  }
428 
429  return weight_sig/weight_all;
430 }
431 
432 ////////////////////////////////////////////////////////////////////////////////
433 ///
434 /// Return vector of averages for target values of k-nearest neighbors.
435 /// Use own copy of the regression vector, I do not like using a pointer to vector.
436 ///
437 
438 const std::vector< Float_t >& TMVA::MethodKNN::GetRegressionValues()
439 {
440  if( fRegressionReturnVal == 0 )
441  fRegressionReturnVal = new std::vector<Float_t>;
442  else
443  fRegressionReturnVal->clear();
444 
445  //
446  // Define local variables
447  //
448  const Event *evt = GetEvent();
449  const Int_t nvar = GetNVariables();
450  const UInt_t knn = static_cast<UInt_t>(fnkNN);
451  std::vector<float> reg_vec;
452 
453  kNN::VarVec vvec(static_cast<UInt_t>(nvar), 0.0);
454 
455  for (Int_t ivar = 0; ivar < nvar; ++ivar) {
456  vvec[ivar] = evt->GetValue(ivar);
457  }
458 
459  // search for fnkNN+2 nearest neighbors, pad with two
460  // events to avoid Monte-Carlo events with zero distance
461  // most of CPU time is spent in this recursive function
462  const kNN::Event event_knn(vvec, evt->GetWeight(), 3);
463  fModule->Find(event_knn, knn + 2);
464 
465  const kNN::List &rlist = fModule->GetkNNList();
466  if (rlist.size() != knn + 2) {
467  Log() << kFATAL << "kNN result list is empty" << Endl;
468  return *fRegressionReturnVal;
469  }
470 
471  // compute regression values
472  Double_t weight_all = 0;
473  UInt_t count_all = 0;
474 
475  for (kNN::List::const_iterator lit = rlist.begin(); lit != rlist.end(); ++lit) {
476 
477  // get reference to current node to make code more readable
478  const kNN::Node<kNN::Event> &node = *(lit->first);
479  const kNN::VarVec &tvec = node.GetEvent().GetTargets();
480  const Double_t weight = node.GetEvent().GetWeight();
481 
482  if (reg_vec.empty()) {
483  reg_vec= kNN::VarVec(tvec.size(), 0.0);
484  }
485 
486  for(UInt_t ivar = 0; ivar < tvec.size(); ++ivar) {
487  if (fUseWeight) reg_vec[ivar] += tvec[ivar]*weight;
488  else reg_vec[ivar] += tvec[ivar];
489  }
490 
491  if (fUseWeight) weight_all += weight;
492  else ++weight_all;
493 
494  // use only fnkNN events
495  ++count_all;
496 
497  if (count_all == knn) {
498  break;
499  }
500  }
501 
502  // check that number of events matches number of k in knn
503  if (!(weight_all > 0.0)) {
504  Log() << kFATAL << "Total weight sum is not positive: " << weight_all << Endl;
505  return *fRegressionReturnVal;
506  }
507 
508  for (UInt_t ivar = 0; ivar < reg_vec.size(); ++ivar) {
509  reg_vec[ivar] /= weight_all;
510  }
511 
512  // copy result
513  fRegressionReturnVal->insert(fRegressionReturnVal->begin(), reg_vec.begin(), reg_vec.end());
514 
515  return *fRegressionReturnVal;
516 }
517 
518 ////////////////////////////////////////////////////////////////////////////////
519 /// no ranking available
520 
522 {
523  return 0;
524 }
525 
526 ////////////////////////////////////////////////////////////////////////////////
527 /// write weights to XML
528 
529 void TMVA::MethodKNN::AddWeightsXMLTo( void* parent ) const {
530  void* wght = gTools().AddChild(parent, "Weights");
531  gTools().AddAttr(wght,"NEvents",fEvent.size());
532  if (fEvent.size()>0) gTools().AddAttr(wght,"NVar",fEvent.begin()->GetNVar());
533  if (fEvent.size()>0) gTools().AddAttr(wght,"NTgt",fEvent.begin()->GetNTgt());
534 
535  for (kNN::EventVec::const_iterator event = fEvent.begin(); event != fEvent.end(); ++event) {
536 
537  std::stringstream s("");
538  s.precision( 16 );
539  for (UInt_t ivar = 0; ivar < event->GetNVar(); ++ivar) {
540  if (ivar>0) s << " ";
541  s << std::scientific << event->GetVar(ivar);
542  }
543 
544  for (UInt_t itgt = 0; itgt < event->GetNTgt(); ++itgt) {
545  s << " " << std::scientific << event->GetTgt(itgt);
546  }
547 
548  void* evt = gTools().AddChild(wght, "Event", s.str().c_str());
549  gTools().AddAttr(evt,"Type", event->GetType());
550  gTools().AddAttr(evt,"Weight", event->GetWeight());
551  }
552 }
553 
554 ////////////////////////////////////////////////////////////////////////////////
555 
556 void TMVA::MethodKNN::ReadWeightsFromXML( void* wghtnode ) {
557  void* ch = gTools().GetChild(wghtnode); // first event
558  UInt_t nvar = 0, ntgt = 0;
559  gTools().ReadAttr( wghtnode, "NVar", nvar );
560  gTools().ReadAttr( wghtnode, "NTgt", ntgt );
561 
562 
563  Short_t evtType(0);
564  Double_t evtWeight(0);
565 
566  while (ch) {
567  // build event
568  kNN::VarVec vvec(nvar, 0);
569  kNN::VarVec tvec(ntgt, 0);
570 
571  gTools().ReadAttr( ch, "Type", evtType );
572  gTools().ReadAttr( ch, "Weight", evtWeight );
573  std::stringstream s( gTools().GetContent(ch) );
574 
575  for(UInt_t ivar=0; ivar<nvar; ivar++)
576  s >> vvec[ivar];
577 
578  for(UInt_t itgt=0; itgt<ntgt; itgt++)
579  s >> tvec[itgt];
580 
581  ch = gTools().GetNextChild(ch);
582 
583  kNN::Event event_knn(vvec, evtWeight, evtType, tvec);
584  fEvent.push_back(event_knn);
585  }
586 
587  // create kd-tree (binary tree) structure
588  MakeKNN();
589 }
590 
591 ////////////////////////////////////////////////////////////////////////////////
592 /// read the weights
593 
595 {
596  Log() << kINFO << "Starting ReadWeightsFromStream(std::istream& is) function..." << Endl;
597 
598  if (!fEvent.empty()) {
599  Log() << kINFO << "Erasing " << fEvent.size() << " previously stored events" << Endl;
600  fEvent.clear();
601  }
602 
603  UInt_t nvar = 0;
604 
605  while (!is.eof()) {
606  std::string line;
607  std::getline(is, line);
608 
609  if (line.empty() || line.find("#") != std::string::npos) {
610  continue;
611  }
612 
613  UInt_t count = 0;
614  std::string::size_type pos=0;
615  while( (pos=line.find(',',pos)) != std::string::npos ) { count++; pos++; }
616 
617  if (nvar == 0) {
618  nvar = count - 2;
619  }
620  if (count < 3 || nvar != count - 2) {
621  Log() << kFATAL << "Missing comma delimeter(s)" << Endl;
622  }
623 
624  // Int_t ievent = -1;
625  Int_t type = -1;
626  Double_t weight = -1.0;
627 
628  kNN::VarVec vvec(nvar, 0.0);
629 
630  UInt_t vcount = 0;
631  std::string::size_type prev = 0;
632 
633  for (std::string::size_type ipos = 0; ipos < line.size(); ++ipos) {
634  if (line[ipos] != ',' && ipos + 1 != line.size()) {
635  continue;
636  }
637 
638  if (!(ipos > prev)) {
639  Log() << kFATAL << "Wrong substring limits" << Endl;
640  }
641 
642  std::string vstring = line.substr(prev, ipos - prev);
643  if (ipos + 1 == line.size()) {
644  vstring = line.substr(prev, ipos - prev + 1);
645  }
646 
647  if (vstring.empty()) {
648  Log() << kFATAL << "Failed to parse string" << Endl;
649  }
650 
651  if (vcount == 0) {
652  // ievent = std::atoi(vstring.c_str());
653  }
654  else if (vcount == 1) {
655  type = std::atoi(vstring.c_str());
656  }
657  else if (vcount == 2) {
658  weight = std::atof(vstring.c_str());
659  }
660  else if (vcount - 3 < vvec.size()) {
661  vvec[vcount - 3] = std::atof(vstring.c_str());
662  }
663  else {
664  Log() << kFATAL << "Wrong variable count" << Endl;
665  }
666 
667  prev = ipos + 1;
668  ++vcount;
669  }
670 
671  fEvent.push_back(kNN::Event(vvec, weight, type));
672  }
673 
674  Log() << kINFO << "Read " << fEvent.size() << " events from text file" << Endl;
675 
676  // create kd-tree (binary tree) structure
677  MakeKNN();
678 }
679 
680 ////////////////////////////////////////////////////////////////////////////////
681 /// save weights to ROOT file
682 
684 {
685  Log() << kINFO << "Starting WriteWeightsToStream(TFile &rf) function..." << Endl;
686 
687  if (fEvent.empty()) {
688  Log() << kWARNING << "MethodKNN contains no events " << Endl;
689  return;
690  }
691 
692  kNN::Event *event = new kNN::Event();
693  TTree *tree = new TTree("knn", "event tree");
694  tree->SetDirectory(0);
695  tree->Branch("event", "TMVA::kNN::Event", &event);
696 
697  Double_t size = 0.0;
698  for (kNN::EventVec::const_iterator it = fEvent.begin(); it != fEvent.end(); ++it) {
699  (*event) = (*it);
700  size += tree->Fill();
701  }
702 
703  // !!! hard coded tree name !!!
704  rf.WriteTObject(tree, "knn", "Overwrite");
705 
706  // scale to MegaBytes
707  size /= 1048576.0;
708 
709  Log() << kINFO << "Wrote " << size << "MB and " << fEvent.size()
710  << " events to ROOT file" << Endl;
711 
712  delete tree;
713  delete event;
714 }
715 
716 ////////////////////////////////////////////////////////////////////////////////
717 /// read weights from ROOT file
718 
720 {
721  Log() << kINFO << "Starting ReadWeightsFromStream(TFile &rf) function..." << Endl;
722 
723  if (!fEvent.empty()) {
724  Log() << kINFO << "Erasing " << fEvent.size() << " previously stored events" << Endl;
725  fEvent.clear();
726  }
727 
728  // !!! hard coded tree name !!!
729  TTree *tree = dynamic_cast<TTree *>(rf.Get("knn"));
730  if (!tree) {
731  Log() << kFATAL << "Failed to find knn tree" << Endl;
732  return;
733  }
734 
735  kNN::Event *event = new kNN::Event();
736  tree->SetBranchAddress("event", &event);
737 
738  const Int_t nevent = tree->GetEntries();
739 
740  Double_t size = 0.0;
741  for (Int_t i = 0; i < nevent; ++i) {
742  size += tree->GetEntry(i);
743  fEvent.push_back(*event);
744  }
745 
746  // scale to MegaBytes
747  size /= 1048576.0;
748 
749  Log() << kINFO << "Read " << size << "MB and " << fEvent.size()
750  << " events from ROOT file" << Endl;
751 
752  delete event;
753 
754  // create kd-tree (binary tree) structure
755  MakeKNN();
756 }
757 
758 ////////////////////////////////////////////////////////////////////////////////
759 /// write specific classifier response
760 
761 void TMVA::MethodKNN::MakeClassSpecific( std::ostream& fout, const TString& className ) const
762 {
763  fout << " // not implemented for class: \"" << className << "\"" << std::endl;
764  fout << "};" << std::endl;
765 }
766 
767 ////////////////////////////////////////////////////////////////////////////////
768 /// get help message text
769 ///
770 /// typical length of text line:
771 /// "|--------------------------------------------------------------|"
772 
774 {
775  Log() << Endl;
776  Log() << gTools().Color("bold") << "--- Short description:" << gTools().Color("reset") << Endl;
777  Log() << Endl;
778  Log() << "The k-nearest neighbor (k-NN) algorithm is a multi-dimensional classification" << Endl
779  << "and regression algorithm. Similarly to other TMVA algorithms, k-NN uses a set of" << Endl
780  << "training events for which a classification category/regression target is known. " << Endl
781  << "The k-NN method compares a test event to all training events using a distance " << Endl
782  << "function, which is an Euclidean distance in a space defined by the input variables. "<< Endl
783  << "The k-NN method, as implemented in TMVA, uses a kd-tree algorithm to perform a" << Endl
784  << "quick search for the k events with shortest distance to the test event. The method" << Endl
785  << "returns a fraction of signal events among the k neighbors. It is recommended" << Endl
786  << "that a histogram which stores the k-NN decision variable is binned with k+1 bins" << Endl
787  << "between 0 and 1." << Endl;
788 
789  Log() << Endl;
790  Log() << gTools().Color("bold") << "--- Performance tuning via configuration options: "
791  << gTools().Color("reset") << Endl;
792  Log() << Endl;
793  Log() << "The k-NN method estimates a density of signal and background events in a "<< Endl
794  << "neighborhood around the test event. The method assumes that the density of the " << Endl
795  << "signal and background events is uniform and constant within the neighborhood. " << Endl
796  << "k is an adjustable parameter and it determines an average size of the " << Endl
797  << "neighborhood. Small k values (less than 10) are sensitive to statistical " << Endl
798  << "fluctuations and large (greater than 100) values might not sufficiently capture " << Endl
799  << "local differences between events in the training set. The speed of the k-NN" << Endl
800  << "method also increases with larger values of k. " << Endl;
801  Log() << Endl;
802  Log() << "The k-NN method assigns equal weight to all input variables. Different scales " << Endl
803  << "among the input variables is compensated using ScaleFrac parameter: the input " << Endl
804  << "variables are scaled so that the widths for central ScaleFrac*100% events are " << Endl
805  << "equal among all the input variables." << Endl;
806 
807  Log() << Endl;
808  Log() << gTools().Color("bold") << "--- Additional configuration options: "
809  << gTools().Color("reset") << Endl;
810  Log() << Endl;
811  Log() << "The method inclues an option to use a Gaussian kernel to smooth out the k-NN" << Endl
812  << "response. The kernel re-weights events using a distance to the test event." << Endl;
813 }
814 
815 ////////////////////////////////////////////////////////////////////////////////
816 /// polynomial kernel
817 
819 {
820  const Double_t avalue = TMath::Abs(value);
821 
822  if (!(avalue < 1.0)) {
823  return 0.0;
824  }
825 
826  const Double_t prod = 1.0 - avalue * avalue * avalue;
827 
828  return (prod * prod * prod);
829 }
830 
831 ////////////////////////////////////////////////////////////////////////////////
832 /// Gaussian kernel
833 
835  const kNN::Event &event, const std::vector<Double_t> &svec) const
836 {
837  if (event_knn.GetNVar() != event.GetNVar() || event_knn.GetNVar() != svec.size()) {
838  Log() << kFATAL << "Mismatched vectors in Gaussian kernel function" << Endl;
839  return 0.0;
840  }
841 
842  //
843  // compute exponent
844  //
845  double sum_exp = 0.0;
846 
847  for(unsigned int ivar = 0; ivar < event_knn.GetNVar(); ++ivar) {
848 
849  const Double_t diff_ = event.GetVar(ivar) - event_knn.GetVar(ivar);
850  const Double_t sigm_ = svec[ivar];
851  if (!(sigm_ > 0.0)) {
852  Log() << kFATAL << "Bad sigma value = " << sigm_ << Endl;
853  return 0.0;
854  }
855 
856  sum_exp += diff_*diff_/(2.0*sigm_*sigm_);
857  }
858 
859  //
860  // Return unnormalized(!) Gaussian function, because normalization
861  // cancels for the ratio of weights.
862  //
863 
864  return std::exp(-sum_exp);
865 }
866 
867 ////////////////////////////////////////////////////////////////////////////////
868 ///
869 /// Get polynomial kernel radius
870 ///
871 
873 {
874  Double_t kradius = -1.0;
875  UInt_t kcount = 0;
876  const UInt_t knn = static_cast<UInt_t>(fnkNN);
877 
878  for (kNN::List::const_iterator lit = rlist.begin(); lit != rlist.end(); ++lit)
879  {
880  if (!(lit->second > 0.0)) continue;
881 
882  if (kradius < lit->second || kradius < 0.0) kradius = lit->second;
883 
884  ++kcount;
885  if (kcount >= knn) break;
886  }
887 
888  return kradius;
889 }
890 
891 ////////////////////////////////////////////////////////////////////////////////
892 ///
893 /// Get polynomial kernel radius
894 ///
895 
896 const std::vector<Double_t> TMVA::MethodKNN::getRMS(const kNN::List &rlist, const kNN::Event &event_knn) const
897 {
898  std::vector<Double_t> rvec;
899  UInt_t kcount = 0;
900  const UInt_t knn = static_cast<UInt_t>(fnkNN);
901 
902  for (kNN::List::const_iterator lit = rlist.begin(); lit != rlist.end(); ++lit)
903  {
904  if (!(lit->second > 0.0)) continue;
905 
906  const kNN::Node<kNN::Event> *node_ = lit -> first;
907  const kNN::Event &event_ = node_-> GetEvent();
908 
909  if (rvec.empty()) {
910  rvec.insert(rvec.end(), event_.GetNVar(), 0.0);
911  }
912  else if (rvec.size() != event_.GetNVar()) {
913  Log() << kFATAL << "Wrong number of variables, should never happen!" << Endl;
914  rvec.clear();
915  return rvec;
916  }
917 
918  for(unsigned int ivar = 0; ivar < event_.GetNVar(); ++ivar) {
919  const Double_t diff_ = event_.GetVar(ivar) - event_knn.GetVar(ivar);
920  rvec[ivar] += diff_*diff_;
921  }
922 
923  ++kcount;
924  if (kcount >= knn) break;
925  }
926 
927  if (kcount < 1) {
928  Log() << kFATAL << "Bad event kcount = " << kcount << Endl;
929  rvec.clear();
930  return rvec;
931  }
932 
933  for(unsigned int ivar = 0; ivar < rvec.size(); ++ivar) {
934  if (!(rvec[ivar] > 0.0)) {
935  Log() << kFATAL << "Bad RMS value = " << rvec[ivar] << Endl;
936  rvec.clear();
937  return rvec;
938  }
939 
940  rvec[ivar] = std::abs(fSigmaFact)*std::sqrt(rvec[ivar]/kcount);
941  }
942 
943  return rvec;
944 }
945 
946 ////////////////////////////////////////////////////////////////////////////////
947 
949 {
950  LDAEvents sig_vec, bac_vec;
951 
952  for (kNN::List::const_iterator lit = rlist.begin(); lit != rlist.end(); ++lit) {
953 
954  // get reference to current node to make code more readable
955  const kNN::Node<kNN::Event> &node = *(lit->first);
956  const kNN::VarVec &tvec = node.GetEvent().GetVars();
957 
958  if (node.GetEvent().GetType() == 1) { // signal type = 1
959  sig_vec.push_back(tvec);
960  }
961  else if (node.GetEvent().GetType() == 2) { // background type = 2
962  bac_vec.push_back(tvec);
963  }
964  else {
965  Log() << kFATAL << "Unknown type for training event" << Endl;
966  }
967  }
968 
969  fLDA.Initialize(sig_vec, bac_vec);
970 
971  return fLDA.GetProb(event_knn.GetVars(), 1);
972 }
void ProcessOptions()
process the options specified by the user
Definition: MethodKNN.cxx:149
void AddWeightsXMLTo(void *parent) const
write weights to XML
Definition: MethodKNN.cxx:529
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:162
Bool_t fUseLDA
Definition: MethodKNN.h:141
void MakeClassSpecific(std::ostream &, const TString &) const
write specific classifier response
Definition: MethodKNN.cxx:761
TLine * line
void DeclareOptions()
MethodKNN options.
Definition: MethodKNN.cxx:115
virtual Int_t Fill()
Fill all branches.
Definition: TTree.cxx:4375
const List & GetkNNList() const
Definition: ModulekNN.h:212
MsgLogger & Log() const
Definition: Configurable.h:128
OptionBase * DeclareOptionRef(T &ref, const TString &name, const TString &desc="")
A ROOT file is a suite of consecutive data records (TKey instances) with a well defined format...
Definition: TFile.h:50
EAnalysisType
Definition: Types.h:129
void Train(void)
kNN training
Definition: MethodKNN.cxx:235
virtual TObject * Get(const char *namecycle)
Return pointer to object identified by namecycle.
void MakeKNN(void)
create kNN
Definition: MethodKNN.cxx:204
virtual Int_t GetEntry(Long64_t entry=0, Int_t getall=0)
Read all branches of entry and return total number of bytes read.
Definition: TTree.cxx:5211
Basic string class.
Definition: TString.h:137
int Int_t
Definition: RtypesCore.h:41
bool Bool_t
Definition: RtypesCore.h:59
const Bool_t kFALSE
Definition: Rtypes.h:92
Bool_t Find(Event event, UInt_t nfind=100, const std::string &option="count") const
find in tree if tree has been filled then search for nfind closest events if metic (fVarScale map) is...
Definition: ModulekNN.cxx:342
VarType GetVar(UInt_t i) const
Definition: ModulekNN.h:187
void AddAttr(void *node, const char *, const T &value, Int_t precision=16)
Definition: Tools.h:309
void * AddChild(void *parent, const char *childname, const char *content=0, bool isRootNode=false)
add child node
Definition: Tools.cxx:1134
Short_t Abs(Short_t d)
Definition: TMathBase.h:110
Bool_t fUseKernel
Definition: MethodKNN.h:139
void DeclareCompatibilityOptions()
options that are used ONLY for the READER to ensure backward compatibility
Definition: MethodKNN.cxx:141
virtual Int_t WriteTObject(const TObject *obj, const char *name=0, Option_t *option="", Int_t bufsize=0)
Write object obj to this directory.
void ReadWeightsFromStream(std::istream &istr)
read the weights
Definition: MethodKNN.cxx:594
const std::vector< Double_t > getRMS(const kNN::List &rlist, const kNN::Event &event_knn) const
Get polynomial kernel radius.
Definition: MethodKNN.cxx:896
double sqrt(double)
Tools & gTools()
Definition: Tools.cxx:79
void Init(void)
Initialization.
Definition: MethodKNN.cxx:191
virtual Int_t SetBranchAddress(const char *bname, void *add, TBranch **ptr=0)
Change branch address, dealing with clone trees properly.
Definition: TTree.cxx:7760
TString fKernel
Definition: MethodKNN.h:136
Double_t GetWeight() const
Definition: NodekNN.h:185
Bool_t fUseWeight
Definition: MethodKNN.h:140
kNN::EventVec fEvent
Definition: MethodKNN.h:143
const Event * GetEvent() const
Definition: MethodBase.h:745
std::list< Elem > List
Definition: ModulekNN.h:107
Float_t GetProb(const std::vector< Float_t > &x, Int_t k)
Signal probability with Gaussian approximation.
Definition: LDA.cxx:244
void * GetChild(void *parent, const char *childname=0)
get child node
Definition: Tools.cxx:1158
Double_t GausKernel(const kNN::Event &event_knn, const kNN::Event &event, const std::vector< Double_t > &svec) const
Gaussian kernel.
Definition: MethodKNN.cxx:834
DataSetInfo & DataInfo() const
Definition: MethodBase.h:406
LDA fLDA
(untouched) events used for learning
Definition: MethodKNN.h:145
void SetTargets(const VarVec &tvec)
Definition: ModulekNN.cxx:101
Double_t GetWeight() const
return the event weight - depending on whether the flag IgnoreNegWeightsInTraining is or not...
Definition: Event.cxx:378
Bool_t Fill(const UShort_t odepth, UInt_t ifrac, const std::string &option="")
fill the tree
Definition: ModulekNN.cxx:239
Double_t PolnKernel(Double_t value) const
polynomial kernel
Definition: MethodKNN.cxx:818
void WriteWeightsToStream(TFile &rf) const
save weights to ROOT file
Definition: MethodKNN.cxx:683
std::vector< Float_t > & GetTargets()
Definition: Event.h:105
UInt_t GetNEvents() const
temporary event when testing on a different DataSet than the own one
Definition: MethodBase.h:413
Double_t GetMvaValue(Double_t *err=0, Double_t *errUpper=0)
Compute classifier response.
Definition: MethodKNN.cxx:297
Float_t fSigmaFact
Definition: MethodKNN.h:134
UInt_t GetNVar() const
Definition: ModulekNN.h:196
virtual ~MethodKNN(void)
destructor
Definition: MethodKNN.cxx:107
Int_t fTreeOptDepth
Experimental feature for local knn analysis.
Definition: MethodKNN.h:148
unsigned int UInt_t
Definition: RtypesCore.h:42
Double_t fSumOfWeightsS
Definition: MethodKNN.h:125
short Short_t
Definition: RtypesCore.h:35
void ReadAttr(void *node, const char *, T &value)
Definition: Tools.h:296
void Initialize(const LDAEvents &inputSignal, const LDAEvents &inputBackground)
Create LDA matrix using local events found by knn method.
Definition: LDA.cxx:73
void Add(const Event &event)
add an event to tree
Definition: ModulekNN.cxx:206
const Ranking * CreateRanking()
no ranking available
Definition: MethodKNN.cxx:521
void GetHelpMessage() const
get help message text
Definition: MethodKNN.cxx:773
UInt_t GetNVariables() const
Definition: MethodBase.h:341
Float_t GetValue(UInt_t ivar) const
return value of i&#39;th variable
Definition: Event.cxx:233
Bool_t IgnoreEventsWithNegWeightsInTraining() const
Definition: MethodBase.h:680
void ReadWeightsFromXML(void *wghtnode)
Definition: MethodKNN.cxx:556
virtual void SetDirectory(TDirectory *dir)
Change the tree&#39;s directory.
Definition: TTree.cxx:8326
#define ClassImp(name)
Definition: Rtypes.h:279
Int_t fBalanceDepth
Definition: MethodKNN.h:131
double Double_t
Definition: RtypesCore.h:55
Double_t fSumOfWeightsB
Definition: MethodKNN.h:126
Bool_t IsNormalised() const
Definition: MethodBase.h:490
int type
Definition: TGX11.cxx:120
void * GetNextChild(void *prevchild, const char *childname=0)
XML helpers.
Definition: Tools.cxx:1170
void ExitFromTraining()
Definition: MethodBase.h:458
virtual Long64_t GetEntries() const
Definition: TTree.h:393
virtual Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets)
FDA can handle classification with 2 classes and regression with one regression-target.
Definition: MethodKNN.cxx:181
const TString & Color(const TString &)
human readable color strings
Definition: Tools.cxx:837
const VarVec & GetVars() const
Definition: ModulekNN.cxx:115
virtual Int_t Branch(TCollection *list, Int_t bufsize=32000, Int_t splitlevel=99, const char *name="")
Create one branch for each element in the collection.
Definition: TTree.cxx:1652
#define REGISTER_METHOD(CLASS)
for example
Abstract ClassifierFactory template that handles arbitrary types.
virtual void DeclareCompatibilityOptions()
options that are used ONLY for the READER to ensure backward compatibility they are hence without any...
Definition: MethodBase.cxx:590
void Clear()
clean up
Definition: ModulekNN.cxx:188
std::vector< std::vector< Float_t > > LDAEvents
Definition: LDA.h:42
Double_t getKernelRadius(const kNN::List &rlist) const
Get polynomial kernel radius.
Definition: MethodKNN.cxx:872
Definition: tree.py:1
Float_t fScaleFrac
Definition: MethodKNN.h:133
std::vector< Float_t > * fRegressionReturnVal
Definition: MethodBase.h:591
A TTree object has a header with a name and a title.
Definition: TTree.h:98
kNN::ModulekNN * fModule
Definition: MethodKNN.h:128
Definition: first.py:1
Double_t Sqrt(Double_t x)
Definition: TMath.h:464
double exp(double)
const Bool_t kTRUE
Definition: Rtypes.h:91
const std::vector< Float_t > & GetRegressionValues()
Return vector of averages for target values of k-nearest neighbors.
Definition: MethodKNN.cxx:438
std::vector< VarType > VarVec
Definition: ModulekNN.h:65
MethodKNN(const TString &jobName, const TString &methodTitle, DataSetInfo &theData, const TString &theOption="KNN")
standard constructor
Definition: MethodKNN.cxx:63
double getLDAValue(const kNN::List &rlist, const kNN::Event &event_knn)
Definition: MethodKNN.cxx:948
Int_t fnkNN
module where all work is done
Definition: MethodKNN.h:130
void NoErrorCalc(Double_t *const err, Double_t *const errUpper)
Definition: MethodBase.cxx:819
const T & GetEvent() const
Definition: NodekNN.h:161