Logo ROOT   6.08/07
Reference Guide
MethodCategory.cxx
Go to the documentation of this file.
1 // @(#)root/tmva $Id$
2 // Author: Andreas Hoecker, Joerg Stelzer, Helge Voss, Eckhard von Toerne
3 
4 /**********************************************************************************
5  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6  * Package: TMVA *
7  * Class : MethodCompositeBase *
8  * Web : http://tmva.sourceforge.net *
9  * *
10  * Description: *
11  * Virtual base class for all MVA method *
12  * *
13  * Authors (alphabetical): *
14  * Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland *
15  * Nadim Sah <Nadim.Sah@cern.ch> - Berlin, Germany *
16  * Peter Speckmayer <Peter.Speckmazer@cern.ch> - CERN, Switzerland *
17  * Joerg Stelzer <Joerg.Stelzer@cern.ch> - MSU East Lansing, USA *
18  * Helge Voss <Helge.Voss@cern.ch> - MPI-K Heidelberg, Germany *
19  * Jan Therhaag <Jan.Therhaag@cern.ch> - U of Bonn, Germany *
20  * Eckhard v. Toerne <evt@uni-bonn.de> - U of Bonn, Germany *
21  * *
22  * Copyright (c) 2005-2011: *
23  * CERN, Switzerland *
24  * MSU East Lansing, USA *
25  * MPI-K Heidelberg, Germany *
26  * U. of Bonn, Germany *
27  * *
28  * Redistribution and use in source and binary forms, with or without *
29  * modification, are permitted according to the terms listed in LICENSE *
30  * (http://tmva.sourceforge.net/LICENSE) *
31  **********************************************************************************/
32 
33 //__________________________________________________________________________
34 //
35 // This class is meant to allow categorisation of the data. For different //
36 // categories, different classifiers may be booked and different variab- //
37 // les may be considered. The aim is to account for the difference that //
38 // is due to different locations/angles. //
39 ////////////////////////////////////////////////////////////////////////////////
40 
41 #include "TMVA/MethodCategory.h"
42 
43 #include <algorithm>
44 #include <iomanip>
45 #include <vector>
46 #include <iostream>
47 
48 #include "Riostream.h"
49 #include "TRandom3.h"
50 #include "TMath.h"
51 #include "TObjString.h"
52 #include "TH1F.h"
53 #include "TGraph.h"
54 #include "TSpline.h"
55 #include "TDirectory.h"
56 #include "TTreeFormula.h"
57 
58 #include "TMVA/ClassifierFactory.h"
59 #include "TMVA/Config.h"
60 #include "TMVA/DataSet.h"
61 #include "TMVA/DataSetInfo.h"
62 #include "TMVA/DataSetManager.h"
63 #include "TMVA/IMethod.h"
64 #include "TMVA/MethodBase.h"
66 #include "TMVA/MsgLogger.h"
67 #include "TMVA/PDF.h"
68 #include "TMVA/Ranking.h"
69 #include "TMVA/Timer.h"
70 #include "TMVA/Tools.h"
71 #include "TMVA/Types.h"
72 #include "TMVA/VariableInfo.h"
74 
75 REGISTER_METHOD(Category)
76 
78 
79 ////////////////////////////////////////////////////////////////////////////////
80 /// standard constructor
81 
83  const TString& methodTitle,
84  DataSetInfo& theData,
85  const TString& theOption )
86  : TMVA::MethodCompositeBase( jobName, Types::kCategory, methodTitle, theData, theOption),
87  fCatTree(0),
88  fDataSetManager(NULL)
89 {
90 }
91 
92 ////////////////////////////////////////////////////////////////////////////////
93 /// constructor from weight file
94 
96  const TString& theWeightFile)
97  : TMVA::MethodCompositeBase( Types::kCategory, dsi, theWeightFile),
98  fCatTree(0),
99  fDataSetManager(NULL)
100 {
101 }
102 
103 ////////////////////////////////////////////////////////////////////////////////
104 /// destructor
105 
107 {
108  std::vector<TTreeFormula*>::iterator formIt = fCatFormulas.begin();
109  std::vector<TTreeFormula*>::iterator lastF = fCatFormulas.end();
110  for(;formIt!=lastF; ++formIt) delete *formIt;
111  delete fCatTree;
112 }
113 
114 ////////////////////////////////////////////////////////////////////////////////
115 /// check whether method category has analysis type
116 /// the method type has to be the same for all sub-methods
117 
119 {
120  std::vector<IMethod*>::iterator itrMethod = fMethods.begin();
121 
122  // iterate over methods and check whether they have the analysis type
123  for(; itrMethod != fMethods.end(); ++itrMethod ) {
124  if ( !(*itrMethod)->HasAnalysisType(type, numberClasses, numberTargets) )
125  return kFALSE;
126  }
127  return kTRUE;
128 }
129 
130 ////////////////////////////////////////////////////////////////////////////////
131 /// options for this method
132 
134 {
135 }
136 
137 ////////////////////////////////////////////////////////////////////////////////
138 /// adds sub-classifier for a category
139 
141  const TString& theVariables,
142  Types::EMVA theMethod ,
143  const TString& theTitle,
144  const TString& theOptions )
145 {
146  std::string addedMethodName = std::string(Types::Instance().GetMethodName(theMethod));
147 
148  Log() << kINFO << "Adding sub-classifier: " << addedMethodName << "::" << theTitle << Endl;
149 
150  DataSetInfo& dsi = CreateCategoryDSI(theCut, theVariables, theTitle);
151 
152  IMethod* addedMethod = ClassifierFactory::Instance().Create(addedMethodName,GetJobName(),theTitle,dsi,theOptions);
153 
154  MethodBase *method = (dynamic_cast<MethodBase*>(addedMethod));
155  if(method==0) return 0;
156 
159  method->SetAnalysisType( fAnalysisType );
160  method->SetupMethod();
161  method->ParseOptions();
162  method->ProcessSetup();
163  method->SetFile(fFile);
164  method->SetSilentFile(IsSilentFile());
165 
166 
167  // set or create correct method base dir for added method
168  const TString dirName(Form("Method_%s",method->GetMethodTypeName().Data()));
169  TDirectory * dir = BaseDir()->GetDirectory(dirName);
170  if (dir != 0) method->SetMethodBaseDir( dir );
171  else method->SetMethodBaseDir( BaseDir()->mkdir(dirName,Form("Directory for all %s methods", method->GetMethodTypeName().Data())) );
172 
173  // method->SetBaseDir(eigenes base dir, gucken ob Fisher dir existiert, sonst erzeugen )
174 
175  // check-for-unused-options is performed; may be overridden by derived
176  // classes
177  method->CheckSetup();
178 
179  // disable writing of XML files and standalone classes for sub methods
180  method->DisableWriting( kTRUE );
181 
182  // store method, cut and variable names and create cut formula
183  fMethods.push_back(method);
184  fCategoryCuts.push_back(theCut);
185  fVars.push_back(theVariables);
186 
187  DataSetInfo& primaryDSI = DataInfo();
188 
189  UInt_t newSpectatorIndex = primaryDSI.GetSpectatorInfos().size();
190  fCategorySpecIdx.push_back(newSpectatorIndex);
191 
192  primaryDSI.AddSpectator( Form("%s_cat%i:=%s", GetName(),(int)fMethods.size(),theCut.GetTitle()),
193  Form("%s:%s",GetName(),method->GetName()),
194  "pass", 0, 0, 'C' );
195 
196  return method;
197 }
198 
199 ////////////////////////////////////////////////////////////////////////////////
200 /// create a DataSetInfo object for a sub-classifier
201 
203  const TString& theVariables,
204  const TString& theTitle)
205 {
206  // create a new dsi with name: theTitle+"_dsi"
207  TString dsiName=theTitle+"_dsi";
208  DataSetInfo& oldDSI = DataInfo();
209  DataSetInfo* dsi = new DataSetInfo(dsiName);
210 
211  // register the new dsi
212  // DataSetManager::Instance().AddDataSetInfo(*dsi); // DSMTEST replaced by following line
214 
215  // copy the targets and spectators from the old dsi to the new dsi
216  std::vector<VariableInfo>::iterator itrVarInfo;
217 
218  for (itrVarInfo = oldDSI.GetTargetInfos().begin(); itrVarInfo != oldDSI.GetTargetInfos().end(); itrVarInfo++)
219  dsi->AddTarget(*itrVarInfo);
220 
221  for (itrVarInfo = oldDSI.GetSpectatorInfos().begin(); itrVarInfo != oldDSI.GetSpectatorInfos().end(); itrVarInfo++)
222  dsi->AddSpectator(*itrVarInfo);
223 
224  // split string that contains the variables into tiny little pieces
225  std::vector<TString> variables = gTools().SplitString(theVariables,':' );
226 
227  // prepare to create varMap
228  std::vector<UInt_t> varMap;
229  Int_t counter=0;
230 
231  // add the variables that were specified in theVariables
232  std::vector<TString>::iterator itrVariables;
233  Bool_t found = kFALSE;
234 
235  // iterate over all variables in 'variables' and add them
236  for (itrVariables = variables.begin(); itrVariables != variables.end(); itrVariables++) {
237  counter=0;
238 
239  // check the variables of the old dsi for the variable that we want to add
240  for (itrVarInfo = oldDSI.GetVariableInfos().begin(); itrVarInfo != oldDSI.GetVariableInfos().end(); itrVarInfo++) {
241  if((*itrVariables==itrVarInfo->GetLabel()) ) { // || (*itrVariables==itrVarInfo->GetExpression())) {
242  // don't compare the expression, since the user might take two times the same expression, but with different labels
243  // and apply different transformations to the variables.
244  dsi->AddVariable(*itrVarInfo);
245  varMap.push_back(counter);
246  found = kTRUE;
247  }
248  counter++;
249  }
250 
251  // check the spectators of the old dsi for the variable that we want to add
252  for (itrVarInfo = oldDSI.GetSpectatorInfos().begin(); itrVarInfo != oldDSI.GetSpectatorInfos().end(); itrVarInfo++) {
253  if((*itrVariables==itrVarInfo->GetLabel()) ) { // || (*itrVariables==itrVarInfo->GetExpression())) {
254  // don't compare the expression, since the user might take two times the same expression, but with different labels
255  // and apply different transformations to the variables.
256  dsi->AddVariable(*itrVarInfo);
257  varMap.push_back(counter);
258  found = kTRUE;
259  }
260  counter++;
261  }
262 
263  // if the variable is neither in the variables nor in the spectators, we abort
264  if (!found) {
265  Log() << kFATAL <<"The variable " << itrVariables->Data() << " was not found and could not be added " << Endl;
266  }
267  found = kFALSE;
268  }
269 
270  // in the case that no variables are specified, add the default-variables from the original dsi
271  if (theVariables=="") {
272  for (UInt_t i=0; i<oldDSI.GetVariableInfos().size(); i++) {
273  dsi->AddVariable(oldDSI.GetVariableInfos()[i]);
274  varMap.push_back(i);
275  }
276  }
277 
278  // add the variable map 'varMap' to the vector of varMaps
279  fVarMaps.push_back(varMap);
280 
281  // set classes and cuts
282  UInt_t nClasses=oldDSI.GetNClasses();
283  TString className;
284 
285  for (UInt_t i=0; i<nClasses; i++) {
286  className = oldDSI.GetClassInfo(i)->GetName();
287  dsi->AddClass(className);
288  dsi->SetCut(oldDSI.GetCut(i),className);
289  dsi->AddCut(theCut,className);
290  dsi->SetWeightExpression(oldDSI.GetWeightExpression(i),className);
291  }
292 
293  // set split options, root dir and normalization for the new dsi
294  dsi->SetSplitOptions(oldDSI.GetSplitOptions());
295  dsi->SetRootDir(oldDSI.GetRootDir());
296  TString norm(oldDSI.GetNormalization().Data());
297  dsi->SetNormalization(norm);
298 
299  DataSetInfo& dsiReference= (*dsi);
300 
301  return dsiReference;
302 }
303 
304 ////////////////////////////////////////////////////////////////////////////////
305 /// initialize the method
306 
308 {
309 }
310 
311 ////////////////////////////////////////////////////////////////////////////////
312 /// initialize the circular tree
313 
315 {
316  delete fCatTree;
317 
318  std::vector<VariableInfo>::const_iterator viIt;
319  const std::vector<VariableInfo>& vars = dsi.GetVariableInfos();
320  const std::vector<VariableInfo>& specs = dsi.GetSpectatorInfos();
321 
322  Bool_t hasAllExternalLinks = kTRUE;
323  for (viIt = vars.begin(); viIt != vars.end(); ++viIt)
324  if( viIt->GetExternalLink() == 0 ) {
325  hasAllExternalLinks = kFALSE;
326  break;
327  }
328  for (viIt = specs.begin(); viIt != specs.end(); ++viIt)
329  if( viIt->GetExternalLink() == 0 ) {
330  hasAllExternalLinks = kFALSE;
331  break;
332  }
333 
334  if(!hasAllExternalLinks) return;
335 
336  {
337  // Rather than having TTree::TTree add to the current directory and then remove it, let
338  // make sure to not add it in the first place.
339  // The add-then-remove can lead to a problem if gDirectory points to the same directory (for example
340  // gROOT) in the current thread and another one (and both try to add to the directory at the same time).
341  TDirectory::TContext ctxt(nullptr);
342  fCatTree = new TTree(Form("Circ%s",GetMethodName().Data()),"Circlar Tree for categorization");
343  fCatTree->SetCircular(1);
344  }
345 
346  for (viIt = vars.begin(); viIt != vars.end(); ++viIt) {
347  const VariableInfo& vi = *viIt;
349  }
350  for (viIt = specs.begin(); viIt != specs.end(); ++viIt) {
351  const VariableInfo& vi = *viIt;
352  if(vi.GetVarType()=='C') continue;
354  }
355 
356  for(UInt_t cat=0; cat!=fCategoryCuts.size(); ++cat) {
357  fCatFormulas.push_back(new TTreeFormula(Form("Category_%i",cat), fCategoryCuts[cat].GetTitle(), fCatTree));
358  }
359 }
360 
361 ////////////////////////////////////////////////////////////////////////////////
362 /// train all sub-classifiers
363 
365 {
366  // specify the minimum # of training events and set 'classification'
367  const Int_t MinNoTrainingEvents = 10;
368 
369  Types::EAnalysisType analysisType = GetAnalysisType();
370 
371  // start the training
372  Log() << kINFO << "Train all sub-classifiers for "
373  << (analysisType == Types::kRegression ? "Regression" : "Classification") << " ..." << Endl;
374 
375  // don't do anything if no sub-classifier booked
376  if (fMethods.empty()) {
377  Log() << kINFO << "...nothing found to train" << Endl;
378  return;
379  }
380 
381  std::vector<IMethod*>::iterator itrMethod;
382 
383  // iterate over all booked sub-classifiers and train them
384  for (itrMethod = fMethods.begin(); itrMethod != fMethods.end(); ++itrMethod ) {
385 
386  MethodBase* mva = dynamic_cast<MethodBase*>(*itrMethod);
387  if(!mva) continue;
388  mva->SetAnalysisType( analysisType );
389  if (!mva->HasAnalysisType( analysisType,
390  mva->DataInfo().GetNClasses(),
391  mva->DataInfo().GetNTargets() ) ) {
392  Log() << kWARNING << "Method " << mva->GetMethodTypeName() << " is not capable of handling " ;
393  if (analysisType == Types::kRegression)
394  Log() << "regression with " << mva->DataInfo().GetNTargets() << " targets." << Endl;
395  else
396  Log() << "classification with " << mva->DataInfo().GetNClasses() << " classes." << Endl;
397  itrMethod = fMethods.erase( itrMethod );
398  continue;
399  }
400  if (mva->Data()->GetNTrainingEvents() >= MinNoTrainingEvents) {
401 
402  Log() << kINFO << "Train method: " << mva->GetMethodName() << " for "
403  << (analysisType == Types::kRegression ? "Regression" : "Classification") << Endl;
404  mva->TrainMethod();
405  Log() << kINFO << "Training finished" << Endl;
406 
407  } else {
408 
409  Log() << kWARNING << "Method " << mva->GetMethodName()
410  << " not trained (training tree has less entries ["
411  << mva->Data()->GetNTrainingEvents()
412  << "] than required [" << MinNoTrainingEvents << "]" << Endl;
413 
414  Log() << kERROR << " w/o training/test events for that category, I better stop here and let you fix " << Endl;
415  Log() << kFATAL << "that one first, otherwise things get too messy later ... " << Endl;
416 
417  }
418  }
419 
420  if (analysisType != Types::kRegression) {
421 
422  // variable ranking
423  Log() << kINFO << "Begin ranking of input variables..." << Endl;
424  for (itrMethod = fMethods.begin(); itrMethod != fMethods.end(); itrMethod++) {
425  MethodBase* mva = dynamic_cast<MethodBase*>(*itrMethod);
426  if (mva && mva->Data()->GetNTrainingEvents() >= MinNoTrainingEvents) {
427  const Ranking* ranking = (*itrMethod)->CreateRanking();
428  if (ranking != 0)
429  ranking->Print();
430  else
431  Log() << kINFO << "No variable ranking supplied by classifier: "
432  << dynamic_cast<MethodBase*>(*itrMethod)->GetMethodName() << Endl;
433  }
434  }
435  }
436 }
437 
438 ////////////////////////////////////////////////////////////////////////////////
439 /// create XML description of Category classifier
440 
441 void TMVA::MethodCategory::AddWeightsXMLTo( void* parent ) const
442 {
443  void* wght = gTools().AddChild(parent, "Weights");
444  gTools().AddAttr( wght, "NSubMethods", fMethods.size() );
445  void* submethod(0);
446 
447  std::vector<IMethod*>::iterator itrMethod;
448 
449  // iterate over methods and write them to XML file
450  for (UInt_t i=0; i<fMethods.size(); i++) {
451  MethodBase* method = dynamic_cast<MethodBase*>(fMethods[i]);
452  submethod = gTools().AddChild(wght, "SubMethod");
453  gTools().AddAttr(submethod, "Index", i);
454  gTools().AddAttr(submethod, "Method", method->GetMethodTypeName() + "::" + method->GetMethodName());
455  gTools().AddAttr(submethod, "Cut", fCategoryCuts[i]);
456  gTools().AddAttr(submethod, "Variables", fVars[i]);
457  method->WriteStateToXML( submethod );
458  }
459 }
460 
461 ////////////////////////////////////////////////////////////////////////////////
462 /// read weights of sub-classifiers of MethodCategory from xml weight file
463 
465 {
466  UInt_t nSubMethods;
467  TString fullMethodName;
468  TString methodType;
469  TString methodTitle;
470  TString theCutString;
471  TString theVariables;
472  Int_t titleLength;
473  gTools().ReadAttr( wghtnode, "NSubMethods", nSubMethods );
474  void* subMethodNode = gTools().GetChild(wghtnode);
475 
476  Log() << kINFO << "Recreating sub-classifiers from XML-file " << Endl;
477 
478  // recreate all sub-methods from weight file
479  for (UInt_t i=0; i<nSubMethods; i++) {
480  gTools().ReadAttr( subMethodNode, "Method", fullMethodName );
481  gTools().ReadAttr( subMethodNode, "Cut", theCutString );
482  gTools().ReadAttr( subMethodNode, "Variables", theVariables );
483 
484  // determine sub-method type
485  methodType = fullMethodName(0,fullMethodName.Index("::"));
486  if (methodType.Contains(" ")) methodType = methodType(methodType.Last(' ')+1,methodType.Length());
487 
488  // determine sub-method title
489  titleLength = fullMethodName.Length()-fullMethodName.Index("::")-2;
490  methodTitle = fullMethodName(fullMethodName.Index("::")+2,titleLength);
491 
492  // reconstruct dsi for sub-method
493  DataSetInfo& dsi = CreateCategoryDSI(TCut(theCutString), theVariables, methodTitle);
494 
495  // recreate sub-method from weights and add to fMethods
496  MethodBase* method = dynamic_cast<MethodBase*>( ClassifierFactory::Instance().Create( methodType.Data(),
497  dsi, "none" ) );
498  if(method==0)
499  Log() << kFATAL << "Could not create sub-method " << method << " from XML." << Endl;
500 
501  method->SetupMethod();
502  method->ReadStateFromXML(subMethodNode);
503 
504  fMethods.push_back(method);
505  fCategoryCuts.push_back(TCut(theCutString));
506  fVars.push_back(theVariables);
507 
508  DataSetInfo& primaryDSI = DataInfo();
509 
510  UInt_t spectatorIdx = 10000;
511  UInt_t counter=0;
512 
513  // find the spectator index
514  std::vector<VariableInfo>& spectators=primaryDSI.GetSpectatorInfos();
515  std::vector<VariableInfo>::iterator itrVarInfo;
516  TString specName= Form("%s_cat%i", GetName(),(int)fCategorySpecIdx.size()+1);
517 
518  for (itrVarInfo = spectators.begin(); itrVarInfo != spectators.end(); ++itrVarInfo, ++counter) {
519  if((specName==itrVarInfo->GetLabel()) || (specName==itrVarInfo->GetExpression())) {
520  spectatorIdx=counter;
521  fCategorySpecIdx.push_back(spectatorIdx);
522  break;
523  }
524  }
525 
526  subMethodNode = gTools().GetNextChild(subMethodNode);
527  }
528 
530 
531 }
532 
533 ////////////////////////////////////////////////////////////////////////////////
534 /// process user options
535 
537 {
538 }
539 
540 ////////////////////////////////////////////////////////////////////////////////
541 /// Get help message text
542 ///
543 /// typical length of text line:
544 /// "|--------------------------------------------------------------|"
545 
547 {
548  Log() << Endl;
549  Log() << gTools().Color("bold") << "--- Short description:" << gTools().Color("reset") << Endl;
550  Log() << Endl;
551  Log() << "This method allows to define different categories of events. The" <<Endl;
552  Log() << "categories are defined via cuts on the variables. For each" << Endl;
553  Log() << "category, a different classifier and set of variables can be" <<Endl;
554  Log() << "specified. The categories which are defined for this method must" << Endl;
555  Log() << "be disjoint." << Endl;
556 }
557 
558 ////////////////////////////////////////////////////////////////////////////////
559 /// no ranking
560 
562 {
563  return 0;
564 }
565 
566 ////////////////////////////////////////////////////////////////////////////////
567 
569 {
570  // if it's not a simple 'spectator' variable (0 or 1) that the categories are defined by
571  // (but rather some 'formula' (i.e. eta>0), then this formulas are stored in fCatTree and that
572  // one will be evaluated.. (the formulae return 'true' or 'false'
573  if (fCatTree) {
574  if (methodIdx>=fCatFormulas.size()) {
575  Log() << kFATAL << "Large method index " << methodIdx << ", number of category formulas = "
576  << fCatFormulas.size() << Endl;
577  }
578  TTreeFormula* f = fCatFormulas[methodIdx];
579  return f->EvalInstance(0) > 0.5;
580  }
581  // otherwise, it simply looks if "variable == true" ("greater 0.5 to be "sure" )
582  else {
583 
584  // checks whether an event lies within a cut
585  if (methodIdx>=fCategorySpecIdx.size()) {
586  Log() << kFATAL << "Unknown method index " << methodIdx << " maximum allowed index="
587  << fCategorySpecIdx.size() << Endl;
588  }
589  UInt_t spectatorIdx = fCategorySpecIdx[methodIdx];
590  Float_t specVal = ev->GetSpectator(spectatorIdx);
591  Bool_t pass = (specVal>0.5);
592  return pass;
593  }
594 }
595 
596 ////////////////////////////////////////////////////////////////////////////////
597 /// returns the mva value of the right sub-classifier
598 
600 {
601  if (fMethods.empty()) return 0;
602 
603  UInt_t methodToUse = 0;
604  const Event* ev = GetEvent();
605 
606  // determine which sub-classifier to use for this event
607  Int_t suitableCutsN = 0;
608 
609  for (UInt_t i=0; i<fMethods.size(); ++i) {
610  if (PassesCut(ev, i)) {
611  ++suitableCutsN;
612  methodToUse=i;
613  }
614  }
615 
616  if (suitableCutsN == 0) {
617  Log() << kWARNING << "Event does not lie within the cut of any sub-classifier." << Endl;
618  return 0;
619  }
620 
621  if (suitableCutsN > 1) {
622  Log() << kFATAL << "The defined categories are not disjoint." << Endl;
623  return 0;
624  }
625 
626  // get mva value from the suitable sub-classifier
627  ev->SetVariableArrangement(&fVarMaps[methodToUse]);
628  Double_t mvaValue = dynamic_cast<MethodBase*>(fMethods[methodToUse])->GetMvaValue(ev,err,errUpper);
629  ev->SetVariableArrangement(0);
630 
631  return mvaValue;
632 }
633 
634 
635 
636 ////////////////////////////////////////////////////////////////////////////////
637 /// returns the mva value of the right sub-classifier
638 
639 const std::vector<Float_t> &TMVA::MethodCategory::GetRegressionValues()
640 {
641  if (fMethods.empty()) return MethodBase::GetRegressionValues();
642 
643  UInt_t methodToUse = 0;
644  const Event* ev = GetEvent();
645 
646  // determine which sub-classifier to use for this event
647  Int_t suitableCutsN = 0;
648 
649  for (UInt_t i=0; i<fMethods.size(); ++i) {
650  if (PassesCut(ev, i)) {
651  ++suitableCutsN;
652  methodToUse=i;
653  }
654  }
655 
656  if (suitableCutsN == 0) {
657  Log() << kWARNING << "Event does not lie within the cut of any sub-classifier." << Endl;
659  }
660 
661  if (suitableCutsN > 1) {
662  Log() << kFATAL << "The defined categories are not disjoint." << Endl;
664  }
665  MethodBase* meth = dynamic_cast<MethodBase*>(fMethods[methodToUse]);
666  if (!meth){
667  Log() << kFATAL << "method not found in Category Regression method" << Endl;
669  }
670  // get mva value from the suitable sub-classifier
671  return meth->GetRegressionValues(ev);
672 }
673 
IMethod * Create(const std::string &name, const TString &job, const TString &title, DataSetInfo &dsi, const TString &option)
creates the method if needed based on the method name using the creator function the factory has stor...
static ClassifierFactory & Instance()
access to the ClassifierFactory singleton creates the instance if needed
Types::EAnalysisType fAnalysisType
Definition: MethodBase.h:589
std::vector< IMethod * > fMethods
void SetModelPersistence(Bool_t status)
Definition: MethodBase.h:378
virtual const char * GetName() const
Returns name of object.
Definition: TNamed.h:51
void Init()
initialize the method
void WriteStateToXML(void *parent) const
general method used in writing the header of the weight files where the used variables, variable transformation type etc.
TMVA::IMethod * AddMethod(const TCut &, const TString &theVariables, Types::EMVA theMethod, const TString &theTitle, const TString &theOptions)
adds sub-classifier for a category
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:162
VariableInfo & AddTarget(const TString &expression, const TString &title, const TString &unit, Double_t min, Double_t max, Bool_t normalized=kTRUE, void *external=0)
add a variable (can be a complex expression) to the set of variables used in the MV analysis ...
void ReadStateFromXML(void *parent)
std::vector< VariableInfo > & GetSpectatorInfos()
Definition: DataSetInfo.h:122
float Float_t
Definition: RtypesCore.h:53
void SetCut(const TCut &cut, const TString &className)
set the cut for the classes
void InitCircularTree(const DataSetInfo &dsi)
initialize the circular tree
static Types & Instance()
the the single instance of "Types" if existin already, or create it (Signleton)
Definition: Types.cxx:64
void SetMethodBaseDir(TDirectory *methodDir)
Definition: MethodBase.h:370
MsgLogger & Log() const
Definition: Configurable.h:128
EAnalysisType
Definition: Types.h:129
std::vector< TCut > fCategoryCuts
const std::vector< Float_t > & GetRegressionValues(const TMVA::Event *const ev)
Definition: MethodBase.h:217
std::vector< UInt_t > fCategorySpecIdx
Ssiz_t Index(const char *pat, Ssiz_t i=0, ECaseCompare cmp=kExact) const
Definition: TString.h:582
virtual const std::vector< Float_t > & GetRegressionValues()
Definition: MethodBase.h:224
Basic string class.
Definition: TString.h:137
int Int_t
Definition: RtypesCore.h:41
bool Bool_t
Definition: RtypesCore.h:59
const Bool_t kFALSE
Definition: Rtypes.h:92
UInt_t GetNClasses() const
Definition: DataSetInfo.h:154
virtual ~MethodCategory(void)
destructor
void AddCut(const TCut &cut, const TString &className)
set the cut for the classes
void SetSilentFile(Bool_t status)
Definition: MethodBase.h:374
virtual Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets)=0
void AddAttr(void *node, const char *, const T &value, Int_t precision=16)
Definition: Tools.h:309
void * AddChild(void *parent, const char *childname, const char *content=0, bool isRootNode=false)
add child node
Definition: Tools.cxx:1134
const TString & GetExpression() const
Definition: VariableInfo.h:65
VariableInfo & AddSpectator(const TString &expression, const TString &title, const TString &unit, Double_t min, Double_t max, char type='F', Bool_t normalized=kTRUE, void *external=0)
add a spectator (can be a complex expression) to the set of spectator variables used in the MV analys...
Tools & gTools()
Definition: Tools.cxx:79
const TString & GetNormalization() const
Definition: DataSetInfo.h:132
char GetVarType() const
Definition: VariableInfo.h:69
std::vector< std::vector< UInt_t > > fVarMaps
void DeclareOptions()
options for this method
const Event * GetEvent() const
Definition: MethodBase.h:745
DataSet * Data() const
Definition: MethodBase.h:405
void * GetChild(void *parent, const char *childname=0)
get child node
Definition: Tools.cxx:1158
void SetVariableArrangement(std::vector< UInt_t > *const m) const
set the variable arrangement
Definition: Event.cxx:188
virtual void ParseOptions()
options parser
void SetupMethod()
setup of methods
Definition: MethodBase.cxx:403
DataSetInfo & DataInfo() const
Definition: MethodBase.h:406
Long64_t GetNTrainingEvents() const
Definition: DataSet.h:93
std::vector< VariableInfo > & GetTargetInfos()
Definition: DataSetInfo.h:117
Bool_t fModelPersistence
Definition: MethodBase.h:627
TDirectory * GetRootDir() const
Definition: DataSetInfo.h:189
Used to pass a selection expression to the Tree drawing routine.
Definition: TTreeFormula.h:64
A specialized string object used for TTree selections.
Definition: TCut.h:27
std::vector< TTreeFormula * > fCatFormulas
needed in conjunction with TTreeFormulas for evaluation category expressions
MethodCategory(const TString &jobName, const TString &methodTitle, DataSetInfo &theData, const TString &theOption="")
standard constructor
const Int_t MinNoTrainingEvents
Definition: Factory.cxx:97
void SetSplitOptions(const TString &so)
Definition: DataSetInfo.h:184
const Ranking * CreateRanking()
no ranking
void * GetExternalLink() const
Definition: VariableInfo.h:89
UInt_t GetNTargets() const
Definition: DataSetInfo.h:129
TMVA::DataSetInfo & CreateCategoryDSI(const TCut &, const TString &, const TString &)
create a DataSetInfo object for a sub-classifier
const char * GetName() const
Definition: MethodBase.h:330
ClassInfo * GetClassInfo(Int_t clNum) const
void SetWeightExpression(const TString &exp, const TString &className="")
set the weight expressions for the classes if class name is specified, set only for this class if cla...
DataSetInfo & AddDataSetInfo(DataSetInfo &dsi)
stores a copy of the dataset info object
unsigned int UInt_t
Definition: RtypesCore.h:42
char * Form(const char *fmt,...)
DataSetManager * fDataSetManager
Ssiz_t Length() const
Definition: TString.h:390
const TString & GetJobName() const
Definition: MethodBase.h:326
const TString & GetMethodName() const
Definition: MethodBase.h:327
void Train(void)
train all sub-classifiers
void ReadAttr(void *node, const char *, T &value)
Definition: Tools.h:296
Bool_t IsSilentFile()
Definition: MethodBase.h:375
virtual void Print() const
get maximum length of variable names
Definition: Ranking.cxx:111
virtual void CheckSetup()
check may be overridden by derived class (sometimes, eg, fitters are used which can only be implement...
Definition: MethodBase.cxx:430
#define ClassImp(name)
Definition: Rtypes.h:279
double f(double x)
double Double_t
Definition: RtypesCore.h:55
Ssiz_t Last(char c) const
Find last occurrence of a character c.
Definition: TString.cxx:865
Describe directory structure in memory.
Definition: TDirectory.h:44
int type
Definition: TGX11.cxx:120
void SetFile(TFile *file)
Definition: MethodBase.h:371
void * GetNextChild(void *prevchild, const char *childname=0)
XML helpers.
Definition: Tools.cxx:1170
Bool_t Contains(const char *pat, ECaseCompare cmp=kExact) const
Definition: TString.h:567
virtual Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t)
check whether method category has analysis type the method type has to be the same for all sub-method...
T EvalInstance(Int_t i=0, const char *stringStack[]=0)
Evaluate this treeformula.
ClassInfo * AddClass(const TString &className)
void ProcessSetup()
process all options the "CheckForUnusedOptions" is done in an independent call, since it may be overr...
Definition: MethodBase.cxx:420
std::vector< TString > fVars
void GetHelpMessage() const
Get help message text.
const TString & Color(const TString &)
human readable color strings
Definition: Tools.cxx:837
Double_t GetMvaValue(Double_t *err=0, Double_t *errUpper=0)
returns the mva value of the right sub-classifier
virtual Int_t Branch(TCollection *list, Int_t bufsize=32000, Int_t splitlevel=99, const char *name="")
Create one branch for each element in the collection.
Definition: TTree.cxx:1652
#define REGISTER_METHOD(CLASS)
for example
Abstract ClassifierFactory template that handles arbitrary types.
const TString & GetSplitOptions() const
Definition: DataSetInfo.h:185
TDirectory * BaseDir() const
returns the ROOT directory where info/histograms etc of the corresponding MVA method instance are sto...
TString GetMethodTypeName() const
Definition: MethodBase.h:328
const TCut & GetCut(Int_t i) const
Definition: DataSetInfo.h:167
void SetWeightFileDir(TString fileDir)
set directory of weight file
void AddWeightsXMLTo(void *parent) const
create XML description of Category classifier
virtual TDirectory * GetDirectory(const char *namecycle, Bool_t printError=false, const char *funcname="GetDirectory")
Find a directory using apath.
Definition: TDirectory.cxx:338
#define NULL
Definition: Rtypes.h:82
std::vector< TString > SplitString(const TString &theOpt, const char separator) const
splits the option string at &#39;separator&#39; and fills the list &#39;splitV&#39; with the primitive strings ...
Definition: Tools.cxx:1207
virtual void SetCircular(Long64_t maxEntries)
Enable/Disable circularity for this tree.
Definition: TTree.cxx:8252
VariableInfo & AddVariable(const TString &expression, const TString &title="", const TString &unit="", Double_t min=0, Double_t max=0, char varType='F', Bool_t normalized=kTRUE, void *external=0)
add a variable (can be a complex expression) to the set of variables used in the MV analysis ...
void DisableWriting(Bool_t setter)
Definition: MethodBase.h:438
Types::EAnalysisType GetAnalysisType() const
Definition: MethodBase.h:433
A TTree object has a header with a name and a title.
Definition: TTree.h:98
void ReadWeightsFromXML(void *wghtnode)
read weights of sub-classifiers of MethodCategory from xml weight file
Bool_t PassesCut(const Event *ev, UInt_t methodIdx)
const Bool_t kTRUE
Definition: Rtypes.h:91
TString fFileDir
Definition: MethodBase.h:631
Float_t GetSpectator(UInt_t ivar) const
return spectator content
Definition: Event.cxx:258
const TString GetWeightExpression(Int_t i) const
Definition: DataSetInfo.h:163
double norm(double *x, double *p)
Definition: unuranDistr.cxx:40
void SetNormalization(const TString &norm)
Definition: DataSetInfo.h:133
void SetRootDir(TDirectory *d)
Definition: DataSetInfo.h:188
virtual const std::vector< Float_t > & GetRegressionValues()
returns the mva value of the right sub-classifier
std::vector< VariableInfo > & GetVariableInfos()
Definition: DataSetInfo.h:112
void variables(TString dataset, TString fin="TMVA.root", TString dirName="InputVariables_Id", TString title="TMVA Input Variables", Bool_t isRegression=kFALSE, Bool_t useTMVAStyle=kTRUE)
Definition: variables.cxx:10
virtual void SetAnalysisType(Types::EAnalysisType type)
Definition: MethodBase.h:432
virtual const char * GetTitle() const
Returns title of object.
Definition: TNamed.h:52
void ProcessOptions()
process user options
const char * Data() const
Definition: TString.h:349