Logo ROOT   6.10/09
Reference Guide
testPyKerasRegression.C
Go to the documentation of this file.
1 #include <iostream>
2 
3 #include "TString.h"
4 #include "TFile.h"
5 #include "TTree.h"
6 #include "TSystem.h"
7 #include "TMVA/Factory.h"
8 #include "TMVA/Reader.h"
9 #include "TMVA/DataLoader.h"
10 #include "TMVA/PyMethodBase.h"
11 
13 from keras.models import Sequential\n\
14 from keras.layers.core import Dense, Activation\n\
15 from keras.optimizers import SGD\n\
16 \n\
17 model = Sequential()\n\
18 model.add(Dense(64, activation=\"tanh\", input_dim=2))\n\
19 model.add(Dense(1, activation=\"linear\"))\n\
20 model.compile(loss=\"mean_squared_error\", optimizer=SGD(lr=0.01))\n\
21 model.save(\"kerasModelRegression.h5\")\n";
22 
24  // Get data file
25  std::cout << "Get test data..." << std::endl;
26  TString fname = "./tmva_reg_example.root";
27  if (gSystem->AccessPathName(fname)) // file does not exist in local directory
28  gSystem->Exec("curl -O http://root.cern.ch/files/tmva_reg_example.root");
29  TFile *input = TFile::Open(fname);
30 
31  // Build model from python file
32  std::cout << "Generate keras model..." << std::endl;
33  UInt_t ret;
34  ret = gSystem->Exec("echo '"+pythonSrc+"' > generateKerasModelRegression.py");
35  if(ret!=0){
36  std::cout << "[ERROR] Failed to write python code to file" << std::endl;
37  return 1;
38  }
39  ret = gSystem->Exec("python generateKerasModelRegression.py");
40  if(ret!=0){
41  std::cout << "[ERROR] Failed to generate model using python" << std::endl;
42  return 1;
43  }
44 
45  // Setup PyMVA and factory
46  std::cout << "Setup TMVA..." << std::endl;
48  TFile* outputFile = TFile::Open("ResultsTestPyKerasRegression.root", "RECREATE");
49  TMVA::Factory *factory = new TMVA::Factory("testPyKerasRegression", outputFile,
50  "!V:Silent:Color:!DrawProgressBar:AnalysisType=Regression");
51 
52  // Load data
53  TMVA::DataLoader *dataloader = new TMVA::DataLoader("datasetTestPyKerasRegression");
54 
55  TTree *tree = (TTree*)input->Get("TreeR");
56  dataloader->AddRegressionTree(tree);
57 
58  dataloader->AddVariable("var1");
59  dataloader->AddVariable("var2");
60  dataloader->AddTarget("fvalue");
61 
62  dataloader->PrepareTrainingAndTestTree("",
63  "SplitMode=Random:NormMode=NumEvents:!V");
64 
65  // Book and train method
66  factory->BookMethod(dataloader, TMVA::Types::kPyKeras, "PyKeras",
67  "!H:!V:VarTransform=D,G:FilenameModel=kerasModelRegression.h5:FilenameTrainedModel=trainedKerasModelRegression.h5:NumEpochs=10:BatchSize=32:SaveBestOnly=false:Verbose=0");
68  std::cout << "Train model..." << std::endl;
69  factory->TrainAllMethods();
70 
71  // Clean-up
72  delete factory;
73  delete dataloader;
74  delete outputFile;
75 
76  // Setup reader
77  UInt_t numEvents = 100;
78  std::cout << "Run reader and estimate target of " << numEvents << " events..." << std::endl;
79  TMVA::Reader *reader = new TMVA::Reader("!Color:Silent");
80  Float_t vars[3];
81  reader->AddVariable("var1", vars+0);
82  reader->AddVariable("var2", vars+1);
83  reader->BookMVA("PyKeras", "datasetTestPyKerasRegression/weights/testPyKerasRegression_PyKeras.weights.xml");
84 
85  // Get mean squared error on events
86  tree->SetBranchAddress("var1", vars+0);
87  tree->SetBranchAddress("var2", vars+1);
88  tree->SetBranchAddress("fvalue", vars+2);
89 
90  Float_t meanMvaError = 0;
91  for(UInt_t i=0; i<numEvents; i++){
92  tree->GetEntry(i);
93  meanMvaError += std::pow(vars[2]-reader->EvaluateMVA("PyKeras"),2);
94  }
95  meanMvaError = meanMvaError/float(numEvents);
96 
97  // Check whether the response is obviously better than guessing
98  std::cout << "Mean squared error: " << meanMvaError << std::endl;
99  if(meanMvaError > 30.0){
100  std::cout << "[ERROR] Mean squared error is " << meanMvaError << " (>30.0)" << std::endl;
101  return 1;
102  }
103 
104  return 0;
105 }
106 
107 int main(){
108  int err = testPyKerasRegression();
109  return err;
110 }
virtual Bool_t AccessPathName(const char *path, EAccessMode mode=kFileExists)
Returns FALSE if one can access a file using the specified access mode.
Definition: TSystem.cxx:1272
MethodBase * BookMethod(DataLoader *loader, TString theMethodName, TString methodTitle, TString theOption="")
Book a classifier or regression method.
Definition: Factory.cxx:343
float Float_t
Definition: RtypesCore.h:53
void AddVariable(const TString &expression, Float_t *)
Add a float variable or expression to the reader.
Definition: Reader.cxx:308
A ROOT file is a suite of consecutive data records (TKey instances) with a well defined format...
Definition: TFile.h:46
virtual TObject * Get(const char *namecycle)
Return pointer to object identified by namecycle.
int testPyKerasRegression()
virtual Int_t GetEntry(Long64_t entry=0, Int_t getall=0)
Read all branches of entry and return total number of bytes read.
Definition: TTree.cxx:5321
Basic string class.
Definition: TString.h:129
void TrainAllMethods()
Iterates through all booked methods and calls training.
Definition: Factory.cxx:1017
void AddVariable(const TString &expression, const TString &title, const TString &unit, char type='F', Double_t min=0, Double_t max=0)
user inserts discriminating variable in data set info
Definition: DataLoader.cxx:491
static void PyInitialize()
Initialize Python interpreter.
static TFile * Open(const char *name, Option_t *option="", const char *ftitle="", Int_t compress=1, Int_t netopt=0)
Create / open a file.
Definition: TFile.cxx:3909
virtual Int_t SetBranchAddress(const char *bname, void *add, TBranch **ptr=0)
Change branch address, dealing with clone trees properly.
Definition: TTree.cxx:7873
double pow(double, double)
IMethod * BookMVA(const TString &methodTag, const TString &weightfile)
read method name from weight file
Definition: Reader.cxx:377
R__EXTERN TSystem * gSystem
Definition: TSystem.h:539
unsigned int UInt_t
Definition: RtypesCore.h:42
TString pythonSrc
void AddRegressionTree(TTree *tree, Double_t weight=1.0, Types::ETreeType treetype=Types::kMaxTreeType)
Definition: DataLoader.h:112
virtual Int_t Exec(const char *shellcmd)
Execute a command.
Definition: TSystem.cxx:660
This is the main MVA steering class.
Definition: Factory.h:81
void PrepareTrainingAndTestTree(const TCut &cut, const TString &splitOpt)
prepare the training and test trees -> same cuts for signal and background
Definition: DataLoader.cxx:629
void AddTarget(const TString &expression, const TString &title="", const TString &unit="", Double_t min=0, Double_t max=0)
user inserts target in data set info
Definition: DataLoader.cxx:509
Double_t EvaluateMVA(const std::vector< Float_t > &, const TString &methodTag, Double_t aux=0)
Evaluate a std::vector<float> of input data for a given method The parameter aux is obligatory for th...
Definition: Reader.cxx:485
The Reader class serves to use the MVAs in a specific analysis context.
Definition: Reader.h:63
Definition: tree.py:1
A TTree object has a header with a name and a title.
Definition: TTree.h:78
int main()