107 if (fModule)
delete fModule;
125 DeclareOptionRef(fnkNN = 20,
"nkNN",
"Number of k-nearest neighbors");
126 DeclareOptionRef(fBalanceDepth = 6,
"BalanceDepth",
"Binary tree balance depth");
127 DeclareOptionRef(fScaleFrac = 0.80,
"ScaleFrac",
"Fraction of events used to compute variable width");
128 DeclareOptionRef(fSigmaFact = 1.0,
"SigmaFact",
"Scale factor for sigma in Gaussian kernel");
129 DeclareOptionRef(fKernel =
"Gaus",
"Kernel",
"Use polynomial (=Poln) or Gaussian (=Gaus) kernel");
130 DeclareOptionRef(fTrim =
kFALSE,
"Trim",
"Use equal number of signal and background events");
131 DeclareOptionRef(fUseKernel =
kFALSE,
"UseKernel",
"Use polynomial kernel weight");
132 DeclareOptionRef(fUseWeight =
kTRUE,
"UseWeight",
"Use weight to count kNN events");
133 DeclareOptionRef(fUseLDA =
kFALSE,
"UseLDA",
"Use local linear discriminant - experimental feature");
141 DeclareOptionRef(fTreeOptDepth = 6,
"TreeOptDepth",
"Binary tree optimisation depth");
151 Log() << kWARNING <<
"kNN must be a positive integer: set kNN = " << fnkNN <<
Endl;
153 if (fScaleFrac < 0.0) {
155 Log() << kWARNING <<
"ScaleFrac can not be negative: set ScaleFrac = " << fScaleFrac <<
Endl;
157 if (fScaleFrac > 1.0) {
160 if (!(fBalanceDepth > 0)) {
162 Log() << kWARNING <<
"Optimize must be a positive integer: set Optimize = " << fBalanceDepth <<
Endl;
167 <<
" kNN = \n" << fnkNN
168 <<
" UseKernel = \n" << fUseKernel
169 <<
" SigmaFact = \n" << fSigmaFact
170 <<
" ScaleFrac = \n" << fScaleFrac
171 <<
" Kernel = \n" << fKernel
172 <<
" Trim = \n" << fTrim
173 <<
" Optimize = " << fBalanceDepth <<
Endl;
205 Log() << kFATAL <<
"ModulekNN is not created" <<
Endl;
211 if (fScaleFrac > 0.0) {
218 Log() << kINFO <<
"Creating kd-tree with " << fEvent.size() <<
" events" <<
Endl;
220 for (kNN::EventVec::const_iterator event = fEvent.begin();
event != fEvent.end(); ++event) {
221 fModule->Add(*event);
225 fModule->Fill(
static_cast<UInt_t>(fBalanceDepth),
226 static_cast<UInt_t>(100.0*fScaleFrac),
235 Log() << kHEADER <<
"<Train> start..." <<
Endl;
237 if (IsNormalised()) {
238 Log() << kINFO <<
"Input events are normalized - setting ScaleFrac to 0" <<
Endl;
242 if (!fEvent.empty()) {
243 Log() << kINFO <<
"Erasing " << fEvent.size() <<
" previously stored events" <<
Endl;
246 if (GetNVariables() < 1)
247 Log() << kFATAL <<
"MethodKNN::Train() - mismatched or wrong number of event variables" <<
Endl;
250 Log() << kINFO <<
"Reading " << GetNEvents() <<
" events" <<
Endl;
258 if (IgnoreEventsWithNegWeightsInTraining() && weight <= 0)
continue;
265 if (DataInfo().IsSignal(
evt_)) {
266 fSumOfWeightsS += weight;
270 fSumOfWeightsB += weight;
283 <<
"Number of signal events " << fSumOfWeightsS <<
Endl
284 <<
"Number of background events " << fSumOfWeightsB <<
Endl;
304 const Int_t nvar = GetNVariables();
320 const kNN::List &rlist = fModule->GetkNNList();
321 if (rlist.size() !=
knn + 2) {
322 Log() << kFATAL <<
"kNN result list is empty" <<
Endl;
333 if (fKernel ==
"Gaus")
use_gaus =
true;
334 else if (fKernel ==
"Poln")
use_poln =
true;
345 Log() << kFATAL <<
"kNN radius is not positive" <<
Endl;
360 Log() << kFATAL <<
"Failed to compute RMS vector" <<
Endl;
368 for (kNN::List::const_iterator
lit = rlist.begin();
lit != rlist.end(); ++
lit) {
375 if (
lit->second < 0.0) {
376 Log() << kFATAL <<
"A neighbor has negative distance to query event" <<
Endl;
378 else if (!(
lit->second > 0.0)) {
379 Log() << kVERBOSE <<
"A neighbor has zero distance to query event" <<
Endl;
390 if (node.
GetEvent().GetType() == 1) {
394 else if (node.
GetEvent().GetType() == 2) {
397 Log() << kFATAL <<
"Unknown type for training event" <<
Endl;
410 Log() << kFATAL <<
"Size kNN result list is not positive" <<
Endl;
416 Log() << kDEBUG <<
"count_all and kNN have different size: " <<
count_all <<
" < " <<
knn <<
Endl;
421 Log() << kFATAL <<
"kNN result total weight is not positive" <<
Endl;
434 if( fRegressionReturnVal == 0 )
435 fRegressionReturnVal =
new std::vector<Float_t>;
437 fRegressionReturnVal->clear();
443 const Int_t nvar = GetNVariables();
459 const kNN::List &rlist = fModule->GetkNNList();
460 if (rlist.size() !=
knn + 2) {
461 Log() << kFATAL <<
"kNN result list is empty" <<
Endl;
462 return *fRegressionReturnVal;
469 for (kNN::List::const_iterator
lit = rlist.begin();
lit != rlist.end(); ++
lit) {
498 Log() << kFATAL <<
"Total weight sum is not positive: " <<
weight_all <<
Endl;
499 return *fRegressionReturnVal;
509 return *fRegressionReturnVal;
529 for (kNN::EventVec::const_iterator event = fEvent.begin();
event != fEvent.end(); ++event) {
531 std::stringstream s(
"");
534 if (
ivar>0) s <<
" ";
535 s << std::scientific <<
event->GetVar(
ivar);
539 s <<
" " << std::scientific <<
event->GetTgt(
itgt);
567 std::stringstream s(
gTools().GetContent(ch) );
590 Log() << kINFO <<
"Starting ReadWeightsFromStream(std::istream& is) function..." <<
Endl;
592 if (!fEvent.empty()) {
593 Log() << kINFO <<
"Erasing " << fEvent.size() <<
" previously stored events" <<
Endl;
603 if (
line.empty() ||
line.find(
"#") != std::string::npos) {
608 std::string::size_type pos=0;
609 while( (pos=
line.find(
',',pos)) != std::string::npos ) { count++; pos++; }
614 if (count < 3 || nvar != count - 2) {
615 Log() << kFATAL <<
"Missing comma delimeter(s)" <<
Endl;
625 std::string::size_type prev = 0;
632 if (!(
ipos > prev)) {
633 Log() << kFATAL <<
"Wrong substring limits" <<
Endl;
642 Log() << kFATAL <<
"Failed to parse string" <<
Endl;
652 weight = std::atof(
vstring.c_str());
658 Log() << kFATAL <<
"Wrong variable count" <<
Endl;
668 Log() << kINFO <<
"Read " << fEvent.size() <<
" events from text file" <<
Endl;
679 Log() << kINFO <<
"Starting WriteWeightsToStream(TFile &rf) function..." <<
Endl;
681 if (fEvent.empty()) {
682 Log() << kWARNING <<
"MethodKNN contains no events " <<
Endl;
688 tree->SetDirectory(
nullptr);
689 tree->Branch(
"event",
"TMVA::kNN::Event", &event);
692 for (kNN::EventVec::const_iterator it = fEvent.begin(); it != fEvent.end(); ++it) {
694 size += tree->Fill();
698 rf.WriteTObject(tree,
"knn",
"Overwrite");
703 Log() << kINFO <<
"Wrote " <<
size <<
"MB and " << fEvent.size()
704 <<
" events to ROOT file" <<
Endl;
715 Log() << kINFO <<
"Starting ReadWeightsFromStream(TFile &rf) function..." <<
Endl;
717 if (!fEvent.empty()) {
718 Log() << kINFO <<
"Erasing " << fEvent.size() <<
" previously stored events" <<
Endl;
725 Log() << kFATAL <<
"Failed to find knn tree" <<
Endl;
730 tree->SetBranchAddress(
"event", &event);
732 const Int_t nevent = tree->GetEntries();
735 for (
Int_t i = 0; i < nevent; ++i) {
736 size += tree->GetEntry(i);
737 fEvent.push_back(*event);
743 Log() << kINFO <<
"Read " <<
size <<
"MB and " << fEvent.size()
744 <<
" events from ROOT file" <<
Endl;
757 fout <<
" // not implemented for class: \"" << className <<
"\"" << std::endl;
758 fout <<
"};" << std::endl;
772 Log() <<
"The k-nearest neighbor (k-NN) algorithm is a multi-dimensional classification" <<
Endl
773 <<
"and regression algorithm. Similarly to other TMVA algorithms, k-NN uses a set of" <<
Endl
774 <<
"training events for which a classification category/regression target is known. " <<
Endl
775 <<
"The k-NN method compares a test event to all training events using a distance " <<
Endl
776 <<
"function, which is an Euclidean distance in a space defined by the input variables. "<<
Endl
777 <<
"The k-NN method, as implemented in TMVA, uses a kd-tree algorithm to perform a" <<
Endl
778 <<
"quick search for the k events with shortest distance to the test event. The method" <<
Endl
779 <<
"returns a fraction of signal events among the k neighbors. It is recommended" <<
Endl
780 <<
"that a histogram which stores the k-NN decision variable is binned with k+1 bins" <<
Endl
781 <<
"between 0 and 1." <<
Endl;
784 Log() <<
gTools().
Color(
"bold") <<
"--- Performance tuning via configuration options: "
787 Log() <<
"The k-NN method estimates a density of signal and background events in a "<<
Endl
788 <<
"neighborhood around the test event. The method assumes that the density of the " <<
Endl
789 <<
"signal and background events is uniform and constant within the neighborhood. " <<
Endl
790 <<
"k is an adjustable parameter and it determines an average size of the " <<
Endl
791 <<
"neighborhood. Small k values (less than 10) are sensitive to statistical " <<
Endl
792 <<
"fluctuations and large (greater than 100) values might not sufficiently capture " <<
Endl
793 <<
"local differences between events in the training set. The speed of the k-NN" <<
Endl
794 <<
"method also increases with larger values of k. " <<
Endl;
796 Log() <<
"The k-NN method assigns equal weight to all input variables. Different scales " <<
Endl
797 <<
"among the input variables is compensated using ScaleFrac parameter: the input " <<
Endl
798 <<
"variables are scaled so that the widths for central ScaleFrac*100% events are " <<
Endl
799 <<
"equal among all the input variables." <<
Endl;
802 Log() <<
gTools().
Color(
"bold") <<
"--- Additional configuration options: "
805 Log() <<
"The method inclues an option to use a Gaussian kernel to smooth out the k-NN" <<
Endl
806 <<
"response. The kernel re-weights events using a distance to the test event." <<
Endl;
822 return (prod * prod * prod);
832 Log() << kFATAL <<
"Mismatched vectors in Gaussian kernel function" <<
Endl;
845 if (!(
sigm_ > 0.0)) {
846 Log() << kFATAL <<
"Bad sigma value = " <<
sigm_ <<
Endl;
872 for (kNN::List::const_iterator
lit = rlist.begin();
lit != rlist.end(); ++
lit)
874 if (!(
lit->second > 0.0))
continue;
892 std::vector<Double_t>
rvec;
896 for (kNN::List::const_iterator
lit = rlist.begin();
lit != rlist.end(); ++
lit)
898 if (!(
lit->second > 0.0))
continue;
907 Log() << kFATAL <<
"Wrong number of variables, should never happen!" <<
Endl;
922 Log() << kFATAL <<
"Bad event kcount = " <<
kcount <<
Endl;
929 Log() << kFATAL <<
"Bad RMS value = " <<
rvec[
ivar] <<
Endl;
946 for (kNN::List::const_iterator
lit = rlist.begin();
lit != rlist.end(); ++
lit) {
952 if (node.
GetEvent().GetType() == 1) {
955 else if (node.
GetEvent().GetType() == 2) {
959 Log() << kFATAL <<
"Unknown type for training event" <<
Endl;
965 return fLDA.GetProb(
event_knn.GetVars(), 1);
#define REGISTER_METHOD(CLASS)
for example
std::vector< std::vector< Float_t > > LDAEvents
size_t size(const MatrixT &matrix)
retrieve the size of a square matrix
short Short_t
Signed Short integer 2 bytes (short)
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void value
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t Atom_t Time_t type
const_iterator begin() const
const_iterator end() const
A ROOT file is an on-disk file, usually with extension .root, that stores objects in a file-system-li...
Class that contains all the data information.
Virtual base Class for all MVA method.
virtual void DeclareCompatibilityOptions()
options that are used ONLY for the READER to ensure backward compatibility they are hence without any...
Analysis of k-nearest neighbor.
void MakeKNN(void)
create kNN
virtual ~MethodKNN(void)
destructor
const std::vector< Double_t > getRMS(const kNN::List &rlist, const kNN::Event &event_knn) const
Get polynomial kernel radius.
const Ranking * CreateRanking() override
no ranking available
void DeclareOptions() override
MethodKNN options.
MethodKNN(const TString &jobName, const TString &methodTitle, DataSetInfo &theData, const TString &theOption="KNN")
standard constructor
Double_t getKernelRadius(const kNN::List &rlist) const
Get polynomial kernel radius.
void Train(void) override
kNN training
double getLDAValue(const kNN::List &rlist, const kNN::Event &event_knn)
void ProcessOptions() override
process the options specified by the user
Double_t PolnKernel(Double_t value) const
polynomial kernel
void DeclareCompatibilityOptions() override
options that are used ONLY for the READER to ensure backward compatibility
void ReadWeightsFromStream(std::istream &istr) override
read the weights
void GetHelpMessage() const override
get help message text
void MakeClassSpecific(std::ostream &, const TString &) const override
write specific classifier response
Double_t GetMvaValue(Double_t *err=nullptr, Double_t *errUpper=nullptr) override
Compute classifier response.
Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets) override
FDA can handle classification with 2 classes and regression with one regression-target.
void Init(void) override
Initialization.
void WriteWeightsToStream(TFile &rf) const
save weights to ROOT file
Double_t GausKernel(const kNN::Event &event_knn, const kNN::Event &event, const std::vector< Double_t > &svec) const
Gaussian kernel.
void ReadWeightsFromXML(void *wghtnode) override
const std::vector< Float_t > & GetRegressionValues() override
Return vector of averages for target values of k-nearest neighbors.
void AddWeightsXMLTo(void *parent) const override
write weights to XML
Ranking for variables in method (implementation)
Singleton class for Global types used by TMVA.
This file contains binary tree and global function template that searches tree for k-nearest neigbors...
Double_t GetWeight() const
const T & GetEvent() const
std::vector< VarType > VarVec
virtual void Clear(Option_t *="")
A TTree represents a columnar dataset.
create variable transformations
MsgLogger & Endl(MsgLogger &ml)
Double_t Sqrt(Double_t x)
Returns the square root of x.
Short_t Abs(Short_t d)
Returns the absolute value of parameter Short_t d.