Logo ROOT   6.08/07
Reference Guide
MethodCompositeBase.cxx
Go to the documentation of this file.
1 // @(#)root/tmva $Id$
2 // Author: Andreas Hoecker, Joerg Stelzer, Helge Voss, Kai Voss,Or Cohen
3 
4 /*****************************************************************************
5  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6  * Package: TMVA *
7  * Class : MethodCompositeBase *
8  * Web : http://tmva.sourceforge.net *
9  * *
10  * Description: *
11  * Virtual base class for all MVA method *
12  * *
13  * Authors (alphabetical): *
14  * Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland *
15  * Joerg Stelzer <Joerg.Stelzer@cern.ch> - MSU, USA *
16  * Helge Voss <Helge.Voss@cern.ch> - MPI-K Heidelberg, Germany *
17  * Kai Voss <Kai.Voss@cern.ch> - U. of Victoria, Canada *
18  * Or Cohen <orcohenor@gmail.com> - Weizmann Inst., Israel *
19  * *
20  * Copyright (c) 2005: *
21  * CERN, Switzerland *
22  * U. of Victoria, Canada *
23  * MPI-K Heidelberg, Germany *
24  * LAPP, Annecy, France *
25  * *
26  * Redistribution and use in source and binary forms, with or without *
27  * modification, are permitted according to the terms listed in LICENSE *
28  * (http://tmva.sourceforge.net/LICENSE) *
29  *****************************************************************************/
30 
31 //_______________________________________________________________________
32 //
33 // This class is virtual class meant to combine more than one classifier//
34 // together. The training of the classifiers is done by classes that are//
35 // derived from this one, while the saving and loading of weights file //
36 // and the evaluation is done here. //
37 //_______________________________________________________________________
38 
40 
41 #include "TMVA/ClassifierFactory.h"
42 #include "TMVA/DataSetInfo.h"
43 #include "TMVA/Factory.h"
44 #include "TMVA/IMethod.h"
45 #include "TMVA/MethodBase.h"
46 #include "TMVA/MethodBoost.h"
47 #include "TMVA/MsgLogger.h"
48 #include "TMVA/Tools.h"
49 #include "TMVA/Types.h"
50 #include "TMVA/Config.h"
51 
52 #include "Riostream.h"
53 #include "TRandom3.h"
54 #include "TMath.h"
55 #include "TObjString.h"
56 
57 #include <algorithm>
58 #include <iomanip>
59 #include <vector>
60 
61 
62 using std::vector;
63 
65 
66 ////////////////////////////////////////////////////////////////////////////////
67 
69  Types::EMVA methodType,
70  const TString& methodTitle,
71  DataSetInfo& theData,
72  const TString& theOption )
73 : TMVA::MethodBase( jobName, methodType, methodTitle, theData, theOption),
74  fCurrentMethodIdx(0), fCurrentMethod(0)
75 {}
76 
77 ////////////////////////////////////////////////////////////////////////////////
78 
80  DataSetInfo& dsi,
81  const TString& weightFile)
82  : TMVA::MethodBase( methodType, dsi, weightFile),
83  fCurrentMethodIdx(0), fCurrentMethod(0)
84 {}
85 
86 ////////////////////////////////////////////////////////////////////////////////
87 /// returns pointer to MVA that corresponds to given method title
88 
90 {
91  std::vector<IMethod*>::const_iterator itrMethod = fMethods.begin();
92  std::vector<IMethod*>::const_iterator itrMethodEnd = fMethods.end();
93 
94  for (; itrMethod != itrMethodEnd; itrMethod++) {
95  MethodBase* mva = dynamic_cast<MethodBase*>(*itrMethod);
96  if ( (mva->GetMethodName())==methodTitle ) return mva;
97  }
98  return 0;
99 }
100 
101 ////////////////////////////////////////////////////////////////////////////////
102 /// returns pointer to MVA that corresponds to given method index
103 
105 {
106  std::vector<IMethod*>::const_iterator itrMethod = fMethods.begin()+index;
107  if (itrMethod<fMethods.end()) return *itrMethod;
108  else return 0;
109 }
110 
111 
112 ////////////////////////////////////////////////////////////////////////////////
113 
115 {
116  void* wght = gTools().AddChild(parent, "Weights");
117  gTools().AddAttr( wght, "NMethods", fMethods.size() );
118  for (UInt_t i=0; i< fMethods.size(); i++)
119  {
120  void* methxml = gTools().AddChild( wght, "Method" );
121  MethodBase* method = dynamic_cast<MethodBase*>(fMethods[i]);
122  gTools().AddAttr(methxml,"Index", i );
123  gTools().AddAttr(methxml,"Weight", fMethodWeight[i]);
124  gTools().AddAttr(methxml,"MethodSigCut", method->GetSignalReferenceCut());
125  gTools().AddAttr(methxml,"MethodSigCutOrientation", method->GetSignalReferenceCutOrientation());
126  gTools().AddAttr(methxml,"MethodTypeName", method->GetMethodTypeName());
127  gTools().AddAttr(methxml,"MethodName", method->GetMethodName() );
128  gTools().AddAttr(methxml,"JobName", method->GetJobName());
129  gTools().AddAttr(methxml,"Options", method->GetOptions());
130  if (method->fTransformationPointer)
131  gTools().AddAttr(methxml,"UseMainMethodTransformation", TString("true"));
132  else
133  gTools().AddAttr(methxml,"UseMainMethodTransformation", TString("false"));
134  method->AddWeightsXMLTo(methxml);
135  }
136 }
137 
138 ////////////////////////////////////////////////////////////////////////////////
139 /// delete methods
140 
142 {
143  std::vector<IMethod*>::iterator itrMethod = fMethods.begin();
144  for (; itrMethod != fMethods.end(); itrMethod++) {
145  Log() << kVERBOSE << "Delete method: " << (*itrMethod)->GetName() << Endl;
146  delete (*itrMethod);
147  }
148  fMethods.clear();
149 }
150 
151 ////////////////////////////////////////////////////////////////////////////////
152 /// XML streamer
153 
155 {
156  UInt_t nMethods;
157  TString methodName, methodTypeName, jobName, optionString;
158 
159  for (UInt_t i=0;i<fMethods.size();i++) delete fMethods[i];
160  fMethods.clear();
161  fMethodWeight.clear();
162  gTools().ReadAttr( wghtnode, "NMethods", nMethods );
163  void* ch = gTools().GetChild(wghtnode);
164  for (UInt_t i=0; i< nMethods; i++) {
165  Double_t methodWeight, methodSigCut, methodSigCutOrientation;
166  gTools().ReadAttr( ch, "Weight", methodWeight );
167  gTools().ReadAttr( ch, "MethodSigCut", methodSigCut);
168  gTools().ReadAttr( ch, "MethodSigCutOrientation", methodSigCutOrientation);
169  gTools().ReadAttr( ch, "MethodTypeName", methodTypeName );
170  gTools().ReadAttr( ch, "MethodName", methodName );
171  gTools().ReadAttr( ch, "JobName", jobName );
172  gTools().ReadAttr( ch, "Options", optionString );
173 
174  // Bool_t rerouteTransformation = kFALSE;
175  if (gTools().HasAttr( ch, "UseMainMethodTransformation")) {
176  TString rerouteString("");
177  gTools().ReadAttr( ch, "UseMainMethodTransformation", rerouteString );
178  rerouteString.ToLower();
179  // if (rerouteString=="true")
180  // rerouteTransformation=kTRUE;
181  }
182 
183  //remove trailing "~" to signal that options have to be reused
184  optionString.ReplaceAll("~","");
185  //ignore meta-options for method Boost
186  optionString.ReplaceAll("Boost_","~Boost_");
187  optionString.ReplaceAll("!~","~!");
188 
189  if (i==0){
190  // the cast on MethodBoost is ugly, but a similar line is also in ReadWeightsFromFile --> needs to be fixed later
191  ((TMVA::MethodBoost*)this)->BookMethod( Types::Instance().GetMethodType( methodTypeName), methodName, optionString );
192  }
193  fMethods.push_back(ClassifierFactory::Instance().Create(
194  std::string(methodTypeName),jobName, methodName,DataInfo(),optionString));
195 
196  fMethodWeight.push_back(methodWeight);
197  MethodBase* meth = dynamic_cast<MethodBase*>(fMethods.back());
198 
199  if(meth==0)
200  Log() << kFATAL << "Could not read method from XML" << Endl;
201 
202  void* methXML = gTools().GetChild(ch);
203 
204  TString _fFileDir= meth->DataInfo().GetName();
205  _fFileDir+="/"+gConfig().GetIONames().fWeightFileDir;
206  meth->SetWeightFileDir(_fFileDir);
208  meth->SetSilentFile(IsSilentFile());
209  meth->SetupMethod();
210  meth->SetMsgType(kWARNING);
211  meth->ParseOptions();
212  meth->ProcessSetup();
213  meth->CheckSetup();
214  meth->ReadWeightsFromXML(methXML);
215  meth->SetSignalReferenceCut(methodSigCut);
216  meth->SetSignalReferenceCutOrientation(methodSigCutOrientation);
217 
219 
220  ch = gTools().GetNextChild(ch);
221  }
222  //Log() << kINFO << "Reading methods from XML done " << Endl;
223 }
224 
225 ////////////////////////////////////////////////////////////////////////////////
226 /// text streamer
227 
229 {
230  TString var, dummy;
231  TString methodName, methodTitle=GetMethodName(),
232  jobName=GetJobName(),optionString=GetOptions();
233  UInt_t methodNum; Double_t methodWeight;
234  // and read the Weights (BDT coefficients)
235  // coverity[tainted_data_argument]
236  istr >> dummy >> methodNum;
237  Log() << kINFO << "Read " << methodNum << " Classifiers" << Endl;
238  for (UInt_t i=0;i<fMethods.size();i++) delete fMethods[i];
239  fMethods.clear();
240  fMethodWeight.clear();
241  for (UInt_t i=0; i<methodNum; i++) {
242  istr >> dummy >> methodName >> dummy >> fCurrentMethodIdx >> dummy >> methodWeight;
243  if ((UInt_t)fCurrentMethodIdx != i) {
244  Log() << kFATAL << "Error while reading weight file; mismatch MethodIndex="
245  << fCurrentMethodIdx << " i=" << i
246  << " MethodName " << methodName
247  << " dummy " << dummy
248  << " MethodWeight= " << methodWeight
249  << Endl;
250  }
251  if (GetMethodType() != Types::kBoost || i==0) {
252  istr >> dummy >> jobName;
253  istr >> dummy >> methodTitle;
254  istr >> dummy >> optionString;
255  if (GetMethodType() == Types::kBoost)
256  ((TMVA::MethodBoost*)this)->BookMethod( Types::Instance().GetMethodType( methodName), methodTitle, optionString );
257  }
258  else methodTitle=Form("%s (%04i)",GetMethodName().Data(),fCurrentMethodIdx);
259  fMethods.push_back(ClassifierFactory::Instance().Create( std::string(methodName), jobName,
260  methodTitle,DataInfo(), optionString) );
261  fMethodWeight.push_back( methodWeight );
262  if(MethodBase* m = dynamic_cast<MethodBase*>(fMethods.back()) )
263  m->ReadWeightsFromStream(istr);
264  }
265 }
266 
267 ////////////////////////////////////////////////////////////////////////////////
268 /// return composite MVA response
269 
271 {
272  Double_t mvaValue = 0;
273  for (UInt_t i=0;i< fMethods.size(); i++) mvaValue+=fMethods[i]->GetMvaValue()*fMethodWeight[i];
274 
275  // cannot determine error
276  NoErrorCalc(err, errUpper);
277 
278  return mvaValue;
279 }
static ClassifierFactory & Instance()
access to the ClassifierFactory singleton creates the instance if needed
MethodCompositeBase(const TString &jobName, Types::EMVA methodType, const TString &methodTitle, DataSetInfo &theData, const TString &theOption="")
Config & gConfig()
Definition: Config.cxx:43
void SetModelPersistence(Bool_t status)
Definition: MethodBase.h:378
void SetMsgType(EMsgType t)
Definition: Configurable.h:131
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:162
void ReadWeightsFromXML(void *wghtnode)
XML streamer.
TString & ReplaceAll(const TString &s1, const TString &s2)
Definition: TString.h:635
static Types & Instance()
the the single instance of "Types" if existin already, or create it (Signleton)
Definition: Types.cxx:64
MsgLogger & Log() const
Definition: Configurable.h:128
TransformationHandler * fTransformationPointer
Definition: MethodBase.h:665
void SetSignalReferenceCutOrientation(Double_t cutOrientation)
Definition: MethodBase.h:361
Basic string class.
Definition: TString.h:137
TransformationHandler & GetTransformationHandler(Bool_t takeReroutedIfAvailable=true)
Definition: MethodBase.h:390
void ToLower()
Change string to lower-case.
Definition: TString.cxx:1089
int Int_t
Definition: RtypesCore.h:41
void SetSilentFile(Bool_t status)
Definition: MethodBase.h:374
void ReadWeightsFromStream(std::istream &istr)
text streamer
void AddAttr(void *node, const char *, const T &value, Int_t precision=16)
Definition: Tools.h:309
void * AddChild(void *parent, const char *childname, const char *content=0, bool isRootNode=false)
add child node
Definition: Tools.cxx:1134
Tools & gTools()
Definition: Tools.cxx:79
void AddWeightsXMLTo(void *parent) const
IMethod * GetMethod(const TString &title) const
returns pointer to MVA that corresponds to given method title
std::vector< Double_t > fMethodWeight
DataSet * Data() const
Definition: MethodBase.h:405
TString fWeightFileDir
Definition: Config.h:100
void * GetChild(void *parent, const char *childname=0)
get child node
Definition: Tools.cxx:1158
virtual ~MethodCompositeBase(void)
delete methods
IONames & GetIONames()
Definition: Config.h:78
virtual void ParseOptions()
options parser
void SetupMethod()
setup of methods
Definition: MethodBase.cxx:403
DataSetInfo & DataInfo() const
Definition: MethodBase.h:406
virtual void AddWeightsXMLTo(void *parent) const =0
virtual void ReadWeightsFromXML(void *wghtnode)=0
unsigned int UInt_t
Definition: RtypesCore.h:42
TMarker * m
Definition: textangle.C:8
char * Form(const char *fmt,...)
const TString & GetJobName() const
Definition: MethodBase.h:326
const TString & GetMethodName() const
Definition: MethodBase.h:327
void ReadAttr(void *node, const char *, T &value)
Definition: Tools.h:296
Bool_t IsSilentFile()
Definition: MethodBase.h:375
Double_t GetSignalReferenceCutOrientation() const
Definition: MethodBase.h:357
virtual void CheckSetup()
check may be overridden by derived class (sometimes, eg, fitters are used which can only be implement...
Definition: MethodBase.cxx:430
#define ClassImp(name)
Definition: Rtypes.h:279
void RerouteTransformationHandler(TransformationHandler *fTargetTransformation)
Definition: MethodBase.h:399
double Double_t
Definition: RtypesCore.h:55
static RooMathCoreReg dummy
void * GetNextChild(void *prevchild, const char *childname=0)
XML helpers.
Definition: Tools.cxx:1170
virtual const char * GetName() const
Returns name of object.
Definition: DataSetInfo.h:85
void ProcessSetup()
process all options the "CheckForUnusedOptions" is done in an independent call, since it may be overr...
Definition: MethodBase.cxx:420
const TString & GetOptions() const
Definition: Configurable.h:90
std::vector< IMethod * > fMethods
Abstract ClassifierFactory template that handles arbitrary types.
Double_t GetMvaValue(Double_t *err=0, Double_t *errUpper=0)
return composite MVA response
TString GetMethodTypeName() const
Definition: MethodBase.h:328
void SetWeightFileDir(TString fileDir)
set directory of weight file
Double_t GetSignalReferenceCut() const
Definition: MethodBase.h:356
Types::EMVA GetMethodType() const
Definition: MethodBase.h:329
void NoErrorCalc(Double_t *const err, Double_t *const errUpper)
Definition: MethodBase.cxx:819
void SetSignalReferenceCut(Double_t cut)
Definition: MethodBase.h:360
Bool_t IsModelPersistence()
Definition: MethodBase.h:379