Logo ROOT   6.08/07
Reference Guide
testPyKerasClassification.C
Go to the documentation of this file.
1 #include <iostream>
2 
3 #include "TString.h"
4 #include "TFile.h"
5 #include "TTree.h"
6 #include "TSystem.h"
7 #include "TMVA/Factory.h"
8 #include "TMVA/Reader.h"
9 #include "TMVA/DataLoader.h"
10 #include "TMVA/PyMethodBase.h"
11 
13 from keras.models import Sequential\n\
14 from keras.layers.core import Dense, Activation\n\
15 from keras import initializations\n\
16 from keras.optimizers import SGD\n\
17 \n\
18 model = Sequential()\n\
19 model.add(Dense(64, init=\"normal\", activation=\"relu\", input_dim=4))\n\
20 model.add(Dense(2, init=\"normal\", activation=\"softmax\"))\n\
21 model.compile(loss=\"categorical_crossentropy\", optimizer=SGD(lr=0.01), metrics=[\"accuracy\",])\n\
22 model.save(\"kerasModelClassification.h5\")\n";
23 
25  // Get data file
26  std::cout << "Get test data..." << std::endl;
27  TString fname = "./tmva_class_example.root";
28  if (gSystem->AccessPathName(fname)) // file does not exist in local directory
29  gSystem->Exec("curl -O http://root.cern.ch/files/tmva_class_example.root");
30  TFile *input = TFile::Open(fname);
31 
32  // Build model from python file
33  std::cout << "Generate keras model..." << std::endl;
34  UInt_t ret;
35  ret = gSystem->Exec("echo '"+pythonSrc+"' > generateKerasModelClassification.py");
36  if(ret!=0){
37  std::cout << "[ERROR] Failed to write python code to file" << std::endl;
38  return 1;
39  }
40  ret = gSystem->Exec("python generateKerasModelClassification.py");
41  if(ret!=0){
42  std::cout << "[ERROR] Failed to generate model using python" << std::endl;
43  return 1;
44  }
45 
46  // Setup PyMVA and factory
47  std::cout << "Setup TMVA..." << std::endl;
49  TFile* outputFile = TFile::Open("ResultsTestPyKerasClassification.root", "RECREATE");
50  TMVA::Factory *factory = new TMVA::Factory("testPyKerasClassification", outputFile,
51  "!V:Silent:Color:!DrawProgressBar:AnalysisType=Classification");
52 
53  // Load data
54  TMVA::DataLoader *dataloader = new TMVA::DataLoader("datasetTestPyKerasClassification");
55 
56  TTree *signal = (TTree*)input->Get("TreeS");
57  TTree *background = (TTree*)input->Get("TreeB");
58  dataloader->AddSignalTree(signal);
59  dataloader->AddBackgroundTree(background);
60 
61  dataloader->AddVariable("var1");
62  dataloader->AddVariable("var2");
63  dataloader->AddVariable("var3");
64  dataloader->AddVariable("var4");
65 
66  dataloader->PrepareTrainingAndTestTree("",
67  "SplitMode=Random:NormMode=NumEvents:!V");
68 
69  // Book and train method
70  factory->BookMethod(dataloader, TMVA::Types::kPyKeras, "PyKeras",
71  "!H:!V:VarTransform=D,G:FilenameModel=kerasModelClassification.h5:FilenameTrainedModel=trainedKerasModelClassification.h5:NumEpochs=10:BatchSize=32:SaveBestOnly=false:Verbose=0");
72  std::cout << "Train model..." << std::endl;
73  factory->TrainAllMethods();
74 
75  // Clean-up
76  delete factory;
77  delete dataloader;
78  delete outputFile;
79 
80  // Setup reader
81  UInt_t numEvents = 100;
82  std::cout << "Run reader and classify " << numEvents << " events..." << std::endl;
83  TMVA::Reader *reader = new TMVA::Reader("!Color:Silent");
84  Float_t vars[4];
85  reader->AddVariable("var1", vars+0);
86  reader->AddVariable("var2", vars+1);
87  reader->AddVariable("var3", vars+2);
88  reader->AddVariable("var4", vars+3);
89  reader->BookMVA("PyKeras", "datasetTestPyKerasClassification/weights/testPyKerasClassification_PyKeras.weights.xml");
90 
91  // Get mean response of method on signal and background events
92  signal->SetBranchAddress("var1", vars+0);
93  signal->SetBranchAddress("var2", vars+1);
94  signal->SetBranchAddress("var3", vars+2);
95  signal->SetBranchAddress("var4", vars+3);
96 
97  background->SetBranchAddress("var1", vars+0);
98  background->SetBranchAddress("var2", vars+1);
99  background->SetBranchAddress("var3", vars+2);
100  background->SetBranchAddress("var4", vars+3);
101 
102  Float_t meanMvaSignal = 0;
103  Float_t meanMvaBackground = 0;
104  for(UInt_t i=0; i<numEvents; i++){
105  signal->GetEntry(i);
106  meanMvaSignal += reader->EvaluateMVA("PyKeras");
107  background->GetEntry(i);
108  meanMvaBackground += reader->EvaluateMVA("PyKeras");
109  }
110  meanMvaSignal = meanMvaSignal/float(numEvents);
111  meanMvaBackground = meanMvaBackground/float(numEvents);
112 
113  // Check whether the response is obviously better than guessing
114  std::cout << "Mean MVA response on signal: " << meanMvaSignal << std::endl;
115  if(meanMvaSignal < 0.6){
116  std::cout << "[ERROR] Mean response on signal is " << meanMvaSignal << " (<0.6)" << std::endl;
117  return 1;
118  }
119  std::cout << "Mean MVA response on background: " << meanMvaBackground << std::endl;
120  if(meanMvaBackground > 0.4){
121  std::cout << "[ERROR] Mean response on background is " << meanMvaBackground << " (>0.4)" << std::endl;
122  return 1;
123  }
124 
125  return 0;
126 }
127 
128 int main(){
129  int err = testPyKerasClassification();
130  return err;
131 }
void AddBackgroundTree(TTree *background, Double_t weight=1.0, Types::ETreeType treetype=Types::kMaxTreeType)
Definition: DataLoader.cxx:382
virtual Bool_t AccessPathName(const char *path, EAccessMode mode=kFileExists)
Returns FALSE if one can access a file using the specified access mode.
Definition: TSystem.cxx:1266
MethodBase * BookMethod(DataLoader *loader, TString theMethodName, TString methodTitle, TString theOption="")
Definition: Factory.cxx:337
float Float_t
Definition: RtypesCore.h:53
void AddVariable(const TString &expression, Float_t *)
Add a float variable or expression to the reader.
Definition: Reader.cxx:309
A ROOT file is a suite of consecutive data records (TKey instances) with a well defined format...
Definition: TFile.h:50
virtual TObject * Get(const char *namecycle)
Return pointer to object identified by namecycle.
virtual Int_t GetEntry(Long64_t entry=0, Int_t getall=0)
Read all branches of entry and return total number of bytes read.
Definition: TTree.cxx:5211
Basic string class.
Definition: TString.h:137
void TrainAllMethods()
iterates through all booked methods and calls training
Definition: Factory.cxx:822
void AddVariable(const TString &expression, const TString &title, const TString &unit, char type='F', Double_t min=0, Double_t max=0)
Definition: DataLoader.cxx:456
static void PyInitialize()
static TFile * Open(const char *name, Option_t *option="", const char *ftitle="", Int_t compress=1, Int_t netopt=0)
Create / open a file.
Definition: TFile.cxx:3907
virtual Int_t SetBranchAddress(const char *bname, void *add, TBranch **ptr=0)
Change branch address, dealing with clone trees properly.
Definition: TTree.cxx:7760
IMethod * BookMVA(const TString &methodTag, const TString &weightfile)
read method name from weight file
Definition: Reader.cxx:378
int testPyKerasClassification()
R__EXTERN TSystem * gSystem
Definition: TSystem.h:549
TString pythonSrc
unsigned int UInt_t
Definition: RtypesCore.h:42
virtual Int_t Exec(const char *shellcmd)
Execute a command.
Definition: TSystem.cxx:658
void PrepareTrainingAndTestTree(const TCut &cut, const TString &splitOpt)
Definition: DataLoader.cxx:580
Double_t EvaluateMVA(const std::vector< Float_t > &, const TString &methodTag, Double_t aux=0)
Evaluate a std::vector<float> of input data for a given method The parameter aux is obligatory for th...
Definition: Reader.cxx:486
A TTree object has a header with a name and a title.
Definition: TTree.h:98
void AddSignalTree(TTree *signal, Double_t weight=1.0, Types::ETreeType treetype=Types::kMaxTreeType)
Definition: DataLoader.cxx:353