Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
Classification.cxx
Go to the documentation of this file.
1// @(#)root/tmva $Id$
2// Author: Omar Zapata, Andreas Hoecker, Peter Speckmayer, Joerg Stelzer, Helge Voss, Kai Voss, Eckhard von Toerne, Jan
3// Therhaag
4
6
8#include <TMVA/Config.h>
9#include <TMVA/Configurable.h>
10#include <TMVA/Tools.h>
11#include <TMVA/Ranking.h>
12#include <TMVA/DataSet.h>
13#include <TMVA/IMethod.h>
14#include <TMVA/MethodBase.h>
16#include <TMVA/DataSetManager.h>
17#include <TMVA/DataSetInfo.h>
18#include <TMVA/DataLoader.h>
19#include <TMVA/MethodBoost.h>
20#include <TMVA/MethodCategory.h>
21#include <TMVA/ROCCalc.h>
22#include <TMVA/ROCCurve.h>
23#include <TMVA/MsgLogger.h>
24
25#include <TMVA/VariableInfo.h>
27
28#include <TMVA/Types.h>
29
30#include <TROOT.h>
31#include <TFile.h>
32#include <TKey.h>
33#include <TTree.h>
34#include <TKey.h>
35#include <TLeaf.h>
36#include <TBranch.h>
37#include <TEventList.h>
38#include <TGraph.h>
39#include <TMatrixF.h>
40#include <TMatrixDSym.h>
41#include <TMultiGraph.h>
42#include <TPrincipal.h>
43#include <TMath.h>
44#include <TSystem.h>
45
46#include <iostream>
47#include <memory>
48
49#define MinNoTrainingEvents 10
50
51//_______________________________________________________________________
53{
54}
55
56//_______________________________________________________________________
58{
59 fMethod = cr.fMethod;
62 fMvaTest = cr.fMvaTest;
63 fIsCuts = cr.fIsCuts;
65}
66
67//_______________________________________________________________________
68/**
69 * Method to get ROC-Integral value from mvas.
70 * \param iClass category, default 0 then signal
71 * \param type train/test tree, default test.
72 * \return Double_t with the ROC-Integral value.
73 */
75{
76 if (fIsCuts) {
77 return fROCIntegral;
78 } else {
79 auto roc = GetROC(iClass, type);
80 auto inte = roc->GetROCIntegral();
81 delete roc;
82 return inte;
83 }
84}
85
86//_______________________________________________________________________
87/**
88 * Method to get TMVA::ROCCurve Object.
89 * \param iClass category, default 0 then signal
90 * \param type train/test tree, default test.
91 * \return TMVA::ROCCurve object.
92 */
94{
95 ROCCurve *fROCCurve = nullptr;
97 fROCCurve = new ROCCurve(fMvaTest[iClass]);
98 else
99 fROCCurve = new ROCCurve(fMvaTrain[iClass]);
100 return fROCCurve;
101}
102
103//_______________________________________________________________________
106{
107 fMethod = cr.fMethod;
108 fDataLoaderName = cr.fDataLoaderName;
109 fMvaTrain = cr.fMvaTrain;
110 fMvaTest = cr.fMvaTest;
111 fIsCuts = cr.fIsCuts;
112 fROCIntegral = cr.fROCIntegral;
113 return *this;
114}
115
116//_______________________________________________________________________
117/**
118 * Method to print the results in stdout.
119 * data loader name, method name/tittle and ROC-integ.
120 */
122{
123 MsgLogger fLogger("Classification");
126 TString hLine = "--------------------------------------------------- :";
127
128 fLogger << kINFO << hLine << Endl;
129 fLogger << kINFO << "DataSet MVA :" << Endl;
130 fLogger << kINFO << "Name: Method/Title: ROC-integ :" << Endl;
131 fLogger << kINFO << hLine << Endl;
132 fLogger << kINFO << TString::Format("%-20s %-15s %#1.3f :", fDataLoaderName.Data(),
133 TString::Format("%s/%s", fMethod.GetValue<TString>("MethodName").Data(),
134 fMethod.GetValue<TString>("MethodTitle").Data()).Data(),
135 GetROCIntegral())
136 << Endl;
137 fLogger << kINFO << hLine << Endl;
138
140}
141
142//_______________________________________________________________________
143/**
144 * Method to get TGraph object with the ROC curve.
145 * \param iClass category, default 0 then signal
146 * \param type train/test tree, default test.
147 * \return TGraph object.
148 */
150{
151 TGraph *roc = GetROC(iClass, type)->GetROCCurve();
152 roc->SetName(TString::Format("%s/%s", GetMethodName().Data(), GetMethodTitle().Data()).Data());
153 roc->SetTitle(TString::Format("%s/%s", GetMethodName().Data(), GetMethodTitle().Data()).Data());
154 roc->GetXaxis()->SetTitle(" Signal Efficiency ");
155 roc->GetYaxis()->SetTitle(" Background Rejection ");
156 return roc;
157}
158
159//_______________________________________________________________________
160/**
161 * Method to check if method was booked.
162 * \param methodname name of the method.
163 * \param methodtitle method title.
164 * \return boolean true if the method was booked, false in other case.
165 */
167{
168 return fMethod.GetValue<TString>("MethodName") == methodname &&
169 fMethod.GetValue<TString>("MethodTitle") == methodtitle
170 ? kTRUE
171 : kFALSE;
172}
173
174//_______________________________________________________________________
175/**
176 * Contructor to create a two class classifier.
177 * \param dataloader TMVA::DataLoader object with the data to train/test.
178 * \param file TFile object to save the results
179 * \param options string extra options.
180 */
182 : TMVA::Envelope("Classification", dataloader, file, options), fAnalysisType(Types::kClassification),
183 fCorrelations(kFALSE), fROC(kTRUE)
184{
185 DeclareOptionRef(fCorrelations, "Correlations", "boolean to show correlation in output");
186 DeclareOptionRef(fROC, "ROC", "boolean to show ROC in output");
187 ParseOptions();
189
191 gSystem->MakeDirectory(fDataLoader->GetName()); // creating directory for DataLoader output
192}
193
194//_______________________________________________________________________
195/**
196 * Contructor to create a two class classifier without output file.
197 * \param dataloader TMVA::DataLoader object with the data to train/test.
198 * \param options string extra options.
199 */
201 : TMVA::Envelope("Classification", dataloader, NULL, options), fAnalysisType(Types::kClassification),
202 fCorrelations(kFALSE), fROC(kTRUE)
203{
204
205 // init configurable
206 SetConfigDescription("Configuration options for Classification running");
208
209 DeclareOptionRef(fCorrelations, "Correlations", "boolean to show correlation in output");
210 DeclareOptionRef(fROC, "ROC", "boolean to show ROC in output");
211 ParseOptions();
214 gSystem->MakeDirectory(fDataLoader->GetName()); // creating directory for DataLoader output
216}
217
218//_______________________________________________________________________
220{
221 for (auto m : fIMethods) {
222 if (m != NULL)
223 delete m;
224 }
225}
226
227//_______________________________________________________________________
228/**
229 * return the options for the booked method.
230 * \param methodname name of the method.
231 * \param methodtitle method title.
232 * \return string the with options for the ml method.
233 */
235{
236 for (auto &meth : fMethods) {
237 if (meth.GetValue<TString>("MethodName") == methodname && meth.GetValue<TString>("MethodTitle") == methodtitle)
238 return meth.GetValue<TString>("MethodOptions");
239 }
240 return "";
241}
242
243//_______________________________________________________________________
244/**
245 * Method to perform Train/Test over all ml method booked.
246 * If the option Jobs > 1 can do it in parallel with MultiProc.
247 */
249{
250 fTimer.Reset();
251 fTimer.Start();
252
253 Bool_t roc = fROC;
254 fROC = kFALSE;
255 if (fJobs <= 1) {
256 Train();
257 Test();
258 } else {
259 for (auto &meth : fMethods) {
260 GetMethod(meth.GetValue<TString>("MethodName"), meth.GetValue<TString>("MethodTitle"));
261 }
262#ifndef _MSC_VER
263 fWorkers.SetNWorkers(fJobs);
264#endif
265 auto executor = [this](UInt_t workerID) -> ClassificationResult {
270 auto methodname = fMethods[workerID].GetValue<TString>("MethodName");
271 auto methodtitle = fMethods[workerID].GetValue<TString>("MethodTitle");
272 auto meth = GetMethod(methodname, methodtitle);
273 if (!IsSilentFile()) {
274 auto fname = TString::Format(".%s%s%s.root", fDataLoader->GetName(), methodname.Data(), methodtitle.Data());
275 auto f = new TFile(fname.Data(), "RECREATE");
276 f->mkdir(fDataLoader->GetName());
277 SetFile(f);
278 meth->SetFile(f);
279 }
280 TrainMethod(methodname, methodtitle);
281 TestMethod(methodname, methodtitle);
282 if (!IsSilentFile()) {
283 GetFile()->Close();
284 }
285 return GetResults(methodname, methodtitle);
286 };
287
288#ifndef _MSC_VER
289 fResults = fWorkers.Map(executor, ROOT::TSeqI(fMethods.size()));
290#endif
291 if (!IsSilentFile())
292 MergeFiles();
293 }
294
295 fROC = roc;
297
298 TString hLine = "--------------------------------------------------- :";
299 Log() << kINFO << hLine << Endl;
300 Log() << kINFO << "DataSet MVA :" << Endl;
301 Log() << kINFO << "Name: Method/Title: ROC-integ :" << Endl;
302 Log() << kINFO << hLine << Endl;
303 for (auto &r : fResults) {
304
305 Log() << kINFO << TString::Format("%-20s %-15s %#1.3f :", r.GetDataLoaderName().Data(),
306 TString::Format("%s/%s", r.GetMethodName().Data(), r.GetMethodTitle().Data()).Data(),
307 r.GetROCIntegral())
308 << Endl;
309 }
310 Log() << kINFO << hLine << Endl;
311
312 Log() << kINFO << "-----------------------------------------------------" << Endl;
313 Log() << kHEADER << "Evaluation done." << Endl << Endl;
314 Log() << kINFO << TString::Format("Jobs = %d Real Time = %lf ", fJobs, fTimer.RealTime()) << Endl;
315 Log() << kINFO << "-----------------------------------------------------" << Endl;
316 Log() << kINFO << "Evaluation done." << Endl;
318}
319
320//_______________________________________________________________________
321/**
322 * Method to train all booked ml methods.
323 */
325{
326 for (auto &meth : fMethods) {
327 TrainMethod(meth.GetValue<TString>("MethodName"), meth.GetValue<TString>("MethodTitle"));
328 }
329}
330
331//_______________________________________________________________________
332/**
333 * Lets train an specific ml method.
334 * \param methodname name of the method.
335 * \param methodtitle method title.
336 */
338{
339 auto method = GetMethod(methodname, methodtitle);
340 if (!method) {
341 Log() << kFATAL
342 << TString::Format("Trying to train method %s %s that maybe is not booked.", methodname.Data(), methodtitle.Data())
343 << Endl;
344 }
345 Log() << kHEADER << gTools().Color("bold") << TString::Format("Training method %s %s", methodname.Data(), methodtitle.Data())
346 << gTools().Color("reset") << Endl;
347
349 if ((fAnalysisType == Types::kMulticlass || fAnalysisType == Types::kClassification) &&
350 method->DataInfo().GetNClasses() < 2)
351 Log() << kFATAL << "You want to do classification training, but specified less than two classes." << Endl;
352
353 // first print some information about the default dataset
354 // if(!IsSilentFile()) WriteDataInformation(method->fDataSetInfo);
355
356 if (method->Data()->GetNTrainingEvents() < MinNoTrainingEvents) {
357 Log() << kWARNING << "Method " << method->GetMethodName() << " not trained (training tree has less entries ["
358 << method->Data()->GetNTrainingEvents() << "] than required [" << MinNoTrainingEvents << "]" << Endl;
359 return;
360 }
361
362 Log() << kHEADER << "Train method: " << method->GetMethodName() << " for Classification" << Endl << Endl;
363 method->TrainMethod();
364 Log() << kHEADER << "Training finished" << Endl << Endl;
365}
366
367//_______________________________________________________________________
368/**
369 * Lets train an specific ml method given the method type in enum TMVA::Types::EMVA
370 * \param method TMVA::Types::EMVA type.
371 * \param methodtitle method title.
372 */
374{
375 TrainMethod(Types::Instance().GetMethodName(method), methodtitle);
376}
377
378//_______________________________________________________________________
379/**
380 * Return a TMVA::MethodBase object. if method is not booked then return a null
381 * pointer.
382 * \param methodname name of the method.
383 * \param methodtitle method title.
384 * \return TMVA::MethodBase object
385 */
387{
388
389 if (!HasMethod(methodname, methodtitle)) {
390 std::cout << methodname << " " << methodtitle << std::endl;
391 Log() << kERROR << "Trying to get method not booked." << Endl;
392 return 0;
393 }
394 Int_t index = -1;
395 if (HasMethodObject(methodname, methodtitle, index)) {
396 return dynamic_cast<MethodBase *>(fIMethods[index]);
397 }
398 // if is not created then lets to create it.
399 if (GetDataLoaderDataInput().GetEntries() <=
400 1) { // 0 entries --> 0 events, 1 entry --> dynamical dataset (or one entry)
401 Log() << kFATAL << "No input data for the training provided!" << Endl;
402 }
403 Log() << kHEADER << "Loading booked method: " << gTools().Color("bold") << methodname << " " << methodtitle
404 << gTools().Color("reset") << Endl << Endl;
405
406 TString moptions = GetMethodOptions(methodname, methodtitle);
407
408 // interpret option string with respect to a request for boosting (i.e., BostNum > 0)
409 Int_t boostNum = 0;
410 auto conf = new TMVA::Configurable(moptions);
411 conf->DeclareOptionRef(boostNum = 0, "Boost_num", "Number of times the classifier will be boosted");
412 conf->ParseOptions();
413 delete conf;
414
415 TString fFileDir;
416 if (fModelPersistence) {
417 fFileDir = fDataLoader->GetName();
418 fFileDir += "/" + gConfig().GetIONames().fWeightFileDir;
419 }
420
421 // initialize methods
422 IMethod *im;
423 TString fJobName = GetName();
424 if (!boostNum) {
425 im = ClassifierFactory::Instance().Create(std::string(methodname.Data()), fJobName, methodtitle,
426 GetDataLoaderDataSetInfo(), moptions);
427 } else {
428 // boosted classifier, requires a specific definition, making it transparent for the user
429 Log() << kDEBUG << "Boost Number is " << boostNum << " > 0: train boosted classifier" << Endl;
430 im = ClassifierFactory::Instance().Create(std::string("Boost"), fJobName, methodtitle, GetDataLoaderDataSetInfo(),
431 moptions);
432 MethodBoost *methBoost = dynamic_cast<MethodBoost *>(im);
433 if (!methBoost)
434 Log() << kFATAL << "Method with type kBoost cannot be casted to MethodCategory. /Classification" << Endl;
435
436 if (fModelPersistence)
437 methBoost->SetWeightFileDir(fFileDir);
438 methBoost->SetModelPersistence(fModelPersistence);
439 methBoost->SetBoostedMethodName(methodname);
440 methBoost->fDataSetManager = GetDataLoaderDataSetManager();
441 methBoost->SetFile(fFile.get());
442 methBoost->SetSilentFile(IsSilentFile());
443 }
444
445 MethodBase *method = dynamic_cast<MethodBase *>(im);
446 if (method == 0)
447 return 0; // could not create method
448
449 // set fDataSetManager if MethodCategory (to enable Category to create datasetinfo objects)
450 if (method->GetMethodType() == Types::kCategory) {
451 MethodCategory *methCat = (dynamic_cast<MethodCategory *>(im));
452 if (!methCat)
453 Log() << kFATAL << "Method with type kCategory cannot be casted to MethodCategory. /Classification" << Endl;
454
455 if (fModelPersistence)
456 methCat->SetWeightFileDir(fFileDir);
457 methCat->SetModelPersistence(fModelPersistence);
458 methCat->fDataSetManager = GetDataLoaderDataSetManager();
459 methCat->SetFile(fFile.get());
460 methCat->SetSilentFile(IsSilentFile());
461 }
462
463 if (!method->HasAnalysisType(fAnalysisType, GetDataLoaderDataSetInfo().GetNClasses(),
464 GetDataLoaderDataSetInfo().GetNTargets())) {
465 Log() << kWARNING << "Method " << method->GetMethodTypeName() << " is not capable of handling ";
466 Log() << "classification with " << GetDataLoaderDataSetInfo().GetNClasses() << " classes." << Endl;
467 return 0;
468 }
469
470 if (fModelPersistence)
471 method->SetWeightFileDir(fFileDir);
472 method->SetModelPersistence(fModelPersistence);
473 method->SetAnalysisType(fAnalysisType);
474 method->SetupMethod();
475 method->ParseOptions();
476 method->ProcessSetup();
477 method->SetFile(fFile.get());
478 method->SetSilentFile(IsSilentFile());
479
480 // check-for-unused-options is performed; may be overridden by derived classes
481 method->CheckSetup();
482 fIMethods.push_back(method);
483 return method;
484}
485
486//_______________________________________________________________________
487/**
488 * Allows to check if the TMVA::MethodBase was created and return the index in the vector.
489 * \param methodname name of the method.
490 * \param methodtitle method title.
491 * \param index refrence to Int_t with the position of the method into the vector fIMethods
492 * \return boolean true if the method was found.
493 */
495{
496 if (fIMethods.empty())
497 return kFALSE;
498 for (UInt_t i = 0; i < fIMethods.size(); i++) {
499 // they put method title like method name in MethodBase and type is type name
500 auto methbase = dynamic_cast<MethodBase *>(fIMethods[i]);
501 if (methbase->GetMethodTypeName() == methodname && methbase->GetMethodName() == methodtitle) {
502 index = i;
503 return kTRUE;
504 }
505 }
506 return kFALSE;
507}
508
509//_______________________________________________________________________
510/**
511 * Perform test evaluation in all booked methods.
512 */
514{
515 for (auto &meth : fMethods) {
516 TestMethod(meth.GetValue<TString>("MethodName"), meth.GetValue<TString>("MethodTitle"));
517 }
518}
519
520//_______________________________________________________________________
521/**
522 * Lets perform test an specific ml method.
523 * \param methodname name of the method.
524 * \param methodtitle method title.
525 */
527{
528 auto method = GetMethod(methodname, methodtitle);
529 if (!method) {
530 Log() << kFATAL
531 << TString::Format("Trying to train method %s %s that maybe is not booked.", methodname.Data(), methodtitle.Data())
532 << Endl;
533 }
534
535 Log() << kHEADER << gTools().Color("bold") << "Test all methods" << gTools().Color("reset") << Endl;
537
538 Types::EAnalysisType analysisType = method->GetAnalysisType();
539 Log() << kHEADER << "Test method: " << method->GetMethodName() << " for Classification"
540 << " performance" << Endl << Endl;
541 method->AddOutput(Types::kTesting, analysisType);
542
543 // -----------------------------------------------------------------------
544 // First part of evaluation process
545 // --> compute efficiencies, and other separation estimators
546 // -----------------------------------------------------------------------
547
548 // although equal, we now want to separate the output for the variables
549 // and the real methods
550 Int_t isel; // will be 0 for a Method; 1 for a Variable
551 Int_t nmeth_used[2] = {0, 0}; // 0 Method; 1 Variable
552
553 std::vector<std::vector<TString>> mname(2);
554 std::vector<std::vector<Double_t>> sig(2), sep(2), roc(2);
555 std::vector<std::vector<Double_t>> eff01(2), eff10(2), eff30(2), effArea(2);
556 std::vector<std::vector<Double_t>> eff01err(2), eff10err(2), eff30err(2);
557 std::vector<std::vector<Double_t>> trainEff01(2), trainEff10(2), trainEff30(2);
558
559 method->SetFile(fFile.get());
560 method->SetSilentFile(IsSilentFile());
561
562 MethodBase *methodNoCuts = NULL;
563 if (!IsCutsMethod(method))
564 methodNoCuts = method;
565
566 Log() << kHEADER << "Evaluate classifier: " << method->GetMethodName() << Endl << Endl;
567 isel = (method->GetMethodTypeName().Contains("Variable")) ? 1 : 0;
568
569 // perform the evaluation
570 method->TestClassification();
571
572 // evaluate the classifier
573 mname[isel].push_back(method->GetMethodName());
574 sig[isel].push_back(method->GetSignificance());
575 sep[isel].push_back(method->GetSeparation());
576 roc[isel].push_back(method->GetROCIntegral());
577
578 Double_t err;
579 eff01[isel].push_back(method->GetEfficiency("Efficiency:0.01", Types::kTesting, err));
580 eff01err[isel].push_back(err);
581 eff10[isel].push_back(method->GetEfficiency("Efficiency:0.10", Types::kTesting, err));
582 eff10err[isel].push_back(err);
583 eff30[isel].push_back(method->GetEfficiency("Efficiency:0.30", Types::kTesting, err));
584 eff30err[isel].push_back(err);
585 effArea[isel].push_back(method->GetEfficiency("", Types::kTesting, err)); // computes the area (average)
586
587 trainEff01[isel].push_back(method->GetTrainingEfficiency("Efficiency:0.01")); // the first pass takes longer
588 trainEff10[isel].push_back(method->GetTrainingEfficiency("Efficiency:0.10"));
589 trainEff30[isel].push_back(method->GetTrainingEfficiency("Efficiency:0.30"));
590
591 nmeth_used[isel]++;
592
593 if (!IsSilentFile()) {
594 Log() << kDEBUG << "\tWrite evaluation histograms to file" << Endl;
595 method->WriteEvaluationHistosToFile(Types::kTesting);
596 method->WriteEvaluationHistosToFile(Types::kTraining);
597 }
598
599 // now sort the variables according to the best 'eff at Beff=0.10'
600 for (Int_t k = 0; k < 2; k++) {
601 std::vector<std::vector<Double_t>> vtemp;
602 vtemp.push_back(effArea[k]); // this is the vector that is ranked
603 vtemp.push_back(eff10[k]);
604 vtemp.push_back(eff01[k]);
605 vtemp.push_back(eff30[k]);
606 vtemp.push_back(eff10err[k]);
607 vtemp.push_back(eff01err[k]);
608 vtemp.push_back(eff30err[k]);
609 vtemp.push_back(trainEff10[k]);
610 vtemp.push_back(trainEff01[k]);
611 vtemp.push_back(trainEff30[k]);
612 vtemp.push_back(sig[k]);
613 vtemp.push_back(sep[k]);
614 vtemp.push_back(roc[k]);
615 std::vector<TString> vtemps = mname[k];
616 gTools().UsefulSortDescending(vtemp, &vtemps);
617 effArea[k] = vtemp[0];
618 eff10[k] = vtemp[1];
619 eff01[k] = vtemp[2];
620 eff30[k] = vtemp[3];
621 eff10err[k] = vtemp[4];
622 eff01err[k] = vtemp[5];
623 eff30err[k] = vtemp[6];
624 trainEff10[k] = vtemp[7];
625 trainEff01[k] = vtemp[8];
626 trainEff30[k] = vtemp[9];
627 sig[k] = vtemp[10];
628 sep[k] = vtemp[11];
629 roc[k] = vtemp[12];
630 mname[k] = vtemps;
631 }
632
633 // -----------------------------------------------------------------------
634 // Second part of evaluation process
635 // --> compute correlations among MVAs
636 // --> compute correlations between input variables and MVA (determines importance)
637 // --> count overlaps
638 // -----------------------------------------------------------------------
639 if (fCorrelations) {
640 const Int_t nmeth = methodNoCuts == NULL ? 0 : 1;
641 const Int_t nvar = method->fDataSetInfo.GetNVariables();
642 if (nmeth > 0) {
643
644 // needed for correlations
645 Double_t *dvec = new Double_t[nmeth + nvar];
646 std::vector<Double_t> rvec;
647
648 // for correlations
649 TPrincipal *tpSig = new TPrincipal(nmeth + nvar, "");
650 TPrincipal *tpBkg = new TPrincipal(nmeth + nvar, "");
651
652 // set required tree branch references
653 std::vector<TString> *theVars = new std::vector<TString>;
654 std::vector<ResultsClassification *> mvaRes;
655 theVars->push_back(methodNoCuts->GetTestvarName());
656 rvec.push_back(methodNoCuts->GetSignalReferenceCut());
657 theVars->back().ReplaceAll("MVA_", "");
658 mvaRes.push_back(dynamic_cast<ResultsClassification *>(
659 methodNoCuts->Data()->GetResults(methodNoCuts->GetMethodName(), Types::kTesting, Types::kMaxAnalysisType)));
660
661 // for overlap study
662 TMatrixD *overlapS = new TMatrixD(nmeth, nmeth);
663 TMatrixD *overlapB = new TMatrixD(nmeth, nmeth);
664 (*overlapS) *= 0; // init...
665 (*overlapB) *= 0; // init...
666
667 // loop over test tree
668 DataSet *defDs = method->fDataSetInfo.GetDataSet();
670 for (Int_t ievt = 0; ievt < defDs->GetNEvents(); ievt++) {
671 const Event *ev = defDs->GetEvent(ievt);
672
673 // for correlations
674 TMatrixD *theMat = 0;
675 for (Int_t im = 0; im < nmeth; im++) {
676 // check for NaN value
677 Double_t retval = (Double_t)(*mvaRes[im])[ievt][0];
678 if (TMath::IsNaN(retval)) {
679 Log() << kWARNING << "Found NaN return value in event: " << ievt << " for method \""
680 << methodNoCuts->GetName() << "\"" << Endl;
681 dvec[im] = 0;
682 } else
683 dvec[im] = retval;
684 }
685 for (Int_t iv = 0; iv < nvar; iv++)
686 dvec[iv + nmeth] = (Double_t)ev->GetValue(iv);
687 if (method->fDataSetInfo.IsSignal(ev)) {
688 tpSig->AddRow(dvec);
689 theMat = overlapS;
690 } else {
691 tpBkg->AddRow(dvec);
692 theMat = overlapB;
693 }
694
695 // count overlaps
696 for (Int_t im = 0; im < nmeth; im++) {
697 for (Int_t jm = im; jm < nmeth; jm++) {
698 if ((dvec[im] - rvec[im]) * (dvec[jm] - rvec[jm]) > 0) {
699 (*theMat)(im, jm)++;
700 if (im != jm)
701 (*theMat)(jm, im)++;
702 }
703 }
704 }
705 }
706
707 // renormalise overlap matrix
708 (*overlapS) *= (1.0 / defDs->GetNEvtSigTest()); // init...
709 (*overlapB) *= (1.0 / defDs->GetNEvtBkgdTest()); // init...
710
711 tpSig->MakePrincipals();
712 tpBkg->MakePrincipals();
713
714 const TMatrixD *covMatS = tpSig->GetCovarianceMatrix();
715 const TMatrixD *covMatB = tpBkg->GetCovarianceMatrix();
716
717 const TMatrixD *corrMatS = gTools().GetCorrelationMatrix(covMatS);
718 const TMatrixD *corrMatB = gTools().GetCorrelationMatrix(covMatB);
719
720 // print correlation matrices
721 if (corrMatS != 0 && corrMatB != 0) {
722
723 // extract MVA matrix
724 TMatrixD mvaMatS(nmeth, nmeth);
725 TMatrixD mvaMatB(nmeth, nmeth);
726 for (Int_t im = 0; im < nmeth; im++) {
727 for (Int_t jm = 0; jm < nmeth; jm++) {
728 mvaMatS(im, jm) = (*corrMatS)(im, jm);
729 mvaMatB(im, jm) = (*corrMatB)(im, jm);
730 }
731 }
732
733 // extract variables - to MVA matrix
734 std::vector<TString> theInputVars;
735 TMatrixD varmvaMatS(nvar, nmeth);
736 TMatrixD varmvaMatB(nvar, nmeth);
737 for (Int_t iv = 0; iv < nvar; iv++) {
738 theInputVars.push_back(method->fDataSetInfo.GetVariableInfo(iv).GetLabel());
739 for (Int_t jm = 0; jm < nmeth; jm++) {
740 varmvaMatS(iv, jm) = (*corrMatS)(nmeth + iv, jm);
741 varmvaMatB(iv, jm) = (*corrMatB)(nmeth + iv, jm);
742 }
743 }
744
745 if (nmeth > 1) {
746 Log() << kINFO << Endl;
747 Log() << kINFO << TString::Format("Dataset[%s] : ", method->fDataSetInfo.GetName())
748 << "Inter-MVA correlation matrix (signal):" << Endl;
749 gTools().FormattedOutput(mvaMatS, *theVars, Log());
750 Log() << kINFO << Endl;
751
752 Log() << kINFO << TString::Format("Dataset[%s] : ", method->fDataSetInfo.GetName())
753 << "Inter-MVA correlation matrix (background):" << Endl;
754 gTools().FormattedOutput(mvaMatB, *theVars, Log());
755 Log() << kINFO << Endl;
756 }
757
758 Log() << kINFO << TString::Format("Dataset[%s] : ", method->fDataSetInfo.GetName())
759 << "Correlations between input variables and MVA response (signal):" << Endl;
760 gTools().FormattedOutput(varmvaMatS, theInputVars, *theVars, Log());
761 Log() << kINFO << Endl;
762
763 Log() << kINFO << TString::Format("Dataset[%s] : ", method->fDataSetInfo.GetName())
764 << "Correlations between input variables and MVA response (background):" << Endl;
765 gTools().FormattedOutput(varmvaMatB, theInputVars, *theVars, Log());
766 Log() << kINFO << Endl;
767 } else
768 Log() << kWARNING << TString::Format("Dataset[%s] : ", method->fDataSetInfo.GetName())
769 << "<TestAllMethods> cannot compute correlation matrices" << Endl;
770
771 // print overlap matrices
772 Log() << kINFO << TString::Format("Dataset[%s] : ", method->fDataSetInfo.GetName())
773 << "The following \"overlap\" matrices contain the fraction of events for which " << Endl;
774 Log() << kINFO << TString::Format("Dataset[%s] : ", method->fDataSetInfo.GetName())
775 << "the MVAs 'i' and 'j' have returned conform answers about \"signal-likeness\"" << Endl;
776 Log() << kINFO << TString::Format("Dataset[%s] : ", method->fDataSetInfo.GetName())
777 << "An event is signal-like, if its MVA output exceeds the following value:" << Endl;
778 gTools().FormattedOutput(rvec, *theVars, "Method", "Cut value", Log());
779 Log() << kINFO << TString::Format("Dataset[%s] : ", method->fDataSetInfo.GetName())
780 << "which correspond to the working point: eff(signal) = 1 - eff(background)" << Endl;
781
782 // give notice that cut method has been excluded from this test
783 if (nmeth != 1)
784 Log() << kINFO << TString::Format("Dataset[%s] : ", method->fDataSetInfo.GetName())
785 << "Note: no correlations and overlap with cut method are provided at present" << Endl;
786
787 if (nmeth > 1) {
788 Log() << kINFO << Endl;
789 Log() << kINFO << TString::Format("Dataset[%s] : ", method->fDataSetInfo.GetName())
790 << "Inter-MVA overlap matrix (signal):" << Endl;
791 gTools().FormattedOutput(*overlapS, *theVars, Log());
792 Log() << kINFO << Endl;
793
794 Log() << kINFO << TString::Format("Dataset[%s] : ", method->fDataSetInfo.GetName())
795 << "Inter-MVA overlap matrix (background):" << Endl;
796 gTools().FormattedOutput(*overlapB, *theVars, Log());
797 }
798
799 // cleanup
800 delete tpSig;
801 delete tpBkg;
802 delete corrMatS;
803 delete corrMatB;
804 delete theVars;
805 delete overlapS;
806 delete overlapB;
807 delete[] dvec;
808 }
809 }
810
811 // -----------------------------------------------------------------------
812 // Third part of evaluation process
813 // --> output
814 // -----------------------------------------------------------------------
815 // putting results in the classification result object
816 auto &fResult = GetResults(methodname, methodtitle);
817
818 // Binary classification
819 if (fROC) {
820 Log().EnableOutput();
822 Log() << Endl;
823 TString hLine = "------------------------------------------------------------------------------------------"
824 "-------------------------";
825 Log() << kINFO << "Evaluation results ranked by best signal efficiency and purity (area)" << Endl;
826 Log() << kINFO << hLine << Endl;
827 Log() << kINFO << "DataSet MVA " << Endl;
828 Log() << kINFO << "Name: Method: ROC-integ" << Endl;
829
830 Log() << kDEBUG << hLine << Endl;
831 for (Int_t k = 0; k < 2; k++) {
832 if (k == 1 && nmeth_used[k] > 0) {
833 Log() << kINFO << hLine << Endl;
834 Log() << kINFO << "Input Variables: " << Endl << hLine << Endl;
835 }
836 for (Int_t i = 0; i < nmeth_used[k]; i++) {
837 TString datasetName = fDataLoader->GetName();
838 TString methodName = mname[k][i];
839
840 if (k == 1) {
841 methodName.ReplaceAll("Variable_", "");
842 }
843
844 TMVA::DataSet *dataset = method->Data();
845 TMVA::Results *results = dataset->GetResults(methodName, Types::kTesting, this->fAnalysisType);
846 std::vector<Bool_t> *mvaResType = dynamic_cast<ResultsClassification *>(results)->GetValueVectorTypes();
847
848 Double_t rocIntegral = 0.0;
849 if (mvaResType->size() != 0) {
850 rocIntegral = GetROCIntegral(methodname, methodtitle);
851 }
852
853 if (sep[k][i] < 0 || sig[k][i] < 0) {
854 // cannot compute separation/significance -> no MVA (usually for Cuts)
855 fResult.fROCIntegral = effArea[k][i];
856 Log() << kINFO
857 << TString::Format("%-13s %-15s: %#1.3f", fDataLoader->GetName(), methodName.Data(), fResult.fROCIntegral)
858 << Endl;
859 } else {
860 fResult.fROCIntegral = rocIntegral;
861 Log() << kINFO << TString::Format("%-13s %-15s: %#1.3f", datasetName.Data(), methodName.Data(), rocIntegral)
862 << Endl;
863 }
864 }
865 }
866 Log() << kINFO << hLine << Endl;
867 Log() << kINFO << Endl;
868 Log() << kINFO << "Testing efficiency compared to training efficiency (overtraining check)" << Endl;
869 Log() << kINFO << hLine << Endl;
870 Log() << kINFO
871 << "DataSet MVA Signal efficiency: from test sample (from training sample) "
872 << Endl;
873 Log() << kINFO << "Name: Method: @B=0.01 @B=0.10 @B=0.30 "
874 << Endl;
875 Log() << kINFO << hLine << Endl;
876 for (Int_t k = 0; k < 2; k++) {
877 if (k == 1 && nmeth_used[k] > 0) {
878 Log() << kINFO << hLine << Endl;
879 Log() << kINFO << "Input Variables: " << Endl << hLine << Endl;
880 }
881 for (Int_t i = 0; i < nmeth_used[k]; i++) {
882 if (k == 1)
883 mname[k][i].ReplaceAll("Variable_", "");
884
885 Log() << kINFO << TString::Format("%-20s %-15s: %#1.3f (%#1.3f) %#1.3f (%#1.3f) %#1.3f (%#1.3f)",
886 method->fDataSetInfo.GetName(), mname[k][i].Data(), eff01[k][i],
887 trainEff01[k][i], eff10[k][i], trainEff10[k][i], eff30[k][i], trainEff30[k][i])
888 << Endl;
889 }
890 }
891 Log() << kINFO << hLine << Endl;
892 Log() << kINFO << Endl;
893
894 if (gTools().CheckForSilentOption(GetOptions()))
895 Log().InhibitOutput();
896 } else if (IsCutsMethod(method)) { // end fROC
897 for (Int_t k = 0; k < 2; k++) {
898 for (Int_t i = 0; i < nmeth_used[k]; i++) {
899
900 if (sep[k][i] < 0 || sig[k][i] < 0) {
901 // cannot compute separation/significance -> no MVA (usually for Cuts)
902 fResult.fROCIntegral = effArea[k][i];
903 }
904 }
905 }
906 }
907
908 TMVA::DataSet *dataset = method->Data();
910
911 if (IsCutsMethod(method)) {
912 fResult.fIsCuts = kTRUE;
913 } else {
914 auto rocCurveTest = GetROC(methodname, methodtitle, 0, Types::kTesting);
915 fResult.fMvaTest[0] = rocCurveTest->GetMvas();
916 fResult.fROCIntegral = GetROCIntegral(methodname, methodtitle);
917 }
918 TString className = method->DataInfo().GetClassInfo(0)->GetName();
919 fResult.fClassNames.push_back(className);
920
921 if (!IsSilentFile()) {
922 // write test/training trees
923 RootBaseDir()->cd(method->fDataSetInfo.GetName());
924 method->fDataSetInfo.GetDataSet()->GetTree(Types::kTesting)->Write("", TObject::kOverwrite);
925 method->fDataSetInfo.GetDataSet()->GetTree(Types::kTraining)->Write("", TObject::kOverwrite);
926 }
927}
928
929//_______________________________________________________________________
930/**
931 * Lets perform test an specific ml method given the method type in enum TMVA::Types::EMVA.
932 * \param method TMVA::Types::EMVA type.
933 * \param methodtitle method title.
934 */
936{
937 TestMethod(Types::Instance().GetMethodName(method), methodtitle);
938}
939
940//_______________________________________________________________________
941/**
942 * Return the vector of TMVA::Experimental::ClassificationResult objects.
943 * \return vector of results.
944 */
945std::vector<TMVA::Experimental::ClassificationResult> &TMVA::Experimental::Classification::GetResults()
946{
947 if (fResults.size() == 0)
948 Log() << kFATAL << "No Classification results available" << Endl;
949 return fResults;
950}
951
952//_______________________________________________________________________
953/**
954 * Allows to check if the ml method is a Cuts method.
955 * \return boolen true if the method is a Cuts method.
956 */
958{
959 return method->GetMethodType() == Types::kCuts ? kTRUE : kFALSE;
960}
961
962//_______________________________________________________________________
963/**
964 * Allow to get result for an specific ml method.
965 * \param methodname name of the method.
966 * \param methodtitle method title.
967 * \return TMVA::Experimental::ClassificationResult object for the method.
968 */
971{
972 for (auto &result : fResults) {
973 if (result.IsMethod(methodname, methodtitle))
974 return result;
975 }
977 result.fMethod["MethodName"] = methodname;
978 result.fMethod["MethodTitle"] = methodtitle;
979 result.fDataLoaderName = fDataLoader->GetName();
980 fResults.push_back(result);
981 return fResults.back();
982}
983
984//_______________________________________________________________________
985/**
986 * Method to get TMVA::ROCCurve Object.
987 * \param method TMVA::MethodBase object
988 * \param iClass category, default 0 then signal
989 * \param type train/test tree, default test.
990 * \return TMVA::ROCCurve object.
991 */
994{
995 TMVA::DataSet *dataset = method->Data();
996 dataset->SetCurrentType(type);
997 TMVA::Results *results = dataset->GetResults(method->GetName(), type, this->fAnalysisType);
998
999 UInt_t nClasses = method->DataInfo().GetNClasses();
1000 if (this->fAnalysisType == Types::kMulticlass && iClass >= nClasses) {
1001 Log() << kERROR << TString::Format("Given class number (iClass = %i) does not exist. There are %i classes in dataset.",
1002 iClass, nClasses)
1003 << Endl;
1004 return nullptr;
1005 }
1006
1007 TMVA::ROCCurve *rocCurve = nullptr;
1008 if (this->fAnalysisType == Types::kClassification) {
1009
1010 std::vector<Float_t> *mvaRes = dynamic_cast<ResultsClassification *>(results)->GetValueVector();
1011 std::vector<Bool_t> *mvaResTypes = dynamic_cast<ResultsClassification *>(results)->GetValueVectorTypes();
1012 std::vector<Float_t> mvaResWeights;
1013
1014 auto eventCollection = dataset->GetEventCollection(type);
1015 mvaResWeights.reserve(eventCollection.size());
1016 for (auto ev : eventCollection) {
1017 mvaResWeights.push_back(ev->GetWeight());
1018 }
1019
1020 rocCurve = new TMVA::ROCCurve(*mvaRes, *mvaResTypes, mvaResWeights);
1021
1022 } else if (this->fAnalysisType == Types::kMulticlass) {
1023 std::vector<Float_t> mvaRes;
1024 std::vector<Bool_t> mvaResTypes;
1025 std::vector<Float_t> mvaResWeights;
1026
1027 std::vector<std::vector<Float_t>> *rawMvaRes = dynamic_cast<ResultsMulticlass *>(results)->GetValueVector();
1028
1029 // Vector transpose due to values being stored as
1030 // [ [0, 1, 2], [0, 1, 2], ... ]
1031 // in ResultsMulticlass::GetValueVector.
1032 mvaRes.reserve(rawMvaRes->size());
1033 for (auto item : *rawMvaRes) {
1034 mvaRes.push_back(item[iClass]);
1035 }
1036
1037 auto eventCollection = dataset->GetEventCollection(type);
1038 mvaResTypes.reserve(eventCollection.size());
1039 mvaResWeights.reserve(eventCollection.size());
1040 for (auto ev : eventCollection) {
1041 mvaResTypes.push_back(ev->GetClass() == iClass);
1042 mvaResWeights.push_back(ev->GetWeight());
1043 }
1044
1045 rocCurve = new TMVA::ROCCurve(mvaRes, mvaResTypes, mvaResWeights);
1046 }
1047
1048 return rocCurve;
1049}
1050
1051//_______________________________________________________________________
1052/**
1053 * Method to get TMVA::ROCCurve Object.
1054 * \param methodname ml method name.
1055 * \param methodtitle ml method title.
1056 * \param iClass category, default 0 then signal
1057 * \param type train/test tree, default test.
1058 * \return TMVA::ROCCurve object.
1059 */
1062{
1063 return GetROC(GetMethod(methodname, methodtitle), iClass, type);
1064}
1065
1066//_______________________________________________________________________
1067/**
1068 * Method to get ROC-Integral value from mvas.
1069 * \param methodname ml method name.
1070 * \param methodtitle ml method title.
1071 * \param iClass category, default 0 then signal
1072 * \return Double_t with the ROC-Integral value.
1073 */
1075{
1076 TMVA::ROCCurve *rocCurve = GetROC(methodname, methodtitle, iClass);
1077 if (!rocCurve) {
1078 Log() << kFATAL
1079 << TString::Format("ROCCurve object was not created in MethodName = %s MethodTitle = %s not found with Dataset = %s ",
1080 methodname.Data(), methodtitle.Data(), fDataLoader->GetName())
1081 << Endl;
1082 return 0;
1083 }
1084
1086 Double_t rocIntegral = rocCurve->GetROCIntegral(npoints);
1087 delete rocCurve;
1088
1089 return rocIntegral;
1090}
1091
1092//_______________________________________________________________________
1094{
1095 TFile *savdir = file;
1096 TDirectory *adir = savdir;
1097 adir->cd();
1098 // loop on all entries of this directory
1099 TKey *key;
1100 TIter nextkey(src->GetListOfKeys());
1101 while ((key = (TKey *)nextkey())) {
1102 const Char_t *classname = key->GetClassName();
1103 TClass *cl = gROOT->GetClass(classname);
1104 if (!cl)
1105 continue;
1106 if (cl->InheritsFrom(TDirectory::Class())) {
1107 src->cd(key->GetName());
1108 TDirectory *subdir = file;
1109 adir->cd();
1110 CopyFrom(subdir, file);
1111 adir->cd();
1112 } else if (cl->InheritsFrom(TTree::Class())) {
1113 TTree *T = (TTree *)src->Get(key->GetName());
1114 adir->cd();
1115 TTree *newT = T->CloneTree(-1, "fast");
1116 newT->Write();
1117 } else {
1118 src->cd();
1119 TObject *obj = key->ReadObj();
1120 adir->cd();
1121 obj->Write();
1122 delete obj;
1123 }
1124 }
1125 adir->SaveSelf(kTRUE);
1126 savdir->cd();
1127}
1128
1129//_______________________________________________________________________
1131{
1132
1133 auto dsdir = fFile->mkdir(fDataLoader->GetName()); // dataset dir
1134 TTree *TrainTree = nullptr;
1135 TTree *TestTree = nullptr;
1136 TFile *ifile = nullptr;
1137 TFile *ofile = nullptr;
1138 for (UInt_t i = 0; i < fMethods.size(); i++) {
1139 auto methodname = fMethods[i].GetValue<TString>("MethodName");
1140 auto methodtitle = fMethods[i].GetValue<TString>("MethodTitle");
1141 auto fname = TString::Format(".%s%s%s.root", fDataLoader->GetName(), methodname.Data(), methodtitle.Data());
1142 TDirectoryFile *ds = nullptr;
1143 if (i == 0) {
1144 ifile = new TFile(fname.Data());
1145 ds = (TDirectoryFile *)ifile->Get(fDataLoader->GetName());
1146 } else {
1147 ofile = new TFile(fname.Data());
1148 ds = (TDirectoryFile *)ofile->Get(fDataLoader->GetName());
1149 }
1150 auto tmptrain = (TTree *)ds->Get("TrainTree");
1151 auto tmptest = (TTree *)ds->Get("TestTree");
1152 fFile->cd();
1153 fFile->cd(fDataLoader->GetName());
1154
1155 auto methdirname = TString::Format("Method_%s", methodtitle.Data());
1156 auto methdir = dsdir->mkdir(methdirname.Data(), methdirname.Data());
1157 auto methdirbase = methdir->mkdir(methodtitle.Data(), methodtitle.Data());
1158 auto mfdir = (TDirectoryFile *)ds->Get(methdirname.Data());
1159 auto mfdirbase = (TDirectoryFile *)mfdir->Get(methodtitle.Data());
1160
1161 CopyFrom(mfdirbase, (TFile *)methdirbase);
1162 dsdir->cd();
1163 if (i == 0) {
1164 TrainTree = tmptrain->CopyTree("");
1165 TestTree = tmptest->CopyTree("");
1166 } else {
1167 Float_t mva = 0;
1168 auto trainbranch = TrainTree->Branch(methodtitle.Data(), &mva);
1169 tmptrain->SetBranchAddress(methodtitle.Data(), &mva);
1170 auto entries = tmptrain->GetEntries();
1171 for (UInt_t ev = 0; ev < entries; ev++) {
1172 tmptrain->GetEntry(ev);
1173 trainbranch->Fill();
1174 }
1175 auto testbranch = TestTree->Branch(methodtitle.Data(), &mva);
1176 tmptest->SetBranchAddress(methodtitle.Data(), &mva);
1177 entries = tmptest->GetEntries();
1178 for (UInt_t ev = 0; ev < entries; ev++) {
1179 tmptest->GetEntry(ev);
1180 testbranch->Fill();
1181 }
1182 ofile->Close();
1183 }
1184 }
1185 TrainTree->Write();
1186 TestTree->Write();
1187 ifile->Close();
1188 // cleaning
1189 for (UInt_t i = 0; i < fMethods.size(); i++) {
1190 auto methodname = fMethods[i].GetValue<TString>("MethodName");
1191 auto methodtitle = fMethods[i].GetValue<TString>("MethodTitle");
1192 auto fname = TString::Format(".%s%s%s.root", fDataLoader->GetName(), methodname.Data(), methodtitle.Data());
1193 gSystem->Unlink(fname.Data());
1194 }
1195}
#define MinNoTrainingEvents
#define f(i)
Definition RSha256.hxx:104
char Char_t
Definition RtypesCore.h:37
float Float_t
Definition RtypesCore.h:57
constexpr Bool_t kFALSE
Definition RtypesCore.h:101
double Double_t
Definition RtypesCore.h:59
constexpr Bool_t kTRUE
Definition RtypesCore.h:100
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t r
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t result
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t index
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t src
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t Atom_t Time_t type
TMatrixT< Double_t > TMatrixD
Definition TMatrixDfwd.h:23
#define gROOT
Definition TROOT.h:407
R__EXTERN TSystem * gSystem
Definition TSystem.h:560
A pseudo container class which is a generator of indices.
Definition TSeq.hxx:67
Long64_t GetEntries() const
Definition TBranch.h:251
TClass instances represent classes, structs and namespaces in the ROOT type system.
Definition TClass.h:81
Bool_t InheritsFrom(const char *cl) const override
Return kTRUE if this class inherits from a class with name "classname".
Definition TClass.cxx:4874
static TClass * GetClass(const char *name, Bool_t load=kTRUE, Bool_t silent=kFALSE)
Static method returning pointer to TClass of the specified class name.
Definition TClass.cxx:2968
A ROOT file is structured in Directories (like a file system).
Bool_t cd() override
Change current directory to "this" directory.
TObject * Get(const char *namecycle) override
Return pointer to object identified by namecycle.
Describe directory structure in memory.
Definition TDirectory.h:45
static TClass * Class()
virtual Bool_t cd()
Change current directory to "this" directory.
virtual void SaveSelf(Bool_t=kFALSE)
Definition TDirectory.h:255
A ROOT file is composed of a header, followed by consecutive data records (TKey instances) with a wel...
Definition TFile.h:53
void Close(Option_t *option="") override
Close a file.
Definition TFile.cxx:936
A TGraph is an object made of two arrays X and Y with npoints each.
Definition TGraph.h:41
void SetName(const char *name="") override
Set graph name.
Definition TGraph.cxx:2354
TAxis * GetXaxis() const
Get x axis of the graph.
Definition TGraph.cxx:1540
TAxis * GetYaxis() const
Get y axis of the graph.
Definition TGraph.cxx:1549
void SetTitle(const char *title="") override
Change (i.e.
Definition TGraph.cxx:2370
Book space in a file, create I/O buffers, to fill them, (un)compress them.
Definition TKey.h:28
virtual const char * GetClassName() const
Definition TKey.h:75
virtual TObject * ReadObj()
To read a TObject* from the file.
Definition TKey.cxx:758
IMethod * Create(const std::string &name, const TString &job, const TString &title, DataSetInfo &dsi, const TString &option)
creates the method if needed based on the method name using the creator function the factory has stor...
static ClassifierFactory & Instance()
access to the ClassifierFactory singleton creates the instance if needed
TString fWeightFileDir
Definition Config.h:124
void SetDrawProgressBar(Bool_t d)
Definition Config.h:69
void SetUseColor(Bool_t uc)
Definition Config.h:60
class TMVA::Config::VariablePlotting fVariablePlotting
void SetSilent(Bool_t s)
Definition Config.h:63
IONames & GetIONames()
Definition Config.h:98
void SetConfigDescription(const char *d)
OptionBase * DeclareOptionRef(T &ref, const TString &name, const TString &desc="")
void SetConfigName(const char *n)
virtual void ParseOptions()
options parser
void CheckForUnusedOptions() const
checks for unused options in option string
UInt_t GetNClasses() const
Class that contains all the data information.
Definition DataSet.h:58
Long64_t GetNEvtSigTest()
return number of signal test events in dataset
Definition DataSet.cxx:427
const Event * GetEvent() const
returns event without transformations
Definition DataSet.cxx:202
Long64_t GetNEvents(Types::ETreeType type=Types::kMaxTreeType) const
Definition DataSet.h:206
Results * GetResults(const TString &, Types::ETreeType type, Types::EAnalysisType analysistype)
Definition DataSet.cxx:265
void SetCurrentType(Types::ETreeType type) const
Definition DataSet.h:89
const std::vector< Event * > & GetEventCollection(Types::ETreeType type=Types::kMaxTreeType) const
Definition DataSet.h:216
Long64_t GetNEvtBkgdTest()
return number of background test events in dataset
Definition DataSet.cxx:435
Abstract base class for all high level ml algorithms, you can book ml methods like BDT,...
Definition Envelope.h:44
Bool_t fModelPersistence
! flag to save the trained model
Definition Envelope.h:49
std::shared_ptr< DataLoader > fDataLoader
! data
Definition Envelope.h:47
virtual void ParseOptions()
Method to parse the internal option string.
Definition Envelope.cxx:182
Float_t GetValue(UInt_t ivar) const
return value of i'th variable
Definition Event.cxx:236
static void SetIsTraining(Bool_t)
when this static function is called, it sets the flag whether events with negative event weight shoul...
Definition Event.cxx:399
Double_t GetROCIntegral(UInt_t iClass=0, TMVA::Types::ETreeType type=TMVA::Types::kTesting)
Method to get ROC-Integral value from mvas.
TGraph * GetROCGraph(UInt_t iClass=0, TMVA::Types::ETreeType type=TMVA::Types::kTesting)
Method to get TGraph object with the ROC curve.
void Show()
Method to print the results in stdout.
Bool_t IsMethod(TString methodname, TString methodtitle)
Method to check if method was booked.
std::map< UInt_t, std::vector< std::tuple< Float_t, Float_t, Bool_t > > > fMvaTest
Mvas for two-class and multiclass classification.
ROCCurve * GetROC(UInt_t iClass=0, TMVA::Types::ETreeType type=TMVA::Types::kTesting)
Method to get TMVA::ROCCurve Object.
Bool_t fIsCuts
if it is a method cuts need special output
ClassificationResult & operator=(const ClassificationResult &r)
std::map< UInt_t, std::vector< std::tuple< Float_t, Float_t, Bool_t > > > fMvaTrain
Mvas for two-class classification.
Classification(DataLoader *loader, TFile *file, TString options)
Contructor to create a two class classifier.
Double_t GetROCIntegral(TString methodname, TString methodtitle, UInt_t iClass=0)
Method to get ROC-Integral value from mvas.
virtual void Test()
Perform test evaluation in all booked methods.
TString GetMethodOptions(TString methodname, TString methodtitle)
return the options for the booked method.
MethodBase * GetMethod(TString methodname, TString methodtitle)
Return a TMVA::MethodBase object.
virtual void TrainMethod(TString methodname, TString methodtitle)
Lets train an specific ml method.
Bool_t HasMethodObject(TString methodname, TString methodtitle, Int_t &index)
Allows to check if the TMVA::MethodBase was created and return the index in the vector.
std::vector< ClassificationResult > & GetResults()
Return the vector of TMVA::Experimental::ClassificationResult objects.
virtual void Train()
Method to train all booked ml methods.
virtual void Evaluate()
Method to perform Train/Test over all ml method booked.
Types::EAnalysisType fAnalysisType
!
TMVA::ROCCurve * GetROC(TMVA::MethodBase *method, UInt_t iClass=0, TMVA::Types::ETreeType type=TMVA::Types::kTesting)
Method to get TMVA::ROCCurve Object.
Bool_t IsCutsMethod(TMVA::MethodBase *method)
Allows to check if the ml method is a Cuts method.
void CopyFrom(TDirectory *src, TFile *file)
virtual void TestMethod(TString methodname, TString methodtitle)
Lets perform test an specific ml method.
Interface for all concrete MVA method implementations.
Definition IMethod.h:53
virtual Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets)=0
Virtual base Class for all MVA method.
Definition MethodBase.h:111
void SetSilentFile(Bool_t status)
Definition MethodBase.h:378
void SetWeightFileDir(TString fileDir)
set directory of weight file
TString GetMethodTypeName() const
Definition MethodBase.h:332
const char * GetName() const
Definition MethodBase.h:334
const TString & GetTestvarName() const
Definition MethodBase.h:335
void SetupMethod()
setup of methods
virtual void SetAnalysisType(Types::EAnalysisType type)
Definition MethodBase.h:436
const TString & GetMethodName() const
Definition MethodBase.h:331
void ProcessSetup()
process all options the "CheckForUnusedOptions" is done in an independent call, since it may be overr...
DataSetInfo & DataInfo() const
Definition MethodBase.h:410
Types::EMVA GetMethodType() const
Definition MethodBase.h:333
void SetFile(TFile *file)
Definition MethodBase.h:375
DataSet * Data() const
Definition MethodBase.h:409
void SetModelPersistence(Bool_t status)
Definition MethodBase.h:382
Double_t GetSignalReferenceCut() const
Definition MethodBase.h:360
virtual void CheckSetup()
check may be overridden by derived class (sometimes, eg, fitters are used which can only be implement...
Class for boosting a TMVA method.
Definition MethodBoost.h:58
void SetBoostedMethodName(TString methodName)
Definition MethodBoost.h:86
DataSetManager * fDataSetManager
DSMTEST.
Class for categorizing the phase space.
DataSetManager * fDataSetManager
ostringstream derivative to redirect and format output
Definition MsgLogger.h:57
static void InhibitOutput()
Definition MsgLogger.cxx:67
static void EnableOutput()
Definition MsgLogger.cxx:68
Double_t GetROCIntegral(const UInt_t points=41)
Calculates the ROC integral (AUC)
Definition ROCCurve.cxx:250
Class that is the base-class for a vector of result.
Class which takes the results of a multiclass classification.
Class that is the base-class for a vector of result.
Definition Results.h:57
void FormattedOutput(const std::vector< Double_t > &, const std::vector< TString > &, const TString titleVars, const TString titleValues, MsgLogger &logger, TString format="%+1.3f")
formatted output of simple table
Definition Tools.cxx:887
void UsefulSortDescending(std::vector< std::vector< Double_t > > &, std::vector< TString > *vs=nullptr)
sort 2D vector (AND in parallel a TString vector) in such a way that the "first vector is sorted" and...
Definition Tools.cxx:564
const TString & Color(const TString &)
human readable color strings
Definition Tools.cxx:828
const TMatrixD * GetCorrelationMatrix(const TMatrixD *covMat)
turns covariance into correlation matrix
Definition Tools.cxx:324
Singleton class for Global types used by TMVA.
Definition Types.h:71
static Types & Instance()
The single instance of "Types" if existing already, or create it (Singleton)
Definition Types.cxx:70
@ kCategory
Definition Types.h:97
@ kMulticlass
Definition Types.h:129
@ kClassification
Definition Types.h:127
@ kMaxAnalysisType
Definition Types.h:131
@ kTraining
Definition Types.h:143
virtual void SetTitle(const char *title="")
Set the title of the TNamed.
Definition TNamed.cxx:164
const char * GetName() const override
Returns name of object.
Definition TNamed.h:47
Mother of all ROOT objects.
Definition TObject.h:41
@ kOverwrite
overwrite existing object with same name
Definition TObject.h:92
virtual Int_t Write(const char *name=nullptr, Int_t option=0, Int_t bufsize=0)
Write this object to the current directory.
Definition TObject.cxx:880
Principal Components Analysis (PCA)
Definition TPrincipal.h:21
virtual void AddRow(const Double_t *x)
Add a data point and update the covariance matrix.
const TMatrixD * GetCovarianceMatrix() const
Definition TPrincipal.h:59
virtual void MakePrincipals()
Perform the principal components analysis.
Basic string class.
Definition TString.h:139
const char * Data() const
Definition TString.h:380
TString & ReplaceAll(const TString &s1, const TString &s2)
Definition TString.h:704
static TString Format(const char *fmt,...)
Static method which formats a string using a printf style format descriptor and return a TString.
Definition TString.cxx:2356
virtual int MakeDirectory(const char *name)
Make a directory.
Definition TSystem.cxx:814
virtual int Unlink(const char *name)
Unlink, i.e.
Definition TSystem.cxx:1368
A TTree represents a columnar dataset.
Definition TTree.h:79
virtual TTree * CopyTree(const char *selection, Option_t *option="", Long64_t nentries=kMaxEntries, Long64_t firstentry=0)
Copy a tree with selection.
Definition TTree.cxx:3716
TBranch * Branch(const char *name, T *obj, Int_t bufsize=32000, Int_t splitlevel=99)
Add a new branch, and infer the data type from the type of obj being passed.
Definition TTree.h:353
Int_t Write(const char *name=nullptr, Int_t option=0, Int_t bufsize=0) override
Write this object to the current directory.
Definition TTree.cxx:9740
static TClass * Class()
create variable transformations
Config & gConfig()
Tools & gTools()
MsgLogger & Endl(MsgLogger &ml)
Definition MsgLogger.h:148
Bool_t IsNaN(Double_t x)
Definition TMath.h:892
Definition file.py:1
TMarker m
Definition textangle.C:8