Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
Classification.cxx
Go to the documentation of this file.
1// @(#)root/tmva $Id$
2// Author: Omar Zapata, Andreas Hoecker, Peter Speckmayer, Joerg Stelzer, Helge Voss, Kai Voss, Eckhard von Toerne, Jan
3// Therhaag
4
6
8#include <TMVA/Config.h>
9#include <TMVA/Configurable.h>
10#include <TMVA/Tools.h>
11#include <TMVA/Ranking.h>
12#include <TMVA/DataSet.h>
13#include <TMVA/IMethod.h>
14#include <TMVA/MethodBase.h>
16#include <TMVA/DataSetManager.h>
17#include <TMVA/DataSetInfo.h>
18#include <TMVA/DataLoader.h>
19#include <TMVA/MethodBoost.h>
20#include <TMVA/MethodCategory.h>
21#include <TMVA/ROCCalc.h>
22#include <TMVA/ROCCurve.h>
23#include <TMVA/MsgLogger.h>
24
25#include <TMVA/VariableInfo.h>
27
28#include <TMVA/Types.h>
29
30#include <TROOT.h>
31#include <TFile.h>
32#include <TKey.h>
33#include <TTree.h>
34#include <TKey.h>
35#include <TLeaf.h>
36#include <TBranch.h>
37#include <TEventList.h>
38#include <TGraph.h>
39#include <TMatrixF.h>
40#include <TMatrixDSym.h>
41#include <TMultiGraph.h>
42#include <TPrincipal.h>
43#include <TMath.h>
44#include <TSystem.h>
45
46#include <iostream>
47#include <memory>
48
49#define MinNoTrainingEvents 10
50
51//_______________________________________________________________________
53{
54}
55
56//_______________________________________________________________________
58{
59 fMethod = cr.fMethod;
62 fMvaTest = cr.fMvaTest;
63 fIsCuts = cr.fIsCuts;
65}
66
67//_______________________________________________________________________
68/**
69 * Method to get ROC-Integral value from mvas.
70 * \param iClass category, default 0 then signal
71 * \param type train/test tree, default test.
72 * \return Double_t with the ROC-Integral value.
73 */
75{
76 if (fIsCuts) {
77 return fROCIntegral;
78 } else {
79 auto roc = GetROC(iClass, type);
80 auto inte = roc->GetROCIntegral();
81 delete roc;
82 return inte;
83 }
84}
85
86//_______________________________________________________________________
87/**
88 * Method to get TMVA::ROCCurve Object.
89 * \param iClass category, default 0 then signal
90 * \param type train/test tree, default test.
91 * \return TMVA::ROCCurve object.
92 */
94{
95 ROCCurve *fROCCurve = nullptr;
97 fROCCurve = new ROCCurve(fMvaTest[iClass]);
98 else
99 fROCCurve = new ROCCurve(fMvaTrain[iClass]);
100 return fROCCurve;
101}
102
103//_______________________________________________________________________
106{
107 fMethod = cr.fMethod;
108 fDataLoaderName = cr.fDataLoaderName;
109 fMvaTrain = cr.fMvaTrain;
110 fMvaTest = cr.fMvaTest;
111 fIsCuts = cr.fIsCuts;
112 fROCIntegral = cr.fROCIntegral;
113 return *this;
114}
115
116//_______________________________________________________________________
117/**
118 * Method to print the results in stdout.
119 * data loader name, method name/tittle and ROC-integ.
120 */
122{
123 MsgLogger fLogger("Classification");
126 TString hLine = "--------------------------------------------------- :";
127
128 fLogger << kINFO << hLine << Endl;
129 fLogger << kINFO << "DataSet MVA :" << Endl;
130 fLogger << kINFO << "Name: Method/Title: ROC-integ :" << Endl;
131 fLogger << kINFO << hLine << Endl;
132 fLogger << kINFO << Form("%-20s %-15s %#1.3f :", fDataLoaderName.Data(),
133 Form("%s/%s", fMethod.GetValue<TString>("MethodName").Data(),
134 fMethod.GetValue<TString>("MethodTitle").Data()),
135 GetROCIntegral())
136 << Endl;
137 fLogger << kINFO << hLine << Endl;
138
140}
141
142//_______________________________________________________________________
143/**
144 * Method to get TGraph object with the ROC curve.
145 * \param iClass category, default 0 then signal
146 * \param type train/test tree, default test.
147 * \return TGraph object.
148 */
150{
151 TGraph *roc = GetROC(iClass, type)->GetROCCurve();
152 roc->SetName(Form("%s/%s", GetMethodName().Data(), GetMethodTitle().Data()));
153 roc->SetTitle(Form("%s/%s", GetMethodName().Data(), GetMethodTitle().Data()));
154 roc->GetXaxis()->SetTitle(" Signal Efficiency ");
155 roc->GetYaxis()->SetTitle(" Background Rejection ");
156 return roc;
157}
158
159//_______________________________________________________________________
160/**
161 * Method to check if method was booked.
162 * \param methodname name of the method.
163 * \param methodtitle method title.
164 * \return boolean true if the method was booked, false in other case.
165 */
167{
168 return fMethod.GetValue<TString>("MethodName") == methodname &&
169 fMethod.GetValue<TString>("MethodTitle") == methodtitle
170 ? kTRUE
171 : kFALSE;
172}
173
174//_______________________________________________________________________
175/**
176 * Contructor to create a two class classifier.
177 * \param dataloader TMVA::DataLoader object with the data to train/test.
178 * \param file TFile object to save the results
179 * \param options string extra options.
180 */
182 : TMVA::Envelope("Classification", dataloader, file, options), fAnalysisType(Types::kClassification),
183 fCorrelations(kFALSE), fROC(kTRUE)
184{
185 DeclareOptionRef(fCorrelations, "Correlations", "boolean to show correlation in output");
186 DeclareOptionRef(fROC, "ROC", "boolean to show ROC in output");
187 ParseOptions();
189
191 gSystem->MakeDirectory(fDataLoader->GetName()); // creating directory for DataLoader output
192}
193
194//_______________________________________________________________________
195/**
196 * Contructor to create a two class classifier without output file.
197 * \param dataloader TMVA::DataLoader object with the data to train/test.
198 * \param options string extra options.
199 */
201 : TMVA::Envelope("Classification", dataloader, NULL, options), fAnalysisType(Types::kClassification),
202 fCorrelations(kFALSE), fROC(kTRUE)
203{
204
205 // init configurable
206 SetConfigDescription("Configuration options for Classification running");
208
209 DeclareOptionRef(fCorrelations, "Correlations", "boolean to show correlation in output");
210 DeclareOptionRef(fROC, "ROC", "boolean to show ROC in output");
211 ParseOptions();
214 gSystem->MakeDirectory(fDataLoader->GetName()); // creating directory for DataLoader output
216}
217
218//_______________________________________________________________________
220{
221 for (auto m : fIMethods) {
222 if (m != NULL)
223 delete m;
224 }
225}
226
227//_______________________________________________________________________
228/**
229 * return the options for the booked method.
230 * \param methodname name of the method.
231 * \param methodtitle method title.
232 * \return string the with options for the ml method.
233 */
235{
236 for (auto &meth : fMethods) {
237 if (meth.GetValue<TString>("MethodName") == methodname && meth.GetValue<TString>("MethodTitle") == methodtitle)
238 return meth.GetValue<TString>("MethodOptions");
239 }
240 return "";
241}
242
243//_______________________________________________________________________
244/**
245 * Method to perform Train/Test over all ml method booked.
246 * If the option Jobs > 1 can do it in parallel with MultiProc.
247 */
249{
250 fTimer.Reset();
251 fTimer.Start();
252
253 Bool_t roc = fROC;
254 fROC = kFALSE;
255 if (fJobs <= 1) {
256 Train();
257 Test();
258 } else {
259 for (auto &meth : fMethods) {
260 GetMethod(meth.GetValue<TString>("MethodName"), meth.GetValue<TString>("MethodTitle"));
261 }
262#ifndef _MSC_VER
263 fWorkers.SetNWorkers(fJobs);
264#endif
265 auto executor = [this](UInt_t workerID) -> ClassificationResult {
270 auto methodname = fMethods[workerID].GetValue<TString>("MethodName");
271 auto methodtitle = fMethods[workerID].GetValue<TString>("MethodTitle");
272 auto meth = GetMethod(methodname, methodtitle);
273 if (!IsSilentFile()) {
274 auto fname = Form(".%s%s%s.root", fDataLoader->GetName(), methodname.Data(), methodtitle.Data());
275 auto f = new TFile(fname, "RECREATE");
276 f->mkdir(fDataLoader->GetName());
277 SetFile(f);
278 meth->SetFile(f);
279 }
280 TrainMethod(methodname, methodtitle);
281 TestMethod(methodname, methodtitle);
282 if (!IsSilentFile()) {
283 GetFile()->Close();
284 }
285 return GetResults(methodname, methodtitle);
286 };
287
288#ifndef _MSC_VER
289 fResults = fWorkers.Map(executor, ROOT::TSeqI(fMethods.size()));
290#endif
291 if (!IsSilentFile())
292 MergeFiles();
293 }
294
295 fROC = roc;
297
298 TString hLine = "--------------------------------------------------- :";
299 Log() << kINFO << hLine << Endl;
300 Log() << kINFO << "DataSet MVA :" << Endl;
301 Log() << kINFO << "Name: Method/Title: ROC-integ :" << Endl;
302 Log() << kINFO << hLine << Endl;
303 for (auto &r : fResults) {
304
305 Log() << kINFO << Form("%-20s %-15s %#1.3f :", r.GetDataLoaderName().Data(),
306 Form("%s/%s", r.GetMethodName().Data(), r.GetMethodTitle().Data()), r.GetROCIntegral())
307 << Endl;
308 }
309 Log() << kINFO << hLine << Endl;
310
311 Log() << kINFO << "-----------------------------------------------------" << Endl;
312 Log() << kHEADER << "Evaluation done." << Endl << Endl;
313 Log() << kINFO << Form("Jobs = %d Real Time = %lf ", fJobs, fTimer.RealTime()) << Endl;
314 Log() << kINFO << "-----------------------------------------------------" << Endl;
315 Log() << kINFO << "Evaluation done." << Endl;
317}
318
319//_______________________________________________________________________
320/**
321 * Method to train all booked ml methods.
322 */
324{
325 for (auto &meth : fMethods) {
326 TrainMethod(meth.GetValue<TString>("MethodName"), meth.GetValue<TString>("MethodTitle"));
327 }
328}
329
330//_______________________________________________________________________
331/**
332 * Lets train an specific ml method.
333 * \param methodname name of the method.
334 * \param methodtitle method title.
335 */
337{
338 auto method = GetMethod(methodname, methodtitle);
339 if (!method) {
340 Log() << kFATAL
341 << Form("Trying to train method %s %s that maybe is not booked.", methodname.Data(), methodtitle.Data())
342 << Endl;
343 }
344 Log() << kHEADER << gTools().Color("bold") << Form("Training method %s %s", methodname.Data(), methodtitle.Data())
345 << gTools().Color("reset") << Endl;
346
348 if ((fAnalysisType == Types::kMulticlass || fAnalysisType == Types::kClassification) &&
349 method->DataInfo().GetNClasses() < 2)
350 Log() << kFATAL << "You want to do classification training, but specified less than two classes." << Endl;
351
352 // first print some information about the default dataset
353 // if(!IsSilentFile()) WriteDataInformation(method->fDataSetInfo);
354
355 if (method->Data()->GetNTrainingEvents() < MinNoTrainingEvents) {
356 Log() << kWARNING << "Method " << method->GetMethodName() << " not trained (training tree has less entries ["
357 << method->Data()->GetNTrainingEvents() << "] than required [" << MinNoTrainingEvents << "]" << Endl;
358 return;
359 }
360
361 Log() << kHEADER << "Train method: " << method->GetMethodName() << " for Classification" << Endl << Endl;
362 method->TrainMethod();
363 Log() << kHEADER << "Training finished" << Endl << Endl;
364}
365
366//_______________________________________________________________________
367/**
368 * Lets train an specific ml method given the method type in enum TMVA::Types::EMVA
369 * \param method TMVA::Types::EMVA type.
370 * \param methodtitle method title.
371 */
373{
374 TrainMethod(Types::Instance().GetMethodName(method), methodtitle);
375}
376
377//_______________________________________________________________________
378/**
379 * Return a TMVA::MethodBase object. if method is not booked then return a null
380 * pointer.
381 * \param methodname name of the method.
382 * \param methodtitle method title.
383 * \return TMVA::MethodBase object
384 */
386{
387
388 if (!HasMethod(methodname, methodtitle)) {
389 std::cout << methodname << " " << methodtitle << std::endl;
390 Log() << kERROR << "Trying to get method not booked." << Endl;
391 return 0;
392 }
393 Int_t index = -1;
394 if (HasMethodObject(methodname, methodtitle, index)) {
395 return dynamic_cast<MethodBase *>(fIMethods[index]);
396 }
397 // if is not created then lets to create it.
398 if (GetDataLoaderDataInput().GetEntries() <=
399 1) { // 0 entries --> 0 events, 1 entry --> dynamical dataset (or one entry)
400 Log() << kFATAL << "No input data for the training provided!" << Endl;
401 }
402 Log() << kHEADER << "Loading booked method: " << gTools().Color("bold") << methodname << " " << methodtitle
403 << gTools().Color("reset") << Endl << Endl;
404
405 TString moptions = GetMethodOptions(methodname, methodtitle);
406
407 // interpret option string with respect to a request for boosting (i.e., BostNum > 0)
408 Int_t boostNum = 0;
409 auto conf = new TMVA::Configurable(moptions);
410 conf->DeclareOptionRef(boostNum = 0, "Boost_num", "Number of times the classifier will be boosted");
411 conf->ParseOptions();
412 delete conf;
413
414 TString fFileDir;
415 if (fModelPersistence) {
416 fFileDir = fDataLoader->GetName();
417 fFileDir += "/" + gConfig().GetIONames().fWeightFileDir;
418 }
419
420 // initialize methods
421 IMethod *im;
422 TString fJobName = GetName();
423 if (!boostNum) {
424 im = ClassifierFactory::Instance().Create(std::string(methodname.Data()), fJobName, methodtitle,
425 GetDataLoaderDataSetInfo(), moptions);
426 } else {
427 // boosted classifier, requires a specific definition, making it transparent for the user
428 Log() << kDEBUG << "Boost Number is " << boostNum << " > 0: train boosted classifier" << Endl;
429 im = ClassifierFactory::Instance().Create(std::string("Boost"), fJobName, methodtitle, GetDataLoaderDataSetInfo(),
430 moptions);
431 MethodBoost *methBoost = dynamic_cast<MethodBoost *>(im);
432 if (!methBoost)
433 Log() << kFATAL << "Method with type kBoost cannot be casted to MethodCategory. /Classification" << Endl;
434
435 if (fModelPersistence)
436 methBoost->SetWeightFileDir(fFileDir);
437 methBoost->SetModelPersistence(fModelPersistence);
438 methBoost->SetBoostedMethodName(methodname);
439 methBoost->fDataSetManager = GetDataLoaderDataSetManager();
440 methBoost->SetFile(fFile.get());
441 methBoost->SetSilentFile(IsSilentFile());
442 }
443
444 MethodBase *method = dynamic_cast<MethodBase *>(im);
445 if (method == 0)
446 return 0; // could not create method
447
448 // set fDataSetManager if MethodCategory (to enable Category to create datasetinfo objects)
449 if (method->GetMethodType() == Types::kCategory) {
450 MethodCategory *methCat = (dynamic_cast<MethodCategory *>(im));
451 if (!methCat)
452 Log() << kFATAL << "Method with type kCategory cannot be casted to MethodCategory. /Classification" << Endl;
453
454 if (fModelPersistence)
455 methCat->SetWeightFileDir(fFileDir);
456 methCat->SetModelPersistence(fModelPersistence);
457 methCat->fDataSetManager = GetDataLoaderDataSetManager();
458 methCat->SetFile(fFile.get());
459 methCat->SetSilentFile(IsSilentFile());
460 }
461
462 if (!method->HasAnalysisType(fAnalysisType, GetDataLoaderDataSetInfo().GetNClasses(),
463 GetDataLoaderDataSetInfo().GetNTargets())) {
464 Log() << kWARNING << "Method " << method->GetMethodTypeName() << " is not capable of handling ";
465 Log() << "classification with " << GetDataLoaderDataSetInfo().GetNClasses() << " classes." << Endl;
466 return 0;
467 }
468
469 if (fModelPersistence)
470 method->SetWeightFileDir(fFileDir);
471 method->SetModelPersistence(fModelPersistence);
472 method->SetAnalysisType(fAnalysisType);
473 method->SetupMethod();
474 method->ParseOptions();
475 method->ProcessSetup();
476 method->SetFile(fFile.get());
477 method->SetSilentFile(IsSilentFile());
478
479 // check-for-unused-options is performed; may be overridden by derived classes
480 method->CheckSetup();
481 fIMethods.push_back(method);
482 return method;
483}
484
485//_______________________________________________________________________
486/**
487 * Allows to check if the TMVA::MethodBase was created and return the index in the vector.
488 * \param methodname name of the method.
489 * \param methodtitle method title.
490 * \param index refrence to Int_t with the position of the method into the vector fIMethods
491 * \return boolean true if the method was found.
492 */
494{
495 if (fIMethods.empty())
496 return kFALSE;
497 for (UInt_t i = 0; i < fIMethods.size(); i++) {
498 // they put method title like method name in MethodBase and type is type name
499 auto methbase = dynamic_cast<MethodBase *>(fIMethods[i]);
500 if (methbase->GetMethodTypeName() == methodname && methbase->GetMethodName() == methodtitle) {
501 index = i;
502 return kTRUE;
503 }
504 }
505 return kFALSE;
506}
507
508//_______________________________________________________________________
509/**
510 * Perform test evaluation in all booked methods.
511 */
513{
514 for (auto &meth : fMethods) {
515 TestMethod(meth.GetValue<TString>("MethodName"), meth.GetValue<TString>("MethodTitle"));
516 }
517}
518
519//_______________________________________________________________________
520/**
521 * Lets perform test an specific ml method.
522 * \param methodname name of the method.
523 * \param methodtitle method title.
524 */
526{
527 auto method = GetMethod(methodname, methodtitle);
528 if (!method) {
529 Log() << kFATAL
530 << Form("Trying to train method %s %s that maybe is not booked.", methodname.Data(), methodtitle.Data())
531 << Endl;
532 }
533
534 Log() << kHEADER << gTools().Color("bold") << "Test all methods" << gTools().Color("reset") << Endl;
536
537 Types::EAnalysisType analysisType = method->GetAnalysisType();
538 Log() << kHEADER << "Test method: " << method->GetMethodName() << " for Classification"
539 << " performance" << Endl << Endl;
540 method->AddOutput(Types::kTesting, analysisType);
541
542 // -----------------------------------------------------------------------
543 // First part of evaluation process
544 // --> compute efficiencies, and other separation estimators
545 // -----------------------------------------------------------------------
546
547 // although equal, we now want to separate the output for the variables
548 // and the real methods
549 Int_t isel; // will be 0 for a Method; 1 for a Variable
550 Int_t nmeth_used[2] = {0, 0}; // 0 Method; 1 Variable
551
552 std::vector<std::vector<TString>> mname(2);
553 std::vector<std::vector<Double_t>> sig(2), sep(2), roc(2);
554 std::vector<std::vector<Double_t>> eff01(2), eff10(2), eff30(2), effArea(2);
555 std::vector<std::vector<Double_t>> eff01err(2), eff10err(2), eff30err(2);
556 std::vector<std::vector<Double_t>> trainEff01(2), trainEff10(2), trainEff30(2);
557
558 method->SetFile(fFile.get());
559 method->SetSilentFile(IsSilentFile());
560
561 MethodBase *methodNoCuts = NULL;
562 if (!IsCutsMethod(method))
563 methodNoCuts = method;
564
565 Log() << kHEADER << "Evaluate classifier: " << method->GetMethodName() << Endl << Endl;
566 isel = (method->GetMethodTypeName().Contains("Variable")) ? 1 : 0;
567
568 // perform the evaluation
569 method->TestClassification();
570
571 // evaluate the classifier
572 mname[isel].push_back(method->GetMethodName());
573 sig[isel].push_back(method->GetSignificance());
574 sep[isel].push_back(method->GetSeparation());
575 roc[isel].push_back(method->GetROCIntegral());
576
577 Double_t err;
578 eff01[isel].push_back(method->GetEfficiency("Efficiency:0.01", Types::kTesting, err));
579 eff01err[isel].push_back(err);
580 eff10[isel].push_back(method->GetEfficiency("Efficiency:0.10", Types::kTesting, err));
581 eff10err[isel].push_back(err);
582 eff30[isel].push_back(method->GetEfficiency("Efficiency:0.30", Types::kTesting, err));
583 eff30err[isel].push_back(err);
584 effArea[isel].push_back(method->GetEfficiency("", Types::kTesting, err)); // computes the area (average)
585
586 trainEff01[isel].push_back(method->GetTrainingEfficiency("Efficiency:0.01")); // the first pass takes longer
587 trainEff10[isel].push_back(method->GetTrainingEfficiency("Efficiency:0.10"));
588 trainEff30[isel].push_back(method->GetTrainingEfficiency("Efficiency:0.30"));
589
590 nmeth_used[isel]++;
591
592 if (!IsSilentFile()) {
593 Log() << kDEBUG << "\tWrite evaluation histograms to file" << Endl;
594 method->WriteEvaluationHistosToFile(Types::kTesting);
595 method->WriteEvaluationHistosToFile(Types::kTraining);
596 }
597
598 // now sort the variables according to the best 'eff at Beff=0.10'
599 for (Int_t k = 0; k < 2; k++) {
600 std::vector<std::vector<Double_t>> vtemp;
601 vtemp.push_back(effArea[k]); // this is the vector that is ranked
602 vtemp.push_back(eff10[k]);
603 vtemp.push_back(eff01[k]);
604 vtemp.push_back(eff30[k]);
605 vtemp.push_back(eff10err[k]);
606 vtemp.push_back(eff01err[k]);
607 vtemp.push_back(eff30err[k]);
608 vtemp.push_back(trainEff10[k]);
609 vtemp.push_back(trainEff01[k]);
610 vtemp.push_back(trainEff30[k]);
611 vtemp.push_back(sig[k]);
612 vtemp.push_back(sep[k]);
613 vtemp.push_back(roc[k]);
614 std::vector<TString> vtemps = mname[k];
615 gTools().UsefulSortDescending(vtemp, &vtemps);
616 effArea[k] = vtemp[0];
617 eff10[k] = vtemp[1];
618 eff01[k] = vtemp[2];
619 eff30[k] = vtemp[3];
620 eff10err[k] = vtemp[4];
621 eff01err[k] = vtemp[5];
622 eff30err[k] = vtemp[6];
623 trainEff10[k] = vtemp[7];
624 trainEff01[k] = vtemp[8];
625 trainEff30[k] = vtemp[9];
626 sig[k] = vtemp[10];
627 sep[k] = vtemp[11];
628 roc[k] = vtemp[12];
629 mname[k] = vtemps;
630 }
631
632 // -----------------------------------------------------------------------
633 // Second part of evaluation process
634 // --> compute correlations among MVAs
635 // --> compute correlations between input variables and MVA (determines importance)
636 // --> count overlaps
637 // -----------------------------------------------------------------------
638 if (fCorrelations) {
639 const Int_t nmeth = methodNoCuts == NULL ? 0 : 1;
640 const Int_t nvar = method->fDataSetInfo.GetNVariables();
641 if (nmeth > 0) {
642
643 // needed for correlations
644 Double_t *dvec = new Double_t[nmeth + nvar];
645 std::vector<Double_t> rvec;
646
647 // for correlations
648 TPrincipal *tpSig = new TPrincipal(nmeth + nvar, "");
649 TPrincipal *tpBkg = new TPrincipal(nmeth + nvar, "");
650
651 // set required tree branch references
652 std::vector<TString> *theVars = new std::vector<TString>;
653 std::vector<ResultsClassification *> mvaRes;
654 theVars->push_back(methodNoCuts->GetTestvarName());
655 rvec.push_back(methodNoCuts->GetSignalReferenceCut());
656 theVars->back().ReplaceAll("MVA_", "");
657 mvaRes.push_back(dynamic_cast<ResultsClassification *>(
658 methodNoCuts->Data()->GetResults(methodNoCuts->GetMethodName(), Types::kTesting, Types::kMaxAnalysisType)));
659
660 // for overlap study
661 TMatrixD *overlapS = new TMatrixD(nmeth, nmeth);
662 TMatrixD *overlapB = new TMatrixD(nmeth, nmeth);
663 (*overlapS) *= 0; // init...
664 (*overlapB) *= 0; // init...
665
666 // loop over test tree
667 DataSet *defDs = method->fDataSetInfo.GetDataSet();
669 for (Int_t ievt = 0; ievt < defDs->GetNEvents(); ievt++) {
670 const Event *ev = defDs->GetEvent(ievt);
671
672 // for correlations
673 TMatrixD *theMat = 0;
674 for (Int_t im = 0; im < nmeth; im++) {
675 // check for NaN value
676 Double_t retval = (Double_t)(*mvaRes[im])[ievt][0];
677 if (TMath::IsNaN(retval)) {
678 Log() << kWARNING << "Found NaN return value in event: " << ievt << " for method \""
679 << methodNoCuts->GetName() << "\"" << Endl;
680 dvec[im] = 0;
681 } else
682 dvec[im] = retval;
683 }
684 for (Int_t iv = 0; iv < nvar; iv++)
685 dvec[iv + nmeth] = (Double_t)ev->GetValue(iv);
686 if (method->fDataSetInfo.IsSignal(ev)) {
687 tpSig->AddRow(dvec);
688 theMat = overlapS;
689 } else {
690 tpBkg->AddRow(dvec);
691 theMat = overlapB;
692 }
693
694 // count overlaps
695 for (Int_t im = 0; im < nmeth; im++) {
696 for (Int_t jm = im; jm < nmeth; jm++) {
697 if ((dvec[im] - rvec[im]) * (dvec[jm] - rvec[jm]) > 0) {
698 (*theMat)(im, jm)++;
699 if (im != jm)
700 (*theMat)(jm, im)++;
701 }
702 }
703 }
704 }
705
706 // renormalise overlap matrix
707 (*overlapS) *= (1.0 / defDs->GetNEvtSigTest()); // init...
708 (*overlapB) *= (1.0 / defDs->GetNEvtBkgdTest()); // init...
709
710 tpSig->MakePrincipals();
711 tpBkg->MakePrincipals();
712
713 const TMatrixD *covMatS = tpSig->GetCovarianceMatrix();
714 const TMatrixD *covMatB = tpBkg->GetCovarianceMatrix();
715
716 const TMatrixD *corrMatS = gTools().GetCorrelationMatrix(covMatS);
717 const TMatrixD *corrMatB = gTools().GetCorrelationMatrix(covMatB);
718
719 // print correlation matrices
720 if (corrMatS != 0 && corrMatB != 0) {
721
722 // extract MVA matrix
723 TMatrixD mvaMatS(nmeth, nmeth);
724 TMatrixD mvaMatB(nmeth, nmeth);
725 for (Int_t im = 0; im < nmeth; im++) {
726 for (Int_t jm = 0; jm < nmeth; jm++) {
727 mvaMatS(im, jm) = (*corrMatS)(im, jm);
728 mvaMatB(im, jm) = (*corrMatB)(im, jm);
729 }
730 }
731
732 // extract variables - to MVA matrix
733 std::vector<TString> theInputVars;
734 TMatrixD varmvaMatS(nvar, nmeth);
735 TMatrixD varmvaMatB(nvar, nmeth);
736 for (Int_t iv = 0; iv < nvar; iv++) {
737 theInputVars.push_back(method->fDataSetInfo.GetVariableInfo(iv).GetLabel());
738 for (Int_t jm = 0; jm < nmeth; jm++) {
739 varmvaMatS(iv, jm) = (*corrMatS)(nmeth + iv, jm);
740 varmvaMatB(iv, jm) = (*corrMatB)(nmeth + iv, jm);
741 }
742 }
743
744 if (nmeth > 1) {
745 Log() << kINFO << Endl;
746 Log() << kINFO << Form("Dataset[%s] : ", method->fDataSetInfo.GetName())
747 << "Inter-MVA correlation matrix (signal):" << Endl;
748 gTools().FormattedOutput(mvaMatS, *theVars, Log());
749 Log() << kINFO << Endl;
750
751 Log() << kINFO << Form("Dataset[%s] : ", method->fDataSetInfo.GetName())
752 << "Inter-MVA correlation matrix (background):" << Endl;
753 gTools().FormattedOutput(mvaMatB, *theVars, Log());
754 Log() << kINFO << Endl;
755 }
756
757 Log() << kINFO << Form("Dataset[%s] : ", method->fDataSetInfo.GetName())
758 << "Correlations between input variables and MVA response (signal):" << Endl;
759 gTools().FormattedOutput(varmvaMatS, theInputVars, *theVars, Log());
760 Log() << kINFO << Endl;
761
762 Log() << kINFO << Form("Dataset[%s] : ", method->fDataSetInfo.GetName())
763 << "Correlations between input variables and MVA response (background):" << Endl;
764 gTools().FormattedOutput(varmvaMatB, theInputVars, *theVars, Log());
765 Log() << kINFO << Endl;
766 } else
767 Log() << kWARNING << Form("Dataset[%s] : ", method->fDataSetInfo.GetName())
768 << "<TestAllMethods> cannot compute correlation matrices" << Endl;
769
770 // print overlap matrices
771 Log() << kINFO << Form("Dataset[%s] : ", method->fDataSetInfo.GetName())
772 << "The following \"overlap\" matrices contain the fraction of events for which " << Endl;
773 Log() << kINFO << Form("Dataset[%s] : ", method->fDataSetInfo.GetName())
774 << "the MVAs 'i' and 'j' have returned conform answers about \"signal-likeness\"" << Endl;
775 Log() << kINFO << Form("Dataset[%s] : ", method->fDataSetInfo.GetName())
776 << "An event is signal-like, if its MVA output exceeds the following value:" << Endl;
777 gTools().FormattedOutput(rvec, *theVars, "Method", "Cut value", Log());
778 Log() << kINFO << Form("Dataset[%s] : ", method->fDataSetInfo.GetName())
779 << "which correspond to the working point: eff(signal) = 1 - eff(background)" << Endl;
780
781 // give notice that cut method has been excluded from this test
782 if (nmeth != 1)
783 Log() << kINFO << Form("Dataset[%s] : ", method->fDataSetInfo.GetName())
784 << "Note: no correlations and overlap with cut method are provided at present" << Endl;
785
786 if (nmeth > 1) {
787 Log() << kINFO << Endl;
788 Log() << kINFO << Form("Dataset[%s] : ", method->fDataSetInfo.GetName())
789 << "Inter-MVA overlap matrix (signal):" << Endl;
790 gTools().FormattedOutput(*overlapS, *theVars, Log());
791 Log() << kINFO << Endl;
792
793 Log() << kINFO << Form("Dataset[%s] : ", method->fDataSetInfo.GetName())
794 << "Inter-MVA overlap matrix (background):" << Endl;
795 gTools().FormattedOutput(*overlapB, *theVars, Log());
796 }
797
798 // cleanup
799 delete tpSig;
800 delete tpBkg;
801 delete corrMatS;
802 delete corrMatB;
803 delete theVars;
804 delete overlapS;
805 delete overlapB;
806 delete[] dvec;
807 }
808 }
809
810 // -----------------------------------------------------------------------
811 // Third part of evaluation process
812 // --> output
813 // -----------------------------------------------------------------------
814 // putting results in the classification result object
815 auto &fResult = GetResults(methodname, methodtitle);
816
817 // Binary classification
818 if (fROC) {
819 Log().EnableOutput();
821 Log() << Endl;
822 TString hLine = "------------------------------------------------------------------------------------------"
823 "-------------------------";
824 Log() << kINFO << "Evaluation results ranked by best signal efficiency and purity (area)" << Endl;
825 Log() << kINFO << hLine << Endl;
826 Log() << kINFO << "DataSet MVA " << Endl;
827 Log() << kINFO << "Name: Method: ROC-integ" << Endl;
828
829 Log() << kDEBUG << hLine << Endl;
830 for (Int_t k = 0; k < 2; k++) {
831 if (k == 1 && nmeth_used[k] > 0) {
832 Log() << kINFO << hLine << Endl;
833 Log() << kINFO << "Input Variables: " << Endl << hLine << Endl;
834 }
835 for (Int_t i = 0; i < nmeth_used[k]; i++) {
836 TString datasetName = fDataLoader->GetName();
837 TString methodName = mname[k][i];
838
839 if (k == 1) {
840 methodName.ReplaceAll("Variable_", "");
841 }
842
843 TMVA::DataSet *dataset = method->Data();
844 TMVA::Results *results = dataset->GetResults(methodName, Types::kTesting, this->fAnalysisType);
845 std::vector<Bool_t> *mvaResType = dynamic_cast<ResultsClassification *>(results)->GetValueVectorTypes();
846
847 Double_t rocIntegral = 0.0;
848 if (mvaResType->size() != 0) {
849 rocIntegral = GetROCIntegral(methodname, methodtitle);
850 }
851
852 if (sep[k][i] < 0 || sig[k][i] < 0) {
853 // cannot compute separation/significance -> no MVA (usually for Cuts)
854 fResult.fROCIntegral = effArea[k][i];
855 Log() << kINFO
856 << Form("%-13s %-15s: %#1.3f", fDataLoader->GetName(), methodName.Data(), fResult.fROCIntegral)
857 << Endl;
858 } else {
859 fResult.fROCIntegral = rocIntegral;
860 Log() << kINFO << Form("%-13s %-15s: %#1.3f", datasetName.Data(), methodName.Data(), rocIntegral)
861 << Endl;
862 }
863 }
864 }
865 Log() << kINFO << hLine << Endl;
866 Log() << kINFO << Endl;
867 Log() << kINFO << "Testing efficiency compared to training efficiency (overtraining check)" << Endl;
868 Log() << kINFO << hLine << Endl;
869 Log() << kINFO
870 << "DataSet MVA Signal efficiency: from test sample (from training sample) "
871 << Endl;
872 Log() << kINFO << "Name: Method: @B=0.01 @B=0.10 @B=0.30 "
873 << Endl;
874 Log() << kINFO << hLine << Endl;
875 for (Int_t k = 0; k < 2; k++) {
876 if (k == 1 && nmeth_used[k] > 0) {
877 Log() << kINFO << hLine << Endl;
878 Log() << kINFO << "Input Variables: " << Endl << hLine << Endl;
879 }
880 for (Int_t i = 0; i < nmeth_used[k]; i++) {
881 if (k == 1)
882 mname[k][i].ReplaceAll("Variable_", "");
883
884 Log() << kINFO << Form("%-20s %-15s: %#1.3f (%#1.3f) %#1.3f (%#1.3f) %#1.3f (%#1.3f)",
885 method->fDataSetInfo.GetName(), (const char *)mname[k][i], eff01[k][i],
886 trainEff01[k][i], eff10[k][i], trainEff10[k][i], eff30[k][i], trainEff30[k][i])
887 << Endl;
888 }
889 }
890 Log() << kINFO << hLine << Endl;
891 Log() << kINFO << Endl;
892
893 if (gTools().CheckForSilentOption(GetOptions()))
894 Log().InhibitOutput();
895 } else if (IsCutsMethod(method)) { // end fROC
896 for (Int_t k = 0; k < 2; k++) {
897 for (Int_t i = 0; i < nmeth_used[k]; i++) {
898
899 if (sep[k][i] < 0 || sig[k][i] < 0) {
900 // cannot compute separation/significance -> no MVA (usually for Cuts)
901 fResult.fROCIntegral = effArea[k][i];
902 }
903 }
904 }
905 }
906
907 TMVA::DataSet *dataset = method->Data();
909
910 if (IsCutsMethod(method)) {
911 fResult.fIsCuts = kTRUE;
912 } else {
913 auto rocCurveTest = GetROC(methodname, methodtitle, 0, Types::kTesting);
914 fResult.fMvaTest[0] = rocCurveTest->GetMvas();
915 fResult.fROCIntegral = GetROCIntegral(methodname, methodtitle);
916 }
917 TString className = method->DataInfo().GetClassInfo(0)->GetName();
918 fResult.fClassNames.push_back(className);
919
920 if (!IsSilentFile()) {
921 // write test/training trees
922 RootBaseDir()->cd(method->fDataSetInfo.GetName());
923 method->fDataSetInfo.GetDataSet()->GetTree(Types::kTesting)->Write("", TObject::kOverwrite);
924 method->fDataSetInfo.GetDataSet()->GetTree(Types::kTraining)->Write("", TObject::kOverwrite);
925 }
926}
927
928//_______________________________________________________________________
929/**
930 * Lets perform test an specific ml method given the method type in enum TMVA::Types::EMVA.
931 * \param method TMVA::Types::EMVA type.
932 * \param methodtitle method title.
933 */
935{
936 TestMethod(Types::Instance().GetMethodName(method), methodtitle);
937}
938
939//_______________________________________________________________________
940/**
941 * Return the vector of TMVA::Experimental::ClassificationResult objects.
942 * \return vector of results.
943 */
944std::vector<TMVA::Experimental::ClassificationResult> &TMVA::Experimental::Classification::GetResults()
945{
946 if (fResults.size() == 0)
947 Log() << kFATAL << "No Classification results available" << Endl;
948 return fResults;
949}
950
951//_______________________________________________________________________
952/**
953 * Allows to check if the ml method is a Cuts method.
954 * \return boolen true if the method is a Cuts method.
955 */
957{
958 return method->GetMethodType() == Types::kCuts ? kTRUE : kFALSE;
959}
960
961//_______________________________________________________________________
962/**
963 * Allow to get result for an specific ml method.
964 * \param methodname name of the method.
965 * \param methodtitle method title.
966 * \return TMVA::Experimental::ClassificationResult object for the method.
967 */
970{
971 for (auto &result : fResults) {
972 if (result.IsMethod(methodname, methodtitle))
973 return result;
974 }
976 result.fMethod["MethodName"] = methodname;
977 result.fMethod["MethodTitle"] = methodtitle;
978 result.fDataLoaderName = fDataLoader->GetName();
979 fResults.push_back(result);
980 return fResults.back();
981}
982
983//_______________________________________________________________________
984/**
985 * Method to get TMVA::ROCCurve Object.
986 * \param method TMVA::MethodBase object
987 * \param iClass category, default 0 then signal
988 * \param type train/test tree, default test.
989 * \return TMVA::ROCCurve object.
990 */
993{
994 TMVA::DataSet *dataset = method->Data();
995 dataset->SetCurrentType(type);
996 TMVA::Results *results = dataset->GetResults(method->GetName(), type, this->fAnalysisType);
997
998 UInt_t nClasses = method->DataInfo().GetNClasses();
999 if (this->fAnalysisType == Types::kMulticlass && iClass >= nClasses) {
1000 Log() << kERROR << Form("Given class number (iClass = %i) does not exist. There are %i classes in dataset.",
1001 iClass, nClasses)
1002 << Endl;
1003 return nullptr;
1004 }
1005
1006 TMVA::ROCCurve *rocCurve = nullptr;
1007 if (this->fAnalysisType == Types::kClassification) {
1008
1009 std::vector<Float_t> *mvaRes = dynamic_cast<ResultsClassification *>(results)->GetValueVector();
1010 std::vector<Bool_t> *mvaResTypes = dynamic_cast<ResultsClassification *>(results)->GetValueVectorTypes();
1011 std::vector<Float_t> mvaResWeights;
1012
1013 auto eventCollection = dataset->GetEventCollection(type);
1014 mvaResWeights.reserve(eventCollection.size());
1015 for (auto ev : eventCollection) {
1016 mvaResWeights.push_back(ev->GetWeight());
1017 }
1018
1019 rocCurve = new TMVA::ROCCurve(*mvaRes, *mvaResTypes, mvaResWeights);
1020
1021 } else if (this->fAnalysisType == Types::kMulticlass) {
1022 std::vector<Float_t> mvaRes;
1023 std::vector<Bool_t> mvaResTypes;
1024 std::vector<Float_t> mvaResWeights;
1025
1026 std::vector<std::vector<Float_t>> *rawMvaRes = dynamic_cast<ResultsMulticlass *>(results)->GetValueVector();
1027
1028 // Vector transpose due to values being stored as
1029 // [ [0, 1, 2], [0, 1, 2], ... ]
1030 // in ResultsMulticlass::GetValueVector.
1031 mvaRes.reserve(rawMvaRes->size());
1032 for (auto item : *rawMvaRes) {
1033 mvaRes.push_back(item[iClass]);
1034 }
1035
1036 auto eventCollection = dataset->GetEventCollection(type);
1037 mvaResTypes.reserve(eventCollection.size());
1038 mvaResWeights.reserve(eventCollection.size());
1039 for (auto ev : eventCollection) {
1040 mvaResTypes.push_back(ev->GetClass() == iClass);
1041 mvaResWeights.push_back(ev->GetWeight());
1042 }
1043
1044 rocCurve = new TMVA::ROCCurve(mvaRes, mvaResTypes, mvaResWeights);
1045 }
1046
1047 return rocCurve;
1048}
1049
1050//_______________________________________________________________________
1051/**
1052 * Method to get TMVA::ROCCurve Object.
1053 * \param methodname ml method name.
1054 * \param methodtitle ml method title.
1055 * \param iClass category, default 0 then signal
1056 * \param type train/test tree, default test.
1057 * \return TMVA::ROCCurve object.
1058 */
1061{
1062 return GetROC(GetMethod(methodname, methodtitle), iClass, type);
1063}
1064
1065//_______________________________________________________________________
1066/**
1067 * Method to get ROC-Integral value from mvas.
1068 * \param methodname ml method name.
1069 * \param methodtitle ml method title.
1070 * \param iClass category, default 0 then signal
1071 * \return Double_t with the ROC-Integral value.
1072 */
1074{
1075 TMVA::ROCCurve *rocCurve = GetROC(methodname, methodtitle, iClass);
1076 if (!rocCurve) {
1077 Log() << kFATAL
1078 << Form("ROCCurve object was not created in MethodName = %s MethodTitle = %s not found with Dataset = %s ",
1079 methodname.Data(), methodtitle.Data(), fDataLoader->GetName())
1080 << Endl;
1081 return 0;
1082 }
1083
1085 Double_t rocIntegral = rocCurve->GetROCIntegral(npoints);
1086 delete rocCurve;
1087
1088 return rocIntegral;
1089}
1090
1091//_______________________________________________________________________
1093{
1094 TFile *savdir = file;
1095 TDirectory *adir = savdir;
1096 adir->cd();
1097 // loop on all entries of this directory
1098 TKey *key;
1099 TIter nextkey(src->GetListOfKeys());
1100 while ((key = (TKey *)nextkey())) {
1101 const Char_t *classname = key->GetClassName();
1102 TClass *cl = gROOT->GetClass(classname);
1103 if (!cl)
1104 continue;
1105 if (cl->InheritsFrom(TDirectory::Class())) {
1106 src->cd(key->GetName());
1107 TDirectory *subdir = file;
1108 adir->cd();
1109 CopyFrom(subdir, file);
1110 adir->cd();
1111 } else if (cl->InheritsFrom(TTree::Class())) {
1112 TTree *T = (TTree *)src->Get(key->GetName());
1113 adir->cd();
1114 TTree *newT = T->CloneTree(-1, "fast");
1115 newT->Write();
1116 } else {
1117 src->cd();
1118 TObject *obj = key->ReadObj();
1119 adir->cd();
1120 obj->Write();
1121 delete obj;
1122 }
1123 }
1124 adir->SaveSelf(kTRUE);
1125 savdir->cd();
1126}
1127
1128//_______________________________________________________________________
1130{
1131
1132 auto dsdir = fFile->mkdir(fDataLoader->GetName()); // dataset dir
1133 TTree *TrainTree = 0;
1134 TTree *TestTree = 0;
1135 TFile *ifile = 0;
1136 TFile *ofile = 0;
1137 for (UInt_t i = 0; i < fMethods.size(); i++) {
1138 auto methodname = fMethods[i].GetValue<TString>("MethodName");
1139 auto methodtitle = fMethods[i].GetValue<TString>("MethodTitle");
1140 auto fname = Form(".%s%s%s.root", fDataLoader->GetName(), methodname.Data(), methodtitle.Data());
1141 TDirectoryFile *ds = 0;
1142 if (i == 0) {
1143 ifile = new TFile(fname);
1144 ds = (TDirectoryFile *)ifile->Get(fDataLoader->GetName());
1145 } else {
1146 ofile = new TFile(fname);
1147 ds = (TDirectoryFile *)ofile->Get(fDataLoader->GetName());
1148 }
1149 auto tmptrain = (TTree *)ds->Get("TrainTree");
1150 auto tmptest = (TTree *)ds->Get("TestTree");
1151 fFile->cd();
1152 fFile->cd(fDataLoader->GetName());
1153
1154 auto methdirname = Form("Method_%s", methodtitle.Data());
1155 auto methdir = dsdir->mkdir(methdirname, methdirname);
1156 auto methdirbase = methdir->mkdir(methodtitle.Data(), methodtitle.Data());
1157 auto mfdir = (TDirectoryFile *)ds->Get(methdirname);
1158 auto mfdirbase = (TDirectoryFile *)mfdir->Get(methodtitle.Data());
1159
1160 CopyFrom(mfdirbase, (TFile *)methdirbase);
1161 dsdir->cd();
1162 if (i == 0) {
1163 TrainTree = tmptrain->CopyTree("");
1164 TestTree = tmptest->CopyTree("");
1165 } else {
1166 Float_t mva = 0;
1167 auto trainbranch = TrainTree->Branch(methodtitle.Data(), &mva);
1168 tmptrain->SetBranchAddress(methodtitle.Data(), &mva);
1169 auto entries = tmptrain->GetEntries();
1170 for (UInt_t ev = 0; ev < entries; ev++) {
1171 tmptrain->GetEntry(ev);
1172 trainbranch->Fill();
1173 }
1174 auto testbranch = TestTree->Branch(methodtitle.Data(), &mva);
1175 tmptest->SetBranchAddress(methodtitle.Data(), &mva);
1176 entries = tmptest->GetEntries();
1177 for (UInt_t ev = 0; ev < entries; ev++) {
1178 tmptest->GetEntry(ev);
1179 testbranch->Fill();
1180 }
1181 ofile->Close();
1182 }
1183 }
1184 TrainTree->Write();
1185 TestTree->Write();
1186 ifile->Close();
1187 // cleaning
1188 for (UInt_t i = 0; i < fMethods.size(); i++) {
1189 auto methodname = fMethods[i].GetValue<TString>("MethodName");
1190 auto methodtitle = fMethods[i].GetValue<TString>("MethodTitle");
1191 auto fname = Form(".%s%s%s.root", fDataLoader->GetName(), methodname.Data(), methodtitle.Data());
1192 gSystem->Unlink(fname);
1193 }
1194}
#define MinNoTrainingEvents
#define f(i)
Definition RSha256.hxx:104
char Char_t
Definition RtypesCore.h:37
float Float_t
Definition RtypesCore.h:57
constexpr Bool_t kFALSE
Definition RtypesCore.h:101
double Double_t
Definition RtypesCore.h:59
constexpr Bool_t kTRUE
Definition RtypesCore.h:100
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t r
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t result
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t index
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t src
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t Atom_t Time_t type
TMatrixT< Double_t > TMatrixD
Definition TMatrixDfwd.h:23
#define gROOT
Definition TROOT.h:405
char * Form(const char *fmt,...)
Formats a string in a circular formatting buffer.
Definition TString.cxx:2467
R__EXTERN TSystem * gSystem
Definition TSystem.h:560
A pseudo container class which is a generator of indices.
Definition TSeq.hxx:67
Long64_t GetEntries() const
Definition TBranch.h:247
TClass instances represent classes, structs and namespaces in the ROOT type system.
Definition TClass.h:81
Bool_t InheritsFrom(const char *cl) const override
Return kTRUE if this class inherits from a class with name "classname".
Definition TClass.cxx:4874
static TClass * GetClass(const char *name, Bool_t load=kTRUE, Bool_t silent=kFALSE)
Static method returning pointer to TClass of the specified class name.
Definition TClass.cxx:2968
A ROOT file is structured in Directories (like a file system).
Bool_t cd() override
Change current directory to "this" directory.
TObject * Get(const char *namecycle) override
Return pointer to object identified by namecycle.
Describe directory structure in memory.
Definition TDirectory.h:45
static TClass * Class()
virtual Bool_t cd()
Change current directory to "this" directory.
virtual void SaveSelf(Bool_t=kFALSE)
Definition TDirectory.h:255
A ROOT file is a suite of consecutive data records (TKey instances) with a well defined format.
Definition TFile.h:51
void Close(Option_t *option="") override
Close a file.
Definition TFile.cxx:914
A TGraph is an object made of two arrays X and Y with npoints each.
Definition TGraph.h:41
void SetName(const char *name="") override
Set graph name.
Definition TGraph.cxx:2364
TAxis * GetXaxis() const
Get x axis of the graph.
Definition TGraph.cxx:1550
TAxis * GetYaxis() const
Get y axis of the graph.
Definition TGraph.cxx:1559
void SetTitle(const char *title="") override
Change (i.e.
Definition TGraph.cxx:2380
Book space in a file, create I/O buffers, to fill them, (un)compress them.
Definition TKey.h:28
virtual const char * GetClassName() const
Definition TKey.h:75
virtual TObject * ReadObj()
To read a TObject* from the file.
Definition TKey.cxx:750
IMethod * Create(const std::string &name, const TString &job, const TString &title, DataSetInfo &dsi, const TString &option)
creates the method if needed based on the method name using the creator function the factory has stor...
static ClassifierFactory & Instance()
access to the ClassifierFactory singleton creates the instance if needed
TString fWeightFileDir
Definition Config.h:124
void SetDrawProgressBar(Bool_t d)
Definition Config.h:69
void SetUseColor(Bool_t uc)
Definition Config.h:60
class TMVA::Config::VariablePlotting fVariablePlotting
void SetSilent(Bool_t s)
Definition Config.h:63
IONames & GetIONames()
Definition Config.h:98
void SetConfigDescription(const char *d)
OptionBase * DeclareOptionRef(T &ref, const TString &name, const TString &desc="")
void SetConfigName(const char *n)
virtual void ParseOptions()
options parser
void CheckForUnusedOptions() const
checks for unused options in option string
UInt_t GetNClasses() const
Class that contains all the data information.
Definition DataSet.h:58
Long64_t GetNEvtSigTest()
return number of signal test events in dataset
Definition DataSet.cxx:427
const Event * GetEvent() const
returns event without transformations
Definition DataSet.cxx:202
Long64_t GetNEvents(Types::ETreeType type=Types::kMaxTreeType) const
Definition DataSet.h:206
Results * GetResults(const TString &, Types::ETreeType type, Types::EAnalysisType analysistype)
Definition DataSet.cxx:265
void SetCurrentType(Types::ETreeType type) const
Definition DataSet.h:89
const std::vector< Event * > & GetEventCollection(Types::ETreeType type=Types::kMaxTreeType) const
Definition DataSet.h:216
Long64_t GetNEvtBkgdTest()
return number of background test events in dataset
Definition DataSet.cxx:435
Abstract base class for all high level ml algorithms, you can book ml methods like BDT,...
Definition Envelope.h:44
Bool_t fModelPersistence
! flag to save the trained model
Definition Envelope.h:49
std::shared_ptr< DataLoader > fDataLoader
! data
Definition Envelope.h:47
virtual void ParseOptions()
Method to parse the internal option string.
Definition Envelope.cxx:182
Float_t GetValue(UInt_t ivar) const
return value of i'th variable
Definition Event.cxx:236
static void SetIsTraining(Bool_t)
when this static function is called, it sets the flag whether events with negative event weight shoul...
Definition Event.cxx:399
Double_t GetROCIntegral(UInt_t iClass=0, TMVA::Types::ETreeType type=TMVA::Types::kTesting)
Method to get ROC-Integral value from mvas.
TGraph * GetROCGraph(UInt_t iClass=0, TMVA::Types::ETreeType type=TMVA::Types::kTesting)
Method to get TGraph object with the ROC curve.
void Show()
Method to print the results in stdout.
Bool_t IsMethod(TString methodname, TString methodtitle)
Method to check if method was booked.
std::map< UInt_t, std::vector< std::tuple< Float_t, Float_t, Bool_t > > > fMvaTest
Mvas for two-class and multiclass classification.
ROCCurve * GetROC(UInt_t iClass=0, TMVA::Types::ETreeType type=TMVA::Types::kTesting)
Method to get TMVA::ROCCurve Object.
Bool_t fIsCuts
if it is a method cuts need special output
ClassificationResult & operator=(const ClassificationResult &r)
std::map< UInt_t, std::vector< std::tuple< Float_t, Float_t, Bool_t > > > fMvaTrain
Mvas for two-class classification.
Classification(DataLoader *loader, TFile *file, TString options)
Contructor to create a two class classifier.
Double_t GetROCIntegral(TString methodname, TString methodtitle, UInt_t iClass=0)
Method to get ROC-Integral value from mvas.
virtual void Test()
Perform test evaluation in all booked methods.
TString GetMethodOptions(TString methodname, TString methodtitle)
return the options for the booked method.
MethodBase * GetMethod(TString methodname, TString methodtitle)
Return a TMVA::MethodBase object.
virtual void TrainMethod(TString methodname, TString methodtitle)
Lets train an specific ml method.
Bool_t HasMethodObject(TString methodname, TString methodtitle, Int_t &index)
Allows to check if the TMVA::MethodBase was created and return the index in the vector.
std::vector< ClassificationResult > & GetResults()
Return the vector of TMVA::Experimental::ClassificationResult objects.
virtual void Train()
Method to train all booked ml methods.
virtual void Evaluate()
Method to perform Train/Test over all ml method booked.
Types::EAnalysisType fAnalysisType
!
TMVA::ROCCurve * GetROC(TMVA::MethodBase *method, UInt_t iClass=0, TMVA::Types::ETreeType type=TMVA::Types::kTesting)
Method to get TMVA::ROCCurve Object.
Bool_t IsCutsMethod(TMVA::MethodBase *method)
Allows to check if the ml method is a Cuts method.
void CopyFrom(TDirectory *src, TFile *file)
virtual void TestMethod(TString methodname, TString methodtitle)
Lets perform test an specific ml method.
Interface for all concrete MVA method implementations.
Definition IMethod.h:53
virtual Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets)=0
Virtual base Class for all MVA method.
Definition MethodBase.h:111
void SetSilentFile(Bool_t status)
Definition MethodBase.h:378
void SetWeightFileDir(TString fileDir)
set directory of weight file
TString GetMethodTypeName() const
Definition MethodBase.h:332
const char * GetName() const
Definition MethodBase.h:334
const TString & GetTestvarName() const
Definition MethodBase.h:335
void SetupMethod()
setup of methods
virtual void SetAnalysisType(Types::EAnalysisType type)
Definition MethodBase.h:436
const TString & GetMethodName() const
Definition MethodBase.h:331
void ProcessSetup()
process all options the "CheckForUnusedOptions" is done in an independent call, since it may be overr...
DataSetInfo & DataInfo() const
Definition MethodBase.h:410
Types::EMVA GetMethodType() const
Definition MethodBase.h:333
void SetFile(TFile *file)
Definition MethodBase.h:375
DataSet * Data() const
Definition MethodBase.h:409
void SetModelPersistence(Bool_t status)
Definition MethodBase.h:382
Double_t GetSignalReferenceCut() const
Definition MethodBase.h:360
virtual void CheckSetup()
check may be overridden by derived class (sometimes, eg, fitters are used which can only be implement...
Class for boosting a TMVA method.
Definition MethodBoost.h:58
void SetBoostedMethodName(TString methodName)
Definition MethodBoost.h:86
DataSetManager * fDataSetManager
DSMTEST.
Class for categorizing the phase space.
DataSetManager * fDataSetManager
ostringstream derivative to redirect and format output
Definition MsgLogger.h:57
static void InhibitOutput()
Definition MsgLogger.cxx:67
static void EnableOutput()
Definition MsgLogger.cxx:68
Double_t GetROCIntegral(const UInt_t points=41)
Calculates the ROC integral (AUC)
Definition ROCCurve.cxx:250
Class that is the base-class for a vector of result.
Class which takes the results of a multiclass classification.
Class that is the base-class for a vector of result.
Definition Results.h:57
void FormattedOutput(const std::vector< Double_t > &, const std::vector< TString > &, const TString titleVars, const TString titleValues, MsgLogger &logger, TString format="%+1.3f")
formatted output of simple table
Definition Tools.cxx:887
void UsefulSortDescending(std::vector< std::vector< Double_t > > &, std::vector< TString > *vs=nullptr)
sort 2D vector (AND in parallel a TString vector) in such a way that the "first vector is sorted" and...
Definition Tools.cxx:564
const TString & Color(const TString &)
human readable color strings
Definition Tools.cxx:828
const TMatrixD * GetCorrelationMatrix(const TMatrixD *covMat)
turns covariance into correlation matrix
Definition Tools.cxx:324
Singleton class for Global types used by TMVA.
Definition Types.h:71
static Types & Instance()
The single instance of "Types" if existing already, or create it (Singleton)
Definition Types.cxx:70
@ kCategory
Definition Types.h:97
@ kMulticlass
Definition Types.h:129
@ kClassification
Definition Types.h:127
@ kMaxAnalysisType
Definition Types.h:131
@ kTraining
Definition Types.h:143
virtual void SetTitle(const char *title="")
Set the title of the TNamed.
Definition TNamed.cxx:164
const char * GetName() const override
Returns name of object.
Definition TNamed.h:47
Mother of all ROOT objects.
Definition TObject.h:41
@ kOverwrite
overwrite existing object with same name
Definition TObject.h:92
virtual Int_t Write(const char *name=nullptr, Int_t option=0, Int_t bufsize=0)
Write this object to the current directory.
Definition TObject.cxx:874
Principal Components Analysis (PCA)
Definition TPrincipal.h:21
virtual void AddRow(const Double_t *x)
Add a data point and update the covariance matrix.
const TMatrixD * GetCovarianceMatrix() const
Definition TPrincipal.h:59
virtual void MakePrincipals()
Perform the principal components analysis.
Basic string class.
Definition TString.h:139
const char * Data() const
Definition TString.h:380
TString & ReplaceAll(const TString &s1, const TString &s2)
Definition TString.h:704
virtual int MakeDirectory(const char *name)
Make a directory.
Definition TSystem.cxx:830
virtual int Unlink(const char *name)
Unlink, i.e.
Definition TSystem.cxx:1384
A TTree represents a columnar dataset.
Definition TTree.h:79
virtual TTree * CopyTree(const char *selection, Option_t *option="", Long64_t nentries=kMaxEntries, Long64_t firstentry=0)
Copy a tree with selection.
Definition TTree.cxx:3713
TBranch * Branch(const char *name, T *obj, Int_t bufsize=32000, Int_t splitlevel=99)
Add a new branch, and infer the data type from the type of obj being passed.
Definition TTree.h:350
Int_t Write(const char *name=nullptr, Int_t option=0, Int_t bufsize=0) override
Write this object to the current directory.
Definition TTree.cxx:9708
static TClass * Class()
create variable transformations
Config & gConfig()
Tools & gTools()
MsgLogger & Endl(MsgLogger &ml)
Definition MsgLogger.h:148
Bool_t IsNaN(Double_t x)
Definition TMath.h:890
Definition file.py:1
TMarker m
Definition textangle.C:8