Logo ROOT  
Reference Guide
MethodCompositeBase.cxx
Go to the documentation of this file.
1// @(#)root/tmva $Id$
2// Author: Andreas Hoecker, Joerg Stelzer, Helge Voss, Kai Voss,Or Cohen
3
4/*****************************************************************************
5 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6 * Package: TMVA *
7 * Class : MethodCompositeBase *
8 * Web : http://tmva.sourceforge.net *
9 * *
10 * Description: *
11 * Virtual base class for all MVA method *
12 * *
13 * Authors (alphabetical): *
14 * Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland *
15 * Joerg Stelzer <Joerg.Stelzer@cern.ch> - MSU, USA *
16 * Helge Voss <Helge.Voss@cern.ch> - MPI-K Heidelberg, Germany *
17 * Kai Voss <Kai.Voss@cern.ch> - U. of Victoria, Canada *
18 * Or Cohen <orcohenor@gmail.com> - Weizmann Inst., Israel *
19 * *
20 * Copyright (c) 2005: *
21 * CERN, Switzerland *
22 * U. of Victoria, Canada *
23 * MPI-K Heidelberg, Germany *
24 * LAPP, Annecy, France *
25 * *
26 * Redistribution and use in source and binary forms, with or without *
27 * modification, are permitted according to the terms listed in LICENSE *
28 * (http://tmva.sourceforge.net/LICENSE) *
29 *****************************************************************************/
30
31/*! \class TMVA::MethodCompositeBase
32\ingroup TMVA
33
34Virtual base class for combining several TMVA method.
35
36This class is virtual class meant to combine more than one classifier
37together. The training of the classifiers is done by classes that are
38derived from this one, while the saving and loading of weights file
39and the evaluation is done here.
40*/
41
43
45#include "TMVA/DataSetInfo.h"
46#include "TMVA/Factory.h"
47#include "TMVA/IMethod.h"
48#include "TMVA/MethodBase.h"
49#include "TMVA/MethodBoost.h"
50#include "TMVA/MsgLogger.h"
51#include "TMVA/Tools.h"
52#include "TMVA/Types.h"
53#include "TMVA/Config.h"
54
55#include "TRandom3.h"
56
57#include <iostream>
58#include <algorithm>
59#include <vector>
60
61
62using std::vector;
63
65
66////////////////////////////////////////////////////////////////////////////////
67
69 Types::EMVA methodType,
70 const TString& methodTitle,
71 DataSetInfo& theData,
72 const TString& theOption )
73: TMVA::MethodBase( jobName, methodType, methodTitle, theData, theOption),
74 fCurrentMethodIdx(0), fCurrentMethod(0)
75{}
76
77////////////////////////////////////////////////////////////////////////////////
78
80 DataSetInfo& dsi,
81 const TString& weightFile)
82 : TMVA::MethodBase( methodType, dsi, weightFile),
83 fCurrentMethodIdx(0), fCurrentMethod(0)
84{}
85
86////////////////////////////////////////////////////////////////////////////////
87/// returns pointer to MVA that corresponds to given method title
88
90{
91 std::vector<IMethod*>::const_iterator itrMethod = fMethods.begin();
92 std::vector<IMethod*>::const_iterator itrMethodEnd = fMethods.end();
93
94 for (; itrMethod != itrMethodEnd; ++itrMethod) {
95 MethodBase* mva = dynamic_cast<MethodBase*>(*itrMethod);
96 if ( (mva->GetMethodName())==methodTitle ) return mva;
97 }
98 return 0;
99}
100
101////////////////////////////////////////////////////////////////////////////////
102/// returns pointer to MVA that corresponds to given method index
103
105{
106 std::vector<IMethod*>::const_iterator itrMethod = fMethods.begin()+index;
107 if (itrMethod<fMethods.end()) return *itrMethod;
108 else return 0;
109}
110
111
112////////////////////////////////////////////////////////////////////////////////
113
115{
116 void* wght = gTools().AddChild(parent, "Weights");
117 gTools().AddAttr( wght, "NMethods", fMethods.size() );
118 for (UInt_t i=0; i< fMethods.size(); i++)
119 {
120 void* methxml = gTools().AddChild( wght, "Method" );
121 MethodBase* method = dynamic_cast<MethodBase*>(fMethods[i]);
122 gTools().AddAttr(methxml,"Index", i );
123 gTools().AddAttr(methxml,"Weight", fMethodWeight[i]);
124 gTools().AddAttr(methxml,"MethodSigCut", method->GetSignalReferenceCut());
125 gTools().AddAttr(methxml,"MethodSigCutOrientation", method->GetSignalReferenceCutOrientation());
126 gTools().AddAttr(methxml,"MethodTypeName", method->GetMethodTypeName());
127 gTools().AddAttr(methxml,"MethodName", method->GetMethodName() );
128 gTools().AddAttr(methxml,"JobName", method->GetJobName());
129 gTools().AddAttr(methxml,"Options", method->GetOptions());
130 if (method->fTransformationPointer)
131 gTools().AddAttr(methxml,"UseMainMethodTransformation", TString("true"));
132 else
133 gTools().AddAttr(methxml,"UseMainMethodTransformation", TString("false"));
134 method->AddWeightsXMLTo(methxml);
135 }
136}
137
138////////////////////////////////////////////////////////////////////////////////
139/// delete methods
140
142{
143 std::vector<IMethod*>::iterator itrMethod = fMethods.begin();
144 for (; itrMethod != fMethods.end(); ++itrMethod) {
145 Log() << kVERBOSE << "Delete method: " << (*itrMethod)->GetName() << Endl;
146 delete (*itrMethod);
147 }
148 fMethods.clear();
149}
150
151////////////////////////////////////////////////////////////////////////////////
152/// XML streamer
153
155{
156 UInt_t nMethods;
157 TString methodName, methodTypeName, jobName, optionString;
158
159 for (UInt_t i=0;i<fMethods.size();i++) delete fMethods[i];
160 fMethods.clear();
161 fMethodWeight.clear();
162 gTools().ReadAttr( wghtnode, "NMethods", nMethods );
163 void* ch = gTools().GetChild(wghtnode);
164 for (UInt_t i=0; i< nMethods; i++) {
165 Double_t methodWeight, methodSigCut, methodSigCutOrientation;
166 gTools().ReadAttr( ch, "Weight", methodWeight );
167 gTools().ReadAttr( ch, "MethodSigCut", methodSigCut);
168 gTools().ReadAttr( ch, "MethodSigCutOrientation", methodSigCutOrientation);
169 gTools().ReadAttr( ch, "MethodTypeName", methodTypeName );
170 gTools().ReadAttr( ch, "MethodName", methodName );
171 gTools().ReadAttr( ch, "JobName", jobName );
172 gTools().ReadAttr( ch, "Options", optionString );
173
174 // Bool_t rerouteTransformation = kFALSE;
175 if (gTools().HasAttr( ch, "UseMainMethodTransformation")) {
176 TString rerouteString("");
177 gTools().ReadAttr( ch, "UseMainMethodTransformation", rerouteString );
178 rerouteString.ToLower();
179 // if (rerouteString=="true")
180 // rerouteTransformation=kTRUE;
181 }
182
183 //remove trailing "~" to signal that options have to be reused
184 optionString.ReplaceAll("~","");
185 //ignore meta-options for method Boost
186 optionString.ReplaceAll("Boost_","~Boost_");
187 optionString.ReplaceAll("!~","~!");
188
189 if (i==0){
190 // the cast on MethodBoost is ugly, but a similar line is also in ReadWeightsFromFile --> needs to be fixed later
191 ((TMVA::MethodBoost*)this)->BookMethod( Types::Instance().GetMethodType( methodTypeName), methodName, optionString );
192 }
193 fMethods.push_back(
194 ClassifierFactory::Instance().Create(methodTypeName.Data(), jobName, methodName, DataInfo(), optionString));
195
196 fMethodWeight.push_back(methodWeight);
197 MethodBase* meth = dynamic_cast<MethodBase*>(fMethods.back());
198
199 if(meth==0)
200 Log() << kFATAL << "Could not read method from XML" << Endl;
201
202 void* methXML = gTools().GetChild(ch);
203
204 TString _fFileDir= meth->DataInfo().GetName();
205 _fFileDir+="/"+gConfig().GetIONames().fWeightFileDir;
206 meth->SetWeightFileDir(_fFileDir);
207 meth->SetModelPersistence(IsModelPersistence());
208 meth->SetSilentFile(IsSilentFile());
209 meth->SetupMethod();
210 meth->SetMsgType(kWARNING);
211 meth->ParseOptions();
212 meth->ProcessSetup();
213 meth->CheckSetup();
214 meth->ReadWeightsFromXML(methXML);
215 meth->SetSignalReferenceCut(methodSigCut);
216 meth->SetSignalReferenceCutOrientation(methodSigCutOrientation);
217
218 meth->RerouteTransformationHandler (&(this->GetTransformationHandler()));
219
220 ch = gTools().GetNextChild(ch);
221 }
222 //Log() << kINFO << "Reading methods from XML done " << Endl;
223}
224
225////////////////////////////////////////////////////////////////////////////////
226/// text streamer
227
229{
230 TString var, dummy;
231 TString methodName, methodTitle=GetMethodName(),
232 jobName=GetJobName(),optionString=GetOptions();
233 UInt_t methodNum; Double_t methodWeight;
234 // and read the Weights (BDT coefficients)
235 // coverity[tainted_data_argument]
236 istr >> dummy >> methodNum;
237 Log() << kINFO << "Read " << methodNum << " Classifiers" << Endl;
238 for (UInt_t i=0;i<fMethods.size();i++) delete fMethods[i];
239 fMethods.clear();
240 fMethodWeight.clear();
241 for (UInt_t i=0; i<methodNum; i++) {
242 istr >> dummy >> methodName >> dummy >> fCurrentMethodIdx >> dummy >> methodWeight;
243 if ((UInt_t)fCurrentMethodIdx != i) {
244 Log() << kFATAL << "Error while reading weight file; mismatch MethodIndex="
245 << fCurrentMethodIdx << " i=" << i
246 << " MethodName " << methodName
247 << " dummy " << dummy
248 << " MethodWeight= " << methodWeight
249 << Endl;
250 }
251 if (GetMethodType() != Types::kBoost || i==0) {
252 istr >> dummy >> jobName;
253 istr >> dummy >> methodTitle;
254 istr >> dummy >> optionString;
255 if (GetMethodType() == Types::kBoost)
256 ((TMVA::MethodBoost*)this)->BookMethod( Types::Instance().GetMethodType( methodName), methodTitle, optionString );
257 }
258 else methodTitle=Form("%s (%04i)",GetMethodName().Data(),fCurrentMethodIdx);
259 fMethods.push_back(
260 ClassifierFactory::Instance().Create(methodName.Data(), jobName, methodTitle, DataInfo(), optionString));
261 fMethodWeight.push_back( methodWeight );
262 if(MethodBase* m = dynamic_cast<MethodBase*>(fMethods.back()) )
263 m->ReadWeightsFromStream(istr);
264 }
265}
266
267////////////////////////////////////////////////////////////////////////////////
268/// return composite MVA response
269
271{
272 Double_t mvaValue = 0;
273 for (UInt_t i=0;i< fMethods.size(); i++) mvaValue+=fMethods[i]->GetMvaValue()*fMethodWeight[i];
274
275 // cannot determine error
276 NoErrorCalc(err, errUpper);
277
278 return mvaValue;
279}
int Int_t
Definition: RtypesCore.h:45
unsigned int UInt_t
Definition: RtypesCore.h:46
double Double_t
Definition: RtypesCore.h:59
#define ClassImp(name)
Definition: Rtypes.h:364
char * Form(const char *fmt,...)
IMethod * Create(const std::string &name, const TString &job, const TString &title, DataSetInfo &dsi, const TString &option)
creates the method if needed based on the method name using the creator function the factory has stor...
static ClassifierFactory & Instance()
access to the ClassifierFactory singleton creates the instance if needed
TString fWeightFileDir
Definition: Config.h:126
IONames & GetIONames()
Definition: Config.h:100
virtual void ParseOptions()
options parser
const TString & GetOptions() const
Definition: Configurable.h:84
void SetMsgType(EMsgType t)
Definition: Configurable.h:125
Class that contains all the data information.
Definition: DataSetInfo.h:62
virtual const char * GetName() const
Returns name of object.
Definition: DataSetInfo.h:71
Interface for all concrete MVA method implementations.
Definition: IMethod.h:53
Virtual base Class for all MVA method.
Definition: MethodBase.h:111
TransformationHandler * fTransformationPointer
Definition: MethodBase.h:671
void SetSilentFile(Bool_t status)
Definition: MethodBase.h:378
void SetWeightFileDir(TString fileDir)
set directory of weight file
TString GetMethodTypeName() const
Definition: MethodBase.h:332
virtual void ReadWeightsFromXML(void *wghtnode)=0
const TString & GetJobName() const
Definition: MethodBase.h:330
void SetupMethod()
setup of methods
Definition: MethodBase.cxx:406
friend class MethodCompositeBase
Definition: MethodBase.h:270
const TString & GetMethodName() const
Definition: MethodBase.h:331
void ProcessSetup()
process all options the "CheckForUnusedOptions" is done in an independent call, since it may be overr...
Definition: MethodBase.cxx:423
void RerouteTransformationHandler(TransformationHandler *fTargetTransformation)
Definition: MethodBase.h:403
DataSetInfo & DataInfo() const
Definition: MethodBase.h:410
virtual void AddWeightsXMLTo(void *parent) const =0
Double_t GetSignalReferenceCutOrientation() const
Definition: MethodBase.h:361
void SetSignalReferenceCut(Double_t cut)
Definition: MethodBase.h:364
void SetSignalReferenceCutOrientation(Double_t cutOrientation)
Definition: MethodBase.h:365
void SetModelPersistence(Bool_t status)
Definition: MethodBase.h:382
Double_t GetSignalReferenceCut() const
Definition: MethodBase.h:360
virtual void CheckSetup()
check may be overridden by derived class (sometimes, eg, fitters are used which can only be implement...
Definition: MethodBase.cxx:433
Class for boosting a TMVA method.
Definition: MethodBoost.h:58
Virtual base class for combining several TMVA method.
void ReadWeightsFromStream(std::istream &istr)
text streamer
Double_t GetMvaValue(Double_t *err=0, Double_t *errUpper=0)
return composite MVA response
virtual ~MethodCompositeBase(void)
delete methods
IMethod * GetMethod(const TString &title) const
returns pointer to MVA that corresponds to given method title
void ReadWeightsFromXML(void *wghtnode)
XML streamer.
void AddWeightsXMLTo(void *parent) const
void * GetNextChild(void *prevchild, const char *childname=0)
XML helpers.
Definition: Tools.cxx:1174
void * AddChild(void *parent, const char *childname, const char *content=0, bool isRootNode=false)
add child node
Definition: Tools.cxx:1136
void * GetChild(void *parent, const char *childname=0)
get child node
Definition: Tools.cxx:1162
void ReadAttr(void *node, const char *, T &value)
read attribute from xml
Definition: Tools.h:335
void AddAttr(void *node, const char *, const T &value, Int_t precision=16)
add attribute to xml
Definition: Tools.h:353
static Types & Instance()
the the single instance of "Types" if existing already, or create it (Singleton)
Definition: Types.cxx:69
@ kBoost
Definition: Types.h:95
Types::EMVA GetMethodType(const TString &method) const
returns the method type (enum) for a given method (string)
Definition: Types.cxx:120
@ kVERBOSE
Definition: Types.h:59
@ kINFO
Definition: Types.h:60
@ kWARNING
Definition: Types.h:61
@ kFATAL
Definition: Types.h:63
Basic string class.
Definition: TString.h:136
void ToLower()
Change string to lower-case.
Definition: TString.cxx:1150
const char * Data() const
Definition: TString.h:369
TString & ReplaceAll(const TString &s1, const TString &s2)
Definition: TString.h:692
void GetMethodName(TString &name, TKey *mkey)
Definition: tmvaglob.cxx:342
create variable transformations
Config & gConfig()
Tools & gTools()
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:158
Double_t Log(Double_t x)
Definition: TMath.h:760
auto * m
Definition: textangle.C:8