44 fSigs.resize(numFolds);
45 fSeps.resize(numFolds);
80 fSigs[iFold] = fr.
fSig;
81 fSeps[iFold] = fr.
fSep;
82 fEff01s[iFold] = fr.
fEff01;
83 fEff10s[iFold] = fr.
fEff10;
84 fEff30s[iFold] = fr.
fEff30;
94 return fROCCurves.get();
110 Double_t increment = 1.0 / (numSamples-1);
111 std::vector<Double_t>
x(numSamples),
y(numSamples);
113 TList *rocCurveList = fROCCurves.get()->GetListOfGraphs();
115 for(
UInt_t iSample = 0; iSample < numSamples; iSample++) {
116 Double_t xPoint = iSample * increment;
119 for(
Int_t iGraph = 0; iGraph < rocCurveList->
GetSize(); iGraph++) {
120 TGraph *foldROC =
static_cast<TGraph *
>(rocCurveList->
At(iGraph));
121 rocSum += foldROC->
Eval(xPoint);
125 y[iSample] = rocSum/rocCurveList->
GetSize();
128 return new TGraph(numSamples, &
x[0], &
y[0]);
135 for(
auto &roc : fROCs) {
138 return avg/fROCs.size();
147 for(
auto &roc : fROCs) {
160 fLogger << kHEADER <<
" ==== Results ====" <<
Endl;
161 for(
auto &item:fROCs) {
162 fLogger << kINFO <<
Form(
"Fold %i ROC-Int : %.4f",item.first,item.second) << std::endl;
165 fLogger << kINFO <<
"------------------------" <<
Endl;
166 fLogger << kINFO <<
Form(
"Average ROC-Int : %.4f",GetROCAverage()) <<
Endl;
167 fLogger << kINFO <<
Form(
"Std-Dev ROC-Int : %.4f",GetROCStandardDeviation()) <<
Endl;
176 fROCCurves->Draw(
"AL");
177 fROCCurves->GetXaxis()->SetTitle(
" Signal Efficiency ");
178 fROCCurves->GetYaxis()->SetTitle(
" Background Rejection ");
179 Float_t adjust=1+fROCs.size()*0.01;
180 c->BuildLegend(0.15,0.15,0.4*adjust,0.5*adjust);
181 c->SetTitle(
"Cross Validation ROC Curves");
193 for (
auto foldRocObj : *(*fROCCurves).GetListOfGraphs()) {
194 TGraph * foldRocGraph =
dynamic_cast<TGraph *
>(foldRocObj->Clone());
197 rocs.Add(foldRocGraph);
202 TGraph *avgRocGraph = GetAvgROCCurve(100);
203 avgRocGraph->
SetTitle(
"Avg ROC Curve");
206 rocs.Add(avgRocGraph);
212 title =
"Cross Validation Average ROC Curve";
215 rocs.SetTitle(title);
216 rocs.GetXaxis()->SetTitle(
"Signal Efficiency");
217 rocs.GetYaxis()->SetTitle(
"Background Rejection");
218 rocs.DrawClone(
"AL");
222 TList *ROCCurveList = rocs.GetListOfGraphs();
226 leg->AddEntry(
static_cast<TGraph *
>(ROCCurveList->
At(nCurves-1)),
227 "Avg ROC Curve",
"l");
228 leg->AddEntry(
static_cast<TGraph *
>(ROCCurveList->
At(0)),
229 "Fold ROC Curves",
"l");
236 c->SetTitle(
"Cross Validation Average ROC Curve");
278 :
TMVA::
Envelope(jobName, dataloader, nullptr, options),
279 fAnalysisType(
Types::kMaxAnalysisType),
280 fAnalysisTypeStr(
"Auto"),
281 fSplitTypeStr(
"Random"),
283 fCvFactoryOptions(
""),
290 fOutputFactoryOptions(
""),
291 fOutputFile(outputFile),
293 fSplitExprString(
""),
295 fTransformations(
""),
323 DeclareOptionRef(fSilent,
"Silent",
324 "Batch mode: boolean silent flag inhibiting any output from TMVA after the creation of the factory "
325 "class object (default: False)");
326 DeclareOptionRef(fVerbose,
"V",
"Verbose flag");
327 DeclareOptionRef(fVerboseLevel =
TString(
"Info"),
"VerboseLevel",
"VerboseLevel (Debug/Verbose/Info)");
328 AddPreDefVal(
TString(
"Debug"));
329 AddPreDefVal(
TString(
"Verbose"));
332 DeclareOptionRef(fTransformations,
"Transformations",
333 "List of transformations to test; formatting example: \"Transformations=I;D;P;U;G,D\", for "
334 "identity, decorrelation, PCA, Uniform and Gaussianisation followed by decorrelation "
337 DeclareOptionRef(fDrawProgressBar,
"DrawProgressBar",
"Boolean to show draw progress bar");
338 DeclareOptionRef(fCorrelations,
"Correlations",
"Boolean to show correlation in output");
339 DeclareOptionRef(fROC,
"ROC",
"Boolean to show ROC in output");
342 DeclareOptionRef(fAnalysisTypeStr,
"AnalysisType",
343 "Set the analysis type (Classification, Regression, Multiclass, Auto) (default: Auto)");
344 AddPreDefVal(
TString(
"Classification"));
345 AddPreDefVal(
TString(
"Regression"));
346 AddPreDefVal(
TString(
"Multiclass"));
350 DeclareOptionRef(fSplitTypeStr,
"SplitType",
351 "Set the split type (Deterministic, Random, RandomStratified) (default: Random)");
352 AddPreDefVal(
TString(
"Deterministic"));
353 AddPreDefVal(
TString(
"Random"));
354 AddPreDefVal(
TString(
"RandomStratified"));
356 DeclareOptionRef(fSplitExprString,
"SplitExpr",
"The expression used to assign events to folds");
357 DeclareOptionRef(fNumFolds,
"NumFolds",
"Number of folds to generate");
358 DeclareOptionRef(fNumWorkerProcs,
"NumWorkerProcs",
359 "Determines how many processes to use for evaluation. 1 means no"
360 " parallelisation. 2 means use 2 processes. 0 means figure out the"
361 " number automatically based on the number of cpus available. Default"
364 DeclareOptionRef(fFoldFileOutput,
"FoldFileOutput",
365 "If given a TMVA output file will be generated for each fold. Filename will be the same as "
366 "specifed for the combined output with a _foldX suffix. (default: false)");
368 DeclareOptionRef(fOutputEnsembling =
TString(
"None"),
"OutputEnsembling",
369 "Combines output from contained methods. If None, no combination is performed. (default None)");
381 if (fSplitTypeStr !=
"Deterministic" && fSplitExprString !=
"") {
382 Log() << kFATAL <<
"SplitExpr can only be used with Deterministic Splitting" <<
Endl;
386 fAnalysisTypeStr.ToLower();
387 if (fAnalysisTypeStr ==
"classification") {
389 }
else if (fAnalysisTypeStr ==
"regression") {
391 }
else if (fAnalysisTypeStr ==
"multiclass") {
393 }
else if (fAnalysisTypeStr ==
"auto") {
398 fCvFactoryOptions +=
"V:";
399 fOutputFactoryOptions +=
"V:";
401 fCvFactoryOptions +=
"!V:";
402 fOutputFactoryOptions +=
"!V:";
405 fCvFactoryOptions +=
Form(
"VerboseLevel=%s:", fVerboseLevel.Data());
406 fOutputFactoryOptions +=
Form(
"VerboseLevel=%s:", fVerboseLevel.Data());
408 fCvFactoryOptions +=
Form(
"AnalysisType=%s:", fAnalysisTypeStr.Data());
409 fOutputFactoryOptions +=
Form(
"AnalysisType=%s:", fAnalysisTypeStr.Data());
411 if (!fDrawProgressBar) {
412 fCvFactoryOptions +=
"!DrawProgressBar:";
413 fOutputFactoryOptions +=
"!DrawProgressBar:";
416 if (fTransformations !=
"") {
417 fCvFactoryOptions +=
Form(
"Transformations=%s:", fTransformations.Data());
418 fOutputFactoryOptions +=
Form(
"Transformations=%s:", fTransformations.Data());
422 fCvFactoryOptions +=
"Correlations:";
423 fOutputFactoryOptions +=
"Correlations:";
425 fCvFactoryOptions +=
"!Correlations:";
426 fOutputFactoryOptions +=
"!Correlations:";
430 fCvFactoryOptions +=
"ROC:";
431 fOutputFactoryOptions +=
"ROC:";
433 fCvFactoryOptions +=
"!ROC:";
434 fOutputFactoryOptions +=
"!ROC:";
438 fCvFactoryOptions +=
Form(
"Silent:");
439 fOutputFactoryOptions +=
Form(
"Silent:");
443 if (fFoldFileOutput && fOutputFile ==
nullptr) {
444 Log() << kFATAL <<
"No output file given, cannot generate per fold output." <<
Endl;
449 fFoldFactory = std::make_unique<TMVA::Factory>(fJobName, fCvFactoryOptions);
454 if (fOutputFile ==
nullptr) {
455 fFactory = std::make_unique<TMVA::Factory>(fJobName, fOutputFactoryOptions);
457 fFactory = std::make_unique<TMVA::Factory>(fJobName, fOutputFile, fOutputFactoryOptions);
460 if(fSplitTypeStr ==
"Random"){
461 fSplit = std::unique_ptr<CvSplitKFolds>(
new CvSplitKFolds(fNumFolds, fSplitExprString,
kFALSE));
462 }
else if(fSplitTypeStr ==
"RandomStratified"){
463 fSplit = std::unique_ptr<CvSplitKFolds>(
new CvSplitKFolds(fNumFolds, fSplitExprString,
kTRUE));
465 fSplit = std::unique_ptr<CvSplitKFolds>(
new CvSplitKFolds(fNumFolds, fSplitExprString));
473 if (i != fNumFolds) {
475 fSplit = std::make_unique<CvSplitKFolds>(fNumFolds, fSplitExprString);
476 fDataLoader->MakeKFoldDataSet(*fSplit);
486 if (splitExpr != fSplitExprString) {
487 fSplitExprString = splitExpr;
488 fSplit = std::make_unique<CvSplitKFolds>(fNumFolds, fSplitExprString);
489 fDataLoader->MakeKFoldDataSet(*fSplit);
511 Log() << kDEBUG <<
"Processing " << methodTitle <<
" fold " << iFold <<
Endl;
514 TFile *foldOutputFile =
nullptr;
516 if (fFoldFileOutput && fOutputFile !=
nullptr) {
517 TString path = std::string(
"") +
gSystem->
DirName(fOutputFile->GetName()) +
"/" + foldTitle +
".root";
519 Log() << kINFO <<
"Creating fold output at:" << path <<
Endl;
520 fFoldFactory = std::make_unique<TMVA::Factory>(fJobName, foldOutputFile, fCvFactoryOptions);
524 MethodBase *smethod = fFoldFactory->BookMethod(fDataLoader.get(), methodTypeName, foldTitle, methodOptions);
531 fFoldFactory->TestAllMethods();
532 fFoldFactory->EvaluateAllMethods();
538 result.
fROCIntegral = fFoldFactory->GetROCIntegral(fDataLoader->GetName(), foldTitle);
540 TGraph *
gr = fFoldFactory->GetROCCurve(fDataLoader->GetName(), foldTitle,
true);
564 if (fFoldFileOutput && foldOutputFile !=
nullptr) {
565 foldOutputFile->
Close();
574 fFoldFactory->DeleteAllMethods();
575 fFoldFactory->fMethodsMap.clear();
589 fDataLoader->MakeKFoldDataSet(*fSplit);
593 fResults.reserve(fMethods.size());
594 for (
auto & methodInfo : fMethods) {
597 TString methodTypeName = methodInfo.GetValue<
TString>(
"MethodName");
598 TString methodTitle = methodInfo.GetValue<
TString>(
"MethodTitle");
600 if (methodTypeName ==
"") {
601 Log() << kFATAL <<
"No method booked for cross-validation" <<
Endl;
607 Log() << kINFO <<
"========================================" <<
Endl;
608 Log() << kINFO <<
"Processing folds for method " << methodTitle <<
Endl;
609 Log() << kINFO <<
"========================================" <<
Endl;
613 auto nWorkers = fNumWorkerProcs;
619 for (
UInt_t iFold = 0; iFold < fNumFolds; ++iFold) {
620 auto fold_result = ProcessFold(iFold, methodInfo);
621 result.Fill(fold_result);
626 std::vector<CrossValidationFoldResult> result_vector;
628 auto workItem = [
this, methodInfo](
UInt_t iFold) {
629 return ProcessFold(iFold, methodInfo);
634 for (
auto && fold_result : result_vector) {
635 result.Fill(fold_result);
640 fResults.push_back(result);
644 Form(
"SplitExpr=%s:NumFolds=%i"
645 ":EncapsulatedMethodName=%s"
646 ":EncapsulatedMethodTypeName=%s"
647 ":OutputEnsembling=%s",
648 fSplitExprString.Data(), fNumFolds, methodTitle.Data(), methodTypeName.
Data(), fOutputEnsembling.Data());
654 IMethod *method_interface = fFactory->GetMethod(fDataLoader->GetName(), methodTitle);
662 Log() << kINFO <<
"========================================" <<
Endl;
663 Log() << kINFO <<
"Folds processed for all methods, evaluating." <<
Endl;
664 Log() << kINFO <<
"========================================" <<
Endl;
668 fDataLoader->RecombineKFoldDataSet(*fSplit);
671 for (
auto & methodInfo : fMethods) {
672 TString methodTypeName = methodInfo.GetValue<
TString>(
"MethodName");
673 TString methodTitle = methodInfo.GetValue<
TString>(
"MethodTitle");
675 IMethod *method_interface = fFactory->GetMethod(fDataLoader->GetName(), methodTitle);
678 if (fOutputFile !=
nullptr) {
679 fFactory->WriteDataInformation(method->fDataSetInfo);
683 method->TrainMethod();
688 fFactory->TestAllMethods();
691 fFactory->EvaluateAllMethods();
693 Log() << kINFO <<
"Evaluation done." <<
Endl;
699 if (fResults.empty()) {
700 Log() << kFATAL <<
"No cross-validation results available" <<
Endl;
char * Form(const char *fmt,...)
R__EXTERN TSystem * gSystem
This class provides a simple interface to execute the same task multiple times in parallel,...
auto Map(F func, unsigned nTimes) -> std::vector< typename std::result_of< F()>::type >
Execute func (with no arguments) nTimes in parallel.
A pseudo container class which is a generator of indices.
virtual void SetLineWidth(Width_t lwidth)
Set the line width.
virtual void SetLineColor(Color_t lcolor)
Set the line color.
virtual Int_t GetSize() const
Return the capacity of the collection, i.e.
A ROOT file is a suite of consecutive data records (TKey instances) with a well defined format.
static TFile * Open(const char *name, Option_t *option="", const char *ftitle="", Int_t compress=ROOT::RCompressionSetting::EDefaults::kUseCompiledDefault, Int_t netopt=0)
Create / open a file.
void Close(Option_t *option="") override
Close a file.
A Graph is a graphics object made of two arrays X and Y with npoints each.
virtual void SetTitle(const char *title="")
Change (i.e.
virtual Double_t Eval(Double_t x, TSpline *spline=0, Option_t *option="") const
Interpolate points in this graph at x using a TSpline.
This class displays a legend box (TPaveText) containing several legend entries.
virtual TObject * At(Int_t idx) const
Returns the object at position idx. Returns 0 if idx is out of range.
UInt_t GetNumWorkers() const
void CheckForUnusedOptions() const
checks for unused options in option string
Class to save the results of cross validation, the metric for the classification ins ROC and you can ...
std::vector< Double_t > fSeps
std::vector< Double_t > fEff01s
CrossValidationResult(UInt_t numFolds)
std::vector< Double_t > fTrainEff30s
std::shared_ptr< TMultiGraph > fROCCurves
std::vector< Double_t > fSigs
std::vector< Double_t > fEff30s
void Fill(CrossValidationFoldResult const &fr)
Float_t GetROCStandardDeviation() const
std::vector< Double_t > fEff10s
std::vector< Double_t > fTrainEff01s
std::map< UInt_t, Float_t > fROCs
std::vector< Double_t > fTrainEff10s
Float_t GetROCAverage() const
std::vector< Double_t > fEffAreas
TCanvas * DrawAvgROCCurve(Bool_t drawFolds=kFALSE, TString title="") const
TMultiGraph * GetROCCurves(Bool_t fLegend=kTRUE)
TGraph * GetAvgROCCurve(UInt_t numSamples=100) const
Generates a multigraph that contains an average ROC Curve.
TCanvas * Draw(const TString name="CrossValidation") const
Class to perform cross validation, splitting the dataloader into folds.
void SetNumFolds(UInt_t i)
void ParseOptions()
Method to parse the internal option string.
const std::vector< CrossValidationResult > & GetResults() const
CrossValidation(TString jobName, TMVA::DataLoader *dataloader, TString options)
void SetSplitExpr(TString splitExpr)
void Evaluate()
Does training, test set evaluation and performance evaluation of using cross-evalution.
CrossValidationFoldResult ProcessFold(UInt_t iFold, const OptionMap &methodInfo)
Evaluates each fold in turn.
void DeleteAllResults(Types::ETreeType type, Types::EAnalysisType analysistype)
Deletes all results currently in the dataset.
Abstract base class for all high level ml algorithms, you can book ml methods like BDT,...
virtual void ParseOptions()
Method to parse the internal option string.
static void SetIsTraining(Bool_t)
when this static function is called, it sets the flag whether events with negative event weight shoul...
Interface for all concrete MVA method implementations.
Virtual base Class for all MVA method.
virtual Double_t GetSeparation(TH1 *, TH1 *) const
compute "separation" defined as
virtual Double_t GetSignificance() const
compute significance of mean difference
Types::EAnalysisType GetAnalysisType() const
virtual Double_t GetEfficiency(const TString &, Types::ETreeType, Double_t &err)
fill background efficiency (resp.
virtual Double_t GetTrainingEfficiency(const TString &)
std::map< const TMVA::Event *, UInt_t > fEventToFoldMapping
ostringstream derivative to redirect and format output
static void EnableOutput()
class to storage options for the differents methods
T GetValue(const TString &key)
Singleton class for Global types used by TMVA.
A TMultiGraph is a collection of TGraph (or derived) objects.
virtual TObject * Clone(const char *newname="") const
Make a clone of an object using the Streamer facility.
const char * Data() const
static TString Format(const char *fmt,...)
Static method which formats a string using a printf style format descriptor and return a TString.
virtual const char * DirName(const char *pathname)
Return the directory name in pathname.
create variable transformations
MsgLogger & Endl(MsgLogger &ml)
Double_t Sqrt(Double_t x)
LongDouble_t Power(LongDouble_t x, LongDouble_t y)