152 Log() << kWARNING <<
"kNN must be a positive integer: set kNN = " <<
fnkNN <<
Endl;
156 Log() << kWARNING <<
"ScaleFrac can not be negative: set ScaleFrac = " <<
fScaleFrac <<
Endl;
163 Log() << kWARNING <<
"Optimize must be a positive integer: set Optimize = " <<
fBalanceDepth <<
Endl;
168 <<
" kNN = \n" <<
fnkNN 173 <<
" Trim = \n" <<
fTrim 195 fModule =
new kNN::ModulekNN();
206 Log() << kFATAL <<
"ModulekNN is not created" <<
Endl;
219 Log() << kINFO <<
"Creating kd-tree with " <<
fEvent.size() <<
" events" <<
Endl;
221 for (kNN::EventVec::const_iterator event =
fEvent.begin();
event !=
fEvent.end(); ++event) {
236 Log() << kHEADER <<
"<Train> start..." <<
Endl;
239 Log() << kINFO <<
"Input events are normalized - setting ScaleFrac to 0" <<
Endl;
244 Log() << kINFO <<
"Erasing " <<
fEvent.size() <<
" previously stored events" <<
Endl;
248 Log() << kFATAL <<
"MethodKNN::Train() - mismatched or wrong number of event variables" <<
Endl;
278 kNN::Event event_knn(vvec, weight, event_type);
280 fEvent.push_back(event_knn);
311 for (
Int_t ivar = 0; ivar < nvar; ++ivar) {
318 const kNN::Event event_knn(vvec, weight, 3);
319 fModule->Find(event_knn, knn + 2);
322 if (rlist.size() != knn + 2) {
323 Log() << kFATAL <<
"kNN result list is empty" <<
Endl;
332 Bool_t use_gaus =
false, use_poln =
false;
334 if (
fKernel ==
"Gaus") use_gaus =
true;
335 else if (
fKernel ==
"Poln") use_poln =
true;
345 if (!(kradius > 0.0)) {
346 Log() << kFATAL <<
"kNN radius is not positive" <<
Endl;
356 std::vector<Double_t> rms_vec;
360 if (rms_vec.empty() || rms_vec.size() != event_knn.GetNVar()) {
361 Log() << kFATAL <<
"Failed to compute RMS vector" <<
Endl;
367 Double_t weight_all = 0, weight_sig = 0, weight_bac = 0;
369 for (kNN::List::const_iterator lit = rlist.begin(); lit != rlist.end(); ++lit) {
372 const kNN::Node<kNN::Event> &node = *(lit->first);
376 if (lit->second < 0.0) {
377 Log() << kFATAL <<
"A neighbor has negative distance to query event" <<
Endl;
379 else if (!(lit->second > 0.0)) {
380 Log() << kVERBOSE <<
"A neighbor has zero distance to query event" <<
Endl;
384 Double_t evweight = node.GetWeight();
391 if (node.GetEvent().GetType() == 1) {
395 else if (node.GetEvent().GetType() == 2) {
400 Log() << kFATAL <<
"Unknown type for training event" <<
Endl;
406 if (count_all >= knn) {
412 if (!(count_all > 0)) {
413 Log() << kFATAL <<
"Size kNN result list is not positive" <<
Endl;
418 if (count_all < knn) {
419 Log() << kDEBUG <<
"count_all and kNN have different size: " << count_all <<
" < " << knn <<
Endl;
423 if (!(weight_all > 0.0)) {
424 Log() << kFATAL <<
"kNN result total weight is not positive" <<
Endl;
428 return weight_sig/weight_all;
448 std::vector<float> reg_vec;
452 for (
Int_t ivar = 0; ivar < nvar; ++ivar) {
459 const kNN::Event event_knn(vvec, evt->
GetWeight(), 3);
460 fModule->Find(event_knn, knn + 2);
463 if (rlist.size() != knn + 2) {
464 Log() << kFATAL <<
"kNN result list is empty" <<
Endl;
472 for (kNN::List::const_iterator lit = rlist.begin(); lit != rlist.end(); ++lit) {
475 const kNN::Node<kNN::Event> &node = *(lit->first);
476 const kNN::VarVec &tvec = node.GetEvent().GetTargets();
477 const Double_t weight = node.GetEvent().GetWeight();
479 if (reg_vec.empty()) {
483 for(
UInt_t ivar = 0; ivar < tvec.size(); ++ivar) {
484 if (
fUseWeight) reg_vec[ivar] += tvec[ivar]*weight;
485 else reg_vec[ivar] += tvec[ivar];
494 if (count_all == knn) {
500 if (!(weight_all > 0.0)) {
501 Log() << kFATAL <<
"Total weight sum is not positive: " << weight_all <<
Endl;
505 for (
UInt_t ivar = 0; ivar < reg_vec.size(); ++ivar) {
506 reg_vec[ivar] /= weight_all;
532 for (kNN::EventVec::const_iterator event =
fEvent.begin();
event !=
fEvent.end(); ++event) {
534 std::stringstream s(
"");
536 for (
UInt_t ivar = 0; ivar <
event->GetNVar(); ++ivar) {
537 if (ivar>0) s <<
" ";
538 s << std::scientific <<
event->GetVar(ivar);
541 for (
UInt_t itgt = 0; itgt <
event->GetNTgt(); ++itgt) {
542 s <<
" " << std::scientific <<
event->GetTgt(itgt);
555 UInt_t nvar = 0, ntgt = 0;
570 std::stringstream s(
gTools().GetContent(ch) );
572 for(
UInt_t ivar=0; ivar<nvar; ivar++)
575 for(
UInt_t itgt=0; itgt<ntgt; itgt++)
580 kNN::Event event_knn(vvec, evtWeight, evtType, tvec);
581 fEvent.push_back(event_knn);
593 Log() << kINFO <<
"Starting ReadWeightsFromStream(std::istream& is) function..." <<
Endl;
596 Log() << kINFO <<
"Erasing " <<
fEvent.size() <<
" previously stored events" <<
Endl;
604 std::getline(is, line);
606 if (line.empty() || line.find(
"#") != std::string::npos) {
611 std::string::size_type pos=0;
612 while( (pos=line.find(
',',pos)) != std::string::npos ) { count++; pos++; }
617 if (count < 3 || nvar != count - 2) {
618 Log() << kFATAL <<
"Missing comma delimeter(s)" <<
Endl;
628 std::string::size_type prev = 0;
630 for (std::string::size_type ipos = 0; ipos < line.size(); ++ipos) {
631 if (line[ipos] !=
',' && ipos + 1 != line.size()) {
635 if (!(ipos > prev)) {
636 Log() << kFATAL <<
"Wrong substring limits" <<
Endl;
639 std::string vstring = line.substr(prev, ipos - prev);
640 if (ipos + 1 == line.size()) {
641 vstring = line.substr(prev, ipos - prev + 1);
644 if (vstring.empty()) {
645 Log() << kFATAL <<
"Failed to parse string" <<
Endl;
651 else if (vcount == 1) {
652 type = std::atoi(vstring.c_str());
654 else if (vcount == 2) {
655 weight = std::atof(vstring.c_str());
657 else if (vcount - 3 < vvec.size()) {
658 vvec[vcount - 3] = std::atof(vstring.c_str());
661 Log() << kFATAL <<
"Wrong variable count" <<
Endl;
668 fEvent.push_back(kNN::Event(vvec, weight, type));
671 Log() << kINFO <<
"Read " <<
fEvent.size() <<
" events from text file" <<
Endl;
682 Log() << kINFO <<
"Starting WriteWeightsToStream(TFile &rf) function..." <<
Endl;
685 Log() << kWARNING <<
"MethodKNN contains no events " <<
Endl;
689 kNN::Event *
event =
new kNN::Event();
692 tree->
Branch(
"event",
"TMVA::kNN::Event", &event);
695 for (kNN::EventVec::const_iterator it =
fEvent.begin(); it !=
fEvent.end(); ++it) {
697 size += tree->
Fill();
706 Log() << kINFO <<
"Wrote " << size <<
"MB and " <<
fEvent.size()
707 <<
" events to ROOT file" <<
Endl;
718 Log() << kINFO <<
"Starting ReadWeightsFromStream(TFile &rf) function..." <<
Endl;
721 Log() << kINFO <<
"Erasing " <<
fEvent.size() <<
" previously stored events" <<
Endl;
728 Log() << kFATAL <<
"Failed to find knn tree" <<
Endl;
732 kNN::Event *
event =
new kNN::Event();
738 for (
Int_t i = 0; i < nevent; ++i) {
746 Log() << kINFO <<
"Read " << size <<
"MB and " <<
fEvent.size()
747 <<
" events from ROOT file" <<
Endl;
760 fout <<
" // not implemented for class: \"" << className <<
"\"" << std::endl;
761 fout <<
"};" << std::endl;
775 Log() <<
"The k-nearest neighbor (k-NN) algorithm is a multi-dimensional classification" <<
Endl 776 <<
"and regression algorithm. Similarly to other TMVA algorithms, k-NN uses a set of" <<
Endl 777 <<
"training events for which a classification category/regression target is known. " <<
Endl 778 <<
"The k-NN method compares a test event to all training events using a distance " <<
Endl 779 <<
"function, which is an Euclidean distance in a space defined by the input variables. "<<
Endl 780 <<
"The k-NN method, as implemented in TMVA, uses a kd-tree algorithm to perform a" <<
Endl 781 <<
"quick search for the k events with shortest distance to the test event. The method" <<
Endl 782 <<
"returns a fraction of signal events among the k neighbors. It is recommended" <<
Endl 783 <<
"that a histogram which stores the k-NN decision variable is binned with k+1 bins" <<
Endl 784 <<
"between 0 and 1." <<
Endl;
787 Log() <<
gTools().
Color(
"bold") <<
"--- Performance tuning via configuration options: " 790 Log() <<
"The k-NN method estimates a density of signal and background events in a "<< Endl
791 <<
"neighborhood around the test event. The method assumes that the density of the " << Endl
792 <<
"signal and background events is uniform and constant within the neighborhood. " << Endl
793 <<
"k is an adjustable parameter and it determines an average size of the " << Endl
794 <<
"neighborhood. Small k values (less than 10) are sensitive to statistical " << Endl
795 <<
"fluctuations and large (greater than 100) values might not sufficiently capture " << Endl
796 <<
"local differences between events in the training set. The speed of the k-NN" << Endl
797 <<
"method also increases with larger values of k. " <<
Endl;
799 Log() <<
"The k-NN method assigns equal weight to all input variables. Different scales " << Endl
800 <<
"among the input variables is compensated using ScaleFrac parameter: the input " << Endl
801 <<
"variables are scaled so that the widths for central ScaleFrac*100% events are " << Endl
802 <<
"equal among all the input variables." <<
Endl;
805 Log() <<
gTools().
Color(
"bold") <<
"--- Additional configuration options: " 808 Log() <<
"The method inclues an option to use a Gaussian kernel to smooth out the k-NN" << Endl
809 <<
"response. The kernel re-weights events using a distance to the test event." <<
Endl;
819 if (!(avalue < 1.0)) {
823 const Double_t prod = 1.0 - avalue * avalue * avalue;
825 return (prod * prod * prod);
832 const kNN::Event &event,
const std::vector<Double_t> &svec)
const 834 if (event_knn.GetNVar() !=
event.GetNVar() || event_knn.GetNVar() != svec.size()) {
835 Log() << kFATAL <<
"Mismatched vectors in Gaussian kernel function" <<
Endl;
842 double sum_exp = 0.0;
844 for(
unsigned int ivar = 0; ivar < event_knn.GetNVar(); ++ivar) {
846 const Double_t diff_ =
event.GetVar(ivar) - event_knn.GetVar(ivar);
848 if (!(sigm_ > 0.0)) {
849 Log() << kFATAL <<
"Bad sigma value = " << sigm_ <<
Endl;
853 sum_exp += diff_*diff_/(2.0*sigm_*sigm_);
875 for (kNN::List::const_iterator lit = rlist.begin(); lit != rlist.end(); ++lit)
877 if (!(lit->second > 0.0))
continue;
879 if (kradius < lit->second || kradius < 0.0) kradius = lit->second;
882 if (kcount >= knn)
break;
895 std::vector<Double_t> rvec;
899 for (kNN::List::const_iterator lit = rlist.begin(); lit != rlist.end(); ++lit)
901 if (!(lit->second > 0.0))
continue;
903 const kNN::Node<kNN::Event> *node_ = lit ->
first;
904 const kNN::Event &event_ = node_->
GetEvent();
907 rvec.insert(rvec.end(), event_.GetNVar(), 0.0);
909 else if (rvec.size() != event_.GetNVar()) {
910 Log() << kFATAL <<
"Wrong number of variables, should never happen!" <<
Endl;
915 for(
unsigned int ivar = 0; ivar < event_.GetNVar(); ++ivar) {
916 const Double_t diff_ = event_.GetVar(ivar) - event_knn.GetVar(ivar);
917 rvec[ivar] += diff_*diff_;
921 if (kcount >= knn)
break;
925 Log() << kFATAL <<
"Bad event kcount = " << kcount <<
Endl;
930 for(
unsigned int ivar = 0; ivar < rvec.size(); ++ivar) {
931 if (!(rvec[ivar] > 0.0)) {
932 Log() << kFATAL <<
"Bad RMS value = " << rvec[ivar] <<
Endl;
949 for (kNN::List::const_iterator lit = rlist.begin(); lit != rlist.end(); ++lit) {
952 const kNN::Node<kNN::Event> &node = *(lit->first);
953 const kNN::VarVec &tvec = node.GetEvent().GetVars();
955 if (node.GetEvent().GetType() == 1) {
956 sig_vec.push_back(tvec);
958 else if (node.GetEvent().GetType() == 2) {
959 bac_vec.push_back(tvec);
962 Log() << kFATAL <<
"Unknown type for training event" <<
Endl;
void ProcessOptions()
process the options specified by the user
void AddWeightsXMLTo(void *parent) const
write weights to XML
MsgLogger & Endl(MsgLogger &ml)
Singleton class for Global types used by TMVA.
void MakeClassSpecific(std::ostream &, const TString &) const
write specific classifier response
void DeclareOptions()
MethodKNN options.
virtual Int_t Fill()
Fill all branches.
OptionBase * DeclareOptionRef(T &ref, const TString &name, const TString &desc="")
A ROOT file is a suite of consecutive data records (TKey instances) with a well defined format...
void Train(void)
kNN training
Virtual base Class for all MVA method.
virtual TObject * Get(const char *namecycle)
Return pointer to object identified by namecycle.
void MakeKNN(void)
create kNN
virtual Int_t GetEntry(Long64_t entry=0, Int_t getall=0)
Read all branches of entry and return total number of bytes read.
Ranking for variables in method (implementation)
void DeclareCompatibilityOptions()
options that are used ONLY for the READER to ensure backward compatibility
virtual Int_t WriteTObject(const TObject *obj, const char *name=0, Option_t *option="", Int_t bufsize=0)
Write object obj to this directory.
void ReadWeightsFromStream(std::istream &istr)
read the weights
const std::vector< Double_t > getRMS(const kNN::List &rlist, const kNN::Event &event_knn) const
Get polynomial kernel radius.
void Init(void)
Initialization.
virtual Int_t SetBranchAddress(const char *bname, void *add, TBranch **ptr=0)
Change branch address, dealing with clone trees properly.
const Event * GetEvent() const
Float_t GetProb(const std::vector< Float_t > &x, Int_t k)
Signal probability with Gaussian approximation.
Double_t GausKernel(const kNN::Event &event_knn, const kNN::Event &event, const std::vector< Double_t > &svec) const
Gaussian kernel.
DataSetInfo & DataInfo() const
Class that contains all the data information.
std::vector< VarType > VarVec
LDA fLDA
(untouched) events used for learning
Double_t GetWeight() const
return the event weight - depending on whether the flag IgnoreNegWeightsInTraining is or not...
Double_t PolnKernel(Double_t value) const
polynomial kernel
void WriteWeightsToStream(TFile &rf) const
save weights to ROOT file
std::vector< Float_t > & GetTargets()
UInt_t GetNEvents() const
temporary event when testing on a different DataSet than the own one
Double_t GetMvaValue(Double_t *err=0, Double_t *errUpper=0)
Compute classifier response.
virtual ~MethodKNN(void)
destructor
Int_t fTreeOptDepth
Experimental feature for local knn analysis.
void Initialize(const LDAEvents &inputSignal, const LDAEvents &inputBackground)
Create LDA matrix using local events found by knn method.
const Ranking * CreateRanking()
no ranking available
void GetHelpMessage() const
get help message text
UInt_t GetNVariables() const
Float_t GetValue(UInt_t ivar) const
return value of i'th variable
Bool_t IgnoreEventsWithNegWeightsInTraining() const
void ReadWeightsFromXML(void *wghtnode)
virtual void SetDirectory(TDirectory *dir)
Change the tree's directory.
Analysis of k-nearest neighbor.
Bool_t IsNormalised() const
virtual Long64_t GetEntries() const
virtual Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets)
FDA can handle classification with 2 classes and regression with one regression-target.
virtual Int_t Branch(TCollection *list, Int_t bufsize=32000, Int_t splitlevel=99, const char *name="")
Create one branch for each element in the collection.
#define REGISTER_METHOD(CLASS)
for example
Abstract ClassifierFactory template that handles arbitrary types.
virtual void DeclareCompatibilityOptions()
options that are used ONLY for the READER to ensure backward compatibility they are hence without any...
std::vector< std::vector< Float_t > > LDAEvents
Double_t getKernelRadius(const kNN::List &rlist) const
Get polynomial kernel radius.
std::vector< Float_t > * fRegressionReturnVal
A TTree object has a header with a name and a title.
Double_t Sqrt(Double_t x)
const std::vector< Float_t > & GetRegressionValues()
Return vector of averages for target values of k-nearest neighbors.
MethodKNN(const TString &jobName, const TString &methodTitle, DataSetInfo &theData, const TString &theOption="KNN")
standard constructor
double getLDAValue(const kNN::List &rlist, const kNN::Event &event_knn)
Int_t fnkNN
module where all work is done
void NoErrorCalc(Double_t *const err, Double_t *const errUpper)