Logo ROOT  
Reference Guide
Loading...
Searching...
No Matches
MethodCategory.cxx
Go to the documentation of this file.
1// @(#)root/tmva $Id$
2// Author: Andreas Hoecker, Joerg Stelzer, Helge Voss, Eckhard von Toerne
3
4/**********************************************************************************
5 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6 * Package: TMVA *
7 * Class : MethodCompositeBase *
8 * *
9 * *
10 * Description: *
11 * Virtual base class for all MVA method *
12 * *
13 * Authors (alphabetical): *
14 * Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland *
15 * Nadim Sah <Nadim.Sah@cern.ch> - Berlin, Germany *
16 * Peter Speckmayer <Peter.Speckmazer@cern.ch> - CERN, Switzerland *
17 * Joerg Stelzer <Joerg.Stelzer@cern.ch> - MSU East Lansing, USA *
18 * Helge Voss <Helge.Voss@cern.ch> - MPI-K Heidelberg, Germany *
19 * Jan Therhaag <Jan.Therhaag@cern.ch> - U of Bonn, Germany *
20 * Eckhard v. Toerne <evt@uni-bonn.de> - U of Bonn, Germany *
21 * *
22 * Copyright (c) 2005-2011: *
23 * CERN, Switzerland *
24 * MSU East Lansing, USA *
25 * MPI-K Heidelberg, Germany *
26 * U. of Bonn, Germany *
27 * *
28 * Redistribution and use in source and binary forms, with or without *
29 * modification, are permitted according to the terms listed in LICENSE *
30 * (see tmva/doc/LICENSE) *
31 **********************************************************************************/
32
33/*! \class TMVA::MethodCategory
34\ingroup TMVA
35
36Class for categorizing the phase space
37
38This class is meant to allow categorisation of the data. For different
39categories, different classifiers may be booked and different variables
40may be considered. The aim is to account for the difference that
41is due to different locations/angles.
42*/
43
44
45#include "TMVA/MethodCategory.h"
46
47#include <algorithm>
48#include <vector>
49
50#include "TRandom3.h"
51#include "TH1F.h"
52#include "TSpline.h"
53#include "TDirectory.h"
54#include "TTreeFormula.h"
55
57#include "TMVA/Config.h"
58#include "TMVA/DataSet.h"
59#include "TMVA/DataSetInfo.h"
60#include "TMVA/DataSetManager.h"
61#include "TMVA/IMethod.h"
62#include "TMVA/MethodBase.h"
64#include "TMVA/MsgLogger.h"
65#include "TMVA/PDF.h"
66#include "TMVA/Ranking.h"
67#include "TMVA/Timer.h"
68#include "TMVA/Tools.h"
69#include "TMVA/Types.h"
70#include "TMVA/VariableInfo.h"
72
73REGISTER_METHOD(Category)
74
75
76////////////////////////////////////////////////////////////////////////////////
77/// standard constructor
78
80 const TString& methodTitle,
81 DataSetInfo& theData,
82 const TString& theOption )
83 : TMVA::MethodCompositeBase( jobName, Types::kCategory, methodTitle, theData, theOption),
84 fCatTree(0),
85 fDataSetManager(NULL)
86{
87}
88
89////////////////////////////////////////////////////////////////////////////////
90/// constructor from weight file
91
93 const TString& theWeightFile)
94 : TMVA::MethodCompositeBase( Types::kCategory, dsi, theWeightFile),
95 fCatTree(0),
96 fDataSetManager(NULL)
97{
98}
99
100////////////////////////////////////////////////////////////////////////////////
101/// destructor
102
104{
105 std::vector<TTreeFormula*>::iterator formIt = fCatFormulas.begin();
106 std::vector<TTreeFormula*>::iterator lastF = fCatFormulas.end();
107 for(;formIt!=lastF; ++formIt) delete *formIt;
108 delete fCatTree;
109}
110
111////////////////////////////////////////////////////////////////////////////////
112/// check whether method category has analysis type
113/// the method type has to be the same for all sub-methods
114
116{
117 std::vector<IMethod*>::iterator itrMethod = fMethods.begin();
118
119 // iterate over methods and check whether they have the analysis type
120 for(; itrMethod != fMethods.end(); ++itrMethod ) {
121 if ( !(*itrMethod)->HasAnalysisType(type, numberClasses, numberTargets) )
122 return kFALSE;
123 }
124 return kTRUE;
125}
126
127////////////////////////////////////////////////////////////////////////////////
128/// options for this method
129
133
134////////////////////////////////////////////////////////////////////////////////
135/// adds sub-classifier for a category
136
138 const TString& theVariables,
139 Types::EMVA theMethod ,
140 const TString& theTitle,
141 const TString& theOptions )
142{
143 std::string addedMethodName(Types::Instance().GetMethodName(theMethod).Data());
144
145 Log() << kINFO << "Adding sub-classifier: " << addedMethodName << "::" << theTitle << Endl;
146
147 DataSetInfo& dsi = CreateCategoryDSI(theCut, theVariables, theTitle);
148
149 IMethod* addedMethod = ClassifierFactory::Instance().Create(addedMethodName,GetJobName(),theTitle,dsi,theOptions);
150
151 MethodBase *method = (dynamic_cast<MethodBase*>(addedMethod));
152 if(method==0) return 0;
153
157 method->SetupMethod();
158 method->ParseOptions();
159 method->ProcessSetup();
160 method->SetFile(fFile);
161 method->SetSilentFile(IsSilentFile());
162
163
164 // set or create correct method base dir for added method
165 const TString dirName = TString::Format("Method_%s",method->GetMethodTypeName().Data());
166 TDirectory * dir = BaseDir()->GetDirectory(dirName);
167 if (dir != 0) method->SetMethodBaseDir( dir );
168 else method->SetMethodBaseDir( BaseDir()->mkdir(dirName, TString::Format("Directory for all %s methods", method->GetMethodTypeName().Data())) );
169
170 // method->SetBaseDir(eigenes base dir, gucken ob Fisher dir existiert, sonst erzeugen )
171
172 // check-for-unused-options is performed; may be overridden by derived
173 // classes
174 method->CheckSetup();
175
176 // disable writing of XML files and standalone classes for sub methods
177 method->DisableWriting( kTRUE );
178
179 // store method, cut and variable names and create cut formula
180 fMethods.push_back(method);
181 fCategoryCuts.push_back(theCut);
182 fVars.push_back(theVariables);
183
184 DataSetInfo& primaryDSI = DataInfo();
185
186 UInt_t newSpectatorIndex = primaryDSI.GetSpectatorInfos().size();
187 fCategorySpecIdx.push_back(newSpectatorIndex);
188
189 primaryDSI.AddSpectator( TString::Format("%s_cat%i:=%s", GetName(),(int)fMethods.size(),theCut.GetTitle()).Data(),
190 TString::Format("%s:%s",GetName(),method->GetName()).Data(),
191 "pass", 0, 0, 'C' );
192
193 return method;
194}
195
196////////////////////////////////////////////////////////////////////////////////
197/// create a DataSetInfo object for a sub-classifier
198
200 const TString& theVariables,
201 const TString& theTitle)
202{
203 // create a new dsi with name: theTitle+"_dsi"
204 TString dsiName=theTitle+"_dsi";
205 DataSetInfo& oldDSI = DataInfo();
206 DataSetInfo* dsi = new DataSetInfo(dsiName);
207
208 // register the new dsi
209 // DataSetManager::Instance().AddDataSetInfo(*dsi); // DSMTEST replaced by following line
210 fDataSetManager->AddDataSetInfo(*dsi);
211
212 // copy the targets and spectators from the old dsi to the new dsi
213 std::vector<VariableInfo>::iterator itrVarInfo;
214
215 for (itrVarInfo = oldDSI.GetTargetInfos().begin(); itrVarInfo != oldDSI.GetTargetInfos().end(); ++itrVarInfo)
216 dsi->AddTarget(*itrVarInfo);
217
218 for (itrVarInfo = oldDSI.GetSpectatorInfos().begin(); itrVarInfo != oldDSI.GetSpectatorInfos().end(); ++itrVarInfo)
219 dsi->AddSpectator(*itrVarInfo);
220
221 // split string that contains the variables into tiny little pieces
222 std::vector<TString> variables = gTools().SplitString(theVariables,':' );
223
224 // prepare to create varMap
225 std::vector<UInt_t> varMap;
226 Int_t counter=0;
227
228 // add the variables that were specified in theVariables
229 std::vector<TString>::iterator itrVariables;
230 Bool_t found = kFALSE;
231
232 // iterate over all variables in 'variables' and add them
233 for (itrVariables = variables.begin(); itrVariables != variables.end(); ++itrVariables) {
234 counter=0;
235
236 // check the variables of the old dsi for the variable that we want to add
237 for (itrVarInfo = oldDSI.GetVariableInfos().begin(); itrVarInfo != oldDSI.GetVariableInfos().end(); ++itrVarInfo) {
238 if((*itrVariables==itrVarInfo->GetLabel()) ) { // || (*itrVariables==itrVarInfo->GetExpression())) {
239 // don't compare the expression, since the user might take two times the same expression, but with different labels
240 // and apply different transformations to the variables.
241 dsi->AddVariable(*itrVarInfo);
242 varMap.push_back(counter);
243 found = kTRUE;
244 }
245 counter++;
246 }
247
248 // check the spectators of the old dsi for the variable that we want to add
249 for (itrVarInfo = oldDSI.GetSpectatorInfos().begin(); itrVarInfo != oldDSI.GetSpectatorInfos().end(); ++itrVarInfo) {
250 if((*itrVariables==itrVarInfo->GetLabel()) ) { // || (*itrVariables==itrVarInfo->GetExpression())) {
251 // don't compare the expression, since the user might take two times the same expression, but with different labels
252 // and apply different transformations to the variables.
253 dsi->AddVariable(*itrVarInfo);
254 varMap.push_back(counter);
255 found = kTRUE;
256 }
257 counter++;
258 }
259
260 // if the variable is neither in the variables nor in the spectators, we abort
261 if (!found) {
262 Log() << kFATAL <<"The variable " << itrVariables->Data() << " was not found and could not be added " << Endl;
263 }
264 found = kFALSE;
265 }
266
267 // in the case that no variables are specified, add the default-variables from the original dsi
268 if (theVariables=="") {
269 for (UInt_t i=0; i<oldDSI.GetVariableInfos().size(); i++) {
270 dsi->AddVariable(oldDSI.GetVariableInfos()[i]);
271 varMap.push_back(i);
272 }
273 }
274
275 // add the variable map 'varMap' to the vector of varMaps
276 fVarMaps.push_back(varMap);
277
278 // set classes and cuts
279 UInt_t nClasses=oldDSI.GetNClasses();
280 TString className;
281
282 for (UInt_t i=0; i<nClasses; i++) {
283 className = oldDSI.GetClassInfo(i)->GetName();
284 dsi->AddClass(className);
285 dsi->SetCut(oldDSI.GetCut(i),className);
286 dsi->AddCut(theCut,className);
287 dsi->SetWeightExpression(oldDSI.GetWeightExpression(i),className);
288 }
289
290 // set split options, root dir and normalization for the new dsi
291 dsi->SetSplitOptions(oldDSI.GetSplitOptions());
292 dsi->SetRootDir(oldDSI.GetRootDir());
293 TString norm(oldDSI.GetNormalization().Data());
294 dsi->SetNormalization(norm);
295 // need to add split options to normalize with cut efficiency
296 TString splitOpt = dsi->GetSplitOptions();
297 splitOpt += ":ScaleWithPreselEff";
298 dsi->SetSplitOptions(splitOpt);
299
300 DataSetInfo& dsiReference= (*dsi);
301
302 return dsiReference;
303}
304
305////////////////////////////////////////////////////////////////////////////////
306/// initialize the method
307
311
312////////////////////////////////////////////////////////////////////////////////
313/// initialize the circular tree
314
316{
317 delete fCatTree;
318 fCatTree = nullptr;
319
320 std::vector<VariableInfo>::const_iterator viIt;
321 const std::vector<VariableInfo>& vars = dsi.GetVariableInfos();
322 const std::vector<VariableInfo>& specs = dsi.GetSpectatorInfos();
323
324 Bool_t hasAllExternalLinks = kTRUE;
325 for (viIt = vars.begin(); viIt != vars.end(); ++viIt)
326 if( viIt->GetExternalLink() == 0 ) {
327 hasAllExternalLinks = kFALSE;
328 break;
329 }
330 for (viIt = specs.begin(); viIt != specs.end(); ++viIt)
331 if( viIt->GetExternalLink() == 0 ) {
332 hasAllExternalLinks = kFALSE;
333 break;
334 }
335
336 if(!hasAllExternalLinks) return;
337
338 {
339 // Rather than having TTree::TTree add to the current directory and then remove it, let
340 // make sure to not add it in the first place.
341 // The add-then-remove can lead to a problem if gDirectory points to the same directory (for example
342 // gROOT) in the current thread and another one (and both try to add to the directory at the same time).
343 TDirectory::TContext ctxt(nullptr);
344 fCatTree = new TTree(TString::Format("Circ%s",GetMethodName().Data()).Data(),"Circular Tree for categorization");
345 fCatTree->SetCircular(1);
346 }
347
348 for (viIt = vars.begin(); viIt != vars.end(); ++viIt) {
349 const VariableInfo& vi = *viIt;
351 }
352 for (viIt = specs.begin(); viIt != specs.end(); ++viIt) {
353 const VariableInfo& vi = *viIt;
354 if(vi.GetVarType()=='C') continue;
356 }
357
358 for(UInt_t cat=0; cat!=fCategoryCuts.size(); ++cat) {
359 fCatFormulas.push_back(new TTreeFormula(TString::Format("Category_%i",cat).Data(), fCategoryCuts[cat].GetTitle(), fCatTree));
360 }
361}
362
363////////////////////////////////////////////////////////////////////////////////
364/// train all sub-classifiers
365
367{
368 // specify the minimum # of training events and set 'classification'
369 const Int_t MinNoTrainingEvents = 10;
370
371 Types::EAnalysisType analysisType = GetAnalysisType();
372
373 // start the training
374 Log() << kINFO << "Train all sub-classifiers for "
375 << (analysisType == Types::kRegression ? "Regression" : "Classification") << " ..." << Endl;
376
377 // don't do anything if no sub-classifier booked
378 if (fMethods.empty()) {
379 Log() << kINFO << "...nothing found to train" << Endl;
380 return;
381 }
382
383 std::vector<IMethod*>::iterator itrMethod;
384
385 // iterate over all booked sub-classifiers and train them
386 for (itrMethod = fMethods.begin(); itrMethod != fMethods.end(); ++itrMethod ) {
387
388 MethodBase* mva = dynamic_cast<MethodBase*>(*itrMethod);
389 if(!mva) continue;
390 mva->SetAnalysisType( analysisType );
391 if (!mva->HasAnalysisType( analysisType,
392 mva->DataInfo().GetNClasses(),
393 mva->DataInfo().GetNTargets() ) ) {
394 Log() << kWARNING << "Method " << mva->GetMethodTypeName() << " is not capable of handling " ;
395 if (analysisType == Types::kRegression)
396 Log() << "regression with " << mva->DataInfo().GetNTargets() << " targets." << Endl;
397 else
398 Log() << "classification with " << mva->DataInfo().GetNClasses() << " classes." << Endl;
399 itrMethod = fMethods.erase( itrMethod );
400 continue;
401 }
403
404 Log() << kINFO << "Train method: " << mva->GetMethodName() << " for "
405 << (analysisType == Types::kRegression ? "Regression" : "Classification") << Endl;
406 mva->TrainMethod();
407 Log() << kINFO << "Training finished" << Endl;
408
409 } else {
410
411 Log() << kWARNING << "Method " << mva->GetMethodName()
412 << " not trained (training tree has less entries ["
413 << mva->Data()->GetNTrainingEvents()
414 << "] than required [" << MinNoTrainingEvents << "]" << Endl;
415
416 Log() << kERROR << " w/o training/test events for that category, I better stop here and let you fix " << Endl;
417 Log() << kFATAL << "that one first, otherwise things get too messy later ... " << Endl;
418
419 }
420 }
421
422 if (analysisType != Types::kRegression) {
423
424 // variable ranking
425 Log() << kINFO << "Begin ranking of input variables..." << Endl;
426 for (itrMethod = fMethods.begin(); itrMethod != fMethods.end(); ++itrMethod) {
427 MethodBase* mva = dynamic_cast<MethodBase*>(*itrMethod);
428 if (mva && mva->Data()->GetNTrainingEvents() >= MinNoTrainingEvents) {
429 const Ranking* ranking = (*itrMethod)->CreateRanking();
430 if (ranking != 0)
431 ranking->Print();
432 else
433 Log() << kINFO << "No variable ranking supplied by classifier: "
434 << dynamic_cast<MethodBase*>(*itrMethod)->GetMethodName() << Endl;
435 }
436 }
437 }
438}
439
440////////////////////////////////////////////////////////////////////////////////
441/// create XML description of Category classifier
442
444{
445 void* wght = gTools().AddChild(parent, "Weights");
446 gTools().AddAttr( wght, "NSubMethods", fMethods.size() );
447 void* submethod(0);
448
449 // iterate over methods and write them to XML file
450 for (UInt_t i=0; i<fMethods.size(); i++) {
451 MethodBase* method = dynamic_cast<MethodBase*>(fMethods[i]);
452 submethod = gTools().AddChild(wght, "SubMethod");
453 gTools().AddAttr(submethod, "Index", i);
454 gTools().AddAttr(submethod, "Method", method->GetMethodTypeName() + "::" + method->GetMethodName());
455 gTools().AddAttr(submethod, "Cut", fCategoryCuts[i]);
456 gTools().AddAttr(submethod, "Variables", fVars[i]);
457 method->WriteStateToXML( submethod );
458 }
459}
460
461////////////////////////////////////////////////////////////////////////////////
462/// read weights of sub-classifiers of MethodCategory from xml weight file
463
465{
466 UInt_t nSubMethods;
467 TString fullMethodName;
468 TString methodType;
469 TString methodTitle;
470 TString theCutString;
471 TString theVariables;
472 Int_t titleLength;
473 gTools().ReadAttr( wghtnode, "NSubMethods", nSubMethods );
474 void* subMethodNode = gTools().GetChild(wghtnode);
475
476 Log() << kINFO << "Recreating sub-classifiers from XML-file " << Endl;
477
478 // recreate all sub-methods from weight file
479 for (UInt_t i=0; i<nSubMethods; i++) {
480 gTools().ReadAttr( subMethodNode, "Method", fullMethodName );
481 gTools().ReadAttr( subMethodNode, "Cut", theCutString );
482 gTools().ReadAttr( subMethodNode, "Variables", theVariables );
483
484 // determine sub-method type
485 methodType = fullMethodName(0,fullMethodName.Index("::"));
486 if (methodType.Contains(" ")) methodType = methodType(methodType.Last(' ')+1,methodType.Length());
487
488 // determine sub-method title
489 titleLength = fullMethodName.Length()-fullMethodName.Index("::")-2;
490 methodTitle = fullMethodName(fullMethodName.Index("::")+2,titleLength);
491
492 // reconstruct dsi for sub-method
493 DataSetInfo& dsi = CreateCategoryDSI(TCut(theCutString), theVariables, methodTitle);
494
495 // recreate sub-method from weights and add to fMethods
496 MethodBase* method = dynamic_cast<MethodBase*>( ClassifierFactory::Instance().Create( methodType.Data(),
497 dsi, "none" ) );
498 if(method==0)
499 Log() << kFATAL << "Could not create sub-method " << method << " from XML." << Endl;
500
501 method->SetupMethod();
502 method->ReadStateFromXML(subMethodNode);
503
504 fMethods.push_back(method);
505 fCategoryCuts.push_back(TCut(theCutString));
506 fVars.push_back(theVariables);
507
508 DataSetInfo& primaryDSI = DataInfo();
509
510 UInt_t spectatorIdx = 10000;
511 UInt_t counter=0;
512
513 // find the spectator index
514 std::vector<VariableInfo>& spectators=primaryDSI.GetSpectatorInfos();
515 std::vector<VariableInfo>::iterator itrVarInfo;
516 TString specName= TString::Format("%s_cat%i", GetName(),(int)fCategorySpecIdx.size()+1);
517
518 for (itrVarInfo = spectators.begin(); itrVarInfo != spectators.end(); ++itrVarInfo, ++counter) {
519 if((specName==itrVarInfo->GetLabel()) || (specName==itrVarInfo->GetExpression())) {
520 spectatorIdx=counter;
521 fCategorySpecIdx.push_back(spectatorIdx);
522 break;
523 }
524 }
525
526 subMethodNode = gTools().GetNextChild(subMethodNode);
527 }
528
530
531}
532
533////////////////////////////////////////////////////////////////////////////////
534/// process user options
535
539
540////////////////////////////////////////////////////////////////////////////////
541/// Get help message text
542///
543/// typical length of text line:
544/// "|--------------------------------------------------------------|"
545
547{
548 Log() << Endl;
549 Log() << gTools().Color("bold") << "--- Short description:" << gTools().Color("reset") << Endl;
550 Log() << Endl;
551 Log() << "This method allows to define different categories of events. The" <<Endl;
552 Log() << "categories are defined via cuts on the variables. For each" << Endl;
553 Log() << "category, a different classifier and set of variables can be" <<Endl;
554 Log() << "specified. The categories which are defined for this method must" << Endl;
555 Log() << "be disjoint." << Endl;
556}
557
558////////////////////////////////////////////////////////////////////////////////
559/// no ranking
560
562{
563 return 0;
564}
565
566////////////////////////////////////////////////////////////////////////////////
567
569{
570 // if it's not a simple 'spectator' variable (0 or 1) that the categories are defined by
571 // (but rather some 'formula' (i.e. eta>0), then this formulas are stored in fCatTree and that
572 // one will be evaluated.. (the formulae return 'true' or 'false'
573 if (fCatTree) {
574 if (methodIdx>=fCatFormulas.size()) {
575 Log() << kFATAL << "Large method index " << methodIdx << ", number of category formulas = "
576 << fCatFormulas.size() << Endl;
577 }
578 TTreeFormula* f = fCatFormulas[methodIdx];
579 return f->EvalInstance(0) > 0.5;
580 }
581 // otherwise, it simply looks if "variable == true" ("greater 0.5 to be "sure" )
582 else {
583
584 // checks whether an event lies within a cut
585 if (methodIdx>=fCategorySpecIdx.size()) {
586 Log() << kFATAL << "Unknown method index " << methodIdx << " maximum allowed index="
587 << fCategorySpecIdx.size() << Endl;
588 }
589 UInt_t spectatorIdx = fCategorySpecIdx[methodIdx];
590 Float_t specVal = ev->GetSpectator(spectatorIdx);
591 Bool_t pass = (specVal>0.5);
592 return pass;
593 }
594}
595
596////////////////////////////////////////////////////////////////////////////////
597/// returns the mva value of the right sub-classifier
598
600{
601 if (fMethods.empty()) return 0;
602
603 UInt_t methodToUse = 0;
604 const Event* ev = GetEvent();
605
606 // determine which sub-classifier to use for this event
607 Int_t suitableCutsN = 0;
608
609 for (UInt_t i=0; i<fMethods.size(); ++i) {
610 if (PassesCut(ev, i)) {
611 ++suitableCutsN;
612 methodToUse=i;
613 }
614 }
615
616 if (suitableCutsN == 0) {
617 Log() << kWARNING << "Event does not lie within the cut of any sub-classifier." << Endl;
618 return 0;
619 }
620
621 if (suitableCutsN > 1) {
622 Log() << kFATAL << "The defined categories are not disjoint." << Endl;
623 return 0;
624 }
625
626 // get mva value from the suitable sub-classifier
627 ev->SetVariableArrangement(&fVarMaps[methodToUse]);
628 Double_t mvaValue = dynamic_cast<MethodBase*>(fMethods[methodToUse])->GetMvaValue(ev,err,errUpper);
630
631 Log() << kDEBUG << "Event is for method " << methodToUse << " spectator is " << ev->GetSpectator(0) << " "
632 << fVarMaps[0][0] << " classID " << DataInfo().IsSignal(ev) << " value " << mvaValue
633 << " type " << Data()->GetCurrentType() << Endl;
634
635 return mvaValue;
636}
637
638///////////////////////////////////////////////////////////////
639/// returns the mva values of the right sub-classifier
640///
641std::vector<Double_t>
643{
644
645 std::vector<Double_t> result;
646
647 Info("GetMVaValues", "Evaluate MethodCategory for %d events type %d on the dataset %s", int(lastEvt - firstEvt),
648 (int)Data()->GetCurrentType(), DataInfo().GetName());
649
650 if (fMethods.empty())
651 return result;
652
653 auto data = Data();
654
655 // it is faster to evaluate all categories
656 std::vector<std::vector<Double_t>> mvaValues(fMethods.size());
657 for (UInt_t i = 0; i < fMethods.size(); ++i) {
658 // need to set variable map
659 for (UInt_t iev = firstEvt; iev < lastEvt; ++iev) {
660 data->SetCurrentEvent(iev);
661 const Event *ev = GetEvent(data->GetEvent());
663 }
664 // need to set correct data in the different method
665 mvaValues[i] = dynamic_cast<MethodBase *>(fMethods[i])->GetDataMvaValues(data,firstEvt, lastEvt, logProgress);
666 }
667
668 // now loop on all events
669 result.resize(lastEvt - firstEvt);
670
671 for (UInt_t iev = firstEvt; iev < lastEvt; ++iev)
672 {
673 data->SetCurrentEvent(iev);
674 UInt_t methodToUse = 0;
675 const Event *ev = GetEvent(data->GetEvent());
676
677 // determine which sub-classifier to use for this event
678 Int_t suitableCutsN = 0;
679
680 for (UInt_t i = 0; i < fMethods.size(); ++i) {
681 if (PassesCut(ev, i)) {
682 ++suitableCutsN;
683 methodToUse = i;
684 }
685 }
686
687 if (suitableCutsN == 0) {
688 Log() << kWARNING << "Event does not lie within the cut of any sub-classifier." << Endl;
689 result[iev] = 0;
690 }
691
692 if (suitableCutsN > 1) {
693 Log() << kFATAL << "The defined categories are not disjoint." << Endl;
694 return result;
695 }
696
697
698 result[iev - firstEvt] = mvaValues[methodToUse][iev - firstEvt];
699
700 // reset variable map which was set it before
701 ev->SetVariableArrangement(nullptr);
702 }
703 return result;
704}
705
706////////////////////////////////////////////////////////////////////////////////
707/// returns the mva values of the multi-class right sub-classifier
708///
709const std::vector<Float_t> &TMVA::MethodCategory::GetMulticlassValues()
710{
711 if (fMethods.empty())
713
714 UInt_t methodToUse = 0;
715 const Event *ev = GetEvent();
716
717 // determine which sub-classifier to use for this event
718 Int_t suitableCutsN = 0;
719
720 for (UInt_t i = 0; i < fMethods.size(); ++i) {
721 if (PassesCut(ev, i)) {
722 ++suitableCutsN;
723 methodToUse = i;
724 }
725 }
726
727 if (suitableCutsN == 0) {
728 Log() << kWARNING << "Event does not lie within the cut of any sub-classifier." << Endl;
730 }
731
732 if (suitableCutsN > 1) {
733 Log() << kFATAL << "The defined categories are not disjoint." << Endl;
735 }
736 MethodBase *meth = dynamic_cast<MethodBase *>(fMethods[methodToUse]);
737 if (!meth) {
738 Log() << kFATAL << "method not found in Category Regression method" << Endl;
740 }
741 // get mva value from the suitable sub-classifier
742 ev->SetVariableArrangement(&fVarMaps[methodToUse]);
743 auto &result = meth->GetMulticlassValues();
744 ev->SetVariableArrangement(nullptr);
745 return result;
746}
747
748////////////////////////////////////////////////////////////////////////////////
749/// returns the mva value of the right sub-classifier
750
751const std::vector<Float_t> &TMVA::MethodCategory::GetRegressionValues()
752{
753 if (fMethods.empty()) return MethodBase::GetRegressionValues();
754
755 UInt_t methodToUse = 0;
756 const Event* ev = GetEvent();
757
758 // determine which sub-classifier to use for this event
759 Int_t suitableCutsN = 0;
760
761 for (UInt_t i=0; i<fMethods.size(); ++i) {
762 if (PassesCut(ev, i)) {
763 ++suitableCutsN;
764 methodToUse=i;
765 }
766 }
767
768 if (suitableCutsN == 0) {
769 Log() << kWARNING << "Event does not lie within the cut of any sub-classifier." << Endl;
771 }
772
773 if (suitableCutsN > 1) {
774 Log() << kFATAL << "The defined categories are not disjoint." << Endl;
776 }
777 MethodBase* meth = dynamic_cast<MethodBase*>(fMethods[methodToUse]);
778 if (!meth){
779 Log() << kFATAL << "method not found in Category Regression method" << Endl;
781 }
782 // get mva value from the suitable sub-classifier
783 ev->SetVariableArrangement(&fVarMaps[methodToUse]);
784 auto & result = meth->GetRegressionValues(ev);
785 return result;
786}
#define MinNoTrainingEvents
#define REGISTER_METHOD(CLASS)
for example
#define f(i)
Definition RSha256.hxx:104
size_t size(const MatrixT &matrix)
retrieve the size of a square matrix
int Int_t
Signed integer 4 bytes (int).
Definition RtypesCore.h:59
unsigned int UInt_t
Unsigned integer 4 bytes (unsigned int).
Definition RtypesCore.h:60
bool Bool_t
Boolean (0=false, 1=true) (bool).
Definition RtypesCore.h:77
constexpr Bool_t kFALSE
Definition RtypesCore.h:108
double Double_t
Double 8 bytes.
Definition RtypesCore.h:73
long long Long64_t
Portable signed long integer 8 bytes.
Definition RtypesCore.h:83
float Float_t
Float 4 bytes (float).
Definition RtypesCore.h:71
constexpr Bool_t kTRUE
Definition RtypesCore.h:107
void Info(const char *location, const char *msgfmt,...)
Use this function for informational messages.
Definition TError.cxx:241
Double_t err
A specialized string object used for TTree selections.
Definition TCut.h:25
TDirectory::TContext keeps track and restore the current directory.
Definition TDirectory.h:89
Describe directory structure in memory.
Definition TDirectory.h:45
IMethod * Create(const std::string &name, const TString &job, const TString &title, DataSetInfo &dsi, const TString &option)
creates the method if needed based on the method name using the creator function the factory has stor...
static ClassifierFactory & Instance()
access to the ClassifierFactory singleton creates the instance if needed
virtual void ParseOptions()
options parser
MsgLogger & Log() const
Class that contains all the data information.
Definition DataSetInfo.h:62
const TString GetWeightExpression(Int_t i) const
std::vector< VariableInfo > & GetVariableInfos()
VariableInfo & AddVariable(const TString &expression, const TString &title="", const TString &unit="", Double_t min=0, Double_t max=0, char varType='F', Bool_t normalized=kTRUE, void *external=nullptr)
add a variable (can be a complex expression) to the set of variables used in the MV analysis
void SetSplitOptions(const TString &so)
ClassInfo * AddClass(const TString &className)
const TString & GetNormalization() const
std::vector< VariableInfo > & GetSpectatorInfos()
TDirectory * GetRootDir() const
void SetNormalization(const TString &norm)
UInt_t GetNClasses() const
const TString & GetSplitOptions() const
UInt_t GetNTargets() const
ClassInfo * GetClassInfo(Int_t clNum) const
VariableInfo & AddTarget(const TString &expression, const TString &title, const TString &unit, Double_t min, Double_t max, Bool_t normalized=kTRUE, void *external=nullptr)
add a variable (can be a complex expression) to the set of variables used in the MV analysis
const TCut & GetCut(Int_t i) const
void SetCut(const TCut &cut, const TString &className)
set the cut for the classes
std::vector< VariableInfo > & GetTargetInfos()
void SetRootDir(TDirectory *d)
VariableInfo & AddSpectator(const TString &expression, const TString &title, const TString &unit, Double_t min, Double_t max, char type='F', Bool_t normalized=kTRUE, void *external=nullptr)
add a spectator (can be a complex expression) to the set of spectator variables used in the MV analys...
void SetWeightExpression(const TString &exp, const TString &className="")
set the weight expressions for the classes if class name is specified, set only for this class if cla...
void AddCut(const TCut &cut, const TString &className)
set the cut for the classes
Long64_t GetNTrainingEvents() const
Definition DataSet.h:68
void SetVariableArrangement(std::vector< UInt_t > *const m) const
set the variable arrangement
Definition Event.cxx:191
Float_t GetSpectator(UInt_t ivar) const
return spectator content
Definition Event.cxx:261
Interface for all concrete MVA method implementations.
Definition IMethod.h:53
virtual Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets)=0
Virtual base Class for all MVA method.
Definition MethodBase.h:111
virtual const std::vector< Float_t > & GetRegressionValues()
Definition MethodBase.h:224
const std::vector< Float_t > & GetRegressionValues(const TMVA::Event *const ev)
Definition MethodBase.h:217
TString fFileDir
unix sub-directory for weight files (default: DataLoader's Name + "weights")
Definition MethodBase.h:640
const char * GetName() const override
Definition MethodBase.h:337
void SetSilentFile(Bool_t status)
Definition MethodBase.h:381
void SetWeightFileDir(TString fileDir)
set directory of weight file
void WriteStateToXML(void *parent) const
general method used in writing the header of the weight files where the used variables,...
TString GetMethodTypeName() const
Definition MethodBase.h:335
void DisableWriting(Bool_t setter)
Definition MethodBase.h:445
virtual const std::vector< Float_t > & GetMulticlassValues()
Definition MethodBase.h:230
Types::EAnalysisType GetAnalysisType() const
Definition MethodBase.h:440
const TString & GetJobName() const
Definition MethodBase.h:333
Bool_t fModelPersistence
Definition MethodBase.h:636
virtual std::vector< Double_t > GetDataMvaValues(DataSet *data=nullptr, Long64_t firstEvt=0, Long64_t lastEvt=-1, Bool_t logProgress=false)
get all the MVA values for the events of the given Data type
void SetupMethod()
setup of methods
TDirectory * BaseDir() const
returns the ROOT directory where info/histograms etc of the corresponding MVA method instance are sto...
virtual void SetAnalysisType(Types::EAnalysisType type)
Definition MethodBase.h:439
const TString & GetMethodName() const
Definition MethodBase.h:334
const Event * GetEvent() const
Definition MethodBase.h:754
void ProcessSetup()
process all options the "CheckForUnusedOptions" is done in an independent call, since it may be overr...
DataSetInfo & DataInfo() const
Definition MethodBase.h:413
Types::EAnalysisType fAnalysisType
Definition MethodBase.h:598
Bool_t IsSilentFile() const
Definition MethodBase.h:382
void SetFile(TFile *file)
Definition MethodBase.h:378
void ReadStateFromXML(void *parent)
friend class MethodCategory
Definition MethodBase.h:272
void SetMethodBaseDir(TDirectory *methodDir)
Definition MethodBase.h:377
DataSet * Data() const
Definition MethodBase.h:412
void SetModelPersistence(Bool_t status)
Definition MethodBase.h:385
virtual void CheckSetup()
check may be overridden by derived class (sometimes, eg, fitters are used which can only be implement...
void InitCircularTree(const DataSetInfo &dsi)
initialize the circular tree
void DeclareOptions() override
options for this method
void ReadWeightsFromXML(void *wghtnode) override
read weights of sub-classifiers of MethodCategory from xml weight file
Bool_t PassesCut(const Event *ev, UInt_t methodIdx)
void Train(void) override
train all sub-classifiers
Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t) override
check whether method category has analysis type the method type has to be the same for all sub-method...
MethodCategory(const TString &jobName, const TString &methodTitle, DataSetInfo &theData, const TString &theOption="")
standard constructor
const Ranking * CreateRanking() override
no ranking
std::vector< std::vector< UInt_t > > fVarMaps
TMVA::DataSetInfo & CreateCategoryDSI(const TCut &, const TString &, const TString &)
create a DataSetInfo object for a sub-classifier
std::vector< IMethod * > fMethods
DataSetManager * fDataSetManager
std::vector< UInt_t > fCategorySpecIdx
std::vector< TString > fVars
TTree * fCatTree
! needed in conjunction with TTreeFormulas for evaluation category expressions
const std::vector< Float_t > & GetRegressionValues() override
returns the mva value of the right sub-classifier
void Init() override
initialize the method
virtual ~MethodCategory(void)
destructor
void AddWeightsXMLTo(void *parent) const override
create XML description of Category classifier
void ProcessOptions() override
process user options
std::vector< TTreeFormula * > fCatFormulas
std::vector< TCut > fCategoryCuts
TMVA::IMethod * AddMethod(const TCut &, const TString &theVariables, Types::EMVA theMethod, const TString &theTitle, const TString &theOptions)
adds sub-classifier for a category
void GetHelpMessage() const override
Get help message text.
Double_t GetMvaValue(Double_t *err=nullptr, Double_t *errUpper=nullptr) override
returns the mva value of the right sub-classifier
std::vector< Double_t > GetMvaValues(Long64_t firstEvt=0, Long64_t lastEvt=-1, Bool_t logProgress=false) override
returns the mva values of the right sub-classifier
const std::vector< Float_t > & GetMulticlassValues() override
returns the mva values of the multi-class right sub-classifier
MethodCompositeBase(const TString &jobName, Types::EMVA methodType, const TString &methodTitle, DataSetInfo &theData, const TString &theOption="")
Ranking for variables in method (implementation).
Definition Ranking.h:48
virtual void Print() const
get maximum length of variable names
Definition Ranking.cxx:110
std::vector< TString > SplitString(const TString &theOpt, const char separator) const
splits the option string at 'separator' and fills the list 'splitV' with the primitive strings
Definition Tools.cxx:1174
const TString & Color(const TString &)
human readable color strings
Definition Tools.cxx:803
void ReadAttr(void *node, const char *, T &value)
read attribute from xml
Definition Tools.h:329
void * GetChild(void *parent, const char *childname=nullptr)
get child node
Definition Tools.cxx:1125
void AddAttr(void *node, const char *, const T &value, Int_t precision=16)
add attribute to xml
Definition Tools.h:347
void * AddChild(void *parent, const char *childname, const char *content=nullptr, bool isRootNode=false)
add child node
Definition Tools.cxx:1099
void * GetNextChild(void *prevchild, const char *childname=nullptr)
XML helpers.
Definition Tools.cxx:1137
Singleton class for Global types used by TMVA.
Definition Types.h:71
static Types & Instance()
The single instance of "Types" if existing already, or create it (Singleton).
Definition Types.cxx:70
@ kRegression
Definition Types.h:128
Class for type info of MVA input variable.
const TString & GetExpression() const
char GetVarType() const
void * GetExternalLink() const
const char * GetName() const override
Returns name of object.
Definition TNamed.h:49
const char * GetTitle() const override
Returns title of object.
Definition TNamed.h:50
Basic string class.
Definition TString.h:138
Ssiz_t Length() const
Definition TString.h:425
const char * Data() const
Definition TString.h:384
Ssiz_t Last(char c) const
Find last occurrence of a character c.
Definition TString.cxx:938
static TString Format(const char *fmt,...)
Static method which formats a string using a printf style format descriptor and return a TString.
Definition TString.cxx:2385
Bool_t Contains(const char *pat, ECaseCompare cmp=kExact) const
Definition TString.h:641
Ssiz_t Index(const char *pat, Ssiz_t i=0, ECaseCompare cmp=kExact) const
Definition TString.h:660
Used to pass a selection expression to the Tree drawing routine.
A TTree represents a columnar dataset.
Definition TTree.h:89
create variable transformations
Tools & gTools()
void variables(TString dataset, TString fin="TMVA.root", TString dirName="InputVariables_Id", TString title="TMVA Input Variables", Bool_t isRegression=kFALSE, Bool_t useTMVAStyle=kTRUE)
MsgLogger & Endl(MsgLogger &ml)
Definition MsgLogger.h:148