Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
RegressionKeras.py
Go to the documentation of this file.
1#!/usr/bin/env python
2# \file
3# \ingroup tutorial_tmva_keras
4# \notebook -nodraw
5# This tutorial shows how to do regression in TMVA with neural networks
6# trained with keras.
7#
8# \macro_code
9#
10# \date 2017
11# \author TMVA Team
12
13from ROOT import TMVA, TFile, TCut, gROOT
14from subprocess import call
15from os.path import isfile
16
17from tensorflow.keras.models import Sequential
18from tensorflow.keras.layers import Dense
19from tensorflow.keras.optimizers import SGD
20
21
22def create_model():
23 # Define model
24 model = Sequential()
25 model.add(Dense(64, activation='tanh', input_dim=2))
26 model.add(Dense(1, activation='linear'))
27
28 # Set loss and optimizer
29 model.compile(loss='mean_squared_error', optimizer=SGD(
30 learning_rate=0.01), weighted_metrics=[])
31
32 # Store model to file
33 model.save('modelRegression.h5')
35
36
37def run():
38
39 with TFile.Open('TMVA_Regression_Keras.root', 'RECREATE') as output, TFile.Open(str(gROOT.GetTutorialDir()) + '/tmva/data/tmva_reg_example.root') as data:
40 factory = TMVA.Factory('TMVARegression', output,
41 '!V:!Silent:Color:DrawProgressBar:Transformations=D,G:AnalysisType=Regression')
42
43 tree = data.Get('TreeR')
44
45 dataloader = TMVA.DataLoader('dataset')
46 for branch in tree.GetListOfBranches():
47 name = branch.GetName()
48 if name != 'fvalue':
50 dataloader.AddTarget('fvalue')
51
53 # use only 1000 events since evaluation is very slow (especially on MacOS). Increase it to get meaningful results
55 'nTrain_Regression=1000:SplitMode=Random:NormMode=NumEvents:!V')
56
57 # Book methods
58 factory.BookMethod(dataloader, TMVA.Types.kPyKeras, 'PyKeras',
59 'H:!V:VarTransform=D,G:FilenameModel=modelRegression.h5:FilenameTrainedModel=trainedModelRegression.h5:NumEpochs=20:BatchSize=32')
60 factory.BookMethod(dataloader, TMVA.Types.kBDT, 'BDTG',
61 '!H:!V:VarTransform=D,G:NTrees=1000:BoostType=Grad:Shrinkage=0.1:UseBaggedBoost:BaggedSampleFraction=0.5:nCuts=20:MaxDepth=4')
62
63 # Run TMVA
67
68
69if __name__ == "__main__":
70 # Setup TMVA
73
74 # Generate model
76
77 # Run TMVA
78 run()
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
A specialized string object used for TTree selections.
Definition TCut.h:25
This is the main MVA steering class.
Definition Factory.h:80