#ifndef ROOT_Math_KDTree
#define ROOT_Math_KDTree
#include <assert.h>
#include <vector>
#include <cmath>
#include "Rtypes.h"
namespace ROOT
{
namespace Math
{
//End_Html
template<class _DataPoint>
class KDTree
{
public:
typedef _DataPoint point_type;
typedef typename _DataPoint::value_type value_type;
static UInt_t Dimension() {return _DataPoint::Dimension();}
enum eSplitOption {
kEffective = 0,
kBinContent
};
private:
class ComparePoints
{
public:
Bool_t operator()(const point_type* pFirst,const point_type* pSecond) const;
UInt_t GetAxis() const {return fAxis;}
void SetAxis(UInt_t iAxis) {fAxis = iAxis;}
private:
UInt_t fAxis;
};
class Cut
{
public:
Cut():fAxis(0),fCutValue(0) {}
Cut(UInt_t iAxis,Double_t fNewCutValue):fAxis(iAxis),fCutValue(fNewCutValue) {}
~Cut() {}
UInt_t GetAxis() const {return fAxis;}
value_type GetCutValue() const {return fCutValue;}
void SetAxis(UInt_t iAxis) {fAxis = iAxis;}
void SetCutValue(Double_t fNewCutValue) {fCutValue = fNewCutValue;}
Bool_t operator<(const point_type& rPoint) const;
Bool_t operator>(const point_type& rPoint) const;
private:
UInt_t fAxis;
Double_t fCutValue;
};
class BaseNode;
class HeadNode;
class SplitNode;
class BinNode;
class TerminalNode;
class BaseNode
{
public:
BaseNode(BaseNode* pParent = 0);
virtual ~BaseNode();
virtual BaseNode* Clone() = 0;
virtual const BinNode* FindNode(const point_type& rPoint) const = 0;
virtual void GetClosestPoints(const point_type& rRef,UInt_t nPoints,std::vector<std::pair<const _DataPoint*,Double_t> >& vFoundPoints) const = 0;
virtual void GetPointsWithinDist(const point_type& rRef,value_type fDist,std::vector<const point_type*>& vFoundPoints) const = 0;
virtual Bool_t Insert(const point_type& rPoint) = 0;
virtual void Print(int iRow = 0) const = 0;
BaseNode*& LeftChild() {return fLeftChild;}
const BaseNode* LeftChild() const {return fLeftChild;}
BaseNode*& Parent() {return fParent;}
const BaseNode* Parent() const {return fParent;}
BaseNode*& RightChild() {return fRightChild;}
const BaseNode* RightChild() const {return fRightChild;}
BaseNode*& GetParentPointer();
virtual Bool_t IsHeadNode() const {return false;}
Bool_t IsLeftChild() const;
private:
BaseNode(const BaseNode& ) {}
BaseNode& operator=(const BaseNode& ) {return *this;}
BaseNode* fParent;
BaseNode* fLeftChild;
BaseNode* fRightChild;
};
class HeadNode : public BaseNode
{
public:
HeadNode(BaseNode& rNode):BaseNode(&rNode) {}
virtual ~HeadNode() {delete Parent();}
virtual const BinNode* FindNode(const point_type& rPoint) const {return Parent()->FindNode(rPoint);}
virtual void GetClosestPoints(const point_type& rRef,UInt_t nPoints,std::vector<std::pair<const _DataPoint*,Double_t> >& vFoundPoints) const;
virtual void GetPointsWithinDist(const point_type& rRef,value_type fDist,std::vector<const _DataPoint*>& vFoundPoints) const;
virtual Bool_t Insert(const point_type& rPoint) {return Parent()->Insert(rPoint);}
virtual void Print(Int_t) const {Parent()->Print();}
private:
HeadNode(const HeadNode& ) {}
HeadNode& operator=(const HeadNode& ) {return *this;}
virtual HeadNode* Clone();
virtual bool IsHeadNode() const {return true;}
using BaseNode::Parent;
using BaseNode::LeftChild;
using BaseNode::RightChild;
using BaseNode::GetParentPointer;
using BaseNode::IsLeftChild;
};
class SplitNode : public BaseNode
{
public:
SplitNode(UInt_t iAxis,Double_t fCutValue,BaseNode& rLeft,BaseNode& rRight,BaseNode* pParent = 0);
virtual ~SplitNode();
const Cut* GetCut() const {return fCut;}
virtual void Print(Int_t iRow = 0) const;
private:
SplitNode(const SplitNode& ) {}
SplitNode& operator=(const SplitNode& ) {return *this;}
virtual SplitNode* Clone();
virtual const BinNode* FindNode(const point_type& rPoint) const;
virtual void GetClosestPoints(const point_type& rRef,UInt_t nPoints,std::vector<std::pair<const _DataPoint*,Double_t> >& vFoundPoints) const;
virtual void GetPointsWithinDist(const point_type& rRef,value_type fDist,std::vector<const _DataPoint*>& vFoundPoints) const;
virtual Bool_t Insert(const point_type& rPoint);
const Cut* fCut;
};
class BinNode : public BaseNode
{
protected:
typedef std::pair<value_type,value_type> tBoundary;
public:
BinNode(BaseNode* pParent = 0);
BinNode(const BinNode& copy);
virtual ~BinNode() {}
virtual void EmptyBin();
virtual const BinNode* FindNode(const point_type& rPoint) const;
point_type GetBinCenter() const;
Double_t GetBinContent() const {return GetSumw();}
#ifndef _AIX
virtual const std::vector<tBoundary>& GetBoundaries() const {return fBoundaries;}
#else
virtual void GetBoundaries() const { }
#endif
Double_t GetDensity() const {return GetBinContent()/GetVolume();}
Double_t GetEffectiveEntries() const {return (GetSumw2()) ? std::pow(GetSumw(),2)/GetSumw2() : 0;}
UInt_t GetEntries() const {return fEntries;}
Double_t GetVolume() const;
Double_t GetSumw() const {return fSumw;}
Double_t GetSumw2() const {return fSumw2;}
virtual Bool_t Insert(const point_type& rPoint);
Bool_t IsInBin(const point_type& rPoint) const;
virtual void Print(int iRow = 0) const;
protected:
virtual BinNode* Clone();
std::vector<tBoundary> fBoundaries;
Double_t fSumw;
Double_t fSumw2;
UInt_t fEntries;
private:
BinNode& operator=(const BinNode& rhs);
virtual void GetClosestPoints(const point_type&,UInt_t,std::vector<std::pair<const _DataPoint*,Double_t> >&) const {}
virtual void GetPointsWithinDist(const point_type&,value_type,std::vector<const point_type*>&) const {}
using BaseNode::LeftChild;
using BaseNode::RightChild;
};
class TerminalNode : public BinNode
{
friend class KDTree<_DataPoint>;
typedef std::pair<value_type,value_type> tBoundary;
public:
TerminalNode(Double_t iBucketSize,BaseNode* pParent = 0);
virtual ~TerminalNode();
virtual void EmptyBin();
#ifndef _AIX
virtual const std::vector<tBoundary>& GetBoundaries() const;
#else
virtual void GetBoundaries() const;
#endif
virtual void GetClosestPoints(const point_type& rRef,UInt_t nPoints,std::vector<std::pair<const _DataPoint*,Double_t> >& vFoundPoints) const;
const std::vector<const point_type*>& GetPoints() const {return fDataPoints;}
virtual void GetPointsWithinDist(const point_type& rRef,value_type fDist,std::vector<const _DataPoint*>& vFoundPoints) const;
virtual void Print(int iRow = 0) const;
private:
TerminalNode(const TerminalNode& ) {}
TerminalNode& operator=(const TerminalNode& ) {return *this;}
typedef typename std::vector<const point_type* >::iterator data_it;
typedef typename std::vector<const point_type* >::const_iterator const_data_it;
TerminalNode(Double_t iBucketSize,UInt_t iSplitAxis,data_it first,data_it end);
virtual BinNode* Clone() {return ConvertToBinNode();}
BinNode* ConvertToBinNode();
virtual const BinNode* FindNode(const point_type&) const {return this;}
virtual Bool_t Insert(const point_type& rPoint);
void Split();
void SetOwner(Bool_t bIsOwner = true) {fOwnData = bIsOwner;}
void SetSplitOption(eSplitOption opt) {fSplitOption = opt;}
data_it SplitEffectiveEntries();
data_it SplitBinContent();
void UpdateBoundaries();
Bool_t fOwnData;
eSplitOption fSplitOption;
Double_t fBucketSize;
UInt_t fSplitAxis;
std::vector<const _DataPoint*> fDataPoints;
};
public:
typedef BinNode Bin;
class iterator
{
friend class KDTree<_DataPoint>;
public:
iterator(): fBin(0) {}
iterator(const iterator& copy): fBin(copy.fBin) {}
~iterator() {}
iterator& operator++();
const iterator& operator++() const;
iterator operator++(int);
const iterator operator++(int) const;
iterator& operator--();
const iterator& operator--() const;
iterator operator--(int);
const iterator operator--(int) const;
bool operator==(const iterator& rIterator) const {return (fBin == rIterator.fBin);}
bool operator!=(const iterator& rIterator) const {return !(*this == rIterator);}
iterator& operator=(const iterator& rhs);
Bin& operator*() {return *fBin;}
const Bin& operator*() const {return *fBin;}
Bin* operator->() {return fBin;}
const Bin* operator->() const {return fBin;}
TerminalNode* TN() {assert(dynamic_cast<TerminalNode*>(fBin)); return (TerminalNode*)fBin;}
private:
iterator(BinNode* pNode): fBin(pNode) {}
Bin* Next() const;
Bin* Previous() const;
mutable Bin* fBin;
};
KDTree(UInt_t iBucketSize);
~KDTree();
void EmptyBins();
iterator End();
const iterator End() const;
const Bin* FindBin(const point_type& rPoint) const {return fHead->FindNode(rPoint);}
iterator First();
const iterator First() const;
void Freeze();
Double_t GetBucketSize() const {return fBucketSize;}
void GetClosestPoints(const point_type& rRef,UInt_t nPoints,std::vector<std::pair<const _DataPoint*,Double_t> >& vFoundPoints) const;
Double_t GetEffectiveEntries() const;
KDTree<_DataPoint>* GetFrozenCopy();
UInt_t GetNBins() const;
UInt_t GetEntries() const;
void GetPointsWithinDist(const point_type& rRef,value_type fDist,std::vector<const point_type*>& vFoundPoints) const;
Double_t GetTotalSumw() const;
Double_t GetTotalSumw2() const;
Bool_t Insert(const point_type& rData) {return fHead->Parent()->Insert(rData);}
Bool_t IsFrozen() const {return fIsFrozen;}
iterator Last();
const iterator Last() const;
void Print() {fHead->Parent()->Print();}
void Reset();
void SetOwner(Bool_t bIsOwner = true);
void SetSplitOption(eSplitOption opt);
private:
KDTree();
KDTree(const KDTree<point_type>& ) {}
KDTree<point_type>& operator=(const KDTree<point_type>& ) {return *this;}
BaseNode* fHead;
Double_t fBucketSize;
Bool_t fIsFrozen;
};
}
}
#include "Math/KDTree.icc"
#endif // ROOT_Math_KDTree