115 fModelPersistence(
kTRUE)
147 DeclareOptionRef(color,
"Color",
"Flag for coloured screen output (default: True, if in batch mode: False)");
150 "List of transformations to test; formatting example: \"Transformations=I;D;P;U;G,D\", for identity, "
151 "decorrelation, PCA, Uniform and Gaussianisation followed by decorrelation transformations");
155 "Batch mode: boolean silent flag inhibiting any output from TMVA after the creation of the factory "
156 "class object (default: False)");
158 "Draw progress bar to display training, testing and evaluation schedule (default: True)");
160 "Option to save the trained model in xml file or using serialization");
164 "Set the analysis type (Classification, Regression, Multiclass, Auto) (default: Auto)");
188 if (analysisType ==
"classification")
190 else if (analysisType ==
"regression")
192 else if (analysisType ==
"multiclass")
194 else if (analysisType ==
"auto")
237 DeclareOptionRef(color,
"Color",
"Flag for coloured screen output (default: True, if in batch mode: False)");
240 "List of transformations to test; formatting example: \"Transformations=I;D;P;U;G,D\", for identity, "
241 "decorrelation, PCA, Uniform and Gaussianisation followed by decorrelation transformations");
245 "Batch mode: boolean silent flag inhibiting any output from TMVA after the creation of the factory "
246 "class object (default: False)");
248 "Draw progress bar to display training, testing and evaluation schedule (default: True)");
250 "Option to save the trained model in xml file or using serialization");
254 "Set the analysis type (Classification, Regression, Multiclass, Auto) (default: Auto)");
278 if (analysisType ==
"classification")
280 else if (analysisType ==
"regression")
282 else if (analysisType ==
"multiclass")
284 else if (analysisType ==
"auto")
307 std::vector<TMVA::VariableTransformBase *>::iterator
trfIt = fDefaultTrfs.
begin();
311 this->DeleteAllMethods();
325 std::map<TString, MVector *>::iterator
itrMap;
332 Log() << kDEBUG <<
"Delete method: " << (*itrMethod)->GetName() <<
Endl;
360 if (fModelPersistence)
366 if (
loader->GetDataSetInfo().GetNClasses() == 2 &&
loader->GetDataSetInfo().GetClassInfo(
"Signal") !=
NULL &&
367 loader->GetDataSetInfo().GetClassInfo(
"Background") !=
NULL) {
369 }
else if (
loader->GetDataSetInfo().GetNClasses() >= 2) {
372 Log() << kFATAL <<
"No analysis type for " <<
loader->GetDataSetInfo().GetNClasses() <<
" classes and "
373 <<
loader->GetDataSetInfo().GetNTargets() <<
" regression targets." <<
Endl;
379 if (fMethodsMap.find(
datasetname) != fMethodsMap.end()) {
381 Log() << kFATAL <<
"Booking failed since method with title <" << methodTitle <<
"> already exists "
382 <<
"in with DataSet Name <" <<
loader->GetName() <<
"> " <<
Endl;
386 Log() << kHEADER <<
"Booking method: " <<
gTools().
Color(
"bold")
394 conf->DeclareOptionRef(
boostNum = 0,
"Boost_num",
"Number of times the classifier will be boosted");
395 conf->ParseOptions();
399 if (fModelPersistence) {
416 Log() << kDEBUG <<
"Boost Number is " <<
boostNum <<
" > 0: train boosted classifier" <<
Endl;
420 Log() << kFATAL <<
"Method with type kBoost cannot be casted to MethodCategory. /Factory" <<
Endl;
423 if (fModelPersistence)
425 methBoost->SetModelPersistence(fModelPersistence);
427 methBoost->fDataSetManager =
loader->GetDataSetInfo().GetDataSetManager();
429 methBoost->SetSilentFile(IsSilentFile());
440 Log() << kFATAL <<
"Method with type kCategory cannot be casted to MethodCategory. /Factory"
444 if (fModelPersistence)
446 methCat->SetModelPersistence(fModelPersistence);
447 methCat->fDataSetManager =
loader->GetDataSetInfo().GetDataSetManager();
448 methCat->SetFile(fgTargetFile);
449 methCat->SetSilentFile(IsSilentFile());
452 if (!
method->HasAnalysisType(fAnalysisType,
loader->GetDataSetInfo().GetNClasses(),
453 loader->GetDataSetInfo().GetNTargets())) {
454 Log() << kWARNING <<
"Method " <<
method->GetMethodTypeName() <<
" is not capable of handling ";
456 Log() <<
"regression with " <<
loader->GetDataSetInfo().GetNTargets() <<
" targets." <<
Endl;
458 Log() <<
"multiclass classification with " <<
loader->GetDataSetInfo().GetNClasses() <<
" classes." <<
Endl;
460 Log() <<
"classification with " <<
loader->GetDataSetInfo().GetNClasses() <<
" classes." <<
Endl;
465 if (fModelPersistence)
467 method->SetModelPersistence(fModelPersistence);
468 method->SetAnalysisType(fAnalysisType);
472 method->SetFile(fgTargetFile);
473 method->SetSilentFile(IsSilentFile());
478 if (fMethodsMap.find(
datasetname) == fMethodsMap.end()) {
508 Log() << kERROR <<
"Cannot handle category methods for now." <<
Endl;
512 if (fModelPersistence) {
523 if (fModelPersistence)
525 method->SetModelPersistence(fModelPersistence);
526 method->SetAnalysisType(fAnalysisType);
528 method->SetFile(fgTargetFile);
529 method->SetSilentFile(IsSilentFile());
531 method->DeclareCompatibilityOptions();
534 method->ReadStateFromFile();
540 Log() << kFATAL <<
"Booking failed since method with title <" << methodTitle <<
"> already exists "
541 <<
"in with DataSet Name <" <<
loader->GetName() <<
"> " <<
Endl;
544 Log() << kINFO <<
"Booked classifier \"" <<
method->GetMethodName() <<
"\" of type: \""
545 <<
method->GetMethodTypeName() <<
"\"" <<
Endl;
562 if (fMethodsMap.find(
datasetname) == fMethodsMap.end())
571 if ((
mva->GetMethodName()) == methodTitle)
582 if (fMethodsMap.find(
datasetname) == fMethodsMap.end())
585 std::string methodName = methodTitle.
Data();
600 if (!RootBaseDir()->GetDirectory(fDataSetInfo.
GetName()))
601 RootBaseDir()->mkdir(fDataSetInfo.
GetName());
605 RootBaseDir()->cd(fDataSetInfo.
GetName());
654 std::vector<TMVA::TransformationHandler *>
trfs;
664 Log() << kDEBUG <<
"current transformation string: '" <<
trfS.Data() <<
"'" <<
Endl;
667 if (
trfS.BeginsWith(
'I'))
674 std::vector<TMVA::TransformationHandler *>::iterator
trfIt =
trfs.
begin();
678 (*trfIt)->SetRootDir(RootBaseDir()->GetDirectory(fDataSetInfo.
GetName()));
698 std::map<TString, MVector *>::iterator
itrMap;
710 Log() << kFATAL <<
"Dynamic cast to MethodBase failed" <<
Endl;
715 Log() << kWARNING <<
"Method " <<
mva->GetMethodName() <<
" not trained (training tree has less entries ["
720 Log() << kINFO <<
"Optimize method: " <<
mva->GetMethodName() <<
" for "
723 : (fAnalysisType ==
Types::kMulticlass ?
"Multiclass classification" :
"Classification"))
727 Log() << kINFO <<
"Optimization of tuning parameters finished for Method:" <<
mva->GetName() <<
Endl;
758 if (fMethodsMap.find(
datasetname) == fMethodsMap.end()) {
759 Log() << kERROR <<
Form(
"DataSet = %s not found in methods map.",
datasetname.Data()) <<
Endl;
771 Log() << kERROR <<
Form(
"Can only generate ROC curves for analysis type kClassification and kMulticlass.")
778 dataset->SetCurrentType(
type);
784 <<
Form(
"Given class number (iClass = %i) does not exist. There are %i classes in dataset.",
iClass,
806 std::vector<Float_t>
mvaRes;
858 if (fMethodsMap.find(
datasetname) == fMethodsMap.end()) {
859 Log() << kERROR <<
Form(
"DataSet = %s not found in methods map.",
datasetname.Data()) <<
Endl;
871 Log() << kERROR <<
Form(
"Can only generate ROC integral for analysis type kClassification. and kMulticlass.")
879 <<
Form(
"ROCCurve object was not created in Method = %s not found with Dataset = %s ",
theMethodName.Data(),
929 if (fMethodsMap.find(
datasetname) == fMethodsMap.end()) {
930 Log() << kERROR <<
Form(
"DataSet = %s not found in methods map.",
datasetname.Data()) <<
Endl;
942 Log() << kERROR <<
Form(
"Can only generate ROC curves for analysis type kClassification and kMulticlass.")
952 <<
Form(
"ROCCurve object was not created in Method = %s not found with Dataset = %s ",
theMethodName.Data(),
1017 <<
Form(
"Given class number (iClass = %i) does not exist. There are %i classes in dataset.",
iClass,
1035 if (
multigraph->GetListOfGraphs() ==
nullptr) {
1036 Log() << kERROR <<
Form(
"No metohds have class %i defined.",
iClass) <<
Endl;
1073 if (fMethodsMap.find(
datasetname) == fMethodsMap.end()) {
1074 Log() << kERROR <<
Form(
"DataSet = %s not found in methods map.",
datasetname.Data()) <<
Endl;
1087 multigraph->GetYaxis()->SetTitle(
"Background rejection (Specificity)");
1088 multigraph->GetXaxis()->SetTitle(
"Signal efficiency (Sensitivity)");
1099 canvas->
BuildLegend(0.15, 0.15, 0.35, 0.3,
"MVA Method");
1114 if (fMethodsMap.empty()) {
1115 Log() << kINFO <<
"...nothing found to train" <<
Endl;
1121 Log() << kDEBUG <<
"Train all methods for "
1127 std::map<TString, MVector *>::iterator
itrMap;
1141 if (
mva->DataInfo().GetDataSetManager()->DataInput().GetEntries() <=
1143 Log() << kFATAL <<
"No input data for the training provided!" <<
Endl;
1147 Log() << kFATAL <<
"You want to do regression training without specifying a target." <<
Endl;
1149 mva->DataInfo().GetNClasses() < 2)
1150 Log() << kFATAL <<
"You want to do classification training, but specified less than two classes." <<
Endl;
1153 if (!IsSilentFile())
1154 WriteDataInformation(
mva->fDataSetInfo);
1157 Log() << kWARNING <<
"Method " <<
mva->GetMethodName() <<
" not trained (training tree has less entries ["
1162 Log() << kHEADER <<
"Train method: " <<
mva->GetMethodName() <<
" for "
1165 : (fAnalysisType ==
Types::kMulticlass ?
"Multiclass classification" :
"Classification"))
1168 Log() << kHEADER <<
"Training finished" <<
Endl <<
Endl;
1175 Log() << kINFO <<
"Ranking input variables (method specific)..." <<
Endl;
1185 Log() << kINFO <<
"No variable ranking supplied by classifier: "
1192 if (!IsSilentFile()) {
1198 m->fTrainHistory.SaveHistory(
m->GetMethodName());
1206 if (fModelPersistence) {
1208 Log() << kHEADER <<
"=== Destroy and recreate all methods via weight files for testing ===" <<
Endl <<
Endl;
1210 if (!IsSilentFile())
1211 RootBaseDir()->cd();
1237 Log() << kFATAL <<
"Method with type kCategory cannot be casted to MethodCategory. /Factory" <<
Endl;
1239 methCat->fDataSetManager =
m->DataInfo().GetDataSetManager();
1246 m->SetModelPersistence(fModelPersistence);
1247 m->SetSilentFile(IsSilentFile());
1248 m->SetAnalysisType(fAnalysisType);
1250 m->ReadStateFromFile();
1270 if (fMethodsMap.empty()) {
1271 Log() << kINFO <<
"...nothing found to test" <<
Endl;
1274 std::map<TString, MVector *>::iterator
itrMap;
1287 Log() << kHEADER <<
"Test method: " <<
mva->GetMethodName() <<
" for "
1290 : (analysisType ==
Types::kMulticlass ?
"Multiclass classification" :
"Classification"))
1301 if (methodTitle !=
"") {
1306 Log() << kWARNING <<
"<MakeClass> Could not find classifier \"" << methodTitle <<
"\" in list" <<
Endl;
1317 Log() << kINFO <<
"Make response class for classifier: " <<
method->GetMethodName() <<
Endl;
1329 if (methodTitle !=
"") {
1332 method->PrintHelpMessage();
1334 Log() << kWARNING <<
"<PrintHelpMessage> Could not find classifier \"" << methodTitle <<
"\" in list" <<
Endl;
1345 Log() << kINFO <<
"Print help message for classifier: " <<
method->GetMethodName() <<
Endl;
1346 method->PrintHelpMessage();
1356 Log() << kINFO <<
"Evaluating all variables..." <<
Endl;
1359 for (
UInt_t i = 0; i <
loader->GetDataSetInfo().GetNVariables(); i++) {
1360 TString s =
loader->GetDataSetInfo().GetVariableInfo(i).GetLabel();
1363 this->BookMethod(
loader,
"Variable", s);
1375 if (fMethodsMap.empty()) {
1376 Log() << kINFO <<
"...nothing found to evaluate" <<
Endl;
1379 std::map<TString, MVector *>::iterator
itrMap;
1394 std::vector<std::vector<TString>>
mname(2);
1395 std::vector<std::vector<Double_t>> sig(2), sep(2),
roc(2);
1415 std::vector<std::vector<Double_t>>
biastrain(1);
1416 std::vector<std::vector<Double_t>>
biastest(1);
1417 std::vector<std::vector<Double_t>>
devtrain(1);
1418 std::vector<std::vector<Double_t>>
devtest(1);
1419 std::vector<std::vector<Double_t>>
rmstrain(1);
1420 std::vector<std::vector<Double_t>>
rmstest(1);
1421 std::vector<std::vector<Double_t>>
minftrain(1);
1422 std::vector<std::vector<Double_t>>
minftest(1);
1423 std::vector<std::vector<Double_t>>
rhotrain(1);
1424 std::vector<std::vector<Double_t>>
rhotest(1);
1427 std::vector<std::vector<Double_t>>
biastrainT(1);
1428 std::vector<std::vector<Double_t>>
biastestT(1);
1429 std::vector<std::vector<Double_t>>
devtrainT(1);
1430 std::vector<std::vector<Double_t>>
devtestT(1);
1431 std::vector<std::vector<Double_t>>
rmstrainT(1);
1432 std::vector<std::vector<Double_t>>
rmstestT(1);
1433 std::vector<std::vector<Double_t>>
minftrainT(1);
1434 std::vector<std::vector<Double_t>>
minftestT(1);
1449 theMethod->SetSilentFile(IsSilentFile());
1456 Log() << kINFO <<
"Evaluate regression method: " <<
theMethod->GetMethodName() <<
Endl;
1461 Log() << kINFO <<
"TestRegression (testing)" <<
Endl;
1473 Log() << kINFO <<
"TestRegression (training)" <<
Endl;
1487 if (!IsSilentFile()) {
1488 Log() << kDEBUG <<
"\tWrite evaluation histograms to file" <<
Endl;
1497 Log() << kINFO <<
"Evaluate multiclass classification method: " <<
theMethod->GetMethodName() <<
Endl;
1516 if (!IsSilentFile()) {
1517 Log() << kDEBUG <<
"\tWrite evaluation histograms to file" <<
Endl;
1526 Log() << kHEADER <<
"Evaluate classifier: " <<
theMethod->GetMethodName() <<
Endl <<
Endl;
1527 isel = (
theMethod->GetMethodTypeName().Contains(
"Variable")) ? 1 : 0;
1548 theMethod->GetTrainingEfficiency(
"Efficiency:0.01"));
1554 if (!IsSilentFile()) {
1555 Log() << kDEBUG <<
"\tWrite evaluation histograms to file" <<
Endl;
1564 std::vector<std::vector<Double_t>>
vtmp;
1610 for (
Int_t k = 0; k < 2; k++) {
1611 std::vector<std::vector<Double_t>>
vtemp;
1622 vtemp.push_back(sig[k]);
1623 vtemp.push_back(sep[k]);
1650 if (fCorrelations) {
1653 const Int_t nvar =
method->fDataSetInfo.GetNVariables();
1660 std::vector<Double_t>
rvec;
1668 std::vector<TString> *
theVars =
new std::vector<TString>;
1669 std::vector<ResultsClassification *>
mvaRes;
1675 theVars->push_back(
m->GetTestvarName());
1676 rvec.push_back(
m->GetSignalReferenceCut());
1677 theVars->back().ReplaceAll(
"MVA_",
"");
1700 Log() << kWARNING <<
"Found NaN return value in event: " <<
ievt <<
" for method \""
1708 if (
method->fDataSetInfo.IsSignal(
ev)) {
1720 (*theMat)(
im,
jm)++;
1722 (*theMat)(
jm,
im)++;
1729 (*overlapS) *= (1.0 /
defDs->GetNEvtSigTest());
1730 (*overlapB) *= (1.0 /
defDs->GetNEvtBkgdTest());
1732 tpSig->MakePrincipals();
1733 tpBkg->MakePrincipals();
1767 Log() << kINFO <<
Endl;
1768 Log() << kINFO <<
Form(
"Dataset[%s] : ",
method->fDataSetInfo.GetName())
1769 <<
"Inter-MVA correlation matrix (signal):" <<
Endl;
1771 Log() << kINFO <<
Endl;
1773 Log() << kINFO <<
Form(
"Dataset[%s] : ",
method->fDataSetInfo.GetName())
1774 <<
"Inter-MVA correlation matrix (background):" <<
Endl;
1776 Log() << kINFO <<
Endl;
1779 Log() << kINFO <<
Form(
"Dataset[%s] : ",
method->fDataSetInfo.GetName())
1780 <<
"Correlations between input variables and MVA response (signal):" <<
Endl;
1782 Log() << kINFO <<
Endl;
1784 Log() << kINFO <<
Form(
"Dataset[%s] : ",
method->fDataSetInfo.GetName())
1785 <<
"Correlations between input variables and MVA response (background):" <<
Endl;
1787 Log() << kINFO <<
Endl;
1789 Log() << kWARNING <<
Form(
"Dataset[%s] : ",
method->fDataSetInfo.GetName())
1790 <<
"<TestAllMethods> cannot compute correlation matrices" <<
Endl;
1793 Log() << kINFO <<
Form(
"Dataset[%s] : ",
method->fDataSetInfo.GetName())
1794 <<
"The following \"overlap\" matrices contain the fraction of events for which " <<
Endl;
1795 Log() << kINFO <<
Form(
"Dataset[%s] : ",
method->fDataSetInfo.GetName())
1796 <<
"the MVAs 'i' and 'j' have returned conform answers about \"signal-likeness\"" <<
Endl;
1797 Log() << kINFO <<
Form(
"Dataset[%s] : ",
method->fDataSetInfo.GetName())
1798 <<
"An event is signal-like, if its MVA output exceeds the following value:" <<
Endl;
1800 Log() << kINFO <<
Form(
"Dataset[%s] : ",
method->fDataSetInfo.GetName())
1801 <<
"which correspond to the working point: eff(signal) = 1 - eff(background)" <<
Endl;
1805 Log() << kINFO <<
Form(
"Dataset[%s] : ",
method->fDataSetInfo.GetName())
1806 <<
"Note: no correlations and overlap with cut method are provided at present" <<
Endl;
1809 Log() << kINFO <<
Endl;
1810 Log() << kINFO <<
Form(
"Dataset[%s] : ",
method->fDataSetInfo.GetName())
1811 <<
"Inter-MVA overlap matrix (signal):" <<
Endl;
1813 Log() << kINFO <<
Endl;
1815 Log() << kINFO <<
Form(
"Dataset[%s] : ",
method->fDataSetInfo.GetName())
1816 <<
"Inter-MVA overlap matrix (background):" <<
Endl;
1839 Log() << kINFO <<
Endl;
1841 "--------------------------------------------------------------------------------------------------";
1842 Log() << kINFO <<
"Evaluation results ranked by smallest RMS on test sample:" <<
Endl;
1843 Log() << kINFO <<
"(\"Bias\" quotes the mean deviation of the regression from true target." <<
Endl;
1844 Log() << kINFO <<
" \"MutInf\" is the \"Mutual Information\" between regression and target." <<
Endl;
1845 Log() << kINFO <<
" Indicated by \"_T\" are the corresponding \"truncated\" quantities ob-" <<
Endl;
1846 Log() << kINFO <<
" tained when removing events deviating more than 2sigma from average.)" <<
Endl;
1858 <<
Form(
"%-20s %-15s:%#9.3g%#9.3g%#9.3g%#9.3g | %#5.3f %#5.3f",
theMethod->fDataSetInfo.GetName(),
1864 Log() << kINFO <<
Endl;
1865 Log() << kINFO <<
"Evaluation results ranked by smallest RMS on training sample:" <<
Endl;
1866 Log() << kINFO <<
"(overtraining check)" <<
Endl;
1869 <<
"DataSet Name: MVA Method: <Bias> <Bias_T> RMS RMS_T | MutInf MutInf_T"
1878 <<
Form(
"%-20s %-15s:%#9.3g%#9.3g%#9.3g%#9.3g | %#5.3f %#5.3f",
theMethod->fDataSetInfo.GetName(),
1884 Log() << kINFO <<
Endl;
1891 "-------------------------------------------------------------------------------------------------------";
1931 "Sig eff@B=0.10",
"Sig eff@B=0.30");
1933 "test (train)",
"test (train)");
1934 Log() << kINFO <<
Endl;
1935 Log() << kINFO <<
"1-vs-rest performance metrics per class" <<
Endl;
1937 Log() << kINFO <<
Endl;
1938 Log() << kINFO <<
"Considers the listed class as signal and the other classes" <<
Endl;
1939 Log() << kINFO <<
"as background, reporting the resulting binary performance." <<
Endl;
1940 Log() << kINFO <<
"A score of 0.820 (0.850) means 0.820 was acheived on the" <<
Endl;
1941 Log() << kINFO <<
"test set and 0.850 on the training set." <<
Endl;
1943 Log() << kINFO <<
Endl;
1946 for (
Int_t k = 0; k < 2; k++) {
1949 mname[k][i].ReplaceAll(
"Variable_",
"");
1960 Log() << kINFO <<
Endl;
1962 Log() << kINFO << row <<
Endl;
1963 Log() << kINFO <<
"------------------------------" <<
Endl;
1986 Log() << kINFO << row <<
Endl;
1993 Log() << kINFO <<
Endl;
1995 Log() << kINFO <<
Endl;
2013 stream << kINFO << header <<
Endl;
2029 stream << kINFO <<
Endl;
2033 Log() << kINFO <<
Endl;
2034 Log() << kINFO <<
"Confusion matrices for all methods" <<
Endl;
2036 Log() << kINFO <<
Endl;
2037 Log() << kINFO <<
"Does a binary comparison between the two classes given by a " <<
Endl;
2038 Log() << kINFO <<
"particular row-column combination. In each case, the class " <<
Endl;
2039 Log() << kINFO <<
"given by the row is considered signal while the class given " <<
Endl;
2040 Log() << kINFO <<
"by the column index is considered background." <<
Endl;
2041 Log() << kINFO <<
Endl;
2054 <<
"=== Showing confusion matrix for method : " <<
Form(
"%-15s", (
const char *)
mname[0][
iMethod])
2056 Log() << kINFO <<
"(Signal Efficiency for Background Efficiency 0.01%)" <<
Endl;
2057 Log() << kINFO <<
"---------------------------------------------------" <<
Endl;
2060 Log() << kINFO <<
Endl;
2062 Log() << kINFO <<
"(Signal Efficiency for Background Efficiency 0.10%)" <<
Endl;
2063 Log() << kINFO <<
"---------------------------------------------------" <<
Endl;
2066 Log() << kINFO <<
Endl;
2068 Log() << kINFO <<
"(Signal Efficiency for Background Efficiency 0.30%)" <<
Endl;
2069 Log() << kINFO <<
"---------------------------------------------------" <<
Endl;
2072 Log() << kINFO <<
Endl;
2075 Log() << kINFO <<
Endl;
2080 Log().EnableOutput();
2083 TString hLine =
"------------------------------------------------------------------------------------------"
2084 "-------------------------";
2085 Log() << kINFO <<
"Evaluation results ranked by best signal efficiency and purity (area)" <<
Endl;
2087 Log() << kINFO <<
"DataSet MVA " <<
Endl;
2088 Log() << kINFO <<
"Name: Method: ROC-integ" <<
Endl;
2094 for (
Int_t k = 0; k < 2; k++) {
2097 Log() << kINFO <<
"Input Variables: " <<
Endl <<
hLine <<
Endl;
2122 if (sep[k][i] < 0 || sig[k][i] < 0) {
2151 Log() << kINFO <<
Endl;
2152 Log() << kINFO <<
"Testing efficiency compared to training efficiency (overtraining check)" <<
Endl;
2155 <<
"DataSet MVA Signal efficiency: from test sample (from training sample) "
2157 Log() << kINFO <<
"Name: Method: @B=0.01 @B=0.10 @B=0.30 "
2160 for (
Int_t k = 0; k < 2; k++) {
2163 Log() << kINFO <<
"Input Variables: " <<
Endl <<
hLine <<
Endl;
2167 mname[k][i].ReplaceAll(
"Variable_",
"");
2173 <<
Form(
"%-20s %-15s: %#1.3f (%#1.3f) %#1.3f (%#1.3f) %#1.3f (%#1.3f)",
2180 Log() << kINFO <<
Endl;
2182 if (
gTools().CheckForSilentOption(GetOptions()))
2183 Log().InhibitOutput();
2186 if (!IsSilentFile()) {
2187 std::list<TString> datasets;
2188 for (
Int_t k = 0; k < 2; k++) {
2194 RootBaseDir()->cd(
theMethod->fDataSetInfo.GetName());
2195 if (std::find(datasets.begin(), datasets.end(),
theMethod->fDataSetInfo.GetName()) == datasets.end()) {
2198 datasets.push_back(
theMethod->fDataSetInfo.GetName());
2214 fModelPersistence =
kFALSE;
2215 fSilentFile =
kTRUE;
2218 const int nbits =
loader->GetDataSetInfo().GetNVariables();
2219 if (
vitype == VIType::kShort)
2221 else if (
vitype == VIType::kAll)
2223 else if (
vitype == VIType::kRandom) {
2227 }
else if (
nbits < 10) {
2228 Log() << kERROR <<
"Error in Variable Importance: Random mode require more that 10 variables in the dataset."
2230 }
else if (
nbits > 30) {
2231 Log() << kERROR <<
"Error in Variable Importance: Number of variables is too large for Random mode"
2248 const int nbits =
loader->GetDataSetInfo().GetNVariables();
2249 std::vector<TString>
varNames =
loader->GetDataSetInfo().GetListOfVariables();
2252 Log() << kERROR <<
"Number of combinations is too large , is 2^" <<
nbits <<
Endl;
2256 Log() << kWARNING <<
"Number of combinations is very large , is 2^" <<
nbits <<
Endl;
2258 uint64_t
range =
static_cast<uint64_t
>(pow(2,
nbits));
2266 for (
int i = 0; i <
nbits; i++)
2286 seedloader->PrepareTrainingAndTestTree(
loader->GetDataSetInfo().GetCut(
"Signal"),
2287 loader->GetDataSetInfo().GetCut(
"Background"),
2288 loader->GetDataSetInfo().GetSplitOptions());
2296 EvaluateAllMethods();
2299 ROC[
x] = GetROCIntegral(
xbitset.to_string(), methodTitle);
2307 this->DeleteAllMethods();
2309 fMethodsMap.clear();
2315 for (uint32_t i = 0; i <
VIBITS; ++i) {
2316 if (
x & (uint64_t(1) << i)) {
2322 uint32_t
ny =
static_cast<uint32_t
>( log(
x -
y) / 0.693147 ) ;
2335 std::cout <<
"--- Variable Importance Results (All)" << std::endl;
2339static uint64_t
sum(uint64_t i)
2342 if (i > 62)
return 0;
2343 return static_cast<uint64_t
>( std::pow(2, i + 1)) - 1;
2359 const int nbits =
loader->GetDataSetInfo().GetNVariables();
2360 std::vector<TString>
varNames =
loader->GetDataSetInfo().GetListOfVariables();
2363 Log() << kERROR <<
"Number of combinations is too large , is 2^" <<
nbits <<
Endl;
2370 for (
int i = 0; i <
nbits; i++)
2379 Log() << kFATAL <<
"Error: need at least one variable.";
2399 EvaluateAllMethods();
2402 SROC = GetROCIntegral(
xbitset.to_string(), methodTitle);
2410 this->DeleteAllMethods();
2411 fMethodsMap.clear();
2415 for (uint32_t i = 0; i <
VIBITS; ++i) {
2417 y =
x & ~(uint64_t(1) << i);
2422 uint32_t
ny =
static_cast<uint32_t
>(log(
x -
y) / 0.693147);
2445 EvaluateAllMethods();
2448 SSROC = GetROCIntegral(
ybitset.to_string(), methodTitle);
2457 this->DeleteAllMethods();
2458 fMethodsMap.clear();
2461 std::cout <<
"--- Variable Importance Results (Short)" << std::endl;
2476 const int nbits =
loader->GetDataSetInfo().GetNVariables();
2477 std::vector<TString>
varNames =
loader->GetDataSetInfo().GetListOfVariables();
2483 for (
int i = 0; i <
nbits; i++)
2512 EvaluateAllMethods();
2515 SROC = GetROCIntegral(
xbitset.to_string(), methodTitle);
2524 this->DeleteAllMethods();
2525 fMethodsMap.clear();
2529 for (uint32_t i = 0; i < 32; ++i) {
2530 if (
x & (uint64_t(1) << i)) {
2560 EvaluateAllMethods();
2563 SSROC = GetROCIntegral(
ybitset.to_string(), methodTitle);
2574 this->DeleteAllMethods();
2575 fMethodsMap.clear();
2579 std::cout <<
"--- Variable Importance Results (Random)" << std::endl;
2592 for (
int i = 0; i <
nbits; i++) {
2603 x_ie[i - 1] = (i - 1) * 1.;
2606 std::cout <<
"--- " <<
varNames[i - 1] <<
" = " <<
roc <<
" %" << std::endl;
2607 vih1->GetXaxis()->SetBinLabel(i,
varNames[i - 1].Data());
2613 vih1->LabelsOption(
"v >",
"X");
2614 vih1->SetBarWidth(0.97);
2619 vih1->GetYaxis()->SetTitle(
"Importance (%)");
2620 vih1->GetYaxis()->SetTitleSize(0.045);
2621 vih1->GetYaxis()->CenterTitle();
2622 vih1->GetYaxis()->SetTitleOffset(1.24);
2624 vih1->GetYaxis()->SetRangeUser(-7, 50);
2625 vih1->SetDirectory(
nullptr);
#define MinNoTrainingEvents
void printMatrix(const TMatrixD &mat)
write a matrix
int Int_t
Signed integer 4 bytes (int)
float Float_t
Float 4 bytes (float)
double Double_t
Double 8 bytes.
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t index
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t Atom_t Time_t type
TMatrixT< Double_t > TMatrixD
char * Form(const char *fmt,...)
Formats a string in a circular formatting buffer.
R__EXTERN TStyle * gStyle
R__EXTERN TSystem * gSystem
const_iterator begin() const
const_iterator end() const
virtual void SetFillColor(Color_t fcolor)
Set the fill area color.
virtual void SetLineWidth(Width_t lwidth)
Set the line width.
virtual void SetLineColor(Color_t lcolor)
Set the line color.
static Int_t GetColor(const char *hexcolor)
Static method returning color number for color specified by hex color string of form: "#rrggbb",...
A ROOT file is an on-disk file, usually with extension .root, that stores objects in a file-system-li...
A TGraph is an object made of two arrays X and Y with npoints each.
TAxis * GetXaxis() const
Get x axis of the graph.
TAxis * GetYaxis() const
Get y axis of the graph.
void SetTitle(const char *title="") override
Change (i.e.
1-D histogram with a float per channel (see TH1 documentation)
static void AddDirectory(Bool_t add=kTRUE)
Sets the flag controlling the automatic add of histograms in memory.
Service class for 2-D histogram classes.
static ClassifierFactory & Instance()
access to the ClassifierFactory singleton creates the instance if needed
TString fWeightFileDirPrefix
void SetDrawProgressBar(Bool_t d)
void SetUseColor(Bool_t uc)
class TMVA::Config::VariablePlotting fVariablePlotting
void SetConfigDescription(const char *d)
OptionBase * DeclareOptionRef(T &ref, const TString &name, const TString &desc="")
void AddPreDefVal(const T &)
void SetConfigName(const char *n)
virtual void ParseOptions()
options parser
const TString & GetOptions() const
MsgLogger * fLogger
! message logger
void CheckForUnusedOptions() const
checks for unused options in option string
Class that contains all the data information.
const TMatrixD * CorrelationMatrix(const TString &className) const
UInt_t GetNClasses() const
DataSet * GetDataSet() const
returns data set
TH2 * CreateCorrelationMatrixHist(const TMatrixD *m, const TString &hName, const TString &hTitle) const
const char * GetName() const override
Returns name of object.
ClassInfo * GetClassInfo(Int_t clNum) const
Class that contains all the data information.
const std::vector< Event * > & GetEventCollection(Types::ETreeType type=Types::kMaxTreeType) const
static void SetIsTraining(Bool_t)
when this static function is called, it sets the flag whether events with negative event weight shoul...
void PrintHelpMessage(const TString &datasetname, const TString &methodTitle="") const
Print predefined help message of classifier.
Bool_t fCorrelations
! enable to calculate correlations
std::vector< IMethod * > MVector
void TrainAllMethods()
Iterates through all booked methods and calls training.
Bool_t Verbose(void) const
void WriteDataInformation(DataSetInfo &fDataSetInfo)
Factory(TString theJobName, TFile *theTargetFile, TString theOption="")
Standard constructor.
void TestAllMethods()
Evaluates all booked methods on the testing data and adds the output to the Results in the corresponi...
Bool_t fVerbose
! verbose mode
void EvaluateAllMethods(void)
Iterates over all MVAs that have been booked, and calls their evaluation methods.
TH1F * EvaluateImportanceRandom(DataLoader *loader, UInt_t nseeds, Types::EMVA theMethod, TString methodTitle, const char *theOption="")
TH1F * GetImportance(const int nbits, std::vector< Double_t > importances, std::vector< TString > varNames)
Bool_t fROC
! enable to calculate ROC values
void EvaluateAllVariables(DataLoader *loader, TString options="")
Iterates over all MVA input variables and evaluates them.
TString fVerboseLevel
! verbosity level, controls granularity of logging
TMultiGraph * GetROCCurveAsMultiGraph(DataLoader *loader, UInt_t iClass, Types::ETreeType type=Types::kTesting)
Generate a collection of graphs, for all methods for a given class.
TH1F * EvaluateImportance(DataLoader *loader, VIType vitype, Types::EMVA theMethod, TString methodTitle, const char *theOption="")
Evaluate Variable Importance.
Double_t GetROCIntegral(DataLoader *loader, TString theMethodName, UInt_t iClass=0, Types::ETreeType type=Types::kTesting)
Calculate the integral of the ROC curve, also known as the area under curve (AUC),...
virtual ~Factory()
Destructor.
MethodBase * BookMethod(DataLoader *loader, MethodName theMethodName, TString methodTitle, TString theOption="")
Books an MVA classifier or regression method.
virtual void MakeClass(const TString &datasetname, const TString &methodTitle="") const
MethodBase * BookMethodWeightfile(DataLoader *dataloader, TMVA::Types::EMVA methodType, const TString &weightfile)
Adds an already constructed method to be managed by this factory.
Bool_t fModelPersistence
! option to save the trained model in xml file or using serialization
std::map< TString, Double_t > OptimizeAllMethods(TString fomType="ROCIntegral", TString fitType="FitGA")
Iterates through all booked methods and sees if they use parameter tuning and if so does just that,...
ROCCurve * GetROC(DataLoader *loader, TString theMethodName, UInt_t iClass=0, Types::ETreeType type=Types::kTesting)
Private method to generate a ROCCurve instance for a given method.
TH1F * EvaluateImportanceShort(DataLoader *loader, Types::EMVA theMethod, TString methodTitle, const char *theOption="")
Types::EAnalysisType fAnalysisType
! the training type
Bool_t HasMethod(const TString &datasetname, const TString &title) const
Checks whether a given method name is defined for a given dataset.
TGraph * GetROCCurve(DataLoader *loader, TString theMethodName, Bool_t setTitles=kTRUE, UInt_t iClass=0, Types::ETreeType type=Types::kTesting)
Argument iClass specifies the class to generate the ROC curve in a multiclass setting.
TH1F * EvaluateImportanceAll(DataLoader *loader, Types::EMVA theMethod, TString methodTitle, const char *theOption="")
void SetVerbose(Bool_t v=kTRUE)
TFile * fgTargetFile
! ROOT output file
IMethod * GetMethod(const TString &datasetname, const TString &title) const
Returns pointer to MVA that corresponds to given method title.
void DeleteAllMethods(void)
Delete methods.
TString fTransformations
! list of transformations to test
void Greetings()
Print welcome message.
Interface for all concrete MVA method implementations.
Virtual base Class for all MVA method.
const TString & GetMethodName() const
Class for boosting a TMVA method.
Class for categorizing the phase space.
ostringstream derivative to redirect and format output
void SetMinType(EMsgType minType)
void SetSource(const std::string &source)
static void InhibitOutput()
Ranking for variables in method (implementation)
Class that is the base-class for a vector of result.
Class which takes the results of a multiclass classification.
Class that is the base-class for a vector of result.
Singleton class for Global types used by TMVA.
static Types & Instance()
The single instance of "Types" if existing already, or create it (Singleton)
A TMultiGraph is a collection of TGraph (or derived) objects.
virtual void SetTitle(const char *title="")
Set the title of the TNamed.
const char * GetName() const override
Returns name of object.
@ kOverwrite
overwrite existing object with same name
virtual const char * GetName() const
Returns name of object.
virtual Int_t Write(const char *name=nullptr, Int_t option=0, Int_t bufsize=0)
Write this object to the current directory.
void SetGrid(Int_t valuex=1, Int_t valuey=1) override
TLegend * BuildLegend(Double_t x1=0.3, Double_t y1=0.21, Double_t x2=0.3, Double_t y2=0.21, const char *title="", Option_t *option="") override
Build a legend from the graphical objects in the pad.
Principal Components Analysis (PCA)
Random number generator class based on M.
void ToLower()
Change string to lower-case.
int CompareTo(const char *cs, ECaseCompare cmp=kExact) const
Compare a string to char *cs2.
const char * Data() const
TString & ReplaceAll(const TString &s1, const TString &s2)
static TString Format(const char *fmt,...)
Static method which formats a string using a printf style format descriptor and return a TString.
Bool_t Contains(const char *pat, ECaseCompare cmp=kExact) const
void SetOptStat(Int_t stat=1)
The type of information printed in the histogram statistics box can be selected via the parameter mod...
void SetTitleXOffset(Float_t offset=1)
virtual int MakeDirectory(const char *name)
Make a directory.
void DataLoaderCopy(TMVA::DataLoader *des, TMVA::DataLoader *src)
void CreateVariableTransforms(const TString &trafoDefinition, TMVA::DataSetInfo &dataInfo, TMVA::TransformationHandler &transformationHandler, TMVA::MsgLogger &log)
MsgLogger & Endl(MsgLogger &ml)
static uint64_t sum(uint64_t i)
const Int_t MinNoTrainingEvents