Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
TMVACrossValidationRegression.C File Reference

Detailed Description

View in nbviewer Open in SWAN This macro provides an example of how to use TMVA for k-folds cross evaluation.

As input data is used a toy-MC sample consisting of two guassian distributions.

The output file "TMVA.root" can be analysed with the use of dedicated macros (simply say: root -l <macro.C>), which can be conveniently invoked through a GUI that will appear at the end of the run of this macro. Launch the GUI via the command:

root -l -e 'TMVA::TMVAGui("TMVA.root")'
#define e(i)
Definition RSha256.hxx:103
auto * l
Definition textangle.C:4

Cross Evaluation

Cross evaluation is a special case of k-folds cross validation where the splitting into k folds is computed deterministically. This ensures that the a given event will always end up in the same fold.

In addition all resulting classifiers are saved and can be applied to new data using MethodCrossValidation. One requirement for this to work is a splitting function that is evaluated for each event to determine into what fold it goes (for training/evaluation) or to what classifier (for application).

Split Expression

Cross evaluation uses a deterministic split to partition the data into folds called the split expression. The expression can be any valid TFormula as long as all parts used are defined.

For each event the split expression is evaluated to a number and the event is put in the fold corresponding to that number.

It is recommended to always use int([NumFolds]) at the end of the expression.

The split expression has access to all spectators and variables defined in the dataloader. Additionally, the number of folds in the split can be accessed with NumFolds (or numFolds).

Example

"int(fabs([eventID]))%int([NumFolds])"
  • Project : TMVA - a ROOT-integrated toolkit for multivariate data analysis
  • Package : TMVA
  • Root Macro: TMVACrossValidationRegression
