77using std::stringstream;
95 fConvergerFitter( 0 ),
96 fSumOfWeightsSig( 0 ),
97 fSumOfWeightsBkg( 0 ),
99 fOutputDimensions( 0 )
113 fConvergerFitter( 0 ),
114 fSumOfWeightsSig( 0 ),
115 fSumOfWeightsBkg( 0 ),
117 fOutputDimensions( 0 )
131 fSumOfWeightsSig = 0;
132 fSumOfWeightsBkg = 0;
134 fFormulaStringP =
"";
135 fParRangeStringP =
"";
136 fFormulaStringT =
"";
137 fParRangeStringT =
"";
143 if (fMulticlassReturnVal == NULL) fMulticlassReturnVal =
new std::vector<Float_t>();
165 DeclareOptionRef( fFormulaStringP =
"(0)",
"Formula",
"The discrimination formula" );
166 DeclareOptionRef( fParRangeStringP =
"()",
"ParRanges",
"Parameter ranges" );
169 DeclareOptionRef( fFitMethod =
"MINUIT",
"FitMethod",
"Optimisation Method");
173 AddPreDefVal(
TString(
"MINUIT"));
175 DeclareOptionRef( fConverger =
"None",
"Converger",
"FitMethod uses Converger to improve result");
177 AddPreDefVal(
TString(
"MINUIT"));
186 fFormulaStringT = fFormulaStringP;
191 for (
UInt_t ipar=0; ipar<fNPars; ipar++) {
192 fFormulaStringT.ReplaceAll(
Form(
"(%i)",ipar),
Form(
"[%i]",ipar) );
196 for (
Int_t ipar=fNPars; ipar<1000; ipar++) {
197 if (fFormulaStringT.Contains(
Form(
"(%i)",ipar) ))
199 <<
"<CreateFormula> Formula contains expression: \"" <<
Form(
"(%i)",ipar) <<
"\", "
200 <<
"which cannot be attributed to a parameter; "
201 <<
"it may be that the number of variable ranges given via \"ParRanges\" "
202 <<
"does not match the number of parameters in the formula expression, please verify!"
207 for (
Int_t ivar=GetNvar()-1; ivar >= 0; ivar--) {
208 fFormulaStringT.ReplaceAll(
Form(
"x%i",ivar),
Form(
"[%i]",ivar+fNPars) );
212 for (
UInt_t ivar=GetNvar(); ivar<1000; ivar++) {
213 if (fFormulaStringT.Contains(
Form(
"x%i",ivar) ))
215 <<
"<CreateFormula> Formula contains expression: \"" <<
Form(
"x%i",ivar) <<
"\", "
216 <<
"which cannot be attributed to an input variable" <<
Endl;
219 Log() <<
"User-defined formula string : \"" << fFormulaStringP <<
"\"" <<
Endl;
220 Log() <<
"TFormula-compatible formula string: \"" << fFormulaStringT <<
"\"" <<
Endl;
221 Log() << kDEBUG <<
"Creating and compiling formula" <<
Endl;
224 if (fFormula)
delete fFormula;
225 fFormula =
new TFormula(
"FDA_Formula", fFormulaStringT );
228 if (!fFormula->IsValid())
229 Log() << kFATAL <<
"<ProcessOptions> Formula expression could not be properly compiled" <<
Endl;
232 if (fFormula->GetNpar() > (
Int_t)(fNPars + GetNvar()))
233 Log() << kFATAL <<
"<ProcessOptions> Dubious number of parameters in formula expression: "
234 << fFormula->GetNpar() <<
" - compared to maximum allowed: " << fNPars + GetNvar() <<
Endl;
243 fParRangeStringT = fParRangeStringP;
246 fParRangeStringT.ReplaceAll(
" ",
"" );
247 fNPars = fParRangeStringT.CountChar(
')' );
251 Log() << kFATAL <<
"<ProcessOptions> Mismatch in parameter string: "
252 <<
"the number of parameters: " << fNPars <<
" != ranges defined: "
253 << parList->
GetSize() <<
"; the format of the \"ParRanges\" string "
254 <<
"must be: \"(-1.2,3.4);(-2.3,4.55);...\", "
255 <<
"where the numbers in \"(a,b)\" correspond to the a=min, b=max parameter ranges; "
256 <<
"each parameter defined in the function string must have a corresponding rang."
260 fParRange.resize( fNPars );
261 for (
UInt_t ipar=0; ipar<fNPars; ipar++) fParRange[ipar] = 0;
263 for (
UInt_t ipar=0; ipar<fNPars; ipar++) {
270 stringstream stmin;
Float_t pmin=0; stmin << pminS.
Data(); stmin >> pmin;
271 stringstream stmax;
Float_t pmax=0; stmax << pmaxS.
Data(); stmax >> pmax;
274 if (
TMath::Abs(pmax-pmin) < 1.e-30) pmax = pmin;
275 if (pmin > pmax)
Log() << kFATAL <<
"<ProcessOptions> max > min in interval for parameter: ["
276 << ipar <<
"] : [" << pmin <<
", " << pmax <<
"] " <<
Endl;
278 Log() << kINFO <<
"Create parameter interval for parameter " << ipar <<
" : [" << pmin <<
"," << pmax <<
"]" <<
Endl;
279 fParRange[ipar] =
new Interval( pmin, pmax );
288 fOutputDimensions = 1;
290 fOutputDimensions = DataInfo().GetNTargets();
292 fOutputDimensions = DataInfo().GetNClasses();
294 for(
Int_t dim = 1; dim < fOutputDimensions; ++dim ){
295 for(
UInt_t par = 0; par < fNPars; ++par ){
296 fParRange.push_back( fParRange.at(par) );
303 if (fConverger ==
"MINUIT") {
305 SetOptions(
dynamic_cast<Configurable*
>(fConvergerFitter)->GetOptions());
308 if(fFitMethod ==
"MC")
309 fFitter =
new MCFitter( *fConvergerFitter,
Form(
"%s_Fitter_MC",
GetName()), fParRange, GetOptions() );
310 else if (fFitMethod ==
"GA")
312 else if (fFitMethod ==
"SA")
314 else if (fFitMethod ==
"MINUIT")
317 Log() << kFATAL <<
"<Train> Do not understand fit method:" << fFitMethod <<
Endl;
320 fFitter->CheckForUnusedOptions();
351 for (
UInt_t ipar=0; ipar<fParRange.size() && ipar<fNPars; ipar++) {
352 if (fParRange[ipar] != 0) {
delete fParRange[ipar]; fParRange[ipar] = 0; }
356 if (fFormula != 0) {
delete fFormula; fFormula = 0; }
367 fSumOfWeightsSig = 0;
368 fSumOfWeightsBkg = 0;
370 for (
UInt_t ievt=0; ievt<GetNEvents(); ievt++) {
373 const Event* ev = GetEvent(ievt);
378 if (!DoRegression()) {
379 if (DataInfo().IsSignal(ev)) { fSumOfWeightsSig += w; }
380 else { fSumOfWeightsBkg += w; }
386 if (!DoRegression()) {
387 if (fSumOfWeightsSig <= 0 || fSumOfWeightsBkg <= 0) {
388 Log() << kFATAL <<
"<Train> Troubles in sum of weights: "
389 << fSumOfWeightsSig <<
" (S) : " << fSumOfWeightsBkg <<
" (B)" <<
Endl;
392 else if (fSumOfWeights <= 0) {
393 Log() << kFATAL <<
"<Train> Troubles in sum of weights: "
394 << fSumOfWeights <<
Endl;
399 for (std::vector<Interval*>::const_iterator parIt = fParRange.begin(); parIt != fParRange.end(); ++parIt) {
400 fBestPars.push_back( (*parIt)->GetMean() );
404 Double_t estimator = fFitter->Run( fBestPars );
407 PrintResults( fFitMethod, fBestPars, estimator );
409 delete fFitter; fFitter = 0;
410 if (fConvergerFitter!=0 && fConvergerFitter!=(
IFitterTarget*)
this) {
411 delete fConvergerFitter;
412 fConvergerFitter = 0;
424 Log() << kHEADER <<
"Results for parameter fit using \"" << fitter <<
"\" fitter:" <<
Endl;
425 std::vector<TString> parNames;
426 for (
UInt_t ipar=0; ipar<pars.size(); ipar++) parNames.push_back(
Form(
"Par(%i)",ipar ) );
428 Log() <<
"Discriminator expression: \"" << fFormulaStringP <<
"\"" <<
Endl;
429 Log() <<
"Value of estimator at minimum: " << estimator <<
Endl;
437 const Double_t sumOfWeights[] = { fSumOfWeightsBkg, fSumOfWeightsSig, fSumOfWeights };
444 if( DoRegression() ){
445 for (
UInt_t ievt=0; ievt<GetNEvents(); ievt++) {
449 for(
Int_t dim = 0; dim < fOutputDimensions; ++dim ){
451 result = InterpretFormula( ev, pars.begin(), pars.end() );
453 estimator[2] += deviation * ev->
GetWeight();
456 estimator[2] /= sumOfWeights[2];
460 }
else if( DoMulticlass() ){
461 for (
UInt_t ievt=0; ievt<GetNEvents(); ievt++) {
465 CalculateMulticlassValues( ev, pars, *fMulticlassReturnVal );
468 for(
Int_t dim = 0; dim < fOutputDimensions; ++dim ){
469 Double_t y = fMulticlassReturnVal->at(dim);
475 estimator[2] /= sumOfWeights[2];
480 for (
UInt_t ievt=0; ievt<GetNEvents(); ievt++) {
484 desired = (DataInfo().IsSignal(ev) ? 1.0 : 0.0);
485 result = InterpretFormula( ev, pars.begin(), pars.end() );
489 estimator[0] /= sumOfWeights[0];
490 estimator[1] /= sumOfWeights[1];
492 return estimator[0] + estimator[1];
503 for( std::vector<Double_t>::iterator it = parBegin; it != parEnd; ++it ){
505 fFormula->SetParameter( ipar, (*it) );
508 for (
UInt_t ivar=0; ivar<GetNvar(); ivar++) fFormula->SetParameter( ivar+ipar, event->
GetValue(ivar) );
510 Double_t result = fFormula->Eval( 0 );
520 const Event* ev = GetEvent();
523 NoErrorCalc(err, errUpper);
525 return InterpretFormula( ev, fBestPars.begin(), fBestPars.end() );
532 if (fRegressionReturnVal == NULL) fRegressionReturnVal =
new std::vector<Float_t>();
533 fRegressionReturnVal->clear();
535 const Event* ev = GetEvent();
539 for(
Int_t dim = 0; dim < fOutputDimensions; ++dim ){
540 Int_t offset = dim*fNPars;
541 evT->
SetTarget(dim,InterpretFormula( ev, fBestPars.begin()+offset, fBestPars.begin()+offset+fNPars ) );
543 const Event* evT2 = GetTransformationHandler().InverseTransform( evT );
544 fRegressionReturnVal->push_back(evT2->
GetTarget(0));
548 return (*fRegressionReturnVal);
555 if (fMulticlassReturnVal == NULL) fMulticlassReturnVal =
new std::vector<Float_t>();
556 fMulticlassReturnVal->clear();
557 std::vector<Float_t> temp;
562 CalculateMulticlassValues( evt, fBestPars, temp );
564 UInt_t nClasses = DataInfo().GetNClasses();
565 for(
UInt_t iClass=0; iClass<nClasses; iClass++){
567 for(
UInt_t j=0;j<nClasses;j++){
569 norm+=
exp(temp[j]-temp[iClass]);
571 (*fMulticlassReturnVal).push_back(1.0/(1.0+norm));
574 return (*fMulticlassReturnVal);
592 for(
Int_t dim = 0; dim < fOutputDimensions; ++dim ){
593 Int_t offset = dim*fNPars;
594 Double_t value = InterpretFormula( evt, parameters.begin()+offset, parameters.begin()+offset+fNPars );
596 values.push_back( value );
611 fBestPars.resize( fNPars );
612 for (
UInt_t ipar=0; ipar<fNPars; ipar++) istr >> fBestPars[ipar];
624 for (
UInt_t ipar=0; ipar<fNPars*fOutputDimensions; ipar++) {
641 if(
gTools().HasAttr( wghtnode,
"NDim")) {
645 fOutputDimensions = 1;
649 fBestPars.resize( fNPars*fOutputDimensions );
659 if (ipar >= fNPars*fOutputDimensions)
Log() << kFATAL <<
"<ReadWeightsFromXML> index out of range: "
660 << ipar <<
" >= " << fNPars <<
Endl;
661 fBestPars[ipar] = par;
678 fout <<
" double fParameter[" << fNPars <<
"];" << std::endl;
679 fout <<
"};" << std::endl;
680 fout <<
"" << std::endl;
681 fout <<
"inline void " << className <<
"::Initialize() " << std::endl;
682 fout <<
"{" << std::endl;
683 for(
UInt_t ipar=0; ipar<fNPars; ipar++) {
684 fout <<
" fParameter[" << ipar <<
"] = " << fBestPars[ipar] <<
";" << std::endl;
686 fout <<
"}" << std::endl;
688 fout <<
"inline double " << className <<
"::GetMvaValue__( const std::vector<double>& inputValues ) const" << std::endl;
689 fout <<
"{" << std::endl;
690 fout <<
" // interpret the formula" << std::endl;
694 for (
UInt_t ipar=0; ipar<fNPars; ipar++) {
699 for (
UInt_t ivar=0; ivar<GetNvar(); ivar++) {
703 fout <<
" double retval = " << str <<
";" << std::endl;
705 fout <<
" return retval; " << std::endl;
706 fout <<
"}" << std::endl;
708 fout <<
"// Clean up" << std::endl;
709 fout <<
"inline void " << className <<
"::Clear() " << std::endl;
710 fout <<
"{" << std::endl;
711 fout <<
" // nothing to clear" << std::endl;
712 fout <<
"}" << std::endl;
726 Log() <<
"The function discriminant analysis (FDA) is a classifier suitable " <<
Endl;
727 Log() <<
"to solve linear or simple nonlinear discrimination problems." <<
Endl;
729 Log() <<
"The user provides the desired function with adjustable parameters" <<
Endl;
730 Log() <<
"via the configuration option string, and FDA fits the parameters to" <<
Endl;
731 Log() <<
"it, requiring the signal (background) function value to be as close" <<
Endl;
732 Log() <<
"as possible to 1 (0). Its advantage over the more involved and" <<
Endl;
733 Log() <<
"automatic nonlinear discriminators is the simplicity and transparency " <<
Endl;
734 Log() <<
"of the discrimination expression. A shortcoming is that FDA will" <<
Endl;
735 Log() <<
"underperform for involved problems with complicated, phase space" <<
Endl;
736 Log() <<
"dependent nonlinear correlations." <<
Endl;
738 Log() <<
"Please consult the Users Guide for the format of the formula string" <<
Endl;
739 Log() <<
"and the allowed parameter ranges:" <<
Endl;
740 if (
gConfig().WriteOptionsReference()) {
741 Log() <<
"<a href=\"http://tmva.sourceforge.net/docu/TMVAUsersGuide.pdf\">"
742 <<
"http://tmva.sourceforge.net/docu/TMVAUsersGuide.pdf</a>" <<
Endl;
744 else Log() <<
"http://tmva.sourceforge.net/docu/TMVAUsersGuide.pdf" <<
Endl;
748 Log() <<
"The FDA performance depends on the complexity and fidelity of the" <<
Endl;
749 Log() <<
"user-defined discriminator function. As a general rule, it should" <<
Endl;
750 Log() <<
"be able to reproduce the discrimination power of any linear" <<
Endl;
751 Log() <<
"discriminant analysis. To reach into the nonlinear domain, it is" <<
Endl;
752 Log() <<
"useful to inspect the correlation profiles of the input variables," <<
Endl;
753 Log() <<
"and add quadratic and higher polynomial terms between variables as" <<
Endl;
754 Log() <<
"necessary. Comparison with more involved nonlinear classifiers can" <<
Endl;
755 Log() <<
"be used as a guide." <<
Endl;
759 Log() <<
"Depending on the function used, the choice of \"FitMethod\" is" <<
Endl;
760 Log() <<
"crucial for getting valuable solutions with FDA. As a guideline it" <<
Endl;
761 Log() <<
"is recommended to start with \"FitMethod=MINUIT\". When more complex" <<
Endl;
762 Log() <<
"functions are used where MINUIT does not converge to reasonable" <<
Endl;
763 Log() <<
"results, the user should switch to non-gradient FitMethods such" <<
Endl;
764 Log() <<
"as GeneticAlgorithm (GA) or Monte Carlo (MC). It might prove to be" <<
Endl;
765 Log() <<
"useful to combine GA (or MC) with MINUIT by setting the option" <<
Endl;
766 Log() <<
"\"Converger=MINUIT\". GA (MC) will then set the starting parameters" <<
Endl;
767 Log() <<
"for MINUIT such that the basic quality of GA (MC) of finding global" <<
Endl;
768 Log() <<
"minima is combined with the efficacy of MINUIT of finding local" <<
Endl;
#define REGISTER_METHOD(CLASS)
for example
char * Form(const char *fmt,...)
virtual Int_t GetSize() const
Return the capacity of the collection, i.e.
virtual TObject * At(Int_t idx) const
Returns the object at position idx. Returns 0 if idx is out of range.
Class that contains all the data information.
Float_t GetValue(UInt_t ivar) const
return value of i'th variable
void SetTarget(UInt_t itgt, Float_t value)
set the target value (dimension itgt) to value
Double_t GetWeight() const
return the event weight - depending on whether the flag IgnoreNegWeightsInTraining is or not.
Float_t GetTarget(UInt_t itgt) const
Fitter using a Genetic Algorithm.
Interface for a fitter 'target'.
The TMVA::Interval Class.
Fitter using Monte Carlo sampling of parameters.
Virtual base Class for all MVA method.
Function discriminant analysis (FDA).
void Train(void)
FDA training.
void AddWeightsXMLTo(void *parent) const
create XML description for LD classification and regression (for arbitrary number of output classes/t...
Double_t EstimatorFunction(std::vector< Double_t > &)
compute estimator for given parameter set (to be minimised)
virtual ~MethodFDA(void)
destructor
Double_t InterpretFormula(const Event *, std::vector< Double_t >::iterator begin, std::vector< Double_t >::iterator end)
formula interpretation
Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets)
FDA can handle classification with 2 classes and regression with one regression-target.
void ReadWeightsFromXML(void *wghtnode)
read coefficients from xml weight file
void CalculateMulticlassValues(const TMVA::Event *&evt, std::vector< Double_t > ¶meters, std::vector< Float_t > &values)
calculate the values for multiclass
void ReadWeightsFromStream(std::istream &i)
read back the training results from a file (stream)
virtual const std::vector< Float_t > & GetMulticlassValues()
MethodFDA(const TString &jobName, const TString &methodTitle, DataSetInfo &theData, const TString &theOption="")
standard constructor
void Init(void)
default initialisation
void ClearAll()
delete and clear all class members
void PrintResults(const TString &, std::vector< Double_t > &, const Double_t) const
display fit parameters check maximum length of variable name
void MakeClassSpecific(std::ostream &, const TString &) const
write FDA-specific classifier response
Double_t GetMvaValue(Double_t *err=0, Double_t *errUpper=0)
returns MVA value for given event
virtual const std::vector< Float_t > & GetRegressionValues()
void ProcessOptions()
the option string is decoded, for available options see "DeclareOptions"
void CreateFormula()
translate formula string into TFormula, and parameter string into par ranges
void DeclareOptions()
define the options (their key words) that can be set in the option string
void GetHelpMessage() const
get help message text
Fitter using a Simulated Annealing Algorithm.
Singleton class for Global types used by TMVA.
Collectable string class.
Ssiz_t First(char c) const
Find first occurrence of a character c.
const char * Data() const
TString & ReplaceAll(const TString &s1, const TString &s2)
std::string GetName(const std::string &scope_name)
double crossEntropy(ItProbability itProbabilityBegin, ItProbability itProbabilityEnd, ItTruth itTruthBegin, ItTruth itTruthEnd, ItDelta itDelta, ItDelta itDeltaEnd, ItInvActFnc itInvActFnc, double patternWeight)
cross entropy error function
MsgLogger & Endl(MsgLogger &ml)
LongDouble_t Power(LongDouble_t x, LongDouble_t y)
static long int sum(long int i)