Logo ROOT  
Reference Guide
MethodCategory.cxx
Go to the documentation of this file.
1// @(#)root/tmva $Id$
2// Author: Andreas Hoecker, Joerg Stelzer, Helge Voss, Eckhard von Toerne
3
4/**********************************************************************************
5 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6 * Package: TMVA *
7 * Class : MethodCompositeBase *
8 * Web : http://tmva.sourceforge.net *
9 * *
10 * Description: *
11 * Virtual base class for all MVA method *
12 * *
13 * Authors (alphabetical): *
14 * Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland *
15 * Nadim Sah <Nadim.Sah@cern.ch> - Berlin, Germany *
16 * Peter Speckmayer <Peter.Speckmazer@cern.ch> - CERN, Switzerland *
17 * Joerg Stelzer <Joerg.Stelzer@cern.ch> - MSU East Lansing, USA *
18 * Helge Voss <Helge.Voss@cern.ch> - MPI-K Heidelberg, Germany *
19 * Jan Therhaag <Jan.Therhaag@cern.ch> - U of Bonn, Germany *
20 * Eckhard v. Toerne <evt@uni-bonn.de> - U of Bonn, Germany *
21 * *
22 * Copyright (c) 2005-2011: *
23 * CERN, Switzerland *
24 * MSU East Lansing, USA *
25 * MPI-K Heidelberg, Germany *
26 * U. of Bonn, Germany *
27 * *
28 * Redistribution and use in source and binary forms, with or without *
29 * modification, are permitted according to the terms listed in LICENSE *
30 * (http://tmva.sourceforge.net/LICENSE) *
31 **********************************************************************************/
32
33/*! \class TMVA::MethodCategory
34\ingroup TMVA
35
36Class for categorizing the phase space
37
38This class is meant to allow categorisation of the data. For different
39categories, different classifiers may be booked and different variables
40may be considered. The aim is to account for the difference that
41is due to different locations/angles.
42*/
43
44
45#include "TMVA/MethodCategory.h"
46
47#include <algorithm>
48#include <vector>
49
50#include "TRandom3.h"
51#include "TH1F.h"
52#include "TSpline.h"
53#include "TDirectory.h"
54#include "TTreeFormula.h"
55
57#include "TMVA/Config.h"
58#include "TMVA/DataSet.h"
59#include "TMVA/DataSetInfo.h"
60#include "TMVA/DataSetManager.h"
61#include "TMVA/IMethod.h"
62#include "TMVA/MethodBase.h"
64#include "TMVA/MsgLogger.h"
65#include "TMVA/PDF.h"
66#include "TMVA/Ranking.h"
67#include "TMVA/Timer.h"
68#include "TMVA/Tools.h"
69#include "TMVA/Types.h"
70#include "TMVA/VariableInfo.h"
72
73REGISTER_METHOD(Category)
74
76
77////////////////////////////////////////////////////////////////////////////////
78/// standard constructor
79
81 const TString& methodTitle,
82 DataSetInfo& theData,
83 const TString& theOption )
84 : TMVA::MethodCompositeBase( jobName, Types::kCategory, methodTitle, theData, theOption),
85 fCatTree(0),
86 fDataSetManager(NULL)
87{
88}
89
90////////////////////////////////////////////////////////////////////////////////
91/// constructor from weight file
92
94 const TString& theWeightFile)
95 : TMVA::MethodCompositeBase( Types::kCategory, dsi, theWeightFile),
96 fCatTree(0),
97 fDataSetManager(NULL)
98{
99}
100
101////////////////////////////////////////////////////////////////////////////////
102/// destructor
103
105{
106 std::vector<TTreeFormula*>::iterator formIt = fCatFormulas.begin();
107 std::vector<TTreeFormula*>::iterator lastF = fCatFormulas.end();
108 for(;formIt!=lastF; ++formIt) delete *formIt;
109 delete fCatTree;
110}
111
112////////////////////////////////////////////////////////////////////////////////
113/// check whether method category has analysis type
114/// the method type has to be the same for all sub-methods
115
117{
118 std::vector<IMethod*>::iterator itrMethod = fMethods.begin();
119
120 // iterate over methods and check whether they have the analysis type
121 for(; itrMethod != fMethods.end(); ++itrMethod ) {
122 if ( !(*itrMethod)->HasAnalysisType(type, numberClasses, numberTargets) )
123 return kFALSE;
124 }
125 return kTRUE;
126}
127
128////////////////////////////////////////////////////////////////////////////////
129/// options for this method
130
132{
133}
134
135////////////////////////////////////////////////////////////////////////////////
136/// adds sub-classifier for a category
137
139 const TString& theVariables,
140 Types::EMVA theMethod ,
141 const TString& theTitle,
142 const TString& theOptions )
143{
144 std::string addedMethodName(Types::Instance().GetMethodName(theMethod).Data());
145
146 Log() << kINFO << "Adding sub-classifier: " << addedMethodName << "::" << theTitle << Endl;
147
148 DataSetInfo& dsi = CreateCategoryDSI(theCut, theVariables, theTitle);
149
150 IMethod* addedMethod = ClassifierFactory::Instance().Create(addedMethodName,GetJobName(),theTitle,dsi,theOptions);
151
152 MethodBase *method = (dynamic_cast<MethodBase*>(addedMethod));
153 if(method==0) return 0;
154
155 if(fModelPersistence) method->SetWeightFileDir(fFileDir);
156 method->SetModelPersistence(fModelPersistence);
157 method->SetAnalysisType( fAnalysisType );
158 method->SetupMethod();
159 method->ParseOptions();
160 method->ProcessSetup();
161 method->SetFile(fFile);
162 method->SetSilentFile(IsSilentFile());
163
164
165 // set or create correct method base dir for added method
166 const TString dirName(Form("Method_%s",method->GetMethodTypeName().Data()));
167 TDirectory * dir = BaseDir()->GetDirectory(dirName);
168 if (dir != 0) method->SetMethodBaseDir( dir );
169 else method->SetMethodBaseDir( BaseDir()->mkdir(dirName,Form("Directory for all %s methods", method->GetMethodTypeName().Data())) );
170
171 // method->SetBaseDir(eigenes base dir, gucken ob Fisher dir existiert, sonst erzeugen )
172
173 // check-for-unused-options is performed; may be overridden by derived
174 // classes
175 method->CheckSetup();
176
177 // disable writing of XML files and standalone classes for sub methods
178 method->DisableWriting( kTRUE );
179
180 // store method, cut and variable names and create cut formula
181 fMethods.push_back(method);
182 fCategoryCuts.push_back(theCut);
183 fVars.push_back(theVariables);
184
185 DataSetInfo& primaryDSI = DataInfo();
186
187 UInt_t newSpectatorIndex = primaryDSI.GetSpectatorInfos().size();
188 fCategorySpecIdx.push_back(newSpectatorIndex);
189
190 primaryDSI.AddSpectator( Form("%s_cat%i:=%s", GetName(),(int)fMethods.size(),theCut.GetTitle()),
191 Form("%s:%s",GetName(),method->GetName()),
192 "pass", 0, 0, 'C' );
193
194 return method;
195}
196
197////////////////////////////////////////////////////////////////////////////////
198/// create a DataSetInfo object for a sub-classifier
199
201 const TString& theVariables,
202 const TString& theTitle)
203{
204 // create a new dsi with name: theTitle+"_dsi"
205 TString dsiName=theTitle+"_dsi";
206 DataSetInfo& oldDSI = DataInfo();
207 DataSetInfo* dsi = new DataSetInfo(dsiName);
208
209 // register the new dsi
210 // DataSetManager::Instance().AddDataSetInfo(*dsi); // DSMTEST replaced by following line
211 fDataSetManager->AddDataSetInfo(*dsi);
212
213 // copy the targets and spectators from the old dsi to the new dsi
214 std::vector<VariableInfo>::iterator itrVarInfo;
215
216 for (itrVarInfo = oldDSI.GetTargetInfos().begin(); itrVarInfo != oldDSI.GetTargetInfos().end(); ++itrVarInfo)
217 dsi->AddTarget(*itrVarInfo);
218
219 for (itrVarInfo = oldDSI.GetSpectatorInfos().begin(); itrVarInfo != oldDSI.GetSpectatorInfos().end(); ++itrVarInfo)
220 dsi->AddSpectator(*itrVarInfo);
221
222 // split string that contains the variables into tiny little pieces
223 std::vector<TString> variables = gTools().SplitString(theVariables,':' );
224
225 // prepare to create varMap
226 std::vector<UInt_t> varMap;
227 Int_t counter=0;
228
229 // add the variables that were specified in theVariables
230 std::vector<TString>::iterator itrVariables;
231 Bool_t found = kFALSE;
232
233 // iterate over all variables in 'variables' and add them
234 for (itrVariables = variables.begin(); itrVariables != variables.end(); ++itrVariables) {
235 counter=0;
236
237 // check the variables of the old dsi for the variable that we want to add
238 for (itrVarInfo = oldDSI.GetVariableInfos().begin(); itrVarInfo != oldDSI.GetVariableInfos().end(); ++itrVarInfo) {
239 if((*itrVariables==itrVarInfo->GetLabel()) ) { // || (*itrVariables==itrVarInfo->GetExpression())) {
240 // don't compare the expression, since the user might take two times the same expression, but with different labels
241 // and apply different transformations to the variables.
242 dsi->AddVariable(*itrVarInfo);
243 varMap.push_back(counter);
244 found = kTRUE;
245 }
246 counter++;
247 }
248
249 // check the spectators of the old dsi for the variable that we want to add
250 for (itrVarInfo = oldDSI.GetSpectatorInfos().begin(); itrVarInfo != oldDSI.GetSpectatorInfos().end(); ++itrVarInfo) {
251 if((*itrVariables==itrVarInfo->GetLabel()) ) { // || (*itrVariables==itrVarInfo->GetExpression())) {
252 // don't compare the expression, since the user might take two times the same expression, but with different labels
253 // and apply different transformations to the variables.
254 dsi->AddVariable(*itrVarInfo);
255 varMap.push_back(counter);
256 found = kTRUE;
257 }
258 counter++;
259 }
260
261 // if the variable is neither in the variables nor in the spectators, we abort
262 if (!found) {
263 Log() << kFATAL <<"The variable " << itrVariables->Data() << " was not found and could not be added " << Endl;
264 }
265 found = kFALSE;
266 }
267
268 // in the case that no variables are specified, add the default-variables from the original dsi
269 if (theVariables=="") {
270 for (UInt_t i=0; i<oldDSI.GetVariableInfos().size(); i++) {
271 dsi->AddVariable(oldDSI.GetVariableInfos()[i]);
272 varMap.push_back(i);
273 }
274 }
275
276 // add the variable map 'varMap' to the vector of varMaps
277 fVarMaps.push_back(varMap);
278
279 // set classes and cuts
280 UInt_t nClasses=oldDSI.GetNClasses();
281 TString className;
282
283 for (UInt_t i=0; i<nClasses; i++) {
284 className = oldDSI.GetClassInfo(i)->GetName();
285 dsi->AddClass(className);
286 dsi->SetCut(oldDSI.GetCut(i),className);
287 dsi->AddCut(theCut,className);
288 dsi->SetWeightExpression(oldDSI.GetWeightExpression(i),className);
289 }
290
291 // set split options, root dir and normalization for the new dsi
292 dsi->SetSplitOptions(oldDSI.GetSplitOptions());
293 dsi->SetRootDir(oldDSI.GetRootDir());
294 TString norm(oldDSI.GetNormalization().Data());
295 dsi->SetNormalization(norm);
296 // need to add split options to normalize with cut efficiency
297 TString splitOpt = dsi->GetSplitOptions();
298 splitOpt += ":ScaleWithPreselEff";
299 dsi->SetSplitOptions(splitOpt);
300
301 DataSetInfo& dsiReference= (*dsi);
302
303 return dsiReference;
304}
305
306////////////////////////////////////////////////////////////////////////////////
307/// initialize the method
308
310{
311}
312
313////////////////////////////////////////////////////////////////////////////////
314/// initialize the circular tree
315
317{
318 delete fCatTree;
319 fCatTree = nullptr;
320
321 std::vector<VariableInfo>::const_iterator viIt;
322 const std::vector<VariableInfo>& vars = dsi.GetVariableInfos();
323 const std::vector<VariableInfo>& specs = dsi.GetSpectatorInfos();
324
325 Bool_t hasAllExternalLinks = kTRUE;
326 for (viIt = vars.begin(); viIt != vars.end(); ++viIt)
327 if( viIt->GetExternalLink() == 0 ) {
328 hasAllExternalLinks = kFALSE;
329 break;
330 }
331 for (viIt = specs.begin(); viIt != specs.end(); ++viIt)
332 if( viIt->GetExternalLink() == 0 ) {
333 hasAllExternalLinks = kFALSE;
334 break;
335 }
336
337 if(!hasAllExternalLinks) return;
338
339 {
340 // Rather than having TTree::TTree add to the current directory and then remove it, let
341 // make sure to not add it in the first place.
342 // The add-then-remove can lead to a problem if gDirectory points to the same directory (for example
343 // gROOT) in the current thread and another one (and both try to add to the directory at the same time).
344 TDirectory::TContext ctxt(nullptr);
345 fCatTree = new TTree(Form("Circ%s",GetMethodName().Data()),"Circular Tree for categorization");
346 fCatTree->SetCircular(1);
347 }
348
349 for (viIt = vars.begin(); viIt != vars.end(); ++viIt) {
350 const VariableInfo& vi = *viIt;
351 fCatTree->Branch(vi.GetExpression(),(Float_t*)vi.GetExternalLink(), TString(vi.GetExpression())+TString("/F"));
352 }
353 for (viIt = specs.begin(); viIt != specs.end(); ++viIt) {
354 const VariableInfo& vi = *viIt;
355 if(vi.GetVarType()=='C') continue;
356 fCatTree->Branch(vi.GetExpression(),(Float_t*)vi.GetExternalLink(), TString(vi.GetExpression())+TString("/F"));
357 }
358
359 for(UInt_t cat=0; cat!=fCategoryCuts.size(); ++cat) {
360 fCatFormulas.push_back(new TTreeFormula(Form("Category_%i",cat), fCategoryCuts[cat].GetTitle(), fCatTree));
361 }
362}
363
364////////////////////////////////////////////////////////////////////////////////
365/// train all sub-classifiers
366
368{
369 // specify the minimum # of training events and set 'classification'
370 const Int_t MinNoTrainingEvents = 10;
371
372 Types::EAnalysisType analysisType = GetAnalysisType();
373
374 // start the training
375 Log() << kINFO << "Train all sub-classifiers for "
376 << (analysisType == Types::kRegression ? "Regression" : "Classification") << " ..." << Endl;
377
378 // don't do anything if no sub-classifier booked
379 if (fMethods.empty()) {
380 Log() << kINFO << "...nothing found to train" << Endl;
381 return;
382 }
383
384 std::vector<IMethod*>::iterator itrMethod;
385
386 // iterate over all booked sub-classifiers and train them
387 for (itrMethod = fMethods.begin(); itrMethod != fMethods.end(); ++itrMethod ) {
388
389 MethodBase* mva = dynamic_cast<MethodBase*>(*itrMethod);
390 if(!mva) continue;
391 mva->SetAnalysisType( analysisType );
392 if (!mva->HasAnalysisType( analysisType,
393 mva->DataInfo().GetNClasses(),
394 mva->DataInfo().GetNTargets() ) ) {
395 Log() << kWARNING << "Method " << mva->GetMethodTypeName() << " is not capable of handling " ;
396 if (analysisType == Types::kRegression)
397 Log() << "regression with " << mva->DataInfo().GetNTargets() << " targets." << Endl;
398 else
399 Log() << "classification with " << mva->DataInfo().GetNClasses() << " classes." << Endl;
400 itrMethod = fMethods.erase( itrMethod );
401 continue;
402 }
404
405 Log() << kINFO << "Train method: " << mva->GetMethodName() << " for "
406 << (analysisType == Types::kRegression ? "Regression" : "Classification") << Endl;
407 mva->TrainMethod();
408 Log() << kINFO << "Training finished" << Endl;
409
410 } else {
411
412 Log() << kWARNING << "Method " << mva->GetMethodName()
413 << " not trained (training tree has less entries ["
414 << mva->Data()->GetNTrainingEvents()
415 << "] than required [" << MinNoTrainingEvents << "]" << Endl;
416
417 Log() << kERROR << " w/o training/test events for that category, I better stop here and let you fix " << Endl;
418 Log() << kFATAL << "that one first, otherwise things get too messy later ... " << Endl;
419
420 }
421 }
422
423 if (analysisType != Types::kRegression) {
424
425 // variable ranking
426 Log() << kINFO << "Begin ranking of input variables..." << Endl;
427 for (itrMethod = fMethods.begin(); itrMethod != fMethods.end(); ++itrMethod) {
428 MethodBase* mva = dynamic_cast<MethodBase*>(*itrMethod);
429 if (mva && mva->Data()->GetNTrainingEvents() >= MinNoTrainingEvents) {
430 const Ranking* ranking = (*itrMethod)->CreateRanking();
431 if (ranking != 0)
432 ranking->Print();
433 else
434 Log() << kINFO << "No variable ranking supplied by classifier: "
435 << dynamic_cast<MethodBase*>(*itrMethod)->GetMethodName() << Endl;
436 }
437 }
438 }
439}
440
441////////////////////////////////////////////////////////////////////////////////
442/// create XML description of Category classifier
443
445{
446 void* wght = gTools().AddChild(parent, "Weights");
447 gTools().AddAttr( wght, "NSubMethods", fMethods.size() );
448 void* submethod(0);
449
450 // iterate over methods and write them to XML file
451 for (UInt_t i=0; i<fMethods.size(); i++) {
452 MethodBase* method = dynamic_cast<MethodBase*>(fMethods[i]);
453 submethod = gTools().AddChild(wght, "SubMethod");
454 gTools().AddAttr(submethod, "Index", i);
455 gTools().AddAttr(submethod, "Method", method->GetMethodTypeName() + "::" + method->GetMethodName());
456 gTools().AddAttr(submethod, "Cut", fCategoryCuts[i]);
457 gTools().AddAttr(submethod, "Variables", fVars[i]);
458 method->WriteStateToXML( submethod );
459 }
460}
461
462////////////////////////////////////////////////////////////////////////////////
463/// read weights of sub-classifiers of MethodCategory from xml weight file
464
466{
467 UInt_t nSubMethods;
468 TString fullMethodName;
469 TString methodType;
470 TString methodTitle;
471 TString theCutString;
472 TString theVariables;
473 Int_t titleLength;
474 gTools().ReadAttr( wghtnode, "NSubMethods", nSubMethods );
475 void* subMethodNode = gTools().GetChild(wghtnode);
476
477 Log() << kINFO << "Recreating sub-classifiers from XML-file " << Endl;
478
479 // recreate all sub-methods from weight file
480 for (UInt_t i=0; i<nSubMethods; i++) {
481 gTools().ReadAttr( subMethodNode, "Method", fullMethodName );
482 gTools().ReadAttr( subMethodNode, "Cut", theCutString );
483 gTools().ReadAttr( subMethodNode, "Variables", theVariables );
484
485 // determine sub-method type
486 methodType = fullMethodName(0,fullMethodName.Index("::"));
487 if (methodType.Contains(" ")) methodType = methodType(methodType.Last(' ')+1,methodType.Length());
488
489 // determine sub-method title
490 titleLength = fullMethodName.Length()-fullMethodName.Index("::")-2;
491 methodTitle = fullMethodName(fullMethodName.Index("::")+2,titleLength);
492
493 // reconstruct dsi for sub-method
494 DataSetInfo& dsi = CreateCategoryDSI(TCut(theCutString), theVariables, methodTitle);
495
496 // recreate sub-method from weights and add to fMethods
497 MethodBase* method = dynamic_cast<MethodBase*>( ClassifierFactory::Instance().Create( methodType.Data(),
498 dsi, "none" ) );
499 if(method==0)
500 Log() << kFATAL << "Could not create sub-method " << method << " from XML." << Endl;
501
502 method->SetupMethod();
503 method->ReadStateFromXML(subMethodNode);
504
505 fMethods.push_back(method);
506 fCategoryCuts.push_back(TCut(theCutString));
507 fVars.push_back(theVariables);
508
509 DataSetInfo& primaryDSI = DataInfo();
510
511 UInt_t spectatorIdx = 10000;
512 UInt_t counter=0;
513
514 // find the spectator index
515 std::vector<VariableInfo>& spectators=primaryDSI.GetSpectatorInfos();
516 std::vector<VariableInfo>::iterator itrVarInfo;
517 TString specName= Form("%s_cat%i", GetName(),(int)fCategorySpecIdx.size()+1);
518
519 for (itrVarInfo = spectators.begin(); itrVarInfo != spectators.end(); ++itrVarInfo, ++counter) {
520 if((specName==itrVarInfo->GetLabel()) || (specName==itrVarInfo->GetExpression())) {
521 spectatorIdx=counter;
522 fCategorySpecIdx.push_back(spectatorIdx);
523 break;
524 }
525 }
526
527 subMethodNode = gTools().GetNextChild(subMethodNode);
528 }
529
530 InitCircularTree(DataInfo());
531
532}
533
534////////////////////////////////////////////////////////////////////////////////
535/// process user options
536
538{
539}
540
541////////////////////////////////////////////////////////////////////////////////
542/// Get help message text
543///
544/// typical length of text line:
545/// "|--------------------------------------------------------------|"
546
548{
549 Log() << Endl;
550 Log() << gTools().Color("bold") << "--- Short description:" << gTools().Color("reset") << Endl;
551 Log() << Endl;
552 Log() << "This method allows to define different categories of events. The" <<Endl;
553 Log() << "categories are defined via cuts on the variables. For each" << Endl;
554 Log() << "category, a different classifier and set of variables can be" <<Endl;
555 Log() << "specified. The categories which are defined for this method must" << Endl;
556 Log() << "be disjoint." << Endl;
557}
558
559////////////////////////////////////////////////////////////////////////////////
560/// no ranking
561
563{
564 return 0;
565}
566
567////////////////////////////////////////////////////////////////////////////////
568
570{
571 // if it's not a simple 'spectator' variable (0 or 1) that the categories are defined by
572 // (but rather some 'formula' (i.e. eta>0), then this formulas are stored in fCatTree and that
573 // one will be evaluated.. (the formulae return 'true' or 'false'
574 if (fCatTree) {
575 if (methodIdx>=fCatFormulas.size()) {
576 Log() << kFATAL << "Large method index " << methodIdx << ", number of category formulas = "
577 << fCatFormulas.size() << Endl;
578 }
579 TTreeFormula* f = fCatFormulas[methodIdx];
580 return f->EvalInstance(0) > 0.5;
581 }
582 // otherwise, it simply looks if "variable == true" ("greater 0.5 to be "sure" )
583 else {
584
585 // checks whether an event lies within a cut
586 if (methodIdx>=fCategorySpecIdx.size()) {
587 Log() << kFATAL << "Unknown method index " << methodIdx << " maximum allowed index="
588 << fCategorySpecIdx.size() << Endl;
589 }
590 UInt_t spectatorIdx = fCategorySpecIdx[methodIdx];
591 Float_t specVal = ev->GetSpectator(spectatorIdx);
592 Bool_t pass = (specVal>0.5);
593 return pass;
594 }
595}
596
597////////////////////////////////////////////////////////////////////////////////
598/// returns the mva value of the right sub-classifier
599
601{
602 if (fMethods.empty()) return 0;
603
604 UInt_t methodToUse = 0;
605 const Event* ev = GetEvent();
606
607 // determine which sub-classifier to use for this event
608 Int_t suitableCutsN = 0;
609
610 for (UInt_t i=0; i<fMethods.size(); ++i) {
611 if (PassesCut(ev, i)) {
612 ++suitableCutsN;
613 methodToUse=i;
614 }
615 }
616
617 if (suitableCutsN == 0) {
618 Log() << kWARNING << "Event does not lie within the cut of any sub-classifier." << Endl;
619 return 0;
620 }
621
622 if (suitableCutsN > 1) {
623 Log() << kFATAL << "The defined categories are not disjoint." << Endl;
624 return 0;
625 }
626
627 // get mva value from the suitable sub-classifier
628 ev->SetVariableArrangement(&fVarMaps[methodToUse]);
629 Double_t mvaValue = dynamic_cast<MethodBase*>(fMethods[methodToUse])->GetMvaValue(ev,err,errUpper);
631
632 std::cout << "Event is for method " << methodToUse << " spectator is " << ev->GetSpectator(0) << " "
633 << fVarMaps[0][0] << " classID " << DataInfo().IsSignal(ev) << " value " << mvaValue
634 << " type " << Data()->GetCurrentType() << std::endl;
635
636 return mvaValue;
637}
638
639///////////////////////////////////////////////////////////////
640/// returns the mva values of the right sub-classifier
641///
642std::vector<Double_t>
644{
645
646 std::vector<Double_t> result;
647
648 Info("GetMVaValues", "Evaluate MethodCategory for %d events type %d on the dataset %s", int(lastEvt - firstEvt),
649 (int)Data()->GetCurrentType(), DataInfo().GetName());
650
651 if (fMethods.empty())
652 return result;
653
654 auto data = Data();
655
656 // it is faster to evaluate all categories
657 std::vector<std::vector<Double_t>> mvaValues(fMethods.size());
658 for (UInt_t i = 0; i < fMethods.size(); ++i) {
659 // need to set variable map
660 for (UInt_t iev = firstEvt; iev < lastEvt; ++iev) {
661 data->SetCurrentEvent(iev);
662 const Event *ev = GetEvent(data->GetEvent());
663 ev->SetVariableArrangement(&fVarMaps[i]);
664 }
665 // need to set correct data in the different method
666 mvaValues[i] = dynamic_cast<MethodBase *>(fMethods[i])->GetDataMvaValues(data,firstEvt, lastEvt, logProgress);
667 }
668
669 // now loop on all events
670 result.resize(lastEvt - firstEvt);
671
672 for (UInt_t iev = firstEvt; iev < lastEvt; ++iev)
673 {
674 //std::cout << "Loop on event " << iev << " of " << DataInfo().GetName() << std::endl;
675 data->SetCurrentEvent(iev);
676 UInt_t methodToUse = 0;
677 const Event *ev = GetEvent(data->GetEvent());
678
679 // determine which sub-classifier to use for this event
680 Int_t suitableCutsN = 0;
681
682 for (UInt_t i = 0; i < fMethods.size(); ++i) {
683 if (PassesCut(ev, i)) {
684 ++suitableCutsN;
685 methodToUse = i;
686 }
687 }
688
689 if (suitableCutsN == 0) {
690 Log() << kWARNING << "Event does not lie within the cut of any sub-classifier." << Endl;
691 result[iev] = 0;
692 }
693
694 if (suitableCutsN > 1) {
695 Log() << kFATAL << "The defined categories are not disjoint." << Endl;
696 return result;
697 }
698
699
700 result[iev - firstEvt] = mvaValues[methodToUse][iev - firstEvt];
701
702 // std::cout << "Event " << iev << " is for method " << methodToUse << " spectator is " << ev->GetSpectator(0)
703 // << " " << fVarMaps[0][0] << " classID " << DataInfo().IsSignal(ev) << " value "
704 // << result[iev - firstEvt] << " type " << data->GetCurrentType() << std::endl;
705
706 // reset variable map which was set it before
707 ev->SetVariableArrangement(nullptr);
708 }
709 return result;
710}
711
712////////////////////////////////////////////////////////////////////////////////
713/// returns the mva values of the multi-class right sub-classifier
714///
715const std::vector<Float_t> &TMVA::MethodCategory::GetMulticlassValues()
716{
717 if (fMethods.empty())
719
720 UInt_t methodToUse = 0;
721 const Event *ev = GetEvent();
722
723 // determine which sub-classifier to use for this event
724 Int_t suitableCutsN = 0;
725
726 for (UInt_t i = 0; i < fMethods.size(); ++i) {
727 if (PassesCut(ev, i)) {
728 ++suitableCutsN;
729 methodToUse = i;
730 }
731 }
732
733 if (suitableCutsN == 0) {
734 Log() << kWARNING << "Event does not lie within the cut of any sub-classifier." << Endl;
736 }
737
738 if (suitableCutsN > 1) {
739 Log() << kFATAL << "The defined categories are not disjoint." << Endl;
741 }
742 MethodBase *meth = dynamic_cast<MethodBase *>(fMethods[methodToUse]);
743 if (!meth) {
744 Log() << kFATAL << "method not found in Category Regression method" << Endl;
746 }
747 // get mva value from the suitable sub-classifier
748 ev->SetVariableArrangement(&fVarMaps[methodToUse]);
749 auto &result = meth->GetMulticlassValues();
750 ev->SetVariableArrangement(nullptr);
751 return result;
752}
753
754////////////////////////////////////////////////////////////////////////////////
755/// returns the mva value of the right sub-classifier
756
757const std::vector<Float_t> &TMVA::MethodCategory::GetRegressionValues()
758{
759 if (fMethods.empty()) return MethodBase::GetRegressionValues();
760
761 UInt_t methodToUse = 0;
762 const Event* ev = GetEvent();
763
764 // determine which sub-classifier to use for this event
765 Int_t suitableCutsN = 0;
766
767 for (UInt_t i=0; i<fMethods.size(); ++i) {
768 if (PassesCut(ev, i)) {
769 ++suitableCutsN;
770 methodToUse=i;
771 }
772 }
773
774 if (suitableCutsN == 0) {
775 Log() << kWARNING << "Event does not lie within the cut of any sub-classifier." << Endl;
777 }
778
779 if (suitableCutsN > 1) {
780 Log() << kFATAL << "The defined categories are not disjoint." << Endl;
782 }
783 MethodBase* meth = dynamic_cast<MethodBase*>(fMethods[methodToUse]);
784 if (!meth){
785 Log() << kFATAL << "method not found in Category Regression method" << Endl;
787 }
788 // get mva value from the suitable sub-classifier
789 ev->SetVariableArrangement(&fVarMaps[methodToUse]);
790 auto & result = meth->GetRegressionValues(ev);
791 return result;
792}
#define REGISTER_METHOD(CLASS)
for example
#define f(i)
Definition: RSha256.hxx:104
size_t size(const MatrixT &matrix)
retrieve the size of a square matrix
int Int_t
Definition: RtypesCore.h:45
unsigned int UInt_t
Definition: RtypesCore.h:46
const Bool_t kFALSE
Definition: RtypesCore.h:101
bool Bool_t
Definition: RtypesCore.h:63
double Double_t
Definition: RtypesCore.h:59
long long Long64_t
Definition: RtypesCore.h:80
float Float_t
Definition: RtypesCore.h:57
const Bool_t kTRUE
Definition: RtypesCore.h:100
#define ClassImp(name)
Definition: Rtypes.h:364
void Info(const char *location, const char *msgfmt,...)
Use this function for informational messages.
Definition: TError.cxx:220
int type
Definition: TGX11.cxx:121
char * Form(const char *fmt,...)
A specialized string object used for TTree selections.
Definition: TCut.h:25
TDirectory::TContext keeps track and restore the current directory.
Definition: TDirectory.h:89
Describe directory structure in memory.
Definition: TDirectory.h:45
virtual TDirectory * GetDirectory(const char *namecycle, Bool_t printError=false, const char *funcname="GetDirectory")
Find a directory using apath.
Definition: TDirectory.cxx:407
IMethod * Create(const std::string &name, const TString &job, const TString &title, DataSetInfo &dsi, const TString &option)
creates the method if needed based on the method name using the creator function the factory has stor...
static ClassifierFactory & Instance()
access to the ClassifierFactory singleton creates the instance if needed
virtual void ParseOptions()
options parser
Class that contains all the data information.
Definition: DataSetInfo.h:62
const TString GetWeightExpression(Int_t i) const
Definition: DataSetInfo.h:164
std::vector< VariableInfo > & GetVariableInfos()
Definition: DataSetInfo.h:103
void SetSplitOptions(const TString &so)
Definition: DataSetInfo.h:185
ClassInfo * AddClass(const TString &className)
const TString & GetNormalization() const
Definition: DataSetInfo.h:131
std::vector< VariableInfo > & GetSpectatorInfos()
Definition: DataSetInfo.h:122
TDirectory * GetRootDir() const
Definition: DataSetInfo.h:190
void SetNormalization(const TString &norm)
Definition: DataSetInfo.h:132
UInt_t GetNClasses() const
Definition: DataSetInfo.h:155
const TString & GetSplitOptions() const
Definition: DataSetInfo.h:186
UInt_t GetNTargets() const
Definition: DataSetInfo.h:128
VariableInfo & AddTarget(const TString &expression, const TString &title, const TString &unit, Double_t min, Double_t max, Bool_t normalized=kTRUE, void *external=0)
add a variable (can be a complex expression) to the set of variables used in the MV analysis
VariableInfo & AddSpectator(const TString &expression, const TString &title, const TString &unit, Double_t min, Double_t max, char type='F', Bool_t normalized=kTRUE, void *external=0)
add a spectator (can be a complex expression) to the set of spectator variables used in the MV analys...
ClassInfo * GetClassInfo(Int_t clNum) const
const TCut & GetCut(Int_t i) const
Definition: DataSetInfo.h:168
void SetCut(const TCut &cut, const TString &className)
set the cut for the classes
VariableInfo & AddVariable(const TString &expression, const TString &title="", const TString &unit="", Double_t min=0, Double_t max=0, char varType='F', Bool_t normalized=kTRUE, void *external=0)
add a variable (can be a complex expression) to the set of variables used in the MV analysis
std::vector< VariableInfo > & GetTargetInfos()
Definition: DataSetInfo.h:114
void SetRootDir(TDirectory *d)
Definition: DataSetInfo.h:189
void SetWeightExpression(const TString &exp, const TString &className="")
set the weight expressions for the classes if class name is specified, set only for this class if cla...
void AddCut(const TCut &cut, const TString &className)
set the cut for the classes
Long64_t GetNTrainingEvents() const
Definition: DataSet.h:68
void SetVariableArrangement(std::vector< UInt_t > *const m) const
set the variable arrangement
Definition: Event.cxx:191
Float_t GetSpectator(UInt_t ivar) const
return spectator content
Definition: Event.cxx:261
Interface for all concrete MVA method implementations.
Definition: IMethod.h:53
virtual Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets)=0
Virtual base Class for all MVA method.
Definition: MethodBase.h:111
virtual const std::vector< Float_t > & GetRegressionValues()
Definition: MethodBase.h:221
const std::vector< Float_t > & GetRegressionValues(const TMVA::Event *const ev)
Definition: MethodBase.h:214
void SetSilentFile(Bool_t status)
Definition: MethodBase.h:378
void SetWeightFileDir(TString fileDir)
set directory of weight file
void WriteStateToXML(void *parent) const
general method used in writing the header of the weight files where the used variables,...
TString GetMethodTypeName() const
Definition: MethodBase.h:332
void DisableWriting(Bool_t setter)
Definition: MethodBase.h:442
const char * GetName() const
Definition: MethodBase.h:334
virtual const std::vector< Float_t > & GetMulticlassValues()
Definition: MethodBase.h:227
void SetupMethod()
setup of methods
Definition: MethodBase.cxx:406
virtual void SetAnalysisType(Types::EAnalysisType type)
Definition: MethodBase.h:436
const TString & GetMethodName() const
Definition: MethodBase.h:331
void ProcessSetup()
process all options the "CheckForUnusedOptions" is done in an independent call, since it may be overr...
Definition: MethodBase.cxx:423
DataSetInfo & DataInfo() const
Definition: MethodBase.h:410
void SetFile(TFile *file)
Definition: MethodBase.h:375
void ReadStateFromXML(void *parent)
friend class MethodCategory
Definition: MethodBase.h:269
void SetMethodBaseDir(TDirectory *methodDir)
Definition: MethodBase.h:374
DataSet * Data() const
Definition: MethodBase.h:409
void SetModelPersistence(Bool_t status)
Definition: MethodBase.h:382
virtual void CheckSetup()
check may be overridden by derived class (sometimes, eg, fitters are used which can only be implement...
Definition: MethodBase.cxx:433
Class for categorizing the phase space.
void InitCircularTree(const DataSetInfo &dsi)
initialize the circular tree
void GetHelpMessage() const
Get help message text.
void Init()
initialize the method
Bool_t PassesCut(const Event *ev, UInt_t methodIdx)
virtual Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t)
check whether method category has analysis type the method type has to be the same for all sub-method...
void ProcessOptions()
process user options
Double_t GetMvaValue(Double_t *err=0, Double_t *errUpper=0)
returns the mva value of the right sub-classifier
virtual const std::vector< Float_t > & GetMulticlassValues()
returns the mva values of the multi-class right sub-classifier
TMVA::DataSetInfo & CreateCategoryDSI(const TCut &, const TString &, const TString &)
create a DataSetInfo object for a sub-classifier
void DeclareOptions()
options for this method
void AddWeightsXMLTo(void *parent) const
create XML description of Category classifier
const Ranking * CreateRanking()
no ranking
virtual ~MethodCategory(void)
destructor
virtual const std::vector< Float_t > & GetRegressionValues()
returns the mva value of the right sub-classifier
TMVA::IMethod * AddMethod(const TCut &, const TString &theVariables, Types::EMVA theMethod, const TString &theTitle, const TString &theOptions)
adds sub-classifier for a category
virtual std::vector< Double_t > GetMvaValues(Long64_t firstEvt=0, Long64_t lastEvt=-1, Bool_t logProgress=false)
returns the mva values of the right sub-classifier
void ReadWeightsFromXML(void *wghtnode)
read weights of sub-classifiers of MethodCategory from xml weight file
void Train(void)
train all sub-classifiers
Virtual base class for combining several TMVA method.
Ranking for variables in method (implementation)
Definition: Ranking.h:48
virtual void Print() const
get maximum length of variable names
Definition: Ranking.cxx:111
void * GetNextChild(void *prevchild, const char *childname=0)
XML helpers.
Definition: Tools.cxx:1174
void * AddChild(void *parent, const char *childname, const char *content=0, bool isRootNode=false)
add child node
Definition: Tools.cxx:1136
std::vector< TString > SplitString(const TString &theOpt, const char separator) const
splits the option string at 'separator' and fills the list 'splitV' with the primitive strings
Definition: Tools.cxx:1211
const TString & Color(const TString &)
human readable color strings
Definition: Tools.cxx:840
void * GetChild(void *parent, const char *childname=0)
get child node
Definition: Tools.cxx:1162
void ReadAttr(void *node, const char *, T &value)
read attribute from xml
Definition: Tools.h:335
void AddAttr(void *node, const char *, const T &value, Int_t precision=16)
add attribute to xml
Definition: Tools.h:353
Singleton class for Global types used by TMVA.
Definition: Types.h:73
static Types & Instance()
the the single instance of "Types" if existing already, or create it (Singleton)
Definition: Types.cxx:69
EAnalysisType
Definition: Types.h:128
@ kRegression
Definition: Types.h:130
Class for type info of MVA input variable.
Definition: VariableInfo.h:47
const TString & GetExpression() const
Definition: VariableInfo.h:57
char GetVarType() const
Definition: VariableInfo.h:61
void * GetExternalLink() const
Definition: VariableInfo.h:83
@ kERROR
Definition: Types.h:62
@ kINFO
Definition: Types.h:60
@ kWARNING
Definition: Types.h:61
@ kFATAL
Definition: Types.h:63
virtual const char * GetTitle() const
Returns title of object.
Definition: TNamed.h:48
virtual const char * GetName() const
Returns name of object.
Definition: TNamed.h:47
Basic string class.
Definition: TString.h:136
Ssiz_t Length() const
Definition: TString.h:410
const char * Data() const
Definition: TString.h:369
Ssiz_t Last(char c) const
Find last occurrence of a character c.
Definition: TString.cxx:916
Bool_t Contains(const char *pat, ECaseCompare cmp=kExact) const
Definition: TString.h:624
Ssiz_t Index(const char *pat, Ssiz_t i=0, ECaseCompare cmp=kExact) const
Definition: TString.h:639
Used to pass a selection expression to the Tree drawing routine.
Definition: TTreeFormula.h:58
A TTree represents a columnar dataset.
Definition: TTree.h:79
void GetMethodName(TString &name, TKey *mkey)
Definition: tmvaglob.cxx:342
create variable transformations
Tools & gTools()
void variables(TString dataset, TString fin="TMVA.root", TString dirName="InputVariables_Id", TString title="TMVA Input Variables", Bool_t isRegression=kFALSE, Bool_t useTMVAStyle=kTRUE)
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:158
Double_t Log(Double_t x)
Definition: TMath.h:760
const Int_t MinNoTrainingEvents
Definition: Factory.cxx:95