Logo ROOT  
Reference Guide
Classification.cxx
Go to the documentation of this file.
1// @(#)root/tmva $Id$
2// Author: Omar Zapata, Andreas Hoecker, Peter Speckmayer, Joerg Stelzer, Helge Voss, Kai Voss, Eckhard von Toerne, Jan
3// Therhaag
4
6
8#include <TMVA/Config.h>
9#include <TMVA/Configurable.h>
10#include <TMVA/Tools.h>
11#include <TMVA/Ranking.h>
12#include <TMVA/DataSet.h>
13#include <TMVA/IMethod.h>
14#include <TMVA/MethodBase.h>
16#include <TMVA/DataSetManager.h>
17#include <TMVA/DataSetInfo.h>
18#include <TMVA/DataLoader.h>
19#include <TMVA/MethodBoost.h>
20#include <TMVA/MethodCategory.h>
21#include <TMVA/ROCCalc.h>
22#include <TMVA/ROCCurve.h>
23#include <TMVA/MsgLogger.h>
24
25#include <TMVA/VariableInfo.h>
27
28#include <TMVA/Types.h>
29
30#include <TROOT.h>
31#include <TFile.h>
32#include <TTree.h>
33#include <TLeaf.h>
34#include <TEventList.h>
35#include <TH2.h>
36#include <TText.h>
37#include <TLegend.h>
38#include <TGraph.h>
39#include <TStyle.h>
40#include <TMatrixF.h>
41#include <TMatrixDSym.h>
42#include <TMultiGraph.h>
43#include <TPaletteAxis.h>
44#include <TPrincipal.h>
45#include <TMath.h>
46#include <TObjString.h>
47#include <TSystem.h>
48#include <TCanvas.h>
49#include <iostream>
50#include <memory>
51#define MinNoTrainingEvents 10
52
53//_______________________________________________________________________
55{
56}
57
58//_______________________________________________________________________
60{
61 fMethod = cr.fMethod;
64 fMvaTest = cr.fMvaTest;
65 fIsCuts = cr.fIsCuts;
67}
68
69//_______________________________________________________________________
70/**
71 * Method to get ROC-Integral value from mvas.
72 * \param iClass category, default 0 then signal
73 * \param type train/test tree, default test.
74 * \return Double_t with the ROC-Integral value.
75 */
77{
78 if (fIsCuts) {
79 return fROCIntegral;
80 } else {
81 auto roc = GetROC(iClass, type);
82 auto inte = roc->GetROCIntegral();
83 delete roc;
84 return inte;
85 }
86}
87
88//_______________________________________________________________________
89/**
90 * Method to get TMVA::ROCCurve Object.
91 * \param iClass category, default 0 then signal
92 * \param type train/test tree, default test.
93 * \return TMVA::ROCCurve object.
94 */
96{
97 ROCCurve *fROCCurve = nullptr;
99 fROCCurve = new ROCCurve(fMvaTest[iClass]);
100 else
101 fROCCurve = new ROCCurve(fMvaTrain[iClass]);
102 return fROCCurve;
103}
104
105//_______________________________________________________________________
108{
109 fMethod = cr.fMethod;
110 fDataLoaderName = cr.fDataLoaderName;
111 fMvaTrain = cr.fMvaTrain;
112 fMvaTest = cr.fMvaTest;
113 fIsCuts = cr.fIsCuts;
114 fROCIntegral = cr.fROCIntegral;
115 return *this;
116}
117
118//_______________________________________________________________________
119/**
120 * Method to print the results in stdout.
121 * data loader name, method name/tittle and ROC-integ.
122 */
124{
125 MsgLogger fLogger("Classification");
128 TString hLine = "--------------------------------------------------- :";
129
130 fLogger << kINFO << hLine << Endl;
131 fLogger << kINFO << "DataSet MVA :" << Endl;
132 fLogger << kINFO << "Name: Method/Title: ROC-integ :" << Endl;
133 fLogger << kINFO << hLine << Endl;
134 fLogger << kINFO << Form("%-20s %-15s %#1.3f :", fDataLoaderName.Data(),
135 Form("%s/%s", fMethod.GetValue<TString>("MethodName").Data(),
136 fMethod.GetValue<TString>("MethodTitle").Data()),
137 GetROCIntegral())
138 << Endl;
139 fLogger << kINFO << hLine << Endl;
140
142}
143
144//_______________________________________________________________________
145/**
146 * Method to get TGraph object with the ROC curve.
147 * \param iClass category, default 0 then signal
148 * \param type train/test tree, default test.
149 * \return TGraph object.
150 */
152{
153 TGraph *roc = GetROC(iClass, type)->GetROCCurve();
154 roc->SetName(Form("%s/%s", GetMethodName().Data(), GetMethodTitle().Data()));
155 roc->SetTitle(Form("%s/%s", GetMethodName().Data(), GetMethodTitle().Data()));
156 roc->GetXaxis()->SetTitle(" Signal Efficiency ");
157 roc->GetYaxis()->SetTitle(" Background Rejection ");
158 return roc;
159}
160
161//_______________________________________________________________________
162/**
163 * Method to check if method was booked.
164 * \param methodname name of the method.
165 * \param methodtitle method title.
166 * \return boolean true if the method was booked, false in other case.
167 */
169{
170 return fMethod.GetValue<TString>("MethodName") == methodname &&
171 fMethod.GetValue<TString>("MethodTitle") == methodtitle
172 ? kTRUE
173 : kFALSE;
174}
175
176//_______________________________________________________________________
177/**
178 * Contructor to create a two class classifier.
179 * \param dataloader TMVA::DataLoader object with the data to train/test.
180 * \param file TFile object to save the results
181 * \param options string extra options.
182 */
184 : TMVA::Envelope("Classification", dataloader, file, options), fAnalysisType(Types::kClassification),
185 fCorrelations(kFALSE), fROC(kTRUE)
186{
187 DeclareOptionRef(fCorrelations, "Correlations", "boolean to show correlation in output");
188 DeclareOptionRef(fROC, "ROC", "boolean to show ROC in output");
189 ParseOptions();
191
193 gSystem->MakeDirectory(fDataLoader->GetName()); // creating directory for DataLoader output
194}
195
196//_______________________________________________________________________
197/**
198 * Contructor to create a two class classifier without output file.
199 * \param dataloader TMVA::DataLoader object with the data to train/test.
200 * \param options string extra options.
201 */
203 : TMVA::Envelope("Classification", dataloader, NULL, options), fAnalysisType(Types::kClassification),
204 fCorrelations(kFALSE), fROC(kTRUE)
205{
206
207 // init configurable
208 SetConfigDescription("Configuration options for Classification running");
210
211 DeclareOptionRef(fCorrelations, "Correlations", "boolean to show correlation in output");
212 DeclareOptionRef(fROC, "ROC", "boolean to show ROC in output");
213 ParseOptions();
216 gSystem->MakeDirectory(fDataLoader->GetName()); // creating directory for DataLoader output
218}
219
220//_______________________________________________________________________
222{
223 for (auto m : fIMethods) {
224 if (m != NULL)
225 delete m;
226 }
227}
228
229//_______________________________________________________________________
230/**
231 * return the options for the booked method.
232 * \param methodname name of the method.
233 * \param methodtitle method title.
234 * \return string the with options for the ml method.
235 */
237{
238 for (auto &meth : fMethods) {
239 if (meth.GetValue<TString>("MethodName") == methodname && meth.GetValue<TString>("MethodTitle") == methodtitle)
240 return meth.GetValue<TString>("MethodOptions");
241 }
242 return "";
243}
244
245//_______________________________________________________________________
246/**
247 * Method to perform Train/Test over all ml method booked.
248 * If the option Jobs > 1 can do it in parallel with MultiProc.
249 */
251{
252 fTimer.Reset();
253 fTimer.Start();
254
255 Bool_t roc = fROC;
256 fROC = kFALSE;
257 if (fJobs <= 1) {
258 Train();
259 Test();
260 } else {
261 for (auto &meth : fMethods) {
262 GetMethod(meth.GetValue<TString>("MethodName"), meth.GetValue<TString>("MethodTitle"));
263 }
264#ifndef _MSC_VER
265 fWorkers.SetNWorkers(fJobs);
266#endif
267 auto executor = [=](UInt_t workerID) -> ClassificationResult {
272 auto methodname = fMethods[workerID].GetValue<TString>("MethodName");
273 auto methodtitle = fMethods[workerID].GetValue<TString>("MethodTitle");
274 auto meth = GetMethod(methodname, methodtitle);
275 if (!IsSilentFile()) {
276 auto fname = Form(".%s%s%s.root", fDataLoader->GetName(), methodname.Data(), methodtitle.Data());
277 auto f = new TFile(fname, "RECREATE");
278 f->mkdir(fDataLoader->GetName());
279 SetFile(f);
280 meth->SetFile(f);
281 }
282 TrainMethod(methodname, methodtitle);
283 TestMethod(methodname, methodtitle);
284 if (!IsSilentFile()) {
285 GetFile()->Close();
286 }
287 return GetResults(methodname, methodtitle);
288 };
289
290#ifndef _MSC_VER
291 fResults = fWorkers.Map(executor, ROOT::TSeqI(fMethods.size()));
292#endif
293 if (!IsSilentFile())
294 MergeFiles();
295 }
296
297 fROC = roc;
299
300 TString hLine = "--------------------------------------------------- :";
301 Log() << kINFO << hLine << Endl;
302 Log() << kINFO << "DataSet MVA :" << Endl;
303 Log() << kINFO << "Name: Method/Title: ROC-integ :" << Endl;
304 Log() << kINFO << hLine << Endl;
305 for (auto &r : fResults) {
306
307 Log() << kINFO << Form("%-20s %-15s %#1.3f :", r.GetDataLoaderName().Data(),
308 Form("%s/%s", r.GetMethodName().Data(), r.GetMethodTitle().Data()), r.GetROCIntegral())
309 << Endl;
310 }
311 Log() << kINFO << hLine << Endl;
312
313 Log() << kINFO << "-----------------------------------------------------" << Endl;
314 Log() << kHEADER << "Evaluation done." << Endl << Endl;
315 Log() << kINFO << Form("Jobs = %d Real Time = %lf ", fJobs, fTimer.RealTime()) << Endl;
316 Log() << kINFO << "-----------------------------------------------------" << Endl;
317 Log() << kINFO << "Evaluation done." << Endl;
319}
320
321//_______________________________________________________________________
322/**
323 * Method to train all booked ml methods.
324 */
326{
327 for (auto &meth : fMethods) {
328 TrainMethod(meth.GetValue<TString>("MethodName"), meth.GetValue<TString>("MethodTitle"));
329 }
330}
331
332//_______________________________________________________________________
333/**
334 * Lets train an specific ml method.
335 * \param methodname name of the method.
336 * \param methodtitle method title.
337 */
339{
340 auto method = GetMethod(methodname, methodtitle);
341 if (!method) {
342 Log() << kFATAL
343 << Form("Trying to train method %s %s that maybe is not booked.", methodname.Data(), methodtitle.Data())
344 << Endl;
345 }
346 Log() << kHEADER << gTools().Color("bold") << Form("Training method %s %s", methodname.Data(), methodtitle.Data())
347 << gTools().Color("reset") << Endl;
348
350 if ((fAnalysisType == Types::kMulticlass || fAnalysisType == Types::kClassification) &&
351 method->DataInfo().GetNClasses() < 2)
352 Log() << kFATAL << "You want to do classification training, but specified less than two classes." << Endl;
353
354 // first print some information about the default dataset
355 // if(!IsSilentFile()) WriteDataInformation(method->fDataSetInfo);
356
357 if (method->Data()->GetNTrainingEvents() < MinNoTrainingEvents) {
358 Log() << kWARNING << "Method " << method->GetMethodName() << " not trained (training tree has less entries ["
359 << method->Data()->GetNTrainingEvents() << "] than required [" << MinNoTrainingEvents << "]" << Endl;
360 return;
361 }
362
363 Log() << kHEADER << "Train method: " << method->GetMethodName() << " for Classification" << Endl << Endl;
364 method->TrainMethod();
365 Log() << kHEADER << "Training finished" << Endl << Endl;
366}
367
368//_______________________________________________________________________
369/**
370 * Lets train an specific ml method given the method type in enum TMVA::Types::EMVA
371 * \param method TMVA::Types::EMVA type.
372 * \param methodtitle method title.
373 */
375{
376 TrainMethod(Types::Instance().GetMethodName(method), methodtitle);
377}
378
379//_______________________________________________________________________
380/**
381 * Return a TMVA::MethodBase object. if method is not booked then return a null
382 * pointer.
383 * \param methodname name of the method.
384 * \param methodtitle method title.
385 * \return TMVA::MethodBase object
386 */
388{
389
390 if (!HasMethod(methodname, methodtitle)) {
391 std::cout << methodname << " " << methodtitle << std::endl;
392 Log() << kERROR << "Trying to get method not booked." << Endl;
393 return 0;
394 }
395 Int_t index = -1;
396 if (HasMethodObject(methodname, methodtitle, index)) {
397 return dynamic_cast<MethodBase *>(fIMethods[index]);
398 }
399 // if is not created then lets to create it.
400 if (GetDataLoaderDataInput().GetEntries() <=
401 1) { // 0 entries --> 0 events, 1 entry --> dynamical dataset (or one entry)
402 Log() << kFATAL << "No input data for the training provided!" << Endl;
403 }
404 Log() << kHEADER << "Loading booked method: " << gTools().Color("bold") << methodname << " " << methodtitle
405 << gTools().Color("reset") << Endl << Endl;
406
407 TString moptions = GetMethodOptions(methodname, methodtitle);
408
409 // interpret option string with respect to a request for boosting (i.e., BostNum > 0)
410 Int_t boostNum = 0;
411 auto conf = new TMVA::Configurable(moptions);
412 conf->DeclareOptionRef(boostNum = 0, "Boost_num", "Number of times the classifier will be boosted");
413 conf->ParseOptions();
414 delete conf;
415
416 TString fFileDir;
417 if (fModelPersistence) {
418 fFileDir = fDataLoader->GetName();
419 fFileDir += "/" + gConfig().GetIONames().fWeightFileDir;
420 }
421
422 // initialize methods
423 IMethod *im;
424 TString fJobName = GetName();
425 if (!boostNum) {
426 im = ClassifierFactory::Instance().Create(std::string(methodname.Data()), fJobName, methodtitle,
427 GetDataLoaderDataSetInfo(), moptions);
428 } else {
429 // boosted classifier, requires a specific definition, making it transparent for the user
430 Log() << kDEBUG << "Boost Number is " << boostNum << " > 0: train boosted classifier" << Endl;
431 im = ClassifierFactory::Instance().Create(std::string("Boost"), fJobName, methodtitle, GetDataLoaderDataSetInfo(),
432 moptions);
433 MethodBoost *methBoost = dynamic_cast<MethodBoost *>(im);
434 if (!methBoost)
435 Log() << kFATAL << "Method with type kBoost cannot be casted to MethodCategory. /Classification" << Endl;
436
437 if (fModelPersistence)
438 methBoost->SetWeightFileDir(fFileDir);
439 methBoost->SetModelPersistence(fModelPersistence);
440 methBoost->SetBoostedMethodName(methodname);
441 methBoost->fDataSetManager = GetDataLoaderDataSetManager();
442 methBoost->SetFile(fFile.get());
443 methBoost->SetSilentFile(IsSilentFile());
444 }
445
446 MethodBase *method = dynamic_cast<MethodBase *>(im);
447 if (method == 0)
448 return 0; // could not create method
449
450 // set fDataSetManager if MethodCategory (to enable Category to create datasetinfo objects)
451 if (method->GetMethodType() == Types::kCategory) {
452 MethodCategory *methCat = (dynamic_cast<MethodCategory *>(im));
453 if (!methCat)
454 Log() << kFATAL << "Method with type kCategory cannot be casted to MethodCategory. /Classification" << Endl;
455
456 if (fModelPersistence)
457 methCat->SetWeightFileDir(fFileDir);
458 methCat->SetModelPersistence(fModelPersistence);
459 methCat->fDataSetManager = GetDataLoaderDataSetManager();
460 methCat->SetFile(fFile.get());
461 methCat->SetSilentFile(IsSilentFile());
462 }
463
464 if (!method->HasAnalysisType(fAnalysisType, GetDataLoaderDataSetInfo().GetNClasses(),
465 GetDataLoaderDataSetInfo().GetNTargets())) {
466 Log() << kWARNING << "Method " << method->GetMethodTypeName() << " is not capable of handling ";
467 Log() << "classification with " << GetDataLoaderDataSetInfo().GetNClasses() << " classes." << Endl;
468 return 0;
469 }
470
471 if (fModelPersistence)
472 method->SetWeightFileDir(fFileDir);
473 method->SetModelPersistence(fModelPersistence);
474 method->SetAnalysisType(fAnalysisType);
475 method->SetupMethod();
476 method->ParseOptions();
477 method->ProcessSetup();
478 method->SetFile(fFile.get());
479 method->SetSilentFile(IsSilentFile());
480
481 // check-for-unused-options is performed; may be overridden by derived classes
482 method->CheckSetup();
483 fIMethods.push_back(method);
484 return method;
485}
486
487//_______________________________________________________________________
488/**
489 * Allows to check if the TMVA::MethodBase was created and return the index in the vector.
490 * \param methodname name of the method.
491 * \param methodtitle method title.
492 * \param index refrence to Int_t with the position of the method into the vector fIMethods
493 * \return boolean true if the method was found.
494 */
496{
497 if (fIMethods.empty())
498 return kFALSE;
499 for (UInt_t i = 0; i < fIMethods.size(); i++) {
500 // they put method title like method name in MethodBase and type is type name
501 auto methbase = dynamic_cast<MethodBase *>(fIMethods[i]);
502 if (methbase->GetMethodTypeName() == methodname && methbase->GetMethodName() == methodtitle) {
503 index = i;
504 return kTRUE;
505 }
506 }
507 return kFALSE;
508}
509
510//_______________________________________________________________________
511/**
512 * Perform test evaluation in all booked methods.
513 */
515{
516 for (auto &meth : fMethods) {
517 TestMethod(meth.GetValue<TString>("MethodName"), meth.GetValue<TString>("MethodTitle"));
518 }
519}
520
521//_______________________________________________________________________
522/**
523 * Lets perform test an specific ml method.
524 * \param methodname name of the method.
525 * \param methodtitle method title.
526 */
528{
529 auto method = GetMethod(methodname, methodtitle);
530 if (!method) {
531 Log() << kFATAL
532 << Form("Trying to train method %s %s that maybe is not booked.", methodname.Data(), methodtitle.Data())
533 << Endl;
534 }
535
536 Log() << kHEADER << gTools().Color("bold") << "Test all methods" << gTools().Color("reset") << Endl;
538
539 Types::EAnalysisType analysisType = method->GetAnalysisType();
540 Log() << kHEADER << "Test method: " << method->GetMethodName() << " for Classification"
541 << " performance" << Endl << Endl;
542 method->AddOutput(Types::kTesting, analysisType);
543
544 // -----------------------------------------------------------------------
545 // First part of evaluation process
546 // --> compute efficiencies, and other separation estimators
547 // -----------------------------------------------------------------------
548
549 // although equal, we now want to separate the output for the variables
550 // and the real methods
551 Int_t isel; // will be 0 for a Method; 1 for a Variable
552 Int_t nmeth_used[2] = {0, 0}; // 0 Method; 1 Variable
553
554 std::vector<std::vector<TString>> mname(2);
555 std::vector<std::vector<Double_t>> sig(2), sep(2), roc(2);
556 std::vector<std::vector<Double_t>> eff01(2), eff10(2), eff30(2), effArea(2);
557 std::vector<std::vector<Double_t>> eff01err(2), eff10err(2), eff30err(2);
558 std::vector<std::vector<Double_t>> trainEff01(2), trainEff10(2), trainEff30(2);
559
560 method->SetFile(fFile.get());
561 method->SetSilentFile(IsSilentFile());
562
563 MethodBase *methodNoCuts = NULL;
564 if (!IsCutsMethod(method))
565 methodNoCuts = method;
566
567 Log() << kHEADER << "Evaluate classifier: " << method->GetMethodName() << Endl << Endl;
568 isel = (method->GetMethodTypeName().Contains("Variable")) ? 1 : 0;
569
570 // perform the evaluation
571 method->TestClassification();
572
573 // evaluate the classifier
574 mname[isel].push_back(method->GetMethodName());
575 sig[isel].push_back(method->GetSignificance());
576 sep[isel].push_back(method->GetSeparation());
577 roc[isel].push_back(method->GetROCIntegral());
578
579 Double_t err;
580 eff01[isel].push_back(method->GetEfficiency("Efficiency:0.01", Types::kTesting, err));
581 eff01err[isel].push_back(err);
582 eff10[isel].push_back(method->GetEfficiency("Efficiency:0.10", Types::kTesting, err));
583 eff10err[isel].push_back(err);
584 eff30[isel].push_back(method->GetEfficiency("Efficiency:0.30", Types::kTesting, err));
585 eff30err[isel].push_back(err);
586 effArea[isel].push_back(method->GetEfficiency("", Types::kTesting, err)); // computes the area (average)
587
588 trainEff01[isel].push_back(method->GetTrainingEfficiency("Efficiency:0.01")); // the first pass takes longer
589 trainEff10[isel].push_back(method->GetTrainingEfficiency("Efficiency:0.10"));
590 trainEff30[isel].push_back(method->GetTrainingEfficiency("Efficiency:0.30"));
591
592 nmeth_used[isel]++;
593
594 if (!IsSilentFile()) {
595 Log() << kDEBUG << "\tWrite evaluation histograms to file" << Endl;
596 method->WriteEvaluationHistosToFile(Types::kTesting);
597 method->WriteEvaluationHistosToFile(Types::kTraining);
598 }
599
600 // now sort the variables according to the best 'eff at Beff=0.10'
601 for (Int_t k = 0; k < 2; k++) {
602 std::vector<std::vector<Double_t>> vtemp;
603 vtemp.push_back(effArea[k]); // this is the vector that is ranked
604 vtemp.push_back(eff10[k]);
605 vtemp.push_back(eff01[k]);
606 vtemp.push_back(eff30[k]);
607 vtemp.push_back(eff10err[k]);
608 vtemp.push_back(eff01err[k]);
609 vtemp.push_back(eff30err[k]);
610 vtemp.push_back(trainEff10[k]);
611 vtemp.push_back(trainEff01[k]);
612 vtemp.push_back(trainEff30[k]);
613 vtemp.push_back(sig[k]);
614 vtemp.push_back(sep[k]);
615 vtemp.push_back(roc[k]);
616 std::vector<TString> vtemps = mname[k];
617 gTools().UsefulSortDescending(vtemp, &vtemps);
618 effArea[k] = vtemp[0];
619 eff10[k] = vtemp[1];
620 eff01[k] = vtemp[2];
621 eff30[k] = vtemp[3];
622 eff10err[k] = vtemp[4];
623 eff01err[k] = vtemp[5];
624 eff30err[k] = vtemp[6];
625 trainEff10[k] = vtemp[7];
626 trainEff01[k] = vtemp[8];
627 trainEff30[k] = vtemp[9];
628 sig[k] = vtemp[10];
629 sep[k] = vtemp[11];
630 roc[k] = vtemp[12];
631 mname[k] = vtemps;
632 }
633
634 // -----------------------------------------------------------------------
635 // Second part of evaluation process
636 // --> compute correlations among MVAs
637 // --> compute correlations between input variables and MVA (determines importance)
638 // --> count overlaps
639 // -----------------------------------------------------------------------
640 if (fCorrelations) {
641 const Int_t nmeth = methodNoCuts == NULL ? 0 : 1;
642 const Int_t nvar = method->fDataSetInfo.GetNVariables();
643 if (nmeth > 0) {
644
645 // needed for correlations
646 Double_t *dvec = new Double_t[nmeth + nvar];
647 std::vector<Double_t> rvec;
648
649 // for correlations
650 TPrincipal *tpSig = new TPrincipal(nmeth + nvar, "");
651 TPrincipal *tpBkg = new TPrincipal(nmeth + nvar, "");
652
653 // set required tree branch references
654 std::vector<TString> *theVars = new std::vector<TString>;
655 std::vector<ResultsClassification *> mvaRes;
656 theVars->push_back(methodNoCuts->GetTestvarName());
657 rvec.push_back(methodNoCuts->GetSignalReferenceCut());
658 theVars->back().ReplaceAll("MVA_", "");
659 mvaRes.push_back(dynamic_cast<ResultsClassification *>(
660 methodNoCuts->Data()->GetResults(methodNoCuts->GetMethodName(), Types::kTesting, Types::kMaxAnalysisType)));
661
662 // for overlap study
663 TMatrixD *overlapS = new TMatrixD(nmeth, nmeth);
664 TMatrixD *overlapB = new TMatrixD(nmeth, nmeth);
665 (*overlapS) *= 0; // init...
666 (*overlapB) *= 0; // init...
667
668 // loop over test tree
669 DataSet *defDs = method->fDataSetInfo.GetDataSet();
671 for (Int_t ievt = 0; ievt < defDs->GetNEvents(); ievt++) {
672 const Event *ev = defDs->GetEvent(ievt);
673
674 // for correlations
675 TMatrixD *theMat = 0;
676 for (Int_t im = 0; im < nmeth; im++) {
677 // check for NaN value
678 Double_t retval = (Double_t)(*mvaRes[im])[ievt][0];
679 if (TMath::IsNaN(retval)) {
680 Log() << kWARNING << "Found NaN return value in event: " << ievt << " for method \""
681 << methodNoCuts->GetName() << "\"" << Endl;
682 dvec[im] = 0;
683 } else
684 dvec[im] = retval;
685 }
686 for (Int_t iv = 0; iv < nvar; iv++)
687 dvec[iv + nmeth] = (Double_t)ev->GetValue(iv);
688 if (method->fDataSetInfo.IsSignal(ev)) {
689 tpSig->AddRow(dvec);
690 theMat = overlapS;
691 } else {
692 tpBkg->AddRow(dvec);
693 theMat = overlapB;
694 }
695
696 // count overlaps
697 for (Int_t im = 0; im < nmeth; im++) {
698 for (Int_t jm = im; jm < nmeth; jm++) {
699 if ((dvec[im] - rvec[im]) * (dvec[jm] - rvec[jm]) > 0) {
700 (*theMat)(im, jm)++;
701 if (im != jm)
702 (*theMat)(jm, im)++;
703 }
704 }
705 }
706 }
707
708 // renormalise overlap matrix
709 (*overlapS) *= (1.0 / defDs->GetNEvtSigTest()); // init...
710 (*overlapB) *= (1.0 / defDs->GetNEvtBkgdTest()); // init...
711
712 tpSig->MakePrincipals();
713 tpBkg->MakePrincipals();
714
715 const TMatrixD *covMatS = tpSig->GetCovarianceMatrix();
716 const TMatrixD *covMatB = tpBkg->GetCovarianceMatrix();
717
718 const TMatrixD *corrMatS = gTools().GetCorrelationMatrix(covMatS);
719 const TMatrixD *corrMatB = gTools().GetCorrelationMatrix(covMatB);
720
721 // print correlation matrices
722 if (corrMatS != 0 && corrMatB != 0) {
723
724 // extract MVA matrix
725 TMatrixD mvaMatS(nmeth, nmeth);
726 TMatrixD mvaMatB(nmeth, nmeth);
727 for (Int_t im = 0; im < nmeth; im++) {
728 for (Int_t jm = 0; jm < nmeth; jm++) {
729 mvaMatS(im, jm) = (*corrMatS)(im, jm);
730 mvaMatB(im, jm) = (*corrMatB)(im, jm);
731 }
732 }
733
734 // extract variables - to MVA matrix
735 std::vector<TString> theInputVars;
736 TMatrixD varmvaMatS(nvar, nmeth);
737 TMatrixD varmvaMatB(nvar, nmeth);
738 for (Int_t iv = 0; iv < nvar; iv++) {
739 theInputVars.push_back(method->fDataSetInfo.GetVariableInfo(iv).GetLabel());
740 for (Int_t jm = 0; jm < nmeth; jm++) {
741 varmvaMatS(iv, jm) = (*corrMatS)(nmeth + iv, jm);
742 varmvaMatB(iv, jm) = (*corrMatB)(nmeth + iv, jm);
743 }
744 }
745
746 if (nmeth > 1) {
747 Log() << kINFO << Endl;
748 Log() << kINFO << Form("Dataset[%s] : ", method->fDataSetInfo.GetName())
749 << "Inter-MVA correlation matrix (signal):" << Endl;
750 gTools().FormattedOutput(mvaMatS, *theVars, Log());
751 Log() << kINFO << Endl;
752
753 Log() << kINFO << Form("Dataset[%s] : ", method->fDataSetInfo.GetName())
754 << "Inter-MVA correlation matrix (background):" << Endl;
755 gTools().FormattedOutput(mvaMatB, *theVars, Log());
756 Log() << kINFO << Endl;
757 }
758
759 Log() << kINFO << Form("Dataset[%s] : ", method->fDataSetInfo.GetName())
760 << "Correlations between input variables and MVA response (signal):" << Endl;
761 gTools().FormattedOutput(varmvaMatS, theInputVars, *theVars, Log());
762 Log() << kINFO << Endl;
763
764 Log() << kINFO << Form("Dataset[%s] : ", method->fDataSetInfo.GetName())
765 << "Correlations between input variables and MVA response (background):" << Endl;
766 gTools().FormattedOutput(varmvaMatB, theInputVars, *theVars, Log());
767 Log() << kINFO << Endl;
768 } else
769 Log() << kWARNING << Form("Dataset[%s] : ", method->fDataSetInfo.GetName())
770 << "<TestAllMethods> cannot compute correlation matrices" << Endl;
771
772 // print overlap matrices
773 Log() << kINFO << Form("Dataset[%s] : ", method->fDataSetInfo.GetName())
774 << "The following \"overlap\" matrices contain the fraction of events for which " << Endl;
775 Log() << kINFO << Form("Dataset[%s] : ", method->fDataSetInfo.GetName())
776 << "the MVAs 'i' and 'j' have returned conform answers about \"signal-likeness\"" << Endl;
777 Log() << kINFO << Form("Dataset[%s] : ", method->fDataSetInfo.GetName())
778 << "An event is signal-like, if its MVA output exceeds the following value:" << Endl;
779 gTools().FormattedOutput(rvec, *theVars, "Method", "Cut value", Log());
780 Log() << kINFO << Form("Dataset[%s] : ", method->fDataSetInfo.GetName())
781 << "which correspond to the working point: eff(signal) = 1 - eff(background)" << Endl;
782
783 // give notice that cut method has been excluded from this test
784 if (nmeth != 1)
785 Log() << kINFO << Form("Dataset[%s] : ", method->fDataSetInfo.GetName())
786 << "Note: no correlations and overlap with cut method are provided at present" << Endl;
787
788 if (nmeth > 1) {
789 Log() << kINFO << Endl;
790 Log() << kINFO << Form("Dataset[%s] : ", method->fDataSetInfo.GetName())
791 << "Inter-MVA overlap matrix (signal):" << Endl;
792 gTools().FormattedOutput(*overlapS, *theVars, Log());
793 Log() << kINFO << Endl;
794
795 Log() << kINFO << Form("Dataset[%s] : ", method->fDataSetInfo.GetName())
796 << "Inter-MVA overlap matrix (background):" << Endl;
797 gTools().FormattedOutput(*overlapB, *theVars, Log());
798 }
799
800 // cleanup
801 delete tpSig;
802 delete tpBkg;
803 delete corrMatS;
804 delete corrMatB;
805 delete theVars;
806 delete overlapS;
807 delete overlapB;
808 delete[] dvec;
809 }
810 }
811
812 // -----------------------------------------------------------------------
813 // Third part of evaluation process
814 // --> output
815 // -----------------------------------------------------------------------
816 // putting results in the classification result object
817 auto &fResult = GetResults(methodname, methodtitle);
818
819 // Binary classification
820 if (fROC) {
821 Log().EnableOutput();
823 Log() << Endl;
824 TString hLine = "------------------------------------------------------------------------------------------"
825 "-------------------------";
826 Log() << kINFO << "Evaluation results ranked by best signal efficiency and purity (area)" << Endl;
827 Log() << kINFO << hLine << Endl;
828 Log() << kINFO << "DataSet MVA " << Endl;
829 Log() << kINFO << "Name: Method: ROC-integ" << Endl;
830
831 Log() << kDEBUG << hLine << Endl;
832 for (Int_t k = 0; k < 2; k++) {
833 if (k == 1 && nmeth_used[k] > 0) {
834 Log() << kINFO << hLine << Endl;
835 Log() << kINFO << "Input Variables: " << Endl << hLine << Endl;
836 }
837 for (Int_t i = 0; i < nmeth_used[k]; i++) {
838 TString datasetName = fDataLoader->GetName();
839 TString methodName = mname[k][i];
840
841 if (k == 1) {
842 methodName.ReplaceAll("Variable_", "");
843 }
844
845 TMVA::DataSet *dataset = method->Data();
846 TMVA::Results *results = dataset->GetResults(methodName, Types::kTesting, this->fAnalysisType);
847 std::vector<Bool_t> *mvaResType = dynamic_cast<ResultsClassification *>(results)->GetValueVectorTypes();
848
849 Double_t rocIntegral = 0.0;
850 if (mvaResType->size() != 0) {
851 rocIntegral = GetROCIntegral(methodname, methodtitle);
852 }
853
854 if (sep[k][i] < 0 || sig[k][i] < 0) {
855 // cannot compute separation/significance -> no MVA (usually for Cuts)
856 fResult.fROCIntegral = effArea[k][i];
857 Log() << kINFO
858 << Form("%-13s %-15s: %#1.3f", fDataLoader->GetName(), methodName.Data(), fResult.fROCIntegral)
859 << Endl;
860 } else {
861 fResult.fROCIntegral = rocIntegral;
862 Log() << kINFO << Form("%-13s %-15s: %#1.3f", datasetName.Data(), methodName.Data(), rocIntegral)
863 << Endl;
864 }
865 }
866 }
867 Log() << kINFO << hLine << Endl;
868 Log() << kINFO << Endl;
869 Log() << kINFO << "Testing efficiency compared to training efficiency (overtraining check)" << Endl;
870 Log() << kINFO << hLine << Endl;
871 Log() << kINFO
872 << "DataSet MVA Signal efficiency: from test sample (from training sample) "
873 << Endl;
874 Log() << kINFO << "Name: Method: @B=0.01 @B=0.10 @B=0.30 "
875 << Endl;
876 Log() << kINFO << hLine << Endl;
877 for (Int_t k = 0; k < 2; k++) {
878 if (k == 1 && nmeth_used[k] > 0) {
879 Log() << kINFO << hLine << Endl;
880 Log() << kINFO << "Input Variables: " << Endl << hLine << Endl;
881 }
882 for (Int_t i = 0; i < nmeth_used[k]; i++) {
883 if (k == 1)
884 mname[k][i].ReplaceAll("Variable_", "");
885
886 Log() << kINFO << Form("%-20s %-15s: %#1.3f (%#1.3f) %#1.3f (%#1.3f) %#1.3f (%#1.3f)",
887 method->fDataSetInfo.GetName(), (const char *)mname[k][i], eff01[k][i],
888 trainEff01[k][i], eff10[k][i], trainEff10[k][i], eff30[k][i], trainEff30[k][i])
889 << Endl;
890 }
891 }
892 Log() << kINFO << hLine << Endl;
893 Log() << kINFO << Endl;
894
895 if (gTools().CheckForSilentOption(GetOptions()))
896 Log().InhibitOutput();
897 } else if (IsCutsMethod(method)) { // end fROC
898 for (Int_t k = 0; k < 2; k++) {
899 for (Int_t i = 0; i < nmeth_used[k]; i++) {
900
901 if (sep[k][i] < 0 || sig[k][i] < 0) {
902 // cannot compute separation/significance -> no MVA (usually for Cuts)
903 fResult.fROCIntegral = effArea[k][i];
904 }
905 }
906 }
907 }
908
909 TMVA::DataSet *dataset = method->Data();
911
912 if (IsCutsMethod(method)) {
913 fResult.fIsCuts = kTRUE;
914 } else {
915 auto rocCurveTest = GetROC(methodname, methodtitle, 0, Types::kTesting);
916 fResult.fMvaTest[0] = rocCurveTest->GetMvas();
917 fResult.fROCIntegral = GetROCIntegral(methodname, methodtitle);
918 }
919 TString className = method->DataInfo().GetClassInfo(0)->GetName();
920 fResult.fClassNames.push_back(className);
921
922 if (!IsSilentFile()) {
923 // write test/training trees
924 RootBaseDir()->cd(method->fDataSetInfo.GetName());
925 method->fDataSetInfo.GetDataSet()->GetTree(Types::kTesting)->Write("", TObject::kOverwrite);
926 method->fDataSetInfo.GetDataSet()->GetTree(Types::kTraining)->Write("", TObject::kOverwrite);
927 }
928}
929
930//_______________________________________________________________________
931/**
932 * Lets perform test an specific ml method given the method type in enum TMVA::Types::EMVA.
933 * \param method TMVA::Types::EMVA type.
934 * \param methodtitle method title.
935 */
937{
938 TestMethod(Types::Instance().GetMethodName(method), methodtitle);
939}
940
941//_______________________________________________________________________
942/**
943 * return the the vector of TMVA::Experimental::ClassificationResult objects.
944 * \return vector of results.
945 */
946std::vector<TMVA::Experimental::ClassificationResult> &TMVA::Experimental::Classification::GetResults()
947{
948 if (fResults.size() == 0)
949 Log() << kFATAL << "No Classification results available" << Endl;
950 return fResults;
951}
952
953//_______________________________________________________________________
954/**
955 * Allows to check if the ml method is a Cuts method.
956 * \return boolen true if the method is a Cuts method.
957 */
959{
960 return method->GetMethodType() == Types::kCuts ? kTRUE : kFALSE;
961}
962
963//_______________________________________________________________________
964/**
965 * Allow to get result for an specific ml method.
966 * \param methodname name of the method.
967 * \param methodtitle method title.
968 * \return TMVA::Experimental::ClassificationResult object for the method.
969 */
972{
973 for (auto &result : fResults) {
974 if (result.IsMethod(methodname, methodtitle))
975 return result;
976 }
978 result.fMethod["MethodName"] = methodname;
979 result.fMethod["MethodTitle"] = methodtitle;
980 result.fDataLoaderName = fDataLoader->GetName();
981 fResults.push_back(result);
982 return fResults.back();
983}
984
985//_______________________________________________________________________
986/**
987 * Method to get TMVA::ROCCurve Object.
988 * \param method TMVA::MethodBase object
989 * \param iClass category, default 0 then signal
990 * \param type train/test tree, default test.
991 * \return TMVA::ROCCurve object.
992 */
995{
996 TMVA::DataSet *dataset = method->Data();
997 dataset->SetCurrentType(type);
998 TMVA::Results *results = dataset->GetResults(method->GetName(), type, this->fAnalysisType);
999
1000 UInt_t nClasses = method->DataInfo().GetNClasses();
1001 if (this->fAnalysisType == Types::kMulticlass && iClass >= nClasses) {
1002 Log() << kERROR << Form("Given class number (iClass = %i) does not exist. There are %i classes in dataset.",
1003 iClass, nClasses)
1004 << Endl;
1005 return nullptr;
1006 }
1007
1008 TMVA::ROCCurve *rocCurve = nullptr;
1009 if (this->fAnalysisType == Types::kClassification) {
1010
1011 std::vector<Float_t> *mvaRes = dynamic_cast<ResultsClassification *>(results)->GetValueVector();
1012 std::vector<Bool_t> *mvaResTypes = dynamic_cast<ResultsClassification *>(results)->GetValueVectorTypes();
1013 std::vector<Float_t> mvaResWeights;
1014
1015 auto eventCollection = dataset->GetEventCollection(type);
1016 mvaResWeights.reserve(eventCollection.size());
1017 for (auto ev : eventCollection) {
1018 mvaResWeights.push_back(ev->GetWeight());
1019 }
1020
1021 rocCurve = new TMVA::ROCCurve(*mvaRes, *mvaResTypes, mvaResWeights);
1022
1023 } else if (this->fAnalysisType == Types::kMulticlass) {
1024 std::vector<Float_t> mvaRes;
1025 std::vector<Bool_t> mvaResTypes;
1026 std::vector<Float_t> mvaResWeights;
1027
1028 std::vector<std::vector<Float_t>> *rawMvaRes = dynamic_cast<ResultsMulticlass *>(results)->GetValueVector();
1029
1030 // Vector transpose due to values being stored as
1031 // [ [0, 1, 2], [0, 1, 2], ... ]
1032 // in ResultsMulticlass::GetValueVector.
1033 mvaRes.reserve(rawMvaRes->size());
1034 for (auto item : *rawMvaRes) {
1035 mvaRes.push_back(item[iClass]);
1036 }
1037
1038 auto eventCollection = dataset->GetEventCollection(type);
1039 mvaResTypes.reserve(eventCollection.size());
1040 mvaResWeights.reserve(eventCollection.size());
1041 for (auto ev : eventCollection) {
1042 mvaResTypes.push_back(ev->GetClass() == iClass);
1043 mvaResWeights.push_back(ev->GetWeight());
1044 }
1045
1046 rocCurve = new TMVA::ROCCurve(mvaRes, mvaResTypes, mvaResWeights);
1047 }
1048
1049 return rocCurve;
1050}
1051
1052//_______________________________________________________________________
1053/**
1054 * Method to get TMVA::ROCCurve Object.
1055 * \param methodname ml method name.
1056 * \param methodtitle ml method title.
1057 * \param iClass category, default 0 then signal
1058 * \param type train/test tree, default test.
1059 * \return TMVA::ROCCurve object.
1060 */
1063{
1064 return GetROC(GetMethod(methodname, methodtitle), iClass, type);
1065}
1066
1067//_______________________________________________________________________
1068/**
1069 * Method to get ROC-Integral value from mvas.
1070 * \param methodname ml method name.
1071 * \param methodtitle ml method title.
1072 * \param iClass category, default 0 then signal
1073 * \return Double_t with the ROC-Integral value.
1074 */
1076{
1077 TMVA::ROCCurve *rocCurve = GetROC(methodname, methodtitle, iClass);
1078 if (!rocCurve) {
1079 Log() << kFATAL
1080 << Form("ROCCurve object was not created in MethodName = %s MethodTitle = %s not found with Dataset = %s ",
1081 methodname.Data(), methodtitle.Data(), fDataLoader->GetName())
1082 << Endl;
1083 return 0;
1084 }
1085
1087 Double_t rocIntegral = rocCurve->GetROCIntegral(npoints);
1088 delete rocCurve;
1089
1090 return rocIntegral;
1091}
1092
1093//_______________________________________________________________________
1095{
1096 TFile *savdir = file;
1097 TDirectory *adir = savdir;
1098 adir->cd();
1099 // loop on all entries of this directory
1100 TKey *key;
1101 TIter nextkey(src->GetListOfKeys());
1102 while ((key = (TKey *)nextkey())) {
1103 const Char_t *classname = key->GetClassName();
1104 TClass *cl = gROOT->GetClass(classname);
1105 if (!cl)
1106 continue;
1107 if (cl->InheritsFrom(TDirectory::Class())) {
1108 src->cd(key->GetName());
1109 TDirectory *subdir = file;
1110 adir->cd();
1111 CopyFrom(subdir, file);
1112 adir->cd();
1113 } else if (cl->InheritsFrom(TTree::Class())) {
1114 TTree *T = (TTree *)src->Get(key->GetName());
1115 adir->cd();
1116 TTree *newT = T->CloneTree(-1, "fast");
1117 newT->Write();
1118 } else {
1119 src->cd();
1120 TObject *obj = key->ReadObj();
1121 adir->cd();
1122 obj->Write();
1123 delete obj;
1124 }
1125 }
1126 adir->SaveSelf(kTRUE);
1127 savdir->cd();
1128}
1129
1130//_______________________________________________________________________
1132{
1133
1134 auto dsdir = fFile->mkdir(fDataLoader->GetName()); // dataset dir
1135 TTree *TrainTree = 0;
1136 TTree *TestTree = 0;
1137 TFile *ifile = 0;
1138 TFile *ofile = 0;
1139 for (UInt_t i = 0; i < fMethods.size(); i++) {
1140 auto methodname = fMethods[i].GetValue<TString>("MethodName");
1141 auto methodtitle = fMethods[i].GetValue<TString>("MethodTitle");
1142 auto fname = Form(".%s%s%s.root", fDataLoader->GetName(), methodname.Data(), methodtitle.Data());
1143 TDirectoryFile *ds = 0;
1144 if (i == 0) {
1145 ifile = new TFile(fname);
1146 ds = (TDirectoryFile *)ifile->Get(fDataLoader->GetName());
1147 } else {
1148 ofile = new TFile(fname);
1149 ds = (TDirectoryFile *)ofile->Get(fDataLoader->GetName());
1150 }
1151 auto tmptrain = (TTree *)ds->Get("TrainTree");
1152 auto tmptest = (TTree *)ds->Get("TestTree");
1153 fFile->cd();
1154 fFile->cd(fDataLoader->GetName());
1155
1156 auto methdirname = Form("Method_%s", methodtitle.Data());
1157 auto methdir = dsdir->mkdir(methdirname, methdirname);
1158 auto methdirbase = methdir->mkdir(methodtitle.Data(), methodtitle.Data());
1159 auto mfdir = (TDirectoryFile *)ds->Get(methdirname);
1160 auto mfdirbase = (TDirectoryFile *)mfdir->Get(methodtitle.Data());
1161
1162 CopyFrom(mfdirbase, (TFile *)methdirbase);
1163 dsdir->cd();
1164 if (i == 0) {
1165 TrainTree = tmptrain->CopyTree("");
1166 TestTree = tmptest->CopyTree("");
1167 } else {
1168 Float_t mva = 0;
1169 auto trainbranch = TrainTree->Branch(methodtitle.Data(), &mva);
1170 tmptrain->SetBranchAddress(methodtitle.Data(), &mva);
1171 auto entries = tmptrain->GetEntries();
1172 for (UInt_t ev = 0; ev < entries; ev++) {
1173 tmptrain->GetEntry(ev);
1174 trainbranch->Fill();
1175 }
1176 auto testbranch = TestTree->Branch(methodtitle.Data(), &mva);
1177 tmptest->SetBranchAddress(methodtitle.Data(), &mva);
1178 entries = tmptest->GetEntries();
1179 for (UInt_t ev = 0; ev < entries; ev++) {
1180 tmptest->GetEntry(ev);
1181 testbranch->Fill();
1182 }
1183 ofile->Close();
1184 }
1185 }
1186 TrainTree->Write();
1187 TestTree->Write();
1188 ifile->Close();
1189 // cleaning
1190 for (UInt_t i = 0; i < fMethods.size(); i++) {
1191 auto methodname = fMethods[i].GetValue<TString>("MethodName");
1192 auto methodtitle = fMethods[i].GetValue<TString>("MethodTitle");
1193 auto fname = Form(".%s%s%s.root", fDataLoader->GetName(), methodname.Data(), methodtitle.Data());
1194 gSystem->Unlink(fname);
1195 }
1196}
void Class()
Definition: Class.C:29
#define MinNoTrainingEvents
ROOT::R::TRInterface & r
Definition: Object.C:4
#define f(i)
Definition: RSha256.hxx:104
int Int_t
Definition: RtypesCore.h:41
char Char_t
Definition: RtypesCore.h:29
unsigned int UInt_t
Definition: RtypesCore.h:42
const Bool_t kFALSE
Definition: RtypesCore.h:88
bool Bool_t
Definition: RtypesCore.h:59
double Double_t
Definition: RtypesCore.h:55
float Float_t
Definition: RtypesCore.h:53
const Bool_t kTRUE
Definition: RtypesCore.h:87
int type
Definition: TGX11.cxx:120
TMatrixT< Double_t > TMatrixD
Definition: TMatrixDfwd.h:22
#define gROOT
Definition: TROOT.h:415
char * Form(const char *fmt,...)
R__EXTERN TSystem * gSystem
Definition: TSystem.h:560
A pseudo container class which is a generator of indices.
Definition: TSeq.hxx:66
Long64_t GetEntries() const
Definition: TBranch.h:249
TClass instances represent classes, structs and namespaces in the ROOT type system.
Definition: TClass.h:75
Bool_t InheritsFrom(const char *cl) const
Return kTRUE if this class inherits from a class with name "classname".
Definition: TClass.cxx:4708
A ROOT file is structured in Directories (like a file system).
Bool_t cd(const char *path=nullptr) override
Change current directory to "this" directory.
TObject * Get(const char *namecycle) override
Return pointer to object identified by namecycle.
Describe directory structure in memory.
Definition: TDirectory.h:34
virtual TObject * Get(const char *namecycle)
Return pointer to object identified by namecycle.
Definition: TDirectory.cxx:805
virtual void SaveSelf(Bool_t=kFALSE)
Definition: TDirectory.h:191
virtual TList * GetListOfKeys() const
Definition: TDirectory.h:160
virtual Bool_t cd(const char *path=nullptr)
Change current directory to "this" directory.
Definition: TDirectory.cxx:497
A ROOT file is a suite of consecutive data records (TKey instances) with a well defined format.
Definition: TFile.h:48
void Close(Option_t *option="") override
Close a file.
Definition: TFile.cxx:856
A Graph is a graphics object made of two arrays X and Y with npoints each.
Definition: TGraph.h:41
virtual void SetName(const char *name="")
Set graph name.
Definition: TGraph.cxx:2296
virtual void SetTitle(const char *title="")
Change (i.e.
Definition: TGraph.cxx:2312
TAxis * GetXaxis() const
Get x axis of the graph.
Definition: TGraph.cxx:1619
TAxis * GetYaxis() const
Get y axis of the graph.
Definition: TGraph.cxx:1629
Book space in a file, create I/O buffers, to fill them, (un)compress them.
Definition: TKey.h:24
virtual const char * GetClassName() const
Definition: TKey.h:72
virtual TObject * ReadObj()
To read a TObject* from the file.
Definition: TKey.cxx:729
IMethod * Create(const std::string &name, const TString &job, const TString &title, DataSetInfo &dsi, const TString &option)
creates the method if needed based on the method name using the creator function the factory has stor...
static ClassifierFactory & Instance()
access to the ClassifierFactory singleton creates the instance if needed
TString fWeightFileDir
Definition: Config.h:124
void SetDrawProgressBar(Bool_t d)
Definition: Config.h:71
void SetUseColor(Bool_t uc)
Definition: Config.h:62
class TMVA::Config::VariablePlotting fVariablePlotting
void SetSilent(Bool_t s)
Definition: Config.h:65
IONames & GetIONames()
Definition: Config.h:100
void SetConfigDescription(const char *d)
Definition: Configurable.h:64
OptionBase * DeclareOptionRef(T &ref, const TString &name, const TString &desc="")
void SetConfigName(const char *n)
Definition: Configurable.h:63
virtual void ParseOptions()
options parser
void CheckForUnusedOptions() const
checks for unused options in option string
UInt_t GetNVariables() const
Definition: DataSetInfo.h:125
UInt_t GetNClasses() const
Definition: DataSetInfo.h:153
Class that contains all the data information.
Definition: DataSet.h:69
Long64_t GetNEvtSigTest()
return number of signal test events in dataset
Definition: DataSet.cxx:427
const Event * GetEvent() const
Definition: DataSet.cxx:202
Long64_t GetNEvents(Types::ETreeType type=Types::kMaxTreeType) const
Definition: DataSet.h:217
Results * GetResults(const TString &, Types::ETreeType type, Types::EAnalysisType analysistype)
Definition: DataSet.cxx:265
void SetCurrentType(Types::ETreeType type) const
Definition: DataSet.h:100
const std::vector< Event * > & GetEventCollection(Types::ETreeType type=Types::kMaxTreeType) const
Definition: DataSet.h:227
Long64_t GetNEvtBkgdTest()
return number of background test events in dataset
Definition: DataSet.cxx:435
Abstract base class for all high level ml algorithms, you can book ml methods like BDT,...
Definition: Envelope.h:47
Bool_t fModelPersistence
file to save the results
Definition: Envelope.h:52
std::shared_ptr< DataLoader > fDataLoader
Booked method information.
Definition: Envelope.h:50
virtual void ParseOptions()
Method to parse the internal option string.
Definition: Envelope.cxx:187
Float_t GetValue(UInt_t ivar) const
return value of i'th variable
Definition: Event.cxx:237
static void SetIsTraining(Bool_t)
when this static function is called, it sets the flag whether events with negative event weight shoul...
Definition: Event.cxx:392
Double_t GetROCIntegral(UInt_t iClass=0, TMVA::Types::ETreeType type=TMVA::Types::kTesting)
Method to get ROC-Integral value from mvas.
TGraph * GetROCGraph(UInt_t iClass=0, TMVA::Types::ETreeType type=TMVA::Types::kTesting)
Method to get TGraph object with the ROC curve.
void Show()
Method to print the results in stdout.
Bool_t IsMethod(TString methodname, TString methodtitle)
Method to check if method was booked.
std::map< UInt_t, std::vector< std::tuple< Float_t, Float_t, Bool_t > > > fMvaTest
ROCCurve * GetROC(UInt_t iClass=0, TMVA::Types::ETreeType type=TMVA::Types::kTesting)
Method to get TMVA::ROCCurve Object.
ClassificationResult & operator=(const ClassificationResult &r)
std::map< UInt_t, std::vector< std::tuple< Float_t, Float_t, Bool_t > > > fMvaTrain
Classification(DataLoader *loader, TFile *file, TString options)
Contructor to create a two class classifier.
Double_t GetROCIntegral(TString methodname, TString methodtitle, UInt_t iClass=0)
Method to get ROC-Integral value from mvas.
virtual void Test()
Perform test evaluation in all booked methods.
TString GetMethodOptions(TString methodname, TString methodtitle)
return the options for the booked method.
MethodBase * GetMethod(TString methodname, TString methodtitle)
Return a TMVA::MethodBase object.
virtual void TrainMethod(TString methodname, TString methodtitle)
Lets train an specific ml method.
Bool_t HasMethodObject(TString methodname, TString methodtitle, Int_t &index)
Allows to check if the TMVA::MethodBase was created and return the index in the vector.
std::vector< ClassificationResult > & GetResults()
return the the vector of TMVA::Experimental::ClassificationResult objects.
virtual void Train()
Method to train all booked ml methods.
virtual void Evaluate()
Method to perform Train/Test over all ml method booked.
Types::EAnalysisType fAnalysisType
vector of objects with booked methods
TMVA::ROCCurve * GetROC(TMVA::MethodBase *method, UInt_t iClass=0, TMVA::Types::ETreeType type=TMVA::Types::kTesting)
Method to get TMVA::ROCCurve Object.
Bool_t IsCutsMethod(TMVA::MethodBase *method)
Allows to check if the ml method is a Cuts method.
void CopyFrom(TDirectory *src, TFile *file)
virtual void TestMethod(TString methodname, TString methodtitle)
Lets perform test an specific ml method.
Interface for all concrete MVA method implementations.
Definition: IMethod.h:54
virtual Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets)=0
Virtual base Class for all MVA method.
Definition: MethodBase.h:111
void SetSilentFile(Bool_t status)
Definition: MethodBase.h:377
void SetWeightFileDir(TString fileDir)
set directory of weight file
TString GetMethodTypeName() const
Definition: MethodBase.h:331
const char * GetName() const
Definition: MethodBase.h:333
const TString & GetTestvarName() const
Definition: MethodBase.h:334
void SetupMethod()
setup of methods
Definition: MethodBase.cxx:411
virtual void SetAnalysisType(Types::EAnalysisType type)
Definition: MethodBase.h:436
const TString & GetMethodName() const
Definition: MethodBase.h:330
void ProcessSetup()
process all options the "CheckForUnusedOptions" is done in an independent call, since it may be overr...
Definition: MethodBase.cxx:428
DataSetInfo & DataInfo() const
Definition: MethodBase.h:409
DataSetInfo & fDataSetInfo
Definition: MethodBase.h:605
Types::EMVA GetMethodType() const
Definition: MethodBase.h:332
void SetFile(TFile *file)
Definition: MethodBase.h:374
DataSet * Data() const
Definition: MethodBase.h:408
void SetModelPersistence(Bool_t status)
Definition: MethodBase.h:381
Double_t GetSignalReferenceCut() const
Definition: MethodBase.h:359
virtual void CheckSetup()
check may be overridden by derived class (sometimes, eg, fitters are used which can only be implement...
Definition: MethodBase.cxx:438
Class for boosting a TMVA method.
Definition: MethodBoost.h:58
void SetBoostedMethodName(TString methodName)
Definition: MethodBoost.h:86
DataSetManager * fDataSetManager
Definition: MethodBoost.h:193
Class for categorizing the phase space.
DataSetManager * fDataSetManager
ostringstream derivative to redirect and format output
Definition: MsgLogger.h:59
static void InhibitOutput()
Definition: MsgLogger.cxx:74
static void EnableOutput()
Definition: MsgLogger.cxx:75
Double_t GetROCIntegral(const UInt_t points=41)
Calculates the ROC integral (AUC)
Definition: ROCCurve.cxx:251
Class that is the base-class for a vector of result.
Class which takes the results of a multiclass classification.
Class that is the base-class for a vector of result.
Definition: Results.h:57
void FormattedOutput(const std::vector< Double_t > &, const std::vector< TString > &, const TString titleVars, const TString titleValues, MsgLogger &logger, TString format="%+1.3f")
formatted output of simple table
Definition: Tools.cxx:899
void UsefulSortDescending(std::vector< std::vector< Double_t > > &, std::vector< TString > *vs=0)
sort 2D vector (AND in parallel a TString vector) in such a way that the "first vector is sorted" and...
Definition: Tools.cxx:576
const TString & Color(const TString &)
human readable color strings
Definition: Tools.cxx:840
const TMatrixD * GetCorrelationMatrix(const TMatrixD *covMat)
turns covariance into correlation matrix
Definition: Tools.cxx:336
Singleton class for Global types used by TMVA.
Definition: Types.h:73
static Types & Instance()
the the single instance of "Types" if existing already, or create it (Singleton)
Definition: Types.cxx:70
@ kCategory
Definition: Types.h:99
@ kCuts
Definition: Types.h:80
EAnalysisType
Definition: Types.h:127
@ kMulticlass
Definition: Types.h:130
@ kClassification
Definition: Types.h:128
@ kMaxAnalysisType
Definition: Types.h:132
@ kTraining
Definition: Types.h:144
@ kTesting
Definition: Types.h:145
virtual void SetTitle(const char *title="")
Set the title of the TNamed.
Definition: TNamed.cxx:164
virtual const char * GetName() const
Returns name of object.
Definition: TNamed.h:47
Mother of all ROOT objects.
Definition: TObject.h:37
virtual Int_t Write(const char *name=0, Int_t option=0, Int_t bufsize=0)
Write this object to the current directory.
Definition: TObject.cxx:785
@ kOverwrite
overwrite existing object with same name
Definition: TObject.h:88
Principal Components Analysis (PCA)
Definition: TPrincipal.h:20
virtual void AddRow(const Double_t *x)
Add a data point and update the covariance matrix.
Definition: TPrincipal.cxx:410
const TMatrixD * GetCovarianceMatrix() const
Definition: TPrincipal.h:58
virtual void MakePrincipals()
Perform the principal components analysis.
Definition: TPrincipal.cxx:869
Basic string class.
Definition: TString.h:131
const char * Data() const
Definition: TString.h:364
TString & ReplaceAll(const TString &s1, const TString &s2)
Definition: TString.h:687
virtual int MakeDirectory(const char *name)
Make a directory.
Definition: TSystem.cxx:835
virtual int Unlink(const char *name)
Unlink, i.e.
Definition: TSystem.cxx:1372
A TTree represents a columnar dataset.
Definition: TTree.h:72
virtual TTree * CopyTree(const char *selection, Option_t *option="", Long64_t nentries=kMaxEntries, Long64_t firstentry=0)
Copy a tree with selection.
Definition: TTree.cxx:3635
TBranch * Branch(const char *name, T *obj, Int_t bufsize=32000, Int_t splitlevel=99)
Add a new branch, and infer the data type from the type of obj being passed.
Definition: TTree.h:341
virtual Int_t Write(const char *name=0, Int_t option=0, Int_t bufsize=0)
Write this object to the current directory.
Definition: TTree.cxx:9485
std::string GetMethodName(TCppMethod_t)
Definition: Cppyy.cxx:757
std::string GetName(const std::string &scope_name)
Definition: Cppyy.cxx:150
TCppMethod_t GetMethod(TCppScope_t scope, TCppIndex_t imeth)
Definition: Cppyy.cxx:751
double T(double x)
Definition: ChebyshevPol.h:34
void GetMethodTitle(TString &name, TKey *ikey)
Definition: tmvaglob.cxx:341
create variable transformations
Config & gConfig()
Tools & gTools()
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:158
Bool_t IsNaN(Double_t x)
Definition: TMath.h:882
Double_t Log(Double_t x)
Definition: TMath.h:750
Definition: file.py:1
auto * m
Definition: textangle.C:8