ROOT  6.07/01
Reference Guide
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Groups Pages
MethodKNN.cxx
Go to the documentation of this file.
1 // @(#)root/tmva $Id$
2 // Author: Rustem Ospanov
3 
4 /**********************************************************************************
5  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6  * Package: TMVA *
7  * Class : MethodKNN *
8  * Web : http://tmva.sourceforge.net *
9  * *
10  * Description: *
11  * Implementation *
12  * *
13  * Author: *
14  * Rustem Ospanov <rustem@fnal.gov> - U. of Texas at Austin, USA *
15  * *
16  * Copyright (c) 2007: *
17  * CERN, Switzerland *
18  * MPI-K Heidelberg, Germany *
19  * U. of Texas at Austin, USA *
20  * *
21  * Redistribution and use in source and binary forms, with or without *
22  * modification, are permitted according to the terms listed in LICENSE *
23  * (http://tmva.sourceforge.net/LICENSE) *
24  **********************************************************************************/
25 
26 //////////////////////////////////////////////////////////////////////////
27 // //
28 // MethodKNN //
29 // //
30 // Analysis of k-nearest neighbor //
31 // //
32 //////////////////////////////////////////////////////////////////////////
33 
34 #include "TMVA/MethodKNN.h"
35 
36 // C/C++
37 #include <cmath>
38 #include <string>
39 #include <cstdlib>
40 
41 // ROOT
42 #include "TFile.h"
43 #include "TMath.h"
44 #include "TTree.h"
45 
46 // TMVA
47 #include "TMVA/ClassifierFactory.h"
48 #include "TMVA/DataSetInfo.h"
49 #include "TMVA/Event.h"
50 #include "TMVA/LDA.h"
51 #include "TMVA/MethodBase.h"
52 #include "TMVA/MsgLogger.h"
53 #include "TMVA/Ranking.h"
54 #include "TMVA/Tools.h"
55 #include "TMVA/Types.h"
56 
57 REGISTER_METHOD(KNN)
58 
59 ClassImp(TMVA::MethodKNN)
60 
61 ////////////////////////////////////////////////////////////////////////////////
62 /// standard constructor
63 
64 TMVA::MethodKNN::MethodKNN( const TString& jobName,
65  const TString& methodTitle,
66  DataSetInfo& theData,
67  const TString& theOption,
68  TDirectory* theTargetDir )
69  : TMVA::MethodBase(jobName, Types::kKNN, methodTitle, theData, theOption, theTargetDir)
70  , fSumOfWeightsS(0)
71  , fSumOfWeightsB(0)
72  , fModule(0)
73  , fnkNN(0)
74  , fBalanceDepth(0)
75  , fScaleFrac(0)
76  , fSigmaFact(0)
77  , fTrim(kFALSE)
78  , fUseKernel(kFALSE)
79  , fUseWeight(kFALSE)
80  , fUseLDA(kFALSE)
81  , fTreeOptDepth(0)
82 {
83 }
84 
85 ////////////////////////////////////////////////////////////////////////////////
86 /// constructor from weight file
87 
89  const TString& theWeightFile,
90  TDirectory* theTargetDir )
91  : TMVA::MethodBase( Types::kKNN, theData, theWeightFile, theTargetDir)
92  , fSumOfWeightsS(0)
93  , fSumOfWeightsB(0)
94  , fModule(0)
95  , fnkNN(0)
96  , fBalanceDepth(0)
97  , fScaleFrac(0)
98  , fSigmaFact(0)
99  , fTrim(kFALSE)
100  , fUseKernel(kFALSE)
101  , fUseWeight(kFALSE)
102  , fUseLDA(kFALSE)
103  , fTreeOptDepth(0)
104 {
105 }
106 
107 ////////////////////////////////////////////////////////////////////////////////
108 /// destructor
109 
111 {
112  if (fModule) delete fModule;
113 }
114 
115 ////////////////////////////////////////////////////////////////////////////////
116 /// MethodKNN options
117 
119 {
120  // fnkNN = 20; // number of k-nearest neighbors
121  // fBalanceDepth = 6; // number of binary tree levels used for tree balancing
122  // fScaleFrac = 0.8; // fraction of events used to compute variable width
123  // fSigmaFact = 1.0; // scale factor for Gaussian sigma
124  // fKernel = use polynomial (1-x^3)^3 or Gaussian kernel
125  // fTrim = false; // use equal number of signal and background events
126  // fUseKernel = false; // use polynomial kernel weight function
127  // fUseWeight = true; // count events using weights
128  // fUseLDA = false
129 
130  DeclareOptionRef(fnkNN = 20, "nkNN", "Number of k-nearest neighbors");
131  DeclareOptionRef(fBalanceDepth = 6, "BalanceDepth", "Binary tree balance depth");
132  DeclareOptionRef(fScaleFrac = 0.80, "ScaleFrac", "Fraction of events used to compute variable width");
133  DeclareOptionRef(fSigmaFact = 1.0, "SigmaFact", "Scale factor for sigma in Gaussian kernel");
134  DeclareOptionRef(fKernel = "Gaus", "Kernel", "Use polynomial (=Poln) or Gaussian (=Gaus) kernel");
135  DeclareOptionRef(fTrim = kFALSE, "Trim", "Use equal number of signal and background events");
136  DeclareOptionRef(fUseKernel = kFALSE, "UseKernel", "Use polynomial kernel weight");
137  DeclareOptionRef(fUseWeight = kTRUE, "UseWeight", "Use weight to count kNN events");
138  DeclareOptionRef(fUseLDA = kFALSE, "UseLDA", "Use local linear discriminant - experimental feature");
139 }
140 
141 ////////////////////////////////////////////////////////////////////////////////
142 /// options that are used ONLY for the READER to ensure backward compatibility
143 
146  DeclareOptionRef(fTreeOptDepth = 6, "TreeOptDepth", "Binary tree optimisation depth");
147 }
148 
149 ////////////////////////////////////////////////////////////////////////////////
150 /// process the options specified by the user
151 
153 {
154  if (!(fnkNN > 0)) {
155  fnkNN = 10;
156  Log() << kWARNING << "kNN must be a positive integer: set kNN = " << fnkNN << Endl;
157  }
158  if (fScaleFrac < 0.0) {
159  fScaleFrac = 0.0;
160  Log() << kWARNING << "ScaleFrac can not be negative: set ScaleFrac = " << fScaleFrac << Endl;
161  }
162  if (fScaleFrac > 1.0) {
163  fScaleFrac = 1.0;
164  }
165  if (!(fBalanceDepth > 0)) {
166  fBalanceDepth = 6;
167  Log() << kWARNING << "Optimize must be a positive integer: set Optimize = " << fBalanceDepth << Endl;
168  }
169 
170  Log() << kVERBOSE
171  << "kNN options: \n"
172  << " kNN = \n" << fnkNN
173  << " UseKernel = \n" << fUseKernel
174  << " SigmaFact = \n" << fSigmaFact
175  << " ScaleFrac = \n" << fScaleFrac
176  << " Kernel = \n" << fKernel
177  << " Trim = \n" << fTrim
178  << " Optimize = " << fBalanceDepth << Endl;
179 }
180 
181 ////////////////////////////////////////////////////////////////////////////////
182 /// FDA can handle classification with 2 classes and regression with one regression-target
183 
185 {
186  if (type == Types::kClassification && numberClasses == 2) return kTRUE;
187  if (type == Types::kRegression) return kTRUE;
188  return kFALSE;
189 }
190 
191 ////////////////////////////////////////////////////////////////////////////////
192 /// Initialization
193 
195 {
196  // fScaleFrac <= 0.0 then do not scale input variables
197  // fScaleFrac >= 1.0 then use all event coordinates to scale input variables
198 
199  fModule = new kNN::ModulekNN();
200  fSumOfWeightsS = 0;
201  fSumOfWeightsB = 0;
202 }
203 
204 ////////////////////////////////////////////////////////////////////////////////
205 /// create kNN
206 
208 {
209  if (!fModule) {
210  Log() << kFATAL << "ModulekNN is not created" << Endl;
211  }
212 
213  fModule->Clear();
214 
215  std::string option;
216  if (fScaleFrac > 0.0) {
217  option += "metric";
218  }
219  if (fTrim) {
220  option += "trim";
221  }
222 
223  Log() << kINFO << "Creating kd-tree with " << fEvent.size() << " events" << Endl;
224 
225  for (kNN::EventVec::const_iterator event = fEvent.begin(); event != fEvent.end(); ++event) {
226  fModule->Add(*event);
227  }
228 
229  // create binary tree
230  fModule->Fill(static_cast<UInt_t>(fBalanceDepth),
231  static_cast<UInt_t>(100.0*fScaleFrac),
232  option);
233 }
234 
235 ////////////////////////////////////////////////////////////////////////////////
236 /// kNN training
237 
239 {
240  Log() << kINFO << "<Train> start..." << Endl;
241 
242  if (IsNormalised()) {
243  Log() << kINFO << "Input events are normalized - setting ScaleFrac to 0" << Endl;
244  fScaleFrac = 0.0;
245  }
246 
247  if (!fEvent.empty()) {
248  Log() << kINFO << "Erasing " << fEvent.size() << " previously stored events" << Endl;
249  fEvent.clear();
250  }
251  if (GetNVariables() < 1)
252  Log() << kFATAL << "MethodKNN::Train() - mismatched or wrong number of event variables" << Endl;
253 
254 
255  Log() << kINFO << "Reading " << GetNEvents() << " events" << Endl;
256 
257  for (UInt_t ievt = 0; ievt < GetNEvents(); ++ievt) {
258  // read the training event
259  const Event* evt_ = GetEvent(ievt);
260  Double_t weight = evt_->GetWeight();
261 
262  // in case event with neg weights are to be ignored
263  if (IgnoreEventsWithNegWeightsInTraining() && weight <= 0) continue;
264 
265  kNN::VarVec vvec(GetNVariables(), 0.0);
266  for (UInt_t ivar = 0; ivar < evt_ -> GetNVariables(); ++ivar) vvec[ivar] = evt_->GetValue(ivar);
267 
268  Short_t event_type = 0;
269 
270  if (DataInfo().IsSignal(evt_)) { // signal type = 1
271  fSumOfWeightsS += weight;
272  event_type = 1;
273  }
274  else { // background type = 2
275  fSumOfWeightsB += weight;
276  event_type = 2;
277  }
278 
279  //
280  // Create event and add classification variables, weight, type and regression variables
281  //
282  kNN::Event event_knn(vvec, weight, event_type);
283  event_knn.SetTargets(evt_->GetTargets());
284  fEvent.push_back(event_knn);
285 
286  }
287  Log() << kINFO
288  << "Number of signal events " << fSumOfWeightsS << Endl
289  << "Number of background events " << fSumOfWeightsB << Endl;
290 
291  // create kd-tree (binary tree) structure
292  MakeKNN();
293 }
294 
295 ////////////////////////////////////////////////////////////////////////////////
296 /// Compute classifier response
297 
299 {
300  // cannot determine error
301  NoErrorCalc(err, errUpper);
302 
303  //
304  // Define local variables
305  //
306  const Event *ev = GetEvent();
307  const Int_t nvar = GetNVariables();
308  const Double_t weight = ev->GetWeight();
309  const UInt_t knn = static_cast<UInt_t>(fnkNN);
310 
311  kNN::VarVec vvec(static_cast<UInt_t>(nvar), 0.0);
312 
313  for (Int_t ivar = 0; ivar < nvar; ++ivar) {
314  vvec[ivar] = ev->GetValue(ivar);
315  }
316 
317  // search for fnkNN+2 nearest neighbors, pad with two
318  // events to avoid Monte-Carlo events with zero distance
319  // most of CPU time is spent in this recursive function
320  const kNN::Event event_knn(vvec, weight, 3);
321  fModule->Find(event_knn, knn + 2);
322 
323  const kNN::List &rlist = fModule->GetkNNList();
324  if (rlist.size() != knn + 2) {
325  Log() << kFATAL << "kNN result list is empty" << Endl;
326  return -100.0;
327  }
328 
329  if (fUseLDA) return MethodKNN::getLDAValue(rlist, event_knn);
330 
331  //
332  // Set flags for kernel option=Gaus, Poln
333  //
334  Bool_t use_gaus = false, use_poln = false;
335  if (fUseKernel) {
336  if (fKernel == "Gaus") use_gaus = true;
337  else if (fKernel == "Poln") use_poln = true;
338  }
339 
340  //
341  // Compute radius for polynomial kernel
342  //
343  Double_t kradius = -1.0;
344  if (use_poln) {
345  kradius = MethodKNN::getKernelRadius(rlist);
346 
347  if (!(kradius > 0.0)) {
348  Log() << kFATAL << "kNN radius is not positive" << Endl;
349  return -100.0;
350  }
351 
352  kradius = 1.0/TMath::Sqrt(kradius);
353  }
354 
355  //
356  // Compute RMS of variable differences for Gaussian sigma
357  //
358  std::vector<Double_t> rms_vec;
359  if (use_gaus) {
360  rms_vec = TMVA::MethodKNN::getRMS(rlist, event_knn);
361 
362  if (rms_vec.empty() || rms_vec.size() != event_knn.GetNVar()) {
363  Log() << kFATAL << "Failed to compute RMS vector" << Endl;
364  return -100.0;
365  }
366  }
367 
368  UInt_t count_all = 0;
369  Double_t weight_all = 0, weight_sig = 0, weight_bac = 0;
370 
371  for (kNN::List::const_iterator lit = rlist.begin(); lit != rlist.end(); ++lit) {
372 
373  // get reference to current node to make code more readable
374  const kNN::Node<kNN::Event> &node = *(lit->first);
375 
376  // Warn about Monte-Carlo event with zero distance
377  // this happens when this query event is also in learning sample
378  if (lit->second < 0.0) {
379  Log() << kFATAL << "A neighbor has negative distance to query event" << Endl;
380  }
381  else if (!(lit->second > 0.0)) {
382  Log() << kVERBOSE << "A neighbor has zero distance to query event" << Endl;
383  }
384 
385  // get event weight and scale weight by kernel function
386  Double_t evweight = node.GetWeight();
387  if (use_gaus) evweight *= MethodKNN::GausKernel(event_knn, node.GetEvent(), rms_vec);
388  else if (use_poln) evweight *= MethodKNN::PolnKernel(TMath::Sqrt(lit->second)*kradius);
389 
390  if (fUseWeight) weight_all += evweight;
391  else ++weight_all;
392 
393  if (node.GetEvent().GetType() == 1) { // signal type = 1
394  if (fUseWeight) weight_sig += evweight;
395  else ++weight_sig;
396  }
397  else if (node.GetEvent().GetType() == 2) { // background type = 2
398  if (fUseWeight) weight_bac += evweight;
399  else ++weight_bac;
400  }
401  else {
402  Log() << kFATAL << "Unknown type for training event" << Endl;
403  }
404 
405  // use only fnkNN events
406  ++count_all;
407 
408  if (count_all >= knn) {
409  break;
410  }
411  }
412 
413  // check that total number of events or total weight sum is positive
414  if (!(count_all > 0)) {
415  Log() << kFATAL << "Size kNN result list is not positive" << Endl;
416  return -100.0;
417  }
418 
419  // check that number of events matches number of k in knn
420  if (count_all < knn) {
421  Log() << kDEBUG << "count_all and kNN have different size: " << count_all << " < " << knn << Endl;
422  }
423 
424  // Check that total weight is positive
425  if (!(weight_all > 0.0)) {
426  Log() << kFATAL << "kNN result total weight is not positive" << Endl;
427  return -100.0;
428  }
429 
430  return weight_sig/weight_all;
431 }
432 
433 ////////////////////////////////////////////////////////////////////////////////
434 ///
435 /// Return vector of averages for target values of k-nearest neighbors.
436 /// Use own copy of the regression vector, I do not like using a pointer to vector.
437 ///
438 
439 const std::vector< Float_t >& TMVA::MethodKNN::GetRegressionValues()
440 {
441  if( fRegressionReturnVal == 0 )
442  fRegressionReturnVal = new std::vector<Float_t>;
443  else
444  fRegressionReturnVal->clear();
445 
446  //
447  // Define local variables
448  //
449  const Event *evt = GetEvent();
450  const Int_t nvar = GetNVariables();
451  const UInt_t knn = static_cast<UInt_t>(fnkNN);
452  std::vector<float> reg_vec;
453 
454  kNN::VarVec vvec(static_cast<UInt_t>(nvar), 0.0);
455 
456  for (Int_t ivar = 0; ivar < nvar; ++ivar) {
457  vvec[ivar] = evt->GetValue(ivar);
458  }
459 
460  // search for fnkNN+2 nearest neighbors, pad with two
461  // events to avoid Monte-Carlo events with zero distance
462  // most of CPU time is spent in this recursive function
463  const kNN::Event event_knn(vvec, evt->GetWeight(), 3);
464  fModule->Find(event_knn, knn + 2);
465 
466  const kNN::List &rlist = fModule->GetkNNList();
467  if (rlist.size() != knn + 2) {
468  Log() << kFATAL << "kNN result list is empty" << Endl;
469  return *fRegressionReturnVal;
470  }
471 
472  // compute regression values
473  Double_t weight_all = 0;
474  UInt_t count_all = 0;
475 
476  for (kNN::List::const_iterator lit = rlist.begin(); lit != rlist.end(); ++lit) {
477 
478  // get reference to current node to make code more readable
479  const kNN::Node<kNN::Event> &node = *(lit->first);
480  const kNN::VarVec &tvec = node.GetEvent().GetTargets();
481  const Double_t weight = node.GetEvent().GetWeight();
482 
483  if (reg_vec.empty()) {
484  reg_vec= kNN::VarVec(tvec.size(), 0.0);
485  }
486 
487  for(UInt_t ivar = 0; ivar < tvec.size(); ++ivar) {
488  if (fUseWeight) reg_vec[ivar] += tvec[ivar]*weight;
489  else reg_vec[ivar] += tvec[ivar];
490  }
491 
492  if (fUseWeight) weight_all += weight;
493  else ++weight_all;
494 
495  // use only fnkNN events
496  ++count_all;
497 
498  if (count_all == knn) {
499  break;
500  }
501  }
502 
503  // check that number of events matches number of k in knn
504  if (!(weight_all > 0.0)) {
505  Log() << kFATAL << "Total weight sum is not positive: " << weight_all << Endl;
506  return *fRegressionReturnVal;
507  }
508 
509  for (UInt_t ivar = 0; ivar < reg_vec.size(); ++ivar) {
510  reg_vec[ivar] /= weight_all;
511  }
512 
513  // copy result
514  fRegressionReturnVal->insert(fRegressionReturnVal->begin(), reg_vec.begin(), reg_vec.end());
515 
516  return *fRegressionReturnVal;
517 }
518 
519 ////////////////////////////////////////////////////////////////////////////////
520 /// no ranking available
521 
523 {
524  return 0;
525 }
526 
527 ////////////////////////////////////////////////////////////////////////////////
528 /// write weights to XML
529 
530 void TMVA::MethodKNN::AddWeightsXMLTo( void* parent ) const {
531  void* wght = gTools().AddChild(parent, "Weights");
532  gTools().AddAttr(wght,"NEvents",fEvent.size());
533  if (fEvent.size()>0) gTools().AddAttr(wght,"NVar",fEvent.begin()->GetNVar());
534  if (fEvent.size()>0) gTools().AddAttr(wght,"NTgt",fEvent.begin()->GetNTgt());
535 
536  for (kNN::EventVec::const_iterator event = fEvent.begin(); event != fEvent.end(); ++event) {
537 
538  std::stringstream s("");
539  s.precision( 16 );
540  for (UInt_t ivar = 0; ivar < event->GetNVar(); ++ivar) {
541  if (ivar>0) s << " ";
542  s << std::scientific << event->GetVar(ivar);
543  }
544 
545  for (UInt_t itgt = 0; itgt < event->GetNTgt(); ++itgt) {
546  s << " " << std::scientific << event->GetTgt(itgt);
547  }
548 
549  void* evt = gTools().AddChild(wght, "Event", s.str().c_str());
550  gTools().AddAttr(evt,"Type", event->GetType());
551  gTools().AddAttr(evt,"Weight", event->GetWeight());
552  }
553 }
554 
555 ////////////////////////////////////////////////////////////////////////////////
556 
557 void TMVA::MethodKNN::ReadWeightsFromXML( void* wghtnode ) {
558  void* ch = gTools().GetChild(wghtnode); // first event
559  UInt_t nvar = 0, ntgt = 0;
560  gTools().ReadAttr( wghtnode, "NVar", nvar );
561  gTools().ReadAttr( wghtnode, "NTgt", ntgt );
562 
563 
564  Short_t evtType(0);
565  Double_t evtWeight(0);
566 
567  while (ch) {
568  // build event
569  kNN::VarVec vvec(nvar, 0);
570  kNN::VarVec tvec(ntgt, 0);
571 
572  gTools().ReadAttr( ch, "Type", evtType );
573  gTools().ReadAttr( ch, "Weight", evtWeight );
574  std::stringstream s( gTools().GetContent(ch) );
575 
576  for(UInt_t ivar=0; ivar<nvar; ivar++)
577  s >> vvec[ivar];
578 
579  for(UInt_t itgt=0; itgt<ntgt; itgt++)
580  s >> tvec[itgt];
581 
582  ch = gTools().GetNextChild(ch);
583 
584  kNN::Event event_knn(vvec, evtWeight, evtType, tvec);
585  fEvent.push_back(event_knn);
586  }
587 
588  // create kd-tree (binary tree) structure
589  MakeKNN();
590 }
591 
592 ////////////////////////////////////////////////////////////////////////////////
593 /// read the weights
594 
596 {
597  Log() << kINFO << "Starting ReadWeightsFromStream(std::istream& is) function..." << Endl;
598 
599  if (!fEvent.empty()) {
600  Log() << kINFO << "Erasing " << fEvent.size() << " previously stored events" << Endl;
601  fEvent.clear();
602  }
603 
604  UInt_t nvar = 0;
605 
606  while (!is.eof()) {
607  std::string line;
608  std::getline(is, line);
609 
610  if (line.empty() || line.find("#") != std::string::npos) {
611  continue;
612  }
613 
614  UInt_t count = 0;
615  std::string::size_type pos=0;
616  while( (pos=line.find(',',pos)) != std::string::npos ) { count++; pos++; }
617 
618  if (nvar == 0) {
619  nvar = count - 2;
620  }
621  if (count < 3 || nvar != count - 2) {
622  Log() << kFATAL << "Missing comma delimeter(s)" << Endl;
623  }
624 
625  // Int_t ievent = -1;
626  Int_t type = -1;
627  Double_t weight = -1.0;
628 
629  kNN::VarVec vvec(nvar, 0.0);
630 
631  UInt_t vcount = 0;
632  std::string::size_type prev = 0;
633 
634  for (std::string::size_type ipos = 0; ipos < line.size(); ++ipos) {
635  if (line[ipos] != ',' && ipos + 1 != line.size()) {
636  continue;
637  }
638 
639  if (!(ipos > prev)) {
640  Log() << kFATAL << "Wrong substring limits" << Endl;
641  }
642 
643  std::string vstring = line.substr(prev, ipos - prev);
644  if (ipos + 1 == line.size()) {
645  vstring = line.substr(prev, ipos - prev + 1);
646  }
647 
648  if (vstring.empty()) {
649  Log() << kFATAL << "Failed to parse string" << Endl;
650  }
651 
652  if (vcount == 0) {
653  // ievent = std::atoi(vstring.c_str());
654  }
655  else if (vcount == 1) {
656  type = std::atoi(vstring.c_str());
657  }
658  else if (vcount == 2) {
659  weight = std::atof(vstring.c_str());
660  }
661  else if (vcount - 3 < vvec.size()) {
662  vvec[vcount - 3] = std::atof(vstring.c_str());
663  }
664  else {
665  Log() << kFATAL << "Wrong variable count" << Endl;
666  }
667 
668  prev = ipos + 1;
669  ++vcount;
670  }
671 
672  fEvent.push_back(kNN::Event(vvec, weight, type));
673  }
674 
675  Log() << kINFO << "Read " << fEvent.size() << " events from text file" << Endl;
676 
677  // create kd-tree (binary tree) structure
678  MakeKNN();
679 }
680 
681 ////////////////////////////////////////////////////////////////////////////////
682 /// save weights to ROOT file
683 
685 {
686  Log() << kINFO << "Starting WriteWeightsToStream(TFile &rf) function..." << Endl;
687 
688  if (fEvent.empty()) {
689  Log() << kWARNING << "MethodKNN contains no events " << Endl;
690  return;
691  }
692 
693  kNN::Event *event = new kNN::Event();
694  TTree *tree = new TTree("knn", "event tree");
695  tree->SetDirectory(0);
696  tree->Branch("event", "TMVA::kNN::Event", &event);
697 
698  Double_t size = 0.0;
699  for (kNN::EventVec::const_iterator it = fEvent.begin(); it != fEvent.end(); ++it) {
700  (*event) = (*it);
701  size += tree->Fill();
702  }
703 
704  // !!! hard coded tree name !!!
705  rf.WriteTObject(tree, "knn", "Overwrite");
706 
707  // scale to MegaBytes
708  size /= 1048576.0;
709 
710  Log() << kINFO << "Wrote " << size << "MB and " << fEvent.size()
711  << " events to ROOT file" << Endl;
712 
713  delete tree;
714  delete event;
715 }
716 
717 ////////////////////////////////////////////////////////////////////////////////
718 /// read weights from ROOT file
719 
721 {
722  Log() << kINFO << "Starting ReadWeightsFromStream(TFile &rf) function..." << Endl;
723 
724  if (!fEvent.empty()) {
725  Log() << kINFO << "Erasing " << fEvent.size() << " previously stored events" << Endl;
726  fEvent.clear();
727  }
728 
729  // !!! hard coded tree name !!!
730  TTree *tree = dynamic_cast<TTree *>(rf.Get("knn"));
731  if (!tree) {
732  Log() << kFATAL << "Failed to find knn tree" << Endl;
733  return;
734  }
735 
736  kNN::Event *event = new kNN::Event();
737  tree->SetBranchAddress("event", &event);
738 
739  const Int_t nevent = tree->GetEntries();
740 
741  Double_t size = 0.0;
742  for (Int_t i = 0; i < nevent; ++i) {
743  size += tree->GetEntry(i);
744  fEvent.push_back(*event);
745  }
746 
747  // scale to MegaBytes
748  size /= 1048576.0;
749 
750  Log() << kINFO << "Read " << size << "MB and " << fEvent.size()
751  << " events from ROOT file" << Endl;
752 
753  delete event;
754 
755  // create kd-tree (binary tree) structure
756  MakeKNN();
757 }
758 
759 ////////////////////////////////////////////////////////////////////////////////
760 /// write specific classifier response
761 
762 void TMVA::MethodKNN::MakeClassSpecific( std::ostream& fout, const TString& className ) const
763 {
764  fout << " // not implemented for class: \"" << className << "\"" << std::endl;
765  fout << "};" << std::endl;
766 }
767 
768 ////////////////////////////////////////////////////////////////////////////////
769 /// get help message text
770 ///
771 /// typical length of text line:
772 /// "|--------------------------------------------------------------|"
773 
775 {
776  Log() << Endl;
777  Log() << gTools().Color("bold") << "--- Short description:" << gTools().Color("reset") << Endl;
778  Log() << Endl;
779  Log() << "The k-nearest neighbor (k-NN) algorithm is a multi-dimensional classification" << Endl
780  << "and regression algorithm. Similarly to other TMVA algorithms, k-NN uses a set of" << Endl
781  << "training events for which a classification category/regression target is known. " << Endl
782  << "The k-NN method compares a test event to all training events using a distance " << Endl
783  << "function, which is an Euclidean distance in a space defined by the input variables. "<< Endl
784  << "The k-NN method, as implemented in TMVA, uses a kd-tree algorithm to perform a" << Endl
785  << "quick search for the k events with shortest distance to the test event. The method" << Endl
786  << "returns a fraction of signal events among the k neighbors. It is recommended" << Endl
787  << "that a histogram which stores the k-NN decision variable is binned with k+1 bins" << Endl
788  << "between 0 and 1." << Endl;
789 
790  Log() << Endl;
791  Log() << gTools().Color("bold") << "--- Performance tuning via configuration options: "
792  << gTools().Color("reset") << Endl;
793  Log() << Endl;
794  Log() << "The k-NN method estimates a density of signal and background events in a "<< Endl
795  << "neighborhood around the test event. The method assumes that the density of the " << Endl
796  << "signal and background events is uniform and constant within the neighborhood. " << Endl
797  << "k is an adjustable parameter and it determines an average size of the " << Endl
798  << "neighborhood. Small k values (less than 10) are sensitive to statistical " << Endl
799  << "fluctuations and large (greater than 100) values might not sufficiently capture " << Endl
800  << "local differences between events in the training set. The speed of the k-NN" << Endl
801  << "method also increases with larger values of k. " << Endl;
802  Log() << Endl;
803  Log() << "The k-NN method assigns equal weight to all input variables. Different scales " << Endl
804  << "among the input variables is compensated using ScaleFrac parameter: the input " << Endl
805  << "variables are scaled so that the widths for central ScaleFrac*100% events are " << Endl
806  << "equal among all the input variables." << Endl;
807 
808  Log() << Endl;
809  Log() << gTools().Color("bold") << "--- Additional configuration options: "
810  << gTools().Color("reset") << Endl;
811  Log() << Endl;
812  Log() << "The method inclues an option to use a Gaussian kernel to smooth out the k-NN" << Endl
813  << "response. The kernel re-weights events using a distance to the test event." << Endl;
814 }
815 
816 ////////////////////////////////////////////////////////////////////////////////
817 /// polynomial kernel
818 
820 {
821  const Double_t avalue = TMath::Abs(value);
822 
823  if (!(avalue < 1.0)) {
824  return 0.0;
825  }
826 
827  const Double_t prod = 1.0 - avalue * avalue * avalue;
828 
829  return (prod * prod * prod);
830 }
831 
832 ////////////////////////////////////////////////////////////////////////////////
833 /// Gaussian kernel
834 
836  const kNN::Event &event, const std::vector<Double_t> &svec) const
837 {
838  if (event_knn.GetNVar() != event.GetNVar() || event_knn.GetNVar() != svec.size()) {
839  Log() << kFATAL << "Mismatched vectors in Gaussian kernel function" << Endl;
840  return 0.0;
841  }
842 
843  //
844  // compute exponent
845  //
846  double sum_exp = 0.0;
847 
848  for(unsigned int ivar = 0; ivar < event_knn.GetNVar(); ++ivar) {
849 
850  const Double_t diff_ = event.GetVar(ivar) - event_knn.GetVar(ivar);
851  const Double_t sigm_ = svec[ivar];
852  if (!(sigm_ > 0.0)) {
853  Log() << kFATAL << "Bad sigma value = " << sigm_ << Endl;
854  return 0.0;
855  }
856 
857  sum_exp += diff_*diff_/(2.0*sigm_*sigm_);
858  }
859 
860  //
861  // Return unnormalized(!) Gaussian function, because normalization
862  // cancels for the ratio of weights.
863  //
864 
865  return std::exp(-sum_exp);
866 }
867 
868 ////////////////////////////////////////////////////////////////////////////////
869 ///
870 /// Get polynomial kernel radius
871 ///
872 
874 {
875  Double_t kradius = -1.0;
876  UInt_t kcount = 0;
877  const UInt_t knn = static_cast<UInt_t>(fnkNN);
878 
879  for (kNN::List::const_iterator lit = rlist.begin(); lit != rlist.end(); ++lit)
880  {
881  if (!(lit->second > 0.0)) continue;
882 
883  if (kradius < lit->second || kradius < 0.0) kradius = lit->second;
884 
885  ++kcount;
886  if (kcount >= knn) break;
887  }
888 
889  return kradius;
890 }
891 
892 ////////////////////////////////////////////////////////////////////////////////
893 ///
894 /// Get polynomial kernel radius
895 ///
896 
897 const std::vector<Double_t> TMVA::MethodKNN::getRMS(const kNN::List &rlist, const kNN::Event &event_knn) const
898 {
899  std::vector<Double_t> rvec;
900  UInt_t kcount = 0;
901  const UInt_t knn = static_cast<UInt_t>(fnkNN);
902 
903  for (kNN::List::const_iterator lit = rlist.begin(); lit != rlist.end(); ++lit)
904  {
905  if (!(lit->second > 0.0)) continue;
906 
907  const kNN::Node<kNN::Event> *node_ = lit -> first;
908  const kNN::Event &event_ = node_-> GetEvent();
909 
910  if (rvec.empty()) {
911  rvec.insert(rvec.end(), event_.GetNVar(), 0.0);
912  }
913  else if (rvec.size() != event_.GetNVar()) {
914  Log() << kFATAL << "Wrong number of variables, should never happen!" << Endl;
915  rvec.clear();
916  return rvec;
917  }
918 
919  for(unsigned int ivar = 0; ivar < event_.GetNVar(); ++ivar) {
920  const Double_t diff_ = event_.GetVar(ivar) - event_knn.GetVar(ivar);
921  rvec[ivar] += diff_*diff_;
922  }
923 
924  ++kcount;
925  if (kcount >= knn) break;
926  }
927 
928  if (kcount < 1) {
929  Log() << kFATAL << "Bad event kcount = " << kcount << Endl;
930  rvec.clear();
931  return rvec;
932  }
933 
934  for(unsigned int ivar = 0; ivar < rvec.size(); ++ivar) {
935  if (!(rvec[ivar] > 0.0)) {
936  Log() << kFATAL << "Bad RMS value = " << rvec[ivar] << Endl;
937  rvec.clear();
938  return rvec;
939  }
940 
941  rvec[ivar] = std::abs(fSigmaFact)*std::sqrt(rvec[ivar]/kcount);
942  }
943 
944  return rvec;
945 }
946 
947 ////////////////////////////////////////////////////////////////////////////////
948 
950 {
951  LDAEvents sig_vec, bac_vec;
952 
953  for (kNN::List::const_iterator lit = rlist.begin(); lit != rlist.end(); ++lit) {
954 
955  // get reference to current node to make code more readable
956  const kNN::Node<kNN::Event> &node = *(lit->first);
957  const kNN::VarVec &tvec = node.GetEvent().GetVars();
958 
959  if (node.GetEvent().GetType() == 1) { // signal type = 1
960  sig_vec.push_back(tvec);
961  }
962  else if (node.GetEvent().GetType() == 2) { // background type = 2
963  bac_vec.push_back(tvec);
964  }
965  else {
966  Log() << kFATAL << "Unknown type for training event" << Endl;
967  }
968  }
969 
970  fLDA.Initialize(sig_vec, bac_vec);
971 
972  return fLDA.GetProb(event_knn.GetVars(), 1);
973 }
virtual void Clear(Option_t *="")
Definition: TObject.h:110
void ProcessOptions()
process the options specified by the user
Definition: MethodKNN.cxx:152
UInt_t GetNVar() const
Definition: ModulekNN.h:196
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:162
TLine * line
void DeclareOptions()
MethodKNN options.
Definition: MethodKNN.cxx:118
virtual Int_t Fill()
Fill all branches.
Definition: TTree.cxx:4306
A ROOT file is a suite of consecutive data records (TKey instances) with a well defined format...
Definition: TFile.h:45
EAnalysisType
Definition: Types.h:124
void Train(void)
kNN training
Definition: MethodKNN.cxx:238
virtual TObject * Get(const char *namecycle)
Return pointer to object identified by namecycle.
void MakeKNN(void)
create kNN
Definition: MethodKNN.cxx:207
virtual Int_t GetEntry(Long64_t entry=0, Int_t getall=0)
Read all branches of entry and return total number of bytes read.
Definition: TTree.cxx:5144
Basic string class.
Definition: TString.h:137
int Int_t
Definition: RtypesCore.h:41
bool Bool_t
Definition: RtypesCore.h:59
const Bool_t kFALSE
Definition: Rtypes.h:92
Double_t GetWeight() const
return the event weight - depending on whether the flag IgnoreNegWeightsInTraining is or not...
Definition: Event.cxx:376
void AddAttr(void *node, const char *, const T &value, Int_t precision=16)
Definition: Tools.h:308
void * AddChild(void *parent, const char *childname, const char *content=0, bool isRootNode=false)
add child node
Definition: Tools.cxx:1134
Short_t Abs(Short_t d)
Definition: TMathBase.h:110
void DeclareCompatibilityOptions()
options that are used ONLY for the READER to ensure backward compatibility
Definition: MethodKNN.cxx:144
virtual Int_t WriteTObject(const TObject *obj, const char *name=0, Option_t *option="", Int_t bufsize=0)
Write object obj to this directory.
void ReadWeightsFromStream(std::istream &istr)
read the weights
Definition: MethodKNN.cxx:595
MethodKNN(const TString &jobName, const TString &methodTitle, DataSetInfo &theData, const TString &theOption="KNN", TDirectory *theTargetDir=NULL)
double sqrt(double)
Tools & gTools()
Definition: Tools.cxx:79
void Init(void)
Initialization.
Definition: MethodKNN.cxx:194
virtual Int_t SetBranchAddress(const char *bname, void *add, TBranch **ptr=0)
Change branch address, dealing with clone trees properly.
Definition: TTree.cxx:7510
Double_t PolnKernel(Double_t value) const
polynomial kernel
Definition: MethodKNN.cxx:819
std::list< Elem > List
Definition: ModulekNN.h:107
void * GetChild(void *parent, const char *childname=0)
get child node
Definition: Tools.cxx:1158
static Vc_ALWAYS_INLINE Vector< T > abs(const Vector< T > &x)
Definition: vector.h:450
void WriteWeightsToStream(TFile &rf) const
save weights to ROOT file
Definition: MethodKNN.cxx:684
void MakeClassSpecific(std::ostream &, const TString &) const
write specific classifier response
Definition: MethodKNN.cxx:762
std::vector< Float_t > & GetTargets()
Definition: Event.h:102
Double_t GetMvaValue(Double_t *err=0, Double_t *errUpper=0)
Compute classifier response.
Definition: MethodKNN.cxx:298
VarType GetVar(UInt_t i) const
Definition: ModulekNN.h:187
Double_t GausKernel(const kNN::Event &event_knn, const kNN::Event &event, const std::vector< Double_t > &svec) const
Gaussian kernel.
Definition: MethodKNN.cxx:835
virtual ~MethodKNN(void)
destructor
Definition: MethodKNN.cxx:110
unsigned int UInt_t
Definition: RtypesCore.h:42
ClassImp(TMVA::MethodKNN) TMVA
standard constructor
Definition: MethodKNN.cxx:59
bool first
Definition: line3Dfit.C:48
short Short_t
Definition: RtypesCore.h:35
void ReadAttr(void *node, const char *, T &value)
Definition: Tools.h:295
const Ranking * CreateRanking()
no ranking available
Definition: MethodKNN.cxx:522
void SetTargets(const VarVec &tvec)
Definition: ModulekNN.cxx:102
tuple tree
Definition: tree.py:24
void ReadWeightsFromXML(void *wghtnode)
Definition: MethodKNN.cxx:557
virtual void SetDirectory(TDirectory *dir)
Change the tree's directory.
Definition: TTree.cxx:8064
double Double_t
Definition: RtypesCore.h:55
Describe directory structure in memory.
Definition: TDirectory.h:44
int type
Definition: TGX11.cxx:120
Double_t GetWeight() const
Definition: NodekNN.h:185
void * GetNextChild(void *prevchild, const char *childname=0)
XML helpers.
Definition: Tools.cxx:1170
Float_t GetValue(UInt_t ivar) const
return value of i'th variable
Definition: Event.cxx:231
const T & GetEvent() const
Definition: NodekNN.h:161
void AddWeightsXMLTo(void *parent) const
write weights to XML
Definition: MethodKNN.cxx:530
virtual Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets)
FDA can handle classification with 2 classes and regression with one regression-target.
Definition: MethodKNN.cxx:184
const TString & Color(const TString &)
human readable color strings
Definition: Tools.cxx:837
virtual Int_t Branch(TCollection *list, Int_t bufsize=32000, Int_t splitlevel=99, const char *name="")
Create one branch for each element in the collection.
Definition: TTree.cxx:1623
#define REGISTER_METHOD(CLASS)
for example
const VarVec & GetTargets() const
Definition: ModulekNN.cxx:109
void GetHelpMessage() const
get help message text
Definition: MethodKNN.cxx:774
Double_t GetWeight() const
Definition: ModulekNN.h:183
virtual void DeclareCompatibilityOptions()
options that are used ONLY for the READER to ensure backward compatibility they are hence without any...
Definition: MethodBase.cxx:606
const std::vector< Double_t > getRMS(const kNN::List &rlist, const kNN::Event &event_knn) const
Get polynomial kernel radius.
Definition: MethodKNN.cxx:897
std::vector< std::vector< Float_t > > LDAEvents
Definition: LDA.h:42
virtual Long64_t GetEntries() const
Definition: TTree.h:386
A TTree object has a header with a name and a title.
Definition: TTree.h:98
Double_t Sqrt(Double_t x)
Definition: TMath.h:464
double exp(double)
const Bool_t kTRUE
Definition: Rtypes.h:91
const std::vector< Float_t > & GetRegressionValues()
Return vector of averages for target values of k-nearest neighbors.
Definition: MethodKNN.cxx:439
std::vector< VarType > VarVec
Definition: ModulekNN.h:65
float value
Definition: math.cpp:443
Short_t GetType() const
Definition: ModulekNN.h:204
double getLDAValue(const kNN::List &rlist, const kNN::Event &event_knn)
Definition: MethodKNN.cxx:949
const VarVec & GetVars() const
Definition: ModulekNN.cxx:116
Definition: math.cpp:60
Double_t getKernelRadius(const kNN::List &rlist) const
Get polynomial kernel radius.
Definition: MethodKNN.cxx:873