This is an example of using a RNN in TMVA. We do classification using a toy time dependent data set that is generated when running this example macro
void MakeTimeData(
int n,
int ntime,
int ndim )
{
std::vector<TH1 *>
v1(ntime);
std::vector<TH1 *>
v2(ntime);
int i = 0;
for (int i = 0; i < ntime; ++i) {
}
auto f1 =
new TF1(
"f1",
"gaus");
auto f2 =
new TF1(
"f2",
"gaus");
std::vector<std::vector<float>> x1(ntime);
std::vector<std::vector<float>> x2(ntime);
for (int i = 0; i < ntime; ++i) {
x1[i] = std::vector<float>(ndim);
x2[i] = std::vector<float>(ndim);
}
for (auto i = 0; i < ntime; i++) {
bkg.Branch(
Form(
"vars_time%d", i),
"std::vector<float>", &x1[i]);
sgn.Branch(
Form(
"vars_time%d", i),
"std::vector<float>", &x2[i]);
}
std::vector<double> mean1(ntime);
std::vector<double> mean2(ntime);
std::vector<double> sigma1(ntime);
std::vector<double> sigma2(ntime);
for (int j = 0; j < ntime; ++j) {
}
for (
int i = 0; i <
n; ++i) {
if (i % 1000 == 0)
std::cout << "Generating event ... " << i << std::endl;
for (int j = 0; j < ntime; ++j) {
h2->Reset();
f1->SetParameters(1, mean1[j], sigma1[j]);
f2->SetParameters(1, mean2[j], sigma2[j]);
h1->FillRandom(
"f1", 1000);
h2->FillRandom("f2", 1000);
for (int k = 0; k < ndim; ++k) {
x1[j][k] =
h1->GetBinContent(k + 1) +
gRandom->Gaus(0, 10);
x2[j][k] = h2->GetBinContent(k + 1) +
gRandom->Gaus(0, 10);
}
}
sgn.Fill();
bkg.Fill();
for (int j = 0; j < ntime; ++j) {
}
for (int j = 0; j < ntime; ++j) {
}
}
}
sgn.Write();
bkg.Write();
sgn.Print();
bkg.Print();
}
}
void TMVA_RNN_Classification(int nevts = 2000, int use_type = 1)
{
const int ninput = 30;
const int ntime = 10;
const int batchSize = 100;
const int maxepochs = 20;
int nTotEvts = nevts;
bool useKeras = true;
bool useTMVA_RNN = true;
bool useTMVA_DNN = true;
bool useTMVA_BDT = false;
std::vector<std::string> rnn_types = {"RNN", "LSTM", "GRU"};
std::vector<bool> use_rnn_type = {1, 1, 1};
if (use_type >=0 && use_type < 3) {
use_rnn_type = {0,0,0};
use_rnn_type[use_type] = 1;
}
bool useGPU = true;
#ifndef R__HAS_TMVAGPU
useGPU = false;
#ifndef R__HAS_TMVACPU
Warning(
"TMVA_RNN_Classification",
"TMVA is not build with GPU or CPU multi-thread support. Cannot use TMVA Deep Learning for RNN");
useTMVA_RNN = false;
#endif
#endif
TString archString = (useGPU) ?
"GPU" :
"CPU";
bool writeOutputFile = true;
const char *rnn_type = "RNN";
#ifdef R__HAS_PYMVA
#else
useKeras = false;
#endif
#ifdef R__USE_IMT
int num_threads = 4;
gSystem->Setenv(
"OMP_NUM_THREADS",
"1");
if (num_threads >= 0) {
}
#endif
TString inputFileName =
"time_data_t10_d30.root";
bool fileExist = !
gSystem->AccessPathName(inputFileName);
if (!fileExist) {
MakeTimeData(nTotEvts,ntime, ninput);
}
if (!inputFile) {
Error(
"TMVA_RNN_Classification",
"Error opening input file %s - exit", inputFileName.
Data());
return;
}
std::cout << "--- RNNClassification : Using input file: " << inputFile->GetName() << std::endl;
TFile *outputFile =
nullptr;
if (writeOutputFile) outputFile =
TFile::Open(outfileName,
"RECREATE");
"!V:!Silent:Color:DrawProgressBar:Transformations=None:!Correlations:"
"AnalysisType=Classification:ModelPersistence");
TTree *signalTree = (
TTree *)inputFile->Get(
"sgn");
TTree *background = (
TTree *)inputFile->Get(
"bkg");
const int nvar = ninput * ntime;
for (auto i = 0; i < ntime; i++) {
}
std::cout << "number of variables is " << vars.size() << std::endl;
std::cout << std::endl;
int nTrainSig = 0.8 * nTotEvts;
int nTrainBkg = 0.8 * nTotEvts;
TString prepareOptions =
TString::Format(
"nTrain_Signal=%d:nTrain_Background=%d:SplitMode=Random:SplitSeed=100:NormMode=NumEvents:!V:!CalcCorrelations", nTrainSig, nTrainBkg);
std::cout << "prepared DATA LOADER " << std::endl;
if (useTMVA_RNN) {
for (int i = 0; i < 3; ++i) {
if (!use_rnn_type[i])
continue;
const char *rnn_type = rnn_types[i].c_str();
"ConvergenceSteps=5,BatchSize=%d,TestRepetitions=1,"
"WeightDecay=1e-2,Regularization=None,MaxEpochs=%d,"
"Optimizer=ADAM,DropConfig=0.0+0.+0.+0.",
batchSize,maxepochs);
TString trainingStrategyString(
"TrainingStrategy=");
trainingStrategyString += trainingString1;
TString rnnOptions(
"!H:V:ErrorStrategy=CROSSENTROPY:VarTransform=None:"
"WeightInitialization=XAVIERUNIFORM:ValidationSize=0.2:RandomSeed=1234");
rnnOptions.Append(":");
rnnOptions.Append(inputLayoutString);
rnnOptions.Append(":");
rnnOptions.Append(layoutString);
rnnOptions.Append(":");
rnnOptions.Append(trainingStrategyString);
rnnOptions.Append(":");
}
}
if (useTMVA_DNN) {
TString layoutString(
"Layout=DENSE|64|TANH,DENSE|TANH|64,DENSE|TANH|64,LINEAR");
TString trainingString1(
"LearningRate=1e-3,Momentum=0.0,Repetitions=1,"
"ConvergenceSteps=10,BatchSize=256,TestRepetitions=1,"
"WeightDecay=1e-4,Regularization=None,MaxEpochs=20"
"DropConfig=0.0+0.+0.+0.,Optimizer=ADAM");
TString trainingStrategyString(
"TrainingStrategy=");
trainingStrategyString += trainingString1;
TString dnnOptions(
"!H:V:ErrorStrategy=CROSSENTROPY:VarTransform=None:"
"WeightInitialization=XAVIER:RandomSeed=0");
dnnOptions.Append(":");
dnnOptions.Append(inputLayoutString);
dnnOptions.Append(":");
dnnOptions.Append(layoutString);
dnnOptions.Append(":");
dnnOptions.Append(trainingStrategyString);
dnnOptions.Append(":");
dnnOptions.Append(archString);
}
if (useKeras) {
for (int i = 0; i < 3; i++) {
if (use_rnn_type[i]) {
Info(
"TMVA_RNN_Classification",
"Building recurrent keras model using a %s layer", rnn_types[i].c_str());
m.AddLine(
"import tensorflow");
m.AddLine(
"from tensorflow.keras.models import Sequential");
m.AddLine(
"from tensorflow.keras.optimizers import Adam");
m.AddLine(
"from tensorflow.keras.layers import Input, Dense, Dropout, Flatten, SimpleRNN, GRU, LSTM, Reshape, "
"BatchNormalization");
m.AddLine(
"model = Sequential() ");
m.AddLine(
"model.add(Reshape((10, 30), input_shape = (10*30, )))");
if (rnn_types[i] == "LSTM")
m.AddLine(
"model.add(LSTM(units=10, return_sequences=True) )");
else if (rnn_types[i] == "GRU")
m.AddLine(
"model.add(GRU(units=10, return_sequences=True) )");
else
m.AddLine(
"model.add(SimpleRNN(units=10, return_sequences=True) )");
m.AddLine(
"model.add(Flatten())");
m.AddLine(
"model.add(Dense(64, activation = 'tanh')) ");
m.AddLine(
"model.add(Dense(2, activation = 'sigmoid')) ");
"model.compile(loss = 'binary_crossentropy', optimizer = Adam(learning_rate = 0.001), weighted_metrics = ['accuracy'])");
m.AddLine(
"model.save(modelName)");
m.AddLine(
"model.summary()");
m.SaveSource(
"make_rnn_model.py");
gSystem->Exec(python_exe +
" make_rnn_model.py");
if (
gSystem->AccessPathName(modelName)) {
Warning(
"TMVA_RNN_Classification",
"Error creating Keras recurrent model file - Skip using Keras");
useKeras = false;
} else {
Info(
"TMVA_RNN_Classification",
"Booking Keras %s model", rnn_types[i].c_str());
"FilenameTrainedModel=%s:NumEpochs=%d:BatchSize=%d",
modelName.
Data(), trainedModelName.
Data(), maxepochs, batchSize));
}
}
}
}
if (!useKeras || !useTMVA_BDT)
useTMVA_BDT = true;
if (useTMVA_BDT) {
"!H:!V:NTrees=100:MinNodeSize=2.5%:BoostType=Grad:Shrinkage=0.10:UseBaggedBoost:"
"BaggedSampleFraction=0.5:nCuts=20:"
"MaxDepth=2");
}
if (outputFile) outputFile->
Close();
}
Error("WriteTObject","The current directory (%s) is not associated with a file. The object (%s) has not been written.", GetName(), objname)
void Info(const char *location, const char *msgfmt,...)
Use this function for informational messages.
void Warning(const char *location, const char *msgfmt,...)
Use this function in warning situations.
char * Form(const char *fmt,...)
Formats a string in a circular formatting buffer.
A specialized string object used for TTree selections.
A file, usually with extension .root, that stores data and code in the form of serialized objects in ...
static TFile * Open(const char *name, Option_t *option="", const char *ftitle="", Int_t compress=ROOT::RCompressionSetting::EDefaults::kUseCompiledDefault, Int_t netopt=0)
Create / open a file.
void Close(Option_t *option="") override
Close a file.
1-D histogram with a double per channel (see TH1 documentation)
static Config & Instance()
static function: returns TMVA instance
void AddVariablesArray(const TString &expression, int size, char type='F', Double_t min=0, Double_t max=0)
user inserts discriminating array of variables in data set info in case input tree provides an array ...
void AddSignalTree(TTree *signal, Double_t weight=1.0, Types::ETreeType treetype=Types::kMaxTreeType)
number of signal events (used to compute significance)
void PrepareTrainingAndTestTree(const TCut &cut, const TString &splitOpt)
prepare the training and test trees -> same cuts for signal and background
void AddBackgroundTree(TTree *background, Double_t weight=1.0, Types::ETreeType treetype=Types::kMaxTreeType)
number of signal events (used to compute significance)
DataSetInfo & GetDataSetInfo()
std::vector< TString > GetListOfVariables() const
returns list of variables
This is the main MVA steering class.
void TrainAllMethods()
Iterates through all booked methods and calls training.
void TestAllMethods()
Evaluates all booked methods on the testing data and adds the output to the Results in the corresponi...
void EvaluateAllMethods(void)
Iterates over all MVAs that have been booked, and calls their evaluation methods.
MethodBase * BookMethod(DataLoader *loader, MethodName theMethodName, TString methodTitle, TString theOption="")
Books an MVA classifier or regression method.
TGraph * GetROCCurve(DataLoader *loader, TString theMethodName, Bool_t setTitles=kTRUE, UInt_t iClass=0, Types::ETreeType type=Types::kTesting)
Argument iClass specifies the class to generate the ROC curve in a multiclass setting.
static void PyInitialize()
Initialize Python interpreter.
Class supporting a collection of lines with C++ code.
const char * Data() const
static TString Format(const char *fmt,...)
Static method which formats a string using a printf style format descriptor and return a TString.
A TTree represents a columnar dataset.
RVec< PromoteType< T > > cos(const RVec< T > &v)
RVec< PromoteType< T > > sin(const RVec< T > &v)
void EnableImplicitMT(UInt_t numthreads=0)
Enable ROOT's implicit multi-threading for all objects and methods that provide an internal paralleli...
UInt_t GetThreadPoolSize()
Returns the size of ROOT's thread pool.