108 if (fModule)
delete fModule;
126 DeclareOptionRef(fnkNN = 20,
"nkNN",
"Number of k-nearest neighbors");
127 DeclareOptionRef(fBalanceDepth = 6,
"BalanceDepth",
"Binary tree balance depth");
128 DeclareOptionRef(fScaleFrac = 0.80,
"ScaleFrac",
"Fraction of events used to compute variable width");
129 DeclareOptionRef(fSigmaFact = 1.0,
"SigmaFact",
"Scale factor for sigma in Gaussian kernel");
130 DeclareOptionRef(fKernel =
"Gaus",
"Kernel",
"Use polynomial (=Poln) or Gaussian (=Gaus) kernel");
131 DeclareOptionRef(fTrim =
kFALSE,
"Trim",
"Use equal number of signal and background events");
132 DeclareOptionRef(fUseKernel =
kFALSE,
"UseKernel",
"Use polynomial kernel weight");
133 DeclareOptionRef(fUseWeight =
kTRUE,
"UseWeight",
"Use weight to count kNN events");
134 DeclareOptionRef(fUseLDA =
kFALSE,
"UseLDA",
"Use local linear discriminant - experimental feature");
142 DeclareOptionRef(fTreeOptDepth = 6,
"TreeOptDepth",
"Binary tree optimisation depth");
152 Log() << kWARNING <<
"kNN must be a positive integer: set kNN = " << fnkNN <<
Endl;
154 if (fScaleFrac < 0.0) {
156 Log() << kWARNING <<
"ScaleFrac can not be negative: set ScaleFrac = " << fScaleFrac <<
Endl;
158 if (fScaleFrac > 1.0) {
161 if (!(fBalanceDepth > 0)) {
163 Log() << kWARNING <<
"Optimize must be a positive integer: set Optimize = " << fBalanceDepth <<
Endl;
168 <<
" kNN = \n" << fnkNN
169 <<
" UseKernel = \n" << fUseKernel
170 <<
" SigmaFact = \n" << fSigmaFact
171 <<
" ScaleFrac = \n" << fScaleFrac
172 <<
" Kernel = \n" << fKernel
173 <<
" Trim = \n" << fTrim
174 <<
" Optimize = " << fBalanceDepth <<
Endl;
206 Log() << kFATAL <<
"ModulekNN is not created" <<
Endl;
212 if (fScaleFrac > 0.0) {
219 Log() << kINFO <<
"Creating kd-tree with " << fEvent.size() <<
" events" <<
Endl;
221 for (kNN::EventVec::const_iterator event = fEvent.begin();
event != fEvent.end(); ++event) {
222 fModule->Add(*event);
226 fModule->Fill(
static_cast<UInt_t>(fBalanceDepth),
227 static_cast<UInt_t>(100.0*fScaleFrac),
236 Log() << kHEADER <<
"<Train> start..." <<
Endl;
238 if (IsNormalised()) {
239 Log() << kINFO <<
"Input events are normalized - setting ScaleFrac to 0" <<
Endl;
243 if (!fEvent.empty()) {
244 Log() << kINFO <<
"Erasing " << fEvent.size() <<
" previously stored events" <<
Endl;
247 if (GetNVariables() < 1)
248 Log() << kFATAL <<
"MethodKNN::Train() - mismatched or wrong number of event variables" <<
Endl;
251 Log() << kINFO <<
"Reading " << GetNEvents() <<
" events" <<
Endl;
259 if (IgnoreEventsWithNegWeightsInTraining() && weight <= 0)
continue;
266 if (DataInfo().IsSignal(
evt_)) {
267 fSumOfWeightsS += weight;
271 fSumOfWeightsB += weight;
284 <<
"Number of signal events " << fSumOfWeightsS <<
Endl
285 <<
"Number of background events " << fSumOfWeightsB <<
Endl;
305 const Int_t nvar = GetNVariables();
321 const kNN::List &rlist = fModule->GetkNNList();
322 if (rlist.size() !=
knn + 2) {
323 Log() << kFATAL <<
"kNN result list is empty" <<
Endl;
334 if (fKernel ==
"Gaus")
use_gaus =
true;
335 else if (fKernel ==
"Poln")
use_poln =
true;
346 Log() << kFATAL <<
"kNN radius is not positive" <<
Endl;
361 Log() << kFATAL <<
"Failed to compute RMS vector" <<
Endl;
369 for (kNN::List::const_iterator
lit = rlist.begin();
lit != rlist.end(); ++
lit) {
376 if (
lit->second < 0.0) {
377 Log() << kFATAL <<
"A neighbor has negative distance to query event" <<
Endl;
379 else if (!(
lit->second > 0.0)) {
380 Log() << kVERBOSE <<
"A neighbor has zero distance to query event" <<
Endl;
391 if (node.
GetEvent().GetType() == 1) {
395 else if (node.
GetEvent().GetType() == 2) {
398 Log() << kFATAL <<
"Unknown type for training event" <<
Endl;
411 Log() << kFATAL <<
"Size kNN result list is not positive" <<
Endl;
417 Log() << kDEBUG <<
"count_all and kNN have different size: " <<
count_all <<
" < " <<
knn <<
Endl;
422 Log() << kFATAL <<
"kNN result total weight is not positive" <<
Endl;
435 if( fRegressionReturnVal == 0 )
436 fRegressionReturnVal =
new std::vector<Float_t>;
438 fRegressionReturnVal->clear();
444 const Int_t nvar = GetNVariables();
460 const kNN::List &rlist = fModule->GetkNNList();
461 if (rlist.size() !=
knn + 2) {
462 Log() << kFATAL <<
"kNN result list is empty" <<
Endl;
463 return *fRegressionReturnVal;
470 for (kNN::List::const_iterator
lit = rlist.begin();
lit != rlist.end(); ++
lit) {
499 Log() << kFATAL <<
"Total weight sum is not positive: " <<
weight_all <<
Endl;
500 return *fRegressionReturnVal;
510 return *fRegressionReturnVal;
530 for (kNN::EventVec::const_iterator event = fEvent.begin();
event != fEvent.end(); ++event) {
532 std::stringstream s(
"");
535 if (
ivar>0) s <<
" ";
536 s << std::scientific <<
event->GetVar(
ivar);
540 s <<
" " << std::scientific <<
event->GetTgt(
itgt);
568 std::stringstream s(
gTools().GetContent(ch) );
591 Log() << kINFO <<
"Starting ReadWeightsFromStream(std::istream& is) function..." <<
Endl;
593 if (!fEvent.empty()) {
594 Log() << kINFO <<
"Erasing " << fEvent.size() <<
" previously stored events" <<
Endl;
604 if (
line.empty() ||
line.find(
"#") != std::string::npos) {
609 std::string::size_type pos=0;
610 while( (pos=
line.find(
',',pos)) != std::string::npos ) { count++; pos++; }
615 if (count < 3 || nvar != count - 2) {
616 Log() << kFATAL <<
"Missing comma delimeter(s)" <<
Endl;
626 std::string::size_type prev = 0;
633 if (!(
ipos > prev)) {
634 Log() << kFATAL <<
"Wrong substring limits" <<
Endl;
643 Log() << kFATAL <<
"Failed to parse string" <<
Endl;
653 weight = std::atof(
vstring.c_str());
659 Log() << kFATAL <<
"Wrong variable count" <<
Endl;
669 Log() << kINFO <<
"Read " << fEvent.size() <<
" events from text file" <<
Endl;
680 Log() << kINFO <<
"Starting WriteWeightsToStream(TFile &rf) function..." <<
Endl;
682 if (fEvent.empty()) {
683 Log() << kWARNING <<
"MethodKNN contains no events " <<
Endl;
689 tree->SetDirectory(
nullptr);
690 tree->Branch(
"event",
"TMVA::kNN::Event", &event);
693 for (kNN::EventVec::const_iterator it = fEvent.begin(); it != fEvent.end(); ++it) {
695 size += tree->Fill();
699 rf.WriteTObject(tree,
"knn",
"Overwrite");
704 Log() << kINFO <<
"Wrote " <<
size <<
"MB and " << fEvent.size()
705 <<
" events to ROOT file" <<
Endl;
716 Log() << kINFO <<
"Starting ReadWeightsFromStream(TFile &rf) function..." <<
Endl;
718 if (!fEvent.empty()) {
719 Log() << kINFO <<
"Erasing " << fEvent.size() <<
" previously stored events" <<
Endl;
726 Log() << kFATAL <<
"Failed to find knn tree" <<
Endl;
731 tree->SetBranchAddress(
"event", &event);
733 const Int_t nevent = tree->GetEntries();
736 for (
Int_t i = 0; i < nevent; ++i) {
737 size += tree->GetEntry(i);
738 fEvent.push_back(*event);
744 Log() << kINFO <<
"Read " <<
size <<
"MB and " << fEvent.size()
745 <<
" events from ROOT file" <<
Endl;
758 fout <<
" // not implemented for class: \"" << className <<
"\"" << std::endl;
759 fout <<
"};" << std::endl;
773 Log() <<
"The k-nearest neighbor (k-NN) algorithm is a multi-dimensional classification" <<
Endl
774 <<
"and regression algorithm. Similarly to other TMVA algorithms, k-NN uses a set of" <<
Endl
775 <<
"training events for which a classification category/regression target is known. " <<
Endl
776 <<
"The k-NN method compares a test event to all training events using a distance " <<
Endl
777 <<
"function, which is an Euclidean distance in a space defined by the input variables. "<<
Endl
778 <<
"The k-NN method, as implemented in TMVA, uses a kd-tree algorithm to perform a" <<
Endl
779 <<
"quick search for the k events with shortest distance to the test event. The method" <<
Endl
780 <<
"returns a fraction of signal events among the k neighbors. It is recommended" <<
Endl
781 <<
"that a histogram which stores the k-NN decision variable is binned with k+1 bins" <<
Endl
782 <<
"between 0 and 1." <<
Endl;
785 Log() <<
gTools().
Color(
"bold") <<
"--- Performance tuning via configuration options: "
788 Log() <<
"The k-NN method estimates a density of signal and background events in a "<<
Endl
789 <<
"neighborhood around the test event. The method assumes that the density of the " <<
Endl
790 <<
"signal and background events is uniform and constant within the neighborhood. " <<
Endl
791 <<
"k is an adjustable parameter and it determines an average size of the " <<
Endl
792 <<
"neighborhood. Small k values (less than 10) are sensitive to statistical " <<
Endl
793 <<
"fluctuations and large (greater than 100) values might not sufficiently capture " <<
Endl
794 <<
"local differences between events in the training set. The speed of the k-NN" <<
Endl
795 <<
"method also increases with larger values of k. " <<
Endl;
797 Log() <<
"The k-NN method assigns equal weight to all input variables. Different scales " <<
Endl
798 <<
"among the input variables is compensated using ScaleFrac parameter: the input " <<
Endl
799 <<
"variables are scaled so that the widths for central ScaleFrac*100% events are " <<
Endl
800 <<
"equal among all the input variables." <<
Endl;
803 Log() <<
gTools().
Color(
"bold") <<
"--- Additional configuration options: "
806 Log() <<
"The method inclues an option to use a Gaussian kernel to smooth out the k-NN" <<
Endl
807 <<
"response. The kernel re-weights events using a distance to the test event." <<
Endl;
823 return (prod * prod * prod);
833 Log() << kFATAL <<
"Mismatched vectors in Gaussian kernel function" <<
Endl;
846 if (!(
sigm_ > 0.0)) {
847 Log() << kFATAL <<
"Bad sigma value = " <<
sigm_ <<
Endl;
873 for (kNN::List::const_iterator
lit = rlist.begin();
lit != rlist.end(); ++
lit)
875 if (!(
lit->second > 0.0))
continue;
893 std::vector<Double_t>
rvec;
897 for (kNN::List::const_iterator
lit = rlist.begin();
lit != rlist.end(); ++
lit)
899 if (!(
lit->second > 0.0))
continue;
908 Log() << kFATAL <<
"Wrong number of variables, should never happen!" <<
Endl;
923 Log() << kFATAL <<
"Bad event kcount = " <<
kcount <<
Endl;
930 Log() << kFATAL <<
"Bad RMS value = " <<
rvec[
ivar] <<
Endl;
947 for (kNN::List::const_iterator
lit = rlist.begin();
lit != rlist.end(); ++
lit) {
953 if (node.
GetEvent().GetType() == 1) {
956 else if (node.
GetEvent().GetType() == 2) {
960 Log() << kFATAL <<
"Unknown type for training event" <<
Endl;
966 return fLDA.GetProb(
event_knn.GetVars(), 1);
#define REGISTER_METHOD(CLASS)
for example
std::vector< std::vector< Float_t > > LDAEvents
size_t size(const MatrixT &matrix)
retrieve the size of a square matrix
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void value
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t Atom_t Time_t type
const_iterator begin() const
const_iterator end() const
A ROOT file is an on-disk file, usually with extension .root, that stores objects in a file-system-li...
Class that contains all the data information.
Virtual base Class for all MVA method.
virtual void DeclareCompatibilityOptions()
options that are used ONLY for the READER to ensure backward compatibility they are hence without any...
Analysis of k-nearest neighbor.
void Init(void)
Initialization.
void MakeKNN(void)
create kNN
virtual ~MethodKNN(void)
destructor
virtual Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets)
FDA can handle classification with 2 classes and regression with one regression-target.
const std::vector< Double_t > getRMS(const kNN::List &rlist, const kNN::Event &event_knn) const
Get polynomial kernel radius.
const Ranking * CreateRanking()
no ranking available
void MakeClassSpecific(std::ostream &, const TString &) const
write specific classifier response
MethodKNN(const TString &jobName, const TString &methodTitle, DataSetInfo &theData, const TString &theOption="KNN")
standard constructor
void DeclareCompatibilityOptions()
options that are used ONLY for the READER to ensure backward compatibility
Double_t getKernelRadius(const kNN::List &rlist) const
Get polynomial kernel radius.
void GetHelpMessage() const
get help message text
const std::vector< Float_t > & GetRegressionValues()
Return vector of averages for target values of k-nearest neighbors.
void Train(void)
kNN training
void ReadWeightsFromStream(std::istream &istr)
read the weights
double getLDAValue(const kNN::List &rlist, const kNN::Event &event_knn)
Double_t PolnKernel(Double_t value) const
polynomial kernel
void ProcessOptions()
process the options specified by the user
void ReadWeightsFromXML(void *wghtnode)
void AddWeightsXMLTo(void *parent) const
write weights to XML
Double_t GetMvaValue(Double_t *err=nullptr, Double_t *errUpper=nullptr)
Compute classifier response.
void DeclareOptions()
MethodKNN options.
void WriteWeightsToStream(TFile &rf) const
save weights to ROOT file
Double_t GausKernel(const kNN::Event &event_knn, const kNN::Event &event, const std::vector< Double_t > &svec) const
Gaussian kernel.
Ranking for variables in method (implementation)
Singleton class for Global types used by TMVA.
This file contains binary tree and global function template that searches tree for k-nearest neigbors...
Double_t GetWeight() const
const T & GetEvent() const
std::vector< VarType > VarVec
virtual void Clear(Option_t *="")
A TTree represents a columnar dataset.
create variable transformations
MsgLogger & Endl(MsgLogger &ml)
Double_t Sqrt(Double_t x)
Returns the square root of x.
Short_t Abs(Short_t d)
Returns the absolute value of parameter Short_t d.