42   fSigs.resize(numFolds);
 
   43   fSeps.resize(numFolds);
 
   78   fSigs[iFold] = fr.
fSig;
 
   79   fSeps[iFold] = fr.
fSep;
 
   80   fEff01s[iFold] = fr.
fEff01;
 
   81   fEff10s[iFold] = fr.
fEff10;
 
   82   fEff30s[iFold] = fr.
fEff30;
 
   92   return fROCCurves.get();
 
  108   Double_t increment = 1.0 / (numSamples-1);
 
  109   std::vector<Double_t> 
x(numSamples), 
y(numSamples);
 
  111   TList *rocCurveList = fROCCurves.get()->GetListOfGraphs();
 
  113   for(
UInt_t iSample = 0; iSample < numSamples; iSample++) {
 
  114      Double_t xPoint = iSample * increment;
 
  117      for(
Int_t iGraph = 0; iGraph < rocCurveList->
GetSize(); iGraph++) {
 
  118        TGraph *foldROC = 
static_cast<TGraph *
>(rocCurveList->
At(iGraph));
 
  119        rocSum += foldROC->
Eval(xPoint);
 
  123      y[iSample] = rocSum/rocCurveList->
GetSize();
 
  126   return new TGraph(numSamples, &
x[0], &
y[0]);
 
  133   for(
auto &roc : fROCs) {
 
  136   return avg/fROCs.size();
 
  145   for(
auto &roc : fROCs) {
 
  158   fLogger << kHEADER << 
" ==== Results ====" << 
Endl;
 
  159   for(
auto &item:fROCs) {
 
  160      fLogger << kINFO << 
Form(
"Fold  %i ROC-Int : %.4f",item.first,item.second) << std::endl;
 
  163   fLogger << kINFO << 
"------------------------" << 
Endl;
 
  164   fLogger << kINFO << 
Form(
"Average ROC-Int : %.4f",GetROCAverage()) << 
Endl;
 
  165   fLogger << kINFO << 
Form(
"Std-Dev ROC-Int : %.4f",GetROCStandardDeviation()) << 
Endl;
 
  174   fROCCurves->Draw(
"AL");
 
  175   fROCCurves->GetXaxis()->SetTitle(
" Signal Efficiency ");
 
  176   fROCCurves->GetYaxis()->SetTitle(
" Background Rejection ");
 
  177   Float_t adjust=1+fROCs.size()*0.01;
 
  178   c->BuildLegend(0.15,0.15,0.4*adjust,0.5*adjust);
 
  179   c->SetTitle(
"Cross Validation ROC Curves");
 
  193      for (
auto foldRocObj : *(*fROCCurves).GetListOfGraphs()) {
 
  194         TGraph * foldRocGraph = 
dynamic_cast<TGraph *
>(foldRocObj->Clone());
 
  197         rocs->
Add(foldRocGraph);
 
  202   TGraph *avgRocGraph = GetAvgROCCurve(100);
 
  203   avgRocGraph->
SetTitle(
"Avg ROC Curve");
 
  206   rocs->
Add(avgRocGraph);
 
  212      title = 
"Cross Validation Average ROC Curve";
 
  227      leg->AddEntry(
static_cast<TGraph *
>(ROCCurveList->
At(nCurves-1)),
 
  228                    "Avg ROC Curve", 
"l");
 
  229      leg->AddEntry(
static_cast<TGraph *
>(ROCCurveList->
At(0)),
 
  230                    "Fold ROC Curves", 
"l");
 
  237   c->SetTitle(
"Cross Validation Average ROC Curve");
 
  279   : 
TMVA::
Envelope(jobName, dataloader, nullptr, options),
 
  280     fAnalysisType(
Types::kMaxAnalysisType),
 
  281     fAnalysisTypeStr(
"Auto"),
 
  282     fSplitTypeStr(
"Random"),
 
  284     fCvFactoryOptions(
""),
 
  291     fOutputFactoryOptions(
""),
 
  292     fOutputFile(outputFile),
 
  294     fSplitExprString(
""),
 
  296     fTransformations(
""),
 
  324   DeclareOptionRef(fSilent, 
"Silent",
 
  325                    "Batch mode: boolean silent flag inhibiting any output from TMVA after the creation of the factory " 
  326                    "class object (default: False)");
 
  327   DeclareOptionRef(fVerbose, 
"V", 
"Verbose flag");
 
  328   DeclareOptionRef(fVerboseLevel = 
TString(
"Info"), 
"VerboseLevel", 
"VerboseLevel (Debug/Verbose/Info)");
 
  329   AddPreDefVal(
TString(
"Debug"));
 
  330   AddPreDefVal(
TString(
"Verbose"));
 
  333   DeclareOptionRef(fTransformations, 
"Transformations",
 
  334                    "List of transformations to test; formatting example: \"Transformations=I;D;P;U;G,D\", for " 
  335                    "identity, decorrelation, PCA, Uniform and Gaussianisation followed by decorrelation " 
  338   DeclareOptionRef(fDrawProgressBar, 
"DrawProgressBar", 
"Boolean to show draw progress bar");
 
  339   DeclareOptionRef(fCorrelations, 
"Correlations", 
"Boolean to show correlation in output");
 
  340   DeclareOptionRef(fROC, 
"ROC", 
"Boolean to show ROC in output");
 
  343   DeclareOptionRef(fAnalysisTypeStr, 
"AnalysisType",
 
  344                    "Set the analysis type (Classification, Regression, Multiclass, Auto) (default: Auto)");
 
  345   AddPreDefVal(
TString(
"Classification"));
 
  346   AddPreDefVal(
TString(
"Regression"));
 
  347   AddPreDefVal(
TString(
"Multiclass"));
 
  351   DeclareOptionRef(fSplitTypeStr, 
"SplitType",
 
  352                    "Set the split type (Deterministic, Random, RandomStratified) (default: Random)");
 
  353   AddPreDefVal(
TString(
"Deterministic"));
 
  354   AddPreDefVal(
TString(
"Random"));
 
  355   AddPreDefVal(
TString(
"RandomStratified"));
 
  357   DeclareOptionRef(fSplitExprString, 
"SplitExpr", 
"The expression used to assign events to folds");
 
  358   DeclareOptionRef(fNumFolds, 
"NumFolds", 
"Number of folds to generate");
 
  359   DeclareOptionRef(fNumWorkerProcs, 
"NumWorkerProcs",
 
  360      "Determines how many processes to use for evaluation. 1 means no" 
  361      " parallelisation. 2 means use 2 processes. 0 means figure out the" 
  362      " number automatically based on the number of cpus available. Default" 
  365   DeclareOptionRef(fFoldFileOutput, 
"FoldFileOutput",
 
  366                    "If given a TMVA output file will be generated for each fold. Filename will be the same as " 
  367                    "specifed for the combined output with a _foldX suffix. (default: false)");
 
  369   DeclareOptionRef(fOutputEnsembling = 
TString(
"None"), 
"OutputEnsembling",
 
  370                    "Combines output from contained methods. If None, no combination is performed. (default None)");
 
  382   if (fSplitTypeStr != 
"Deterministic" && fSplitExprString != 
"") {
 
  383      Log() << kFATAL << 
"SplitExpr can only be used with Deterministic Splitting" << 
Endl;
 
  387   fAnalysisTypeStr.ToLower();
 
  388   if (fAnalysisTypeStr == 
"classification") {
 
  390   } 
else if (fAnalysisTypeStr == 
"regression") {
 
  392   } 
else if (fAnalysisTypeStr == 
"multiclass") {
 
  394   } 
else if (fAnalysisTypeStr == 
"auto") {
 
  399      fCvFactoryOptions += 
"V:";
 
  400      fOutputFactoryOptions += 
"V:";
 
  402      fCvFactoryOptions += 
"!V:";
 
  403      fOutputFactoryOptions += 
"!V:";
 
  406   fCvFactoryOptions += 
Form(
"VerboseLevel=%s:", fVerboseLevel.Data());
 
  407   fOutputFactoryOptions += 
Form(
"VerboseLevel=%s:", fVerboseLevel.Data());
 
  409   fCvFactoryOptions += 
Form(
"AnalysisType=%s:", fAnalysisTypeStr.Data());
 
  410   fOutputFactoryOptions += 
Form(
"AnalysisType=%s:", fAnalysisTypeStr.Data());
 
  412   if (!fDrawProgressBar) {
 
  413      fCvFactoryOptions += 
"!DrawProgressBar:";
 
  414      fOutputFactoryOptions += 
"!DrawProgressBar:";
 
  417   if (fTransformations != 
"") {
 
  418      fCvFactoryOptions += 
Form(
"Transformations=%s:", fTransformations.Data());
 
  419      fOutputFactoryOptions += 
Form(
"Transformations=%s:", fTransformations.Data());
 
  423      fCvFactoryOptions += 
"Correlations:";
 
  424      fOutputFactoryOptions += 
"Correlations:";
 
  426      fCvFactoryOptions += 
"!Correlations:";
 
  427      fOutputFactoryOptions += 
"!Correlations:";
 
  431      fCvFactoryOptions += 
"ROC:";
 
  432      fOutputFactoryOptions += 
"ROC:";
 
  434      fCvFactoryOptions += 
"!ROC:";
 
  435      fOutputFactoryOptions += 
"!ROC:";
 
  439      fCvFactoryOptions += 
Form(
"Silent:");
 
  440      fOutputFactoryOptions += 
Form(
"Silent:");
 
  444   if (fFoldFileOutput && fOutputFile == 
nullptr) {
 
  445      Log() << kFATAL << 
"No output file given, cannot generate per fold output." << 
Endl;
 
  450   fFoldFactory = std::make_unique<TMVA::Factory>(fJobName, fCvFactoryOptions);
 
  455   if (fOutputFile == 
nullptr) {
 
  456      fFactory = std::make_unique<TMVA::Factory>(fJobName, fOutputFactoryOptions);
 
  458      fFactory = std::make_unique<TMVA::Factory>(fJobName, fOutputFile, fOutputFactoryOptions);
 
  461   if(fSplitTypeStr == 
"Random"){
 
  462      fSplit = std::unique_ptr<CvSplitKFolds>(
new CvSplitKFolds(fNumFolds, fSplitExprString, 
kFALSE));
 
  463   } 
else if(fSplitTypeStr == 
"RandomStratified"){
 
  464      fSplit = std::unique_ptr<CvSplitKFolds>(
new CvSplitKFolds(fNumFolds, fSplitExprString, 
kTRUE));
 
  466      fSplit = std::unique_ptr<CvSplitKFolds>(
new CvSplitKFolds(fNumFolds, fSplitExprString));
 
  474   if (i != fNumFolds) {
 
  476      fSplit = std::make_unique<CvSplitKFolds>(fNumFolds, fSplitExprString);
 
  477      fDataLoader->MakeKFoldDataSet(*fSplit);
 
  487   if (splitExpr != fSplitExprString) {
 
  488      fSplitExprString = splitExpr;
 
  489      fSplit = std::make_unique<CvSplitKFolds>(fNumFolds, fSplitExprString);
 
  490      fDataLoader->MakeKFoldDataSet(*fSplit);
 
  513   Log() << kDEBUG << 
"Processing  " << methodTitle << 
" fold " << iFold << 
Endl;
 
  516   TFile *foldOutputFile = 
nullptr;
 
  518   if (fFoldFileOutput && fOutputFile != 
nullptr) {
 
  521      Log() << kINFO << 
"Creating fold output at:" << path << 
Endl;
 
  522      fFoldFactory = std::make_unique<TMVA::Factory>(fJobName, foldOutputFile, fCvFactoryOptions);
 
  526   MethodBase *smethod = fFoldFactory->BookMethod(fDataLoader.get(), methodTypeName, foldTitle, methodOptions);
 
  533   fFoldFactory->TestAllMethods();
 
  534   fFoldFactory->EvaluateAllMethods();
 
  540      result.fROCIntegral = fFoldFactory->GetROCIntegral(fDataLoader->GetName(), foldTitle);
 
  542      TGraph *
gr = fFoldFactory->GetROCCurve(fDataLoader->GetName(), foldTitle, 
true);
 
  566   if (fFoldFileOutput && foldOutputFile != 
nullptr) {
 
  567      foldOutputFile->
Close();
 
  576   fFoldFactory->DeleteAllMethods();
 
  577   fFoldFactory->fMethodsMap.clear();
 
  591      fDataLoader->MakeKFoldDataSet(*fSplit);
 
  595   fResults.reserve(fMethods.size());
 
  596   for (
auto & methodInfo : fMethods) {
 
  599      TString methodTypeName = methodInfo.GetValue<
TString>(
"MethodName");
 
  600      TString methodTitle = methodInfo.GetValue<
TString>(
"MethodTitle");
 
  602      if (methodTypeName == 
"") {
 
  603         Log() << kFATAL << 
"No method booked for cross-validation" << 
Endl;
 
  607      Log() << kINFO << 
Endl;
 
  608      Log() << kINFO << 
Endl;
 
  609      Log() << kINFO << 
"========================================" << 
Endl;
 
  610      Log() << kINFO << 
"Processing folds for method " << methodTitle << 
Endl;
 
  611      Log() << kINFO << 
"========================================" << 
Endl;
 
  612      Log() << kINFO << 
Endl;
 
  615      auto nWorkers = fNumWorkerProcs;
 
  621         for (
UInt_t iFold = 0; iFold < fNumFolds; ++iFold) {
 
  622            auto fold_result = ProcessFold(iFold, methodInfo);
 
  628         std::vector<CrossValidationFoldResult> result_vector;
 
  630         auto workItem = [
this, methodInfo](
UInt_t iFold) {
 
  631            return ProcessFold(iFold, methodInfo);
 
  636         for (
auto && fold_result : result_vector) {
 
  642      fResults.push_back(
result);
 
  646         Form(
"SplitExpr=%s:NumFolds=%i" 
  647              ":EncapsulatedMethodName=%s" 
  648              ":EncapsulatedMethodTypeName=%s" 
  649              ":OutputEnsembling=%s",
 
  650              fSplitExprString.Data(), fNumFolds, methodTitle.Data(), methodTypeName.
Data(), fOutputEnsembling.Data());
 
  656      IMethod *method_interface = fFactory->GetMethod(fDataLoader->GetName(), methodTitle);
 
  662   Log() << kINFO << 
Endl;
 
  663   Log() << kINFO << 
Endl;
 
  664   Log() << kINFO << 
"========================================" << 
Endl;
 
  665   Log() << kINFO << 
"Folds processed for all methods, evaluating." << 
Endl;
 
  666   Log() << kINFO << 
"========================================" << 
Endl;
 
  667   Log() << kINFO << 
Endl;
 
  670   fDataLoader->RecombineKFoldDataSet(*fSplit);
 
  673   for (
auto & methodInfo : fMethods) {
 
  674      TString methodTypeName = methodInfo.GetValue<
TString>(
"MethodName");
 
  675      TString methodTitle = methodInfo.GetValue<
TString>(
"MethodTitle");
 
  677      IMethod *method_interface = fFactory->GetMethod(fDataLoader->GetName(), methodTitle);
 
  680      if (fOutputFile != 
nullptr) {
 
  681         fFactory->WriteDataInformation(method->fDataSetInfo);
 
  685      method->TrainMethod();
 
  690   fFactory->TestAllMethods();
 
  693   fFactory->EvaluateAllMethods();
 
  695   Log() << kINFO << 
"Evaluation done." << 
Endl;
 
  701   if (fResults.empty()) {
 
  702      Log() << kFATAL << 
"No cross-validation results available" << 
Endl;
 
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t result
 
char * Form(const char *fmt,...)
Formats a string in a circular formatting buffer.
 
R__EXTERN TSystem * gSystem
 
auto Map(F func, unsigned nTimes) -> std::vector< InvokeResult_t< F > >
Execute a function without arguments several times.
 
This class provides a simple interface to execute the same task multiple times in parallel,...
 
A pseudo container class which is a generator of indices.
 
virtual void SetLineWidth(Width_t lwidth)
Set the line width.
 
virtual void SetLineColor(Color_t lcolor)
Set the line color.
 
virtual Int_t GetSize() const
Return the capacity of the collection, i.e.
 
A ROOT file is a suite of consecutive data records (TKey instances) with a well defined format.
 
static TFile * Open(const char *name, Option_t *option="", const char *ftitle="", Int_t compress=ROOT::RCompressionSetting::EDefaults::kUseCompiledDefault, Int_t netopt=0)
Create / open a file.
 
void Close(Option_t *option="") override
Close a file.
 
A TGraph is an object made of two arrays X and Y with npoints each.
 
virtual Double_t Eval(Double_t x, TSpline *spline=nullptr, Option_t *option="") const
Interpolate points in this graph at x using a TSpline.
 
void SetTitle(const char *title="") override
Change (i.e.
 
This class displays a legend box (TPaveText) containing several legend entries.
 
TObject * At(Int_t idx) const override
Returns the object at position idx. Returns 0 if idx is out of range.
 
UInt_t GetNumWorkers() const
 
void CheckForUnusedOptions() const
checks for unused options in option string
 
Class to save the results of cross validation, the metric for the classification ins ROC and you can ...
 
std::vector< Double_t > fSeps
 
std::vector< Double_t > fEff01s
 
CrossValidationResult(UInt_t numFolds)
 
std::vector< Double_t > fTrainEff30s
 
std::shared_ptr< TMultiGraph > fROCCurves
 
std::vector< Double_t > fSigs
 
std::vector< Double_t > fEff30s
 
void Fill(CrossValidationFoldResult const &fr)
 
Float_t GetROCStandardDeviation() const
 
std::vector< Double_t > fEff10s
 
std::vector< Double_t > fTrainEff01s
 
std::map< UInt_t, Float_t > fROCs
 
std::vector< Double_t > fTrainEff10s
 
Float_t GetROCAverage() const
 
std::vector< Double_t > fEffAreas
 
TCanvas * DrawAvgROCCurve(Bool_t drawFolds=kFALSE, TString title="") const
 
TMultiGraph * GetROCCurves(Bool_t fLegend=kTRUE)
 
TGraph * GetAvgROCCurve(UInt_t numSamples=100) const
Generates a multigraph that contains an average ROC Curve.
 
TCanvas * Draw(const TString name="CrossValidation") const
 
Class to perform cross validation, splitting the dataloader into folds.
 
void SetNumFolds(UInt_t i)
 
void ParseOptions()
Method to parse the internal option string.
 
const std::vector< CrossValidationResult > & GetResults() const
 
CrossValidation(TString jobName, TMVA::DataLoader *dataloader, TString options)
 
void SetSplitExpr(TString splitExpr)
 
void Evaluate()
Does training, test set evaluation and performance evaluation of using cross-evalution.
 
CrossValidationFoldResult ProcessFold(UInt_t iFold, const OptionMap &methodInfo)
Evaluates each fold in turn.
 
void DeleteAllResults(Types::ETreeType type, Types::EAnalysisType analysistype)
Deletes all results currently in the dataset.
 
Abstract base class for all high level ml algorithms, you can book ml methods like BDT,...
 
virtual void ParseOptions()
Method to parse the internal option string.
 
static void SetIsTraining(Bool_t)
when this static function is called, it sets the flag whether events with negative event weight shoul...
 
Interface for all concrete MVA method implementations.
 
Virtual base Class for all MVA method.
 
virtual Double_t GetSeparation(TH1 *, TH1 *) const
compute "separation" defined as
 
virtual Double_t GetSignificance() const
compute significance of mean difference
 
Types::EAnalysisType GetAnalysisType() const
 
virtual Double_t GetEfficiency(const TString &, Types::ETreeType, Double_t &err)
fill background efficiency (resp.
 
virtual Double_t GetTrainingEfficiency(const TString &)
 
std::map< const TMVA::Event *, UInt_t > fEventToFoldMapping
 
ostringstream derivative to redirect and format output
 
static void EnableOutput()
 
class to storage options for the differents methods
 
T GetValue(const TString &key)
 
Singleton class for Global types used by TMVA.
 
A TMultiGraph is a collection of TGraph (or derived) objects.
 
TList * GetListOfGraphs() const
 
virtual void Add(TGraph *graph, Option_t *chopt="")
Add a new graph to the list of graphs.
 
TAxis * GetYaxis()
Get y axis of the graph.
 
TAxis * GetXaxis()
Get x axis of the graph.
 
TObject * Clone(const char *newname="") const override
Make a clone of an object using the Streamer facility.
 
virtual void SetTitle(const char *title="")
Set the title of the TNamed.
 
virtual void SetName(const char *name)
Set the name of the TNamed.
 
virtual TObject * DrawClone(Option_t *option="") const
Draw a clone of this object in the current selected pad with: gROOT->SetSelectedPad(c1).
 
const char * Data() const
 
static TString Format(const char *fmt,...)
Static method which formats a string using a printf style format descriptor and return a TString.
 
virtual TString GetDirName(const char *pathname)
Return the directory name in pathname.
 
create variable transformations
 
MsgLogger & Endl(MsgLogger &ml)
 
Double_t Sqrt(Double_t x)
Returns the square root of x.
 
LongDouble_t Power(LongDouble_t x, LongDouble_t y)
Returns x raised to the power y.