ROOT  6.06/09
Reference Guide
MethodCategory.cxx
Go to the documentation of this file.
1 // @(#)root/tmva $Id$
2 // Author: Andreas Hoecker, Joerg Stelzer, Helge Voss, Eckhard von Toerne
3 
4 /**********************************************************************************
5  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6  * Package: TMVA *
7  * Class : MethodCompositeBase *
8  * Web : http://tmva.sourceforge.net *
9  * *
10  * Description: *
11  * Virtual base class for all MVA method *
12  * *
13  * Authors (alphabetical): *
14  * Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland *
15  * Nadim Sah <Nadim.Sah@cern.ch> - Berlin, Germany *
16  * Peter Speckmayer <Peter.Speckmazer@cern.ch> - CERN, Switzerland *
17  * Joerg Stelzer <Joerg.Stelzer@cern.ch> - MSU East Lansing, USA *
18  * Helge Voss <Helge.Voss@cern.ch> - MPI-K Heidelberg, Germany *
19  * Jan Therhaag <Jan.Therhaag@cern.ch> - U of Bonn, Germany *
20  * Eckhard v. Toerne <evt@uni-bonn.de> - U of Bonn, Germany *
21  * *
22  * Copyright (c) 2005-2011: *
23  * CERN, Switzerland *
24  * MSU East Lansing, USA *
25  * MPI-K Heidelberg, Germany *
26  * U. of Bonn, Germany *
27  * *
28  * Redistribution and use in source and binary forms, with or without *
29  * modification, are permitted according to the terms listed in LICENSE *
30  * (http://tmva.sourceforge.net/LICENSE) *
31  **********************************************************************************/
32 
33 //__________________________________________________________________________
34 //
35 // This class is meant to allow categorisation of the data. For different //
36 // categories, different classifiers may be booked and different variab- //
37 // les may be considered. The aim is to account for the difference that //
38 // is due to different locations/angles. //
39 ////////////////////////////////////////////////////////////////////////////////
40 
41 #include <algorithm>
42 #include <iomanip>
43 #include <vector>
44 #include <iostream>
45 
46 #include "Riostream.h"
47 #include "TRandom3.h"
48 #include "TMath.h"
49 #include "TObjString.h"
50 #include "TH1F.h"
51 #include "TGraph.h"
52 #include "TSpline.h"
53 #include "TDirectory.h"
54 #include "TTreeFormula.h"
55 
56 #include "TMVA/MethodCategory.h"
57 #include "TMVA/Tools.h"
58 #include "TMVA/ClassifierFactory.h"
59 #include "TMVA/Timer.h"
60 #include "TMVA/Types.h"
61 #include "TMVA/PDF.h"
62 #include "TMVA/Config.h"
63 #include "TMVA/Ranking.h"
64 #include "TMVA/VariableInfo.h"
65 #include "TMVA/DataSetManager.h"
67 
68 REGISTER_METHOD(Category)
69 
70 ClassImp(TMVA::MethodCategory)
71 
72 ////////////////////////////////////////////////////////////////////////////////
73 /// standard constructor
74 
75  TMVA::MethodCategory::MethodCategory( const TString& jobName,
76  const TString& methodTitle,
77  DataSetInfo& theData,
78  const TString& theOption,
79  TDirectory* theTargetDir )
80  : TMVA::MethodCompositeBase( jobName, Types::kCategory, methodTitle, theData, theOption, theTargetDir ),
81  fCatTree(0),
82  fDataSetManager(NULL)
83 {
84 }
85 
86 ////////////////////////////////////////////////////////////////////////////////
87 /// constructor from weight file
88 
90  const TString& theWeightFile,
91  TDirectory* theTargetDir )
92  : TMVA::MethodCompositeBase( Types::kCategory, dsi, theWeightFile, theTargetDir ),
93  fCatTree(0),
94  fDataSetManager(NULL)
95 {
96 }
97 
98 ////////////////////////////////////////////////////////////////////////////////
99 /// destructor
100 
102 {
103  std::vector<TTreeFormula*>::iterator formIt = fCatFormulas.begin();
104  std::vector<TTreeFormula*>::iterator lastF = fCatFormulas.end();
105  for(;formIt!=lastF; ++formIt) delete *formIt;
106  delete fCatTree;
107 }
108 
109 ////////////////////////////////////////////////////////////////////////////////
110 /// check whether method category has analysis type
111 /// the method type has to be the same for all sub-methods
112 
114 {
115  std::vector<IMethod*>::iterator itrMethod = fMethods.begin();
116 
117  // iterate over methods and check whether they have the analysis type
118  for(; itrMethod != fMethods.end(); ++itrMethod ) {
119  if ( !(*itrMethod)->HasAnalysisType(type, numberClasses, numberTargets) )
120  return kFALSE;
121  }
122  return kTRUE;
123 }
124 
125 ////////////////////////////////////////////////////////////////////////////////
126 /// options for this method
127 
129 {
130 }
131 
132 ////////////////////////////////////////////////////////////////////////////////
133 /// adds sub-classifier for a category
134 
136  const TString& theVariables,
137  Types::EMVA theMethod ,
138  const TString& theTitle,
139  const TString& theOptions )
140 {
141  std::string addedMethodName = std::string(Types::Instance().GetMethodName(theMethod));
142 
143  Log() << kINFO << "Adding sub-classifier: " << addedMethodName << "::" << theTitle << Endl;
144 
145  DataSetInfo& dsi = CreateCategoryDSI(theCut, theVariables, theTitle);
146 
147  IMethod* addedMethod = ClassifierFactory::Instance().Create(addedMethodName,GetJobName(),theTitle,dsi,theOptions);
148 
149  MethodBase *method = (dynamic_cast<MethodBase*>(addedMethod));
150  if(method==0) return 0;
151 
152  method->SetAnalysisType( fAnalysisType );
153  method->SetupMethod();
154  method->ParseOptions();
155  method->ProcessSetup();
156 
157  // set or create correct method base dir for added method
158  const TString dirName(Form("Method_%s",method->GetMethodTypeName().Data()));
159  TDirectory * dir = BaseDir()->GetDirectory(dirName);
160  if (dir != 0) method->SetMethodBaseDir( dir );
161  else method->SetMethodBaseDir( BaseDir()->mkdir(dirName,Form("Directory for all %s methods", method->GetMethodTypeName().Data())) );
162 
163  // method->SetBaseDir(eigenes base dir, gucken ob Fisher dir existiert, sonst erzeugen )
164 
165  // check-for-unused-options is performed; may be overridden by derived
166  // classes
167  method->CheckSetup();
168 
169  // disable writing of XML files and standalone classes for sub methods
170  method->DisableWriting( kTRUE );
171 
172  // store method, cut and variable names and create cut formula
173  fMethods.push_back(method);
174  fCategoryCuts.push_back(theCut);
175  fVars.push_back(theVariables);
176 
177  DataSetInfo& primaryDSI = DataInfo();
178 
179  UInt_t newSpectatorIndex = primaryDSI.GetSpectatorInfos().size();
180  fCategorySpecIdx.push_back(newSpectatorIndex);
181 
182  primaryDSI.AddSpectator( Form("%s_cat%i:=%s", GetName(),(int)fMethods.size(),theCut.GetTitle()),
183  Form("%s:%s",GetName(),method->GetName()),
184  "pass", 0, 0, 'C' );
185 
186  return method;
187 }
188 
189 ////////////////////////////////////////////////////////////////////////////////
190 /// create a DataSetInfo object for a sub-classifier
191 
193  const TString& theVariables,
194  const TString& theTitle)
195 {
196  // create a new dsi with name: theTitle+"_dsi"
197  TString dsiName=theTitle+"_dsi";
198  DataSetInfo& oldDSI = DataInfo();
199  DataSetInfo* dsi = new DataSetInfo(dsiName);
200 
201  // register the new dsi
202  // DataSetManager::Instance().AddDataSetInfo(*dsi); // DSMTEST replaced by following line
203  fDataSetManager->AddDataSetInfo(*dsi);
204 
205  // copy the targets and spectators from the old dsi to the new dsi
206  std::vector<VariableInfo>::iterator itrVarInfo;
207 
208  for (itrVarInfo = oldDSI.GetTargetInfos().begin(); itrVarInfo != oldDSI.GetTargetInfos().end(); itrVarInfo++)
209  dsi->AddTarget(*itrVarInfo);
210 
211  for (itrVarInfo = oldDSI.GetSpectatorInfos().begin(); itrVarInfo != oldDSI.GetSpectatorInfos().end(); itrVarInfo++)
212  dsi->AddSpectator(*itrVarInfo);
213 
214  // split string that contains the variables into tiny little pieces
215  std::vector<TString> variables = gTools().SplitString(theVariables,':' );
216 
217  // prepare to create varMap
218  std::vector<UInt_t> varMap;
219  Int_t counter=0;
220 
221  // add the variables that were specified in theVariables
222  std::vector<TString>::iterator itrVariables;
223  Bool_t found = kFALSE;
224 
225  // iterate over all variables in 'variables' and add them
226  for (itrVariables = variables.begin(); itrVariables != variables.end(); itrVariables++) {
227  counter=0;
228 
229  // check the variables of the old dsi for the variable that we want to add
230  for (itrVarInfo = oldDSI.GetVariableInfos().begin(); itrVarInfo != oldDSI.GetVariableInfos().end(); itrVarInfo++) {
231  if((*itrVariables==itrVarInfo->GetLabel()) ) { // || (*itrVariables==itrVarInfo->GetExpression())) {
232  // don't compare the expression, since the user might take two times the same expression, but with different labels
233  // and apply different transformations to the variables.
234  dsi->AddVariable(*itrVarInfo);
235  varMap.push_back(counter);
236  found = kTRUE;
237  }
238  counter++;
239  }
240 
241  // check the spectators of the old dsi for the variable that we want to add
242  for (itrVarInfo = oldDSI.GetSpectatorInfos().begin(); itrVarInfo != oldDSI.GetSpectatorInfos().end(); itrVarInfo++) {
243  if((*itrVariables==itrVarInfo->GetLabel()) ) { // || (*itrVariables==itrVarInfo->GetExpression())) {
244  // don't compare the expression, since the user might take two times the same expression, but with different labels
245  // and apply different transformations to the variables.
246  dsi->AddVariable(*itrVarInfo);
247  varMap.push_back(counter);
248  found = kTRUE;
249  }
250  counter++;
251  }
252 
253  // if the variable is neither in the variables nor in the spectators, we abort
254  if (!found) {
255  Log() << kFATAL <<"The variable " << itrVariables->Data() << " was not found and could not be added " << Endl;
256  }
257  found = kFALSE;
258  }
259 
260  // in the case that no variables are specified, add the default-variables from the original dsi
261  if (theVariables=="") {
262  for (UInt_t i=0; i<oldDSI.GetVariableInfos().size(); i++) {
263  dsi->AddVariable(oldDSI.GetVariableInfos()[i]);
264  varMap.push_back(i);
265  }
266  }
267 
268  // add the variable map 'varMap' to the vector of varMaps
269  fVarMaps.push_back(varMap);
270 
271  // set classes and cuts
272  UInt_t nClasses=oldDSI.GetNClasses();
273  TString className;
274 
275  for (UInt_t i=0; i<nClasses; i++) {
276  className = oldDSI.GetClassInfo(i)->GetName();
277  dsi->AddClass(className);
278  dsi->SetCut(oldDSI.GetCut(i),className);
279  dsi->AddCut(theCut,className);
280  dsi->SetWeightExpression(oldDSI.GetWeightExpression(i),className);
281  }
282 
283  // set split options, root dir and normalization for the new dsi
284  dsi->SetSplitOptions(oldDSI.GetSplitOptions());
285  dsi->SetRootDir(oldDSI.GetRootDir());
286  TString norm(oldDSI.GetNormalization().Data());
287  dsi->SetNormalization(norm);
288 
289  DataSetInfo& dsiReference= (*dsi);
290 
291  return dsiReference;
292 }
293 
294 ////////////////////////////////////////////////////////////////////////////////
295 /// initialize the method
296 
298 {
299 }
300 
301 ////////////////////////////////////////////////////////////////////////////////
302 /// initialize the circular tree
303 
305 {
306  delete fCatTree;
307 
308  std::vector<VariableInfo>::const_iterator viIt;
309  const std::vector<VariableInfo>& vars = dsi.GetVariableInfos();
310  const std::vector<VariableInfo>& specs = dsi.GetSpectatorInfos();
311 
312  Bool_t hasAllExternalLinks = kTRUE;
313  for (viIt = vars.begin(); viIt != vars.end(); ++viIt)
314  if( viIt->GetExternalLink() == 0 ) {
315  hasAllExternalLinks = kFALSE;
316  break;
317  }
318  for (viIt = specs.begin(); viIt != specs.end(); ++viIt)
319  if( viIt->GetExternalLink() == 0 ) {
320  hasAllExternalLinks = kFALSE;
321  break;
322  }
323 
324  if(!hasAllExternalLinks) return;
325 
326  {
327  // Rather than having TTree::TTree add to the current directory and then remove it, let
328  // make sure to not add it in the first place.
329  // The add-then-remove can lead to a problem if gDirectory points to the same directory (for example
330  // gROOT) in the current thread and another one (and both try to add to the directory at the same time).
331  TDirectory::TContext ctxt(nullptr);
332  fCatTree = new TTree(Form("Circ%s",GetMethodName().Data()),"Circlar Tree for categorization");
333  fCatTree->SetCircular(1);
334  }
335 
336  for (viIt = vars.begin(); viIt != vars.end(); ++viIt) {
337  const VariableInfo& vi = *viIt;
338  fCatTree->Branch(vi.GetExpression(),(Float_t*)vi.GetExternalLink(), TString(vi.GetExpression())+TString("/F"));
339  }
340  for (viIt = specs.begin(); viIt != specs.end(); ++viIt) {
341  const VariableInfo& vi = *viIt;
342  if(vi.GetVarType()=='C') continue;
343  fCatTree->Branch(vi.GetExpression(),(Float_t*)vi.GetExternalLink(), TString(vi.GetExpression())+TString("/F"));
344  }
345 
346  for(UInt_t cat=0; cat!=fCategoryCuts.size(); ++cat) {
347  fCatFormulas.push_back(new TTreeFormula(Form("Category_%i",cat), fCategoryCuts[cat].GetTitle(), fCatTree));
348  }
349 }
350 
351 ////////////////////////////////////////////////////////////////////////////////
352 /// train all sub-classifiers
353 
355 {
356  // specify the minimum # of training events and set 'classification'
357  const Int_t MinNoTrainingEvents = 10;
358 
359  Types::EAnalysisType analysisType = GetAnalysisType();
360 
361  // start the training
362  Log() << kINFO << "Train all sub-classifiers for "
363  << (analysisType == Types::kRegression ? "Regression" : "Classification") << " ..." << Endl;
364 
365  // don't do anything if no sub-classifier booked
366  if (fMethods.empty()) {
367  Log() << kINFO << "...nothing found to train" << Endl;
368  return;
369  }
370 
371  std::vector<IMethod*>::iterator itrMethod;
372 
373  // iterate over all booked sub-classifiers and train them
374  for (itrMethod = fMethods.begin(); itrMethod != fMethods.end(); ++itrMethod ) {
375 
376  MethodBase* mva = dynamic_cast<MethodBase*>(*itrMethod);
377  if(!mva) continue;
378  mva->SetAnalysisType( analysisType );
379  if (!mva->HasAnalysisType( analysisType,
380  mva->DataInfo().GetNClasses(),
381  mva->DataInfo().GetNTargets() ) ) {
382  Log() << kWARNING << "Method " << mva->GetMethodTypeName() << " is not capable of handling " ;
383  if (analysisType == Types::kRegression)
384  Log() << "regression with " << mva->DataInfo().GetNTargets() << " targets." << Endl;
385  else
386  Log() << "classification with " << mva->DataInfo().GetNClasses() << " classes." << Endl;
387  itrMethod = fMethods.erase( itrMethod );
388  continue;
389  }
390  if (mva->Data()->GetNTrainingEvents() >= MinNoTrainingEvents) {
391 
392  Log() << kINFO << "Train method: " << mva->GetMethodName() << " for "
393  << (analysisType == Types::kRegression ? "Regression" : "Classification") << Endl;
394  mva->TrainMethod();
395  Log() << kINFO << "Training finished" << Endl;
396 
397  } else {
398 
399  Log() << kWARNING << "Method " << mva->GetMethodName()
400  << " not trained (training tree has less entries ["
401  << mva->Data()->GetNTrainingEvents()
402  << "] than required [" << MinNoTrainingEvents << "]" << Endl;
403 
404  Log() << kERROR << " w/o training/test events for that category, I better stop here and let you fix " << Endl;
405  Log() << kFATAL << "that one first, otherwise things get too messy later ... " << Endl;
406 
407  }
408  }
409 
410  if (analysisType != Types::kRegression) {
411 
412  // variable ranking
413  Log() << kINFO << "Begin ranking of input variables..." << Endl;
414  for (itrMethod = fMethods.begin(); itrMethod != fMethods.end(); itrMethod++) {
415  MethodBase* mva = dynamic_cast<MethodBase*>(*itrMethod);
416  if (mva && mva->Data()->GetNTrainingEvents() >= MinNoTrainingEvents) {
417  const Ranking* ranking = (*itrMethod)->CreateRanking();
418  if (ranking != 0)
419  ranking->Print();
420  else
421  Log() << kINFO << "No variable ranking supplied by classifier: "
422  << dynamic_cast<MethodBase*>(*itrMethod)->GetMethodName() << Endl;
423  }
424  }
425  }
426 }
427 
428 ////////////////////////////////////////////////////////////////////////////////
429 /// create XML description of Category classifier
430 
431 void TMVA::MethodCategory::AddWeightsXMLTo( void* parent ) const
432 {
433  void* wght = gTools().AddChild(parent, "Weights");
434  gTools().AddAttr( wght, "NSubMethods", fMethods.size() );
435  void* submethod(0);
436 
437  std::vector<IMethod*>::iterator itrMethod;
438 
439  // iterate over methods and write them to XML file
440  for (UInt_t i=0; i<fMethods.size(); i++) {
441  MethodBase* method = dynamic_cast<MethodBase*>(fMethods[i]);
442  submethod = gTools().AddChild(wght, "SubMethod");
443  gTools().AddAttr(submethod, "Index", i);
444  gTools().AddAttr(submethod, "Method", method->GetMethodTypeName() + "::" + method->GetMethodName());
445  gTools().AddAttr(submethod, "Cut", fCategoryCuts[i]);
446  gTools().AddAttr(submethod, "Variables", fVars[i]);
447  method->WriteStateToXML( submethod );
448  }
449 }
450 
451 ////////////////////////////////////////////////////////////////////////////////
452 /// read weights of sub-classifiers of MethodCategory from xml weight file
453 
455 {
456  UInt_t nSubMethods;
457  TString fullMethodName;
458  TString methodType;
459  TString methodTitle;
460  TString theCutString;
461  TString theVariables;
462  Int_t titleLength;
463  gTools().ReadAttr( wghtnode, "NSubMethods", nSubMethods );
464  void* subMethodNode = gTools().GetChild(wghtnode);
465 
466  Log() << kINFO << "Recreating sub-classifiers from XML-file " << Endl;
467 
468  // recreate all sub-methods from weight file
469  for (UInt_t i=0; i<nSubMethods; i++) {
470  gTools().ReadAttr( subMethodNode, "Method", fullMethodName );
471  gTools().ReadAttr( subMethodNode, "Cut", theCutString );
472  gTools().ReadAttr( subMethodNode, "Variables", theVariables );
473 
474  // determine sub-method type
475  methodType = fullMethodName(0,fullMethodName.Index("::"));
476  if (methodType.Contains(" ")) methodType = methodType(methodType.Last(' ')+1,methodType.Length());
477 
478  // determine sub-method title
479  titleLength = fullMethodName.Length()-fullMethodName.Index("::")-2;
480  methodTitle = fullMethodName(fullMethodName.Index("::")+2,titleLength);
481 
482  // reconstruct dsi for sub-method
483  DataSetInfo& dsi = CreateCategoryDSI(TCut(theCutString), theVariables, methodTitle);
484 
485  // recreate sub-method from weights and add to fMethods
486  MethodBase* method = dynamic_cast<MethodBase*>( ClassifierFactory::Instance().Create( methodType.Data(),
487  dsi, "none" ) );
488  if(method==0)
489  Log() << kFATAL << "Could not create sub-method " << method << " from XML." << Endl;
490 
491  method->SetupMethod();
492  method->ReadStateFromXML(subMethodNode);
493 
494  fMethods.push_back(method);
495  fCategoryCuts.push_back(TCut(theCutString));
496  fVars.push_back(theVariables);
497 
498  DataSetInfo& primaryDSI = DataInfo();
499 
500  UInt_t spectatorIdx = 10000;
501  UInt_t counter=0;
502 
503  // find the spectator index
504  std::vector<VariableInfo>& spectators=primaryDSI.GetSpectatorInfos();
505  std::vector<VariableInfo>::iterator itrVarInfo;
506  TString specName= Form("%s_cat%i", GetName(),(int)fCategorySpecIdx.size()+1);
507 
508  for (itrVarInfo = spectators.begin(); itrVarInfo != spectators.end(); ++itrVarInfo, ++counter) {
509  if((specName==itrVarInfo->GetLabel()) || (specName==itrVarInfo->GetExpression())) {
510  spectatorIdx=counter;
511  fCategorySpecIdx.push_back(spectatorIdx);
512  break;
513  }
514  }
515 
516  subMethodNode = gTools().GetNextChild(subMethodNode);
517  }
518 
519  InitCircularTree(DataInfo());
520 
521 }
522 
523 ////////////////////////////////////////////////////////////////////////////////
524 /// process user options
525 
527 {
528 }
529 
530 ////////////////////////////////////////////////////////////////////////////////
531 /// Get help message text
532 ///
533 /// typical length of text line:
534 /// "|--------------------------------------------------------------|"
535 
537 {
538  Log() << Endl;
539  Log() << gTools().Color("bold") << "--- Short description:" << gTools().Color("reset") << Endl;
540  Log() << Endl;
541  Log() << "This method allows to define different categories of events. The" <<Endl;
542  Log() << "categories are defined via cuts on the variables. For each" << Endl;
543  Log() << "category, a different classifier and set of variables can be" <<Endl;
544  Log() << "specified. The categories which are defined for this method must" << Endl;
545  Log() << "be disjoint." << Endl;
546 }
547 
548 ////////////////////////////////////////////////////////////////////////////////
549 /// no ranking
550 
552 {
553  return 0;
554 }
555 
556 ////////////////////////////////////////////////////////////////////////////////
557 
559 {
560  // if it's not a simple 'spectator' variable (0 or 1) that the categories are defined by
561  // (but rather some 'formula' (i.e. eta>0), then this formulas are stored in fCatTree and that
562  // one will be evaluated.. (the formulae return 'true' or 'false'
563  if (fCatTree) {
564  if (methodIdx>=fCatFormulas.size()) {
565  Log() << kFATAL << "Large method index " << methodIdx << ", number of category formulas = "
566  << fCatFormulas.size() << Endl;
567  }
568  TTreeFormula* f = fCatFormulas[methodIdx];
569  return f->EvalInstance(0) > 0.5;
570  }
571  // otherwise, it simply looks if "variable == true" ("greater 0.5 to be "sure" )
572  else {
573 
574  // checks whether an event lies within a cut
575  if (methodIdx>=fCategorySpecIdx.size()) {
576  Log() << kFATAL << "Unknown method index " << methodIdx << " maximum allowed index="
577  << fCategorySpecIdx.size() << Endl;
578  }
579  UInt_t spectatorIdx = fCategorySpecIdx[methodIdx];
580  Float_t specVal = ev->GetSpectator(spectatorIdx);
581  Bool_t pass = (specVal>0.5);
582  return pass;
583  }
584 }
585 
586 ////////////////////////////////////////////////////////////////////////////////
587 /// returns the mva value of the right sub-classifier
588 
590 {
591  if (fMethods.empty()) return 0;
592 
593  UInt_t methodToUse = 0;
594  const Event* ev = GetEvent();
595 
596  // determine which sub-classifier to use for this event
597  Int_t suitableCutsN = 0;
598 
599  for (UInt_t i=0; i<fMethods.size(); ++i) {
600  if (PassesCut(ev, i)) {
601  ++suitableCutsN;
602  methodToUse=i;
603  }
604  }
605 
606  if (suitableCutsN == 0) {
607  Log() << kWARNING << "Event does not lie within the cut of any sub-classifier." << Endl;
608  return 0;
609  }
610 
611  if (suitableCutsN > 1) {
612  Log() << kFATAL << "The defined categories are not disjoint." << Endl;
613  return 0;
614  }
615 
616  // get mva value from the suitable sub-classifier
617  ev->SetVariableArrangement(&fVarMaps[methodToUse]);
618  Double_t mvaValue = dynamic_cast<MethodBase*>(fMethods[methodToUse])->GetMvaValue(ev,err,errUpper);
619  ev->SetVariableArrangement(0);
620 
621  return mvaValue;
622 }
623 
624 
625 
626 ////////////////////////////////////////////////////////////////////////////////
627 /// returns the mva value of the right sub-classifier
628 
629 const std::vector<Float_t> &TMVA::MethodCategory::GetRegressionValues()
630 {
631  if (fMethods.empty()) return MethodBase::GetRegressionValues();
632 
633  UInt_t methodToUse = 0;
634  const Event* ev = GetEvent();
635 
636  // determine which sub-classifier to use for this event
637  Int_t suitableCutsN = 0;
638 
639  for (UInt_t i=0; i<fMethods.size(); ++i) {
640  if (PassesCut(ev, i)) {
641  ++suitableCutsN;
642  methodToUse=i;
643  }
644  }
645 
646  if (suitableCutsN == 0) {
647  Log() << kWARNING << "Event does not lie within the cut of any sub-classifier." << Endl;
649  }
650 
651  if (suitableCutsN > 1) {
652  Log() << kFATAL << "The defined categories are not disjoint." << Endl;
654  }
655  MethodBase* meth = dynamic_cast<MethodBase*>(fMethods[methodToUse]);
656  if (!meth){
657  Log() << kFATAL << "method not found in Category Regression method" << Endl;
659  }
660  // get mva value from the suitable sub-classifier
661  return meth->GetRegressionValues(ev);
662 }
663 
IMethod * Create(const std::string &name, const TString &job, const TString &title, DataSetInfo &dsi, const TString &option)
creates the method if needed based on the method name using the creator function the factory has stor...
static ClassifierFactory & Instance()
access to the ClassifierFactory singleton creates the instance if needed
virtual const char * GetTitle() const
Returns title of object.
Definition: TNamed.h:52
void Init()
initialize the method
Ssiz_t Last(char c) const
Find last occurrence of a character c.
Definition: TString.cxx:864
TMVA::IMethod * AddMethod(const TCut &, const TString &theVariables, Types::EMVA theMethod, const TString &theTitle, const TString &theOptions)
adds sub-classifier for a category
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:162
VariableInfo & AddTarget(const TString &expression, const TString &title, const TString &unit, Double_t min, Double_t max, Bool_t normalized=kTRUE, void *external=0)
add a variable (can be a complex expression) to the set of variables used in the MV analysis ...
void AddWeightsXMLTo(void *parent) const
create XML description of Category classifier
const TString GetWeightExpression(Int_t i) const
Definition: DataSetInfo.h:161
void ReadStateFromXML(void *parent)
std::vector< VariableInfo > & GetSpectatorInfos()
Definition: DataSetInfo.h:122
Ssiz_t Length() const
Definition: TString.h:390
void variables(TString fin="TMVA.root", TString dirName="InputVariables_Id", TString title="TMVA Input Variables", Bool_t isRegression=kFALSE, Bool_t useTMVAStyle=kTRUE)
Definition: variables.cxx:10
float Float_t
Definition: RtypesCore.h:53
void SetVariableArrangement(std::vector< UInt_t > *const m) const
set the variable arrangement
Definition: Event.cxx:187
void SetCut(const TCut &cut, const TString &className)
set the cut for the classes
void InitCircularTree(const DataSetInfo &dsi)
initialize the circular tree
Float_t GetSpectator(UInt_t ivar) const
return spectator content
Definition: Event.cxx:256
const TString & GetExpression() const
Definition: VariableInfo.h:62
const char * GetName() const
Returns name of object.
Definition: MethodBase.h:299
UInt_t GetNClasses() const
Definition: DataSetInfo.h:152
void SetMethodBaseDir(TDirectory *methodDir)
Definition: MethodBase.h:337
DataSet * Data() const
Definition: MethodBase.h:363
EAnalysisType
Definition: Types.h:124
UInt_t GetNTargets() const
Definition: DataSetInfo.h:129
const std::vector< Float_t > & GetRegressionValues(const TMVA::Event *const ev)
Definition: MethodBase.h:186
virtual const std::vector< Float_t > & GetRegressionValues()
Definition: MethodBase.h:193
Basic string class.
Definition: TString.h:137
int Int_t
Definition: RtypesCore.h:41
bool Bool_t
Definition: RtypesCore.h:59
const Bool_t kFALSE
Definition: Rtypes.h:92
virtual ~MethodCategory(void)
destructor
ClassImp(TMVA::MethodCategory) TMVA
standard constructor
void AddCut(const TCut &cut, const TString &className)
set the cut for the classes
virtual Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets)=0
void AddAttr(void *node, const char *, const T &value, Int_t precision=16)
Definition: Tools.h:308
void * AddChild(void *parent, const char *childname, const char *content=0, bool isRootNode=false)
add child node
Definition: Tools.cxx:1134
void GetHelpMessage() const
Get help message text.
const TString & GetMethodName() const
Definition: MethodBase.h:296
const char * Data() const
Definition: TString.h:349
static Types & Instance()
the the single instance of "Types" if existin already, or create it (Signleton)
Definition: Types.cxx:61
Tools & gTools()
Definition: Tools.cxx:79
void DeclareOptions()
options for this method
void * GetChild(void *parent, const char *childname=0)
get child node
Definition: Tools.cxx:1158
TDirectory * GetRootDir() const
Definition: DataSetInfo.h:187
const TString & GetNormalization() const
Definition: DataSetInfo.h:132
std::vector< std::vector< double > > Data
virtual void ParseOptions()
options parser
void SetupMethod()
setup of methods
Definition: MethodBase.cxx:295
std::vector< VariableInfo > & GetTargetInfos()
Definition: DataSetInfo.h:117
MethodCategory(const TString &jobName, const TString &methodTitle, DataSetInfo &theData, const TString &theOption="", TDirectory *theTargetDir=NULL)
Used to pass a selection expression to the Tree drawing routine.
Definition: TTreeFormula.h:64
A specialized string object used for TTree selections.
Definition: TCut.h:27
const Int_t MinNoTrainingEvents
Definition: Factory.cxx:80
void SetSplitOptions(const TString &so)
Definition: DataSetInfo.h:182
const Ranking * CreateRanking()
no ranking
char GetVarType() const
Definition: VariableInfo.h:67
std::string GetMethodName(TCppMethod_t)
Definition: Cppyy.cxx:707
TMVA::DataSetInfo & CreateCategoryDSI(const TCut &, const TString &, const TString &)
create a DataSetInfo object for a sub-classifier
ClassInfo * GetClassInfo(Int_t clNum) const
void SetWeightExpression(const TString &exp, const TString &className="")
set the weight expressions for the classes if class name is specified, set only for this class if cla...
unsigned int UInt_t
Definition: RtypesCore.h:42
char * Form(const char *fmt,...)
void Train(void)
train all sub-classifiers
const TString & GetName() const
Definition: ClassInfo.h:72
void ReadAttr(void *node, const char *, T &value)
Definition: Tools.h:295
const TCut & GetCut(Int_t i) const
Definition: DataSetInfo.h:165
const TString & GetSplitOptions() const
Definition: DataSetInfo.h:183
virtual void CheckSetup()
check may be overridden by derived class (sometimes, eg, fitters are used which can only be implement...
Definition: MethodBase.cxx:320
double f(double x)
double Double_t
Definition: RtypesCore.h:55
Describe directory structure in memory.
Definition: TDirectory.h:41
int type
Definition: TGX11.cxx:120
void * GetNextChild(void *prevchild, const char *childname=0)
XML helpers.
Definition: Tools.cxx:1170
DataSetInfo & DataInfo() const
Definition: MethodBase.h:364
virtual Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t)
check whether method category has analysis type the method type has to be the same for all sub-method...
T EvalInstance(Int_t i=0, const char *stringStack[]=0)
Evaluate this treeformula.
ClassInfo * AddClass(const TString &className)
void ProcessSetup()
process all options the "CheckForUnusedOptions" is done in an independent call, since it may be overr...
Definition: MethodBase.cxx:310
const TString & Color(const TString &)
human readable color strings
Definition: Tools.cxx:837
Double_t GetMvaValue(Double_t *err=0, Double_t *errUpper=0)
returns the mva value of the right sub-classifier
void * GetExternalLink() const
Definition: VariableInfo.h:85
#define REGISTER_METHOD(CLASS)
for example
Abstract ClassifierFactory template that handles arbitrary types.
VariableInfo & AddSpectator(const TString &expression, const TString &title, const TString &unit, Double_t min, Double_t max, char type= 'F', Bool_t normalized=kTRUE, void *external=0)
add a spectator (can be a complex expression) to the set of spectator variables used in the MV analys...
Bool_t Contains(const char *pat, ECaseCompare cmp=kExact) const
Definition: TString.h:567
virtual TDirectory * GetDirectory(const char *namecycle, Bool_t printError=false, const char *funcname="GetDirectory")
Find a directory using apath.
Definition: TDirectory.cxx:336
#define NULL
Definition: Rtypes.h:82
VariableInfo & AddVariable(const TString &expression, const TString &title="", const TString &unit="", Double_t min=0, Double_t max=0, char varType='F', Bool_t normalized=kTRUE, void *external=0)
add a variable (can be a complex expression) to the set of variables used in the MV analysis ...
void DisableWriting(Bool_t setter)
Definition: MethodBase.h:396
Long64_t GetNTrainingEvents() const
Definition: DataSet.h:90
A TTree object has a header with a name and a title.
Definition: TTree.h:94
void ReadWeightsFromXML(void *wghtnode)
read weights of sub-classifiers of MethodCategory from xml weight file
Bool_t PassesCut(const Event *ev, UInt_t methodIdx)
virtual void Print() const
get maximum length of variable names
Definition: Ranking.cxx:109
TString GetMethodTypeName() const
Definition: MethodBase.h:297
Ssiz_t Index(const char *pat, Ssiz_t i=0, ECaseCompare cmp=kExact) const
Definition: TString.h:582
const Bool_t kTRUE
Definition: Rtypes.h:91
double norm(double *x, double *p)
Definition: unuranDistr.cxx:40
void WriteStateToXML(void *parent) const
general method used in writing the header of the weight files where the used variables, variable transformation type etc.
void SetNormalization(const TString &norm)
Definition: DataSetInfo.h:133
void SetRootDir(TDirectory *d)
Definition: DataSetInfo.h:186
virtual const std::vector< Float_t > & GetRegressionValues()
returns the mva value of the right sub-classifier
std::vector< TString > SplitString(const TString &theOpt, const char separator) const
splits the option string at 'separator' and fills the list 'splitV' with the primitive strings ...
Definition: Tools.cxx:1207
std::vector< VariableInfo > & GetVariableInfos()
Definition: DataSetInfo.h:112
virtual void SetAnalysisType(Types::EAnalysisType type)
Definition: MethodBase.h:390
Definition: math.cpp:60
void ProcessOptions()
process user options