Logo ROOT   6.10/09
Reference Guide
ModulekNN.h
Go to the documentation of this file.
1 // @(#)root/tmva $Id$
2 // Author: Rustem Ospanov
3 
4 /**********************************************************************************
5  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6  * Package: TMVA *
7  * Class : ModulekNN *
8  * Web : http://tmva.sourceforge.net *
9  * *
10  * Description: *
11  * Module for k-nearest neighbor algorithm *
12  * *
13  * Author: *
14  * Rustem Ospanov <rustem@fnal.gov> - U. of Texas at Austin, USA *
15  * *
16  * Copyright (c) 2007: *
17  * CERN, Switzerland *
18  * MPI-K Heidelberg, Germany *
19  * U. of Texas at Austin, USA *
20  * *
21  * Redistribution and use in source and binary forms, with or without *
22  * modification, are permitted according to the terms listed in LICENSE *
23  * (http://tmva.sourceforge.net/LICENSE) *
24  **********************************************************************************/
25 
26 #ifndef ROOT_TMVA_ModulekNN
27 #define ROOT_TMVA_ModulekNN
28 
29 //______________________________________________________________________
30 /*
31  kNN::Event describes point in input variable vector-space, with
32  additional functionality like distance between points
33 */
34 //______________________________________________________________________
35 
36 
37 // C++
38 #include <cassert>
39 #include <iosfwd>
40 #include <map>
41 #include <string>
42 #include <vector>
43 
44 // ROOT
45 #include "Rtypes.h"
46 #include "TRandom3.h"
47 #include "ThreadLocalStorage.h"
48 #include "TMVA/NodekNN.h"
49 
50 namespace TMVA {
51 
52  class MsgLogger;
53 
54  namespace kNN {
55 
56  typedef Float_t VarType;
57  typedef std::vector<VarType> VarVec;
58 
59  class Event {
60  public:
61 
62  Event();
63  Event(const VarVec &vec, Double_t weight, Short_t type);
64  Event(const VarVec &vec, Double_t weight, Short_t type, const VarVec &tvec);
65  ~Event();
66 
67  Double_t GetWeight() const;
68 
69  VarType GetVar(UInt_t i) const;
70  VarType GetTgt(UInt_t i) const;
71 
72  UInt_t GetNVar() const;
73  UInt_t GetNTgt() const;
74 
75  Short_t GetType() const;
76 
77  // keep these two function separate
78  VarType GetDist(VarType var, UInt_t ivar) const;
79  VarType GetDist(const Event &other) const;
80 
81  void SetTargets(const VarVec &tvec);
82  const VarVec& GetTargets() const;
83  const VarVec& GetVars() const;
84 
85  void Print() const;
86  void Print(std::ostream& os) const;
87 
88  private:
89 
90  VarVec fVar; // coordinates (variables) for knn search
91  VarVec fTgt; // targets for regression analysis
92 
93  Double_t fWeight; // event weight
94  Short_t fType; // event type ==0 or == 1, expand it to arbitrary class types?
95  };
96 
97  typedef std::vector<TMVA::kNN::Event> EventVec;
98  typedef std::pair<const Node<Event> *, VarType> Elem;
99  typedef std::list<Elem> List;
100 
101  std::ostream& operator<<(std::ostream& os, const Event& event);
102 
103  class ModulekNN
104  {
105  public:
106 
107  typedef std::map<int, std::vector<Double_t> > VarMap;
108 
109  public:
110 
111  ModulekNN();
112  ~ModulekNN();
113 
114  void Clear();
115 
116  void Add(const Event &event);
117 
118  Bool_t Fill(const UShort_t odepth, UInt_t ifrac, const std::string &option = "");
119 
120  Bool_t Find(Event event, UInt_t nfind = 100, const std::string &option = "count") const;
121  Bool_t Find(UInt_t nfind, const std::string &option) const;
122 
123  const EventVec& GetEventVec() const;
124 
125  const List& GetkNNList() const;
126  const Event& GetkNNEvent() const;
127 
128  const VarMap& GetVarMap() const;
129 
130  const std::map<Int_t, Double_t>& GetMetric() const;
131 
132  void Print() const;
133  void Print(std::ostream &os) const;
134 
135  private:
136 
137  Node<Event>* Optimize(UInt_t optimize_depth);
138 
139  void ComputeMetric(UInt_t ifrac);
140 
141  const Event Scale(const Event &event) const;
142 
143  private:
144 
145  // This is a workaround for OSx where static thread_local data members are
146  // not supported. The C++ solution would indeed be the following:
147  static TRandom3& GetRndmThreadLocal() {TTHREAD_TLS_DECL_ARG(TRandom3,fgRndm,1); return fgRndm;};
148 
149  UInt_t fDimn;
150 
152 
153  std::map<Int_t, Double_t> fVarScale;
154 
155  mutable List fkNNList; // latest result from kNN search
156  mutable Event fkNNEvent; // latest event used for kNN search
157 
158  std::map<Short_t, UInt_t> fCount; // count number of events of each type
159 
160  EventVec fEvent; // vector of all events used to build tree and analysis
161  VarMap fVar; // sorted map of variables in each dimension for all event types
162 
163  mutable MsgLogger* fLogger; // message logger
164  MsgLogger& Log() const { return *fLogger; }
165  };
166 
167  //
168  // inlined functions for Event class
169  //
170  inline VarType Event::GetDist(const VarType var1, const UInt_t ivar) const
171  {
172  const VarType var2 = GetVar(ivar);
173  return (var1 - var2) * (var1 - var2);
174  }
175  inline Double_t Event::GetWeight() const
176  {
177  return fWeight;
178  }
179  inline VarType Event::GetVar(const UInt_t i) const
180  {
181  return fVar[i];
182  }
183  inline VarType Event::GetTgt(const UInt_t i) const
184  {
185  return fTgt[i];
186  }
187 
188  inline UInt_t Event::GetNVar() const
189  {
190  return fVar.size();
191  }
192  inline UInt_t Event::GetNTgt() const
193  {
194  return fTgt.size();
195  }
196  inline Short_t Event::GetType() const
197  {
198  return fType;
199  }
200 
201  //
202  // inline functions for ModulekNN class
203  //
204  inline const List& ModulekNN::GetkNNList() const
205  {
206  return fkNNList;
207  }
208  inline const Event& ModulekNN::GetkNNEvent() const
209  {
210  return fkNNEvent;
211  }
212  inline const EventVec& ModulekNN::GetEventVec() const
213  {
214  return fEvent;
215  }
217  {
218  return fVar;
219  }
220  inline const std::map<Int_t, Double_t>& ModulekNN::GetMetric() const
221  {
222  return fVarScale;
223  }
224 
225  } // end of kNN namespace
226 } // end of TMVA namespace
227 
228 #endif
229 
Event()
default constructor
Definition: ModulekNN.cxx:50
RooCmdArg Optimize(Int_t flag=2)
Random number generator class based on M.
Definition: TRandom3.h:27
Float_t VarType
Definition: ModulekNN.h:56
float Float_t
Definition: RtypesCore.h:53
unsigned short UShort_t
Definition: RtypesCore.h:36
const List & GetkNNList() const
Definition: ModulekNN.h:204
bool Bool_t
Definition: RtypesCore.h:59
VarType GetVar(UInt_t i) const
Definition: ModulekNN.h:179
Double_t GetWeight() const
Definition: ModulekNN.h:175
std::map< int, std::vector< Double_t > > VarMap
Definition: ModulekNN.h:107
std::map< Int_t, Double_t > fVarScale
Definition: ModulekNN.h:153
const std::map< Int_t, Double_t > & GetMetric() const
Definition: ModulekNN.h:220
MsgLogger & Log() const
Definition: ModulekNN.h:164
Short_t GetType() const
Definition: ModulekNN.h:196
const VarVec & GetTargets() const
Definition: ModulekNN.cxx:114
const EventVec & GetEventVec() const
Definition: ModulekNN.h:212
Double_t fWeight
Definition: ModulekNN.h:93
std::vector< VarType > VarVec
Definition: ModulekNN.h:57
void SetTargets(const VarVec &tvec)
Definition: ModulekNN.cxx:107
VarType GetDist(VarType var, UInt_t ivar) const
Definition: ModulekNN.h:170
UInt_t GetNVar() const
Definition: ModulekNN.h:188
unsigned int UInt_t
Definition: RtypesCore.h:42
short Short_t
Definition: RtypesCore.h:35
VarType GetTgt(UInt_t i) const
Definition: ModulekNN.h:183
const VarMap & GetVarMap() const
Definition: ModulekNN.h:216
std::list< Elem > List
Definition: ModulekNN.h:99
void Add(THist< DIMENSIONS, PRECISION_TO, STAT_TO... > &to, const THist< DIMENSIONS, PRECISION_FROM, STAT_FROM... > &from)
Add two histograms.
Definition: THist.hxx:336
std::pair< const Node< Event > *, VarType > Elem
Definition: ModulekNN.h:98
double Double_t
Definition: RtypesCore.h:55
std::ostream & operator<<(std::ostream &os, const Event &event)
int type
Definition: TGX11.cxx:120
UInt_t GetNTgt() const
Definition: ModulekNN.h:192
static TRandom3 & GetRndmThreadLocal()
Definition: ModulekNN.h:147
std::vector< TMVA::kNN::Event > EventVec
Definition: ModulekNN.h:97
UInt_t Find(std::list< std::pair< const Node< T > *, Float_t > > &nlist, const Node< T > *node, const T &event, UInt_t nfind)
ostringstream derivative to redirect and format output
Definition: MsgLogger.h:59
MsgLogger * fLogger
Definition: ModulekNN.h:163
const VarVec & GetVars() const
Definition: ModulekNN.cxx:121
const Event & GetkNNEvent() const
Definition: ModulekNN.h:208
Abstract ClassifierFactory template that handles arbitrary types.
Short_t fType
Definition: ModulekNN.h:94
void Print() const
print
Definition: ModulekNN.cxx:129
std::map< Short_t, UInt_t > fCount
Definition: ModulekNN.h:158
Node< Event > * fTree
Definition: ModulekNN.h:151
static void Fill(TTree *tree, int init, int count)
~Event()
destructor
Definition: ModulekNN.cxx:81