42 fSigs.resize(numFolds);
43 fSeps.resize(numFolds);
108 Double_t increment = 1.0 / (numSamples-1);
109 std::vector<Double_t>
x(numSamples),
y(numSamples);
113 for(
UInt_t iSample = 0; iSample < numSamples; iSample++) {
114 Double_t xPoint = iSample * increment;
117 for(
Int_t iGraph = 0; iGraph < rocCurveList->
GetSize(); iGraph++) {
118 TGraph *foldROC =
static_cast<TGraph *
>(rocCurveList->
At(iGraph));
119 rocSum += foldROC->
Eval(xPoint);
123 y[iSample] = rocSum/rocCurveList->
GetSize();
126 return new TGraph(numSamples, &
x[0], &
y[0]);
133 for(
auto &roc :
fROCs) {
136 return avg/
fROCs.size();
145 for(
auto &roc :
fROCs) {
158 fLogger << kHEADER <<
" ==== Results ====" <<
Endl;
159 for(
auto &item:
fROCs) {
160 fLogger << kINFO <<
TString::Format(
"Fold %i ROC-Int : %.4f",item.first,item.second) << std::endl;
163 fLogger << kINFO <<
"------------------------" <<
Endl;
175 fROCCurves->GetXaxis()->SetTitle(
" Signal Efficiency ");
176 fROCCurves->GetYaxis()->SetTitle(
" Background Rejection ");
178 c->BuildLegend(0.15,0.15,0.4*adjust,0.5*adjust);
179 c->SetTitle(
"Cross Validation ROC Curves");
193 for (
auto foldRocObj : *(*fROCCurves).GetListOfGraphs()) {
194 TGraph * foldRocGraph =
dynamic_cast<TGraph *
>(foldRocObj->Clone());
197 rocs->
Add(foldRocGraph);
203 avgRocGraph->
SetTitle(
"Avg ROC Curve");
206 rocs->
Add(avgRocGraph);
212 title =
"Cross Validation Average ROC Curve";
227 leg->AddEntry(
static_cast<TGraph *
>(ROCCurveList->
At(nCurves-1)),
228 "Avg ROC Curve",
"l");
229 leg->AddEntry(
static_cast<TGraph *
>(ROCCurveList->
At(0)),
230 "Fold ROC Curves",
"l");
237 c->SetTitle(
"Cross Validation Average ROC Curve");
279 :
TMVA::
Envelope(jobName, dataloader, nullptr, options),
325 "Batch mode: boolean silent flag inhibiting any output from TMVA after the creation of the factory "
326 "class object (default: False)");
334 "List of transformations to test; formatting example: \"Transformations=I;D;P;U;G,D\", for "
335 "identity, decorrelation, PCA, Uniform and Gaussianisation followed by decorrelation "
344 "Set the analysis type (Classification, Regression, Multiclass, Auto) (default: Auto)");
352 "Set the split type (Deterministic, Random, RandomStratified) (default: Random)");
360 "Determines how many processes to use for evaluation. 1 means no"
361 " parallelisation. 2 means use 2 processes. 0 means figure out the"
362 " number automatically based on the number of cpus available. Default"
366 "If given a TMVA output file will be generated for each fold. Filename will be the same as "
367 "specifed for the combined output with a _foldX suffix. (default: false)");
370 "Combines output from contained methods. If None, no combination is performed. (default None)");
383 Log() << kFATAL <<
"SplitExpr can only be used with Deterministic Splitting" <<
Endl;
445 Log() << kFATAL <<
"No output file given, cannot generate per fold output." <<
Endl;
513 Log() << kDEBUG <<
"Processing " << methodTitle <<
" fold " << iFold <<
Endl;
516 TFile *foldOutputFile =
nullptr;
521 Log() << kINFO <<
"Creating fold output at:" << path <<
Endl;
543 gr->SetLineColor(iFold + 1);
545 gr->SetTitle(foldTitle.Data());
567 foldOutputFile->
Close();
596 for (
auto & methodInfo :
fMethods) {
599 TString methodTypeName = methodInfo.GetValue<
TString>(
"MethodName");
600 TString methodTitle = methodInfo.GetValue<
TString>(
"MethodTitle");
602 if (methodTypeName ==
"") {
603 Log() << kFATAL <<
"No method booked for cross-validation" <<
Endl;
609 Log() << kINFO <<
"========================================" <<
Endl;
610 Log() << kINFO <<
"Processing folds for method " << methodTitle <<
Endl;
611 Log() << kINFO <<
"========================================" <<
Endl;
623 result.
Fill(fold_result);
628 std::vector<CrossValidationFoldResult> result_vector;
630 auto workItem = [
this, methodInfo](
UInt_t iFold) {
636 for (
auto && fold_result : result_vector) {
637 result.
Fill(fold_result);
647 ":EncapsulatedMethodName=%s"
648 ":EncapsulatedMethodTypeName=%s"
649 ":OutputEnsembling=%s",
664 Log() << kINFO <<
"========================================" <<
Endl;
665 Log() << kINFO <<
"Folds processed for all methods, evaluating." <<
Endl;
666 Log() << kINFO <<
"========================================" <<
Endl;
673 for (
auto & methodInfo :
fMethods) {
674 TString methodTypeName = methodInfo.GetValue<
TString>(
"MethodName");
675 TString methodTitle = methodInfo.GetValue<
TString>(
"MethodTitle");
681 fFactory->WriteDataInformation(method->fDataSetInfo);
685 method->TrainMethod();
695 Log() << kINFO <<
"Evaluation done." <<
Endl;
702 Log() << kFATAL <<
"No cross-validation results available" <<
Endl;
int Int_t
Signed integer 4 bytes (int).
unsigned int UInt_t
Unsigned integer 4 bytes (unsigned int).
bool Bool_t
Boolean (0=false, 1=true) (bool).
double Double_t
Double 8 bytes.
float Float_t
Float 4 bytes (float).
auto Map(F func, unsigned nTimes) -> std::vector< InvokeResult_t< F > >
Execute a function without arguments several times.
This class provides a simple interface to execute the same task multiple times in parallel,...
virtual void SetLineWidth(Width_t lwidth)
Set the line width.
virtual void SetLineColor(Color_t lcolor)
Set the line color.
virtual Int_t GetSize() const
Return the capacity of the collection, i.e.
A file, usually with extension .root, that stores data and code in the form of serialized objects in ...
static TFile * Open(const char *name, Option_t *option="", const char *ftitle="", Int_t compress=ROOT::RCompressionSetting::EDefaults::kUseCompiledDefault, Int_t netopt=0)
Create / open a file.
void Close(Option_t *option="") override
Close a file.
virtual Double_t Eval(Double_t x, TSpline *spline=nullptr, Option_t *option="") const
void SetTitle(const char *title="") override
Set the title of the TNamed.
TObject * At(Int_t idx) const override
Returns the object at position idx. Returns 0 if idx is out of range.
UInt_t GetNumWorkers() const
OptionBase * DeclareOptionRef(T &ref, const TString &name, const TString &desc="")
void AddPreDefVal(const T &)
void CheckForUnusedOptions() const
checks for unused options in option string
Class to save the results of cross validation, the metric for the classification ins ROC and you can ...
std::vector< Double_t > fSeps
std::vector< Double_t > fEff01s
CrossValidationResult(UInt_t numFolds)
std::vector< Double_t > fTrainEff30s
std::shared_ptr< TMultiGraph > fROCCurves
std::vector< Double_t > fSigs
std::vector< Double_t > fEff30s
void Fill(CrossValidationFoldResult const &fr)
Float_t GetROCStandardDeviation() const
std::vector< Double_t > fEff10s
std::vector< Double_t > fTrainEff01s
std::map< UInt_t, Float_t > fROCs
std::vector< Double_t > fTrainEff10s
Float_t GetROCAverage() const
std::vector< Double_t > fEffAreas
TCanvas * DrawAvgROCCurve(Bool_t drawFolds=kFALSE, TString title="") const
TMultiGraph * GetROCCurves(Bool_t fLegend=kTRUE)
TGraph * GetAvgROCCurve(UInt_t numSamples=100) const
Generates a multigraph that contains an average ROC Curve.
TCanvas * Draw(const TString name="CrossValidation") const
void SetNumFolds(UInt_t i)
const std::vector< CrossValidationResult > & GetResults() const
std::vector< CrossValidationResult > fResults
!
std::unique_ptr< Factory > fFoldFactory
CrossValidation(TString jobName, TMVA::DataLoader *dataloader, TString options)
Bool_t fFoldStatus
! If true: dataset is prepared
std::unique_ptr< CvSplitKFolds > fSplit
void ParseOptions() override
options parser
Types::EAnalysisType fAnalysisType
Bool_t fFoldFileOutput
! If true: generate output file for each fold
TString fCvFactoryOptions
void SetSplitExpr(TString splitExpr)
TString fOutputFactoryOptions
std::unique_ptr< Factory > fFactory
void Evaluate() override
Does training, test set evaluation and performance evaluation of using cross-evalution.
UInt_t fNumFolds
! Number of folds to prepare
TString fOutputEnsembling
! How to combine output of individual folds
UInt_t fNumWorkerProcs
! Number of processes to use for fold evaluation. (Default, no parallel evaluation)
CrossValidationFoldResult ProcessFold(UInt_t iFold, const OptionMap &methodInfo)
Evaluates each fold in turn.
void DeleteAllResults(Types::ETreeType type, Types::EAnalysisType analysistype)
Deletes all results currently in the dataset.
void ParseOptions() override
Method to parse the internal option string.
std::vector< OptionMap > fMethods
! Booked method information
std::shared_ptr< DataLoader > fDataLoader
! data
Envelope(const TString &name, DataLoader *dataloader=nullptr, TFile *file=nullptr, const TString options="")
Constructor for the initialization of Envelopes, differents Envelopes may needs differents constructo...
static void SetIsTraining(Bool_t)
when this static function is called, it sets the flag whether events with negative event weight shoul...
Interface for all concrete MVA method implementations.
Virtual base Class for all MVA method.
virtual Double_t GetSeparation(TH1 *, TH1 *) const
compute "separation" defined as
virtual Double_t GetSignificance() const
compute significance of mean difference
Types::EAnalysisType GetAnalysisType() const
virtual Double_t GetEfficiency(const TString &, Types::ETreeType, Double_t &err)
fill background efficiency (resp.
virtual Double_t GetTrainingEfficiency(const TString &)
std::map< const TMVA::Event *, UInt_t > fEventToFoldMapping
ostringstream derivative to redirect and format output
static void EnableOutput()
class to storage options for the differents methods
T GetValue(const TString &key)
Singleton class for Global types used by TMVA.
virtual void Add(TGraph *graph, Option_t *chopt="")
TList * GetListOfGraphs() const
TObject * Clone(const char *newname="") const override
Make a clone of an object using the Streamer facility.
virtual void SetTitle(const char *title="")
Set the title of the TNamed.
virtual void SetName(const char *name)
Set the name of the TNamed.
virtual TObject * DrawClone(Option_t *option="") const
Draw a clone of this object in the current selected pad with: gROOT->SetSelectedPad(c1).
const char * Data() const
static TString Format(const char *fmt,...)
Static method which formats a string using a printf style format descriptor and return a TString.
create variable transformations
MsgLogger & Endl(MsgLogger &ml)
Double_t Sqrt(Double_t x)
Returns the square root of x.
LongDouble_t Power(LongDouble_t x, LongDouble_t y)
Returns x raised to the power y.