DataSetInfo : [dataset] : Added class "Regression"
: Add Tree TreeR of type Regression with 10000 events
--- TMVACrossValidationRegression: Using input file: ./files/tmva_reg_example.root
: Dataset[dataset] : Class index : 0 name : Regression
<HEADER> Factory : You are running ROOT Version: 6.24/09, Sep 29, 2022
:
: _/_/_/_/_/ _| _| _| _| _|_|
: _/ _|_| _|_| _| _| _| _|
: _/ _| _| _| _| _| _|_|_|_|
: _/ _| _| _| _| _| _|
: _/ _| _| _| _| _|
:
: ___________TMVA Version 4.2.1, Feb 5, 2015
:
: Building event vectors for type 2 Regression
: Dataset[dataset] : create input formulas for tree TreeR
<HEADER> DataSetFactory : [dataset] : Number of events in input trees
:
: Number of training and testing events
: ---------------------------------------------------------------------------
: Regression -- training events : 9999
: Regression -- testing events : 1
: Regression -- training and testing events: 10000
:
<HEADER> DataSetInfo : Correlation matrix (Regression):
: ------------------------
: var1 var2
: var1: +1.000 +0.002
: var2: +0.002 +1.000
: ------------------------
<HEADER> DataSetFactory : [dataset] :
:
:
:
: ========================================
: ========================================
:
<HEADER> Factory : Booking method: BDTG_fold1
:
: the option NegWeightTreatment=InverseBoostNegWeights does not exist for BoostType=Grad
: --> change to new default NegWeightTreatment=Pray
: Regression Loss Function: Huber
: Training 500 Decision Trees ... patience please
: Elapsed time for training with 4999 events: 1.3 sec
: Dataset[dataset] : Create results for training
: Dataset[dataset] : Evaluation of BDTG_fold1 on training sample
: Dataset[dataset] : Elapsed time for evaluation of 4999 events: 0.206 sec
: Create variable histograms
: Create regression target histograms
: Create regression average deviation
: Results created
: Creating xml weight file: dataset/weights/TMVACrossValidationRegression_BDTG_fold1.weights.xml
<HEADER> Factory : Test all methods
<HEADER> Factory : Test method: BDTG_fold1 for Regression performance
:
: Dataset[dataset] : Create results for testing
: Dataset[dataset] : Evaluation of BDTG_fold1 on testing sample
: Dataset[dataset] : Elapsed time for evaluation of 5000 events: 0.206 sec
: Create variable histograms
: Create regression target histograms
: Create regression average deviation
: Results created
<HEADER> Factory : Evaluate all methods
: Evaluate regression method: BDTG_fold1
: TestRegression (testing)
: Calculate regression for all events
: Elapsed time for evaluation of 5000 events: 0.204 sec
: TestRegression (training)
: Calculate regression for all events
: Elapsed time for evaluation of 4999 events: 0.205 sec
:
: Evaluation results ranked by smallest RMS on test sample:
: ("Bias" quotes the mean deviation of the regression from true target.
: "MutInf" is the "Mutual Information" between regression and target.
: Indicated by "_T" are the corresponding "truncated" quantities ob-
: tained when removing events deviating more than 2sigma from average.)
: --------------------------------------------------------------------------------------------------
: --------------------------------------------------------------------------------------------------
: dataset BDTG_fold1 : 0.133 0.0851 2.22 1.67 | 3.123 3.198
: --------------------------------------------------------------------------------------------------
:
: Evaluation results ranked by smallest RMS on training sample:
: (overtraining check)
: --------------------------------------------------------------------------------------------------
: DataSet Name: MVA Method: <Bias> <Bias_T> RMS RMS_T | MutInf MutInf_T
: --------------------------------------------------------------------------------------------------
: dataset BDTG_fold1 : 0.0474 -0.00861 2.09 1.52 | 3.136 3.206
: --------------------------------------------------------------------------------------------------
:
<HEADER> Factory : Thank you for using TMVA!
: For citation information, please visit: http://tmva.sf.net/citeTMVA.html
<HEADER> Factory : Booking method: BDTG_fold2
:
: the option NegWeightTreatment=InverseBoostNegWeights does not exist for BoostType=Grad
: --> change to new default NegWeightTreatment=Pray
: Regression Loss Function: Huber
: Training 500 Decision Trees ... patience please
: Elapsed time for training with 5000 events: 1.31 sec
: Dataset[dataset] : Create results for training
: Dataset[dataset] : Evaluation of BDTG_fold2 on training sample
: Dataset[dataset] : Elapsed time for evaluation of 5000 events: 0.208 sec
: Create variable histograms
: Create regression target histograms
: Create regression average deviation
: Results created
: Creating xml weight file: dataset/weights/TMVACrossValidationRegression_BDTG_fold2.weights.xml
<HEADER> Factory : Test all methods
<HEADER> Factory : Test method: BDTG_fold2 for Regression performance
:
: Dataset[dataset] : Create results for testing
: Dataset[dataset] : Evaluation of BDTG_fold2 on testing sample
: Dataset[dataset] : Elapsed time for evaluation of 4999 events: 0.208 sec
: Create variable histograms
: Create regression target histograms
: Create regression average deviation
: Results created
<HEADER> Factory : Evaluate all methods
: Evaluate regression method: BDTG_fold2
: TestRegression (testing)
: Calculate regression for all events
: Elapsed time for evaluation of 4999 events: 0.207 sec
: TestRegression (training)
: Calculate regression for all events
: Elapsed time for evaluation of 5000 events: 0.207 sec
:
: Evaluation results ranked by smallest RMS on test sample:
: ("Bias" quotes the mean deviation of the regression from true target.
: "MutInf" is the "Mutual Information" between regression and target.
: Indicated by "_T" are the corresponding "truncated" quantities ob-
: tained when removing events deviating more than 2sigma from average.)
: --------------------------------------------------------------------------------------------------
: --------------------------------------------------------------------------------------------------
: dataset BDTG_fold2 : -0.0428 -0.0362 2.33 1.72 | 3.109 3.188
: --------------------------------------------------------------------------------------------------
:
: Evaluation results ranked by smallest RMS on training sample:
: (overtraining check)
: --------------------------------------------------------------------------------------------------
: DataSet Name: MVA Method: <Bias> <Bias_T> RMS RMS_T | MutInf MutInf_T
: --------------------------------------------------------------------------------------------------
: dataset BDTG_fold2 : 0.00417 0.0137 2.05 1.51 | 3.145 3.215
: --------------------------------------------------------------------------------------------------
:
<HEADER> Factory : Thank you for using TMVA!
: For citation information, please visit: http://tmva.sf.net/citeTMVA.html
<HEADER> Factory : Booking method: BDTG
:
: Reading weightfile: dataset/weights/TMVACrossValidationRegression_BDTG_fold1.weights.xml
: Reading weight file: dataset/weights/TMVACrossValidationRegression_BDTG_fold1.weights.xml
: Reading weightfile: dataset/weights/TMVACrossValidationRegression_BDTG_fold2.weights.xml
: Reading weight file: dataset/weights/TMVACrossValidationRegression_BDTG_fold2.weights.xml
:
:
: ========================================
: Folds processed for all methods, evaluating.
: ========================================
:
<HEADER> Factory : [dataset] : Create Transformation "I" with events from all classes.
:
<HEADER> : Transformation, Variable selection :
: Input : variable 'var1' <---> Output : variable 'var1'
: Input : variable 'var2' <---> Output : variable 'var2'
<HEADER> TFHandler_Factory : Variable Mean RMS [ Min Max ]
: -----------------------------------------------------------
: var1: 2.4948 1.4515 [ 0.00020069 5.0000 ]
: var2: 2.4837 1.4409 [ 0.00071490 5.0000 ]
: fvalue: 134.53 84.778 [ 1.6186 394.84 ]
: -----------------------------------------------------------
: Ranking input variables (method unspecific)...
<HEADER> IdTransformation : Ranking result (top variable is best ranked)
: --------------------------------------------
: Rank : Variable : |Correlation with target|
: --------------------------------------------
: 1 : var2 : 7.607e-01
: 2 : var1 : 5.995e-01
: --------------------------------------------
<HEADER> IdTransformation : Ranking result (top variable is best ranked)
: -------------------------------------
: Rank : Variable : Mutual information
: -------------------------------------
: 1 : var1 : 2.253e+00
: 2 : var2 : 2.100e+00
: -------------------------------------
<HEADER> IdTransformation : Ranking result (top variable is best ranked)
: ------------------------------------
: Rank : Variable : Correlation Ratio
: ------------------------------------
: 1 : var2 : 2.458e+00
: 2 : var1 : 2.336e+00
: ------------------------------------
<HEADER> IdTransformation : Ranking result (top variable is best ranked)
: ----------------------------------------
: Rank : Variable : Correlation Ratio (T)
: ----------------------------------------
: 1 : var1 : 5.362e-01
: 2 : var2 : 5.109e-01
: ----------------------------------------
: Elapsed time for training with 9999 events: 5.01e-06 sec
: Dataset[dataset] : Create results for training
: Dataset[dataset] : Evaluation of BDTG on training sample
: Dataset[dataset] : Elapsed time for evaluation of 9999 events: 0.37 sec
: Create variable histograms
: Create regression target histograms
: Create regression average deviation
: Results created
: Creating xml weight file: dataset/weights/TMVACrossValidationRegression_BDTG.weights.xml
<HEADER> Factory : Test all methods
<HEADER> Factory : Test method: BDTG for Regression performance
:
: Dataset[dataset] : Create results for testing
: Dataset[dataset] : Evaluation of BDTG on testing sample
: Dataset[dataset] : Elapsed time for evaluation of 9999 events: 0.369 sec
: Create variable histograms
: Create regression target histograms
: Create regression average deviation
: Results created
<HEADER> Factory : Evaluate all methods
: Evaluate regression method: BDTG
: TestRegression (testing)
: Calculate regression for all events
: Elapsed time for evaluation of 9999 events: 0.369 sec
: TestRegression (training)
: Calculate regression for all events
: Elapsed time for evaluation of 9999 events: 0.371 sec
<HEADER> TFHandler_BDTG : Variable Mean RMS [ Min Max ]
: -----------------------------------------------------------
: var1: 2.4948 1.4515 [ 0.00020069 5.0000 ]
: var2: 2.4837 1.4409 [ 0.00071490 5.0000 ]
: fvalue: 134.53 84.778 [ 1.6186 394.84 ]
: -----------------------------------------------------------
:
: Evaluation results ranked by smallest RMS on test sample:
: ("Bias" quotes the mean deviation of the regression from true target.
: "MutInf" is the "Mutual Information" between regression and target.
: Indicated by "_T" are the corresponding "truncated" quantities ob-
: tained when removing events deviating more than 2sigma from average.)
: --------------------------------------------------------------------------------------------------
: --------------------------------------------------------------------------------------------------
: dataset BDTG : 0.0449 0.0259 2.28 1.70 | 3.108 3.190
: --------------------------------------------------------------------------------------------------
:
: Evaluation results ranked by smallest RMS on training sample:
: (overtraining check)
: --------------------------------------------------------------------------------------------------
: DataSet Name: MVA Method: <Bias> <Bias_T> RMS RMS_T | MutInf MutInf_T
: --------------------------------------------------------------------------------------------------
: dataset BDTG : 0.0449 0.0259 2.28 1.70 | 3.108 3.190
: --------------------------------------------------------------------------------------------------
:
<HEADER> Dataset:dataset : Created tree 'TestTree' with 9999 events
:
<HEADER> Dataset:dataset : Created tree 'TrainTree' with 9999 events
:
<HEADER> Factory : Thank you for using TMVA!
: For citation information, please visit: http://tmva.sf.net/citeTMVA.html
: Evaluation done.
==> Wrote root file: TMVARegCv.root
==> TMVACrossValidationRegression is done!
(int) 0
#include <cstdlib>
#include <iostream>
#include <map>
#include <string>
#include "TChain.h"
#include "TFile.h"
#include "TTree.h"
#include "TString.h"
#include "TObjString.h"
#include "TSystem.h"
#include "TROOT.h"
#include "TMVA/Factory.h"
#include "TMVA/Tools.h"
#include "TMVA/TMVAGui.h"
TFile * getDataFile(TString fname) {
TFile *input(0);
if (!gSystem->AccessPathName(fname)) {
input = TFile::Open(fname); // check if file in local directory exists
} else {
// if not: download from ROOT server
input = TFile::Open("http://root.cern.ch/files/tmva_reg_example.root", "CACHEREAD");
}
if (!input) {
std::cout << "ERROR: could not open data file " << fname << std::endl;
exit(1);
}
return input;
}
int TMVACrossValidationRegression()
{
// This loads the library
// --------------------------------------------------------------------------
// Create a ROOT output file where TMVA will store ntuples, histograms, etc.
TString outfileName("TMVARegCv.root");
TFile * outputFile = TFile::Open(outfileName, "RECREATE");
TString infileName("./files/tmva_reg_example.root");
TFile * inputFile = getDataFile(infileName);
TMVA::DataLoader *dataloader=new TMVA::DataLoader("dataset");
dataloader->AddVariable("var1", "Variable 1", "units", 'F');
dataloader->AddVariable("var2", "Variable 2", "units", 'F');
// Add the variable carrying the regression target
dataloader->AddTarget("fvalue");
TTree * regTree = (TTree*)inputFile->Get("TreeR");
dataloader->AddRegressionTree(regTree, 1.0);
// Individual events can be weighted
// dataloader->SetWeightExpression("weight", "Regression");
std::cout << "--- TMVACrossValidationRegression: Using input file: " << inputFile->GetName() << std::endl;
// Bypasses the normal splitting mechanism, CV uses a new system for this.
// Unfortunately the old system is unhappy if we leave the test set empty so
// we ensure that there is at least one event by placing the first event in
// it.
// You can with the selection cut place a global cut on the defined
// variables. Only events passing the cut will be using in training/testing.
// Example: `TCut selectionCut = "var1 < 1";`
TCut selectionCut = "";
dataloader->PrepareTrainingAndTestTree(selectionCut, "nTest_Regression=1"
":SplitMode=Block"
":NormMode=NumEvents"
":!V");
// --------------------------------------------------------------------------
//
// This sets up a CrossValidation class (which wraps a TMVA::Factory
// internally) for 2-fold cross validation. The data will be split into the
// two folds randomly if `splitExpr` is `""`.
//
// One can also give a deterministic split using spectator variables. An
// example would be e.g. `"int(fabs([spec1]))%int([NumFolds])"`.
//
UInt_t numFolds = 2;
TString analysisType = "Regression";
TString splitExpr = "";
TString cvOptions = Form("!V"
":!Silent"
":ModelPersistence"
":!FoldFileOutput"
":AnalysisType=%s"
":NumFolds=%i"
":SplitExpr=%s",
analysisType.Data(), numFolds, splitExpr.Data());
TMVA::CrossValidation cv{"TMVACrossValidationRegression", dataloader, outputFile, cvOptions};
// --------------------------------------------------------------------------
//
// Books a method to use for evaluation
//
cv.BookMethod(TMVA::Types::kBDT, "BDTG",
"!H:!V:NTrees=500:BoostType=Grad:Shrinkage=0.1:"
"UseBaggedBoost:BaggedSampleFraction=0.5:nCuts=20:MaxDepth=3");
// --------------------------------------------------------------------------
//
// Train, test and evaluate the booked methods.
// Evaluates the booked methods once for each fold and aggregates the result
// in the specified output file.
//
cv.Evaluate();
// --------------------------------------------------------------------------
//
// Save the output
//
outputFile->Close();
std::cout << "==> Wrote root file: " << outputFile->GetName() << std::endl;
std::cout << "==> TMVACrossValidationRegression is done!" << std::endl;
// --------------------------------------------------------------------------
//
// Launch the GUI for the root macros
//
if (!gROOT->IsBatch()) {
TMVA::TMVAGui(outfileName);
}
return 0;
}
//
// This is used if the macro is compiled. If run through ROOT with
// `root -l -b -q MACRO.C` or similar it is unused.
//
int main(int argc, char **argv)
{
TMVACrossValidationRegression();
}
unsigned int UInt_t
Definition RtypesCore.h:46
#define gROOT
Definition TROOT.h:406
char * Form(const char *fmt,...)
R__EXTERN TSystem * gSystem
Definition TSystem.h:559
A specialized string object used for TTree selections.
Definition TCut.h:25
TObject * Get(const char *namecycle) override
Return pointer to object identified by namecycle.
A ROOT file is a suite of consecutive data records (TKey instances) with a well defined format.
Definition TFile.h:54
static Bool_t SetCacheFileDir(ROOT::Internal::TStringView cacheDir, Bool_t operateDisconnected=kTRUE, Bool_t forceCacheread=kFALSE)
Definition TFile.h:324
static TFile * Open(const char *name, Option_t *option="", const char *ftitle="", Int_t compress=ROOT::RCompressionSetting::EDefaults::kUseCompiledDefault, Int_t netopt=0)
Create / open a file.
Definition TFile.cxx:3997
void Close(Option_t *option="") override
Close a file.
Definition TFile.cxx:879
Class to perform cross validation, splitting the dataloader into folds.
void AddRegressionTree(TTree *tree, Double_t weight=1.0, Types::ETreeType treetype=Types::kMaxTreeType)
Definition DataLoader.h:103
void PrepareTrainingAndTestTree(const TCut &cut, const TString &splitOpt)
prepare the training and test trees -> same cuts for signal and background
void AddTarget(const TString &expression, const TString &title="", const TString &unit="", Double_t min=0, Double_t max=0)
user inserts target in data set info
void AddVariable(const TString &expression, const TString &title, const TString &unit, char type='F', Double_t min=0, Double_t max=0)
user inserts discriminating variable in data set info
static Tools & Instance()
Definition Tools.cxx:75
virtual const char * GetName() const
Returns name of object.
Definition TNamed.h:47
Basic string class.
Definition TString.h:136
const char * Data() const
Definition TString.h:369
virtual Bool_t AccessPathName(const char *path, EAccessMode mode=kFileExists)
Returns FALSE if one can access a file using the specified access mode.
Definition TSystem.cxx:1294
A TTree represents a columnar dataset.
Definition TTree.h:79
int main()
void TMVAGui(const char *fName="TMVA.root", TString dataset="")
Author
Kim Albertsson (adapted from code originally by Andreas Hoecker)

Definition in file TMVACrossValidationRegression.C.