62 :
TMVA::MethodBase(jobName, Types::kKNN, methodTitle, theData, theOption, theTargetDir)
105 if (fModule)
delete fModule;
123 DeclareOptionRef(fnkNN = 20,
"nkNN",
"Number of k-nearest neighbors");
124 DeclareOptionRef(fBalanceDepth = 6,
"BalanceDepth",
"Binary tree balance depth");
125 DeclareOptionRef(fScaleFrac = 0.80,
"ScaleFrac",
"Fraction of events used to compute variable width");
126 DeclareOptionRef(fSigmaFact = 1.0,
"SigmaFact",
"Scale factor for sigma in Gaussian kernel");
127 DeclareOptionRef(fKernel =
"Gaus",
"Kernel",
"Use polynomial (=Poln) or Gaussian (=Gaus) kernel");
128 DeclareOptionRef(fTrim =
kFALSE,
"Trim",
"Use equal number of signal and background events");
129 DeclareOptionRef(fUseKernel =
kFALSE,
"UseKernel",
"Use polynomial kernel weight");
130 DeclareOptionRef(fUseWeight =
kTRUE,
"UseWeight",
"Use weight to count kNN events");
131 DeclareOptionRef(fUseLDA =
kFALSE,
"UseLDA",
"Use local linear discriminant - experimental feature");
139 DeclareOptionRef(fTreeOptDepth = 6,
"TreeOptDepth",
"Binary tree optimisation depth");
149 Log() <<
kWARNING <<
"kNN must be a positive integer: set kNN = " << fnkNN <<
Endl;
151 if (fScaleFrac < 0.0) {
153 Log() <<
kWARNING <<
"ScaleFrac can not be negative: set ScaleFrac = " << fScaleFrac <<
Endl;
155 if (fScaleFrac > 1.0) {
158 if (!(fBalanceDepth > 0)) {
160 Log() <<
kWARNING <<
"Optimize must be a positive integer: set Optimize = " << fBalanceDepth <<
Endl;
165 <<
" kNN = \n" << fnkNN
166 <<
" UseKernel = \n" << fUseKernel
167 <<
" SigmaFact = \n" << fSigmaFact
168 <<
" ScaleFrac = \n" << fScaleFrac
169 <<
" Kernel = \n" << fKernel
170 <<
" Trim = \n" << fTrim
171 <<
" Optimize = " << fBalanceDepth <<
Endl;
209 if (fScaleFrac > 0.0) {
216 Log() <<
kINFO <<
"Creating kd-tree with " << fEvent.size() <<
" events" <<
Endl;
218 for (kNN::EventVec::const_iterator event = fEvent.begin();
event != fEvent.end(); ++event) {
219 fModule->Add(*event);
223 fModule->Fill(static_cast<UInt_t>(fBalanceDepth),
224 static_cast<UInt_t>(100.0*fScaleFrac),
235 if (IsNormalised()) {
236 Log() <<
kINFO <<
"Input events are normalized - setting ScaleFrac to 0" <<
Endl;
240 if (!fEvent.empty()) {
241 Log() <<
kINFO <<
"Erasing " << fEvent.size() <<
" previously stored events" <<
Endl;
244 if (GetNVariables() < 1)
245 Log() <<
kFATAL <<
"MethodKNN::Train() - mismatched or wrong number of event variables" <<
Endl;
248 Log() <<
kINFO <<
"Reading " << GetNEvents() <<
" events" <<
Endl;
250 for (
UInt_t ievt = 0; ievt < GetNEvents(); ++ievt) {
252 const Event* evt_ = GetEvent(ievt);
256 if (IgnoreEventsWithNegWeightsInTraining() && weight <= 0)
continue;
259 for (
UInt_t ivar = 0; ivar < evt_ -> GetNVariables(); ++ivar) vvec[ivar] = evt_->
GetValue(ivar);
263 if (DataInfo().IsSignal(evt_)) {
264 fSumOfWeightsS += weight;
268 fSumOfWeightsB += weight;
275 kNN::Event event_knn(vvec, weight, event_type);
277 fEvent.push_back(event_knn);
281 <<
"Number of signal events " << fSumOfWeightsS << Endl
282 <<
"Number of background events " << fSumOfWeightsB <<
Endl;
294 NoErrorCalc(err, errUpper);
299 const Event *ev = GetEvent();
300 const Int_t nvar = GetNVariables();
306 for (
Int_t ivar = 0; ivar < nvar; ++ivar) {
314 fModule->Find(event_knn, knn + 2);
316 const kNN::List &rlist = fModule->GetkNNList();
317 if (rlist.size() != knn + 2) {
327 Bool_t use_gaus =
false, use_poln =
false;
329 if (fKernel ==
"Gaus") use_gaus =
true;
330 else if (fKernel ==
"Poln") use_poln =
true;
340 if (!(kradius > 0.0)) {
351 std::vector<Double_t> rms_vec;
355 if (rms_vec.empty() || rms_vec.size() != event_knn.
GetNVar()) {
362 Double_t weight_all = 0, weight_sig = 0, weight_bac = 0;
364 for (kNN::List::const_iterator lit = rlist.begin(); lit != rlist.end(); ++lit) {
371 if (lit->second < 0.0) {
372 Log() <<
kFATAL <<
"A neighbor has negative distance to query event" <<
Endl;
374 else if (!(lit->second > 0.0)) {
375 Log() <<
kVERBOSE <<
"A neighbor has zero distance to query event" <<
Endl;
383 if (fUseWeight) weight_all += evweight;
387 if (fUseWeight) weight_sig += evweight;
391 if (fUseWeight) weight_bac += evweight;
395 Log() <<
kFATAL <<
"Unknown type for training event" <<
Endl;
401 if (count_all >= knn) {
407 if (!(count_all > 0)) {
408 Log() <<
kFATAL <<
"Size kNN result list is not positive" <<
Endl;
413 if (count_all < knn) {
414 Log() <<
kDEBUG <<
"count_all and kNN have different size: " << count_all <<
" < " << knn <<
Endl;
418 if (!(weight_all > 0.0)) {
419 Log() <<
kFATAL <<
"kNN result total weight is not positive" <<
Endl;
423 return weight_sig/weight_all;
434 if( fRegressionReturnVal == 0 )
435 fRegressionReturnVal =
new std::vector<Float_t>;
437 fRegressionReturnVal->clear();
442 const Event *evt = GetEvent();
443 const Int_t nvar = GetNVariables();
445 std::vector<float> reg_vec;
449 for (
Int_t ivar = 0; ivar < nvar; ++ivar) {
457 fModule->Find(event_knn, knn + 2);
459 const kNN::List &rlist = fModule->GetkNNList();
460 if (rlist.size() != knn + 2) {
462 return *fRegressionReturnVal;
469 for (kNN::List::const_iterator lit = rlist.begin(); lit != rlist.end(); ++lit) {
476 if (reg_vec.empty()) {
480 for(
UInt_t ivar = 0; ivar < tvec.size(); ++ivar) {
481 if (fUseWeight) reg_vec[ivar] += tvec[ivar]*weight;
482 else reg_vec[ivar] += tvec[ivar];
485 if (fUseWeight) weight_all += weight;
491 if (count_all == knn) {
497 if (!(weight_all > 0.0)) {
498 Log() <<
kFATAL <<
"Total weight sum is not positive: " << weight_all <<
Endl;
499 return *fRegressionReturnVal;
502 for (
UInt_t ivar = 0; ivar < reg_vec.size(); ++ivar) {
503 reg_vec[ivar] /= weight_all;
507 fRegressionReturnVal->insert(fRegressionReturnVal->begin(), reg_vec.begin(), reg_vec.end());
509 return *fRegressionReturnVal;
526 if (fEvent.size()>0)
gTools().
AddAttr(wght,
"NVar",fEvent.begin()->GetNVar());
527 if (fEvent.size()>0)
gTools().
AddAttr(wght,
"NTgt",fEvent.begin()->GetNTgt());
529 for (kNN::EventVec::const_iterator event = fEvent.begin();
event != fEvent.end(); ++event) {
531 std::stringstream s(
"");
533 for (
UInt_t ivar = 0; ivar <
event->GetNVar(); ++ivar) {
534 if (ivar>0) s <<
" ";
535 s << std::scientific <<
event->GetVar(ivar);
538 for (
UInt_t itgt = 0; itgt <
event->GetNTgt(); ++itgt) {
539 s <<
" " << std::scientific <<
event->GetTgt(itgt);
552 UInt_t nvar = 0, ntgt = 0;
567 std::stringstream s(
gTools().GetContent(ch) );
569 for(
UInt_t ivar=0; ivar<nvar; ivar++)
572 for(
UInt_t itgt=0; itgt<ntgt; itgt++)
577 kNN::Event event_knn(vvec, evtWeight, evtType, tvec);
578 fEvent.push_back(event_knn);
590 Log() <<
kINFO <<
"Starting ReadWeightsFromStream(std::istream& is) function..." <<
Endl;
592 if (!fEvent.empty()) {
593 Log() <<
kINFO <<
"Erasing " << fEvent.size() <<
" previously stored events" <<
Endl;
601 std::getline(is, line);
603 if (line.empty() || line.find(
"#") != std::string::npos) {
608 std::string::size_type pos=0;
609 while( (pos=line.find(
',',pos)) != std::string::npos ) { count++; pos++; }
614 if (count < 3 || nvar != count - 2) {
625 std::string::size_type prev = 0;
627 for (std::string::size_type ipos = 0; ipos < line.size(); ++ipos) {
628 if (line[ipos] !=
',' && ipos + 1 != line.size()) {
632 if (!(ipos > prev)) {
636 std::string vstring = line.substr(prev, ipos - prev);
637 if (ipos + 1 == line.size()) {
638 vstring = line.substr(prev, ipos - prev + 1);
641 if (vstring.empty()) {
648 else if (vcount == 1) {
649 type = std::atoi(vstring.c_str());
651 else if (vcount == 2) {
652 weight = std::atof(vstring.c_str());
654 else if (vcount - 3 < vvec.size()) {
655 vvec[vcount - 3] = std::atof(vstring.c_str());
665 fEvent.push_back(
kNN::Event(vvec, weight, type));
668 Log() <<
kINFO <<
"Read " << fEvent.size() <<
" events from text file" <<
Endl;
679 Log() <<
kINFO <<
"Starting WriteWeightsToStream(TFile &rf) function..." <<
Endl;
681 if (fEvent.empty()) {
689 tree->
Branch(
"event",
"TMVA::kNN::Event", &event);
692 for (kNN::EventVec::const_iterator it = fEvent.begin(); it != fEvent.end(); ++it) {
694 size += tree->
Fill();
703 Log() <<
kINFO <<
"Wrote " << size <<
"MB and " << fEvent.size()
704 <<
" events to ROOT file" <<
Endl;
715 Log() <<
kINFO <<
"Starting ReadWeightsFromStream(TFile &rf) function..." <<
Endl;
717 if (!fEvent.empty()) {
718 Log() <<
kINFO <<
"Erasing " << fEvent.size() <<
" previously stored events" <<
Endl;
735 for (
Int_t i = 0; i < nevent; ++i) {
737 fEvent.push_back(*event);
743 Log() <<
kINFO <<
"Read " << size <<
"MB and " << fEvent.size()
744 <<
" events from ROOT file" <<
Endl;
757 fout <<
" // not implemented for class: \"" << className <<
"\"" << std::endl;
758 fout <<
"};" << std::endl;
772 Log() <<
"The k-nearest neighbor (k-NN) algorithm is a multi-dimensional classification" <<
Endl
773 <<
"and regression algorithm. Similarly to other TMVA algorithms, k-NN uses a set of" <<
Endl
774 <<
"training events for which a classification category/regression target is known. " <<
Endl
775 <<
"The k-NN method compares a test event to all training events using a distance " <<
Endl
776 <<
"function, which is an Euclidean distance in a space defined by the input variables. "<<
Endl
777 <<
"The k-NN method, as implemented in TMVA, uses a kd-tree algorithm to perform a" <<
Endl
778 <<
"quick search for the k events with shortest distance to the test event. The method" <<
Endl
779 <<
"returns a fraction of signal events among the k neighbors. It is recommended" <<
Endl
780 <<
"that a histogram which stores the k-NN decision variable is binned with k+1 bins" <<
Endl
781 <<
"between 0 and 1." <<
Endl;
784 Log() <<
gTools().
Color(
"bold") <<
"--- Performance tuning via configuration options: "
787 Log() <<
"The k-NN method estimates a density of signal and background events in a "<< Endl
788 <<
"neighborhood around the test event. The method assumes that the density of the " << Endl
789 <<
"signal and background events is uniform and constant within the neighborhood. " << Endl
790 <<
"k is an adjustable parameter and it determines an average size of the " << Endl
791 <<
"neighborhood. Small k values (less than 10) are sensitive to statistical " << Endl
792 <<
"fluctuations and large (greater than 100) values might not sufficiently capture " << Endl
793 <<
"local differences between events in the training set. The speed of the k-NN" << Endl
794 <<
"method also increases with larger values of k. " <<
Endl;
796 Log() <<
"The k-NN method assigns equal weight to all input variables. Different scales " << Endl
797 <<
"among the input variables is compensated using ScaleFrac parameter: the input " << Endl
798 <<
"variables are scaled so that the widths for central ScaleFrac*100% events are " << Endl
799 <<
"equal among all the input variables." <<
Endl;
802 Log() <<
gTools().
Color(
"bold") <<
"--- Additional configuration options: "
805 Log() <<
"The method inclues an option to use a Gaussian kernel to smooth out the k-NN" << Endl
806 <<
"response. The kernel re-weights events using a distance to the test event." <<
Endl;
816 if (!(avalue < 1.0)) {
820 const Double_t prod = 1.0 - avalue * avalue * avalue;
822 return (prod * prod * prod);
829 const kNN::Event &event,
const std::vector<Double_t> &svec)
const
831 if (event_knn.
GetNVar() !=
event.GetNVar() || event_knn.
GetNVar() != svec.size()) {
832 Log() <<
kFATAL <<
"Mismatched vectors in Gaussian kernel function" <<
Endl;
839 double sum_exp = 0.0;
841 for(
unsigned int ivar = 0; ivar < event_knn.
GetNVar(); ++ivar) {
843 const Double_t diff_ =
event.GetVar(ivar) - event_knn.
GetVar(ivar);
845 if (!(sigm_ > 0.0)) {
850 sum_exp += diff_*diff_/(2.0*sigm_*sigm_);
872 for (kNN::List::const_iterator lit = rlist.begin(); lit != rlist.end(); ++lit)
874 if (!(lit->second > 0.0))
continue;
876 if (kradius < lit->second || kradius < 0.0) kradius = lit->second;
879 if (kcount >= knn)
break;
892 std::vector<Double_t> rvec;
896 for (kNN::List::const_iterator lit = rlist.begin(); lit != rlist.end(); ++lit)
898 if (!(lit->second > 0.0))
continue;
901 const kNN::Event &event_ = node_-> GetEvent();
904 rvec.insert(rvec.end(), event_.
GetNVar(), 0.0);
906 else if (rvec.size() != event_.
GetNVar()) {
907 Log() <<
kFATAL <<
"Wrong number of variables, should never happen!" <<
Endl;
912 for(
unsigned int ivar = 0; ivar < event_.
GetNVar(); ++ivar) {
914 rvec[ivar] += diff_*diff_;
918 if (kcount >= knn)
break;
927 for(
unsigned int ivar = 0; ivar < rvec.size(); ++ivar) {
928 if (!(rvec[ivar] > 0.0)) {
946 for (kNN::List::const_iterator lit = rlist.begin(); lit != rlist.end(); ++lit) {
953 sig_vec.push_back(tvec);
956 bac_vec.push_back(tvec);
959 Log() <<
kFATAL <<
"Unknown type for training event" <<
Endl;
963 fLDA.Initialize(sig_vec, bac_vec);
965 return fLDA.GetProb(event_knn.
GetVars(), 1);
virtual void Clear(Option_t *="")
void ProcessOptions()
process the options specified by the user
MsgLogger & Endl(MsgLogger &ml)
void DeclareOptions()
MethodKNN options.
virtual Int_t Fill()
Fill all branches.
A ROOT file is a suite of consecutive data records (TKey instances) with a well defined format...
void Train(void)
kNN training
virtual TObject * Get(const char *namecycle)
Return pointer to object identified by namecycle.
void MakeKNN(void)
create kNN
virtual Int_t GetEntry(Long64_t entry=0, Int_t getall=0)
Read all branches of entry and return total number of bytes read.
Double_t GetWeight() const
return the event weight - depending on whether the flag IgnoreNegWeightsInTraining is or not...
void DeclareCompatibilityOptions()
options that are used ONLY for the READER to ensure backward compatibility
Float_t GetValue(UInt_t ivar) const
return value of i'th variable
virtual Int_t WriteTObject(const TObject *obj, const char *name=0, Option_t *option="", Int_t bufsize=0)
Write object obj to this directory.
void ReadWeightsFromStream(std::istream &istr)
read the weights
MethodKNN(const TString &jobName, const TString &methodTitle, DataSetInfo &theData, const TString &theOption="KNN", TDirectory *theTargetDir=NULL)
void Init(void)
Initialization.
virtual Int_t SetBranchAddress(const char *bname, void *add, TBranch **ptr=0)
Change branch address, dealing with clone trees properly.
Double_t PolnKernel(Double_t value) const
polynomial kernel
static Vc_ALWAYS_INLINE Vector< T > abs(const Vector< T > &x)
void WriteWeightsToStream(TFile &rf) const
save weights to ROOT file
const VarVec & GetTargets() const
void SetTargets(const VarVec &tvec)
const VarVec & GetVars() const
void MakeClassSpecific(std::ostream &, const TString &) const
write specific classifier response
std::vector< Float_t > & GetTargets()
Double_t GetMvaValue(Double_t *err=0, Double_t *errUpper=0)
Compute classifier response.
Double_t GausKernel(const kNN::Event &event_knn, const kNN::Event &event, const std::vector< Double_t > &svec) const
Gaussian kernel.
virtual ~MethodKNN(void)
destructor
ClassImp(TMVA::MethodKNN) TMVA
standard constructor
const Ranking * CreateRanking()
no ranking available
Double_t GetWeight() const
void ReadWeightsFromXML(void *wghtnode)
virtual void SetDirectory(TDirectory *dir)
Change the tree's directory.
Describe directory structure in memory.
Double_t GetWeight() const
const T & GetEvent() const
void AddWeightsXMLTo(void *parent) const
write weights to XML
virtual Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets)
FDA can handle classification with 2 classes and regression with one regression-target.
virtual Int_t Branch(TCollection *list, Int_t bufsize=32000, Int_t splitlevel=99, const char *name="")
Create one branch for each element in the collection.
#define REGISTER_METHOD(CLASS)
for example
Abstract ClassifierFactory template that handles arbitrary types.
VarType GetVar(UInt_t i) const
void GetHelpMessage() const
get help message text
virtual void DeclareCompatibilityOptions()
options that are used ONLY for the READER to ensure backward compatibility they are hence without any...
const std::vector< Double_t > getRMS(const kNN::List &rlist, const kNN::Event &event_knn) const
Get polynomial kernel radius.
std::vector< std::vector< Float_t > > LDAEvents
virtual Long64_t GetEntries() const
A TTree object has a header with a name and a title.
Double_t Sqrt(Double_t x)
const std::vector< Float_t > & GetRegressionValues()
Return vector of averages for target values of k-nearest neighbors.
std::vector< VarType > VarVec
double getLDAValue(const kNN::List &rlist, const kNN::Event &event_knn)
Double_t getKernelRadius(const kNN::List &rlist) const
Get polynomial kernel radius.