64 using std::stringstream;
78 : MethodBase( jobName, Types::kFDA, methodTitle, theData, theOption, theTargetDir ),
83 fConvergerFitter( 0 ),
84 fSumOfWeightsSig( 0 ),
85 fSumOfWeightsBkg( 0 ),
87 fOutputDimensions( 0 )
102 fConvergerFitter( 0 ),
103 fSumOfWeightsSig( 0 ),
104 fSumOfWeightsBkg( 0 ),
106 fOutputDimensions( 0 )
120 fSumOfWeightsSig = 0;
121 fSumOfWeightsBkg = 0;
123 fFormulaStringP =
"";
124 fParRangeStringP =
"";
125 fFormulaStringT =
"";
126 fParRangeStringT =
"";
132 if (fMulticlassReturnVal ==
NULL) fMulticlassReturnVal =
new std::vector<Float_t>();
151 DeclareOptionRef( fFormulaStringP =
"(0)",
"Formula",
"The discrimination formula" );
152 DeclareOptionRef( fParRangeStringP =
"()",
"ParRanges",
"Parameter ranges" );
155 DeclareOptionRef( fFitMethod =
"MINUIT",
"FitMethod",
"Optimisation Method");
159 AddPreDefVal(
TString(
"MINUIT"));
161 DeclareOptionRef( fConverger =
"None",
"Converger",
"FitMethod uses Converger to improve result");
163 AddPreDefVal(
TString(
"MINUIT"));
172 fFormulaStringT = fFormulaStringP;
177 for (
UInt_t ipar=0; ipar<fNPars; ipar++) {
178 fFormulaStringT.ReplaceAll(
Form(
"(%i)",ipar),
Form(
"[%i]",ipar) );
182 for (
Int_t ipar=fNPars; ipar<1000; ipar++) {
183 if (fFormulaStringT.Contains(
Form(
"(%i)",ipar) ))
185 <<
"<CreateFormula> Formula contains expression: \"" <<
Form(
"(%i)",ipar) <<
"\", "
186 <<
"which cannot be attributed to a parameter; "
187 <<
"it may be that the number of variable ranges given via \"ParRanges\" "
188 <<
"does not match the number of parameters in the formula expression, please verify!"
193 for (
Int_t ivar=GetNvar()-1; ivar >= 0; ivar--) {
194 fFormulaStringT.ReplaceAll(
Form(
"x%i",ivar),
Form(
"[%i]",ivar+fNPars) );
198 for (
UInt_t ivar=GetNvar(); ivar<1000; ivar++) {
199 if (fFormulaStringT.Contains(
Form(
"x%i",ivar) ))
201 <<
"<CreateFormula> Formula contains expression: \"" <<
Form(
"x%i",ivar) <<
"\", "
202 <<
"which cannot be attributed to an input variable" <<
Endl;
205 Log() <<
"User-defined formula string : \"" << fFormulaStringP <<
"\"" <<
Endl;
206 Log() <<
"TFormula-compatible formula string: \"" << fFormulaStringT <<
"\"" <<
Endl;
207 Log() <<
"Creating and compiling formula" <<
Endl;
210 if (fFormula)
delete fFormula;
211 fFormula =
new TFormula(
"FDA_Formula", fFormulaStringT );
214 if (!fFormula->IsValid())
215 Log() <<
kFATAL <<
"<ProcessOptions> Formula expression could not be properly compiled" <<
Endl;
218 if (fFormula->GetNpar() > (
Int_t)(fNPars + GetNvar()))
219 Log() <<
kFATAL <<
"<ProcessOptions> Dubious number of parameters in formula expression: "
220 << fFormula->GetNpar() <<
" - compared to maximum allowed: " << fNPars + GetNvar() <<
Endl;
229 fParRangeStringT = fParRangeStringP;
232 fParRangeStringT.ReplaceAll(
" ",
"" );
233 fNPars = fParRangeStringT.CountChar(
')' );
237 Log() <<
kFATAL <<
"<ProcessOptions> Mismatch in parameter string: "
238 <<
"the number of parameters: " << fNPars <<
" != ranges defined: "
239 << parList->
GetSize() <<
"; the format of the \"ParRanges\" string "
240 <<
"must be: \"(-1.2,3.4);(-2.3,4.55);...\", "
241 <<
"where the numbers in \"(a,b)\" correspond to the a=min, b=max parameter ranges; "
242 <<
"each parameter defined in the function string must have a corresponding rang."
246 fParRange.resize( fNPars );
247 for (
UInt_t ipar=0; ipar<fNPars; ipar++) fParRange[ipar] = 0;
249 for (
UInt_t ipar=0; ipar<fNPars; ipar++) {
256 stringstream stmin;
Float_t pmin=0; stmin << pminS.
Data(); stmin >> pmin;
257 stringstream stmax;
Float_t pmax=0; stmax << pmaxS.
Data(); stmax >> pmax;
260 if (
TMath::Abs(pmax-pmin) < 1.e-30) pmax = pmin;
261 if (pmin > pmax)
Log() <<
kFATAL <<
"<ProcessOptions> max > min in interval for parameter: ["
262 << ipar <<
"] : [" << pmin <<
", " << pmax <<
"] " <<
Endl;
264 Log() <<
kINFO <<
"Create parameter interval for parameter " << ipar <<
" : [" << pmin <<
"," << pmax <<
"]" <<
Endl;
265 fParRange[ipar] =
new Interval( pmin, pmax );
274 fOutputDimensions = 1;
276 fOutputDimensions = DataInfo().GetNTargets();
278 fOutputDimensions = DataInfo().GetNClasses();
280 for(
Int_t dim = 1; dim < fOutputDimensions; ++dim ){
282 fParRange.push_back( fParRange.at(
par) );
289 if (fConverger ==
"MINUIT") {
290 fConvergerFitter =
new MinuitFitter( *
this,
Form(
"%s_Converger_Minuit", GetName()), fParRange, GetOptions() );
291 SetOptions(dynamic_cast<Configurable*>(fConvergerFitter)->GetOptions());
294 if(fFitMethod ==
"MC")
295 fFitter =
new MCFitter( *fConvergerFitter,
Form(
"%s_Fitter_MC", GetName()), fParRange, GetOptions() );
296 else if (fFitMethod ==
"GA")
297 fFitter =
new GeneticFitter( *fConvergerFitter,
Form(
"%s_Fitter_GA", GetName()), fParRange, GetOptions() );
298 else if (fFitMethod ==
"SA")
300 else if (fFitMethod ==
"MINUIT")
301 fFitter =
new MinuitFitter( *fConvergerFitter,
Form(
"%s_Fitter_Minuit", GetName()), fParRange, GetOptions() );
303 Log() <<
kFATAL <<
"<Train> Do not understand fit method:" << fFitMethod <<
Endl;
306 fFitter->CheckForUnusedOptions();
337 for (
UInt_t ipar=0; ipar<fParRange.size() && ipar<fNPars; ipar++) {
338 if (fParRange[ipar] != 0) {
delete fParRange[ipar]; fParRange[ipar] = 0; }
342 if (fFormula != 0) {
delete fFormula; fFormula = 0; }
353 fSumOfWeightsSig = 0;
354 fSumOfWeightsBkg = 0;
356 for (
UInt_t ievt=0; ievt<GetNEvents(); ievt++) {
359 const Event* ev = GetEvent(ievt);
364 if (!DoRegression()) {
365 if (DataInfo().IsSignal(ev)) { fSumOfWeightsSig += w; }
366 else { fSumOfWeightsBkg += w; }
372 if (!DoRegression()) {
373 if (fSumOfWeightsSig <= 0 || fSumOfWeightsBkg <= 0) {
374 Log() <<
kFATAL <<
"<Train> Troubles in sum of weights: "
375 << fSumOfWeightsSig <<
" (S) : " << fSumOfWeightsBkg <<
" (B)" <<
Endl;
378 else if (fSumOfWeights <= 0) {
379 Log() <<
kFATAL <<
"<Train> Troubles in sum of weights: "
380 << fSumOfWeights <<
Endl;
385 for (std::vector<Interval*>::const_iterator parIt = fParRange.begin(); parIt != fParRange.end(); parIt++) {
386 fBestPars.push_back( (*parIt)->GetMean() );
390 Double_t estimator = fFitter->Run( fBestPars );
393 PrintResults( fFitMethod, fBestPars, estimator );
395 delete fFitter; fFitter = 0;
396 if (fConvergerFitter!=0 && fConvergerFitter!=(
IFitterTarget*)
this) {
397 delete fConvergerFitter;
398 fConvergerFitter = 0;
409 Log() <<
"Results for parameter fit using \"" << fitter <<
"\" fitter:" <<
Endl;
410 std::vector<TString> parNames;
411 for (
UInt_t ipar=0; ipar<pars.size(); ipar++) parNames.push_back(
Form(
"Par(%i)",ipar ) );
413 Log() <<
"Discriminator expression: \"" << fFormulaStringP <<
"\"" <<
Endl;
414 Log() <<
"Value of estimator at minimum: " << estimator <<
Endl;
424 const Double_t sumOfWeights[] = { fSumOfWeightsBkg, fSumOfWeightsSig, fSumOfWeights };
431 if( DoRegression() ){
432 for (
UInt_t ievt=0; ievt<GetNEvents(); ievt++) {
436 for(
Int_t dim = 0; dim < fOutputDimensions; ++dim ){
438 result = InterpretFormula( ev, pars.begin(), pars.end() );
440 estimator[2] += deviation * ev->
GetWeight();
443 estimator[2] /= sumOfWeights[2];
447 }
else if( DoMulticlass() ){
448 for (
UInt_t ievt=0; ievt<GetNEvents(); ievt++) {
452 CalculateMulticlassValues( ev, pars, *fMulticlassReturnVal );
455 for(
Int_t dim = 0; dim < fOutputDimensions; ++dim ){
456 Double_t y = fMulticlassReturnVal->at(dim);
458 crossEntropy += t*
log(y);
460 estimator[2] += ev->
GetWeight()*crossEntropy;
462 estimator[2] /= sumOfWeights[2];
467 for (
UInt_t ievt=0; ievt<GetNEvents(); ievt++) {
471 desired = (DataInfo().IsSignal(ev) ? 1.0 : 0.0);
472 result = InterpretFormula( ev, pars.begin(), pars.end() );
476 estimator[0] /= sumOfWeights[0];
477 estimator[1] /= sumOfWeights[1];
479 return estimator[0] + estimator[1];
490 for( std::vector<Double_t>::iterator it = parBegin; it != parEnd; ++it ){
492 fFormula->SetParameter( ipar, (*it) );
495 for (
UInt_t ivar=0; ivar<GetNvar(); ivar++) fFormula->SetParameter( ivar+ipar, event->
GetValue(ivar) );
507 const Event* ev = GetEvent();
510 NoErrorCalc(err, errUpper);
512 return InterpretFormula( ev, fBestPars.begin(), fBestPars.end() );
519 if (fRegressionReturnVal ==
NULL) fRegressionReturnVal =
new std::vector<Float_t>();
520 fRegressionReturnVal->clear();
522 const Event* ev = GetEvent();
526 for(
Int_t dim = 0; dim < fOutputDimensions; ++dim ){
527 Int_t offset = dim*fNPars;
528 evT->
SetTarget(dim,InterpretFormula( ev, fBestPars.begin()+offset, fBestPars.begin()+offset+fNPars ) );
530 const Event* evT2 = GetTransformationHandler().InverseTransform( evT );
531 fRegressionReturnVal->push_back(evT2->
GetTarget(0));
535 return (*fRegressionReturnVal);
543 if (fMulticlassReturnVal ==
NULL) fMulticlassReturnVal =
new std::vector<Float_t>();
544 fMulticlassReturnVal->clear();
545 std::vector<Float_t> temp;
550 CalculateMulticlassValues( evt, fBestPars, temp );
552 UInt_t nClasses = DataInfo().GetNClasses();
553 for(
UInt_t iClass=0; iClass<nClasses; iClass++){
555 for(
UInt_t j=0;j<nClasses;j++){
557 norm+=
exp(temp[j]-temp[iClass]);
559 (*fMulticlassReturnVal).push_back(1.0/(1.0+norm));
562 return (*fMulticlassReturnVal);
580 for(
Int_t dim = 0; dim < fOutputDimensions; ++dim ){
581 Int_t offset = dim*fNPars;
582 Double_t value = InterpretFormula( evt, parameters.begin()+offset, parameters.begin()+offset+fNPars );
584 values.push_back( value );
604 fBestPars.resize( fNPars );
605 for (
UInt_t ipar=0; ipar<fNPars; ipar++) istr >> fBestPars[ipar];
617 for (
UInt_t ipar=0; ipar<fNPars*fOutputDimensions; ipar++) {
634 if(
gTools().HasAttr( wghtnode,
"NDim")) {
638 fOutputDimensions = 1;
642 fBestPars.resize( fNPars*fOutputDimensions );
652 if (ipar >= fNPars*fOutputDimensions)
Log() <<
kFATAL <<
"<ReadWeightsFromXML> index out of range: "
653 << ipar <<
" >= " << fNPars <<
Endl;
654 fBestPars[ipar] =
par;
671 fout <<
" double fParameter[" << fNPars <<
"];" << std::endl;
672 fout <<
"};" << std::endl;
673 fout <<
"" << std::endl;
674 fout <<
"inline void " << className <<
"::Initialize() " << std::endl;
675 fout <<
"{" << std::endl;
676 for(
UInt_t ipar=0; ipar<fNPars; ipar++) {
677 fout <<
" fParameter[" << ipar <<
"] = " << fBestPars[ipar] <<
";" << std::endl;
679 fout <<
"}" << std::endl;
681 fout <<
"inline double " << className <<
"::GetMvaValue__( const std::vector<double>& inputValues ) const" << std::endl;
682 fout <<
"{" << std::endl;
683 fout <<
" // interpret the formula" << std::endl;
687 for (
UInt_t ipar=0; ipar<fNPars; ipar++) {
692 for (
UInt_t ivar=0; ivar<GetNvar(); ivar++) {
696 fout <<
" double retval = " << str <<
";" << std::endl;
698 fout <<
" return retval; " << std::endl;
699 fout <<
"}" << std::endl;
701 fout <<
"// Clean up" << std::endl;
702 fout <<
"inline void " << className <<
"::Clear() " << std::endl;
703 fout <<
"{" << std::endl;
704 fout <<
" // nothing to clear" << std::endl;
705 fout <<
"}" << std::endl;
719 Log() <<
"The function discriminant analysis (FDA) is a classifier suitable " <<
Endl;
720 Log() <<
"to solve linear or simple nonlinear discrimination problems." <<
Endl;
722 Log() <<
"The user provides the desired function with adjustable parameters" <<
Endl;
723 Log() <<
"via the configuration option string, and FDA fits the parameters to" <<
Endl;
724 Log() <<
"it, requiring the signal (background) function value to be as close" <<
Endl;
725 Log() <<
"as possible to 1 (0). Its advantage over the more involved and" <<
Endl;
726 Log() <<
"automatic nonlinear discriminators is the simplicity and transparency " <<
Endl;
727 Log() <<
"of the discrimination expression. A shortcoming is that FDA will" <<
Endl;
728 Log() <<
"underperform for involved problems with complicated, phase space" <<
Endl;
729 Log() <<
"dependent nonlinear correlations." <<
Endl;
731 Log() <<
"Please consult the Users Guide for the format of the formula string" <<
Endl;
732 Log() <<
"and the allowed parameter ranges:" <<
Endl;
733 if (
gConfig().WriteOptionsReference()) {
734 Log() <<
"<a href=\"http://tmva.sourceforge.net/docu/TMVAUsersGuide.pdf\">"
735 <<
"http://tmva.sourceforge.net/docu/TMVAUsersGuide.pdf</a>" <<
Endl;
737 else Log() <<
"http://tmva.sourceforge.net/docu/TMVAUsersGuide.pdf" <<
Endl;
741 Log() <<
"The FDA performance depends on the complexity and fidelity of the" <<
Endl;
742 Log() <<
"user-defined discriminator function. As a general rule, it should" <<
Endl;
743 Log() <<
"be able to reproduce the discrimination power of any linear" <<
Endl;
744 Log() <<
"discriminant analysis. To reach into the nonlinear domain, it is" <<
Endl;
745 Log() <<
"useful to inspect the correlation profiles of the input variables," <<
Endl;
746 Log() <<
"and add quadratic and higher polynomial terms between variables as" <<
Endl;
747 Log() <<
"necessary. Comparison with more involved nonlinear classifiers can" <<
Endl;
748 Log() <<
"be used as a guide." <<
Endl;
752 Log() <<
"Depending on the function used, the choice of \"FitMethod\" is" <<
Endl;
753 Log() <<
"crucial for getting valuable solutions with FDA. As a guideline it" <<
Endl;
754 Log() <<
"is recommended to start with \"FitMethod=MINUIT\". When more complex" <<
Endl;
755 Log() <<
"functions are used where MINUIT does not converge to reasonable" <<
Endl;
756 Log() <<
"results, the user should switch to non-gradient FitMethods such" <<
Endl;
757 Log() <<
"as GeneticAlgorithm (GA) or Monte Carlo (MC). It might prove to be" <<
Endl;
758 Log() <<
"useful to combine GA (or MC) with MINUIT by setting the option" <<
Endl;
759 Log() <<
"\"Converger=MINUIT\". GA (MC) will then set the starting parameters" <<
Endl;
760 Log() <<
"for MINUIT such that the basic quality of GA (MC) of finding global" <<
Endl;
761 Log() <<
"minima is combined with the efficacy of MINUIT of finding local" <<
Endl;
void Init(void)
default initialisation
void DeclareOptions()
define the options (their key words) that can be set in the option string
MsgLogger & Endl(MsgLogger &ml)
void ClearAll()
delete and clear all class members
Collectable string class.
TString & ReplaceAll(const TString &s1, const TString &s2)
Double_t InterpretFormula(const Event *, std::vector< Double_t >::iterator begin, std::vector< Double_t >::iterator end)
formula interpretation
void CreateFormula()
translate formula string into TFormula, and parameter string into par ranges
Double_t GetWeight() const
return the event weight - depending on whether the flag IgnoreNegWeightsInTraining is or not...
virtual TObject * At(Int_t idx) const
Returns the object at position idx. Returns 0 if idx is out of range.
LongDouble_t Power(LongDouble_t x, LongDouble_t y)
Float_t GetValue(UInt_t ivar) const
return value of i'th variable
void PrintResults(const TString &, std::vector< Double_t > &, const Double_t) const
display fit parameters check maximum length of variable name
const char * Data() const
Double_t GetMvaValue(Double_t *err=0, Double_t *errUpper=0)
returns MVA value for given event
Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets)
FDA can handle classification with 2 classes and regression with one regression-target.
ClassImp(TMVA::MethodFDA) TMVA
standard constructor
void CalculateMulticlassValues(const TMVA::Event *&evt, std::vector< Double_t > ¶meters, std::vector< Float_t > &values)
calculate the values for multiclass
char * Form(const char *fmt,...)
void SetTarget(UInt_t itgt, Float_t value)
set the target value (dimension itgt) to value
void ReadWeightsFromXML(void *wghtnode)
read coefficients from xml weight file
void Train(void)
FDA training.
void ProcessOptions()
the option string is decoded, for availabel options see "DeclareOptions"
void MakeClassSpecific(std::ostream &, const TString &) const
write FDA-specific classifier response
virtual Int_t GetSize() const
Describe directory structure in memory.
Double_t EstimatorFunction(std::vector< Double_t > &)
compute estimator for given parameter set (to be minimised) const Double_t sumOfWeights[] = { fSumOfW...
void ReadWeightsFromStream(std::istream &i)
read back the training results from a file (stream)
MethodFDA(const TString &jobName, const TString &methodTitle, DataSetInfo &theData, const TString &theOption="", TDirectory *theTargetDir=0)
Float_t GetTarget(UInt_t itgt) const
#define REGISTER_METHOD(CLASS)
for example
Abstract ClassifierFactory template that handles arbitrary types.
virtual const std::vector< Float_t > & GetRegressionValues()
virtual const std::vector< Float_t > & GetMulticlassValues()
void GetHelpMessage() const
get help message text
double norm(double *x, double *p)
virtual ~MethodFDA(void)
destructor
Ssiz_t First(char c) const
Find first occurrence of a character c.
void AddWeightsXMLTo(void *parent) const
create XML description for LD classification and regression (for arbitrary number of output classes/t...