Logo ROOT  
Reference Guide
MethodCompositeBase.cxx
Go to the documentation of this file.
1// @(#)root/tmva $Id$
2// Author: Andreas Hoecker, Joerg Stelzer, Helge Voss, Kai Voss,Or Cohen
3
4/*****************************************************************************
5 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6 * Package: TMVA *
7 * Class : MethodCompositeBase *
8 * Web : http://tmva.sourceforge.net *
9 * *
10 * Description: *
11 * Virtual base class for all MVA method *
12 * *
13 * Authors (alphabetical): *
14 * Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland *
15 * Joerg Stelzer <Joerg.Stelzer@cern.ch> - MSU, USA *
16 * Helge Voss <Helge.Voss@cern.ch> - MPI-K Heidelberg, Germany *
17 * Kai Voss <Kai.Voss@cern.ch> - U. of Victoria, Canada *
18 * Or Cohen <orcohenor@gmail.com> - Weizmann Inst., Israel *
19 * *
20 * Copyright (c) 2005: *
21 * CERN, Switzerland *
22 * U. of Victoria, Canada *
23 * MPI-K Heidelberg, Germany *
24 * LAPP, Annecy, France *
25 * *
26 * Redistribution and use in source and binary forms, with or without *
27 * modification, are permitted according to the terms listed in LICENSE *
28 * (http://tmva.sourceforge.net/LICENSE) *
29 *****************************************************************************/
30
31/*! \class TMVA::MethodCompositeBase
32\ingroup TMVA
33
34Virtual base class for combining several TMVA method.
35
36This class is virtual class meant to combine more than one classifier
37together. The training of the classifiers is done by classes that are
38derived from this one, while the saving and loading of weights file
39and the evaluation is done here.
40*/
41
43
45#include "TMVA/DataSetInfo.h"
46#include "TMVA/Factory.h"
47#include "TMVA/IMethod.h"
48#include "TMVA/MethodBase.h"
49#include "TMVA/MethodBoost.h"
50#include "TMVA/MsgLogger.h"
51#include "TMVA/Tools.h"
52#include "TMVA/Types.h"
53#include "TMVA/Config.h"
54
55#include "Riostream.h"
56#include "TRandom3.h"
57#include "TMath.h"
58
59#include <algorithm>
60#include <iomanip>
61#include <vector>
62
63
64using std::vector;
65
67
68////////////////////////////////////////////////////////////////////////////////
69
71 Types::EMVA methodType,
72 const TString& methodTitle,
73 DataSetInfo& theData,
74 const TString& theOption )
75: TMVA::MethodBase( jobName, methodType, methodTitle, theData, theOption),
76 fCurrentMethodIdx(0), fCurrentMethod(0)
77{}
78
79////////////////////////////////////////////////////////////////////////////////
80
82 DataSetInfo& dsi,
83 const TString& weightFile)
84 : TMVA::MethodBase( methodType, dsi, weightFile),
85 fCurrentMethodIdx(0), fCurrentMethod(0)
86{}
87
88////////////////////////////////////////////////////////////////////////////////
89/// returns pointer to MVA that corresponds to given method title
90
92{
93 std::vector<IMethod*>::const_iterator itrMethod = fMethods.begin();
94 std::vector<IMethod*>::const_iterator itrMethodEnd = fMethods.end();
95
96 for (; itrMethod != itrMethodEnd; ++itrMethod) {
97 MethodBase* mva = dynamic_cast<MethodBase*>(*itrMethod);
98 if ( (mva->GetMethodName())==methodTitle ) return mva;
99 }
100 return 0;
101}
102
103////////////////////////////////////////////////////////////////////////////////
104/// returns pointer to MVA that corresponds to given method index
105
107{
108 std::vector<IMethod*>::const_iterator itrMethod = fMethods.begin()+index;
109 if (itrMethod<fMethods.end()) return *itrMethod;
110 else return 0;
111}
112
113
114////////////////////////////////////////////////////////////////////////////////
115
117{
118 void* wght = gTools().AddChild(parent, "Weights");
119 gTools().AddAttr( wght, "NMethods", fMethods.size() );
120 for (UInt_t i=0; i< fMethods.size(); i++)
121 {
122 void* methxml = gTools().AddChild( wght, "Method" );
123 MethodBase* method = dynamic_cast<MethodBase*>(fMethods[i]);
124 gTools().AddAttr(methxml,"Index", i );
125 gTools().AddAttr(methxml,"Weight", fMethodWeight[i]);
126 gTools().AddAttr(methxml,"MethodSigCut", method->GetSignalReferenceCut());
127 gTools().AddAttr(methxml,"MethodSigCutOrientation", method->GetSignalReferenceCutOrientation());
128 gTools().AddAttr(methxml,"MethodTypeName", method->GetMethodTypeName());
129 gTools().AddAttr(methxml,"MethodName", method->GetMethodName() );
130 gTools().AddAttr(methxml,"JobName", method->GetJobName());
131 gTools().AddAttr(methxml,"Options", method->GetOptions());
132 if (method->fTransformationPointer)
133 gTools().AddAttr(methxml,"UseMainMethodTransformation", TString("true"));
134 else
135 gTools().AddAttr(methxml,"UseMainMethodTransformation", TString("false"));
136 method->AddWeightsXMLTo(methxml);
137 }
138}
139
140////////////////////////////////////////////////////////////////////////////////
141/// delete methods
142
144{
145 std::vector<IMethod*>::iterator itrMethod = fMethods.begin();
146 for (; itrMethod != fMethods.end(); ++itrMethod) {
147 Log() << kVERBOSE << "Delete method: " << (*itrMethod)->GetName() << Endl;
148 delete (*itrMethod);
149 }
150 fMethods.clear();
151}
152
153////////////////////////////////////////////////////////////////////////////////
154/// XML streamer
155
157{
158 UInt_t nMethods;
159 TString methodName, methodTypeName, jobName, optionString;
160
161 for (UInt_t i=0;i<fMethods.size();i++) delete fMethods[i];
162 fMethods.clear();
163 fMethodWeight.clear();
164 gTools().ReadAttr( wghtnode, "NMethods", nMethods );
165 void* ch = gTools().GetChild(wghtnode);
166 for (UInt_t i=0; i< nMethods; i++) {
167 Double_t methodWeight, methodSigCut, methodSigCutOrientation;
168 gTools().ReadAttr( ch, "Weight", methodWeight );
169 gTools().ReadAttr( ch, "MethodSigCut", methodSigCut);
170 gTools().ReadAttr( ch, "MethodSigCutOrientation", methodSigCutOrientation);
171 gTools().ReadAttr( ch, "MethodTypeName", methodTypeName );
172 gTools().ReadAttr( ch, "MethodName", methodName );
173 gTools().ReadAttr( ch, "JobName", jobName );
174 gTools().ReadAttr( ch, "Options", optionString );
175
176 // Bool_t rerouteTransformation = kFALSE;
177 if (gTools().HasAttr( ch, "UseMainMethodTransformation")) {
178 TString rerouteString("");
179 gTools().ReadAttr( ch, "UseMainMethodTransformation", rerouteString );
180 rerouteString.ToLower();
181 // if (rerouteString=="true")
182 // rerouteTransformation=kTRUE;
183 }
184
185 //remove trailing "~" to signal that options have to be reused
186 optionString.ReplaceAll("~","");
187 //ignore meta-options for method Boost
188 optionString.ReplaceAll("Boost_","~Boost_");
189 optionString.ReplaceAll("!~","~!");
190
191 if (i==0){
192 // the cast on MethodBoost is ugly, but a similar line is also in ReadWeightsFromFile --> needs to be fixed later
193 ((TMVA::MethodBoost*)this)->BookMethod( Types::Instance().GetMethodType( methodTypeName), methodName, optionString );
194 }
195 fMethods.push_back(
196 ClassifierFactory::Instance().Create(methodTypeName.Data(), jobName, methodName, DataInfo(), optionString));
197
198 fMethodWeight.push_back(methodWeight);
199 MethodBase* meth = dynamic_cast<MethodBase*>(fMethods.back());
200
201 if(meth==0)
202 Log() << kFATAL << "Could not read method from XML" << Endl;
203
204 void* methXML = gTools().GetChild(ch);
205
206 TString _fFileDir= meth->DataInfo().GetName();
207 _fFileDir+="/"+gConfig().GetIONames().fWeightFileDir;
208 meth->SetWeightFileDir(_fFileDir);
209 meth->SetModelPersistence(IsModelPersistence());
210 meth->SetSilentFile(IsSilentFile());
211 meth->SetupMethod();
212 meth->SetMsgType(kWARNING);
213 meth->ParseOptions();
214 meth->ProcessSetup();
215 meth->CheckSetup();
216 meth->ReadWeightsFromXML(methXML);
217 meth->SetSignalReferenceCut(methodSigCut);
218 meth->SetSignalReferenceCutOrientation(methodSigCutOrientation);
219
220 meth->RerouteTransformationHandler (&(this->GetTransformationHandler()));
221
222 ch = gTools().GetNextChild(ch);
223 }
224 //Log() << kINFO << "Reading methods from XML done " << Endl;
225}
226
227////////////////////////////////////////////////////////////////////////////////
228/// text streamer
229
231{
232 TString var, dummy;
233 TString methodName, methodTitle=GetMethodName(),
234 jobName=GetJobName(),optionString=GetOptions();
235 UInt_t methodNum; Double_t methodWeight;
236 // and read the Weights (BDT coefficients)
237 // coverity[tainted_data_argument]
238 istr >> dummy >> methodNum;
239 Log() << kINFO << "Read " << methodNum << " Classifiers" << Endl;
240 for (UInt_t i=0;i<fMethods.size();i++) delete fMethods[i];
241 fMethods.clear();
242 fMethodWeight.clear();
243 for (UInt_t i=0; i<methodNum; i++) {
244 istr >> dummy >> methodName >> dummy >> fCurrentMethodIdx >> dummy >> methodWeight;
245 if ((UInt_t)fCurrentMethodIdx != i) {
246 Log() << kFATAL << "Error while reading weight file; mismatch MethodIndex="
247 << fCurrentMethodIdx << " i=" << i
248 << " MethodName " << methodName
249 << " dummy " << dummy
250 << " MethodWeight= " << methodWeight
251 << Endl;
252 }
253 if (GetMethodType() != Types::kBoost || i==0) {
254 istr >> dummy >> jobName;
255 istr >> dummy >> methodTitle;
256 istr >> dummy >> optionString;
257 if (GetMethodType() == Types::kBoost)
258 ((TMVA::MethodBoost*)this)->BookMethod( Types::Instance().GetMethodType( methodName), methodTitle, optionString );
259 }
260 else methodTitle=Form("%s (%04i)",GetMethodName().Data(),fCurrentMethodIdx);
261 fMethods.push_back(
262 ClassifierFactory::Instance().Create(methodName.Data(), jobName, methodTitle, DataInfo(), optionString));
263 fMethodWeight.push_back( methodWeight );
264 if(MethodBase* m = dynamic_cast<MethodBase*>(fMethods.back()) )
265 m->ReadWeightsFromStream(istr);
266 }
267}
268
269////////////////////////////////////////////////////////////////////////////////
270/// return composite MVA response
271
273{
274 Double_t mvaValue = 0;
275 for (UInt_t i=0;i< fMethods.size(); i++) mvaValue+=fMethods[i]->GetMvaValue()*fMethodWeight[i];
276
277 // cannot determine error
278 NoErrorCalc(err, errUpper);
279
280 return mvaValue;
281}
static RooMathCoreReg dummy
double Double_t
Definition: RtypesCore.h:57
#define ClassImp(name)
Definition: Rtypes.h:361
char * Form(const char *fmt,...)
IMethod * Create(const std::string &name, const TString &job, const TString &title, DataSetInfo &dsi, const TString &option)
creates the method if needed based on the method name using the creator function the factory has stor...
static ClassifierFactory & Instance()
access to the ClassifierFactory singleton creates the instance if needed
TString fWeightFileDir
Definition: Config.h:124
IONames & GetIONames()
Definition: Config.h:100
virtual void ParseOptions()
options parser
const TString & GetOptions() const
Definition: Configurable.h:84
void SetMsgType(EMsgType t)
Definition: Configurable.h:125
Class that contains all the data information.
Definition: DataSetInfo.h:60
virtual const char * GetName() const
Returns name of object.
Definition: DataSetInfo.h:69
Interface for all concrete MVA method implementations.
Definition: IMethod.h:54
Virtual base Class for all MVA method.
Definition: MethodBase.h:111
TransformationHandler * fTransformationPointer
Definition: MethodBase.h:669
void SetSilentFile(Bool_t status)
Definition: MethodBase.h:377
void SetWeightFileDir(TString fileDir)
set directory of weight file
TString GetMethodTypeName() const
Definition: MethodBase.h:331
virtual void ReadWeightsFromXML(void *wghtnode)=0
const TString & GetJobName() const
Definition: MethodBase.h:329
void SetupMethod()
setup of methods
Definition: MethodBase.cxx:408
friend class MethodCompositeBase
Definition: MethodBase.h:269
const TString & GetMethodName() const
Definition: MethodBase.h:330
void ProcessSetup()
process all options the "CheckForUnusedOptions" is done in an independent call, since it may be overr...
Definition: MethodBase.cxx:425
void RerouteTransformationHandler(TransformationHandler *fTargetTransformation)
Definition: MethodBase.h:402
DataSetInfo & DataInfo() const
Definition: MethodBase.h:409
virtual void AddWeightsXMLTo(void *parent) const =0
Double_t GetSignalReferenceCutOrientation() const
Definition: MethodBase.h:360
void SetSignalReferenceCut(Double_t cut)
Definition: MethodBase.h:363
void SetSignalReferenceCutOrientation(Double_t cutOrientation)
Definition: MethodBase.h:364
void SetModelPersistence(Bool_t status)
Definition: MethodBase.h:381
Double_t GetSignalReferenceCut() const
Definition: MethodBase.h:359
virtual void CheckSetup()
check may be overridden by derived class (sometimes, eg, fitters are used which can only be implement...
Definition: MethodBase.cxx:435
Class for boosting a TMVA method.
Definition: MethodBoost.h:58
Virtual base class for combining several TMVA method.
void ReadWeightsFromStream(std::istream &istr)
text streamer
Double_t GetMvaValue(Double_t *err=0, Double_t *errUpper=0)
return composite MVA response
virtual ~MethodCompositeBase(void)
delete methods
IMethod * GetMethod(const TString &title) const
returns pointer to MVA that corresponds to given method title
void ReadWeightsFromXML(void *wghtnode)
XML streamer.
void AddWeightsXMLTo(void *parent) const
void * GetNextChild(void *prevchild, const char *childname=0)
XML helpers.
Definition: Tools.cxx:1173
void * AddChild(void *parent, const char *childname, const char *content=0, bool isRootNode=false)
add child node
Definition: Tools.cxx:1135
void * GetChild(void *parent, const char *childname=0)
get child node
Definition: Tools.cxx:1161
void ReadAttr(void *node, const char *, T &value)
read attribute from xml
Definition: Tools.h:335
void AddAttr(void *node, const char *, const T &value, Int_t precision=16)
add attribute to xml
Definition: Tools.h:353
static Types & Instance()
the the single instance of "Types" if existing already, or create it (Singleton)
Definition: Types.cxx:70
@ kBoost
Definition: Types.h:95
Types::EMVA GetMethodType(const TString &method) const
returns the method type (enum) for a given method (string)
Definition: Types.cxx:121
Basic string class.
Definition: TString.h:131
void ToLower()
Change string to lower-case.
Definition: TString.cxx:1125
const char * Data() const
Definition: TString.h:364
TString & ReplaceAll(const TString &s1, const TString &s2)
Definition: TString.h:687
void GetMethodName(TString &name, TKey *mkey)
Definition: tmvaglob.cxx:335
create variable transformations
Config & gConfig()
Tools & gTools()
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:158
Double_t Log(Double_t x)
Definition: TMath.h:750
auto * m
Definition: textangle.C:8