Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
CrossValidation.cxx
Go to the documentation of this file.
1// @(#)root/tmva $Id$
2// Author: Omar Zapata, Thomas James Stevenson and Pourya Vakilipourtakalou
3// Modified: Kim Albertsson 2017
4
5/*************************************************************************
6 * Copyright (C) 2018, Rene Brun and Fons Rademakers. *
7 * All rights reserved. *
8 * *
9 * For the licensing terms see $ROOTSYS/LICENSE. *
10 * For the list of contributors see $ROOTSYS/README/CREDITS. *
11 *************************************************************************/
12
14
16#include "TMVA/Config.h"
17#include "TMVA/CvSplit.h"
18#include "TMVA/DataSet.h"
19#include "TMVA/Event.h"
20#include "TMVA/MethodBase.h"
22#include "TMVA/MsgLogger.h"
25#include "TMVA/ROCCurve.h"
26#include "TMVA/Types.h"
27
28#include "TSystem.h"
29#include "TAxis.h"
30#include "TCanvas.h"
31#include "TGraph.h"
32#include "TLegend.h"
33#include "TMath.h"
34
35#include <iostream>
36#include <memory>
37
38//_______________________________________________________________________
40:fROCCurves(new TMultiGraph())
41{
42 fSigs.resize(numFolds);
43 fSeps.resize(numFolds);
44 fEff01s.resize(numFolds);
45 fEff10s.resize(numFolds);
46 fEff30s.resize(numFolds);
47 fEffAreas.resize(numFolds);
48 fTrainEff01s.resize(numFolds);
49 fTrainEff10s.resize(numFolds);
50 fTrainEff30s.resize(numFolds);
51}
52
53//_______________________________________________________________________
55{
56 fROCs=obj.fROCs;
57 fROCCurves = obj.fROCCurves;
58
59 fSigs = obj.fSigs;
60 fSeps = obj.fSeps;
61 fEff01s = obj.fEff01s;
62 fEff10s = obj.fEff10s;
63 fEff30s = obj.fEff30s;
64 fEffAreas = obj.fEffAreas;
65 fTrainEff01s = obj.fTrainEff01s;
66 fTrainEff10s = obj.fTrainEff10s;
67 fTrainEff30s = obj.fTrainEff30s;
68}
69
70//_______________________________________________________________________
72{
73 UInt_t iFold = fr.fFold;
74
75 fROCs[iFold] = fr.fROCIntegral;
76 fROCCurves->Add(dynamic_cast<TGraph *>(fr.fROC.Clone()));
77
78 fSigs[iFold] = fr.fSig;
79 fSeps[iFold] = fr.fSep;
80 fEff01s[iFold] = fr.fEff01;
81 fEff10s[iFold] = fr.fEff10;
82 fEff30s[iFold] = fr.fEff30;
83 fEffAreas[iFold] = fr.fEffArea;
84 fTrainEff01s[iFold] = fr.fTrainEff01;
85 fTrainEff10s[iFold] = fr.fTrainEff10;
86 fTrainEff30s[iFold] = fr.fTrainEff30;
87}
88
89//_______________________________________________________________________
91{
92 return fROCCurves.get();
93}
94
95////////////////////////////////////////////////////////////////////////////////
96/// \brief Generates a multigraph that contains an average ROC Curve.
97///
98/// \note You own the returned pointer.
99///
100/// \param[in] numSamples Number of samples used for generating the average ROC
101/// Curve. Avg. curve will be evaluated only at these
102/// points (using interpolation if necessary).
103///
104
106{
107 // `numSamples * increment` should equal 1.0!
108 Double_t increment = 1.0 / (numSamples-1);
109 std::vector<Double_t> x(numSamples), y(numSamples);
110
111 TList *rocCurveList = fROCCurves.get()->GetListOfGraphs();
112
113 for(UInt_t iSample = 0; iSample < numSamples; iSample++) {
114 Double_t xPoint = iSample * increment;
115 Double_t rocSum = 0;
116
117 for(Int_t iGraph = 0; iGraph < rocCurveList->GetSize(); iGraph++) {
118 TGraph *foldROC = static_cast<TGraph *>(rocCurveList->At(iGraph));
119 rocSum += foldROC->Eval(xPoint);
120 }
121
122 x[iSample] = xPoint;
123 y[iSample] = rocSum/rocCurveList->GetSize();
124 }
125
126 return new TGraph(numSamples, &x[0], &y[0]);
127}
128
129//_______________________________________________________________________
131{
132 Float_t avg=0;
133 for(auto &roc : fROCs) {
134 avg+=roc.second;
135 }
136 return avg/fROCs.size();
137}
138
139//_______________________________________________________________________
141{
142 // NOTE: We are using here the unbiased estimation of the standard deviation.
143 Float_t std=0;
144 Float_t avg=GetROCAverage();
145 for(auto &roc : fROCs) {
146 std+=TMath::Power(roc.second-avg, 2);
147 }
148 return TMath::Sqrt(std/float(fROCs.size()-1.0));
149}
150
151//_______________________________________________________________________
153{
156
157 MsgLogger fLogger("CrossValidation");
158 fLogger << kHEADER << " ==== Results ====" << Endl;
159 for(auto &item:fROCs) {
160 fLogger << kINFO << Form("Fold %i ROC-Int : %.4f",item.first,item.second) << std::endl;
161 }
162
163 fLogger << kINFO << "------------------------" << Endl;
164 fLogger << kINFO << Form("Average ROC-Int : %.4f",GetROCAverage()) << Endl;
165 fLogger << kINFO << Form("Std-Dev ROC-Int : %.4f",GetROCStandardDeviation()) << Endl;
166
168}
169
170//_______________________________________________________________________
172{
173 auto *c = new TCanvas(name.Data());
174 fROCCurves->Draw("AL");
175 fROCCurves->GetXaxis()->SetTitle(" Signal Efficiency ");
176 fROCCurves->GetYaxis()->SetTitle(" Background Rejection ");
177 Float_t adjust=1+fROCs.size()*0.01;
178 c->BuildLegend(0.15,0.15,0.4*adjust,0.5*adjust);
179 c->SetTitle("Cross Validation ROC Curves");
180 c->Draw();
181 return c;
182}
183
184//
186{
187 // note this function will create memory leak for the TMultiGraph
188 // but it needs to be kept alive in order to display the canvas
189 TMultiGraph *rocs = new TMultiGraph();
190
191 // Potentially add the folds
192 if (drawFolds) {
193 for (auto foldRocObj : *(*fROCCurves).GetListOfGraphs()) {
194 TGraph * foldRocGraph = dynamic_cast<TGraph *>(foldRocObj->Clone());
195 foldRocGraph->SetLineColor(1);
196 foldRocGraph->SetLineWidth(1);
197 rocs->Add(foldRocGraph);
198 }
199 }
200
201 // Add the average roc curve
202 TGraph *avgRocGraph = GetAvgROCCurve(100);
203 avgRocGraph->SetTitle("Avg ROC Curve");
204 avgRocGraph->SetLineColor(2);
205 avgRocGraph->SetLineWidth(3);
206 rocs->Add(avgRocGraph);
207
208 // Draw
209 TCanvas *c = new TCanvas();
210
211 if (title != "") {
212 title = "Cross Validation Average ROC Curve";
213 }
214
215 rocs->SetName("cv_rocs");
216 rocs->SetTitle(title);
217 rocs->GetXaxis()->SetTitle("Signal Efficiency");
218 rocs->GetYaxis()->SetTitle("Background Rejection");
219 rocs->DrawClone("AL");
220
221 // Build legend
222 TLegend *leg = new TLegend();
223 TList *ROCCurveList = rocs->GetListOfGraphs();
224
225 if (drawFolds) {
226 Int_t nCurves = ROCCurveList->GetSize();
227 leg->AddEntry(static_cast<TGraph *>(ROCCurveList->At(nCurves-1)),
228 "Avg ROC Curve", "l");
229 leg->AddEntry(static_cast<TGraph *>(ROCCurveList->At(0)),
230 "Fold ROC Curves", "l");
231 leg->Draw();
232 } else {
233 c->BuildLegend();
234 }
235
236 // Draw Canvas
237 c->SetTitle("Cross Validation Average ROC Curve");
238 c->Draw();
239 return c;
240}
241
242/**
243* \class TMVA::CrossValidation
244* \ingroup TMVA
245* \brief
246
247Use html for explicit line breaking<br>
248Markdown links? [class reference](#reference)?
249
250
251~~~{.cpp}
252ce->BookMethod(dataloader, options);
253ce->Evaluate();
254~~~
255
256Cross-evaluation will generate a new training and a test set dynamically from
257from `K` folds. These `K` folds are generated by splitting the input training
258set. The input test set is currently ignored.
259
260This means that when you specify your DataSet you should include all events
261in your training set. One way of doing this would be the following:
262
263~~~{.cpp}
264dataloader->AddTree( signalTree, "cls1" );
265dataloader->AddTree( background, "cls2" );
266dataloader->PrepareTrainingAndTestTree( "", "", "nTest_cls1=1:nTest_cls2=1" );
267~~~
268
269## Split Expression
270See CVSplit documentation?
271
272*/
273
274////////////////////////////////////////////////////////////////////////////////
275///
276
278 TString options)
279 : TMVA::Envelope(jobName, dataloader, nullptr, options),
280 fAnalysisType(Types::kMaxAnalysisType),
281 fAnalysisTypeStr("Auto"),
282 fSplitTypeStr("Random"),
283 fCorrelations(kFALSE),
284 fCvFactoryOptions(""),
285 fDrawProgressBar(kFALSE),
286 fFoldFileOutput(kFALSE),
287 fFoldStatus(kFALSE),
288 fJobName(jobName),
289 fNumFolds(2),
290 fNumWorkerProcs(1),
291 fOutputFactoryOptions(""),
292 fOutputFile(outputFile),
293 fSilent(kFALSE),
294 fSplitExprString(""),
295 fROC(kTRUE),
296 fTransformations(""),
297 fVerbose(kFALSE),
298 fVerboseLevel(kINFO)
299{
300 InitOptions();
303}
304
305////////////////////////////////////////////////////////////////////////////////
306///
307
309 : CrossValidation(jobName, dataloader, nullptr, options)
310{
311}
312
313////////////////////////////////////////////////////////////////////////////////
314///
315
317
318////////////////////////////////////////////////////////////////////////////////
319///
320
322{
323 // Forwarding of Factory options
324 DeclareOptionRef(fSilent, "Silent",
325 "Batch mode: boolean silent flag inhibiting any output from TMVA after the creation of the factory "
326 "class object (default: False)");
327 DeclareOptionRef(fVerbose, "V", "Verbose flag");
328 DeclareOptionRef(fVerboseLevel = TString("Info"), "VerboseLevel", "VerboseLevel (Debug/Verbose/Info)");
329 AddPreDefVal(TString("Debug"));
330 AddPreDefVal(TString("Verbose"));
331 AddPreDefVal(TString("Info"));
332
333 DeclareOptionRef(fTransformations, "Transformations",
334 "List of transformations to test; formatting example: \"Transformations=I;D;P;U;G,D\", for "
335 "identity, decorrelation, PCA, Uniform and Gaussianisation followed by decorrelation "
336 "transformations");
337
338 DeclareOptionRef(fDrawProgressBar, "DrawProgressBar", "Boolean to show draw progress bar");
339 DeclareOptionRef(fCorrelations, "Correlations", "Boolean to show correlation in output");
340 DeclareOptionRef(fROC, "ROC", "Boolean to show ROC in output");
341
342 TString analysisType("Auto");
343 DeclareOptionRef(fAnalysisTypeStr, "AnalysisType",
344 "Set the analysis type (Classification, Regression, Multiclass, Auto) (default: Auto)");
345 AddPreDefVal(TString("Classification"));
346 AddPreDefVal(TString("Regression"));
347 AddPreDefVal(TString("Multiclass"));
348 AddPreDefVal(TString("Auto"));
349
350 // Options specific to CE
351 DeclareOptionRef(fSplitTypeStr, "SplitType",
352 "Set the split type (Deterministic, Random, RandomStratified) (default: Random)");
353 AddPreDefVal(TString("Deterministic"));
354 AddPreDefVal(TString("Random"));
355 AddPreDefVal(TString("RandomStratified"));
356
357 DeclareOptionRef(fSplitExprString, "SplitExpr", "The expression used to assign events to folds");
358 DeclareOptionRef(fNumFolds, "NumFolds", "Number of folds to generate");
359 DeclareOptionRef(fNumWorkerProcs, "NumWorkerProcs",
360 "Determines how many processes to use for evaluation. 1 means no"
361 " parallelisation. 2 means use 2 processes. 0 means figure out the"
362 " number automatically based on the number of cpus available. Default"
363 " 1.");
364
365 DeclareOptionRef(fFoldFileOutput, "FoldFileOutput",
366 "If given a TMVA output file will be generated for each fold. Filename will be the same as "
367 "specifed for the combined output with a _foldX suffix. (default: false)");
368
369 DeclareOptionRef(fOutputEnsembling = TString("None"), "OutputEnsembling",
370 "Combines output from contained methods. If None, no combination is performed. (default None)");
371 AddPreDefVal(TString("None"));
372 AddPreDefVal(TString("Avg"));
373}
374
375////////////////////////////////////////////////////////////////////////////////
376///
377
379{
381
382 if (fSplitTypeStr != "Deterministic" && fSplitExprString != "") {
383 Log() << kFATAL << "SplitExpr can only be used with Deterministic Splitting" << Endl;
384 }
385
386 // Factory options
387 fAnalysisTypeStr.ToLower();
388 if (fAnalysisTypeStr == "classification") {
389 fAnalysisType = Types::kClassification;
390 } else if (fAnalysisTypeStr == "regression") {
391 fAnalysisType = Types::kRegression;
392 } else if (fAnalysisTypeStr == "multiclass") {
393 fAnalysisType = Types::kMulticlass;
394 } else if (fAnalysisTypeStr == "auto") {
395 fAnalysisType = Types::kNoAnalysisType;
396 }
397
398 if (fVerbose) {
399 fCvFactoryOptions += "V:";
400 fOutputFactoryOptions += "V:";
401 } else {
402 fCvFactoryOptions += "!V:";
403 fOutputFactoryOptions += "!V:";
404 }
405
406 fCvFactoryOptions += Form("VerboseLevel=%s:", fVerboseLevel.Data());
407 fOutputFactoryOptions += Form("VerboseLevel=%s:", fVerboseLevel.Data());
408
409 fCvFactoryOptions += Form("AnalysisType=%s:", fAnalysisTypeStr.Data());
410 fOutputFactoryOptions += Form("AnalysisType=%s:", fAnalysisTypeStr.Data());
411
412 if (!fDrawProgressBar) {
413 fCvFactoryOptions += "!DrawProgressBar:";
414 fOutputFactoryOptions += "!DrawProgressBar:";
415 }
416
417 if (fTransformations != "") {
418 fCvFactoryOptions += Form("Transformations=%s:", fTransformations.Data());
419 fOutputFactoryOptions += Form("Transformations=%s:", fTransformations.Data());
420 }
421
422 if (fCorrelations) {
423 fCvFactoryOptions += "Correlations:";
424 fOutputFactoryOptions += "Correlations:";
425 } else {
426 fCvFactoryOptions += "!Correlations:";
427 fOutputFactoryOptions += "!Correlations:";
428 }
429
430 if (fROC) {
431 fCvFactoryOptions += "ROC:";
432 fOutputFactoryOptions += "ROC:";
433 } else {
434 fCvFactoryOptions += "!ROC:";
435 fOutputFactoryOptions += "!ROC:";
436 }
437
438 if (fSilent) {
439 fCvFactoryOptions += Form("Silent:");
440 fOutputFactoryOptions += Form("Silent:");
441 }
442
443 // CE specific options
444 if (fFoldFileOutput && fOutputFile == nullptr) {
445 Log() << kFATAL << "No output file given, cannot generate per fold output." << Endl;
446 }
447
448 // Initialisations
449
450 fFoldFactory = std::make_unique<TMVA::Factory>(fJobName, fCvFactoryOptions);
451
452 // The fOutputFactory should always have !ModelPersistence set since we use a custom code path for this.
453 // In this case we create a special method (MethodCrossValidation) that can only be used by
454 // CrossValidation and the Reader.
455 if (fOutputFile == nullptr) {
456 fFactory = std::make_unique<TMVA::Factory>(fJobName, fOutputFactoryOptions);
457 } else {
458 fFactory = std::make_unique<TMVA::Factory>(fJobName, fOutputFile, fOutputFactoryOptions);
459 }
460
461 if(fSplitTypeStr == "Random"){
462 fSplit = std::unique_ptr<CvSplitKFolds>(new CvSplitKFolds(fNumFolds, fSplitExprString, kFALSE));
463 } else if(fSplitTypeStr == "RandomStratified"){
464 fSplit = std::unique_ptr<CvSplitKFolds>(new CvSplitKFolds(fNumFolds, fSplitExprString, kTRUE));
465 } else {
466 fSplit = std::unique_ptr<CvSplitKFolds>(new CvSplitKFolds(fNumFolds, fSplitExprString));
467 }
468
469}
470
471//_______________________________________________________________________
473{
474 if (i != fNumFolds) {
475 fNumFolds = i;
476 fSplit = std::make_unique<CvSplitKFolds>(fNumFolds, fSplitExprString);
477 fDataLoader->MakeKFoldDataSet(*fSplit);
478 fFoldStatus = kTRUE;
479 }
480}
481
482////////////////////////////////////////////////////////////////////////////////
483///
484
486{
487 if (splitExpr != fSplitExprString) {
488 fSplitExprString = splitExpr;
489 fSplit = std::make_unique<CvSplitKFolds>(fNumFolds, fSplitExprString);
490 fDataLoader->MakeKFoldDataSet(*fSplit);
491 fFoldStatus = kTRUE;
492 }
493}
494
495////////////////////////////////////////////////////////////////////////////////
496/// Evaluates each fold in turn.
497/// - Prepares train and test data sets
498/// - Trains method
499/// - Evalutes on test set
500/// - Stores the evaluation internally
501///
502/// @param iFold fold to evaluate
503///
504
506{
507 TString methodTypeName = methodInfo.GetValue<TString>("MethodName");
508 TString methodTitle = methodInfo.GetValue<TString>("MethodTitle");
509 TString methodOptions = methodInfo.GetValue<TString>("MethodOptions");
510 TString foldTitle = methodTitle + TString("_fold") + TString::Format("%i", iFold + 1);
511
512 Log() << kDEBUG << "Processing " << methodTitle << " fold " << iFold << Endl;
513
514 // Only used if fFoldOutputFile == true
515 TFile *foldOutputFile = nullptr;
516
517 if (fFoldFileOutput && fOutputFile != nullptr) {
518 TString path = gSystem->GetDirName(fOutputFile->GetName()) + "/" + foldTitle + ".root";
519 foldOutputFile = TFile::Open(path, "RECREATE");
520 Log() << kINFO << "Creating fold output at:" << path << Endl;
521 fFoldFactory = std::make_unique<TMVA::Factory>(fJobName, foldOutputFile, fCvFactoryOptions);
522 }
523
524 fDataLoader->PrepareFoldDataSet(*fSplit, iFold, TMVA::Types::kTraining);
525 MethodBase *smethod = fFoldFactory->BookMethod(fDataLoader.get(), methodTypeName, foldTitle, methodOptions);
526
527 // Train method (train method and eval train set)
529 smethod->TrainMethod();
531
532 fFoldFactory->TestAllMethods();
533 fFoldFactory->EvaluateAllMethods();
534
536
537 // Results for aggregation (ROC integral, efficiencies etc.)
538 if (fAnalysisType == Types::kClassification || fAnalysisType == Types::kMulticlass) {
539 result.fROCIntegral = fFoldFactory->GetROCIntegral(fDataLoader->GetName(), foldTitle);
540
541 TGraph *gr = fFoldFactory->GetROCCurve(fDataLoader->GetName(), foldTitle, true);
542 gr->SetLineColor(iFold + 1);
543 gr->SetLineWidth(2);
544 gr->SetTitle(foldTitle.Data());
545 result.fROC = *gr;
546
547 result.fSig = smethod->GetSignificance();
548 result.fSep = smethod->GetSeparation();
549
550 if (fAnalysisType == Types::kClassification) {
551 Double_t err;
552 result.fEff01 = smethod->GetEfficiency("Efficiency:0.01", Types::kTesting, err);
553 result.fEff10 = smethod->GetEfficiency("Efficiency:0.10", Types::kTesting, err);
554 result.fEff30 = smethod->GetEfficiency("Efficiency:0.30", Types::kTesting, err);
555 result.fEffArea = smethod->GetEfficiency("", Types::kTesting, err);
556 result.fTrainEff01 = smethod->GetTrainingEfficiency("Efficiency:0.01");
557 result.fTrainEff10 = smethod->GetTrainingEfficiency("Efficiency:0.10");
558 result.fTrainEff30 = smethod->GetTrainingEfficiency("Efficiency:0.30");
559 } else if (fAnalysisType == Types::kMulticlass) {
560 // Nothing here for now
561 }
562 }
563
564 // Per-fold file output
565 if (fFoldFileOutput && foldOutputFile != nullptr) {
566 foldOutputFile->Close();
567 }
568
569 // Clean-up for this fold
570 {
571 smethod->Data()->DeleteAllResults(Types::kTraining, smethod->GetAnalysisType());
572 smethod->Data()->DeleteAllResults(Types::kTesting, smethod->GetAnalysisType());
573 }
574
575 fFoldFactory->DeleteAllMethods();
576 fFoldFactory->fMethodsMap.clear();
577
578 return result;
579}
580
581////////////////////////////////////////////////////////////////////////////////
582/// Does training, test set evaluation and performance evaluation of using
583/// cross-evalution.
584///
585
587{
588 // Generate K folds on given dataset
589 if (!fFoldStatus) {
590 fDataLoader->MakeKFoldDataSet(*fSplit);
591 fFoldStatus = kTRUE;
592 }
593
594 fResults.reserve(fMethods.size());
595 for (auto & methodInfo : fMethods) {
596 CrossValidationResult result{fNumFolds};
597
598 TString methodTypeName = methodInfo.GetValue<TString>("MethodName");
599 TString methodTitle = methodInfo.GetValue<TString>("MethodTitle");
600
601 if (methodTypeName == "") {
602 Log() << kFATAL << "No method booked for cross-validation" << Endl;
603 }
604
606 Log() << kINFO << Endl;
607 Log() << kINFO << Endl;
608 Log() << kINFO << "========================================" << Endl;
609 Log() << kINFO << "Processing folds for method " << methodTitle << Endl;
610 Log() << kINFO << "========================================" << Endl;
611 Log() << kINFO << Endl;
612
613 // Process K folds
614 auto nWorkers = fNumWorkerProcs;
615 if (nWorkers == 1) {
616 // Fall back to global config
617 nWorkers = TMVA::gConfig().GetNumWorkers();
618 }
619 if (nWorkers == 1) {
620 for (UInt_t iFold = 0; iFold < fNumFolds; ++iFold) {
621 auto fold_result = ProcessFold(iFold, methodInfo);
622 result.Fill(fold_result);
623 }
624 } else {
625#ifndef _MSC_VER
626 ROOT::TProcessExecutor workers(nWorkers);
627 std::vector<CrossValidationFoldResult> result_vector;
628
629 auto workItem = [this, methodInfo](UInt_t iFold) {
630 return ProcessFold(iFold, methodInfo);
631 };
632
633 result_vector = workers.Map(workItem, ROOT::TSeqI(fNumFolds));
634
635 for (auto && fold_result : result_vector) {
636 result.Fill(fold_result);
637 }
638#endif
639 }
640
641 fResults.push_back(result);
642
643 // Serialise the cross evaluated method
644 TString options =
645 Form("SplitExpr=%s:NumFolds=%i"
646 ":EncapsulatedMethodName=%s"
647 ":EncapsulatedMethodTypeName=%s"
648 ":OutputEnsembling=%s",
649 fSplitExprString.Data(), fNumFolds, methodTitle.Data(), methodTypeName.Data(), fOutputEnsembling.Data());
650
651 fFactory->BookMethod(fDataLoader.get(), Types::kCrossValidation, methodTitle, options);
652
653 // Feed EventToFold mapping used when random fold assignments are used
654 // (when splitExpr="").
655 IMethod *method_interface = fFactory->GetMethod(fDataLoader->GetName(), methodTitle);
656 auto *method = dynamic_cast<MethodCrossValidation *>(method_interface);
657
658 method->fEventToFoldMapping = fSplit->fEventToFoldMapping;
659 }
660
661 Log() << kINFO << Endl;
662 Log() << kINFO << Endl;
663 Log() << kINFO << "========================================" << Endl;
664 Log() << kINFO << "Folds processed for all methods, evaluating." << Endl;
665 Log() << kINFO << "========================================" << Endl;
666 Log() << kINFO << Endl;
667
668 // Recombination of data (making sure there is data in training and testing trees).
669 fDataLoader->RecombineKFoldDataSet(*fSplit);
670
671 // "Eval" on training set
672 for (auto & methodInfo : fMethods) {
673 TString methodTypeName = methodInfo.GetValue<TString>("MethodName");
674 TString methodTitle = methodInfo.GetValue<TString>("MethodTitle");
675
676 IMethod *method_interface = fFactory->GetMethod(fDataLoader->GetName(), methodTitle);
677 auto method = dynamic_cast<MethodCrossValidation *>(method_interface);
678
679 if (fOutputFile != nullptr) {
680 fFactory->WriteDataInformation(method->fDataSetInfo);
681 }
682
684 method->TrainMethod();
686 }
687
688 // Eval on Testing set
689 fFactory->TestAllMethods();
690
691 // Calc statistics
692 fFactory->EvaluateAllMethods();
693
694 Log() << kINFO << "Evaluation done." << Endl;
695}
696
697//_______________________________________________________________________
698const std::vector<TMVA::CrossValidationResult> &TMVA::CrossValidation::GetResults() const
699{
700 if (fResults.empty()) {
701 Log() << kFATAL << "No cross-validation results available" << Endl;
702 }
703 return fResults;
704}
#define c(i)
Definition RSha256.hxx:101
const Bool_t kFALSE
Definition RtypesCore.h:101
double Double_t
Definition RtypesCore.h:59
float Float_t
Definition RtypesCore.h:57
const Bool_t kTRUE
Definition RtypesCore.h:100
char name[80]
Definition TGX11.cxx:110
char * Form(const char *fmt,...)
R__EXTERN TSystem * gSystem
Definition TSystem.h:559
auto Map(F func, unsigned nTimes) -> std::vector< typename std::result_of< F()>::type >
Execute a function without arguments several times.
This class provides a simple interface to execute the same task multiple times in parallel,...
A pseudo container class which is a generator of indices.
Definition TSeq.hxx:66
virtual void SetLineWidth(Width_t lwidth)
Set the line width.
Definition TAttLine.h:43
virtual void SetLineColor(Color_t lcolor)
Set the line color.
Definition TAttLine.h:40
The Canvas class.
Definition TCanvas.h:23
virtual Int_t GetSize() const
Return the capacity of the collection, i.e.
A ROOT file is a suite of consecutive data records (TKey instances) with a well defined format.
Definition TFile.h:54
static TFile * Open(const char *name, Option_t *option="", const char *ftitle="", Int_t compress=ROOT::RCompressionSetting::EDefaults::kUseCompiledDefault, Int_t netopt=0)
Create / open a file.
Definition TFile.cxx:4025
void Close(Option_t *option="") override
Close a file.
Definition TFile.cxx:899
A TGraph is an object made of two arrays X and Y with npoints each.
Definition TGraph.h:41
virtual void SetTitle(const char *title="")
Change (i.e.
Definition TGraph.cxx:2353
virtual Double_t Eval(Double_t x, TSpline *spline=nullptr, Option_t *option="") const
Interpolate points in this graph at x using a TSpline.
Definition TGraph.cxx:887
This class displays a legend box (TPaveText) containing several legend entries.
Definition TLegend.h:23
A doubly linked list.
Definition TList.h:38
virtual TObject * At(Int_t idx) const
Returns the object at position idx. Returns 0 if idx is out of range.
Definition TList.cxx:357
UInt_t GetNumWorkers() const
Definition Config.h:72
void SetSilent(Bool_t s)
Definition Config.h:63
void CheckForUnusedOptions() const
checks for unused options in option string
Class to save the results of cross validation, the metric for the classification ins ROC and you can ...
std::vector< Double_t > fSeps
std::vector< Double_t > fEff01s
CrossValidationResult(UInt_t numFolds)
std::vector< Double_t > fTrainEff30s
std::shared_ptr< TMultiGraph > fROCCurves
std::vector< Double_t > fSigs
std::vector< Double_t > fEff30s
void Fill(CrossValidationFoldResult const &fr)
Float_t GetROCStandardDeviation() const
std::vector< Double_t > fEff10s
std::vector< Double_t > fTrainEff01s
std::map< UInt_t, Float_t > fROCs
std::vector< Double_t > fTrainEff10s
std::vector< Double_t > fEffAreas
TCanvas * DrawAvgROCCurve(Bool_t drawFolds=kFALSE, TString title="") const
TMultiGraph * GetROCCurves(Bool_t fLegend=kTRUE)
TGraph * GetAvgROCCurve(UInt_t numSamples=100) const
Generates a multigraph that contains an average ROC Curve.
TCanvas * Draw(const TString name="CrossValidation") const
Class to perform cross validation, splitting the dataloader into folds.
void ParseOptions()
Method to parse the internal option string.
const std::vector< CrossValidationResult > & GetResults() const
CrossValidation(TString jobName, TMVA::DataLoader *dataloader, TString options)
void SetSplitExpr(TString splitExpr)
void Evaluate()
Does training, test set evaluation and performance evaluation of using cross-evalution.
CrossValidationFoldResult ProcessFold(UInt_t iFold, const OptionMap &methodInfo)
Evaluates each fold in turn.
void DeleteAllResults(Types::ETreeType type, Types::EAnalysisType analysistype)
Deletes all results currently in the dataset.
Definition DataSet.cxx:343
Abstract base class for all high level ml algorithms, you can book ml methods like BDT,...
Definition Envelope.h:44
virtual void ParseOptions()
Method to parse the internal option string.
Definition Envelope.cxx:182
static void SetIsTraining(Bool_t)
when this static function is called, it sets the flag whether events with negative event weight shoul...
Definition Event.cxx:399
Interface for all concrete MVA method implementations.
Definition IMethod.h:53
Virtual base Class for all MVA method.
Definition MethodBase.h:111
virtual Double_t GetSeparation(TH1 *, TH1 *) const
compute "separation" defined as
virtual Double_t GetSignificance() const
compute significance of mean difference
Types::EAnalysisType GetAnalysisType() const
Definition MethodBase.h:437
virtual Double_t GetEfficiency(const TString &, Types::ETreeType, Double_t &err)
fill background efficiency (resp.
virtual Double_t GetTrainingEfficiency(const TString &)
DataSet * Data() const
Definition MethodBase.h:409
std::map< const TMVA::Event *, UInt_t > fEventToFoldMapping
ostringstream derivative to redirect and format output
Definition MsgLogger.h:57
static void EnableOutput()
Definition MsgLogger.cxx:68
class to storage options for the differents methods
Definition OptionMap.h:34
T GetValue(const TString &key)
Definition OptionMap.h:133
Singleton class for Global types used by TMVA.
Definition Types.h:71
@ kCrossValidation
Definition Types.h:109
@ kMulticlass
Definition Types.h:129
@ kNoAnalysisType
Definition Types.h:130
@ kClassification
Definition Types.h:127
@ kRegression
Definition Types.h:128
@ kTraining
Definition Types.h:143
A TMultiGraph is a collection of TGraph (or derived) objects.
Definition TMultiGraph.h:36
TList * GetListOfGraphs() const
Definition TMultiGraph.h:70
virtual void Add(TGraph *graph, Option_t *chopt="")
Add a new graph to the list of graphs.
TAxis * GetYaxis()
Get y axis of the graph.
TAxis * GetXaxis()
Get x axis of the graph.
virtual void SetTitle(const char *title="")
Set the title of the TNamed.
Definition TNamed.cxx:164
virtual void SetName(const char *name)
Set the name of the TNamed.
Definition TNamed.cxx:140
virtual TObject * Clone(const char *newname="") const
Make a clone of an object using the Streamer facility.
Definition TNamed.cxx:74
virtual TObject * DrawClone(Option_t *option="") const
Draw a clone of this object in the current selected pad for instance with: gROOT->SetSelectedPad(gPad...
Definition TObject.cxx:291
Basic string class.
Definition TString.h:136
const char * Data() const
Definition TString.h:369
static TString Format(const char *fmt,...)
Static method which formats a string using a printf style format descriptor and return a TString.
Definition TString.cxx:2336
virtual TString GetDirName(const char *pathname)
Return the directory name in pathname.
Definition TSystem.cxx:1032
Double_t y[n]
Definition legend1.C:17
Double_t x[n]
Definition legend1.C:17
TGraphErrors * gr
Definition legend1.C:25
leg
Definition legend1.C:34
create variable transformations
Config & gConfig()
MsgLogger & Endl(MsgLogger &ml)
Definition MsgLogger.h:148
Double_t Sqrt(Double_t x)
Definition TMath.h:641
LongDouble_t Power(LongDouble_t x, LongDouble_t y)
Definition TMath.h:685