92 return fROCCurves.get();
133 for(
auto &
roc : fROCs) {
136 return avg/fROCs.size();
145 for(
auto &
roc : fROCs) {
158 fLogger << kHEADER <<
" ==== Results ====" <<
Endl;
159 for(
auto &item:fROCs) {
160 fLogger << kINFO <<
TString::Format(
"Fold %i ROC-Int : %.4f",item.first,item.second) << std::endl;
163 fLogger << kINFO <<
"------------------------" <<
Endl;
165 fLogger << kINFO <<
TString::Format(
"Std-Dev ROC-Int : %.4f",GetROCStandardDeviation()) <<
Endl;
174 fROCCurves->Draw(
"AL");
175 fROCCurves->GetXaxis()->SetTitle(
" Signal Efficiency ");
176 fROCCurves->GetYaxis()->SetTitle(
" Background Rejection ");
179 c->SetTitle(
"Cross Validation ROC Curves");
193 for (
auto foldRocObj : *(*fROCCurves).GetListOfGraphs()) {
212 title =
"Cross Validation Average ROC Curve";
215 rocs->SetName(
"cv_rocs");
216 rocs->SetTitle(title);
217 rocs->GetXaxis()->SetTitle(
"Signal Efficiency");
218 rocs->GetYaxis()->SetTitle(
"Background Rejection");
219 rocs->DrawClone(
"AL");
228 "Avg ROC Curve",
"l");
230 "Fold ROC Curves",
"l");
237 c->SetTitle(
"Cross Validation Average ROC Curve");
280 fAnalysisType(
Types::kMaxAnalysisType),
281 fAnalysisTypeStr(
"Auto"),
282 fSplitTypeStr(
"Random"),
284 fCvFactoryOptions(
""),
291 fOutputFactoryOptions(
""),
294 fSplitExprString(
""),
296 fTransformations(
""),
324 DeclareOptionRef(fSilent,
"Silent",
325 "Batch mode: boolean silent flag inhibiting any output from TMVA after the creation of the factory "
326 "class object (default: False)");
327 DeclareOptionRef(fVerbose,
"V",
"Verbose flag");
328 DeclareOptionRef(fVerboseLevel =
TString(
"Info"),
"VerboseLevel",
"VerboseLevel (Debug/Verbose/Info)");
329 AddPreDefVal(
TString(
"Debug"));
330 AddPreDefVal(
TString(
"Verbose"));
333 DeclareOptionRef(fTransformations,
"Transformations",
334 "List of transformations to test; formatting example: \"Transformations=I;D;P;U;G,D\", for "
335 "identity, decorrelation, PCA, Uniform and Gaussianisation followed by decorrelation "
338 DeclareOptionRef(fDrawProgressBar,
"DrawProgressBar",
"Boolean to show draw progress bar");
339 DeclareOptionRef(fCorrelations,
"Correlations",
"Boolean to show correlation in output");
340 DeclareOptionRef(fROC,
"ROC",
"Boolean to show ROC in output");
343 DeclareOptionRef(fAnalysisTypeStr,
"AnalysisType",
344 "Set the analysis type (Classification, Regression, Multiclass, Auto) (default: Auto)");
345 AddPreDefVal(
TString(
"Classification"));
346 AddPreDefVal(
TString(
"Regression"));
347 AddPreDefVal(
TString(
"Multiclass"));
351 DeclareOptionRef(fSplitTypeStr,
"SplitType",
352 "Set the split type (Deterministic, Random, RandomStratified) (default: Random)");
353 AddPreDefVal(
TString(
"Deterministic"));
354 AddPreDefVal(
TString(
"Random"));
355 AddPreDefVal(
TString(
"RandomStratified"));
357 DeclareOptionRef(fSplitExprString,
"SplitExpr",
"The expression used to assign events to folds");
358 DeclareOptionRef(fNumFolds,
"NumFolds",
"Number of folds to generate");
359 DeclareOptionRef(fNumWorkerProcs,
"NumWorkerProcs",
360 "Determines how many processes to use for evaluation. 1 means no"
361 " parallelisation. 2 means use 2 processes. 0 means figure out the"
362 " number automatically based on the number of cpus available. Default"
365 DeclareOptionRef(fFoldFileOutput,
"FoldFileOutput",
366 "If given a TMVA output file will be generated for each fold. Filename will be the same as "
367 "specifed for the combined output with a _foldX suffix. (default: false)");
369 DeclareOptionRef(fOutputEnsembling =
TString(
"None"),
"OutputEnsembling",
370 "Combines output from contained methods. If None, no combination is performed. (default None)");
382 if (fSplitTypeStr !=
"Deterministic" && fSplitExprString !=
"") {
383 Log() << kFATAL <<
"SplitExpr can only be used with Deterministic Splitting" <<
Endl;
387 fAnalysisTypeStr.ToLower();
388 if (fAnalysisTypeStr ==
"classification") {
390 }
else if (fAnalysisTypeStr ==
"regression") {
392 }
else if (fAnalysisTypeStr ==
"multiclass") {
394 }
else if (fAnalysisTypeStr ==
"auto") {
399 fCvFactoryOptions +=
"V:";
400 fOutputFactoryOptions +=
"V:";
402 fCvFactoryOptions +=
"!V:";
403 fOutputFactoryOptions +=
"!V:";
406 fCvFactoryOptions +=
TString::Format(
"VerboseLevel=%s:", fVerboseLevel.Data());
407 fOutputFactoryOptions +=
TString::Format(
"VerboseLevel=%s:", fVerboseLevel.Data());
409 fCvFactoryOptions +=
TString::Format(
"AnalysisType=%s:", fAnalysisTypeStr.Data());
410 fOutputFactoryOptions +=
TString::Format(
"AnalysisType=%s:", fAnalysisTypeStr.Data());
412 if (!fDrawProgressBar) {
413 fCvFactoryOptions +=
"!DrawProgressBar:";
414 fOutputFactoryOptions +=
"!DrawProgressBar:";
417 if (fTransformations !=
"") {
418 fCvFactoryOptions +=
TString::Format(
"Transformations=%s:", fTransformations.Data());
419 fOutputFactoryOptions +=
TString::Format(
"Transformations=%s:", fTransformations.Data());
423 fCvFactoryOptions +=
"Correlations:";
424 fOutputFactoryOptions +=
"Correlations:";
426 fCvFactoryOptions +=
"!Correlations:";
427 fOutputFactoryOptions +=
"!Correlations:";
431 fCvFactoryOptions +=
"ROC:";
432 fOutputFactoryOptions +=
"ROC:";
434 fCvFactoryOptions +=
"!ROC:";
435 fOutputFactoryOptions +=
"!ROC:";
439 fCvFactoryOptions +=
"Silent:";
440 fOutputFactoryOptions +=
"Silent:";
444 if (fFoldFileOutput && fOutputFile ==
nullptr) {
445 Log() << kFATAL <<
"No output file given, cannot generate per fold output." <<
Endl;
450 fFoldFactory = std::make_unique<TMVA::Factory>(fJobName, fCvFactoryOptions);
455 if (fOutputFile ==
nullptr) {
456 fFactory = std::make_unique<TMVA::Factory>(fJobName, fOutputFactoryOptions);
458 fFactory = std::make_unique<TMVA::Factory>(fJobName, fOutputFile, fOutputFactoryOptions);
461 if(fSplitTypeStr ==
"Random"){
462 fSplit = std::unique_ptr<CvSplitKFolds>(
new CvSplitKFolds(fNumFolds, fSplitExprString,
kFALSE));
463 }
else if(fSplitTypeStr ==
"RandomStratified"){
464 fSplit = std::unique_ptr<CvSplitKFolds>(
new CvSplitKFolds(fNumFolds, fSplitExprString,
kTRUE));
466 fSplit = std::unique_ptr<CvSplitKFolds>(
new CvSplitKFolds(fNumFolds, fSplitExprString));
474 if (i != fNumFolds) {
476 fSplit = std::make_unique<CvSplitKFolds>(fNumFolds, fSplitExprString);
477 fDataLoader->MakeKFoldDataSet(*fSplit);
489 fSplit = std::make_unique<CvSplitKFolds>(fNumFolds, fSplitExprString);
490 fDataLoader->MakeKFoldDataSet(*fSplit);
513 Log() << kDEBUG <<
"Processing " << methodTitle <<
" fold " <<
iFold <<
Endl;
518 if (fFoldFileOutput && fOutputFile !=
nullptr) {
521 Log() << kINFO <<
"Creating fold output at:" << path <<
Endl;
522 fFoldFactory = std::make_unique<TMVA::Factory>(fJobName,
foldOutputFile, fCvFactoryOptions);
533 fFoldFactory->TestAllMethods();
534 fFoldFactory->EvaluateAllMethods();
540 result.fROCIntegral = fFoldFactory->GetROCIntegral(fDataLoader->GetName(),
foldTitle);
557 result.fTrainEff01 =
smethod->GetTrainingEfficiency(
"Efficiency:0.01");
558 result.fTrainEff10 =
smethod->GetTrainingEfficiency(
"Efficiency:0.10");
559 result.fTrainEff30 =
smethod->GetTrainingEfficiency(
"Efficiency:0.30");
576 fFoldFactory->DeleteAllMethods();
577 fFoldFactory->fMethodsMap.clear();
591 fDataLoader->MakeKFoldDataSet(*fSplit);
595 fResults.reserve(fMethods.size());
603 Log() << kFATAL <<
"No method booked for cross-validation" <<
Endl;
607 Log() << kINFO <<
Endl;
608 Log() << kINFO <<
Endl;
609 Log() << kINFO <<
"========================================" <<
Endl;
610 Log() << kINFO <<
"Processing folds for method " << methodTitle <<
Endl;
611 Log() << kINFO <<
"========================================" <<
Endl;
612 Log() << kINFO <<
Endl;
642 fResults.push_back(
result);
647 ":EncapsulatedMethodName=%s"
648 ":EncapsulatedMethodTypeName=%s"
649 ":OutputEnsembling=%s",
650 fSplitExprString.Data(), fNumFolds, methodTitle.Data(),
methodTypeName.Data(), fOutputEnsembling.Data());
659 method->fEventToFoldMapping = fSplit->fEventToFoldMapping;
662 Log() << kINFO <<
Endl;
663 Log() << kINFO <<
Endl;
664 Log() << kINFO <<
"========================================" <<
Endl;
665 Log() << kINFO <<
"Folds processed for all methods, evaluating." <<
Endl;
666 Log() << kINFO <<
"========================================" <<
Endl;
667 Log() << kINFO <<
Endl;
670 fDataLoader->RecombineKFoldDataSet(*fSplit);
680 if (fOutputFile !=
nullptr) {
681 fFactory->WriteDataInformation(
method->fDataSetInfo);
690 fFactory->TestAllMethods();
693 fFactory->EvaluateAllMethods();
695 Log() << kINFO <<
"Evaluation done." <<
Endl;
701 if (fResults.empty()) {
702 Log() << kFATAL <<
"No cross-validation results available" <<
Endl;
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t result
R__EXTERN TSystem * gSystem
This class provides a simple interface to execute the same task multiple times in parallel,...
virtual void SetLineWidth(Width_t lwidth)
Set the line width.
virtual void SetLineColor(Color_t lcolor)
Set the line color.
A ROOT file is an on-disk file, usually with extension .root, that stores objects in a file-system-li...
static TFile * Open(const char *name, Option_t *option="", const char *ftitle="", Int_t compress=ROOT::RCompressionSetting::EDefaults::kUseCompiledDefault, Int_t netopt=0)
Create / open a file.
A TGraph is an object made of two arrays X and Y with npoints each.
void SetTitle(const char *title="") override
Change (i.e.
This class displays a legend box (TPaveText) containing several legend entries.
UInt_t GetNumWorkers() const
void CheckForUnusedOptions() const
checks for unused options in option string
Class to save the results of cross validation, the metric for the classification ins ROC and you can ...
std::vector< Double_t > fSeps
std::vector< Double_t > fEff01s
CrossValidationResult(UInt_t numFolds)
std::vector< Double_t > fTrainEff30s
std::shared_ptr< TMultiGraph > fROCCurves
std::vector< Double_t > fSigs
std::vector< Double_t > fEff30s
void Fill(CrossValidationFoldResult const &fr)
Float_t GetROCStandardDeviation() const
std::vector< Double_t > fEff10s
std::vector< Double_t > fTrainEff01s
std::map< UInt_t, Float_t > fROCs
std::vector< Double_t > fTrainEff10s
Float_t GetROCAverage() const
std::vector< Double_t > fEffAreas
TCanvas * DrawAvgROCCurve(Bool_t drawFolds=kFALSE, TString title="") const
TMultiGraph * GetROCCurves(Bool_t fLegend=kTRUE)
TGraph * GetAvgROCCurve(UInt_t numSamples=100) const
Generates a multigraph that contains an average ROC Curve.
TCanvas * Draw(const TString name="CrossValidation") const
Class to perform cross validation, splitting the dataloader into folds.
void SetNumFolds(UInt_t i)
void ParseOptions()
Method to parse the internal option string.
const std::vector< CrossValidationResult > & GetResults() const
CrossValidation(TString jobName, TMVA::DataLoader *dataloader, TString options)
void SetSplitExpr(TString splitExpr)
void Evaluate()
Does training, test set evaluation and performance evaluation of using cross-evalution.
CrossValidationFoldResult ProcessFold(UInt_t iFold, const OptionMap &methodInfo)
Evaluates each fold in turn.
Abstract base class for all high level ml algorithms, you can book ml methods like BDT,...
virtual void ParseOptions()
Method to parse the internal option string.
static void SetIsTraining(Bool_t)
when this static function is called, it sets the flag whether events with negative event weight shoul...
Interface for all concrete MVA method implementations.
Virtual base Class for all MVA method.
ostringstream derivative to redirect and format output
static void EnableOutput()
class to storage options for the differents methods
Singleton class for Global types used by TMVA.
A TMultiGraph is a collection of TGraph (or derived) objects.
TObject * Clone(const char *newname="") const override
Make a clone of an object using the Streamer facility.
static TString Format(const char *fmt,...)
Static method which formats a string using a printf style format descriptor and return a TString.
virtual TString GetDirName(const char *pathname)
Return the directory name in pathname.
create variable transformations
MsgLogger & Endl(MsgLogger &ml)
Double_t Sqrt(Double_t x)
Returns the square root of x.
LongDouble_t Power(LongDouble_t x, LongDouble_t y)
Returns x raised to the power y.