Logo ROOT  
Reference Guide
 
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Modules Pages
Loading...
Searching...
No Matches
RegressionKeras.py File Reference

Detailed Description

View in nbviewer Open in SWAN
This tutorial shows how to do regression in TMVA with neural networks trained with keras.

from ROOT import TMVA, TFile, TTree, TCut, gROOT
from subprocess import call
from os.path import isfile
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Activation
# Setup TMVA
output = TFile.Open('TMVA_Regression_Keras.root', 'RECREATE')
factory = TMVA.Factory('TMVARegression', output,
'!V:!Silent:Color:DrawProgressBar:Transformations=D,G:AnalysisType=Regression')
# Load data
if not isfile('tmva_reg_example.root'):
call(['curl', '-L', '-O', 'http://root.cern/files/tmva_reg_example.root'])
data = TFile.Open(str(gROOT.GetTutorialDir()) + '/tmva/data/tmva_reg_example.root')
tree = data.Get('TreeR')
dataloader = TMVA.DataLoader('dataset')
for branch in tree.GetListOfBranches():
name = branch.GetName()
if name != 'fvalue':
#use only 1000 events since evaluation is very slow (especially on MacOS). Increase it to get meaningful results
'nTrain_Regression=1000:SplitMode=Random:NormMode=NumEvents:!V')
# Generate model
# Define model
model = Sequential()
model.add(Dense(64, activation='tanh', input_dim=2))
model.add(Dense(1, activation='linear'))
# Set loss and optimizer
model.compile(loss='mean_squared_error', optimizer=SGD(learning_rate=0.01), weighted_metrics=[])
# Store model to file
model.save('modelRegression.h5')
# Book methods
factory.BookMethod(dataloader, TMVA.Types.kPyKeras, 'PyKeras',
'H:!V:VarTransform=D,G:FilenameModel=modelRegression.h5:FilenameTrainedModel=trainedModelRegression.h5:NumEpochs=20:BatchSize=32')
factory.BookMethod(dataloader, TMVA.Types.kBDT, 'BDTG',
'!H:!V:VarTransform=D,G:NTrees=1000:BoostType=Grad:Shrinkage=0.1:UseBaggedBoost:BaggedSampleFraction=0.5:nCuts=20:MaxDepth=4')
# Run TMVA
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
A specialized string object used for TTree selections.
Definition TCut.h:25
This is the main MVA steering class.
Definition Factory.h:80
Date
2017
Author
TMVA Team

Definition in file RegressionKeras.py.