13from ROOT
import TMVA, TFile, TTree, TCut, gROOT
14from os.path
import isfile
16from tensorflow.keras.models
import Sequential
17from tensorflow.keras.layers
import Dense, Activation
18from tensorflow.keras.optimizers
import SGD
26 '!V:!Silent:Color:DrawProgressBar:Transformations=D,G:AnalysisType=multiclass')
29if not isfile(
'tmva_example_multiple_background.root'):
30 createDataMacro = str(gROOT.GetTutorialDir()) +
'/tmva/createData.C'
31 print(createDataMacro)
32 gROOT.ProcessLine(
'.L {}'.format(createDataMacro))
33 gROOT.ProcessLine(
'create_MultipleBackground(4000)')
35data =
TFile.Open(
'tmva_example_multiple_background.root')
36signal = data.Get(
'TreeS')
37background0 = data.Get(
'TreeB0')
38background1 = data.Get(
'TreeB1')
39background2 = data.Get(
'TreeB2')
42for branch
in signal.GetListOfBranches():
43 dataloader.AddVariable(branch.GetName())
45dataloader.AddTree(signal,
'Signal')
46dataloader.AddTree(background0,
'Background_0')
47dataloader.AddTree(background1,
'Background_1')
48dataloader.AddTree(background2,
'Background_2')
49dataloader.PrepareTrainingAndTestTree(
TCut(
''),
50 'SplitMode=Random:NormMode=NumEvents:!V')
56model.add(Dense(32, activation=
'relu', input_dim=4))
57model.add(Dense(4, activation=
'softmax'))
60model.compile(loss=
'categorical_crossentropy', optimizer=SGD(learning_rate=0.01), metrics=[
'accuracy',])
63model.save(
'modelMultiClass.h5')
67factory.BookMethod(dataloader, TMVA.Types.kFisher,
'Fisher',
68 '!H:!V:Fisher:VarTransform=D,G')
69factory.BookMethod(dataloader, TMVA.Types.kPyKeras,
'PyKeras',
70 'H:!V:VarTransform=D,G:FilenameModel=modelMultiClass.h5:FilenameTrainedModel=trainedModelMultiClass.h5:NumEpochs=20:BatchSize=32')
73factory.TrainAllMethods()
74factory.TestAllMethods()
75factory.EvaluateAllMethods()
A specialized string object used for TTree selections.
static TFile * Open(const char *name, Option_t *option="", const char *ftitle="", Int_t compress=ROOT::RCompressionSetting::EDefaults::kUseCompiledDefault, Int_t netopt=0)
Create / open a file.
This is the main MVA steering class.
static void PyInitialize()
Initialize Python interpreter.