#ifndef ROOT_TMVA_ModulekNN
#define ROOT_TMVA_ModulekNN
#include <cassert>
#include <iosfwd>
#include <map>
#include <string>
#include <vector>
#ifndef ROOT_Rtypes
#include "Rtypes.h"
#endif
#ifndef ROOT_TRandom
#include "TRandom3.h"
#endif
#ifndef ROOT_TMVA_NodekNN
#include "TMVA/NodekNN.h"
#endif
namespace TMVA {
class MsgLogger;
namespace kNN {
typedef Float_t VarType;
typedef std::vector<VarType> VarVec;
class Event {
public:
Event();
Event(const VarVec &vec, Double_t weight, Short_t type);
Event(const VarVec &vec, Double_t weight, Short_t type, const VarVec &tvec);
~Event();
Double_t GetWeight() const;
VarType GetVar(UInt_t i) const;
VarType GetTgt(UInt_t i) const;
UInt_t GetNVar() const;
UInt_t GetNTgt() const;
Short_t GetType() const;
VarType GetDist(VarType var, UInt_t ivar) const;
VarType GetDist(const Event &other) const;
void SetTargets(const VarVec &tvec);
const VarVec& GetTargets() const;
const VarVec& GetVars() const;
void Print() const;
void Print(std::ostream& os) const;
private:
VarVec fVar;
VarVec fTgt;
Double_t fWeight;
Short_t fType;
};
typedef std::vector<TMVA::kNN::Event> EventVec;
typedef std::pair<const Node<Event> *, VarType> Elem;
typedef std::list<Elem> List;
std::ostream& operator<<(std::ostream& os, const Event& event);
class ModulekNN
{
public:
typedef std::map<int, std::vector<Double_t> > VarMap;
public:
ModulekNN();
~ModulekNN();
void Clear();
void Add(const Event &event);
Bool_t Fill(const UShort_t odepth, UInt_t ifrac, const std::string &option = "");
Bool_t Find(Event event, UInt_t nfind = 100, const std::string &option = "count") const;
Bool_t Find(UInt_t nfind, const std::string &option) const;
const EventVec& GetEventVec() const;
const List& GetkNNList() const;
const Event& GetkNNEvent() const;
const VarMap& GetVarMap() const;
const std::map<Int_t, Double_t>& GetMetric() const;
void Print() const;
void Print(std::ostream &os) const;
private:
Node<Event>* Optimize(UInt_t optimize_depth);
void ComputeMetric(UInt_t ifrac);
const Event Scale(const Event &event) const;
private:
static TRandom3 fgRndm;
UInt_t fDimn;
Node<Event> *fTree;
std::map<Int_t, Double_t> fVarScale;
mutable List fkNNList;
mutable Event fkNNEvent;
std::map<Short_t, UInt_t> fCount;
EventVec fEvent;
VarMap fVar;
mutable MsgLogger* fLogger;
MsgLogger& Log() const { return *fLogger; }
};
inline VarType Event::GetDist(const VarType var1, const UInt_t ivar) const
{
const VarType var2 = GetVar(ivar);
return (var1 - var2) * (var1 - var2);
}
inline Double_t Event::GetWeight() const
{
return fWeight;
}
inline VarType Event::GetVar(const UInt_t i) const
{
return fVar[i];
}
inline VarType Event::GetTgt(const UInt_t i) const
{
return fTgt[i];
}
inline UInt_t Event::GetNVar() const
{
return fVar.size();
}
inline UInt_t Event::GetNTgt() const
{
return fTgt.size();
}
inline Short_t Event::GetType() const
{
return fType;
}
inline const List& ModulekNN::GetkNNList() const
{
return fkNNList;
}
inline const Event& ModulekNN::GetkNNEvent() const
{
return fkNNEvent;
}
inline const EventVec& ModulekNN::GetEventVec() const
{
return fEvent;
}
inline const ModulekNN::VarMap& ModulekNN::GetVarMap() const
{
return fVar;
}
inline const std::map<Int_t, Double_t>& ModulekNN::GetMetric() const
{
return fVarScale;
}
}
}
#endif