Logo ROOT   6.18/05
Reference Guide
ModulekNN.h
Go to the documentation of this file.
1// @(#)root/tmva $Id$
2// Author: Rustem Ospanov
3
4/**********************************************************************************
5 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6 * Package: TMVA *
7 * Class : ModulekNN *
8 * Web : http://tmva.sourceforge.net *
9 * *
10 * Description: *
11 * Module for k-nearest neighbor algorithm *
12 * *
13 * Author: *
14 * Rustem Ospanov <rustem@fnal.gov> - U. of Texas at Austin, USA *
15 * *
16 * Copyright (c) 2007: *
17 * CERN, Switzerland *
18 * MPI-K Heidelberg, Germany *
19 * U. of Texas at Austin, USA *
20 * *
21 * Redistribution and use in source and binary forms, with or without *
22 * modification, are permitted according to the terms listed in LICENSE *
23 * (http://tmva.sourceforge.net/LICENSE) *
24 **********************************************************************************/
25
26#ifndef ROOT_TMVA_ModulekNN
27#define ROOT_TMVA_ModulekNN
28
29//______________________________________________________________________
30/*
31 kNN::Event describes point in input variable vector-space, with
32 additional functionality like distance between points
33*/
34//______________________________________________________________________
35
36
37// C++
38#include <cassert>
39#include <iosfwd>
40#include <map>
41#include <string>
42#include <vector>
43
44// ROOT
45#include "Rtypes.h"
46#include "TRandom3.h"
47#include "ThreadLocalStorage.h"
48#include "TMVA/NodekNN.h"
49
50namespace TMVA {
51
52 class MsgLogger;
53
54 namespace kNN {
55
56 typedef Float_t VarType;
57 typedef std::vector<VarType> VarVec;
58
59 class Event {
60 public:
61
62 Event();
63 Event(const VarVec &vec, Double_t weight, Short_t type);
64 Event(const VarVec &vec, Double_t weight, Short_t type, const VarVec &tvec);
65 ~Event();
66
67 Double_t GetWeight() const;
68
69 VarType GetVar(UInt_t i) const;
70 VarType GetTgt(UInt_t i) const;
71
72 UInt_t GetNVar() const;
73 UInt_t GetNTgt() const;
74
75 Short_t GetType() const;
76
77 // keep these two function separate
78 VarType GetDist(VarType var, UInt_t ivar) const;
79 VarType GetDist(const Event &other) const;
80
81 void SetTargets(const VarVec &tvec);
82 const VarVec& GetTargets() const;
83 const VarVec& GetVars() const;
84
85 void Print() const;
86 void Print(std::ostream& os) const;
87
88 private:
89
90 VarVec fVar; // coordinates (variables) for knn search
91 VarVec fTgt; // targets for regression analysis
92
93 Double_t fWeight; // event weight
94 Short_t fType; // event type ==0 or == 1, expand it to arbitrary class types?
95 };
96
97 typedef std::vector<TMVA::kNN::Event> EventVec;
98 typedef std::pair<const Node<Event> *, VarType> Elem;
99 typedef std::list<Elem> List;
100
101 std::ostream& operator<<(std::ostream& os, const Event& event);
102
104 {
105 public:
106
107 typedef std::map<int, std::vector<Double_t> > VarMap;
108
109 public:
110
111 ModulekNN();
112 ~ModulekNN();
113
114 void Clear();
115
116 void Add(const Event &event);
117
118 Bool_t Fill(const UShort_t odepth, UInt_t ifrac, const std::string &option = "");
119
120 Bool_t Find(Event event, UInt_t nfind = 100, const std::string &option = "count") const;
121 Bool_t Find(UInt_t nfind, const std::string &option) const;
122
123 const EventVec& GetEventVec() const;
124
125 const List& GetkNNList() const;
126 const Event& GetkNNEvent() const;
127
128 const VarMap& GetVarMap() const;
129
130 const std::map<Int_t, Double_t>& GetMetric() const;
131
132 void Print() const;
133 void Print(std::ostream &os) const;
134
135 private:
136
137 Node<Event>* Optimize(UInt_t optimize_depth);
138
139 void ComputeMetric(UInt_t ifrac);
140
141 const Event Scale(const Event &event) const;
142
143 private:
144
145 // This is a workaround for OSx where static thread_local data members are
146 // not supported. The C++ solution would indeed be the following:
147 static TRandom3& GetRndmThreadLocal() {TTHREAD_TLS_DECL_ARG(TRandom3,fgRndm,1); return fgRndm;};
148
150
152
153 std::map<Int_t, Double_t> fVarScale;
154
155 mutable List fkNNList; // latest result from kNN search
156 mutable Event fkNNEvent; // latest event used for kNN search
157
158 std::map<Short_t, UInt_t> fCount; // count number of events of each type
159
160 EventVec fEvent; // vector of all events used to build tree and analysis
161 VarMap fVar; // sorted map of variables in each dimension for all event types
162
163 mutable MsgLogger* fLogger; // message logger
164 MsgLogger& Log() const { return *fLogger; }
165 };
166
167 //
168 // inlined functions for Event class
169 //
170 inline VarType Event::GetDist(const VarType var1, const UInt_t ivar) const
171 {
172 const VarType var2 = GetVar(ivar);
173 return (var1 - var2) * (var1 - var2);
174 }
176 {
177 return fWeight;
178 }
179 inline VarType Event::GetVar(const UInt_t i) const
180 {
181 return fVar[i];
182 }
183 inline VarType Event::GetTgt(const UInt_t i) const
184 {
185 return fTgt[i];
186 }
187
188 inline UInt_t Event::GetNVar() const
189 {
190 return fVar.size();
191 }
192 inline UInt_t Event::GetNTgt() const
193 {
194 return fTgt.size();
195 }
196 inline Short_t Event::GetType() const
197 {
198 return fType;
199 }
200
201 //
202 // inline functions for ModulekNN class
203 //
204 inline const List& ModulekNN::GetkNNList() const
205 {
206 return fkNNList;
207 }
208 inline const Event& ModulekNN::GetkNNEvent() const
209 {
210 return fkNNEvent;
211 }
212 inline const EventVec& ModulekNN::GetEventVec() const
213 {
214 return fEvent;
215 }
217 {
218 return fVar;
219 }
220 inline const std::map<Int_t, Double_t>& ModulekNN::GetMetric() const
221 {
222 return fVarScale;
223 }
224
225 } // end of kNN namespace
226} // end of TMVA namespace
227
228#endif
229
unsigned short UShort_t
Definition: RtypesCore.h:36
unsigned int UInt_t
Definition: RtypesCore.h:42
bool Bool_t
Definition: RtypesCore.h:59
short Short_t
Definition: RtypesCore.h:35
double Double_t
Definition: RtypesCore.h:55
float Float_t
Definition: RtypesCore.h:53
int type
Definition: TGX11.cxx:120
ostringstream derivative to redirect and format output
Definition: MsgLogger.h:59
Short_t GetType() const
Definition: ModulekNN.h:196
VarType GetDist(VarType var, UInt_t ivar) const
Definition: ModulekNN.h:170
VarType GetTgt(UInt_t i) const
Definition: ModulekNN.h:183
void SetTargets(const VarVec &tvec)
Definition: ModulekNN.cxx:107
const VarVec & GetTargets() const
Definition: ModulekNN.cxx:114
Double_t GetWeight() const
Definition: ModulekNN.h:175
UInt_t GetNVar() const
Definition: ModulekNN.h:188
UInt_t GetNTgt() const
Definition: ModulekNN.h:192
Event()
default constructor
Definition: ModulekNN.cxx:50
~Event()
destructor
Definition: ModulekNN.cxx:81
void Print() const
print
Definition: ModulekNN.cxx:129
Short_t fType
Definition: ModulekNN.h:94
VarType GetVar(UInt_t i) const
Definition: ModulekNN.h:179
Double_t fWeight
Definition: ModulekNN.h:93
const VarVec & GetVars() const
Definition: ModulekNN.cxx:121
std::map< Int_t, Double_t > fVarScale
Definition: ModulekNN.h:153
MsgLogger * fLogger
Definition: ModulekNN.h:163
Bool_t Fill(const UShort_t odepth, UInt_t ifrac, const std::string &option="")
fill the tree
Definition: ModulekNN.cxx:245
const VarMap & GetVarMap() const
Definition: ModulekNN.h:216
Node< Event > * Optimize(UInt_t optimize_depth)
Optimize() balances binary tree for first odepth levels for each depth we split sorted depth % dimens...
Definition: ModulekNN.cxx:449
const EventVec & GetEventVec() const
Definition: ModulekNN.h:212
static TRandom3 & GetRndmThreadLocal()
Definition: ModulekNN.h:147
std::map< int, std::vector< Double_t > > VarMap
Definition: ModulekNN.h:107
void Print() const
print
Definition: ModulekNN.cxx:662
void Clear()
clean up
Definition: ModulekNN.cxx:194
ModulekNN()
default constructor
Definition: ModulekNN.cxx:173
Bool_t Find(Event event, UInt_t nfind=100, const std::string &option="count") const
find in tree if tree has been filled then search for nfind closest events if metic (fVarScale map) is...
Definition: ModulekNN.cxx:348
const Event Scale(const Event &event) const
scale each event variable so that rms of variables is approximately 1.0 this allows comparisons of va...
Definition: ModulekNN.cxx:628
void ComputeMetric(UInt_t ifrac)
compute scale factor for each variable (dimension) so that distance is computed uniformly along each ...
Definition: ModulekNN.cxx:542
Node< Event > * fTree
Definition: ModulekNN.h:151
const Event & GetkNNEvent() const
Definition: ModulekNN.h:208
const std::map< Int_t, Double_t > & GetMetric() const
Definition: ModulekNN.h:220
~ModulekNN()
destructor
Definition: ModulekNN.cxx:183
std::map< Short_t, UInt_t > fCount
Definition: ModulekNN.h:158
MsgLogger & Log() const
Definition: ModulekNN.h:164
void Add(const Event &event)
add an event to tree
Definition: ModulekNN.cxx:212
const List & GetkNNList() const
Definition: ModulekNN.h:204
This file contains binary tree and global function template that searches tree for k-nearest neigbors...
Definition: NodekNN.h:67
Random number generator class based on M.
Definition: TRandom3.h:27
std::ostream & operator<<(std::ostream &os, const Event &event)
create variable transformations