108 if (fModule)
delete fModule;
126 DeclareOptionRef(fnkNN = 20,
"nkNN",
"Number of k-nearest neighbors");
127 DeclareOptionRef(fBalanceDepth = 6,
"BalanceDepth",
"Binary tree balance depth");
128 DeclareOptionRef(fScaleFrac = 0.80,
"ScaleFrac",
"Fraction of events used to compute variable width");
129 DeclareOptionRef(fSigmaFact = 1.0,
"SigmaFact",
"Scale factor for sigma in Gaussian kernel");
130 DeclareOptionRef(fKernel =
"Gaus",
"Kernel",
"Use polynomial (=Poln) or Gaussian (=Gaus) kernel");
131 DeclareOptionRef(fTrim =
kFALSE,
"Trim",
"Use equal number of signal and background events");
132 DeclareOptionRef(fUseKernel =
kFALSE,
"UseKernel",
"Use polynomial kernel weight");
133 DeclareOptionRef(fUseWeight =
kTRUE,
"UseWeight",
"Use weight to count kNN events");
134 DeclareOptionRef(fUseLDA =
kFALSE,
"UseLDA",
"Use local linear discriminant - experimental feature");
142 DeclareOptionRef(fTreeOptDepth = 6,
"TreeOptDepth",
"Binary tree optimisation depth");
152 Log() << kWARNING <<
"kNN must be a positive integer: set kNN = " << fnkNN <<
Endl;
154 if (fScaleFrac < 0.0) {
156 Log() << kWARNING <<
"ScaleFrac can not be negative: set ScaleFrac = " << fScaleFrac <<
Endl;
158 if (fScaleFrac > 1.0) {
161 if (!(fBalanceDepth > 0)) {
163 Log() << kWARNING <<
"Optimize must be a positive integer: set Optimize = " << fBalanceDepth <<
Endl;
168 <<
" kNN = \n" << fnkNN
169 <<
" UseKernel = \n" << fUseKernel
170 <<
" SigmaFact = \n" << fSigmaFact
171 <<
" ScaleFrac = \n" << fScaleFrac
172 <<
" Kernel = \n" << fKernel
173 <<
" Trim = \n" << fTrim
174 <<
" Optimize = " << fBalanceDepth <<
Endl;
206 Log() << kFATAL <<
"ModulekNN is not created" <<
Endl;
212 if (fScaleFrac > 0.0) {
219 Log() << kINFO <<
"Creating kd-tree with " << fEvent.size() <<
" events" <<
Endl;
221 for (kNN::EventVec::const_iterator event = fEvent.begin();
event != fEvent.end(); ++event) {
222 fModule->Add(*event);
226 fModule->Fill(
static_cast<UInt_t>(fBalanceDepth),
227 static_cast<UInt_t>(100.0*fScaleFrac),
236 Log() << kHEADER <<
"<Train> start..." <<
Endl;
238 if (IsNormalised()) {
239 Log() << kINFO <<
"Input events are normalized - setting ScaleFrac to 0" <<
Endl;
243 if (!fEvent.empty()) {
244 Log() << kINFO <<
"Erasing " << fEvent.size() <<
" previously stored events" <<
Endl;
247 if (GetNVariables() < 1)
248 Log() << kFATAL <<
"MethodKNN::Train() - mismatched or wrong number of event variables" <<
Endl;
251 Log() << kINFO <<
"Reading " << GetNEvents() <<
" events" <<
Endl;
253 for (
UInt_t ievt = 0; ievt < GetNEvents(); ++ievt) {
255 const Event* evt_ = GetEvent(ievt);
259 if (IgnoreEventsWithNegWeightsInTraining() && weight <= 0)
continue;
262 for (
UInt_t ivar = 0; ivar < evt_ -> GetNVariables(); ++ivar) vvec[ivar] = evt_->
GetValue(ivar);
266 if (DataInfo().IsSignal(evt_)) {
267 fSumOfWeightsS += weight;
271 fSumOfWeightsB += weight;
278 kNN::Event event_knn(vvec, weight, event_type);
280 fEvent.push_back(event_knn);
284 <<
"Number of signal events " << fSumOfWeightsS <<
Endl
285 <<
"Number of background events " << fSumOfWeightsB <<
Endl;
299 NoErrorCalc(err, errUpper);
304 const Event *ev = GetEvent();
305 const Int_t nvar = GetNVariables();
311 for (
Int_t ivar = 0; ivar < nvar; ++ivar) {
319 fModule->Find(event_knn, knn + 2);
321 const kNN::List &rlist = fModule->GetkNNList();
322 if (rlist.size() != knn + 2) {
323 Log() << kFATAL <<
"kNN result list is empty" <<
Endl;
332 Bool_t use_gaus =
false, use_poln =
false;
334 if (fKernel ==
"Gaus") use_gaus =
true;
335 else if (fKernel ==
"Poln") use_poln =
true;
345 if (!(kradius > 0.0)) {
346 Log() << kFATAL <<
"kNN radius is not positive" <<
Endl;
356 std::vector<Double_t> rms_vec;
360 if (rms_vec.empty() || rms_vec.size() != event_knn.
GetNVar()) {
361 Log() << kFATAL <<
"Failed to compute RMS vector" <<
Endl;
367 Double_t weight_all = 0, weight_sig = 0;
369 for (kNN::List::const_iterator lit = rlist.begin(); lit != rlist.end(); ++lit) {
376 if (lit->second < 0.0) {
377 Log() << kFATAL <<
"A neighbor has negative distance to query event" <<
Endl;
379 else if (!(lit->second > 0.0)) {
380 Log() << kVERBOSE <<
"A neighbor has zero distance to query event" <<
Endl;
388 if (fUseWeight) weight_all += evweight;
391 if (node.
GetEvent().GetType() == 1) {
392 if (fUseWeight) weight_sig += evweight;
395 else if (node.
GetEvent().GetType() == 2) {
398 Log() << kFATAL <<
"Unknown type for training event" <<
Endl;
404 if (count_all >= knn) {
410 if (!(count_all > 0)) {
411 Log() << kFATAL <<
"Size kNN result list is not positive" <<
Endl;
416 if (count_all < knn) {
417 Log() << kDEBUG <<
"count_all and kNN have different size: " << count_all <<
" < " << knn <<
Endl;
421 if (!(weight_all > 0.0)) {
422 Log() << kFATAL <<
"kNN result total weight is not positive" <<
Endl;
426 return weight_sig/weight_all;
435 if( fRegressionReturnVal == 0 )
436 fRegressionReturnVal =
new std::vector<Float_t>;
438 fRegressionReturnVal->clear();
443 const Event *evt = GetEvent();
444 const Int_t nvar = GetNVariables();
446 std::vector<float> reg_vec;
450 for (
Int_t ivar = 0; ivar < nvar; ++ivar) {
458 fModule->Find(event_knn, knn + 2);
460 const kNN::List &rlist = fModule->GetkNNList();
461 if (rlist.size() != knn + 2) {
462 Log() << kFATAL <<
"kNN result list is empty" <<
Endl;
463 return *fRegressionReturnVal;
470 for (kNN::List::const_iterator lit = rlist.begin(); lit != rlist.end(); ++lit) {
477 if (reg_vec.empty()) {
481 for(
UInt_t ivar = 0; ivar < tvec.size(); ++ivar) {
482 if (fUseWeight) reg_vec[ivar] += tvec[ivar]*weight;
483 else reg_vec[ivar] += tvec[ivar];
486 if (fUseWeight) weight_all += weight;
492 if (count_all == knn) {
498 if (!(weight_all > 0.0)) {
499 Log() << kFATAL <<
"Total weight sum is not positive: " << weight_all <<
Endl;
500 return *fRegressionReturnVal;
503 for (
UInt_t ivar = 0; ivar < reg_vec.size(); ++ivar) {
504 reg_vec[ivar] /= weight_all;
508 fRegressionReturnVal->insert(fRegressionReturnVal->begin(), reg_vec.begin(), reg_vec.end());
510 return *fRegressionReturnVal;
527 if (fEvent.size()>0)
gTools().
AddAttr(wght,
"NVar",fEvent.begin()->GetNVar());
528 if (fEvent.size()>0)
gTools().
AddAttr(wght,
"NTgt",fEvent.begin()->GetNTgt());
530 for (kNN::EventVec::const_iterator event = fEvent.begin();
event != fEvent.end(); ++event) {
532 std::stringstream s(
"");
534 for (
UInt_t ivar = 0; ivar <
event->GetNVar(); ++ivar) {
535 if (ivar>0) s <<
" ";
536 s << std::scientific <<
event->GetVar(ivar);
539 for (
UInt_t itgt = 0; itgt <
event->GetNTgt(); ++itgt) {
540 s <<
" " << std::scientific <<
event->GetTgt(itgt);
553 UInt_t nvar = 0, ntgt = 0;
568 std::stringstream s(
gTools().GetContent(ch) );
570 for(
UInt_t ivar=0; ivar<nvar; ivar++)
573 for(
UInt_t itgt=0; itgt<ntgt; itgt++)
578 kNN::Event event_knn(vvec, evtWeight, evtType, tvec);
579 fEvent.push_back(event_knn);
591 Log() << kINFO <<
"Starting ReadWeightsFromStream(std::istream& is) function..." <<
Endl;
593 if (!fEvent.empty()) {
594 Log() << kINFO <<
"Erasing " << fEvent.size() <<
" previously stored events" <<
Endl;
602 std::getline(is,
line);
604 if (
line.empty() ||
line.find(
"#") != std::string::npos) {
609 std::string::size_type pos=0;
610 while( (pos=
line.find(
',',pos)) != std::string::npos ) { count++; pos++; }
615 if (count < 3 || nvar != count - 2) {
616 Log() << kFATAL <<
"Missing comma delimeter(s)" <<
Endl;
626 std::string::size_type prev = 0;
628 for (std::string::size_type ipos = 0; ipos <
line.size(); ++ipos) {
629 if (
line[ipos] !=
',' && ipos + 1 !=
line.size()) {
633 if (!(ipos > prev)) {
634 Log() << kFATAL <<
"Wrong substring limits" <<
Endl;
637 std::string vstring =
line.substr(prev, ipos - prev);
638 if (ipos + 1 ==
line.size()) {
639 vstring =
line.substr(prev, ipos - prev + 1);
642 if (vstring.empty()) {
643 Log() << kFATAL <<
"Failed to parse string" <<
Endl;
649 else if (vcount == 1) {
650 type = std::atoi(vstring.c_str());
652 else if (vcount == 2) {
653 weight = std::atof(vstring.c_str());
655 else if (vcount - 3 < vvec.size()) {
656 vvec[vcount - 3] = std::atof(vstring.c_str());
659 Log() << kFATAL <<
"Wrong variable count" <<
Endl;
669 Log() << kINFO <<
"Read " << fEvent.size() <<
" events from text file" <<
Endl;
680 Log() << kINFO <<
"Starting WriteWeightsToStream(TFile &rf) function..." <<
Endl;
682 if (fEvent.empty()) {
683 Log() << kWARNING <<
"MethodKNN contains no events " <<
Endl;
689 tree->SetDirectory(
nullptr);
690 tree->Branch(
"event",
"TMVA::kNN::Event", &event);
693 for (kNN::EventVec::const_iterator it = fEvent.begin(); it != fEvent.end(); ++it) {
695 size += tree->Fill();
704 Log() << kINFO <<
"Wrote " <<
size <<
"MB and " << fEvent.size()
705 <<
" events to ROOT file" <<
Endl;
716 Log() << kINFO <<
"Starting ReadWeightsFromStream(TFile &rf) function..." <<
Endl;
718 if (!fEvent.empty()) {
719 Log() << kINFO <<
"Erasing " << fEvent.size() <<
" previously stored events" <<
Endl;
726 Log() << kFATAL <<
"Failed to find knn tree" <<
Endl;
731 tree->SetBranchAddress(
"event", &event);
733 const Int_t nevent = tree->GetEntries();
736 for (
Int_t i = 0; i < nevent; ++i) {
737 size += tree->GetEntry(i);
738 fEvent.push_back(*event);
744 Log() << kINFO <<
"Read " <<
size <<
"MB and " << fEvent.size()
745 <<
" events from ROOT file" <<
Endl;
758 fout <<
" // not implemented for class: \"" << className <<
"\"" << std::endl;
759 fout <<
"};" << std::endl;
773 Log() <<
"The k-nearest neighbor (k-NN) algorithm is a multi-dimensional classification" <<
Endl
774 <<
"and regression algorithm. Similarly to other TMVA algorithms, k-NN uses a set of" <<
Endl
775 <<
"training events for which a classification category/regression target is known. " <<
Endl
776 <<
"The k-NN method compares a test event to all training events using a distance " <<
Endl
777 <<
"function, which is an Euclidean distance in a space defined by the input variables. "<<
Endl
778 <<
"The k-NN method, as implemented in TMVA, uses a kd-tree algorithm to perform a" <<
Endl
779 <<
"quick search for the k events with shortest distance to the test event. The method" <<
Endl
780 <<
"returns a fraction of signal events among the k neighbors. It is recommended" <<
Endl
781 <<
"that a histogram which stores the k-NN decision variable is binned with k+1 bins" <<
Endl
782 <<
"between 0 and 1." <<
Endl;
785 Log() <<
gTools().
Color(
"bold") <<
"--- Performance tuning via configuration options: "
788 Log() <<
"The k-NN method estimates a density of signal and background events in a "<<
Endl
789 <<
"neighborhood around the test event. The method assumes that the density of the " <<
Endl
790 <<
"signal and background events is uniform and constant within the neighborhood. " <<
Endl
791 <<
"k is an adjustable parameter and it determines an average size of the " <<
Endl
792 <<
"neighborhood. Small k values (less than 10) are sensitive to statistical " <<
Endl
793 <<
"fluctuations and large (greater than 100) values might not sufficiently capture " <<
Endl
794 <<
"local differences between events in the training set. The speed of the k-NN" <<
Endl
795 <<
"method also increases with larger values of k. " <<
Endl;
797 Log() <<
"The k-NN method assigns equal weight to all input variables. Different scales " <<
Endl
798 <<
"among the input variables is compensated using ScaleFrac parameter: the input " <<
Endl
799 <<
"variables are scaled so that the widths for central ScaleFrac*100% events are " <<
Endl
800 <<
"equal among all the input variables." <<
Endl;
803 Log() <<
gTools().
Color(
"bold") <<
"--- Additional configuration options: "
806 Log() <<
"The method inclues an option to use a Gaussian kernel to smooth out the k-NN" <<
Endl
807 <<
"response. The kernel re-weights events using a distance to the test event." <<
Endl;
817 if (!(avalue < 1.0)) {
821 const Double_t prod = 1.0 - avalue * avalue * avalue;
823 return (prod * prod * prod);
830 const kNN::Event &event,
const std::vector<Double_t> &svec)
const
833 Log() << kFATAL <<
"Mismatched vectors in Gaussian kernel function" <<
Endl;
840 double sum_exp = 0.0;
842 for(
unsigned int ivar = 0; ivar < event_knn.
GetNVar(); ++ivar) {
844 const Double_t diff_ =
event.GetVar(ivar) - event_knn.
GetVar(ivar);
846 if (!(sigm_ > 0.0)) {
847 Log() << kFATAL <<
"Bad sigma value = " << sigm_ <<
Endl;
851 sum_exp += diff_*diff_/(2.0*sigm_*sigm_);
859 return std::exp(-sum_exp);
873 for (kNN::List::const_iterator lit = rlist.begin(); lit != rlist.end(); ++lit)
875 if (!(lit->second > 0.0))
continue;
877 if (kradius < lit->second || kradius < 0.0) kradius = lit->second;
880 if (kcount >= knn)
break;
893 std::vector<Double_t> rvec;
897 for (kNN::List::const_iterator lit = rlist.begin(); lit != rlist.end(); ++lit)
899 if (!(lit->second > 0.0))
continue;
902 const kNN::Event &event_ = node_-> GetEvent();
905 rvec.insert(rvec.end(), event_.
GetNVar(), 0.0);
907 else if (rvec.size() != event_.
GetNVar()) {
908 Log() << kFATAL <<
"Wrong number of variables, should never happen!" <<
Endl;
913 for(
unsigned int ivar = 0; ivar < event_.
GetNVar(); ++ivar) {
915 rvec[ivar] += diff_*diff_;
919 if (kcount >= knn)
break;
923 Log() << kFATAL <<
"Bad event kcount = " << kcount <<
Endl;
928 for(
unsigned int ivar = 0; ivar < rvec.size(); ++ivar) {
929 if (!(rvec[ivar] > 0.0)) {
930 Log() << kFATAL <<
"Bad RMS value = " << rvec[ivar] <<
Endl;
935 rvec[ivar] = std::abs(fSigmaFact)*std::sqrt(rvec[ivar]/kcount);
947 for (kNN::List::const_iterator lit = rlist.begin(); lit != rlist.end(); ++lit) {
953 if (node.
GetEvent().GetType() == 1) {
954 sig_vec.push_back(tvec);
956 else if (node.
GetEvent().GetType() == 2) {
957 bac_vec.push_back(tvec);
960 Log() << kFATAL <<
"Unknown type for training event" <<
Endl;
964 fLDA.Initialize(sig_vec, bac_vec);
966 return fLDA.GetProb(event_knn.
GetVars(), 1);
#define REGISTER_METHOD(CLASS)
for example
std::vector< std::vector< Float_t > > LDAEvents
size_t size(const MatrixT &matrix)
retrieve the size of a square matrix
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void value
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t Atom_t Time_t type
Int_t WriteTObject(const TObject *obj, const char *name=nullptr, Option_t *option="", Int_t bufsize=0) override
Write object obj to this directory.
TObject * Get(const char *namecycle) override
Return pointer to object identified by namecycle.
A ROOT file is an on-disk file, usually with extension .root, that stores objects in a file-system-li...
Class that contains all the data information.
Float_t GetValue(UInt_t ivar) const
return value of i'th variable
std::vector< Float_t > & GetTargets()
Double_t GetWeight() const
return the event weight - depending on whether the flag IgnoreNegWeightsInTraining is or not.
Virtual base Class for all MVA method.
virtual void DeclareCompatibilityOptions()
options that are used ONLY for the READER to ensure backward compatibility they are hence without any...
Analysis of k-nearest neighbor.
void Init(void)
Initialization.
void MakeKNN(void)
create kNN
virtual ~MethodKNN(void)
destructor
virtual Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets)
FDA can handle classification with 2 classes and regression with one regression-target.
const std::vector< Double_t > getRMS(const kNN::List &rlist, const kNN::Event &event_knn) const
Get polynomial kernel radius.
const Ranking * CreateRanking()
no ranking available
void MakeClassSpecific(std::ostream &, const TString &) const
write specific classifier response
MethodKNN(const TString &jobName, const TString &methodTitle, DataSetInfo &theData, const TString &theOption="KNN")
standard constructor
void DeclareCompatibilityOptions()
options that are used ONLY for the READER to ensure backward compatibility
Double_t getKernelRadius(const kNN::List &rlist) const
Get polynomial kernel radius.
void GetHelpMessage() const
get help message text
const std::vector< Float_t > & GetRegressionValues()
Return vector of averages for target values of k-nearest neighbors.
void Train(void)
kNN training
void ReadWeightsFromStream(std::istream &istr)
read the weights
double getLDAValue(const kNN::List &rlist, const kNN::Event &event_knn)
Double_t PolnKernel(Double_t value) const
polynomial kernel
void ProcessOptions()
process the options specified by the user
void ReadWeightsFromXML(void *wghtnode)
void AddWeightsXMLTo(void *parent) const
write weights to XML
Double_t GetMvaValue(Double_t *err=nullptr, Double_t *errUpper=nullptr)
Compute classifier response.
void DeclareOptions()
MethodKNN options.
void WriteWeightsToStream(TFile &rf) const
save weights to ROOT file
Double_t GausKernel(const kNN::Event &event_knn, const kNN::Event &event, const std::vector< Double_t > &svec) const
Gaussian kernel.
Ranking for variables in method (implementation)
Singleton class for Global types used by TMVA.
void SetTargets(const VarVec &tvec)
VarType GetVar(UInt_t i) const
const VarVec & GetVars() const
This file contains binary tree and global function template that searches tree for k-nearest neigbors...
Double_t GetWeight() const
const T & GetEvent() const
std::vector< VarType > VarVec
virtual void Clear(Option_t *="")
A TTree represents a columnar dataset.
create variable transformations
MsgLogger & Endl(MsgLogger &ml)
Double_t Sqrt(Double_t x)
Returns the square root of x.
Short_t Abs(Short_t d)
Returns the absolute value of parameter Short_t d.