72 using std::stringstream;
81 TMVA::MethodFDA::MethodFDA( const
TString& jobName,
86 : MethodBase( jobName, Types::kFDA, methodTitle, theData, theOption, theTargetDir ),
91 fConvergerFitter( 0 ),
92 fSumOfWeightsSig( 0 ),
93 fSumOfWeightsBkg( 0 ),
95 fOutputDimensions( 0 )
110 fConvergerFitter( 0 ),
111 fSumOfWeightsSig( 0 ),
112 fSumOfWeightsBkg( 0 ),
114 fOutputDimensions( 0 )
128 fSumOfWeightsSig = 0;
129 fSumOfWeightsBkg = 0;
131 fFormulaStringP =
"";
132 fParRangeStringP =
"";
133 fFormulaStringT =
"";
134 fParRangeStringT =
"";
140 if (fMulticlassReturnVal ==
NULL) fMulticlassReturnVal =
new std::vector<Float_t>();
159 DeclareOptionRef( fFormulaStringP =
"(0)",
"Formula",
"The discrimination formula" );
160 DeclareOptionRef( fParRangeStringP =
"()",
"ParRanges",
"Parameter ranges" );
163 DeclareOptionRef( fFitMethod =
"MINUIT",
"FitMethod",
"Optimisation Method");
167 AddPreDefVal(
TString(
"MINUIT"));
169 DeclareOptionRef( fConverger =
"None",
"Converger",
"FitMethod uses Converger to improve result");
171 AddPreDefVal(
TString(
"MINUIT"));
180 fFormulaStringT = fFormulaStringP;
185 for (
UInt_t ipar=0; ipar<fNPars; ipar++) {
186 fFormulaStringT.ReplaceAll(
Form(
"(%i)",ipar),
Form(
"[%i]",ipar) );
190 for (
Int_t ipar=fNPars; ipar<1000; ipar++) {
191 if (fFormulaStringT.Contains(
Form(
"(%i)",ipar) ))
193 <<
"<CreateFormula> Formula contains expression: \"" <<
Form(
"(%i)",ipar) <<
"\", "
194 <<
"which cannot be attributed to a parameter; "
195 <<
"it may be that the number of variable ranges given via \"ParRanges\" "
196 <<
"does not match the number of parameters in the formula expression, please verify!"
201 for (
Int_t ivar=GetNvar()-1; ivar >= 0; ivar--) {
202 fFormulaStringT.ReplaceAll(
Form(
"x%i",ivar),
Form(
"[%i]",ivar+fNPars) );
206 for (
UInt_t ivar=GetNvar(); ivar<1000; ivar++) {
207 if (fFormulaStringT.Contains(
Form(
"x%i",ivar) ))
209 <<
"<CreateFormula> Formula contains expression: \"" <<
Form(
"x%i",ivar) <<
"\", "
210 <<
"which cannot be attributed to an input variable" <<
Endl;
213 Log() <<
"User-defined formula string : \"" << fFormulaStringP <<
"\"" <<
Endl;
214 Log() <<
"TFormula-compatible formula string: \"" << fFormulaStringT <<
"\"" <<
Endl;
215 Log() <<
"Creating and compiling formula" <<
Endl;
218 if (fFormula)
delete fFormula;
219 fFormula =
new TFormula(
"FDA_Formula", fFormulaStringT );
222 if (!fFormula->IsValid())
223 Log() <<
kFATAL <<
"<ProcessOptions> Formula expression could not be properly compiled" <<
Endl;
226 if (fFormula->GetNpar() > (
Int_t)(fNPars + GetNvar()))
227 Log() <<
kFATAL <<
"<ProcessOptions> Dubious number of parameters in formula expression: "
228 << fFormula->GetNpar() <<
" - compared to maximum allowed: " << fNPars + GetNvar() <<
Endl;
237 fParRangeStringT = fParRangeStringP;
240 fParRangeStringT.ReplaceAll(
" ",
"" );
241 fNPars = fParRangeStringT.CountChar(
')' );
245 Log() <<
kFATAL <<
"<ProcessOptions> Mismatch in parameter string: "
246 <<
"the number of parameters: " << fNPars <<
" != ranges defined: "
247 << parList->
GetSize() <<
"; the format of the \"ParRanges\" string "
248 <<
"must be: \"(-1.2,3.4);(-2.3,4.55);...\", "
249 <<
"where the numbers in \"(a,b)\" correspond to the a=min, b=max parameter ranges; "
250 <<
"each parameter defined in the function string must have a corresponding rang."
254 fParRange.resize( fNPars );
255 for (
UInt_t ipar=0; ipar<fNPars; ipar++) fParRange[ipar] = 0;
257 for (
UInt_t ipar=0; ipar<fNPars; ipar++) {
264 stringstream stmin;
Float_t pmin=0; stmin << pminS.
Data(); stmin >> pmin;
265 stringstream stmax;
Float_t pmax=0; stmax << pmaxS.
Data(); stmax >> pmax;
268 if (
TMath::Abs(pmax-pmin) < 1.e-30) pmax = pmin;
269 if (pmin > pmax)
Log() <<
kFATAL <<
"<ProcessOptions> max > min in interval for parameter: ["
270 << ipar <<
"] : [" << pmin <<
", " << pmax <<
"] " <<
Endl;
272 Log() <<
kINFO <<
"Create parameter interval for parameter " << ipar <<
" : [" << pmin <<
"," << pmax <<
"]" <<
Endl;
273 fParRange[ipar] =
new Interval( pmin, pmax );
282 fOutputDimensions = 1;
284 fOutputDimensions = DataInfo().GetNTargets();
286 fOutputDimensions = DataInfo().GetNClasses();
288 for(
Int_t dim = 1; dim < fOutputDimensions; ++dim ){
290 fParRange.push_back( fParRange.at(
par) );
297 if (fConverger ==
"MINUIT") {
298 fConvergerFitter =
new MinuitFitter( *
this,
Form(
"%s_Converger_Minuit", GetName()), fParRange, GetOptions() );
299 SetOptions(dynamic_cast<Configurable*>(fConvergerFitter)->GetOptions());
302 if(fFitMethod ==
"MC")
303 fFitter =
new MCFitter( *fConvergerFitter,
Form(
"%s_Fitter_MC", GetName()), fParRange, GetOptions() );
304 else if (fFitMethod ==
"GA")
305 fFitter =
new GeneticFitter( *fConvergerFitter,
Form(
"%s_Fitter_GA", GetName()), fParRange, GetOptions() );
306 else if (fFitMethod ==
"SA")
308 else if (fFitMethod ==
"MINUIT")
309 fFitter =
new MinuitFitter( *fConvergerFitter,
Form(
"%s_Fitter_Minuit", GetName()), fParRange, GetOptions() );
311 Log() <<
kFATAL <<
"<Train> Do not understand fit method:" << fFitMethod <<
Endl;
314 fFitter->CheckForUnusedOptions();
345 for (
UInt_t ipar=0; ipar<fParRange.size() && ipar<fNPars; ipar++) {
346 if (fParRange[ipar] != 0) {
delete fParRange[ipar]; fParRange[ipar] = 0; }
350 if (fFormula != 0) {
delete fFormula; fFormula = 0; }
361 fSumOfWeightsSig = 0;
362 fSumOfWeightsBkg = 0;
364 for (
UInt_t ievt=0; ievt<GetNEvents(); ievt++) {
367 const Event* ev = GetEvent(ievt);
372 if (!DoRegression()) {
373 if (DataInfo().IsSignal(ev)) { fSumOfWeightsSig +=
w; }
374 else { fSumOfWeightsBkg +=
w; }
380 if (!DoRegression()) {
381 if (fSumOfWeightsSig <= 0 || fSumOfWeightsBkg <= 0) {
382 Log() <<
kFATAL <<
"<Train> Troubles in sum of weights: "
383 << fSumOfWeightsSig <<
" (S) : " << fSumOfWeightsBkg <<
" (B)" <<
Endl;
386 else if (fSumOfWeights <= 0) {
387 Log() <<
kFATAL <<
"<Train> Troubles in sum of weights: "
388 << fSumOfWeights <<
Endl;
393 for (std::vector<Interval*>::const_iterator parIt = fParRange.begin(); parIt != fParRange.end(); parIt++) {
394 fBestPars.push_back( (*parIt)->GetMean() );
398 Double_t estimator = fFitter->Run( fBestPars );
401 PrintResults( fFitMethod, fBestPars, estimator );
403 delete fFitter; fFitter = 0;
404 if (fConvergerFitter!=0 && fConvergerFitter!=(
IFitterTarget*)
this) {
405 delete fConvergerFitter;
406 fConvergerFitter = 0;
417 Log() <<
"Results for parameter fit using \"" << fitter <<
"\" fitter:" <<
Endl;
418 std::vector<TString> parNames;
419 for (
UInt_t ipar=0; ipar<pars.size(); ipar++) parNames.push_back(
Form(
"Par(%i)",ipar ) );
421 Log() <<
"Discriminator expression: \"" << fFormulaStringP <<
"\"" <<
Endl;
422 Log() <<
"Value of estimator at minimum: " << estimator <<
Endl;
432 const Double_t sumOfWeights[] = { fSumOfWeightsBkg, fSumOfWeightsSig, fSumOfWeights };
439 if( DoRegression() ){
440 for (
UInt_t ievt=0; ievt<GetNEvents(); ievt++) {
444 for(
Int_t dim = 0; dim < fOutputDimensions; ++dim ){
446 result = InterpretFormula( ev, pars.begin(), pars.end() );
448 estimator[2] += deviation * ev->
GetWeight();
451 estimator[2] /= sumOfWeights[2];
455 }
else if( DoMulticlass() ){
456 for (
UInt_t ievt=0; ievt<GetNEvents(); ievt++) {
460 CalculateMulticlassValues( ev, pars, *fMulticlassReturnVal );
463 for(
Int_t dim = 0; dim < fOutputDimensions; ++dim ){
464 Double_t y = fMulticlassReturnVal->at(dim);
466 crossEntropy += t*
log(y);
468 estimator[2] += ev->
GetWeight()*crossEntropy;
470 estimator[2] /= sumOfWeights[2];
475 for (
UInt_t ievt=0; ievt<GetNEvents(); ievt++) {
479 desired = (DataInfo().IsSignal(ev) ? 1.0 : 0.0);
480 result = InterpretFormula( ev, pars.begin(), pars.end() );
484 estimator[0] /= sumOfWeights[0];
485 estimator[1] /= sumOfWeights[1];
487 return estimator[0] + estimator[1];
498 for( std::vector<Double_t>::iterator it = parBegin; it != parEnd; ++it ){
500 fFormula->SetParameter( ipar, (*it) );
503 for (
UInt_t ivar=0; ivar<GetNvar(); ivar++) fFormula->SetParameter( ivar+ipar, event->
GetValue(ivar) );
515 const Event* ev = GetEvent();
518 NoErrorCalc(err, errUpper);
520 return InterpretFormula( ev, fBestPars.begin(), fBestPars.end() );
527 if (fRegressionReturnVal ==
NULL) fRegressionReturnVal =
new std::vector<Float_t>();
528 fRegressionReturnVal->clear();
530 const Event* ev = GetEvent();
534 for(
Int_t dim = 0; dim < fOutputDimensions; ++dim ){
536 evT->
SetTarget(dim,InterpretFormula( ev, fBestPars.begin()+
offset, fBestPars.begin()+offset+fNPars ) );
538 const Event* evT2 = GetTransformationHandler().InverseTransform( evT );
539 fRegressionReturnVal->push_back(evT2->
GetTarget(0));
543 return (*fRegressionReturnVal);
551 if (fMulticlassReturnVal ==
NULL) fMulticlassReturnVal =
new std::vector<Float_t>();
552 fMulticlassReturnVal->clear();
553 std::vector<Float_t> temp;
558 CalculateMulticlassValues( evt, fBestPars, temp );
560 UInt_t nClasses = DataInfo().GetNClasses();
561 for(
UInt_t iClass=0; iClass<nClasses; iClass++){
563 for(
UInt_t j=0;j<nClasses;j++){
565 norm+=
exp(temp[j]-temp[iClass]);
567 (*fMulticlassReturnVal).push_back(1.0/(1.0+norm));
570 return (*fMulticlassReturnVal);
588 for(
Int_t dim = 0; dim < fOutputDimensions; ++dim ){
590 Double_t value = InterpretFormula( evt, parameters.begin()+
offset, parameters.begin()+offset+fNPars );
592 values.push_back( value );
612 fBestPars.resize( fNPars );
613 for (
UInt_t ipar=0; ipar<fNPars; ipar++) istr >> fBestPars[ipar];
625 for (
UInt_t ipar=0; ipar<fNPars*fOutputDimensions; ipar++) {
642 if(
gTools().HasAttr( wghtnode,
"NDim")) {
646 fOutputDimensions = 1;
650 fBestPars.resize( fNPars*fOutputDimensions );
660 if (ipar >= fNPars*fOutputDimensions)
Log() <<
kFATAL <<
"<ReadWeightsFromXML> index out of range: "
661 << ipar <<
" >= " << fNPars <<
Endl;
662 fBestPars[ipar] =
par;
679 fout <<
" double fParameter[" << fNPars <<
"];" << std::endl;
680 fout <<
"};" << std::endl;
681 fout <<
"" << std::endl;
682 fout <<
"inline void " << className <<
"::Initialize() " << std::endl;
683 fout <<
"{" << std::endl;
684 for(
UInt_t ipar=0; ipar<fNPars; ipar++) {
685 fout <<
" fParameter[" << ipar <<
"] = " << fBestPars[ipar] <<
";" << std::endl;
687 fout <<
"}" << std::endl;
689 fout <<
"inline double " << className <<
"::GetMvaValue__( const std::vector<double>& inputValues ) const" << std::endl;
690 fout <<
"{" << std::endl;
691 fout <<
" // interpret the formula" << std::endl;
695 for (
UInt_t ipar=0; ipar<fNPars; ipar++) {
700 for (
UInt_t ivar=0; ivar<GetNvar(); ivar++) {
704 fout <<
" double retval = " << str <<
";" << std::endl;
706 fout <<
" return retval; " << std::endl;
707 fout <<
"}" << std::endl;
709 fout <<
"// Clean up" << std::endl;
710 fout <<
"inline void " << className <<
"::Clear() " << std::endl;
711 fout <<
"{" << std::endl;
712 fout <<
" // nothing to clear" << std::endl;
713 fout <<
"}" << std::endl;
727 Log() <<
"The function discriminant analysis (FDA) is a classifier suitable " <<
Endl;
728 Log() <<
"to solve linear or simple nonlinear discrimination problems." <<
Endl;
730 Log() <<
"The user provides the desired function with adjustable parameters" <<
Endl;
731 Log() <<
"via the configuration option string, and FDA fits the parameters to" <<
Endl;
732 Log() <<
"it, requiring the signal (background) function value to be as close" <<
Endl;
733 Log() <<
"as possible to 1 (0). Its advantage over the more involved and" <<
Endl;
734 Log() <<
"automatic nonlinear discriminators is the simplicity and transparency " <<
Endl;
735 Log() <<
"of the discrimination expression. A shortcoming is that FDA will" <<
Endl;
736 Log() <<
"underperform for involved problems with complicated, phase space" <<
Endl;
737 Log() <<
"dependent nonlinear correlations." <<
Endl;
739 Log() <<
"Please consult the Users Guide for the format of the formula string" <<
Endl;
740 Log() <<
"and the allowed parameter ranges:" <<
Endl;
741 if (
gConfig().WriteOptionsReference()) {
742 Log() <<
"<a href=\"http://tmva.sourceforge.net/docu/TMVAUsersGuide.pdf\">"
743 <<
"http://tmva.sourceforge.net/docu/TMVAUsersGuide.pdf</a>" <<
Endl;
745 else Log() <<
"http://tmva.sourceforge.net/docu/TMVAUsersGuide.pdf" <<
Endl;
749 Log() <<
"The FDA performance depends on the complexity and fidelity of the" <<
Endl;
750 Log() <<
"user-defined discriminator function. As a general rule, it should" <<
Endl;
751 Log() <<
"be able to reproduce the discrimination power of any linear" <<
Endl;
752 Log() <<
"discriminant analysis. To reach into the nonlinear domain, it is" <<
Endl;
753 Log() <<
"useful to inspect the correlation profiles of the input variables," <<
Endl;
754 Log() <<
"and add quadratic and higher polynomial terms between variables as" <<
Endl;
755 Log() <<
"necessary. Comparison with more involved nonlinear classifiers can" <<
Endl;
756 Log() <<
"be used as a guide." <<
Endl;
760 Log() <<
"Depending on the function used, the choice of \"FitMethod\" is" <<
Endl;
761 Log() <<
"crucial for getting valuable solutions with FDA. As a guideline it" <<
Endl;
762 Log() <<
"is recommended to start with \"FitMethod=MINUIT\". When more complex" <<
Endl;
763 Log() <<
"functions are used where MINUIT does not converge to reasonable" <<
Endl;
764 Log() <<
"results, the user should switch to non-gradient FitMethods such" <<
Endl;
765 Log() <<
"as GeneticAlgorithm (GA) or Monte Carlo (MC). It might prove to be" <<
Endl;
766 Log() <<
"useful to combine GA (or MC) with MINUIT by setting the option" <<
Endl;
767 Log() <<
"\"Converger=MINUIT\". GA (MC) will then set the starting parameters" <<
Endl;
768 Log() <<
"for MINUIT such that the basic quality of GA (MC) of finding global" <<
Endl;
769 Log() <<
"minima is combined with the efficacy of MINUIT of finding local" <<
Endl;
void Init(void)
default initialisation
void DeclareOptions()
define the options (their key words) that can be set in the option string
MsgLogger & Endl(MsgLogger &ml)
void ClearAll()
delete and clear all class members
Collectable string class.
TString & ReplaceAll(const TString &s1, const TString &s2)
std::vector< double > values
Double_t InterpretFormula(const Event *, std::vector< Double_t >::iterator begin, std::vector< Double_t >::iterator end)
formula interpretation
void CreateFormula()
translate formula string into TFormula, and parameter string into par ranges
void SetTarget(UInt_t itgt, Float_t value)
set the target value (dimension itgt) to value
Double_t GetWeight() const
return the event weight - depending on whether the flag IgnoreNegWeightsInTraining is or not...
virtual TObject * At(Int_t idx) const
Returns the object at position idx. Returns 0 if idx is out of range.
LongDouble_t Power(LongDouble_t x, LongDouble_t y)
void PrintResults(const TString &, std::vector< Double_t > &, const Double_t) const
display fit parameters check maximum length of variable name
const char * Data() const
Double_t GetMvaValue(Double_t *err=0, Double_t *errUpper=0)
returns MVA value for given event
Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets)
FDA can handle classification with 2 classes and regression with one regression-target.
ClassImp(TMVA::MethodFDA) TMVA
standard constructor
void CalculateMulticlassValues(const TMVA::Event *&evt, std::vector< Double_t > ¶meters, std::vector< Float_t > &values)
calculate the values for multiclass
char * Form(const char *fmt,...)
void ReadWeightsFromXML(void *wghtnode)
read coefficients from xml weight file
void Train(void)
FDA training.
void ProcessOptions()
the option string is decoded, for availabel options see "DeclareOptions"
void MakeClassSpecific(std::ostream &, const TString &) const
write FDA-specific classifier response
virtual Int_t GetSize() const
Describe directory structure in memory.
Double_t EstimatorFunction(std::vector< Double_t > &)
compute estimator for given parameter set (to be minimised) const Double_t sumOfWeights[] = { fSumOfW...
Float_t GetValue(UInt_t ivar) const
return value of i'th variable
void ReadWeightsFromStream(std::istream &i)
read back the training results from a file (stream)
MethodFDA(const TString &jobName, const TString &methodTitle, DataSetInfo &theData, const TString &theOption="", TDirectory *theTargetDir=0)
Float_t GetTarget(UInt_t itgt) const
#define REGISTER_METHOD(CLASS)
for example
virtual const std::vector< Float_t > & GetRegressionValues()
virtual const std::vector< Float_t > & GetMulticlassValues()
void GetHelpMessage() const
get help message text
double norm(double *x, double *p)
virtual ~MethodFDA(void)
destructor
Ssiz_t First(char c) const
Find first occurrence of a character c.
void AddWeightsXMLTo(void *parent) const
create XML description for LD classification and regression (for arbitrary number of output classes/t...