Logo ROOT   6.10/09
Reference Guide
MethodCategory.cxx
Go to the documentation of this file.
1 // @(#)root/tmva $Id$
2 // Author: Andreas Hoecker, Joerg Stelzer, Helge Voss, Eckhard von Toerne
3 
4 /**********************************************************************************
5  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6  * Package: TMVA *
7  * Class : MethodCompositeBase *
8  * Web : http://tmva.sourceforge.net *
9  * *
10  * Description: *
11  * Virtual base class for all MVA method *
12  * *
13  * Authors (alphabetical): *
14  * Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland *
15  * Nadim Sah <Nadim.Sah@cern.ch> - Berlin, Germany *
16  * Peter Speckmayer <Peter.Speckmazer@cern.ch> - CERN, Switzerland *
17  * Joerg Stelzer <Joerg.Stelzer@cern.ch> - MSU East Lansing, USA *
18  * Helge Voss <Helge.Voss@cern.ch> - MPI-K Heidelberg, Germany *
19  * Jan Therhaag <Jan.Therhaag@cern.ch> - U of Bonn, Germany *
20  * Eckhard v. Toerne <evt@uni-bonn.de> - U of Bonn, Germany *
21  * *
22  * Copyright (c) 2005-2011: *
23  * CERN, Switzerland *
24  * MSU East Lansing, USA *
25  * MPI-K Heidelberg, Germany *
26  * U. of Bonn, Germany *
27  * *
28  * Redistribution and use in source and binary forms, with or without *
29  * modification, are permitted according to the terms listed in LICENSE *
30  * (http://tmva.sourceforge.net/LICENSE) *
31  **********************************************************************************/
32 
33 /*! \class TMVA::MethodCategory
34 \ingroup TMVA
35 
36 Class for categorizing the phase space
37 
38 This class is meant to allow categorisation of the data. For different
39 categories, different classifiers may be booked and different variables
40 may be considered. The aim is to account for the difference that
41 is due to different locations/angles.
42 */
43 
44 
45 #include "TMVA/MethodCategory.h"
46 
47 #include <algorithm>
48 #include <iomanip>
49 #include <vector>
50 #include <iostream>
51 
52 #include "Riostream.h"
53 #include "TRandom3.h"
54 #include "TMath.h"
55 #include "TObjString.h"
56 #include "TH1F.h"
57 #include "TGraph.h"
58 #include "TSpline.h"
59 #include "TDirectory.h"
60 #include "TTreeFormula.h"
61 
62 #include "TMVA/ClassifierFactory.h"
63 #include "TMVA/Config.h"
64 #include "TMVA/DataSet.h"
65 #include "TMVA/DataSetInfo.h"
66 #include "TMVA/DataSetManager.h"
67 #include "TMVA/IMethod.h"
68 #include "TMVA/MethodBase.h"
70 #include "TMVA/MsgLogger.h"
71 #include "TMVA/PDF.h"
72 #include "TMVA/Ranking.h"
73 #include "TMVA/Timer.h"
74 #include "TMVA/Tools.h"
75 #include "TMVA/Types.h"
76 #include "TMVA/VariableInfo.h"
78 
79 REGISTER_METHOD(Category)
80 
82 
83 ////////////////////////////////////////////////////////////////////////////////
84 /// standard constructor
85 
87  const TString& methodTitle,
88  DataSetInfo& theData,
89  const TString& theOption )
90  : TMVA::MethodCompositeBase( jobName, Types::kCategory, methodTitle, theData, theOption),
91  fCatTree(0),
92  fDataSetManager(NULL)
93 {
94 }
95 
96 ////////////////////////////////////////////////////////////////////////////////
97 /// constructor from weight file
98 
100  const TString& theWeightFile)
101  : TMVA::MethodCompositeBase( Types::kCategory, dsi, theWeightFile),
102  fCatTree(0),
103  fDataSetManager(NULL)
104 {
105 }
106 
107 ////////////////////////////////////////////////////////////////////////////////
108 /// destructor
109 
111 {
112  std::vector<TTreeFormula*>::iterator formIt = fCatFormulas.begin();
113  std::vector<TTreeFormula*>::iterator lastF = fCatFormulas.end();
114  for(;formIt!=lastF; ++formIt) delete *formIt;
115  delete fCatTree;
116 }
117 
118 ////////////////////////////////////////////////////////////////////////////////
119 /// check whether method category has analysis type
120 /// the method type has to be the same for all sub-methods
121 
123 {
124  std::vector<IMethod*>::iterator itrMethod = fMethods.begin();
125 
126  // iterate over methods and check whether they have the analysis type
127  for(; itrMethod != fMethods.end(); ++itrMethod ) {
128  if ( !(*itrMethod)->HasAnalysisType(type, numberClasses, numberTargets) )
129  return kFALSE;
130  }
131  return kTRUE;
132 }
133 
134 ////////////////////////////////////////////////////////////////////////////////
135 /// options for this method
136 
138 {
139 }
140 
141 ////////////////////////////////////////////////////////////////////////////////
142 /// adds sub-classifier for a category
143 
145  const TString& theVariables,
146  Types::EMVA theMethod ,
147  const TString& theTitle,
148  const TString& theOptions )
149 {
150  std::string addedMethodName = std::string(Types::Instance().GetMethodName(theMethod));
151 
152  Log() << kINFO << "Adding sub-classifier: " << addedMethodName << "::" << theTitle << Endl;
153 
154  DataSetInfo& dsi = CreateCategoryDSI(theCut, theVariables, theTitle);
155 
156  IMethod* addedMethod = ClassifierFactory::Instance().Create(addedMethodName,GetJobName(),theTitle,dsi,theOptions);
157 
158  MethodBase *method = (dynamic_cast<MethodBase*>(addedMethod));
159  if(method==0) return 0;
160 
163  method->SetAnalysisType( fAnalysisType );
164  method->SetupMethod();
165  method->ParseOptions();
166  method->ProcessSetup();
167  method->SetFile(fFile);
168  method->SetSilentFile(IsSilentFile());
169 
170 
171  // set or create correct method base dir for added method
172  const TString dirName(Form("Method_%s",method->GetMethodTypeName().Data()));
173  TDirectory * dir = BaseDir()->GetDirectory(dirName);
174  if (dir != 0) method->SetMethodBaseDir( dir );
175  else method->SetMethodBaseDir( BaseDir()->mkdir(dirName,Form("Directory for all %s methods", method->GetMethodTypeName().Data())) );
176 
177  // method->SetBaseDir(eigenes base dir, gucken ob Fisher dir existiert, sonst erzeugen )
178 
179  // check-for-unused-options is performed; may be overridden by derived
180  // classes
181  method->CheckSetup();
182 
183  // disable writing of XML files and standalone classes for sub methods
184  method->DisableWriting( kTRUE );
185 
186  // store method, cut and variable names and create cut formula
187  fMethods.push_back(method);
188  fCategoryCuts.push_back(theCut);
189  fVars.push_back(theVariables);
190 
191  DataSetInfo& primaryDSI = DataInfo();
192 
193  UInt_t newSpectatorIndex = primaryDSI.GetSpectatorInfos().size();
194  fCategorySpecIdx.push_back(newSpectatorIndex);
195 
196  primaryDSI.AddSpectator( Form("%s_cat%i:=%s", GetName(),(int)fMethods.size(),theCut.GetTitle()),
197  Form("%s:%s",GetName(),method->GetName()),
198  "pass", 0, 0, 'C' );
199 
200  return method;
201 }
202 
203 ////////////////////////////////////////////////////////////////////////////////
204 /// create a DataSetInfo object for a sub-classifier
205 
207  const TString& theVariables,
208  const TString& theTitle)
209 {
210  // create a new dsi with name: theTitle+"_dsi"
211  TString dsiName=theTitle+"_dsi";
212  DataSetInfo& oldDSI = DataInfo();
213  DataSetInfo* dsi = new DataSetInfo(dsiName);
214 
215  // register the new dsi
216  // DataSetManager::Instance().AddDataSetInfo(*dsi); // DSMTEST replaced by following line
218 
219  // copy the targets and spectators from the old dsi to the new dsi
220  std::vector<VariableInfo>::iterator itrVarInfo;
221 
222  for (itrVarInfo = oldDSI.GetTargetInfos().begin(); itrVarInfo != oldDSI.GetTargetInfos().end(); itrVarInfo++)
223  dsi->AddTarget(*itrVarInfo);
224 
225  for (itrVarInfo = oldDSI.GetSpectatorInfos().begin(); itrVarInfo != oldDSI.GetSpectatorInfos().end(); itrVarInfo++)
226  dsi->AddSpectator(*itrVarInfo);
227 
228  // split string that contains the variables into tiny little pieces
229  std::vector<TString> variables = gTools().SplitString(theVariables,':' );
230 
231  // prepare to create varMap
232  std::vector<UInt_t> varMap;
233  Int_t counter=0;
234 
235  // add the variables that were specified in theVariables
236  std::vector<TString>::iterator itrVariables;
237  Bool_t found = kFALSE;
238 
239  // iterate over all variables in 'variables' and add them
240  for (itrVariables = variables.begin(); itrVariables != variables.end(); itrVariables++) {
241  counter=0;
242 
243  // check the variables of the old dsi for the variable that we want to add
244  for (itrVarInfo = oldDSI.GetVariableInfos().begin(); itrVarInfo != oldDSI.GetVariableInfos().end(); itrVarInfo++) {
245  if((*itrVariables==itrVarInfo->GetLabel()) ) { // || (*itrVariables==itrVarInfo->GetExpression())) {
246  // don't compare the expression, since the user might take two times the same expression, but with different labels
247  // and apply different transformations to the variables.
248  dsi->AddVariable(*itrVarInfo);
249  varMap.push_back(counter);
250  found = kTRUE;
251  }
252  counter++;
253  }
254 
255  // check the spectators of the old dsi for the variable that we want to add
256  for (itrVarInfo = oldDSI.GetSpectatorInfos().begin(); itrVarInfo != oldDSI.GetSpectatorInfos().end(); itrVarInfo++) {
257  if((*itrVariables==itrVarInfo->GetLabel()) ) { // || (*itrVariables==itrVarInfo->GetExpression())) {
258  // don't compare the expression, since the user might take two times the same expression, but with different labels
259  // and apply different transformations to the variables.
260  dsi->AddVariable(*itrVarInfo);
261  varMap.push_back(counter);
262  found = kTRUE;
263  }
264  counter++;
265  }
266 
267  // if the variable is neither in the variables nor in the spectators, we abort
268  if (!found) {
269  Log() << kFATAL <<"The variable " << itrVariables->Data() << " was not found and could not be added " << Endl;
270  }
271  found = kFALSE;
272  }
273 
274  // in the case that no variables are specified, add the default-variables from the original dsi
275  if (theVariables=="") {
276  for (UInt_t i=0; i<oldDSI.GetVariableInfos().size(); i++) {
277  dsi->AddVariable(oldDSI.GetVariableInfos()[i]);
278  varMap.push_back(i);
279  }
280  }
281 
282  // add the variable map 'varMap' to the vector of varMaps
283  fVarMaps.push_back(varMap);
284 
285  // set classes and cuts
286  UInt_t nClasses=oldDSI.GetNClasses();
287  TString className;
288 
289  for (UInt_t i=0; i<nClasses; i++) {
290  className = oldDSI.GetClassInfo(i)->GetName();
291  dsi->AddClass(className);
292  dsi->SetCut(oldDSI.GetCut(i),className);
293  dsi->AddCut(theCut,className);
294  dsi->SetWeightExpression(oldDSI.GetWeightExpression(i),className);
295  }
296 
297  // set split options, root dir and normalization for the new dsi
298  dsi->SetSplitOptions(oldDSI.GetSplitOptions());
299  dsi->SetRootDir(oldDSI.GetRootDir());
300  TString norm(oldDSI.GetNormalization().Data());
301  dsi->SetNormalization(norm);
302 
303  DataSetInfo& dsiReference= (*dsi);
304 
305  return dsiReference;
306 }
307 
308 ////////////////////////////////////////////////////////////////////////////////
309 /// initialize the method
310 
312 {
313 }
314 
315 ////////////////////////////////////////////////////////////////////////////////
316 /// initialize the circular tree
317 
319 {
320  delete fCatTree;
321 
322  std::vector<VariableInfo>::const_iterator viIt;
323  const std::vector<VariableInfo>& vars = dsi.GetVariableInfos();
324  const std::vector<VariableInfo>& specs = dsi.GetSpectatorInfos();
325 
326  Bool_t hasAllExternalLinks = kTRUE;
327  for (viIt = vars.begin(); viIt != vars.end(); ++viIt)
328  if( viIt->GetExternalLink() == 0 ) {
329  hasAllExternalLinks = kFALSE;
330  break;
331  }
332  for (viIt = specs.begin(); viIt != specs.end(); ++viIt)
333  if( viIt->GetExternalLink() == 0 ) {
334  hasAllExternalLinks = kFALSE;
335  break;
336  }
337 
338  if(!hasAllExternalLinks) return;
339 
340  {
341  // Rather than having TTree::TTree add to the current directory and then remove it, let
342  // make sure to not add it in the first place.
343  // The add-then-remove can lead to a problem if gDirectory points to the same directory (for example
344  // gROOT) in the current thread and another one (and both try to add to the directory at the same time).
345  TDirectory::TContext ctxt(nullptr);
346  fCatTree = new TTree(Form("Circ%s",GetMethodName().Data()),"Circular Tree for categorization");
347  fCatTree->SetCircular(1);
348  }
349 
350  for (viIt = vars.begin(); viIt != vars.end(); ++viIt) {
351  const VariableInfo& vi = *viIt;
353  }
354  for (viIt = specs.begin(); viIt != specs.end(); ++viIt) {
355  const VariableInfo& vi = *viIt;
356  if(vi.GetVarType()=='C') continue;
358  }
359 
360  for(UInt_t cat=0; cat!=fCategoryCuts.size(); ++cat) {
361  fCatFormulas.push_back(new TTreeFormula(Form("Category_%i",cat), fCategoryCuts[cat].GetTitle(), fCatTree));
362  }
363 }
364 
365 ////////////////////////////////////////////////////////////////////////////////
366 /// train all sub-classifiers
367 
369 {
370  // specify the minimum # of training events and set 'classification'
371  const Int_t MinNoTrainingEvents = 10;
372 
373  Types::EAnalysisType analysisType = GetAnalysisType();
374 
375  // start the training
376  Log() << kINFO << "Train all sub-classifiers for "
377  << (analysisType == Types::kRegression ? "Regression" : "Classification") << " ..." << Endl;
378 
379  // don't do anything if no sub-classifier booked
380  if (fMethods.empty()) {
381  Log() << kINFO << "...nothing found to train" << Endl;
382  return;
383  }
384 
385  std::vector<IMethod*>::iterator itrMethod;
386 
387  // iterate over all booked sub-classifiers and train them
388  for (itrMethod = fMethods.begin(); itrMethod != fMethods.end(); ++itrMethod ) {
389 
390  MethodBase* mva = dynamic_cast<MethodBase*>(*itrMethod);
391  if(!mva) continue;
392  mva->SetAnalysisType( analysisType );
393  if (!mva->HasAnalysisType( analysisType,
394  mva->DataInfo().GetNClasses(),
395  mva->DataInfo().GetNTargets() ) ) {
396  Log() << kWARNING << "Method " << mva->GetMethodTypeName() << " is not capable of handling " ;
397  if (analysisType == Types::kRegression)
398  Log() << "regression with " << mva->DataInfo().GetNTargets() << " targets." << Endl;
399  else
400  Log() << "classification with " << mva->DataInfo().GetNClasses() << " classes." << Endl;
401  itrMethod = fMethods.erase( itrMethod );
402  continue;
403  }
404  if (mva->Data()->GetNTrainingEvents() >= MinNoTrainingEvents) {
405 
406  Log() << kINFO << "Train method: " << mva->GetMethodName() << " for "
407  << (analysisType == Types::kRegression ? "Regression" : "Classification") << Endl;
408  mva->TrainMethod();
409  Log() << kINFO << "Training finished" << Endl;
410 
411  } else {
412 
413  Log() << kWARNING << "Method " << mva->GetMethodName()
414  << " not trained (training tree has less entries ["
415  << mva->Data()->GetNTrainingEvents()
416  << "] than required [" << MinNoTrainingEvents << "]" << Endl;
417 
418  Log() << kERROR << " w/o training/test events for that category, I better stop here and let you fix " << Endl;
419  Log() << kFATAL << "that one first, otherwise things get too messy later ... " << Endl;
420 
421  }
422  }
423 
424  if (analysisType != Types::kRegression) {
425 
426  // variable ranking
427  Log() << kINFO << "Begin ranking of input variables..." << Endl;
428  for (itrMethod = fMethods.begin(); itrMethod != fMethods.end(); itrMethod++) {
429  MethodBase* mva = dynamic_cast<MethodBase*>(*itrMethod);
430  if (mva && mva->Data()->GetNTrainingEvents() >= MinNoTrainingEvents) {
431  const Ranking* ranking = (*itrMethod)->CreateRanking();
432  if (ranking != 0)
433  ranking->Print();
434  else
435  Log() << kINFO << "No variable ranking supplied by classifier: "
436  << dynamic_cast<MethodBase*>(*itrMethod)->GetMethodName() << Endl;
437  }
438  }
439  }
440 }
441 
442 ////////////////////////////////////////////////////////////////////////////////
443 /// create XML description of Category classifier
444 
445 void TMVA::MethodCategory::AddWeightsXMLTo( void* parent ) const
446 {
447  void* wght = gTools().AddChild(parent, "Weights");
448  gTools().AddAttr( wght, "NSubMethods", fMethods.size() );
449  void* submethod(0);
450 
451  std::vector<IMethod*>::iterator itrMethod;
452 
453  // iterate over methods and write them to XML file
454  for (UInt_t i=0; i<fMethods.size(); i++) {
455  MethodBase* method = dynamic_cast<MethodBase*>(fMethods[i]);
456  submethod = gTools().AddChild(wght, "SubMethod");
457  gTools().AddAttr(submethod, "Index", i);
458  gTools().AddAttr(submethod, "Method", method->GetMethodTypeName() + "::" + method->GetMethodName());
459  gTools().AddAttr(submethod, "Cut", fCategoryCuts[i]);
460  gTools().AddAttr(submethod, "Variables", fVars[i]);
461  method->WriteStateToXML( submethod );
462  }
463 }
464 
465 ////////////////////////////////////////////////////////////////////////////////
466 /// read weights of sub-classifiers of MethodCategory from xml weight file
467 
469 {
470  UInt_t nSubMethods;
471  TString fullMethodName;
472  TString methodType;
473  TString methodTitle;
474  TString theCutString;
475  TString theVariables;
476  Int_t titleLength;
477  gTools().ReadAttr( wghtnode, "NSubMethods", nSubMethods );
478  void* subMethodNode = gTools().GetChild(wghtnode);
479 
480  Log() << kINFO << "Recreating sub-classifiers from XML-file " << Endl;
481 
482  // recreate all sub-methods from weight file
483  for (UInt_t i=0; i<nSubMethods; i++) {
484  gTools().ReadAttr( subMethodNode, "Method", fullMethodName );
485  gTools().ReadAttr( subMethodNode, "Cut", theCutString );
486  gTools().ReadAttr( subMethodNode, "Variables", theVariables );
487 
488  // determine sub-method type
489  methodType = fullMethodName(0,fullMethodName.Index("::"));
490  if (methodType.Contains(" ")) methodType = methodType(methodType.Last(' ')+1,methodType.Length());
491 
492  // determine sub-method title
493  titleLength = fullMethodName.Length()-fullMethodName.Index("::")-2;
494  methodTitle = fullMethodName(fullMethodName.Index("::")+2,titleLength);
495 
496  // reconstruct dsi for sub-method
497  DataSetInfo& dsi = CreateCategoryDSI(TCut(theCutString), theVariables, methodTitle);
498 
499  // recreate sub-method from weights and add to fMethods
500  MethodBase* method = dynamic_cast<MethodBase*>( ClassifierFactory::Instance().Create( methodType.Data(),
501  dsi, "none" ) );
502  if(method==0)
503  Log() << kFATAL << "Could not create sub-method " << method << " from XML." << Endl;
504 
505  method->SetupMethod();
506  method->ReadStateFromXML(subMethodNode);
507 
508  fMethods.push_back(method);
509  fCategoryCuts.push_back(TCut(theCutString));
510  fVars.push_back(theVariables);
511 
512  DataSetInfo& primaryDSI = DataInfo();
513 
514  UInt_t spectatorIdx = 10000;
515  UInt_t counter=0;
516 
517  // find the spectator index
518  std::vector<VariableInfo>& spectators=primaryDSI.GetSpectatorInfos();
519  std::vector<VariableInfo>::iterator itrVarInfo;
520  TString specName= Form("%s_cat%i", GetName(),(int)fCategorySpecIdx.size()+1);
521 
522  for (itrVarInfo = spectators.begin(); itrVarInfo != spectators.end(); ++itrVarInfo, ++counter) {
523  if((specName==itrVarInfo->GetLabel()) || (specName==itrVarInfo->GetExpression())) {
524  spectatorIdx=counter;
525  fCategorySpecIdx.push_back(spectatorIdx);
526  break;
527  }
528  }
529 
530  subMethodNode = gTools().GetNextChild(subMethodNode);
531  }
532 
534 
535 }
536 
537 ////////////////////////////////////////////////////////////////////////////////
538 /// process user options
539 
541 {
542 }
543 
544 ////////////////////////////////////////////////////////////////////////////////
545 /// Get help message text
546 ///
547 /// typical length of text line:
548 /// "|--------------------------------------------------------------|"
549 
551 {
552  Log() << Endl;
553  Log() << gTools().Color("bold") << "--- Short description:" << gTools().Color("reset") << Endl;
554  Log() << Endl;
555  Log() << "This method allows to define different categories of events. The" <<Endl;
556  Log() << "categories are defined via cuts on the variables. For each" << Endl;
557  Log() << "category, a different classifier and set of variables can be" <<Endl;
558  Log() << "specified. The categories which are defined for this method must" << Endl;
559  Log() << "be disjoint." << Endl;
560 }
561 
562 ////////////////////////////////////////////////////////////////////////////////
563 /// no ranking
564 
566 {
567  return 0;
568 }
569 
570 ////////////////////////////////////////////////////////////////////////////////
571 
573 {
574  // if it's not a simple 'spectator' variable (0 or 1) that the categories are defined by
575  // (but rather some 'formula' (i.e. eta>0), then this formulas are stored in fCatTree and that
576  // one will be evaluated.. (the formulae return 'true' or 'false'
577  if (fCatTree) {
578  if (methodIdx>=fCatFormulas.size()) {
579  Log() << kFATAL << "Large method index " << methodIdx << ", number of category formulas = "
580  << fCatFormulas.size() << Endl;
581  }
582  TTreeFormula* f = fCatFormulas[methodIdx];
583  return f->EvalInstance(0) > 0.5;
584  }
585  // otherwise, it simply looks if "variable == true" ("greater 0.5 to be "sure" )
586  else {
587 
588  // checks whether an event lies within a cut
589  if (methodIdx>=fCategorySpecIdx.size()) {
590  Log() << kFATAL << "Unknown method index " << methodIdx << " maximum allowed index="
591  << fCategorySpecIdx.size() << Endl;
592  }
593  UInt_t spectatorIdx = fCategorySpecIdx[methodIdx];
594  Float_t specVal = ev->GetSpectator(spectatorIdx);
595  Bool_t pass = (specVal>0.5);
596  return pass;
597  }
598 }
599 
600 ////////////////////////////////////////////////////////////////////////////////
601 /// returns the mva value of the right sub-classifier
602 
604 {
605  if (fMethods.empty()) return 0;
606 
607  UInt_t methodToUse = 0;
608  const Event* ev = GetEvent();
609 
610  // determine which sub-classifier to use for this event
611  Int_t suitableCutsN = 0;
612 
613  for (UInt_t i=0; i<fMethods.size(); ++i) {
614  if (PassesCut(ev, i)) {
615  ++suitableCutsN;
616  methodToUse=i;
617  }
618  }
619 
620  if (suitableCutsN == 0) {
621  Log() << kWARNING << "Event does not lie within the cut of any sub-classifier." << Endl;
622  return 0;
623  }
624 
625  if (suitableCutsN > 1) {
626  Log() << kFATAL << "The defined categories are not disjoint." << Endl;
627  return 0;
628  }
629 
630  // get mva value from the suitable sub-classifier
631  ev->SetVariableArrangement(&fVarMaps[methodToUse]);
632  Double_t mvaValue = dynamic_cast<MethodBase*>(fMethods[methodToUse])->GetMvaValue(ev,err,errUpper);
633  ev->SetVariableArrangement(0);
634 
635  return mvaValue;
636 }
637 
638 
639 
640 ////////////////////////////////////////////////////////////////////////////////
641 /// returns the mva value of the right sub-classifier
642 
643 const std::vector<Float_t> &TMVA::MethodCategory::GetRegressionValues()
644 {
645  if (fMethods.empty()) return MethodBase::GetRegressionValues();
646 
647  UInt_t methodToUse = 0;
648  const Event* ev = GetEvent();
649 
650  // determine which sub-classifier to use for this event
651  Int_t suitableCutsN = 0;
652 
653  for (UInt_t i=0; i<fMethods.size(); ++i) {
654  if (PassesCut(ev, i)) {
655  ++suitableCutsN;
656  methodToUse=i;
657  }
658  }
659 
660  if (suitableCutsN == 0) {
661  Log() << kWARNING << "Event does not lie within the cut of any sub-classifier." << Endl;
663  }
664 
665  if (suitableCutsN > 1) {
666  Log() << kFATAL << "The defined categories are not disjoint." << Endl;
668  }
669  MethodBase* meth = dynamic_cast<MethodBase*>(fMethods[methodToUse]);
670  if (!meth){
671  Log() << kFATAL << "method not found in Category Regression method" << Endl;
673  }
674  // get mva value from the suitable sub-classifier
675  return meth->GetRegressionValues(ev);
676 }
677 
IMethod * Create(const std::string &name, const TString &job, const TString &title, DataSetInfo &dsi, const TString &option)
creates the method if needed based on the method name using the creator function the factory has stor...
static ClassifierFactory & Instance()
access to the ClassifierFactory singleton creates the instance if needed
Types::EAnalysisType fAnalysisType
Definition: MethodBase.h:577
std::vector< IMethod * > fMethods
void SetModelPersistence(Bool_t status)
Definition: MethodBase.h:366
virtual const char * GetName() const
Returns name of object.
Definition: TNamed.h:47
void Init()
initialize the method
void WriteStateToXML(void *parent) const
general method used in writing the header of the weight files where the used variables, variable transformation type etc.
TMVA::IMethod * AddMethod(const TCut &, const TString &theVariables, Types::EMVA theMethod, const TString &theTitle, const TString &theOptions)
adds sub-classifier for a category
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:158
Singleton class for Global types used by TMVA.
Definition: Types.h:73
VariableInfo & AddTarget(const TString &expression, const TString &title, const TString &unit, Double_t min, Double_t max, Bool_t normalized=kTRUE, void *external=0)
add a variable (can be a complex expression) to the set of variables used in the MV analysis ...
void ReadStateFromXML(void *parent)
std::vector< VariableInfo > & GetSpectatorInfos()
Definition: DataSetInfo.h:104
float Float_t
Definition: RtypesCore.h:53
void SetCut(const TCut &cut, const TString &className)
set the cut for the classes
void InitCircularTree(const DataSetInfo &dsi)
initialize the circular tree
static Types & Instance()
the the single instance of "Types" if existing already, or create it (Singleton)
Definition: Types.cxx:70
void SetMethodBaseDir(TDirectory *methodDir)
Definition: MethodBase.h:358
MsgLogger & Log() const
Definition: Configurable.h:122
EAnalysisType
Definition: Types.h:125
std::vector< TCut > fCategoryCuts
Virtual base Class for all MVA method.
Definition: MethodBase.h:106
const std::vector< Float_t > & GetRegressionValues(const TMVA::Event *const ev)
Definition: MethodBase.h:204
std::vector< UInt_t > fCategorySpecIdx
Ssiz_t Index(const char *pat, Ssiz_t i=0, ECaseCompare cmp=kExact) const
Definition: TString.h:587
virtual const std::vector< Float_t > & GetRegressionValues()
Definition: MethodBase.h:211
Basic string class.
Definition: TString.h:129
Ranking for variables in method (implementation)
Definition: Ranking.h:48
int Int_t
Definition: RtypesCore.h:41
bool Bool_t
Definition: RtypesCore.h:59
UInt_t GetNClasses() const
Definition: DataSetInfo.h:136
virtual ~MethodCategory(void)
destructor
void AddCut(const TCut &cut, const TString &className)
set the cut for the classes
#define NULL
Definition: RtypesCore.h:88
void SetSilentFile(Bool_t status)
Definition: MethodBase.h:362
virtual Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets)=0
void AddAttr(void *node, const char *, const T &value, Int_t precision=16)
add attribute to xml
Definition: Tools.h:308
void * AddChild(void *parent, const char *childname, const char *content=0, bool isRootNode=false)
add child node
Definition: Tools.cxx:1135
const TString & GetExpression() const
Definition: VariableInfo.h:57
VariableInfo & AddSpectator(const TString &expression, const TString &title, const TString &unit, Double_t min, Double_t max, char type='F', Bool_t normalized=kTRUE, void *external=0)
add a spectator (can be a complex expression) to the set of spectator variables used in the MV analys...
const TString & GetNormalization() const
Definition: DataSetInfo.h:114
char GetVarType() const
Definition: VariableInfo.h:61
std::vector< std::vector< UInt_t > > fVarMaps
void DeclareOptions()
options for this method
const Event * GetEvent() const
Definition: MethodBase.h:733
DataSet * Data() const
Definition: MethodBase.h:393
Virtual base class for combining several TMVA method.
void * GetChild(void *parent, const char *childname=0)
get child node
Definition: Tools.cxx:1161
void SetVariableArrangement(std::vector< UInt_t > *const m) const
set the variable arrangement
Definition: Event.cxx:192
virtual void ParseOptions()
options parser
void SetupMethod()
setup of methods
Definition: MethodBase.cxx:411
DataSetInfo & DataInfo() const
Definition: MethodBase.h:394
Class that contains all the data information.
Definition: DataSetInfo.h:60
Long64_t GetNTrainingEvents() const
Definition: DataSet.h:79
std::vector< VariableInfo > & GetTargetInfos()
Definition: DataSetInfo.h:99
Bool_t fModelPersistence
Definition: MethodBase.h:615
TDirectory * GetRootDir() const
Definition: DataSetInfo.h:171
Used to pass a selection expression to the Tree drawing routine.
Definition: TTreeFormula.h:58
A specialized string object used for TTree selections.
Definition: TCut.h:25
std::vector< TTreeFormula * > fCatFormulas
needed in conjunction with TTreeFormulas for evaluation category expressions
MethodCategory(const TString &jobName, const TString &methodTitle, DataSetInfo &theData, const TString &theOption="")
standard constructor
const Int_t MinNoTrainingEvents
Definition: Factory.cxx:99
void SetSplitOptions(const TString &so)
Definition: DataSetInfo.h:166
const Ranking * CreateRanking()
no ranking
void * GetExternalLink() const
Definition: VariableInfo.h:81
UInt_t GetNTargets() const
Definition: DataSetInfo.h:111
TMVA::DataSetInfo & CreateCategoryDSI(const TCut &, const TString &, const TString &)
create a DataSetInfo object for a sub-classifier
const char * GetName() const
Definition: MethodBase.h:318
ClassInfo * GetClassInfo(Int_t clNum) const
void SetWeightExpression(const TString &exp, const TString &className="")
set the weight expressions for the classes if class name is specified, set only for this class if cla...
DataSetInfo & AddDataSetInfo(DataSetInfo &dsi)
stores a copy of the dataset info object
unsigned int UInt_t
Definition: RtypesCore.h:42
char * Form(const char *fmt,...)
DataSetManager * fDataSetManager
Ssiz_t Length() const
Definition: TString.h:388
const TString & GetJobName() const
Definition: MethodBase.h:314
const TString & GetMethodName() const
Definition: MethodBase.h:315
void Train(void)
train all sub-classifiers
void ReadAttr(void *node, const char *, T &value)
read attribute from xml
Definition: Tools.h:290
Tools & gTools()
Bool_t IsSilentFile()
Definition: MethodBase.h:363
const Bool_t kFALSE
Definition: RtypesCore.h:92
Class for categorizing the phase space.
virtual void Print() const
get maximum length of variable names
Definition: Ranking.cxx:111
virtual void CheckSetup()
check may be overridden by derived class (sometimes, eg, fitters are used which can only be implement...
Definition: MethodBase.cxx:438
#define ClassImp(name)
Definition: Rtypes.h:336
double f(double x)
double Double_t
Definition: RtypesCore.h:55
Ssiz_t Last(char c) const
Find last occurrence of a character c.
Definition: TString.cxx:875
Describe directory structure in memory.
Definition: TDirectory.h:34
int type
Definition: TGX11.cxx:120
void SetFile(TFile *file)
Definition: MethodBase.h:359
void * GetNextChild(void *prevchild, const char *childname=0)
XML helpers.
Definition: Tools.cxx:1173
Bool_t Contains(const char *pat, ECaseCompare cmp=kExact) const
Definition: TString.h:572
virtual Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t)
check whether method category has analysis type the method type has to be the same for all sub-method...
T EvalInstance(Int_t i=0, const char *stringStack[]=0)
Evaluate this treeformula.
ClassInfo * AddClass(const TString &className)
void ProcessSetup()
process all options the "CheckForUnusedOptions" is done in an independent call, since it may be overr...
Definition: MethodBase.cxx:428
std::vector< TString > fVars
void GetHelpMessage() const
Get help message text.
const TString & Color(const TString &)
human readable color strings
Definition: Tools.cxx:839
Double_t GetMvaValue(Double_t *err=0, Double_t *errUpper=0)
returns the mva value of the right sub-classifier
Interface for all concrete MVA method implementations.
Definition: IMethod.h:54
virtual Int_t Branch(TCollection *list, Int_t bufsize=32000, Int_t splitlevel=99, const char *name="")
Create one branch for each element in the collection.
Definition: TTree.cxx:1660
#define REGISTER_METHOD(CLASS)
for example
Abstract ClassifierFactory template that handles arbitrary types.
const TString & GetSplitOptions() const
Definition: DataSetInfo.h:167
TDirectory * BaseDir() const
returns the ROOT directory where info/histograms etc of the corresponding MVA method instance are sto...
TString GetMethodTypeName() const
Definition: MethodBase.h:316
const TCut & GetCut(Int_t i) const
Definition: DataSetInfo.h:149
void SetWeightFileDir(TString fileDir)
set directory of weight file
void AddWeightsXMLTo(void *parent) const
create XML description of Category classifier
virtual TDirectory * GetDirectory(const char *namecycle, Bool_t printError=false, const char *funcname="GetDirectory")
Find a directory using apath.
Definition: TDirectory.cxx:338
std::vector< TString > SplitString(const TString &theOpt, const char separator) const
splits the option string at &#39;separator&#39; and fills the list &#39;splitV&#39; with the primitive strings ...
Definition: Tools.cxx:1210
virtual void SetCircular(Long64_t maxEntries)
Enable/Disable circularity for this tree.
Definition: TTree.cxx:8365
VariableInfo & AddVariable(const TString &expression, const TString &title="", const TString &unit="", Double_t min=0, Double_t max=0, char varType='F', Bool_t normalized=kTRUE, void *external=0)
add a variable (can be a complex expression) to the set of variables used in the MV analysis ...
void DisableWriting(Bool_t setter)
Definition: MethodBase.h:426
Types::EAnalysisType GetAnalysisType() const
Definition: MethodBase.h:421
A TTree object has a header with a name and a title.
Definition: TTree.h:78
void ReadWeightsFromXML(void *wghtnode)
read weights of sub-classifiers of MethodCategory from xml weight file
Bool_t PassesCut(const Event *ev, UInt_t methodIdx)
Class for type info of MVA input variable.
Definition: VariableInfo.h:47
TString fFileDir
Definition: MethodBase.h:619
Float_t GetSpectator(UInt_t ivar) const
return spectator content
Definition: Event.cxx:262
const TString GetWeightExpression(Int_t i) const
Definition: DataSetInfo.h:145
const Bool_t kTRUE
Definition: RtypesCore.h:91
double norm(double *x, double *p)
Definition: unuranDistr.cxx:40
void SetNormalization(const TString &norm)
Definition: DataSetInfo.h:115
void SetRootDir(TDirectory *d)
Definition: DataSetInfo.h:170
virtual const std::vector< Float_t > & GetRegressionValues()
returns the mva value of the right sub-classifier
std::vector< VariableInfo > & GetVariableInfos()
Definition: DataSetInfo.h:94
void variables(TString dataset, TString fin="TMVA.root", TString dirName="InputVariables_Id", TString title="TMVA Input Variables", Bool_t isRegression=kFALSE, Bool_t useTMVAStyle=kTRUE)
virtual void SetAnalysisType(Types::EAnalysisType type)
Definition: MethodBase.h:420
virtual const char * GetTitle() const
Returns title of object.
Definition: TNamed.h:48
void ProcessOptions()
process user options
const char * Data() const
Definition: TString.h:347