38 bool useLikelihood =
true;
39 bool useLikelihoodKDE =
false;
40 bool useFischer =
true;
45 bool usePyTorch =
true;
58 auto outputFile =
TFile::Open(
"Higgs_ClassificationOutput.root",
"RECREATE");
60 TMVA::Factory factory(
"TMVA_Higgs_Classification", outputFile,
61 "!V:ROC:!Silent:Color:AnalysisType=Classification" );
71 TString inputFileName =
"Higgs_data.root";
72 TString inputFileLink =
"http://root.cern/files/" + inputFileName;
74 TFile *inputFile =
nullptr;
83 Info(
"TMVA_Higgs_Classification",
"Download Higgs_data.root file");
85 inputFile =
TFile::Open(inputFileLink,
"CACHEREAD");
87 Error(
"TMVA_Higgs_Classification",
"Input file cannot be downloaded - exit");
147 "nTrain_Signal=7000:nTrain_Background=7000:SplitMode=Random:NormMode=NumEvents:!V" );
161 "H:!V:TransformOutput:PDFInterpol=Spline2:NSmoothSig[0]=20:NSmoothBkg[0]=20:NSmoothBkg[1]=10:NSmooth=1:NAvEvtPerBin=50" );
164if (useLikelihoodKDE) {
166 "!H:!V:!TransformOutput:PDFInterpol=KDE:KDEtype=Gauss:KDEiter=Adaptive:KDEFineFactor=0.3:KDEborder=None:NAvEvtPerBin=50" );
172 factory.BookMethod(loader,
TMVA::Types::kFisher,
"Fisher",
"H:!V:Fisher:VarTransform=None:CreateMVAPdfs:PDFInterpolMVAPdf=Spline2:NbinsMVAPdf=50:NsmoothMVAPdf=10" );
178 "!V:NTrees=200:MinNodeSize=2.5%:MaxDepth=2:BoostType=AdaBoost:AdaBoostBeta=0.5:UseBaggedBoost:BaggedSampleFraction=0.5:SeparationType=GiniIndex:nCuts=20" );
184 "!H:!V:NeuronType=tanh:VarTransform=N:NCycles=100:HiddenLayers=N+5:TestRate=5:!UseRegulator" );
254 bool useDLGPU =
false;
260 TString inputLayoutString =
"InputLayout=1|1|7";
261 TString batchLayoutString=
"BatchLayout=1|128|7";
262 TString layoutString (
"Layout=DENSE|64|TANH,DENSE|64|TANH,DENSE|64|TANH,DENSE|64|TANH,DENSE|1|LINEAR");
265 TString training1(
"LearningRate=1e-3,Momentum=0.9,"
266 "ConvergenceSteps=10,BatchSize=128,TestRepetitions=1,"
267 "MaxEpochs=30,WeightDecay=1e-4,Regularization=None,"
268 "Optimizer=ADAM,ADAM_beta1=0.9,ADAM_beta2=0.999,ADAM_eps=1.E-7,"
269 "DropConfig=0.0+0.0+0.0+0.");
275 TString trainingStrategyString (
"TrainingStrategy=");
276 trainingStrategyString += training1;
280 TString dnnOptions (
"!H:V:ErrorStrategy=CROSSENTROPY:VarTransform=G:"
281 "WeightInitialization=XAVIER");
282 dnnOptions.Append (
":"); dnnOptions.Append (inputLayoutString);
283 dnnOptions.Append (
":"); dnnOptions.Append (batchLayoutString);
284 dnnOptions.Append (
":"); dnnOptions.Append (layoutString);
285 dnnOptions.Append (
":"); dnnOptions.Append (trainingStrategyString);
287 TString dnnMethodName =
"DNN_CPU";
289 dnnOptions +=
":Architecture=GPU";
290 dnnMethodName =
"DNN_GPU";
292 dnnOptions +=
":Architecture=CPU";
301 Info(
"TMVA_Higgs_Classification",
"Building deep neural network with keras ");
305 m.AddLine(
"import tensorflow");
306 m.AddLine(
"from tensorflow.keras.models import Sequential");
307 m.AddLine(
"from tensorflow.keras.optimizers import Adam");
308 m.AddLine(
"from tensorflow.keras.layers import Input, Dense");
310 m.AddLine(
"model = Sequential() ");
311 m.AddLine(
"model.add(Dense(64, activation='relu',input_dim=7))");
312 m.AddLine(
"model.add(Dense(64, activation='relu'))");
313 m.AddLine(
"model.add(Dense(64, activation='relu'))");
314 m.AddLine(
"model.add(Dense(64, activation='relu'))");
315 m.AddLine(
"model.add(Dense(2, activation='sigmoid'))");
316 m.AddLine(
"model.compile(loss = 'binary_crossentropy', optimizer = Adam(learning_rate = 0.001), weighted_metrics = ['accuracy'])");
317 m.AddLine(
"model.save('Higgs_model.h5')");
318 m.AddLine(
"model.summary()");
320 m.SaveSource(
"make_higgs_model.py");
322 auto ret = (
TString *)
gROOT->ProcessLine(
"TMVA::Python_Executable()");
323 TString python_exe = (ret) ? *(ret) :
"python";
324 gSystem->
Exec(python_exe +
" make_higgs_model.py");
327 Warning(
"TMVA_Higgs_Classification",
"Error creating Keras model file - skip using Keras");
330 Info(
"TMVA_Higgs_Classification",
"Booking tf.Keras Dense model");
333 "H:!V:VarTransform=None:FilenameModel=Higgs_model.h5:tf.keras:"
334 "FilenameTrainedModel=Higgs_trained_model.h5:NumEpochs=20:BatchSize=100:"
335 "GpuOptions=allow_growth=True");
346 factory.TrainAllMethods();
354 factory.TestAllMethods();
356 factory.EvaluateAllMethods();
360 auto c1 = factory.GetROCCurve(loader);
void Info(const char *location, const char *msgfmt,...)
Use this function for informational messages.
void Error(const char *location, const char *msgfmt,...)
Use this function in case an error occurred.
void Warning(const char *location, const char *msgfmt,...)
Use this function in warning situations.
R__EXTERN TSystem * gSystem
A specialized string object used for TTree selections.
TObject * Get(const char *namecycle) override
Return pointer to object identified by namecycle.
A ROOT file is an on-disk file, usually with extension .root, that stores objects in a file-system-li...
static TFile * Open(const char *name, Option_t *option="", const char *ftitle="", Int_t compress=ROOT::RCompressionSetting::EDefaults::kUseCompiledDefault, Int_t netopt=0)
Create / open a file.
static Bool_t SetCacheFileDir(std::string_view cacheDir, Bool_t operateDisconnected=kTRUE, Bool_t forceCacheread=kFALSE)
Sets the directory where to locally stage/cache remote files.
void AddSignalTree(TTree *signal, Double_t weight=1.0, Types::ETreeType treetype=Types::kMaxTreeType)
number of signal events (used to compute significance)
void PrepareTrainingAndTestTree(const TCut &cut, const TString &splitOpt)
prepare the training and test trees -> same cuts for signal and background
void AddBackgroundTree(TTree *background, Double_t weight=1.0, Types::ETreeType treetype=Types::kMaxTreeType)
number of signal events (used to compute significance)
void AddVariable(const TString &expression, const TString &title, const TString &unit, char type='F', Double_t min=0, Double_t max=0)
user inserts discriminating variable in data set info
This is the main MVA steering class.
static void PyInitialize()
Initialize Python interpreter.
Class supporting a collection of lines with C++ code.
virtual Int_t Exec(const char *shellcmd)
Execute a command.
virtual Bool_t AccessPathName(const char *path, EAccessMode mode=kFileExists)
Returns FALSE if one can access a file using the specified access mode.
virtual void Setenv(const char *name, const char *value)
Set environment variable.
A TTree represents a columnar dataset.
void Print(Option_t *option="") const override
Print a summary of the tree contents.