44 fSigs.resize(numFolds);
45 fSeps.resize(numFolds);
80 fSigs[iFold] = fr.
fSig;
81 fSeps[iFold] = fr.
fSep;
82 fEff01s[iFold] = fr.
fEff01;
83 fEff10s[iFold] = fr.
fEff10;
84 fEff30s[iFold] = fr.
fEff30;
94 return fROCCurves.get();
110 Double_t increment = 1.0 / (numSamples-1);
114 TList *rocCurveList = fROCCurves.get()->GetListOfGraphs();
116 for(
UInt_t iSample = 0; iSample < numSamples; iSample++) {
117 Double_t xPoint = iSample * increment;
120 for(
Int_t iGraph = 0; iGraph < rocCurveList->
GetSize(); iGraph++) {
121 TGraph *foldROC =
static_cast<TGraph *
>(rocCurveList->
At(iGraph));
122 rocSum += foldROC->
Eval(xPoint);
126 y[iSample] = rocSum/rocCurveList->
GetSize();
129 return new TGraph(numSamples,
x,
y);
136 for(
auto &roc : fROCs) {
139 return avg/fROCs.size();
148 for(
auto &roc : fROCs) {
161 fLogger << kHEADER <<
" ==== Results ====" <<
Endl;
162 for(
auto &item:fROCs) {
163 fLogger << kINFO <<
Form(
"Fold %i ROC-Int : %.4f",item.first,item.second) << std::endl;
166 fLogger << kINFO <<
"------------------------" <<
Endl;
167 fLogger << kINFO <<
Form(
"Average ROC-Int : %.4f",GetROCAverage()) <<
Endl;
168 fLogger << kINFO <<
Form(
"Std-Dev ROC-Int : %.4f",GetROCStandardDeviation()) <<
Endl;
177 fROCCurves->Draw(
"AL");
178 fROCCurves->GetXaxis()->SetTitle(
" Signal Efficiency ");
179 fROCCurves->GetYaxis()->SetTitle(
" Background Rejection ");
180 Float_t adjust=1+fROCs.size()*0.01;
181 c->BuildLegend(0.15,0.15,0.4*adjust,0.5*adjust);
182 c->SetTitle(
"Cross Validation ROC Curves");
194 for (
auto foldRocObj : *(*fROCCurves).GetListOfGraphs()) {
195 TGraph * foldRocGraph =
dynamic_cast<TGraph *
>(foldRocObj->Clone());
198 rocs.Add(foldRocGraph);
203 TGraph *avgRocGraph = GetAvgROCCurve(100);
204 avgRocGraph->
SetTitle(
"Avg ROC Curve");
207 rocs.Add(avgRocGraph);
213 title =
"Cross Validation Average ROC Curve";
216 rocs.SetTitle(title);
217 rocs.GetXaxis()->SetTitle(
"Signal Efficiency");
218 rocs.GetYaxis()->SetTitle(
"Background Rejection");
219 rocs.DrawClone(
"AL");
223 TList *ROCCurveList = rocs.GetListOfGraphs();
227 leg->AddEntry(
static_cast<TGraph *
>(ROCCurveList->
At(nCurves-1)),
228 "Avg ROC Curve",
"l");
229 leg->AddEntry(
static_cast<TGraph *
>(ROCCurveList->
At(0)),
230 "Fold ROC Curves",
"l");
237 c->SetTitle(
"Cross Validation Average ROC Curve");
280 fAnalysisType(
Types::kMaxAnalysisType),
281 fAnalysisTypeStr(
"Auto"),
282 fSplitTypeStr(
"Random"),
284 fCvFactoryOptions(
""),
291 fOutputFactoryOptions(
""),
292 fOutputFile(outputFile),
294 fSplitExprString(
""),
296 fTransformations(
""),
324 DeclareOptionRef(fSilent,
"Silent",
325 "Batch mode: boolean silent flag inhibiting any output from TMVA after the creation of the factory "
326 "class object (default: False)");
327 DeclareOptionRef(fVerbose,
"V",
"Verbose flag");
328 DeclareOptionRef(fVerboseLevel =
TString(
"Info"),
"VerboseLevel",
"VerboseLevel (Debug/Verbose/Info)");
329 AddPreDefVal(
TString(
"Debug"));
330 AddPreDefVal(
TString(
"Verbose"));
333 DeclareOptionRef(fTransformations,
"Transformations",
334 "List of transformations to test; formatting example: \"Transformations=I;D;P;U;G,D\", for "
335 "identity, decorrelation, PCA, Uniform and Gaussianisation followed by decorrelation "
338 DeclareOptionRef(fDrawProgressBar,
"DrawProgressBar",
"Boolean to show draw progress bar");
339 DeclareOptionRef(fCorrelations,
"Correlations",
"Boolean to show correlation in output");
340 DeclareOptionRef(fROC,
"ROC",
"Boolean to show ROC in output");
343 DeclareOptionRef(fAnalysisTypeStr,
"AnalysisType",
344 "Set the analysis type (Classification, Regression, Multiclass, Auto) (default: Auto)");
345 AddPreDefVal(
TString(
"Classification"));
346 AddPreDefVal(
TString(
"Regression"));
347 AddPreDefVal(
TString(
"Multiclass"));
351 DeclareOptionRef(fSplitTypeStr,
"SplitType",
352 "Set the split type (Deterministic, Random, RandomStratified) (default: Random)");
353 AddPreDefVal(
TString(
"Deterministic"));
354 AddPreDefVal(
TString(
"Random"));
355 AddPreDefVal(
TString(
"RandomStratified"));
357 DeclareOptionRef(fSplitExprString,
"SplitExpr",
"The expression used to assign events to folds");
358 DeclareOptionRef(fNumFolds,
"NumFolds",
"Number of folds to generate");
359 DeclareOptionRef(fNumWorkerProcs,
"NumWorkerProcs",
360 "Determines how many processes to use for evaluation. 1 means no"
361 " parallelisation. 2 means use 2 processes. 0 means figure out the"
362 " number automatically based on the number of cpus available. Default"
365 DeclareOptionRef(fFoldFileOutput,
"FoldFileOutput",
366 "If given a TMVA output file will be generated for each fold. Filename will be the same as "
367 "specifed for the combined output with a _foldX suffix. (default: false)");
369 DeclareOptionRef(fOutputEnsembling =
TString(
"None"),
"OutputEnsembling",
370 "Combines output from contained methods. If None, no combination is performed. (default None)");
382 if (fSplitTypeStr !=
"Deterministic" and fSplitExprString !=
"") {
383 Log() << kFATAL <<
"SplitExpr can only be used with Deterministic Splitting" <<
Endl;
387 fAnalysisTypeStr.ToLower();
388 if (fAnalysisTypeStr ==
"classification") {
390 }
else if (fAnalysisTypeStr ==
"regression") {
392 }
else if (fAnalysisTypeStr ==
"multiclass") {
394 }
else if (fAnalysisTypeStr ==
"auto") {
399 fCvFactoryOptions +=
"V:";
400 fOutputFactoryOptions +=
"V:";
402 fCvFactoryOptions +=
"!V:";
403 fOutputFactoryOptions +=
"!V:";
406 fCvFactoryOptions +=
Form(
"VerboseLevel=%s:", fVerboseLevel.Data());
407 fOutputFactoryOptions +=
Form(
"VerboseLevel=%s:", fVerboseLevel.Data());
409 fCvFactoryOptions +=
Form(
"AnalysisType=%s:", fAnalysisTypeStr.Data());
410 fOutputFactoryOptions +=
Form(
"AnalysisType=%s:", fAnalysisTypeStr.Data());
412 if (not fDrawProgressBar) {
413 fOutputFactoryOptions +=
"!DrawProgressBar:";
416 if (fTransformations !=
"") {
417 fCvFactoryOptions +=
Form(
"Transformations=%s:", fTransformations.Data());
418 fOutputFactoryOptions +=
Form(
"Transformations=%s:", fTransformations.Data());
423 fOutputFactoryOptions +=
"Correlations:";
426 fOutputFactoryOptions +=
"!Correlations:";
431 fOutputFactoryOptions +=
"ROC:";
434 fOutputFactoryOptions +=
"!ROC:";
439 fOutputFactoryOptions +=
Form(
"Silent:");
442 fCvFactoryOptions +=
"!Correlations:!ROC:!Color:!DrawProgressBar:Silent";
445 if (fFoldFileOutput and fOutputFile ==
nullptr) {
446 Log() << kFATAL <<
"No output file given, cannot generate per fold output." <<
Endl;
451 fFoldFactory = std::make_unique<TMVA::Factory>(fJobName, fCvFactoryOptions);
456 if (fOutputFile ==
nullptr) {
457 fFactory = std::make_unique<TMVA::Factory>(fJobName, fOutputFactoryOptions);
459 fFactory = std::make_unique<TMVA::Factory>(fJobName, fOutputFile, fOutputFactoryOptions);
462 if(fSplitTypeStr ==
"Random"){
463 fSplit = std::unique_ptr<CvSplitKFolds>(
new CvSplitKFolds(fNumFolds, fSplitExprString,
kFALSE));
464 }
else if(fSplitTypeStr ==
"RandomStratified"){
465 fSplit = std::unique_ptr<CvSplitKFolds>(
new CvSplitKFolds(fNumFolds, fSplitExprString,
kTRUE));
467 fSplit = std::unique_ptr<CvSplitKFolds>(
new CvSplitKFolds(fNumFolds, fSplitExprString));
475 if (i != fNumFolds) {
477 fSplit = std::make_unique<CvSplitKFolds>(fNumFolds, fSplitExprString);
478 fDataLoader->MakeKFoldDataSet(*fSplit);
488 if (splitExpr != fSplitExprString) {
489 fSplitExprString = splitExpr;
490 fSplit = std::make_unique<CvSplitKFolds>(fNumFolds, fSplitExprString);
491 fDataLoader->MakeKFoldDataSet(*fSplit);
513 Log() << kDEBUG <<
"Processing " << methodTitle <<
" fold " << iFold <<
Endl;
516 TFile *foldOutputFile =
nullptr;
518 if (fFoldFileOutput and fOutputFile !=
nullptr) {
519 TString path = std::string(
"") +
gSystem->
DirName(fOutputFile->GetName()) +
"/" + foldTitle +
".root";
520 std::cout <<
"PATH: " << path << std::endl;
522 fFoldFactory = std::make_unique<TMVA::Factory>(fJobName, foldOutputFile, fCvFactoryOptions);
526 MethodBase *smethod = fFoldFactory->BookMethod(fDataLoader.get(), methodTypeName, foldTitle, methodOptions);
533 fFoldFactory->TestAllMethods();
534 fFoldFactory->EvaluateAllMethods();
540 result.
fROCIntegral = fFoldFactory->GetROCIntegral(fDataLoader->GetName(), foldTitle);
542 TGraph *
gr = fFoldFactory->GetROCCurve(fDataLoader->GetName(), foldTitle,
true);
566 if (fFoldFileOutput and foldOutputFile !=
nullptr) {
567 foldOutputFile->
Close();
576 fFoldFactory->DeleteAllMethods();
577 fFoldFactory->fMethodsMap.clear();
591 fDataLoader->MakeKFoldDataSet(*fSplit);
595 fResults.reserve(fMethods.size());
596 for (
auto & methodInfo : fMethods) {
599 TString methodTypeName = methodInfo.GetValue<
TString>(
"MethodName");
600 TString methodTitle = methodInfo.GetValue<
TString>(
"MethodTitle");
602 if (methodTypeName ==
"") {
603 Log() << kFATAL <<
"No method booked for cross-validation" <<
Endl;
607 Log() << kINFO <<
"Evaluate method: " << methodTitle <<
Endl;
610 auto nWorkers = fNumWorkerProcs;
616 for (
UInt_t iFold = 0; iFold < fNumFolds; ++iFold) {
617 auto fold_result = ProcessFold(iFold, methodInfo);
618 result.Fill(fold_result);
622 std::vector<CrossValidationFoldResult> result_vector;
624 auto workItem = [
this, methodInfo](
UInt_t iFold) {
625 return ProcessFold(iFold, methodInfo);
630 for (
auto && fold_result : result_vector) {
631 result.Fill(fold_result);
635 fResults.push_back(result);
639 Form(
"SplitExpr=%s:NumFolds=%i"
640 ":EncapsulatedMethodName=%s"
641 ":EncapsulatedMethodTypeName=%s"
642 ":OutputEnsembling=%s",
643 fSplitExprString.Data(), fNumFolds, methodTitle.Data(), methodTypeName.
Data(), fOutputEnsembling.Data());
649 IMethod *method_interface = fFactory->GetMethod(fDataLoader->GetName(), methodTitle);
656 fDataLoader->RecombineKFoldDataSet(*fSplit);
659 for (
auto & methodInfo : fMethods) {
660 TString methodTypeName = methodInfo.GetValue<
TString>(
"MethodName");
661 TString methodTitle = methodInfo.GetValue<
TString>(
"MethodTitle");
663 IMethod *method_interface = fFactory->GetMethod(fDataLoader->GetName(), methodTitle);
666 if (fOutputFile !=
nullptr) {
667 fFactory->WriteDataInformation(method->fDataSetInfo);
671 method->TrainMethod();
676 fFactory->TestAllMethods();
679 fFactory->EvaluateAllMethods();
681 Log() << kINFO <<
"Evaluation done." <<
Endl;
687 if (fResults.empty()) {
688 Log() << kFATAL <<
"No cross-validation results available" <<
Endl;
char * Form(const char *fmt,...)
R__EXTERN TSystem * gSystem
This class provides a simple interface to execute the same task multiple times in parallel,...
auto Map(F func, unsigned nTimes) -> std::vector< typename std::result_of< F()>::type >
Execute func (with no arguments) nTimes in parallel.
A pseudo container class which is a generator of indices.
virtual void SetLineWidth(Width_t lwidth)
Set the line width.
virtual void SetLineColor(Color_t lcolor)
Set the line color.
virtual Int_t GetSize() const
Return the capacity of the collection, i.e.
A ROOT file is a suite of consecutive data records (TKey instances) with a well defined format.
virtual void Close(Option_t *option="")
Close a file.
static TFile * Open(const char *name, Option_t *option="", const char *ftitle="", Int_t compress=ROOT::RCompressionSetting::EDefaults::kUseGeneralPurpose, Int_t netopt=0)
Create / open a file.
A Graph is a graphics object made of two arrays X and Y with npoints each.
virtual void SetTitle(const char *title="")
Set graph title.
virtual Double_t Eval(Double_t x, TSpline *spline=0, Option_t *option="") const
Interpolate points in this graph at x using a TSpline.
This class displays a legend box (TPaveText) containing several legend entries.
virtual TObject * At(Int_t idx) const
Returns the object at position idx. Returns 0 if idx is out of range.
UInt_t GetNumWorkers() const
void CheckForUnusedOptions() const
checks for unused options in option string
Class to save the results of cross validation, the metric for the classification ins ROC and you can ...
std::vector< Double_t > fSeps
std::vector< Double_t > fEff01s
CrossValidationResult(UInt_t numFolds)
std::vector< Double_t > fTrainEff30s
std::shared_ptr< TMultiGraph > fROCCurves
std::vector< Double_t > fSigs
std::vector< Double_t > fEff30s
void Fill(CrossValidationFoldResult const &fr)
Float_t GetROCStandardDeviation() const
std::vector< Double_t > fEff10s
std::vector< Double_t > fTrainEff01s
std::map< UInt_t, Float_t > fROCs
std::vector< Double_t > fTrainEff10s
Float_t GetROCAverage() const
std::vector< Double_t > fEffAreas
TCanvas * DrawAvgROCCurve(Bool_t drawFolds=kFALSE, TString title="") const
TMultiGraph * GetROCCurves(Bool_t fLegend=kTRUE)
TGraph * GetAvgROCCurve(UInt_t numSamples=100) const
Generates a multigraph that contains an average ROC Curve.
TCanvas * Draw(const TString name="CrossValidation") const
Class to perform cross validation, splitting the dataloader into folds.
void SetNumFolds(UInt_t i)
void ParseOptions()
Method to parse the internal option string.
const std::vector< CrossValidationResult > & GetResults() const
CrossValidation(TString jobName, TMVA::DataLoader *dataloader, TString options)
void SetSplitExpr(TString splitExpr)
void Evaluate()
Does training, test set evaluation and performance evaluation of using cross-evalution.
CrossValidationFoldResult ProcessFold(UInt_t iFold, const OptionMap &methodInfo)
Evaluates each fold in turn.
void DeleteAllResults(Types::ETreeType type, Types::EAnalysisType analysistype)
Deletes all results currently in the dataset.
Abstract base class for all high level ml algorithms, you can book ml methods like BDT,...
virtual void ParseOptions()
Method to parse the internal option string.
static void SetIsTraining(Bool_t)
when this static function is called, it sets the flag whether events with negative event weight shoul...
Interface for all concrete MVA method implementations.
Virtual base Class for all MVA method.
virtual Double_t GetSeparation(TH1 *, TH1 *) const
compute "separation" defined as
virtual Double_t GetSignificance() const
compute significance of mean difference
Types::EAnalysisType GetAnalysisType() const
virtual Double_t GetEfficiency(const TString &, Types::ETreeType, Double_t &err)
fill background efficiency (resp.
virtual Double_t GetTrainingEfficiency(const TString &)
std::map< const TMVA::Event *, UInt_t > fEventToFoldMapping
ostringstream derivative to redirect and format output
static void EnableOutput()
class to storage options for the differents methods
T GetValue(const TString &key)
Singleton class for Global types used by TMVA.
A TMultiGraph is a collection of TGraph (or derived) objects.
virtual TObject * Clone(const char *newname="") const
Make a clone of an object using the Streamer facility.
const char * Data() const
static TString Format(const char *fmt,...)
Static method which formats a string using a printf style format descriptor and return a TString.
virtual const char * DirName(const char *pathname)
Return the directory name in pathname.
Abstract ClassifierFactory template that handles arbitrary types.
MsgLogger & Endl(MsgLogger &ml)
Double_t Sqrt(Double_t x)
LongDouble_t Power(LongDouble_t x, LongDouble_t y)