Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
MethodCategory.cxx
Go to the documentation of this file.
1// @(#)root/tmva $Id$
2// Author: Andreas Hoecker, Joerg Stelzer, Helge Voss, Eckhard von Toerne
3
4/**********************************************************************************
5 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6 * Package: TMVA *
7 * Class : MethodCompositeBase *
8 * Web : http://tmva.sourceforge.net *
9 * *
10 * Description: *
11 * Virtual base class for all MVA method *
12 * *
13 * Authors (alphabetical): *
14 * Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland *
15 * Nadim Sah <Nadim.Sah@cern.ch> - Berlin, Germany *
16 * Peter Speckmayer <Peter.Speckmazer@cern.ch> - CERN, Switzerland *
17 * Joerg Stelzer <Joerg.Stelzer@cern.ch> - MSU East Lansing, USA *
18 * Helge Voss <Helge.Voss@cern.ch> - MPI-K Heidelberg, Germany *
19 * Jan Therhaag <Jan.Therhaag@cern.ch> - U of Bonn, Germany *
20 * Eckhard v. Toerne <evt@uni-bonn.de> - U of Bonn, Germany *
21 * *
22 * Copyright (c) 2005-2011: *
23 * CERN, Switzerland *
24 * MSU East Lansing, USA *
25 * MPI-K Heidelberg, Germany *
26 * U. of Bonn, Germany *
27 * *
28 * Redistribution and use in source and binary forms, with or without *
29 * modification, are permitted according to the terms listed in LICENSE *
30 * (http://tmva.sourceforge.net/LICENSE) *
31 **********************************************************************************/
32
33/*! \class TMVA::MethodCategory
34\ingroup TMVA
35
36Class for categorizing the phase space
37
38This class is meant to allow categorisation of the data. For different
39categories, different classifiers may be booked and different variables
40may be considered. The aim is to account for the difference that
41is due to different locations/angles.
42*/
43
44
45#include "TMVA/MethodCategory.h"
46
47#include <algorithm>
48#include <vector>
49
50#include "TRandom3.h"
51#include "TH1F.h"
52#include "TSpline.h"
53#include "TDirectory.h"
54#include "TTreeFormula.h"
55
57#include "TMVA/Config.h"
58#include "TMVA/DataSet.h"
59#include "TMVA/DataSetInfo.h"
60#include "TMVA/DataSetManager.h"
61#include "TMVA/IMethod.h"
62#include "TMVA/MethodBase.h"
64#include "TMVA/MsgLogger.h"
65#include "TMVA/PDF.h"
66#include "TMVA/Ranking.h"
67#include "TMVA/Timer.h"
68#include "TMVA/Tools.h"
69#include "TMVA/Types.h"
70#include "TMVA/VariableInfo.h"
72
73REGISTER_METHOD(Category)
74
76
77////////////////////////////////////////////////////////////////////////////////
78/// standard constructor
79
81 const TString& methodTitle,
82 DataSetInfo& theData,
83 const TString& theOption )
84 : TMVA::MethodCompositeBase( jobName, Types::kCategory, methodTitle, theData, theOption),
85 fCatTree(0),
86 fDataSetManager(NULL)
87{
88}
89
90////////////////////////////////////////////////////////////////////////////////
91/// constructor from weight file
92
94 const TString& theWeightFile)
95 : TMVA::MethodCompositeBase( Types::kCategory, dsi, theWeightFile),
96 fCatTree(0),
97 fDataSetManager(NULL)
98{
99}
100
101////////////////////////////////////////////////////////////////////////////////
102/// destructor
103
105{
106 std::vector<TTreeFormula*>::iterator formIt = fCatFormulas.begin();
107 std::vector<TTreeFormula*>::iterator lastF = fCatFormulas.end();
108 for(;formIt!=lastF; ++formIt) delete *formIt;
109 delete fCatTree;
110}
111
112////////////////////////////////////////////////////////////////////////////////
113/// check whether method category has analysis type
114/// the method type has to be the same for all sub-methods
115
117{
118 std::vector<IMethod*>::iterator itrMethod = fMethods.begin();
119
120 // iterate over methods and check whether they have the analysis type
121 for(; itrMethod != fMethods.end(); ++itrMethod ) {
122 if ( !(*itrMethod)->HasAnalysisType(type, numberClasses, numberTargets) )
123 return kFALSE;
124 }
125 return kTRUE;
126}
127
128////////////////////////////////////////////////////////////////////////////////
129/// options for this method
130
132{
133}
134
135////////////////////////////////////////////////////////////////////////////////
136/// adds sub-classifier for a category
137
139 const TString& theVariables,
140 Types::EMVA theMethod ,
141 const TString& theTitle,
142 const TString& theOptions )
143{
144 std::string addedMethodName(Types::Instance().GetMethodName(theMethod).Data());
145
146 Log() << kINFO << "Adding sub-classifier: " << addedMethodName << "::" << theTitle << Endl;
147
148 DataSetInfo& dsi = CreateCategoryDSI(theCut, theVariables, theTitle);
149
150 IMethod* addedMethod = ClassifierFactory::Instance().Create(addedMethodName,GetJobName(),theTitle,dsi,theOptions);
151
152 MethodBase *method = (dynamic_cast<MethodBase*>(addedMethod));
153 if(method==0) return 0;
154
155 if(fModelPersistence) method->SetWeightFileDir(fFileDir);
156 method->SetModelPersistence(fModelPersistence);
157 method->SetAnalysisType( fAnalysisType );
158 method->SetupMethod();
159 method->ParseOptions();
160 method->ProcessSetup();
161 method->SetFile(fFile);
162 method->SetSilentFile(IsSilentFile());
163
164
165 // set or create correct method base dir for added method
166 const TString dirName(Form("Method_%s",method->GetMethodTypeName().Data()));
167 TDirectory * dir = BaseDir()->GetDirectory(dirName);
168 if (dir != 0) method->SetMethodBaseDir( dir );
169 else method->SetMethodBaseDir( BaseDir()->mkdir(dirName,Form("Directory for all %s methods", method->GetMethodTypeName().Data())) );
170
171 // method->SetBaseDir(eigenes base dir, gucken ob Fisher dir existiert, sonst erzeugen )
172
173 // check-for-unused-options is performed; may be overridden by derived
174 // classes
175 method->CheckSetup();
176
177 // disable writing of XML files and standalone classes for sub methods
178 method->DisableWriting( kTRUE );
179
180 // store method, cut and variable names and create cut formula
181 fMethods.push_back(method);
182 fCategoryCuts.push_back(theCut);
183 fVars.push_back(theVariables);
184
185 DataSetInfo& primaryDSI = DataInfo();
186
187 UInt_t newSpectatorIndex = primaryDSI.GetSpectatorInfos().size();
188 fCategorySpecIdx.push_back(newSpectatorIndex);
189
190 primaryDSI.AddSpectator( Form("%s_cat%i:=%s", GetName(),(int)fMethods.size(),theCut.GetTitle()),
191 Form("%s:%s",GetName(),method->GetName()),
192 "pass", 0, 0, 'C' );
193
194 return method;
195}
196
197////////////////////////////////////////////////////////////////////////////////
198/// create a DataSetInfo object for a sub-classifier
199
201 const TString& theVariables,
202 const TString& theTitle)
203{
204 // create a new dsi with name: theTitle+"_dsi"
205 TString dsiName=theTitle+"_dsi";
206 DataSetInfo& oldDSI = DataInfo();
207 DataSetInfo* dsi = new DataSetInfo(dsiName);
208
209 // register the new dsi
210 // DataSetManager::Instance().AddDataSetInfo(*dsi); // DSMTEST replaced by following line
211 fDataSetManager->AddDataSetInfo(*dsi);
212
213 // copy the targets and spectators from the old dsi to the new dsi
214 std::vector<VariableInfo>::iterator itrVarInfo;
215
216 for (itrVarInfo = oldDSI.GetTargetInfos().begin(); itrVarInfo != oldDSI.GetTargetInfos().end(); ++itrVarInfo)
217 dsi->AddTarget(*itrVarInfo);
218
219 for (itrVarInfo = oldDSI.GetSpectatorInfos().begin(); itrVarInfo != oldDSI.GetSpectatorInfos().end(); ++itrVarInfo)
220 dsi->AddSpectator(*itrVarInfo);
221
222 // split string that contains the variables into tiny little pieces
223 std::vector<TString> variables = gTools().SplitString(theVariables,':' );
224
225 // prepare to create varMap
226 std::vector<UInt_t> varMap;
227 Int_t counter=0;
228
229 // add the variables that were specified in theVariables
230 std::vector<TString>::iterator itrVariables;
231 Bool_t found = kFALSE;
232
233 // iterate over all variables in 'variables' and add them
234 for (itrVariables = variables.begin(); itrVariables != variables.end(); ++itrVariables) {
235 counter=0;
236
237 // check the variables of the old dsi for the variable that we want to add
238 for (itrVarInfo = oldDSI.GetVariableInfos().begin(); itrVarInfo != oldDSI.GetVariableInfos().end(); ++itrVarInfo) {
239 if((*itrVariables==itrVarInfo->GetLabel()) ) { // || (*itrVariables==itrVarInfo->GetExpression())) {
240 // don't compare the expression, since the user might take two times the same expression, but with different labels
241 // and apply different transformations to the variables.
242 dsi->AddVariable(*itrVarInfo);
243 varMap.push_back(counter);
244 found = kTRUE;
245 }
246 counter++;
247 }
248
249 // check the spectators of the old dsi for the variable that we want to add
250 for (itrVarInfo = oldDSI.GetSpectatorInfos().begin(); itrVarInfo != oldDSI.GetSpectatorInfos().end(); ++itrVarInfo) {
251 if((*itrVariables==itrVarInfo->GetLabel()) ) { // || (*itrVariables==itrVarInfo->GetExpression())) {
252 // don't compare the expression, since the user might take two times the same expression, but with different labels
253 // and apply different transformations to the variables.
254 dsi->AddVariable(*itrVarInfo);
255 varMap.push_back(counter);
256 found = kTRUE;
257 }
258 counter++;
259 }
260
261 // if the variable is neither in the variables nor in the spectators, we abort
262 if (!found) {
263 Log() << kFATAL <<"The variable " << itrVariables->Data() << " was not found and could not be added " << Endl;
264 }
265 found = kFALSE;
266 }
267
268 // in the case that no variables are specified, add the default-variables from the original dsi
269 if (theVariables=="") {
270 for (UInt_t i=0; i<oldDSI.GetVariableInfos().size(); i++) {
271 dsi->AddVariable(oldDSI.GetVariableInfos()[i]);
272 varMap.push_back(i);
273 }
274 }
275
276 // add the variable map 'varMap' to the vector of varMaps
277 fVarMaps.push_back(varMap);
278
279 // set classes and cuts
280 UInt_t nClasses=oldDSI.GetNClasses();
281 TString className;
282
283 for (UInt_t i=0; i<nClasses; i++) {
284 className = oldDSI.GetClassInfo(i)->GetName();
285 dsi->AddClass(className);
286 dsi->SetCut(oldDSI.GetCut(i),className);
287 dsi->AddCut(theCut,className);
288 dsi->SetWeightExpression(oldDSI.GetWeightExpression(i),className);
289 }
290
291 // set split options, root dir and normalization for the new dsi
292 dsi->SetSplitOptions(oldDSI.GetSplitOptions());
293 dsi->SetRootDir(oldDSI.GetRootDir());
294 TString norm(oldDSI.GetNormalization().Data());
295 dsi->SetNormalization(norm);
296
297 DataSetInfo& dsiReference= (*dsi);
298
299 return dsiReference;
300}
301
302////////////////////////////////////////////////////////////////////////////////
303/// initialize the method
304
306{
307}
308
309////////////////////////////////////////////////////////////////////////////////
310/// initialize the circular tree
311
313{
314 delete fCatTree;
315
316 std::vector<VariableInfo>::const_iterator viIt;
317 const std::vector<VariableInfo>& vars = dsi.GetVariableInfos();
318 const std::vector<VariableInfo>& specs = dsi.GetSpectatorInfos();
319
320 Bool_t hasAllExternalLinks = kTRUE;
321 for (viIt = vars.begin(); viIt != vars.end(); ++viIt)
322 if( viIt->GetExternalLink() == 0 ) {
323 hasAllExternalLinks = kFALSE;
324 break;
325 }
326 for (viIt = specs.begin(); viIt != specs.end(); ++viIt)
327 if( viIt->GetExternalLink() == 0 ) {
328 hasAllExternalLinks = kFALSE;
329 break;
330 }
331
332 if(!hasAllExternalLinks) return;
333
334 {
335 // Rather than having TTree::TTree add to the current directory and then remove it, let
336 // make sure to not add it in the first place.
337 // The add-then-remove can lead to a problem if gDirectory points to the same directory (for example
338 // gROOT) in the current thread and another one (and both try to add to the directory at the same time).
339 TDirectory::TContext ctxt(nullptr);
340 fCatTree = new TTree(Form("Circ%s",GetMethodName().Data()),"Circular Tree for categorization");
341 fCatTree->SetCircular(1);
342 }
343
344 for (viIt = vars.begin(); viIt != vars.end(); ++viIt) {
345 const VariableInfo& vi = *viIt;
346 fCatTree->Branch(vi.GetExpression(),(Float_t*)vi.GetExternalLink(), TString(vi.GetExpression())+TString("/F"));
347 }
348 for (viIt = specs.begin(); viIt != specs.end(); ++viIt) {
349 const VariableInfo& vi = *viIt;
350 if(vi.GetVarType()=='C') continue;
351 fCatTree->Branch(vi.GetExpression(),(Float_t*)vi.GetExternalLink(), TString(vi.GetExpression())+TString("/F"));
352 }
353
354 for(UInt_t cat=0; cat!=fCategoryCuts.size(); ++cat) {
355 fCatFormulas.push_back(new TTreeFormula(Form("Category_%i",cat), fCategoryCuts[cat].GetTitle(), fCatTree));
356 }
357}
358
359////////////////////////////////////////////////////////////////////////////////
360/// train all sub-classifiers
361
363{
364 // specify the minimum # of training events and set 'classification'
365 const Int_t MinNoTrainingEvents = 10;
366
367 Types::EAnalysisType analysisType = GetAnalysisType();
368
369 // start the training
370 Log() << kINFO << "Train all sub-classifiers for "
371 << (analysisType == Types::kRegression ? "Regression" : "Classification") << " ..." << Endl;
372
373 // don't do anything if no sub-classifier booked
374 if (fMethods.empty()) {
375 Log() << kINFO << "...nothing found to train" << Endl;
376 return;
377 }
378
379 std::vector<IMethod*>::iterator itrMethod;
380
381 // iterate over all booked sub-classifiers and train them
382 for (itrMethod = fMethods.begin(); itrMethod != fMethods.end(); ++itrMethod ) {
383
384 MethodBase* mva = dynamic_cast<MethodBase*>(*itrMethod);
385 if(!mva) continue;
386 mva->SetAnalysisType( analysisType );
387 if (!mva->HasAnalysisType( analysisType,
388 mva->DataInfo().GetNClasses(),
389 mva->DataInfo().GetNTargets() ) ) {
390 Log() << kWARNING << "Method " << mva->GetMethodTypeName() << " is not capable of handling " ;
391 if (analysisType == Types::kRegression)
392 Log() << "regression with " << mva->DataInfo().GetNTargets() << " targets." << Endl;
393 else
394 Log() << "classification with " << mva->DataInfo().GetNClasses() << " classes." << Endl;
395 itrMethod = fMethods.erase( itrMethod );
396 continue;
397 }
399
400 Log() << kINFO << "Train method: " << mva->GetMethodName() << " for "
401 << (analysisType == Types::kRegression ? "Regression" : "Classification") << Endl;
402 mva->TrainMethod();
403 Log() << kINFO << "Training finished" << Endl;
404
405 } else {
406
407 Log() << kWARNING << "Method " << mva->GetMethodName()
408 << " not trained (training tree has less entries ["
409 << mva->Data()->GetNTrainingEvents()
410 << "] than required [" << MinNoTrainingEvents << "]" << Endl;
411
412 Log() << kERROR << " w/o training/test events for that category, I better stop here and let you fix " << Endl;
413 Log() << kFATAL << "that one first, otherwise things get too messy later ... " << Endl;
414
415 }
416 }
417
418 if (analysisType != Types::kRegression) {
419
420 // variable ranking
421 Log() << kINFO << "Begin ranking of input variables..." << Endl;
422 for (itrMethod = fMethods.begin(); itrMethod != fMethods.end(); ++itrMethod) {
423 MethodBase* mva = dynamic_cast<MethodBase*>(*itrMethod);
424 if (mva && mva->Data()->GetNTrainingEvents() >= MinNoTrainingEvents) {
425 const Ranking* ranking = (*itrMethod)->CreateRanking();
426 if (ranking != 0)
427 ranking->Print();
428 else
429 Log() << kINFO << "No variable ranking supplied by classifier: "
430 << dynamic_cast<MethodBase*>(*itrMethod)->GetMethodName() << Endl;
431 }
432 }
433 }
434}
435
436////////////////////////////////////////////////////////////////////////////////
437/// create XML description of Category classifier
438
440{
441 void* wght = gTools().AddChild(parent, "Weights");
442 gTools().AddAttr( wght, "NSubMethods", fMethods.size() );
443 void* submethod(0);
444
445 // iterate over methods and write them to XML file
446 for (UInt_t i=0; i<fMethods.size(); i++) {
447 MethodBase* method = dynamic_cast<MethodBase*>(fMethods[i]);
448 submethod = gTools().AddChild(wght, "SubMethod");
449 gTools().AddAttr(submethod, "Index", i);
450 gTools().AddAttr(submethod, "Method", method->GetMethodTypeName() + "::" + method->GetMethodName());
451 gTools().AddAttr(submethod, "Cut", fCategoryCuts[i]);
452 gTools().AddAttr(submethod, "Variables", fVars[i]);
453 method->WriteStateToXML( submethod );
454 }
455}
456
457////////////////////////////////////////////////////////////////////////////////
458/// read weights of sub-classifiers of MethodCategory from xml weight file
459
461{
462 UInt_t nSubMethods;
463 TString fullMethodName;
464 TString methodType;
465 TString methodTitle;
466 TString theCutString;
467 TString theVariables;
468 Int_t titleLength;
469 gTools().ReadAttr( wghtnode, "NSubMethods", nSubMethods );
470 void* subMethodNode = gTools().GetChild(wghtnode);
471
472 Log() << kINFO << "Recreating sub-classifiers from XML-file " << Endl;
473
474 // recreate all sub-methods from weight file
475 for (UInt_t i=0; i<nSubMethods; i++) {
476 gTools().ReadAttr( subMethodNode, "Method", fullMethodName );
477 gTools().ReadAttr( subMethodNode, "Cut", theCutString );
478 gTools().ReadAttr( subMethodNode, "Variables", theVariables );
479
480 // determine sub-method type
481 methodType = fullMethodName(0,fullMethodName.Index("::"));
482 if (methodType.Contains(" ")) methodType = methodType(methodType.Last(' ')+1,methodType.Length());
483
484 // determine sub-method title
485 titleLength = fullMethodName.Length()-fullMethodName.Index("::")-2;
486 methodTitle = fullMethodName(fullMethodName.Index("::")+2,titleLength);
487
488 // reconstruct dsi for sub-method
489 DataSetInfo& dsi = CreateCategoryDSI(TCut(theCutString), theVariables, methodTitle);
490
491 // recreate sub-method from weights and add to fMethods
492 MethodBase* method = dynamic_cast<MethodBase*>( ClassifierFactory::Instance().Create( methodType.Data(),
493 dsi, "none" ) );
494 if(method==0)
495 Log() << kFATAL << "Could not create sub-method " << method << " from XML." << Endl;
496
497 method->SetupMethod();
498 method->ReadStateFromXML(subMethodNode);
499
500 fMethods.push_back(method);
501 fCategoryCuts.push_back(TCut(theCutString));
502 fVars.push_back(theVariables);
503
504 DataSetInfo& primaryDSI = DataInfo();
505
506 UInt_t spectatorIdx = 10000;
507 UInt_t counter=0;
508
509 // find the spectator index
510 std::vector<VariableInfo>& spectators=primaryDSI.GetSpectatorInfos();
511 std::vector<VariableInfo>::iterator itrVarInfo;
512 TString specName= Form("%s_cat%i", GetName(),(int)fCategorySpecIdx.size()+1);
513
514 for (itrVarInfo = spectators.begin(); itrVarInfo != spectators.end(); ++itrVarInfo, ++counter) {
515 if((specName==itrVarInfo->GetLabel()) || (specName==itrVarInfo->GetExpression())) {
516 spectatorIdx=counter;
517 fCategorySpecIdx.push_back(spectatorIdx);
518 break;
519 }
520 }
521
522 subMethodNode = gTools().GetNextChild(subMethodNode);
523 }
524
525 InitCircularTree(DataInfo());
526
527}
528
529////////////////////////////////////////////////////////////////////////////////
530/// process user options
531
533{
534}
535
536////////////////////////////////////////////////////////////////////////////////
537/// Get help message text
538///
539/// typical length of text line:
540/// "|--------------------------------------------------------------|"
541
543{
544 Log() << Endl;
545 Log() << gTools().Color("bold") << "--- Short description:" << gTools().Color("reset") << Endl;
546 Log() << Endl;
547 Log() << "This method allows to define different categories of events. The" <<Endl;
548 Log() << "categories are defined via cuts on the variables. For each" << Endl;
549 Log() << "category, a different classifier and set of variables can be" <<Endl;
550 Log() << "specified. The categories which are defined for this method must" << Endl;
551 Log() << "be disjoint." << Endl;
552}
553
554////////////////////////////////////////////////////////////////////////////////
555/// no ranking
556
558{
559 return 0;
560}
561
562////////////////////////////////////////////////////////////////////////////////
563
565{
566 // if it's not a simple 'spectator' variable (0 or 1) that the categories are defined by
567 // (but rather some 'formula' (i.e. eta>0), then this formulas are stored in fCatTree and that
568 // one will be evaluated.. (the formulae return 'true' or 'false'
569 if (fCatTree) {
570 if (methodIdx>=fCatFormulas.size()) {
571 Log() << kFATAL << "Large method index " << methodIdx << ", number of category formulas = "
572 << fCatFormulas.size() << Endl;
573 }
574 TTreeFormula* f = fCatFormulas[methodIdx];
575 return f->EvalInstance(0) > 0.5;
576 }
577 // otherwise, it simply looks if "variable == true" ("greater 0.5 to be "sure" )
578 else {
579
580 // checks whether an event lies within a cut
581 if (methodIdx>=fCategorySpecIdx.size()) {
582 Log() << kFATAL << "Unknown method index " << methodIdx << " maximum allowed index="
583 << fCategorySpecIdx.size() << Endl;
584 }
585 UInt_t spectatorIdx = fCategorySpecIdx[methodIdx];
586 Float_t specVal = ev->GetSpectator(spectatorIdx);
587 Bool_t pass = (specVal>0.5);
588 return pass;
589 }
590}
591
592////////////////////////////////////////////////////////////////////////////////
593/// returns the mva value of the right sub-classifier
594
596{
597 if (fMethods.empty()) return 0;
598
599 UInt_t methodToUse = 0;
600 const Event* ev = GetEvent();
601
602 // determine which sub-classifier to use for this event
603 Int_t suitableCutsN = 0;
604
605 for (UInt_t i=0; i<fMethods.size(); ++i) {
606 if (PassesCut(ev, i)) {
607 ++suitableCutsN;
608 methodToUse=i;
609 }
610 }
611
612 if (suitableCutsN == 0) {
613 Log() << kWARNING << "Event does not lie within the cut of any sub-classifier." << Endl;
614 return 0;
615 }
616
617 if (suitableCutsN > 1) {
618 Log() << kFATAL << "The defined categories are not disjoint." << Endl;
619 return 0;
620 }
621
622 // get mva value from the suitable sub-classifier
623 ev->SetVariableArrangement(&fVarMaps[methodToUse]);
624 Double_t mvaValue = dynamic_cast<MethodBase*>(fMethods[methodToUse])->GetMvaValue(ev,err,errUpper);
626
627 return mvaValue;
628}
629
630////////////////////////////////////////////////////////////////////////////////
631/// returns the mva values of the multi-class right sub-classifier
632///
633const std::vector<Float_t> &TMVA::MethodCategory::GetMulticlassValues()
634{
635 if (fMethods.empty())
637
638 UInt_t methodToUse = 0;
639 const Event *ev = GetEvent();
640
641 // determine which sub-classifier to use for this event
642 Int_t suitableCutsN = 0;
643
644 for (UInt_t i = 0; i < fMethods.size(); ++i) {
645 if (PassesCut(ev, i)) {
646 ++suitableCutsN;
647 methodToUse = i;
648 }
649 }
650
651 if (suitableCutsN == 0) {
652 Log() << kWARNING << "Event does not lie within the cut of any sub-classifier." << Endl;
654 }
655
656 if (suitableCutsN > 1) {
657 Log() << kFATAL << "The defined categories are not disjoint." << Endl;
659 }
660 MethodBase *meth = dynamic_cast<MethodBase *>(fMethods[methodToUse]);
661 if (!meth) {
662 Log() << kFATAL << "method not found in Category Regression method" << Endl;
664 }
665 // get mva value from the suitable sub-classifier
666 ev->SetVariableArrangement(&fVarMaps[methodToUse]);
667 auto & result = meth->GetMulticlassValues();
668 ev->SetVariableArrangement(nullptr);
669 return result;
670}
671
672////////////////////////////////////////////////////////////////////////////////
673/// returns the mva value of the right sub-classifier
674
675const std::vector<Float_t> &TMVA::MethodCategory::GetRegressionValues()
676{
677 if (fMethods.empty()) return MethodBase::GetRegressionValues();
678
679 UInt_t methodToUse = 0;
680 const Event* ev = GetEvent();
681
682 // determine which sub-classifier to use for this event
683 Int_t suitableCutsN = 0;
684
685 for (UInt_t i=0; i<fMethods.size(); ++i) {
686 if (PassesCut(ev, i)) {
687 ++suitableCutsN;
688 methodToUse=i;
689 }
690 }
691
692 if (suitableCutsN == 0) {
693 Log() << kWARNING << "Event does not lie within the cut of any sub-classifier." << Endl;
695 }
696
697 if (suitableCutsN > 1) {
698 Log() << kFATAL << "The defined categories are not disjoint." << Endl;
700 }
701 MethodBase* meth = dynamic_cast<MethodBase*>(fMethods[methodToUse]);
702 if (!meth){
703 Log() << kFATAL << "method not found in Category Regression method" << Endl;
705 }
706 // get mva value from the suitable sub-classifier
707 ev->SetVariableArrangement(&fVarMaps[methodToUse]);
708 auto & result = meth->GetRegressionValues(ev);
709 return result;
710}
#define MinNoTrainingEvents
#define REGISTER_METHOD(CLASS)
for example
#define f(i)
Definition RSha256.hxx:104
const Bool_t kFALSE
Definition RtypesCore.h:92
double Double_t
Definition RtypesCore.h:59
float Float_t
Definition RtypesCore.h:57
const Bool_t kTRUE
Definition RtypesCore.h:91
#define ClassImp(name)
Definition Rtypes.h:364
int type
Definition TGX11.cxx:121
char * Form(const char *fmt,...)
A specialized string object used for TTree selections.
Definition TCut.h:25
Small helper to keep current directory context.
Definition TDirectory.h:52
Describe directory structure in memory.
Definition TDirectory.h:45
virtual TDirectory * GetDirectory(const char *namecycle, Bool_t printError=false, const char *funcname="GetDirectory")
Find a directory using apath.
IMethod * Create(const std::string &name, const TString &job, const TString &title, DataSetInfo &dsi, const TString &option)
creates the method if needed based on the method name using the creator function the factory has stor...
static ClassifierFactory & Instance()
access to the ClassifierFactory singleton creates the instance if needed
virtual void ParseOptions()
options parser
Class that contains all the data information.
Definition DataSetInfo.h:62
const TString GetWeightExpression(Int_t i) const
std::vector< VariableInfo > & GetVariableInfos()
void SetSplitOptions(const TString &so)
ClassInfo * AddClass(const TString &className)
const TString & GetNormalization() const
std::vector< VariableInfo > & GetSpectatorInfos()
TDirectory * GetRootDir() const
void SetNormalization(const TString &norm)
UInt_t GetNClasses() const
const TString & GetSplitOptions() const
UInt_t GetNTargets() const
VariableInfo & AddTarget(const TString &expression, const TString &title, const TString &unit, Double_t min, Double_t max, Bool_t normalized=kTRUE, void *external=0)
add a variable (can be a complex expression) to the set of variables used in the MV analysis
VariableInfo & AddSpectator(const TString &expression, const TString &title, const TString &unit, Double_t min, Double_t max, char type='F', Bool_t normalized=kTRUE, void *external=0)
add a spectator (can be a complex expression) to the set of spectator variables used in the MV analys...
ClassInfo * GetClassInfo(Int_t clNum) const
const TCut & GetCut(Int_t i) const
void SetCut(const TCut &cut, const TString &className)
set the cut for the classes
VariableInfo & AddVariable(const TString &expression, const TString &title="", const TString &unit="", Double_t min=0, Double_t max=0, char varType='F', Bool_t normalized=kTRUE, void *external=0)
add a variable (can be a complex expression) to the set of variables used in the MV analysis
std::vector< VariableInfo > & GetTargetInfos()
void SetRootDir(TDirectory *d)
void SetWeightExpression(const TString &exp, const TString &className="")
set the weight expressions for the classes if class name is specified, set only for this class if cla...
void AddCut(const TCut &cut, const TString &className)
set the cut for the classes
Long64_t GetNTrainingEvents() const
Definition DataSet.h:68
void SetVariableArrangement(std::vector< UInt_t > *const m) const
set the variable arrangement
Definition Event.cxx:191
Float_t GetSpectator(UInt_t ivar) const
return spectator content
Definition Event.cxx:261
Interface for all concrete MVA method implementations.
Definition IMethod.h:53
virtual Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets)=0
Virtual base Class for all MVA method.
Definition MethodBase.h:111
virtual const std::vector< Float_t > & GetRegressionValues()
Definition MethodBase.h:220
const std::vector< Float_t > & GetRegressionValues(const TMVA::Event *const ev)
Definition MethodBase.h:213
void SetSilentFile(Bool_t status)
Definition MethodBase.h:377
void SetWeightFileDir(TString fileDir)
set directory of weight file
void WriteStateToXML(void *parent) const
general method used in writing the header of the weight files where the used variables,...
TString GetMethodTypeName() const
Definition MethodBase.h:331
void DisableWriting(Bool_t setter)
Definition MethodBase.h:442
const char * GetName() const
Definition MethodBase.h:333
virtual const std::vector< Float_t > & GetMulticlassValues()
Definition MethodBase.h:226
void SetupMethod()
setup of methods
virtual void SetAnalysisType(Types::EAnalysisType type)
Definition MethodBase.h:436
const TString & GetMethodName() const
Definition MethodBase.h:330
void ProcessSetup()
process all options the "CheckForUnusedOptions" is done in an independent call, since it may be overr...
DataSetInfo & DataInfo() const
Definition MethodBase.h:409
void SetFile(TFile *file)
Definition MethodBase.h:374
void ReadStateFromXML(void *parent)
friend class MethodCategory
Definition MethodBase.h:268
void SetMethodBaseDir(TDirectory *methodDir)
Definition MethodBase.h:373
DataSet * Data() const
Definition MethodBase.h:408
void SetModelPersistence(Bool_t status)
Definition MethodBase.h:381
virtual void CheckSetup()
check may be overridden by derived class (sometimes, eg, fitters are used which can only be implement...
Class for categorizing the phase space.
void InitCircularTree(const DataSetInfo &dsi)
initialize the circular tree
void GetHelpMessage() const
Get help message text.
void Init()
initialize the method
Bool_t PassesCut(const Event *ev, UInt_t methodIdx)
virtual Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t)
check whether method category has analysis type the method type has to be the same for all sub-method...
void ProcessOptions()
process user options
Double_t GetMvaValue(Double_t *err=0, Double_t *errUpper=0)
returns the mva value of the right sub-classifier
virtual const std::vector< Float_t > & GetMulticlassValues()
returns the mva values of the multi-class right sub-classifier
TMVA::DataSetInfo & CreateCategoryDSI(const TCut &, const TString &, const TString &)
create a DataSetInfo object for a sub-classifier
void DeclareOptions()
options for this method
void AddWeightsXMLTo(void *parent) const
create XML description of Category classifier
const Ranking * CreateRanking()
no ranking
virtual ~MethodCategory(void)
destructor
virtual const std::vector< Float_t > & GetRegressionValues()
returns the mva value of the right sub-classifier
TMVA::IMethod * AddMethod(const TCut &, const TString &theVariables, Types::EMVA theMethod, const TString &theTitle, const TString &theOptions)
adds sub-classifier for a category
void ReadWeightsFromXML(void *wghtnode)
read weights of sub-classifiers of MethodCategory from xml weight file
void Train(void)
train all sub-classifiers
Virtual base class for combining several TMVA method.
Ranking for variables in method (implementation)
Definition Ranking.h:48
virtual void Print() const
get maximum length of variable names
Definition Ranking.cxx:111
void * GetNextChild(void *prevchild, const char *childname=0)
XML helpers.
Definition Tools.cxx:1174
void * AddChild(void *parent, const char *childname, const char *content=0, bool isRootNode=false)
add child node
Definition Tools.cxx:1136
std::vector< TString > SplitString(const TString &theOpt, const char separator) const
splits the option string at 'separator' and fills the list 'splitV' with the primitive strings
Definition Tools.cxx:1211
const TString & Color(const TString &)
human readable color strings
Definition Tools.cxx:840
void * GetChild(void *parent, const char *childname=0)
get child node
Definition Tools.cxx:1162
void ReadAttr(void *node, const char *, T &value)
read attribute from xml
Definition Tools.h:335
void AddAttr(void *node, const char *, const T &value, Int_t precision=16)
add attribute to xml
Definition Tools.h:353
Singleton class for Global types used by TMVA.
Definition Types.h:73
static Types & Instance()
the the single instance of "Types" if existing already, or create it (Singleton)
Definition Types.cxx:69
@ kRegression
Definition Types.h:130
Class for type info of MVA input variable.
const TString & GetExpression() const
char GetVarType() const
void * GetExternalLink() const
virtual const char * GetTitle() const
Returns title of object.
Definition TNamed.h:48
virtual const char * GetName() const
Returns name of object.
Definition TNamed.h:47
Basic string class.
Definition TString.h:136
Ssiz_t Length() const
Definition TString.h:410
const char * Data() const
Definition TString.h:369
Ssiz_t Last(char c) const
Find last occurrence of a character c.
Definition TString.cxx:912
Bool_t Contains(const char *pat, ECaseCompare cmp=kExact) const
Definition TString.h:624
Ssiz_t Index(const char *pat, Ssiz_t i=0, ECaseCompare cmp=kExact) const
Definition TString.h:639
Used to pass a selection expression to the Tree drawing routine.
T EvalInstance(Int_t i=0, const char *stringStack[]=0)
Evaluate this treeformula.
A TTree represents a columnar dataset.
Definition TTree.h:79
create variable transformations
Tools & gTools()
void variables(TString dataset, TString fin="TMVA.root", TString dirName="InputVariables_Id", TString title="TMVA Input Variables", Bool_t isRegression=kFALSE, Bool_t useTMVAStyle=kTRUE)
MsgLogger & Endl(MsgLogger &ml)
Definition MsgLogger.h:158