13from ROOT 
import TMVA, TFile, TTree, TCut
 
   14from subprocess 
import call
 
   15from os.path 
import isfile
 
   17from tensorflow.keras.models 
import Sequential
 
   18from tensorflow.keras.layers 
import Dense, Activation
 
   19from tensorflow.keras.optimizers 
import SGD
 
   25output = 
TFile.Open(
'TMVA_Classification_Keras..root', 
'RECREATE')
 
   27                       '!V:!Silent:Color:DrawProgressBar:Transformations=D,G:AnalysisType=Classification')
 
   31data = 
TFile.Open(
"http://root.cern.ch/files/tmva_class_example.root", 
"CACHEREAD")
 
   33    raise FileNotFoundError(
"Input file cannot be downloaded - exit")
 
   34signal = data.Get(
'TreeS')
 
   35background = data.Get(
'TreeB')
 
   38for branch 
in signal.GetListOfBranches():
 
   39    dataloader.AddVariable(branch.GetName())
 
   41dataloader.AddSignalTree(signal, 1.0)
 
   42dataloader.AddBackgroundTree(background, 1.0)
 
   43dataloader.PrepareTrainingAndTestTree(
TCut(
''),
 
   44                                      'nTrain_Signal=4000:nTrain_Background=4000:SplitMode=Random:NormMode=NumEvents:!V')
 
   50model.add(Dense(64, activation=
'relu', input_dim=4))
 
   51model.add(Dense(2, activation=
'softmax'))
 
   54model.compile(loss=
'categorical_crossentropy',
 
   55              optimizer=SGD(learning_rate=0.01), weighted_metrics=[
'accuracy', ])
 
   58model.save(
'modelClassification.h5')
 
   62factory.BookMethod(dataloader, TMVA.Types.kFisher, 
'Fisher',
 
   63                   '!H:!V:Fisher:VarTransform=D,G')
 
   64factory.BookMethod(dataloader, TMVA.Types.kPyKeras, 
'PyKeras',
 
   65                   'H:!V:VarTransform=D,G:FilenameModel=modelClassification.h5:FilenameTrainedModel=trainedModelClassification.h5:NumEpochs=20:BatchSize=32')
 
   68factory.TrainAllMethods()
 
   69factory.TestAllMethods()
 
   70factory.EvaluateAllMethods()
 
A specialized string object used for TTree selections.
 
static Bool_t SetCacheFileDir(ROOT::Internal::TStringView cacheDir, Bool_t operateDisconnected=kTRUE, Bool_t forceCacheread=kFALSE)
 
static TFile * Open(const char *name, Option_t *option="", const char *ftitle="", Int_t compress=ROOT::RCompressionSetting::EDefaults::kUseCompiledDefault, Int_t netopt=0)
Create / open a file.
 
This is the main MVA steering class.
 
static void PyInitialize()
Initialize Python interpreter.