Logo ROOT  
Reference Guide
MethodCategory.cxx
Go to the documentation of this file.
1 // @(#)root/tmva $Id$
2 // Author: Andreas Hoecker, Joerg Stelzer, Helge Voss, Eckhard von Toerne
3 
4 /**********************************************************************************
5  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6  * Package: TMVA *
7  * Class : MethodCompositeBase *
8  * Web : http://tmva.sourceforge.net *
9  * *
10  * Description: *
11  * Virtual base class for all MVA method *
12  * *
13  * Authors (alphabetical): *
14  * Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland *
15  * Nadim Sah <Nadim.Sah@cern.ch> - Berlin, Germany *
16  * Peter Speckmayer <Peter.Speckmazer@cern.ch> - CERN, Switzerland *
17  * Joerg Stelzer <Joerg.Stelzer@cern.ch> - MSU East Lansing, USA *
18  * Helge Voss <Helge.Voss@cern.ch> - MPI-K Heidelberg, Germany *
19  * Jan Therhaag <Jan.Therhaag@cern.ch> - U of Bonn, Germany *
20  * Eckhard v. Toerne <evt@uni-bonn.de> - U of Bonn, Germany *
21  * *
22  * Copyright (c) 2005-2011: *
23  * CERN, Switzerland *
24  * MSU East Lansing, USA *
25  * MPI-K Heidelberg, Germany *
26  * U. of Bonn, Germany *
27  * *
28  * Redistribution and use in source and binary forms, with or without *
29  * modification, are permitted according to the terms listed in LICENSE *
30  * (http://tmva.sourceforge.net/LICENSE) *
31  **********************************************************************************/
32 
33 /*! \class TMVA::MethodCategory
34 \ingroup TMVA
35 
36 Class for categorizing the phase space
37 
38 This class is meant to allow categorisation of the data. For different
39 categories, different classifiers may be booked and different variables
40 may be considered. The aim is to account for the difference that
41 is due to different locations/angles.
42 */
43 
44 
45 #include "TMVA/MethodCategory.h"
46 
47 #include <algorithm>
48 #include <vector>
49 
50 #include "TRandom3.h"
51 #include "TH1F.h"
52 #include "TSpline.h"
53 #include "TDirectory.h"
54 #include "TTreeFormula.h"
55 
56 #include "TMVA/ClassifierFactory.h"
57 #include "TMVA/Config.h"
58 #include "TMVA/DataSet.h"
59 #include "TMVA/DataSetInfo.h"
60 #include "TMVA/DataSetManager.h"
61 #include "TMVA/IMethod.h"
62 #include "TMVA/MethodBase.h"
64 #include "TMVA/MsgLogger.h"
65 #include "TMVA/PDF.h"
66 #include "TMVA/Ranking.h"
67 #include "TMVA/Timer.h"
68 #include "TMVA/Tools.h"
69 #include "TMVA/Types.h"
70 #include "TMVA/VariableInfo.h"
72 
73 REGISTER_METHOD(Category)
74 
76 
77 ////////////////////////////////////////////////////////////////////////////////
78 /// standard constructor
79 
81  const TString& methodTitle,
82  DataSetInfo& theData,
83  const TString& theOption )
84  : TMVA::MethodCompositeBase( jobName, Types::kCategory, methodTitle, theData, theOption),
85  fCatTree(0),
86  fDataSetManager(NULL)
87 {
88 }
89 
90 ////////////////////////////////////////////////////////////////////////////////
91 /// constructor from weight file
92 
94  const TString& theWeightFile)
95  : TMVA::MethodCompositeBase( Types::kCategory, dsi, theWeightFile),
96  fCatTree(0),
97  fDataSetManager(NULL)
98 {
99 }
100 
101 ////////////////////////////////////////////////////////////////////////////////
102 /// destructor
103 
105 {
106  std::vector<TTreeFormula*>::iterator formIt = fCatFormulas.begin();
107  std::vector<TTreeFormula*>::iterator lastF = fCatFormulas.end();
108  for(;formIt!=lastF; ++formIt) delete *formIt;
109  delete fCatTree;
110 }
111 
112 ////////////////////////////////////////////////////////////////////////////////
113 /// check whether method category has analysis type
114 /// the method type has to be the same for all sub-methods
115 
117 {
118  std::vector<IMethod*>::iterator itrMethod = fMethods.begin();
119 
120  // iterate over methods and check whether they have the analysis type
121  for(; itrMethod != fMethods.end(); ++itrMethod ) {
122  if ( !(*itrMethod)->HasAnalysisType(type, numberClasses, numberTargets) )
123  return kFALSE;
124  }
125  return kTRUE;
126 }
127 
128 ////////////////////////////////////////////////////////////////////////////////
129 /// options for this method
130 
132 {
133 }
134 
135 ////////////////////////////////////////////////////////////////////////////////
136 /// adds sub-classifier for a category
137 
139  const TString& theVariables,
140  Types::EMVA theMethod ,
141  const TString& theTitle,
142  const TString& theOptions )
143 {
144  std::string addedMethodName(Types::Instance().GetMethodName(theMethod).Data());
145 
146  Log() << kINFO << "Adding sub-classifier: " << addedMethodName << "::" << theTitle << Endl;
147 
148  DataSetInfo& dsi = CreateCategoryDSI(theCut, theVariables, theTitle);
149 
150  IMethod* addedMethod = ClassifierFactory::Instance().Create(addedMethodName,GetJobName(),theTitle,dsi,theOptions);
151 
152  MethodBase *method = (dynamic_cast<MethodBase*>(addedMethod));
153  if(method==0) return 0;
154 
155  if(fModelPersistence) method->SetWeightFileDir(fFileDir);
156  method->SetModelPersistence(fModelPersistence);
157  method->SetAnalysisType( fAnalysisType );
158  method->SetupMethod();
159  method->ParseOptions();
160  method->ProcessSetup();
161  method->SetFile(fFile);
162  method->SetSilentFile(IsSilentFile());
163 
164 
165  // set or create correct method base dir for added method
166  const TString dirName(Form("Method_%s",method->GetMethodTypeName().Data()));
167  TDirectory * dir = BaseDir()->GetDirectory(dirName);
168  if (dir != 0) method->SetMethodBaseDir( dir );
169  else method->SetMethodBaseDir( BaseDir()->mkdir(dirName,Form("Directory for all %s methods", method->GetMethodTypeName().Data())) );
170 
171  // method->SetBaseDir(eigenes base dir, gucken ob Fisher dir existiert, sonst erzeugen )
172 
173  // check-for-unused-options is performed; may be overridden by derived
174  // classes
175  method->CheckSetup();
176 
177  // disable writing of XML files and standalone classes for sub methods
178  method->DisableWriting( kTRUE );
179 
180  // store method, cut and variable names and create cut formula
181  fMethods.push_back(method);
182  fCategoryCuts.push_back(theCut);
183  fVars.push_back(theVariables);
184 
185  DataSetInfo& primaryDSI = DataInfo();
186 
187  UInt_t newSpectatorIndex = primaryDSI.GetSpectatorInfos().size();
188  fCategorySpecIdx.push_back(newSpectatorIndex);
189 
190  primaryDSI.AddSpectator( Form("%s_cat%i:=%s", GetName(),(int)fMethods.size(),theCut.GetTitle()),
191  Form("%s:%s",GetName(),method->GetName()),
192  "pass", 0, 0, 'C' );
193 
194  return method;
195 }
196 
197 ////////////////////////////////////////////////////////////////////////////////
198 /// create a DataSetInfo object for a sub-classifier
199 
201  const TString& theVariables,
202  const TString& theTitle)
203 {
204  // create a new dsi with name: theTitle+"_dsi"
205  TString dsiName=theTitle+"_dsi";
206  DataSetInfo& oldDSI = DataInfo();
207  DataSetInfo* dsi = new DataSetInfo(dsiName);
208 
209  // register the new dsi
210  // DataSetManager::Instance().AddDataSetInfo(*dsi); // DSMTEST replaced by following line
211  fDataSetManager->AddDataSetInfo(*dsi);
212 
213  // copy the targets and spectators from the old dsi to the new dsi
214  std::vector<VariableInfo>::iterator itrVarInfo;
215 
216  for (itrVarInfo = oldDSI.GetTargetInfos().begin(); itrVarInfo != oldDSI.GetTargetInfos().end(); ++itrVarInfo)
217  dsi->AddTarget(*itrVarInfo);
218 
219  for (itrVarInfo = oldDSI.GetSpectatorInfos().begin(); itrVarInfo != oldDSI.GetSpectatorInfos().end(); ++itrVarInfo)
220  dsi->AddSpectator(*itrVarInfo);
221 
222  // split string that contains the variables into tiny little pieces
223  std::vector<TString> variables = gTools().SplitString(theVariables,':' );
224 
225  // prepare to create varMap
226  std::vector<UInt_t> varMap;
227  Int_t counter=0;
228 
229  // add the variables that were specified in theVariables
230  std::vector<TString>::iterator itrVariables;
231  Bool_t found = kFALSE;
232 
233  // iterate over all variables in 'variables' and add them
234  for (itrVariables = variables.begin(); itrVariables != variables.end(); ++itrVariables) {
235  counter=0;
236 
237  // check the variables of the old dsi for the variable that we want to add
238  for (itrVarInfo = oldDSI.GetVariableInfos().begin(); itrVarInfo != oldDSI.GetVariableInfos().end(); ++itrVarInfo) {
239  if((*itrVariables==itrVarInfo->GetLabel()) ) { // || (*itrVariables==itrVarInfo->GetExpression())) {
240  // don't compare the expression, since the user might take two times the same expression, but with different labels
241  // and apply different transformations to the variables.
242  dsi->AddVariable(*itrVarInfo);
243  varMap.push_back(counter);
244  found = kTRUE;
245  }
246  counter++;
247  }
248 
249  // check the spectators of the old dsi for the variable that we want to add
250  for (itrVarInfo = oldDSI.GetSpectatorInfos().begin(); itrVarInfo != oldDSI.GetSpectatorInfos().end(); ++itrVarInfo) {
251  if((*itrVariables==itrVarInfo->GetLabel()) ) { // || (*itrVariables==itrVarInfo->GetExpression())) {
252  // don't compare the expression, since the user might take two times the same expression, but with different labels
253  // and apply different transformations to the variables.
254  dsi->AddVariable(*itrVarInfo);
255  varMap.push_back(counter);
256  found = kTRUE;
257  }
258  counter++;
259  }
260 
261  // if the variable is neither in the variables nor in the spectators, we abort
262  if (!found) {
263  Log() << kFATAL <<"The variable " << itrVariables->Data() << " was not found and could not be added " << Endl;
264  }
265  found = kFALSE;
266  }
267 
268  // in the case that no variables are specified, add the default-variables from the original dsi
269  if (theVariables=="") {
270  for (UInt_t i=0; i<oldDSI.GetVariableInfos().size(); i++) {
271  dsi->AddVariable(oldDSI.GetVariableInfos()[i]);
272  varMap.push_back(i);
273  }
274  }
275 
276  // add the variable map 'varMap' to the vector of varMaps
277  fVarMaps.push_back(varMap);
278 
279  // set classes and cuts
280  UInt_t nClasses=oldDSI.GetNClasses();
281  TString className;
282 
283  for (UInt_t i=0; i<nClasses; i++) {
284  className = oldDSI.GetClassInfo(i)->GetName();
285  dsi->AddClass(className);
286  dsi->SetCut(oldDSI.GetCut(i),className);
287  dsi->AddCut(theCut,className);
288  dsi->SetWeightExpression(oldDSI.GetWeightExpression(i),className);
289  }
290 
291  // set split options, root dir and normalization for the new dsi
292  dsi->SetSplitOptions(oldDSI.GetSplitOptions());
293  dsi->SetRootDir(oldDSI.GetRootDir());
294  TString norm(oldDSI.GetNormalization().Data());
295  dsi->SetNormalization(norm);
296  // need to add split options to normalize with cut efficiency
297  TString splitOpt = dsi->GetSplitOptions();
298  splitOpt += ":ScaleWithPreselEff";
299  dsi->SetSplitOptions(splitOpt);
300 
301  DataSetInfo& dsiReference= (*dsi);
302 
303  return dsiReference;
304 }
305 
306 ////////////////////////////////////////////////////////////////////////////////
307 /// initialize the method
308 
310 {
311 }
312 
313 ////////////////////////////////////////////////////////////////////////////////
314 /// initialize the circular tree
315 
317 {
318  delete fCatTree;
319  fCatTree = nullptr;
320 
321  std::vector<VariableInfo>::const_iterator viIt;
322  const std::vector<VariableInfo>& vars = dsi.GetVariableInfos();
323  const std::vector<VariableInfo>& specs = dsi.GetSpectatorInfos();
324 
325  Bool_t hasAllExternalLinks = kTRUE;
326  for (viIt = vars.begin(); viIt != vars.end(); ++viIt)
327  if( viIt->GetExternalLink() == 0 ) {
328  hasAllExternalLinks = kFALSE;
329  break;
330  }
331  for (viIt = specs.begin(); viIt != specs.end(); ++viIt)
332  if( viIt->GetExternalLink() == 0 ) {
333  hasAllExternalLinks = kFALSE;
334  break;
335  }
336 
337  if(!hasAllExternalLinks) return;
338 
339  {
340  // Rather than having TTree::TTree add to the current directory and then remove it, let
341  // make sure to not add it in the first place.
342  // The add-then-remove can lead to a problem if gDirectory points to the same directory (for example
343  // gROOT) in the current thread and another one (and both try to add to the directory at the same time).
344  TDirectory::TContext ctxt(nullptr);
345  fCatTree = new TTree(Form("Circ%s",GetMethodName().Data()),"Circular Tree for categorization");
346  fCatTree->SetCircular(1);
347  }
348 
349  for (viIt = vars.begin(); viIt != vars.end(); ++viIt) {
350  const VariableInfo& vi = *viIt;
351  fCatTree->Branch(vi.GetExpression(),(Float_t*)vi.GetExternalLink(), TString(vi.GetExpression())+TString("/F"));
352  }
353  for (viIt = specs.begin(); viIt != specs.end(); ++viIt) {
354  const VariableInfo& vi = *viIt;
355  if(vi.GetVarType()=='C') continue;
356  fCatTree->Branch(vi.GetExpression(),(Float_t*)vi.GetExternalLink(), TString(vi.GetExpression())+TString("/F"));
357  }
358 
359  for(UInt_t cat=0; cat!=fCategoryCuts.size(); ++cat) {
360  fCatFormulas.push_back(new TTreeFormula(Form("Category_%i",cat), fCategoryCuts[cat].GetTitle(), fCatTree));
361  }
362 }
363 
364 ////////////////////////////////////////////////////////////////////////////////
365 /// train all sub-classifiers
366 
368 {
369  // specify the minimum # of training events and set 'classification'
370  const Int_t MinNoTrainingEvents = 10;
371 
372  Types::EAnalysisType analysisType = GetAnalysisType();
373 
374  // start the training
375  Log() << kINFO << "Train all sub-classifiers for "
376  << (analysisType == Types::kRegression ? "Regression" : "Classification") << " ..." << Endl;
377 
378  // don't do anything if no sub-classifier booked
379  if (fMethods.empty()) {
380  Log() << kINFO << "...nothing found to train" << Endl;
381  return;
382  }
383 
384  std::vector<IMethod*>::iterator itrMethod;
385 
386  // iterate over all booked sub-classifiers and train them
387  for (itrMethod = fMethods.begin(); itrMethod != fMethods.end(); ++itrMethod ) {
388 
389  MethodBase* mva = dynamic_cast<MethodBase*>(*itrMethod);
390  if(!mva) continue;
391  mva->SetAnalysisType( analysisType );
392  if (!mva->HasAnalysisType( analysisType,
393  mva->DataInfo().GetNClasses(),
394  mva->DataInfo().GetNTargets() ) ) {
395  Log() << kWARNING << "Method " << mva->GetMethodTypeName() << " is not capable of handling " ;
396  if (analysisType == Types::kRegression)
397  Log() << "regression with " << mva->DataInfo().GetNTargets() << " targets." << Endl;
398  else
399  Log() << "classification with " << mva->DataInfo().GetNClasses() << " classes." << Endl;
400  itrMethod = fMethods.erase( itrMethod );
401  continue;
402  }
403  if (mva->Data()->GetNTrainingEvents() >= MinNoTrainingEvents) {
404 
405  Log() << kINFO << "Train method: " << mva->GetMethodName() << " for "
406  << (analysisType == Types::kRegression ? "Regression" : "Classification") << Endl;
407  mva->TrainMethod();
408  Log() << kINFO << "Training finished" << Endl;
409 
410  } else {
411 
412  Log() << kWARNING << "Method " << mva->GetMethodName()
413  << " not trained (training tree has less entries ["
414  << mva->Data()->GetNTrainingEvents()
415  << "] than required [" << MinNoTrainingEvents << "]" << Endl;
416 
417  Log() << kERROR << " w/o training/test events for that category, I better stop here and let you fix " << Endl;
418  Log() << kFATAL << "that one first, otherwise things get too messy later ... " << Endl;
419 
420  }
421  }
422 
423  if (analysisType != Types::kRegression) {
424 
425  // variable ranking
426  Log() << kINFO << "Begin ranking of input variables..." << Endl;
427  for (itrMethod = fMethods.begin(); itrMethod != fMethods.end(); ++itrMethod) {
428  MethodBase* mva = dynamic_cast<MethodBase*>(*itrMethod);
429  if (mva && mva->Data()->GetNTrainingEvents() >= MinNoTrainingEvents) {
430  const Ranking* ranking = (*itrMethod)->CreateRanking();
431  if (ranking != 0)
432  ranking->Print();
433  else
434  Log() << kINFO << "No variable ranking supplied by classifier: "
435  << dynamic_cast<MethodBase*>(*itrMethod)->GetMethodName() << Endl;
436  }
437  }
438  }
439 }
440 
441 ////////////////////////////////////////////////////////////////////////////////
442 /// create XML description of Category classifier
443 
444 void TMVA::MethodCategory::AddWeightsXMLTo( void* parent ) const
445 {
446  void* wght = gTools().AddChild(parent, "Weights");
447  gTools().AddAttr( wght, "NSubMethods", fMethods.size() );
448  void* submethod(0);
449 
450  // iterate over methods and write them to XML file
451  for (UInt_t i=0; i<fMethods.size(); i++) {
452  MethodBase* method = dynamic_cast<MethodBase*>(fMethods[i]);
453  submethod = gTools().AddChild(wght, "SubMethod");
454  gTools().AddAttr(submethod, "Index", i);
455  gTools().AddAttr(submethod, "Method", method->GetMethodTypeName() + "::" + method->GetMethodName());
456  gTools().AddAttr(submethod, "Cut", fCategoryCuts[i]);
457  gTools().AddAttr(submethod, "Variables", fVars[i]);
458  method->WriteStateToXML( submethod );
459  }
460 }
461 
462 ////////////////////////////////////////////////////////////////////////////////
463 /// read weights of sub-classifiers of MethodCategory from xml weight file
464 
466 {
467  UInt_t nSubMethods;
468  TString fullMethodName;
469  TString methodType;
470  TString methodTitle;
471  TString theCutString;
472  TString theVariables;
473  Int_t titleLength;
474  gTools().ReadAttr( wghtnode, "NSubMethods", nSubMethods );
475  void* subMethodNode = gTools().GetChild(wghtnode);
476 
477  Log() << kINFO << "Recreating sub-classifiers from XML-file " << Endl;
478 
479  // recreate all sub-methods from weight file
480  for (UInt_t i=0; i<nSubMethods; i++) {
481  gTools().ReadAttr( subMethodNode, "Method", fullMethodName );
482  gTools().ReadAttr( subMethodNode, "Cut", theCutString );
483  gTools().ReadAttr( subMethodNode, "Variables", theVariables );
484 
485  // determine sub-method type
486  methodType = fullMethodName(0,fullMethodName.Index("::"));
487  if (methodType.Contains(" ")) methodType = methodType(methodType.Last(' ')+1,methodType.Length());
488 
489  // determine sub-method title
490  titleLength = fullMethodName.Length()-fullMethodName.Index("::")-2;
491  methodTitle = fullMethodName(fullMethodName.Index("::")+2,titleLength);
492 
493  // reconstruct dsi for sub-method
494  DataSetInfo& dsi = CreateCategoryDSI(TCut(theCutString), theVariables, methodTitle);
495 
496  // recreate sub-method from weights and add to fMethods
497  MethodBase* method = dynamic_cast<MethodBase*>( ClassifierFactory::Instance().Create( methodType.Data(),
498  dsi, "none" ) );
499  if(method==0)
500  Log() << kFATAL << "Could not create sub-method " << method << " from XML." << Endl;
501 
502  method->SetupMethod();
503  method->ReadStateFromXML(subMethodNode);
504 
505  fMethods.push_back(method);
506  fCategoryCuts.push_back(TCut(theCutString));
507  fVars.push_back(theVariables);
508 
509  DataSetInfo& primaryDSI = DataInfo();
510 
511  UInt_t spectatorIdx = 10000;
512  UInt_t counter=0;
513 
514  // find the spectator index
515  std::vector<VariableInfo>& spectators=primaryDSI.GetSpectatorInfos();
516  std::vector<VariableInfo>::iterator itrVarInfo;
517  TString specName= Form("%s_cat%i", GetName(),(int)fCategorySpecIdx.size()+1);
518 
519  for (itrVarInfo = spectators.begin(); itrVarInfo != spectators.end(); ++itrVarInfo, ++counter) {
520  if((specName==itrVarInfo->GetLabel()) || (specName==itrVarInfo->GetExpression())) {
521  spectatorIdx=counter;
522  fCategorySpecIdx.push_back(spectatorIdx);
523  break;
524  }
525  }
526 
527  subMethodNode = gTools().GetNextChild(subMethodNode);
528  }
529 
530  InitCircularTree(DataInfo());
531 
532 }
533 
534 ////////////////////////////////////////////////////////////////////////////////
535 /// process user options
536 
538 {
539 }
540 
541 ////////////////////////////////////////////////////////////////////////////////
542 /// Get help message text
543 ///
544 /// typical length of text line:
545 /// "|--------------------------------------------------------------|"
546 
548 {
549  Log() << Endl;
550  Log() << gTools().Color("bold") << "--- Short description:" << gTools().Color("reset") << Endl;
551  Log() << Endl;
552  Log() << "This method allows to define different categories of events. The" <<Endl;
553  Log() << "categories are defined via cuts on the variables. For each" << Endl;
554  Log() << "category, a different classifier and set of variables can be" <<Endl;
555  Log() << "specified. The categories which are defined for this method must" << Endl;
556  Log() << "be disjoint." << Endl;
557 }
558 
559 ////////////////////////////////////////////////////////////////////////////////
560 /// no ranking
561 
563 {
564  return 0;
565 }
566 
567 ////////////////////////////////////////////////////////////////////////////////
568 
570 {
571  // if it's not a simple 'spectator' variable (0 or 1) that the categories are defined by
572  // (but rather some 'formula' (i.e. eta>0), then this formulas are stored in fCatTree and that
573  // one will be evaluated.. (the formulae return 'true' or 'false'
574  if (fCatTree) {
575  if (methodIdx>=fCatFormulas.size()) {
576  Log() << kFATAL << "Large method index " << methodIdx << ", number of category formulas = "
577  << fCatFormulas.size() << Endl;
578  }
579  TTreeFormula* f = fCatFormulas[methodIdx];
580  return f->EvalInstance(0) > 0.5;
581  }
582  // otherwise, it simply looks if "variable == true" ("greater 0.5 to be "sure" )
583  else {
584 
585  // checks whether an event lies within a cut
586  if (methodIdx>=fCategorySpecIdx.size()) {
587  Log() << kFATAL << "Unknown method index " << methodIdx << " maximum allowed index="
588  << fCategorySpecIdx.size() << Endl;
589  }
590  UInt_t spectatorIdx = fCategorySpecIdx[methodIdx];
591  Float_t specVal = ev->GetSpectator(spectatorIdx);
592  Bool_t pass = (specVal>0.5);
593  return pass;
594  }
595 }
596 
597 ////////////////////////////////////////////////////////////////////////////////
598 /// returns the mva value of the right sub-classifier
599 
601 {
602  if (fMethods.empty()) return 0;
603 
604  UInt_t methodToUse = 0;
605  const Event* ev = GetEvent();
606 
607  // determine which sub-classifier to use for this event
608  Int_t suitableCutsN = 0;
609 
610  for (UInt_t i=0; i<fMethods.size(); ++i) {
611  if (PassesCut(ev, i)) {
612  ++suitableCutsN;
613  methodToUse=i;
614  }
615  }
616 
617  if (suitableCutsN == 0) {
618  Log() << kWARNING << "Event does not lie within the cut of any sub-classifier." << Endl;
619  return 0;
620  }
621 
622  if (suitableCutsN > 1) {
623  Log() << kFATAL << "The defined categories are not disjoint." << Endl;
624  return 0;
625  }
626 
627  // get mva value from the suitable sub-classifier
628  ev->SetVariableArrangement(&fVarMaps[methodToUse]);
629  Double_t mvaValue = dynamic_cast<MethodBase*>(fMethods[methodToUse])->GetMvaValue(ev,err,errUpper);
630  ev->SetVariableArrangement(0);
631 
632  std::cout << "Event is for method " << methodToUse << " spectator is " << ev->GetSpectator(0) << " "
633  << fVarMaps[0][0] << " classID " << DataInfo().IsSignal(ev) << " value " << mvaValue
634  << " type " << Data()->GetCurrentType() << std::endl;
635 
636  return mvaValue;
637 }
638 
639 ///////////////////////////////////////////////////////////////
640 /// returns the mva values of the right sub-classifier
641 ///
642 std::vector<Double_t>
644 {
645 
646  std::vector<Double_t> result;
647 
648  Info("GetMVaValues", "Evaluate MethodCategory for %d events type %d on the dataset %s", int(lastEvt - firstEvt),
649  (int)Data()->GetCurrentType(), DataInfo().GetName());
650 
651  if (fMethods.empty())
652  return result;
653 
654  auto data = Data();
655 
656  // it is faster to evaluate all categories
657  std::vector<std::vector<Double_t>> mvaValues(fMethods.size());
658  for (UInt_t i = 0; i < fMethods.size(); ++i) {
659  // need to set variable map
660  for (UInt_t iev = firstEvt; iev < lastEvt; ++iev) {
661  data->SetCurrentEvent(iev);
662  const Event *ev = GetEvent(data->GetEvent());
663  ev->SetVariableArrangement(&fVarMaps[i]);
664  }
665  // need to set correct data in the different method
666  mvaValues[i] = dynamic_cast<MethodBase *>(fMethods[i])->GetDataMvaValues(data,firstEvt, lastEvt, logProgress);
667  }
668 
669  // now loop on all events
670  result.resize(lastEvt - firstEvt);
671 
672  for (UInt_t iev = firstEvt; iev < lastEvt; ++iev)
673  {
674  //std::cout << "Loop on event " << iev << " of " << DataInfo().GetName() << std::endl;
675  data->SetCurrentEvent(iev);
676  UInt_t methodToUse = 0;
677  const Event *ev = GetEvent(data->GetEvent());
678 
679  // determine which sub-classifier to use for this event
680  Int_t suitableCutsN = 0;
681 
682  for (UInt_t i = 0; i < fMethods.size(); ++i) {
683  if (PassesCut(ev, i)) {
684  ++suitableCutsN;
685  methodToUse = i;
686  }
687  }
688 
689  if (suitableCutsN == 0) {
690  Log() << kWARNING << "Event does not lie within the cut of any sub-classifier." << Endl;
691  result[iev] = 0;
692  }
693 
694  if (suitableCutsN > 1) {
695  Log() << kFATAL << "The defined categories are not disjoint." << Endl;
696  return result;
697  }
698 
699 
700  result[iev - firstEvt] = mvaValues[methodToUse][iev - firstEvt];
701 
702  // std::cout << "Event " << iev << " is for method " << methodToUse << " spectator is " << ev->GetSpectator(0)
703  // << " " << fVarMaps[0][0] << " classID " << DataInfo().IsSignal(ev) << " value "
704  // << result[iev - firstEvt] << " type " << data->GetCurrentType() << std::endl;
705 
706  // reset variable map which was set it before
707  ev->SetVariableArrangement(nullptr);
708  }
709  return result;
710 }
711 
712 ////////////////////////////////////////////////////////////////////////////////
713 /// returns the mva values of the multi-class right sub-classifier
714 ///
715 const std::vector<Float_t> &TMVA::MethodCategory::GetMulticlassValues()
716 {
717  if (fMethods.empty())
719 
720  UInt_t methodToUse = 0;
721  const Event *ev = GetEvent();
722 
723  // determine which sub-classifier to use for this event
724  Int_t suitableCutsN = 0;
725 
726  for (UInt_t i = 0; i < fMethods.size(); ++i) {
727  if (PassesCut(ev, i)) {
728  ++suitableCutsN;
729  methodToUse = i;
730  }
731  }
732 
733  if (suitableCutsN == 0) {
734  Log() << kWARNING << "Event does not lie within the cut of any sub-classifier." << Endl;
736  }
737 
738  if (suitableCutsN > 1) {
739  Log() << kFATAL << "The defined categories are not disjoint." << Endl;
741  }
742  MethodBase *meth = dynamic_cast<MethodBase *>(fMethods[methodToUse]);
743  if (!meth) {
744  Log() << kFATAL << "method not found in Category Regression method" << Endl;
746  }
747  // get mva value from the suitable sub-classifier
748  ev->SetVariableArrangement(&fVarMaps[methodToUse]);
749  auto &result = meth->GetMulticlassValues();
750  ev->SetVariableArrangement(nullptr);
751  return result;
752 }
753 
754 ////////////////////////////////////////////////////////////////////////////////
755 /// returns the mva value of the right sub-classifier
756 
757 const std::vector<Float_t> &TMVA::MethodCategory::GetRegressionValues()
758 {
759  if (fMethods.empty()) return MethodBase::GetRegressionValues();
760 
761  UInt_t methodToUse = 0;
762  const Event* ev = GetEvent();
763 
764  // determine which sub-classifier to use for this event
765  Int_t suitableCutsN = 0;
766 
767  for (UInt_t i=0; i<fMethods.size(); ++i) {
768  if (PassesCut(ev, i)) {
769  ++suitableCutsN;
770  methodToUse=i;
771  }
772  }
773 
774  if (suitableCutsN == 0) {
775  Log() << kWARNING << "Event does not lie within the cut of any sub-classifier." << Endl;
777  }
778 
779  if (suitableCutsN > 1) {
780  Log() << kFATAL << "The defined categories are not disjoint." << Endl;
782  }
783  MethodBase* meth = dynamic_cast<MethodBase*>(fMethods[methodToUse]);
784  if (!meth){
785  Log() << kFATAL << "method not found in Category Regression method" << Endl;
787  }
788  // get mva value from the suitable sub-classifier
789  ev->SetVariableArrangement(&fVarMaps[methodToUse]);
790  auto & result = meth->GetRegressionValues(ev);
791  return result;
792 }
TMVA::MethodCategory::HasAnalysisType
virtual Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t)
check whether method category has analysis type the method type has to be the same for all sub-method...
Definition: MethodCategory.cxx:116
TMVA::MethodCategory::AddMethod
TMVA::IMethod * AddMethod(const TCut &, const TString &theVariables, Types::EMVA theMethod, const TString &theTitle, const TString &theOptions)
adds sub-classifier for a category
Definition: MethodCategory.cxx:138
TCut
A specialized string object used for TTree selections.
Definition: TCut.h:25
kTRUE
const Bool_t kTRUE
Definition: RtypesCore.h:100
TH1F.h
DataSetManager.h
TMVA::Tools::GetChild
void * GetChild(void *parent, const char *childname=0)
get child node
Definition: Tools.cxx:1162
TMVA::MethodBase::Data
DataSet * Data() const
Definition: MethodBase.h:409
TMVA::MethodBase::SetModelPersistence
void SetModelPersistence(Bool_t status)
Definition: MethodBase.h:382
TMVA::Ranking::Print
virtual void Print() const
get maximum length of variable names
Definition: Ranking.cxx:111
f
#define f(i)
Definition: RSha256.hxx:104
TDirectory.h
TMVA::kERROR
@ kERROR
Definition: Types.h:62
TMVA::DataSetInfo::AddSpectator
VariableInfo & AddSpectator(const TString &expression, const TString &title, const TString &unit, Double_t min, Double_t max, char type='F', Bool_t normalized=kTRUE, void *external=0)
add a spectator (can be a complex expression) to the set of spectator variables used in the MV analys...
Definition: DataSetInfo.cxx:289
TMVA::DataSetInfo::SetCut
void SetCut(const TCut &cut, const TString &className)
set the cut for the classes
Definition: DataSetInfo.cxx:360
TMVA::Tools::SplitString
std::vector< TString > SplitString(const TString &theOpt, const char separator) const
splits the option string at 'separator' and fills the list 'splitV' with the primitive strings
Definition: Tools.cxx:1211
TMVA::Types::kRegression
@ kRegression
Definition: Types.h:130
TMVA::MethodBase::SetupMethod
void SetupMethod()
setup of methods
Definition: MethodBase.cxx:406
TMVA::DataSetInfo::SetNormalization
void SetNormalization(const TString &norm)
Definition: DataSetInfo.h:132
TString::Data
const char * Data() const
Definition: TString.h:369
DataSetInfo.h
ClassImp
#define ClassImp(name)
Definition: Rtypes.h:364
Form
char * Form(const char *fmt,...)
TNamed::GetTitle
virtual const char * GetTitle() const
Returns title of object.
Definition: TNamed.h:48
TMVA::Ranking
Ranking for variables in method (implementation)
Definition: Ranking.h:48
TMVA::MethodBase::ReadStateFromXML
void ReadStateFromXML(void *parent)
Definition: MethodBase.cxx:1480
TMVA::MethodCategory::GetRegressionValues
virtual const std::vector< Float_t > & GetRegressionValues()
returns the mva value of the right sub-classifier
Definition: MethodCategory.cxx:757
IMethod.h
TMVA::MethodBase::SetSilentFile
void SetSilentFile(Bool_t status)
Definition: MethodBase.h:378
MinNoTrainingEvents
const Int_t MinNoTrainingEvents
Definition: Factory.cxx:95
Long64_t
long long Long64_t
Definition: RtypesCore.h:80
TMath::Log
Double_t Log(Double_t x)
Definition: TMath.h:760
TTree
A TTree represents a columnar dataset.
Definition: TTree.h:79
Ranking.h
TMVA::Tools::AddChild
void * AddChild(void *parent, const char *childname, const char *content=0, bool isRootNode=false)
add child node
Definition: Tools.cxx:1136
TMVA::MethodCategory::ProcessOptions
void ProcessOptions()
process user options
Definition: MethodCategory.cxx:537
TTreeFormula.h
TMVA::MethodBase::TrainMethod
void TrainMethod()
Definition: MethodBase.cxx:650
Float_t
float Float_t
Definition: RtypesCore.h:57
VariableInfo.h
TString::Contains
Bool_t Contains(const char *pat, ECaseCompare cmp=kExact) const
Definition: TString.h:624
TString::Length
Ssiz_t Length() const
Definition: TString.h:410
TMVA::DataSetInfo::GetSpectatorInfos
std::vector< VariableInfo > & GetSpectatorInfos()
Definition: DataSetInfo.h:122
TMVA::MethodBase::GetRegressionValues
const std::vector< Float_t > & GetRegressionValues(const TMVA::Event *const ev)
Definition: MethodBase.h:214
TDirectory::TContext
TDirectory::TContext keeps track and restore the current directory.
Definition: TDirectory.h:89
MethodBase.h
TMVA::DataSetInfo::SetSplitOptions
void SetSplitOptions(const TString &so)
Definition: DataSetInfo.h:185
TString
Basic string class.
Definition: TString.h:136
TMVA::VariableInfo::GetVarType
char GetVarType() const
Definition: VariableInfo.h:61
TMVA::MethodCategory
Class for categorizing the phase space.
Definition: MethodCategory.h:58
TMVA::MethodCategory::PassesCut
Bool_t PassesCut(const Event *ev, UInt_t methodIdx)
Definition: MethodCategory.cxx:569
REGISTER_METHOD
#define REGISTER_METHOD(CLASS)
for example
Definition: ClassifierFactory.h:124
TMVA::MethodCompositeBase
Virtual base class for combining several TMVA method.
Definition: MethodCompositeBase.h:50
bool
TMVA::MethodCategory::DeclareOptions
void DeclareOptions()
options for this method
Definition: MethodCategory.cxx:131
TString::Last
Ssiz_t Last(char c) const
Find last occurrence of a character c.
Definition: TString.cxx:913
PDF.h
TMVA::MethodCategory::AddWeightsXMLTo
void AddWeightsXMLTo(void *parent) const
create XML description of Category classifier
Definition: MethodCategory.cxx:444
TMVA::MethodBase::DataInfo
DataSetInfo & DataInfo() const
Definition: MethodBase.h:410
TMVA::DataSetInfo::SetWeightExpression
void SetWeightExpression(const TString &exp, const TString &className="")
set the weight expressions for the classes if class name is specified, set only for this class if cla...
Definition: DataSetInfo.cxx:333
TMVA::ClassifierFactory::Instance
static ClassifierFactory & Instance()
access to the ClassifierFactory singleton creates the instance if needed
Definition: ClassifierFactory.cxx:48
TMVA::DataSetInfo::GetNClasses
UInt_t GetNClasses() const
Definition: DataSetInfo.h:155
TMVA::MethodCategory::~MethodCategory
virtual ~MethodCategory(void)
destructor
Definition: MethodCategory.cxx:104
TMVA::Tools::AddAttr
void AddAttr(void *node, const char *, const T &value, Int_t precision=16)
add attribute to xml
Definition: Tools.h:353
TSpline.h
TMVA::DataSetInfo
Class that contains all the data information.
Definition: DataSetInfo.h:62
TMVA::DataSetInfo::GetCut
const TCut & GetCut(Int_t i) const
Definition: DataSetInfo.h:168
TMVA::MethodCategory::Train
void Train(void)
train all sub-classifiers
Definition: MethodCategory.cxx:367
TMVA::DataSetInfo::GetTargetInfos
std::vector< VariableInfo > & GetTargetInfos()
Definition: DataSetInfo.h:114
TMVA::DataSetInfo::AddCut
void AddCut(const TCut &cut, const TString &className)
set the cut for the classes
Definition: DataSetInfo.cxx:376
TMVA::MethodCategory::GetMvaValues
virtual std::vector< Double_t > GetMvaValues(Long64_t firstEvt=0, Long64_t lastEvt=-1, Bool_t logProgress=false)
returns the mva values of the right sub-classifier
Definition: MethodCategory.cxx:643
VariableRearrangeTransform.h
TMVA::MethodBase::GetMethodName
const TString & GetMethodName() const
Definition: MethodBase.h:331
MsgLogger.h
Timer.h
TMVA::MethodBase::MethodCategory
friend class MethodCategory
Definition: MethodBase.h:269
TMVA::DataSetInfo::GetNormalization
const TString & GetNormalization() const
Definition: DataSetInfo.h:131
TMVA::Types::EAnalysisType
EAnalysisType
Definition: Types.h:128
size
size_t size(const MatrixT &matrix)
retrieve the size of a square matrix
TMVA::MethodCategory::CreateRanking
const Ranking * CreateRanking()
no ranking
Definition: MethodCategory.cxx:562
TMVA::variables
void variables(TString dataset, TString fin="TMVA.root", TString dirName="InputVariables_Id", TString title="TMVA Input Variables", Bool_t isRegression=kFALSE, Bool_t useTMVAStyle=kTRUE)
TMVA::MethodBase::CheckSetup
virtual void CheckSetup()
check may be overridden by derived class (sometimes, eg, fitters are used which can only be implement...
Definition: MethodBase.cxx:433
TMVA::DataSetInfo::GetWeightExpression
const TString GetWeightExpression(Int_t i) const
Definition: DataSetInfo.h:164
TMVA::DataSetInfo::GetClassInfo
ClassInfo * GetClassInfo(Int_t clNum) const
Definition: DataSetInfo.cxx:146
kFALSE
const Bool_t kFALSE
Definition: RtypesCore.h:101
TMVA::MethodBase::GetMulticlassValues
virtual const std::vector< Float_t > & GetMulticlassValues()
Definition: MethodBase.h:227
TMVA::VariableInfo
Class for type info of MVA input variable.
Definition: VariableInfo.h:47
TMVA::MethodBase::SetMethodBaseDir
void SetMethodBaseDir(TDirectory *methodDir)
Definition: MethodBase.h:374
TMVA::MethodCategory::InitCircularTree
void InitCircularTree(const DataSetInfo &dsi)
initialize the circular tree
Definition: MethodCategory.cxx:316
TMVA::Tools::ReadAttr
void ReadAttr(void *node, const char *, T &value)
read attribute from xml
Definition: Tools.h:335
TMVA::MethodCategory::Init
void Init()
initialize the method
Definition: MethodCategory.cxx:309
TMVA::MethodCompositeBase::GetMvaValue
virtual Double_t GetMvaValue(Double_t *errLower=0, Double_t *errUpper=0)=0
TTreeFormula
Used to pass a selection expression to the Tree drawing routine.
Definition: TTreeFormula.h:58
TRandom3.h
TMVA::MethodCategory::GetMulticlassValues
virtual const std::vector< Float_t > & GetMulticlassValues()
returns the mva values of the multi-class right sub-classifier
Definition: MethodCategory.cxx:715
TDirectory::GetDirectory
virtual TDirectory * GetDirectory(const char *namecycle, Bool_t printError=false, const char *funcname="GetDirectory")
Find a directory using apath.
Definition: TDirectory.cxx:407
TMVA::MethodBase::SetWeightFileDir
void SetWeightFileDir(TString fileDir)
set directory of weight file
Definition: MethodBase.cxx:2059
TMVA::MethodBase::GetMethodTypeName
TString GetMethodTypeName() const
Definition: MethodBase.h:332
TMVA::MethodBase
Virtual base Class for all MVA method.
Definition: MethodBase.h:111
TMVA::Types
Singleton class for Global types used by TMVA.
Definition: Types.h:73
TMVA::MethodBase::DisableWriting
void DisableWriting(Bool_t setter)
Definition: MethodBase.h:442
TMVA::DataSetInfo::GetRootDir
TDirectory * GetRootDir() const
Definition: DataSetInfo.h:190
Types.h
TMVA::DataSetInfo::GetSplitOptions
const TString & GetSplitOptions() const
Definition: DataSetInfo.h:186
TMVA::MethodBase::WriteStateToXML
void WriteStateToXML(void *parent) const
general method used in writing the header of the weight files where the used variables,...
Definition: MethodBase.cxx:1331
TMVA::Endl
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:158
Config.h
unsigned int
TMVA::TMVAGlob::GetMethodName
void GetMethodName(TString &name, TKey *mkey)
Definition: tmvaglob.cxx:342
TMVA::IMethod
Interface for all concrete MVA method implementations.
Definition: IMethod.h:53
TMVA::Tools::Color
const TString & Color(const TString &)
human readable color strings
Definition: Tools.cxx:840
TString::Index
Ssiz_t Index(const char *pat, Ssiz_t i=0, ECaseCompare cmp=kExact) const
Definition: TString.h:639
TMVA::DataSetInfo::AddVariable
VariableInfo & AddVariable(const TString &expression, const TString &title="", const TString &unit="", Double_t min=0, Double_t max=0, char varType='F', Bool_t normalized=kTRUE, void *external=0)
add a variable (can be a complex expression) to the set of variables used in the MV analysis
Definition: DataSetInfo.cxx:207
TMVA::DataSetInfo::GetNTargets
UInt_t GetNTargets() const
Definition: DataSetInfo.h:128
TMVA::Event::GetSpectator
Float_t GetSpectator(UInt_t ivar) const
return spectator content
Definition: Event.cxx:261
Double_t
double Double_t
Definition: RtypesCore.h:59
TMVA::DataSetInfo::SetRootDir
void SetRootDir(TDirectory *d)
Definition: DataSetInfo.h:189
TMVA::DataSetInfo::GetVariableInfos
std::vector< VariableInfo > & GetVariableInfos()
Definition: DataSetInfo.h:103
TMVA::MethodBase::SetAnalysisType
virtual void SetAnalysisType(Types::EAnalysisType type)
Definition: MethodBase.h:436
TMVA::MethodBase::GetRegressionValues
virtual const std::vector< Float_t > & GetRegressionValues()
Definition: MethodBase.h:221
TMVA::Types::Instance
static Types & Instance()
the the single instance of "Types" if existing already, or create it (Singleton)
Definition: Types.cxx:69
TMVA::DataSetInfo::AddClass
ClassInfo * AddClass(const TString &className)
Definition: DataSetInfo.cxx:113
TMVA::IMethod::HasAnalysisType
virtual Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets)=0
Info
void Info(const char *location, const char *msgfmt,...)
Use this function for informational messages.
Definition: TError.cxx:220
TMVA::kFATAL
@ kFATAL
Definition: Types.h:63
TMVA::Tools::GetNextChild
void * GetNextChild(void *prevchild, const char *childname=0)
XML helpers.
Definition: Tools.cxx:1174
TMVA::Types::EMVA
EMVA
Definition: Types.h:78
TMVA::MethodBase::GetName
const char * GetName() const
Definition: MethodBase.h:334
TMVA::MethodCategory::CreateCategoryDSI
TMVA::DataSetInfo & CreateCategoryDSI(const TCut &, const TString &, const TString &)
create a DataSetInfo object for a sub-classifier
Definition: MethodCategory.cxx:200
TMVA::Event
Definition: Event.h:51
TMVA::Event::SetVariableArrangement
void SetVariableArrangement(std::vector< UInt_t > *const m) const
set the variable arrangement
Definition: Event.cxx:191
TMVA::MethodCategory::ReadWeightsFromXML
void ReadWeightsFromXML(void *wghtnode)
read weights of sub-classifiers of MethodCategory from xml weight file
Definition: MethodCategory.cxx:465
TDirectory
Describe directory structure in memory.
Definition: TDirectory.h:45
TMVA::VariableInfo::GetExpression
const TString & GetExpression() const
Definition: VariableInfo.h:57
MethodCompositeBase.h
TMVA::MethodBase::SetFile
void SetFile(TFile *file)
Definition: MethodBase.h:375
TMVA::VariableInfo::GetExternalLink
void * GetExternalLink() const
Definition: VariableInfo.h:83
TMVA::kINFO
@ kINFO
Definition: Types.h:60
Tools.h
TMVA::DataSet::GetNTrainingEvents
Long64_t GetNTrainingEvents() const
Definition: DataSet.h:68
TNamed::GetName
virtual const char * GetName() const
Returns name of object.
Definition: TNamed.h:47
ClassifierFactory.h
type
int type
Definition: TGX11.cxx:121
TMVA::DataSetInfo::AddTarget
VariableInfo & AddTarget(const TString &expression, const TString &title, const TString &unit, Double_t min, Double_t max, Bool_t normalized=kTRUE, void *external=0)
add a variable (can be a complex expression) to the set of variables used in the MV analysis
Definition: DataSetInfo.cxx:259
TMVA::gTools
Tools & gTools()
TMVA::MethodCategory::GetHelpMessage
void GetHelpMessage() const
Get help message text.
Definition: MethodCategory.cxx:547
MethodCategory.h
DataSet.h
TMVA::Configurable::ParseOptions
virtual void ParseOptions()
options parser
Definition: Configurable.cxx:124
TMVA::ClassifierFactory::Create
IMethod * Create(const std::string &name, const TString &job, const TString &title, DataSetInfo &dsi, const TString &option)
creates the method if needed based on the method name using the creator function the factory has stor...
Definition: ClassifierFactory.cxx:89
TMVA::kWARNING
@ kWARNING
Definition: Types.h:61
TMVA::MethodBase::ProcessSetup
void ProcessSetup()
process all options the "CheckForUnusedOptions" is done in an independent call, since it may be overr...
Definition: MethodBase.cxx:423
TMVA
create variable transformations
Definition: GeneticMinimizer.h:22
int