12from ROOT
import TMVA, TFile, TCut, gROOT
13from os.path
import isfile
15from tensorflow.keras.models
import Sequential
16from tensorflow.keras.layers
import Dense
17from tensorflow.keras.optimizers
import SGD
23 model.add(Dense(32, activation=
'relu', input_dim=4))
24 model.add(Dense(4, activation=
'softmax'))
27 model.compile(loss=
'categorical_crossentropy', optimizer=SGD(
28 learning_rate=0.01), weighted_metrics=[
'accuracy',])
31 model.save(
'modelMultiClass.keras')
36 with TFile.Open(
'TMVA.root',
'RECREATE')
as output,
TFile.Open(
'tmva_example_multiple_background.root')
as data:
38 '!V:!Silent:Color:DrawProgressBar:Transformations=D,G:AnalysisType=multiclass')
40 signal = data.Get(
'TreeS')
41 background0 = data.Get(
'TreeB0')
42 background1 = data.Get(
'TreeB1')
43 background2 = data.Get(
'TreeB2')
46 for branch
in signal.GetListOfBranches():
47 dataloader.AddVariable(branch.GetName())
49 dataloader.AddTree(signal,
'Signal')
50 dataloader.AddTree(background0,
'Background_0')
51 dataloader.AddTree(background1,
'Background_1')
52 dataloader.AddTree(background2,
'Background_2')
53 dataloader.PrepareTrainingAndTestTree(
TCut(
''),
54 'SplitMode=Random:NormMode=NumEvents:!V')
57 factory.BookMethod(dataloader, TMVA.Types.kFisher,
'Fisher',
58 '!H:!V:Fisher:VarTransform=D,G')
59 factory.BookMethod(dataloader, TMVA.Types.kPyKeras,
'PyKeras',
60 'H:!V:VarTransform=D,G:FilenameModel=modelMultiClass.keras:FilenameTrainedModel=trainedModelMultiClass.keras:NumEpochs=20:BatchSize=32')
63 factory.TrainAllMethods()
64 factory.TestAllMethods()
65 factory.EvaluateAllMethods()
68if __name__ ==
"__main__":
77 if not isfile(
'tmva_example_multiple_background.root'):
78 createDataMacro = str(gROOT.GetTutorialDir()) +
'/machine_learning/createData.C'
79 print(createDataMacro)
80 gROOT.ProcessLine(
'.L {}'.format(createDataMacro))
81 gROOT.ProcessLine(
'create_MultipleBackground(4000)')
A specialized string object used for TTree selections.
static TFile * Open(const char *name, Option_t *option="", const char *ftitle="", Int_t compress=ROOT::RCompressionSetting::EDefaults::kUseCompiledDefault, Int_t netopt=0)
Create / open a file.
This is the main MVA steering class.
static void PyInitialize()
Initialize Python interpreter.