108 if (fModule)
delete fModule;
126 DeclareOptionRef(fnkNN = 20,
"nkNN",
"Number of k-nearest neighbors");
127 DeclareOptionRef(fBalanceDepth = 6,
"BalanceDepth",
"Binary tree balance depth");
128 DeclareOptionRef(fScaleFrac = 0.80,
"ScaleFrac",
"Fraction of events used to compute variable width");
129 DeclareOptionRef(fSigmaFact = 1.0,
"SigmaFact",
"Scale factor for sigma in Gaussian kernel");
130 DeclareOptionRef(fKernel =
"Gaus",
"Kernel",
"Use polynomial (=Poln) or Gaussian (=Gaus) kernel");
131 DeclareOptionRef(fTrim =
kFALSE,
"Trim",
"Use equal number of signal and background events");
132 DeclareOptionRef(fUseKernel =
kFALSE,
"UseKernel",
"Use polynomial kernel weight");
133 DeclareOptionRef(fUseWeight =
kTRUE,
"UseWeight",
"Use weight to count kNN events");
134 DeclareOptionRef(fUseLDA =
kFALSE,
"UseLDA",
"Use local linear discriminant - experimental feature");
142 DeclareOptionRef(fTreeOptDepth = 6,
"TreeOptDepth",
"Binary tree optimisation depth");
152 Log() << kWARNING <<
"kNN must be a positive integer: set kNN = " << fnkNN <<
Endl;
154 if (fScaleFrac < 0.0) {
156 Log() << kWARNING <<
"ScaleFrac can not be negative: set ScaleFrac = " << fScaleFrac <<
Endl;
158 if (fScaleFrac > 1.0) {
161 if (!(fBalanceDepth > 0)) {
163 Log() << kWARNING <<
"Optimize must be a positive integer: set Optimize = " << fBalanceDepth <<
Endl;
168 <<
" kNN = \n" << fnkNN
169 <<
" UseKernel = \n" << fUseKernel
170 <<
" SigmaFact = \n" << fSigmaFact
171 <<
" ScaleFrac = \n" << fScaleFrac
172 <<
" Kernel = \n" << fKernel
173 <<
" Trim = \n" << fTrim
174 <<
" Optimize = " << fBalanceDepth <<
Endl;
206 Log() << kFATAL <<
"ModulekNN is not created" <<
Endl;
212 if (fScaleFrac > 0.0) {
219 Log() << kINFO <<
"Creating kd-tree with " << fEvent.size() <<
" events" <<
Endl;
221 for (kNN::EventVec::const_iterator
event = fEvent.begin();
event != fEvent.end(); ++
event) {
222 fModule->Add(*
event);
226 fModule->Fill(
static_cast<UInt_t>(fBalanceDepth),
227 static_cast<UInt_t>(100.0*fScaleFrac),
236 Log() << kHEADER <<
"<Train> start..." <<
Endl;
238 if (IsNormalised()) {
239 Log() << kINFO <<
"Input events are normalized - setting ScaleFrac to 0" <<
Endl;
243 if (!fEvent.empty()) {
244 Log() << kINFO <<
"Erasing " << fEvent.size() <<
" previously stored events" <<
Endl;
247 if (GetNVariables() < 1)
248 Log() << kFATAL <<
"MethodKNN::Train() - mismatched or wrong number of event variables" <<
Endl;
251 Log() << kINFO <<
"Reading " << GetNEvents() <<
" events" <<
Endl;
253 for (
UInt_t ievt = 0; ievt < GetNEvents(); ++ievt) {
255 const Event* evt_ = GetEvent(ievt);
259 if (IgnoreEventsWithNegWeightsInTraining() && weight <= 0)
continue;
262 for (
UInt_t ivar = 0; ivar < evt_ -> GetNVariables(); ++ivar) vvec[ivar] = evt_->
GetValue(ivar);
266 if (DataInfo().IsSignal(evt_)) {
267 fSumOfWeightsS += weight;
271 fSumOfWeightsB += weight;
278 kNN::Event event_knn(vvec, weight, event_type);
280 fEvent.push_back(event_knn);
284 <<
"Number of signal events " << fSumOfWeightsS <<
Endl
285 <<
"Number of background events " << fSumOfWeightsB <<
Endl;
299 NoErrorCalc(err, errUpper);
304 const Event *ev = GetEvent();
305 const Int_t nvar = GetNVariables();
311 for (
Int_t ivar = 0; ivar < nvar; ++ivar) {
319 fModule->Find(event_knn, knn + 2);
321 const kNN::List &rlist = fModule->GetkNNList();
322 if (rlist.size() != knn + 2) {
323 Log() << kFATAL <<
"kNN result list is empty" <<
Endl;
332 Bool_t use_gaus =
false, use_poln =
false;
334 if (fKernel ==
"Gaus") use_gaus =
true;
335 else if (fKernel ==
"Poln") use_poln =
true;
345 if (!(kradius > 0.0)) {
346 Log() << kFATAL <<
"kNN radius is not positive" <<
Endl;
356 std::vector<Double_t> rms_vec;
360 if (rms_vec.empty() || rms_vec.size() != event_knn.
GetNVar()) {
361 Log() << kFATAL <<
"Failed to compute RMS vector" <<
Endl;
367 Double_t weight_all = 0, weight_sig = 0, weight_bac = 0;
369 for (kNN::List::const_iterator lit = rlist.begin(); lit != rlist.end(); ++lit) {
376 if (lit->second < 0.0) {
377 Log() << kFATAL <<
"A neighbor has negative distance to query event" <<
Endl;
379 else if (!(lit->second > 0.0)) {
380 Log() << kVERBOSE <<
"A neighbor has zero distance to query event" <<
Endl;
388 if (fUseWeight) weight_all += evweight;
391 if (node.
GetEvent().GetType() == 1) {
392 if (fUseWeight) weight_sig += evweight;
395 else if (node.
GetEvent().GetType() == 2) {
396 if (fUseWeight) weight_bac += evweight;
400 Log() << kFATAL <<
"Unknown type for training event" <<
Endl;
406 if (count_all >= knn) {
412 if (!(count_all > 0)) {
413 Log() << kFATAL <<
"Size kNN result list is not positive" <<
Endl;
418 if (count_all < knn) {
419 Log() << kDEBUG <<
"count_all and kNN have different size: " << count_all <<
" < " << knn <<
Endl;
423 if (!(weight_all > 0.0)) {
424 Log() << kFATAL <<
"kNN result total weight is not positive" <<
Endl;
428 return weight_sig/weight_all;
437 if( fRegressionReturnVal == 0 )
438 fRegressionReturnVal =
new std::vector<Float_t>;
440 fRegressionReturnVal->clear();
445 const Event *evt = GetEvent();
446 const Int_t nvar = GetNVariables();
448 std::vector<float> reg_vec;
452 for (
Int_t ivar = 0; ivar < nvar; ++ivar) {
460 fModule->Find(event_knn, knn + 2);
462 const kNN::List &rlist = fModule->GetkNNList();
463 if (rlist.size() != knn + 2) {
464 Log() << kFATAL <<
"kNN result list is empty" <<
Endl;
465 return *fRegressionReturnVal;
472 for (kNN::List::const_iterator lit = rlist.begin(); lit != rlist.end(); ++lit) {
479 if (reg_vec.empty()) {
483 for(
UInt_t ivar = 0; ivar < tvec.size(); ++ivar) {
484 if (fUseWeight) reg_vec[ivar] += tvec[ivar]*weight;
485 else reg_vec[ivar] += tvec[ivar];
488 if (fUseWeight) weight_all += weight;
494 if (count_all == knn) {
500 if (!(weight_all > 0.0)) {
501 Log() << kFATAL <<
"Total weight sum is not positive: " << weight_all <<
Endl;
502 return *fRegressionReturnVal;
505 for (
UInt_t ivar = 0; ivar < reg_vec.size(); ++ivar) {
506 reg_vec[ivar] /= weight_all;
510 fRegressionReturnVal->insert(fRegressionReturnVal->begin(), reg_vec.begin(), reg_vec.end());
512 return *fRegressionReturnVal;
529 if (fEvent.size()>0)
gTools().
AddAttr(wght,
"NVar",fEvent.begin()->GetNVar());
530 if (fEvent.size()>0)
gTools().
AddAttr(wght,
"NTgt",fEvent.begin()->GetNTgt());
532 for (kNN::EventVec::const_iterator
event = fEvent.begin();
event != fEvent.end(); ++
event) {
534 std::stringstream s(
"");
536 for (
UInt_t ivar = 0; ivar <
event->GetNVar(); ++ivar) {
537 if (ivar>0) s <<
" ";
538 s << std::scientific <<
event->GetVar(ivar);
541 for (
UInt_t itgt = 0; itgt <
event->GetNTgt(); ++itgt) {
542 s <<
" " << std::scientific <<
event->GetTgt(itgt);
555 UInt_t nvar = 0, ntgt = 0;
570 std::stringstream s(
gTools().GetContent(ch) );
572 for(
UInt_t ivar=0; ivar<nvar; ivar++)
575 for(
UInt_t itgt=0; itgt<ntgt; itgt++)
580 kNN::Event event_knn(vvec, evtWeight, evtType, tvec);
581 fEvent.push_back(event_knn);
593 Log() << kINFO <<
"Starting ReadWeightsFromStream(std::istream& is) function..." <<
Endl;
595 if (!fEvent.empty()) {
596 Log() << kINFO <<
"Erasing " << fEvent.size() <<
" previously stored events" <<
Endl;
604 std::getline(is,
line);
606 if (
line.empty() ||
line.find(
"#") != std::string::npos) {
611 std::string::size_type pos=0;
612 while( (pos=
line.find(
',',pos)) != std::string::npos ) { count++; pos++; }
617 if (count < 3 || nvar != count - 2) {
618 Log() << kFATAL <<
"Missing comma delimeter(s)" <<
Endl;
628 std::string::size_type prev = 0;
630 for (std::string::size_type ipos = 0; ipos <
line.size(); ++ipos) {
631 if (
line[ipos] !=
',' && ipos + 1 !=
line.size()) {
635 if (!(ipos > prev)) {
636 Log() << kFATAL <<
"Wrong substring limits" <<
Endl;
639 std::string vstring =
line.substr(prev, ipos - prev);
640 if (ipos + 1 ==
line.size()) {
641 vstring =
line.substr(prev, ipos - prev + 1);
644 if (vstring.empty()) {
645 Log() << kFATAL <<
"Failed to parse string" <<
Endl;
651 else if (vcount == 1) {
652 type = std::atoi(vstring.c_str());
654 else if (vcount == 2) {
655 weight = std::atof(vstring.c_str());
657 else if (vcount - 3 < vvec.size()) {
658 vvec[vcount - 3] = std::atof(vstring.c_str());
661 Log() << kFATAL <<
"Wrong variable count" <<
Endl;
671 Log() << kINFO <<
"Read " << fEvent.size() <<
" events from text file" <<
Endl;
682 Log() << kINFO <<
"Starting WriteWeightsToStream(TFile &rf) function..." <<
Endl;
684 if (fEvent.empty()) {
685 Log() << kWARNING <<
"MethodKNN contains no events " <<
Endl;
691 tree->SetDirectory(0);
692 tree->Branch(
"event",
"TMVA::kNN::Event", &
event);
695 for (kNN::EventVec::const_iterator it = fEvent.begin(); it != fEvent.end(); ++it) {
706 Log() << kINFO <<
"Wrote " <<
size <<
"MB and " << fEvent.size()
707 <<
" events to ROOT file" <<
Endl;
718 Log() << kINFO <<
"Starting ReadWeightsFromStream(TFile &rf) function..." <<
Endl;
720 if (!fEvent.empty()) {
721 Log() << kINFO <<
"Erasing " << fEvent.size() <<
" previously stored events" <<
Endl;
728 Log() << kFATAL <<
"Failed to find knn tree" <<
Endl;
738 for (
Int_t i = 0; i < nevent; ++i) {
740 fEvent.push_back(*
event);
746 Log() << kINFO <<
"Read " <<
size <<
"MB and " << fEvent.size()
747 <<
" events from ROOT file" <<
Endl;
760 fout <<
" // not implemented for class: \"" << className <<
"\"" << std::endl;
761 fout <<
"};" << std::endl;
775 Log() <<
"The k-nearest neighbor (k-NN) algorithm is a multi-dimensional classification" <<
Endl
776 <<
"and regression algorithm. Similarly to other TMVA algorithms, k-NN uses a set of" <<
Endl
777 <<
"training events for which a classification category/regression target is known. " <<
Endl
778 <<
"The k-NN method compares a test event to all training events using a distance " <<
Endl
779 <<
"function, which is an Euclidean distance in a space defined by the input variables. "<<
Endl
780 <<
"The k-NN method, as implemented in TMVA, uses a kd-tree algorithm to perform a" <<
Endl
781 <<
"quick search for the k events with shortest distance to the test event. The method" <<
Endl
782 <<
"returns a fraction of signal events among the k neighbors. It is recommended" <<
Endl
783 <<
"that a histogram which stores the k-NN decision variable is binned with k+1 bins" <<
Endl
784 <<
"between 0 and 1." <<
Endl;
787 Log() <<
gTools().
Color(
"bold") <<
"--- Performance tuning via configuration options: "
790 Log() <<
"The k-NN method estimates a density of signal and background events in a "<<
Endl
791 <<
"neighborhood around the test event. The method assumes that the density of the " <<
Endl
792 <<
"signal and background events is uniform and constant within the neighborhood. " <<
Endl
793 <<
"k is an adjustable parameter and it determines an average size of the " <<
Endl
794 <<
"neighborhood. Small k values (less than 10) are sensitive to statistical " <<
Endl
795 <<
"fluctuations and large (greater than 100) values might not sufficiently capture " <<
Endl
796 <<
"local differences between events in the training set. The speed of the k-NN" <<
Endl
797 <<
"method also increases with larger values of k. " <<
Endl;
799 Log() <<
"The k-NN method assigns equal weight to all input variables. Different scales " <<
Endl
800 <<
"among the input variables is compensated using ScaleFrac parameter: the input " <<
Endl
801 <<
"variables are scaled so that the widths for central ScaleFrac*100% events are " <<
Endl
802 <<
"equal among all the input variables." <<
Endl;
805 Log() <<
gTools().
Color(
"bold") <<
"--- Additional configuration options: "
808 Log() <<
"The method inclues an option to use a Gaussian kernel to smooth out the k-NN" <<
Endl
809 <<
"response. The kernel re-weights events using a distance to the test event." <<
Endl;
819 if (!(avalue < 1.0)) {
823 const Double_t prod = 1.0 - avalue * avalue * avalue;
825 return (prod * prod * prod);
835 Log() << kFATAL <<
"Mismatched vectors in Gaussian kernel function" <<
Endl;
842 double sum_exp = 0.0;
844 for(
unsigned int ivar = 0; ivar < event_knn.
GetNVar(); ++ivar) {
846 const Double_t diff_ =
event.GetVar(ivar) - event_knn.
GetVar(ivar);
848 if (!(sigm_ > 0.0)) {
849 Log() << kFATAL <<
"Bad sigma value = " << sigm_ <<
Endl;
853 sum_exp += diff_*diff_/(2.0*sigm_*sigm_);
861 return std::exp(-sum_exp);
875 for (kNN::List::const_iterator lit = rlist.begin(); lit != rlist.end(); ++lit)
877 if (!(lit->second > 0.0))
continue;
879 if (kradius < lit->second || kradius < 0.0) kradius = lit->second;
882 if (kcount >= knn)
break;
895 std::vector<Double_t> rvec;
899 for (kNN::List::const_iterator lit = rlist.begin(); lit != rlist.end(); ++lit)
901 if (!(lit->second > 0.0))
continue;
904 const kNN::Event &event_ = node_-> GetEvent();
907 rvec.insert(rvec.end(), event_.
GetNVar(), 0.0);
909 else if (rvec.size() != event_.
GetNVar()) {
910 Log() << kFATAL <<
"Wrong number of variables, should never happen!" <<
Endl;
915 for(
unsigned int ivar = 0; ivar < event_.
GetNVar(); ++ivar) {
917 rvec[ivar] += diff_*diff_;
921 if (kcount >= knn)
break;
925 Log() << kFATAL <<
"Bad event kcount = " << kcount <<
Endl;
930 for(
unsigned int ivar = 0; ivar < rvec.size(); ++ivar) {
931 if (!(rvec[ivar] > 0.0)) {
932 Log() << kFATAL <<
"Bad RMS value = " << rvec[ivar] <<
Endl;
937 rvec[ivar] = std::abs(fSigmaFact)*std::sqrt(rvec[ivar]/kcount);
949 for (kNN::List::const_iterator lit = rlist.begin(); lit != rlist.end(); ++lit) {
955 if (node.
GetEvent().GetType() == 1) {
956 sig_vec.push_back(tvec);
958 else if (node.
GetEvent().GetType() == 2) {
959 bac_vec.push_back(tvec);
962 Log() << kFATAL <<
"Unknown type for training event" <<
Endl;
966 fLDA.Initialize(sig_vec, bac_vec);
968 return fLDA.GetProb(event_knn.
GetVars(), 1);
#define REGISTER_METHOD(CLASS)
for example
std::vector< std::vector< Float_t > > LDAEvents
size_t size(const MatrixT &matrix)
retrieve the size of a square matrix
Int_t WriteTObject(const TObject *obj, const char *name=nullptr, Option_t *option="", Int_t bufsize=0) override
Write object obj to this directory.
TObject * Get(const char *namecycle) override
Return pointer to object identified by namecycle.
A ROOT file is a suite of consecutive data records (TKey instances) with a well defined format.
Class that contains all the data information.
Float_t GetValue(UInt_t ivar) const
return value of i'th variable
std::vector< Float_t > & GetTargets()
Double_t GetWeight() const
return the event weight - depending on whether the flag IgnoreNegWeightsInTraining is or not.
Virtual base Class for all MVA method.
virtual void DeclareCompatibilityOptions()
options that are used ONLY for the READER to ensure backward compatibility they are hence without any...
Analysis of k-nearest neighbor.
void Init(void)
Initialization.
void MakeKNN(void)
create kNN
virtual ~MethodKNN(void)
destructor
virtual Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets)
FDA can handle classification with 2 classes and regression with one regression-target.
const std::vector< Double_t > getRMS(const kNN::List &rlist, const kNN::Event &event_knn) const
Get polynomial kernel radius.
const Ranking * CreateRanking()
no ranking available
void MakeClassSpecific(std::ostream &, const TString &) const
write specific classifier response
MethodKNN(const TString &jobName, const TString &methodTitle, DataSetInfo &theData, const TString &theOption="KNN")
standard constructor
void DeclareCompatibilityOptions()
options that are used ONLY for the READER to ensure backward compatibility
Double_t getKernelRadius(const kNN::List &rlist) const
Get polynomial kernel radius.
void GetHelpMessage() const
get help message text
const std::vector< Float_t > & GetRegressionValues()
Return vector of averages for target values of k-nearest neighbors.
void Train(void)
kNN training
void ReadWeightsFromStream(std::istream &istr)
read the weights
double getLDAValue(const kNN::List &rlist, const kNN::Event &event_knn)
Double_t PolnKernel(Double_t value) const
polynomial kernel
void ProcessOptions()
process the options specified by the user
void ReadWeightsFromXML(void *wghtnode)
void AddWeightsXMLTo(void *parent) const
write weights to XML
void DeclareOptions()
MethodKNN options.
void WriteWeightsToStream(TFile &rf) const
save weights to ROOT file
Double_t GausKernel(const kNN::Event &event_knn, const kNN::Event &event, const std::vector< Double_t > &svec) const
Gaussian kernel.
Double_t GetMvaValue(Double_t *err=0, Double_t *errUpper=0)
Compute classifier response.
Ranking for variables in method (implementation)
Singleton class for Global types used by TMVA.
void SetTargets(const VarVec &tvec)
VarType GetVar(UInt_t i) const
const VarVec & GetVars() const
This file contains binary tree and global function template that searches tree for k-nearest neigbors...
Double_t GetWeight() const
const T & GetEvent() const
std::vector< VarType > VarVec
virtual void Clear(Option_t *="")
A TTree represents a columnar dataset.
create variable transformations
MsgLogger & Endl(MsgLogger &ml)
Double_t Sqrt(Double_t x)