ROOT
Version v6.34
master
v6.36
v6.32
v6.30
v6.28
v6.26
v6.24
v6.22
v6.20
v6.18
v6.16
v6.14
v6.12
v6.10
v6.08
v6.06
Reference Guide
▼
ROOT
ROOT Reference Documentation
Tutorials
►
Functional Parts
►
Namespaces
►
All Classes
▼
Files
▼
File List
►
bindings
►
core
►
documentation
►
geom
►
graf2d
►
graf3d
►
gui
►
hist
►
html
►
io
►
main
►
math
►
montecarlo
►
net
►
proof
►
roofit
►
sql
►
tmva
►
tree
▼
tutorials
►
cocoa
►
cont
►
dataframe
►
eve
►
eve7
►
fft
►
fit
►
fitsio
►
foam
►
geom
►
gl
►
graphics
►
graphs
►
gui
►
hist
►
histfactory
►
http
►
image
►
io
►
legacy
►
math
►
matrix
►
mc
►
multicore
►
net
►
physics
►
proof
►
pyroot
►
pythia
►
quadp
►
r
►
rcanvas
►
roofit
►
roostats
►
spectrum
►
splot
►
sql
▼
tmva
►
envelope
▼
keras
ApplicationClassificationKeras.py
ApplicationRegressionKeras.py
ClassificationKeras.py
GenerateModel.py
MulticlassKeras.py
RegressionKeras.py
►
pytorch
createData.C
►
PyTorch_Generate_CNN_Model.py
►
RBatchGenerator_filters_vectors.py
RBatchGenerator_NumPy.py
RBatchGenerator_PyTorch.py
RBatchGenerator_TensorFlow.py
tmva001_RTensor.C
tmva002_RDataFrameAsTensor.C
tmva003_RReader.C
tmva004_RStandardScaler.C
tmva100_DataPreparation.py
tmva101_Training.py
tmva102_Testing.py
tmva103_Application.C
TMVA_CNN_Classification.C
TMVA_CNN_Classification.py
TMVA_Higgs_Classification.C
TMVA_Higgs_Classification.py
TMVA_RNN_Classification.C
TMVA_RNN_Classification.py
►
TMVA_SOFIE_GNN.py
►
TMVA_SOFIE_GNN_Application.C
►
TMVA_SOFIE_GNN_Parser.py
TMVA_SOFIE_Inference.py
TMVA_SOFIE_Keras.C
TMVA_SOFIE_Keras_HiggsModel.C
TMVA_SOFIE_Models.py
TMVA_SOFIE_ONNX.C
TMVA_SOFIE_PyTorch.C
TMVA_SOFIE_RDataFrame.C
TMVA_SOFIE_RDataFrame.py
TMVA_SOFIE_RDataFrame_JIT.C
TMVA_SOFIE_RSofieReader.C
TMVAClassification.C
TMVAClassificationApplication.C
TMVAClassificationCategory.C
TMVAClassificationCategoryApplication.C
TMVACrossValidation.C
TMVACrossValidationApplication.C
TMVACrossValidationRegression.C
TMVAGAexample.C
TMVAGAexample2.C
TMVAMinimalClassification.C
TMVAMulticlass.C
TMVAMulticlassApplication.C
TMVAMultipleBackgroundExample.C
TMVARegression.C
TMVARegressionApplication.C
►
tree
►
unfold
►
unuran
►
v7
►
vecops
►
webcanv
►
webgui
►
xml
►
.enableImplicitMTWrapper.py
.rootlogon.py
demos.C
demoshelp.C
hsimple.C
rootlogoff.C
rootlogon.C
►
v6-34-00-patches
►
File Members
Release Notes
•
All
Classes
Namespaces
Files
Functions
Variables
Typedefs
Enumerations
Enumerator
Properties
Friends
Macros
Modules
Pages
Loading...
Searching...
No Matches
ClassificationKeras.py
Go to the documentation of this file.
1
#!/usr/bin/env python
2
# \file
3
# \ingroup tutorial_tmva_keras
4
# \notebook -nodraw
5
# This tutorial shows how to do classification in TMVA with neural networks
6
# trained with keras.
7
#
8
# \macro_code
9
#
10
# \date 2017
11
# \author TMVA Team
12
13
from
ROOT
import
TMVA, TFile, TCut, gROOT
14
from
subprocess
import
call
15
from
os.path
import
isfile
16
17
from
tensorflow.keras.models
import
Sequential
18
from
tensorflow.keras.layers
import
Dense
19
from
tensorflow.keras.optimizers
import
SGD
20
21
22
def
create_model
():
23
# Generate model
24
25
# Define model
26
model =
Sequential
()
27
model.add
(
Dense
(64, activation=
'relu'
, input_dim=4))
28
model.add
(
Dense
(2, activation=
'softmax'
))
29
30
# Set loss and optimizer
31
model.compile
(loss=
'categorical_crossentropy'
,
32
optimizer=
SGD
(learning_rate=0.01), weighted_metrics=[
'accuracy'
, ])
33
34
# Store model to file
35
model.save
(
'modelClassification.h5'
)
36
model.summary
()
37
38
39
def
run():
40
with
TFile.Open
(
'TMVA_Classification_Keras.root'
,
'RECREATE'
)
as
output,
TFile.Open
(str(
gROOT.GetTutorialDir
()) +
'/tmva/data/tmva_class_example.root'
)
as
data:
41
factory =
TMVA.Factory
(
'TMVAClassification'
, output,
42
'!V:!Silent:Color:DrawProgressBar:Transformations=D,G:AnalysisType=Classification'
)
43
44
signal =
data.Get
(
'TreeS'
)
45
background =
data.Get
(
'TreeB'
)
46
47
dataloader =
TMVA.DataLoader
(
'dataset'
)
48
for
branch
in
signal.GetListOfBranches
():
49
dataloader.AddVariable
(
branch.GetName
())
50
51
dataloader.AddSignalTree
(signal, 1.0)
52
dataloader.AddBackgroundTree
(background, 1.0)
53
dataloader.PrepareTrainingAndTestTree
(
TCut
(
''
),
54
'nTrain_Signal=4000:nTrain_Background=4000:SplitMode=Random:NormMode=NumEvents:!V'
)
55
56
# Book methods
57
factory.BookMethod
(dataloader,
TMVA.Types.kFisher
,
'Fisher'
,
58
'!H:!V:Fisher:VarTransform=D,G'
)
59
factory.BookMethod
(dataloader,
TMVA.Types.kPyKeras
,
'PyKeras'
,
60
'H:!V:VarTransform=D,G:FilenameModel=modelClassification.h5:FilenameTrainedModel=trainedModelClassification.h5:NumEpochs=20:BatchSize=32'
)
61
62
# Run training, test and evaluation
63
factory.TrainAllMethods
()
64
factory.TestAllMethods
()
65
factory.EvaluateAllMethods
()
66
67
68
if
__name__ ==
"__main__"
:
69
# Setup TMVA
70
TMVA.Tools.Instance
()
71
TMVA.PyMethodBase.PyInitialize
()
72
73
# Create and store the ML model
74
create_model
()
75
76
# Run TMVA
77
run()
TRangeDynCast
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
Definition
TCollection.h:358
ROOT::Detail::TRangeCast
Definition
TCollection.h:311
TCut
A specialized string object used for TTree selections.
Definition
TCut.h:25
TMVA::DataLoader
Definition
DataLoader.h:50
TMVA::Factory
This is the main MVA steering class.
Definition
Factory.h:80
tutorials
tmva
keras
ClassificationKeras.py
ROOT v6-34 - Reference Guide Generated on Mon Apr 21 2025 16:53:50 (GVA Time) using Doxygen 1.10.0