from ROOT import TMVA, TFile, TTree, TCut
from subprocess import call
from os.path import isfile
 
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Activation
from tensorflow.keras.optimizers import SGD
 
 
output = 
TFile.Open(
'TMVA_Classification_Keras..root', 
'RECREATE')
 
                       '!V:!Silent:Color:DrawProgressBar:Transformations=D,G:AnalysisType=Classification')
 
data = 
TFile.Open(
"http://root.cern.ch/files/tmva_class_example.root", 
"CACHEREAD")
 
if data is None:
    raise FileNotFoundError("Input file cannot be downloaded - exit")
signal = data.Get('TreeS')
background = data.Get('TreeB')
 
for branch in signal.GetListOfBranches():
    dataloader.AddVariable(branch.GetName())
 
dataloader.AddSignalTree(signal, 1.0)
dataloader.AddBackgroundTree(background, 1.0)
dataloader.PrepareTrainingAndTestTree(
TCut(
''),
 
                                      'nTrain_Signal=4000:nTrain_Background=4000:SplitMode=Random:NormMode=NumEvents:!V')
 
 
model = Sequential()
model.add(Dense(64, activation='relu', input_dim=4))
model.add(Dense(2, activation='softmax'))
 
model.compile(loss='categorical_crossentropy',
              optimizer=SGD(learning_rate=0.01), weighted_metrics=['accuracy', ])
 
model.save('modelClassification.h5')
model.summary()
 
factory.BookMethod(dataloader, TMVA.Types.kFisher, 'Fisher',
                   '!H:!V:Fisher:VarTransform=D,G')
factory.BookMethod(dataloader, TMVA.Types.kPyKeras, 'PyKeras',
                   'H:!V:VarTransform=D,G:FilenameModel=modelClassification.h5:FilenameTrainedModel=trainedModelClassification.h5:NumEpochs=20:BatchSize=32')
 
factory.TrainAllMethods()
factory.TestAllMethods()
factory.EvaluateAllMethods()
A specialized string object used for TTree selections.
 
static Bool_t SetCacheFileDir(ROOT::Internal::TStringView cacheDir, Bool_t operateDisconnected=kTRUE, Bool_t forceCacheread=kFALSE)
 
static TFile * Open(const char *name, Option_t *option="", const char *ftitle="", Int_t compress=ROOT::RCompressionSetting::EDefaults::kUseCompiledDefault, Int_t netopt=0)
Create / open a file.
 
This is the main MVA steering class.
 
static void PyInitialize()
Initialize Python interpreter.