Logo ROOT  
Reference Guide
Loading...
Searching...
No Matches
MethodCompositeBase.cxx
Go to the documentation of this file.
1// @(#)root/tmva $Id$
2// Author: Andreas Hoecker, Joerg Stelzer, Helge Voss, Kai Voss,Or Cohen
3
4/*****************************************************************************
5 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6 * Package: TMVA *
7 * Class : MethodCompositeBase *
8 * *
9 * *
10 * Description: *
11 * Virtual base class for all MVA method *
12 * *
13 * Authors (alphabetical): *
14 * Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland *
15 * Joerg Stelzer <Joerg.Stelzer@cern.ch> - MSU, USA *
16 * Helge Voss <Helge.Voss@cern.ch> - MPI-K Heidelberg, Germany *
17 * Kai Voss <Kai.Voss@cern.ch> - U. of Victoria, Canada *
18 * Or Cohen <orcohenor@gmail.com> - Weizmann Inst., Israel *
19 * *
20 * Copyright (c) 2005: *
21 * CERN, Switzerland *
22 * U. of Victoria, Canada *
23 * MPI-K Heidelberg, Germany *
24 * LAPP, Annecy, France *
25 * *
26 * Redistribution and use in source and binary forms, with or without *
27 * modification, are permitted according to the terms listed in LICENSE *
28 * (see tmva/doc/LICENSE) *
29 *****************************************************************************/
30
31/*! \class TMVA::MethodCompositeBase
32\ingroup TMVA
33
34Virtual base class for combining several TMVA method.
35
36This class is virtual class meant to combine more than one classifier
37together. The training of the classifiers is done by classes that are
38derived from this one, while the saving and loading of weights file
39and the evaluation is done here.
40*/
41
43
45#include "TMVA/DataSetInfo.h"
46#include "TMVA/Factory.h"
47#include "TMVA/IMethod.h"
48#include "TMVA/MethodBase.h"
49#include "TMVA/MethodBoost.h"
50#include "TMVA/MsgLogger.h"
51#include "TMVA/Tools.h"
52#include "TMVA/Types.h"
53#include "TMVA/Config.h"
54
55#include "TRandom3.h"
56
57#include <iostream>
58#include <algorithm>
59#include <vector>
60
61
62using std::vector;
63
64
65////////////////////////////////////////////////////////////////////////////////
66
68 Types::EMVA methodType,
69 const TString& methodTitle,
70 DataSetInfo& theData,
71 const TString& theOption )
72: TMVA::MethodBase( jobName, methodType, methodTitle, theData, theOption),
74{}
75
76////////////////////////////////////////////////////////////////////////////////
77
79 DataSetInfo& dsi,
80 const TString& weightFile)
81 : TMVA::MethodBase( methodType, dsi, weightFile),
83{}
84
85////////////////////////////////////////////////////////////////////////////////
86/// returns pointer to MVA that corresponds to given method title
87
89{
90 std::vector<IMethod*>::const_iterator itrMethod = fMethods.begin();
91 std::vector<IMethod*>::const_iterator itrMethodEnd = fMethods.end();
92
93 for (; itrMethod != itrMethodEnd; ++itrMethod) {
94 MethodBase* mva = dynamic_cast<MethodBase*>(*itrMethod);
95 if ( (mva->GetMethodName())==methodTitle ) return mva;
96 }
97 return 0;
98}
99
100////////////////////////////////////////////////////////////////////////////////
101/// returns pointer to MVA that corresponds to given method index
102
104{
105 std::vector<IMethod*>::const_iterator itrMethod = fMethods.begin()+index;
106 if (itrMethod<fMethods.end()) return *itrMethod;
107 else return 0;
108}
109
110
111////////////////////////////////////////////////////////////////////////////////
112
114{
115 void* wght = gTools().AddChild(parent, "Weights");
116 gTools().AddAttr( wght, "NMethods", fMethods.size() );
117 for (UInt_t i=0; i< fMethods.size(); i++)
118 {
119 void* methxml = gTools().AddChild( wght, "Method" );
120 MethodBase* method = dynamic_cast<MethodBase*>(fMethods[i]);
121 gTools().AddAttr(methxml,"Index", i );
122 gTools().AddAttr(methxml,"Weight", fMethodWeight[i]);
123 gTools().AddAttr(methxml,"MethodSigCut", method->GetSignalReferenceCut());
124 gTools().AddAttr(methxml,"MethodSigCutOrientation", method->GetSignalReferenceCutOrientation());
125 gTools().AddAttr(methxml,"MethodTypeName", method->GetMethodTypeName());
126 gTools().AddAttr(methxml,"MethodName", method->GetMethodName() );
127 gTools().AddAttr(methxml,"JobName", method->GetJobName());
128 gTools().AddAttr(methxml,"Options", method->GetOptions());
129 if (method->fTransformationPointer)
130 gTools().AddAttr(methxml,"UseMainMethodTransformation", TString("true"));
131 else
132 gTools().AddAttr(methxml,"UseMainMethodTransformation", TString("false"));
133 method->AddWeightsXMLTo(methxml);
134 }
135}
136
137////////////////////////////////////////////////////////////////////////////////
138/// delete methods
139
141{
142 std::vector<IMethod*>::iterator itrMethod = fMethods.begin();
143 for (; itrMethod != fMethods.end(); ++itrMethod) {
144 Log() << kVERBOSE << "Delete method: " << (*itrMethod)->GetName() << Endl;
145 delete (*itrMethod);
146 }
147 fMethods.clear();
148}
149
150////////////////////////////////////////////////////////////////////////////////
151/// XML streamer
152
154{
155 UInt_t nMethods;
156 TString methodName, methodTypeName, jobName, optionString;
157
158 for (UInt_t i=0;i<fMethods.size();i++) delete fMethods[i];
159 fMethods.clear();
160 fMethodWeight.clear();
161 gTools().ReadAttr( wghtnode, "NMethods", nMethods );
162 void* ch = gTools().GetChild(wghtnode);
163 for (UInt_t i=0; i< nMethods; i++) {
164 Double_t methodWeight, methodSigCut, methodSigCutOrientation;
165 gTools().ReadAttr( ch, "Weight", methodWeight );
166 gTools().ReadAttr( ch, "MethodSigCut", methodSigCut);
167 gTools().ReadAttr( ch, "MethodSigCutOrientation", methodSigCutOrientation);
168 gTools().ReadAttr( ch, "MethodTypeName", methodTypeName );
169 gTools().ReadAttr( ch, "MethodName", methodName );
170 gTools().ReadAttr( ch, "JobName", jobName );
171 gTools().ReadAttr( ch, "Options", optionString );
172
173 // Bool_t rerouteTransformation = kFALSE;
174 if (gTools().HasAttr( ch, "UseMainMethodTransformation")) {
175 TString rerouteString("");
176 gTools().ReadAttr( ch, "UseMainMethodTransformation", rerouteString );
177 rerouteString.ToLower();
178 // if (rerouteString=="true")
179 // rerouteTransformation=kTRUE;
180 }
181
182 //remove trailing "~" to signal that options have to be reused
183 optionString.ReplaceAll("~","");
184 //ignore meta-options for method Boost
185 optionString.ReplaceAll("Boost_","~Boost_");
186 optionString.ReplaceAll("!~","~!");
187
188 if (i==0){
189 // the cast on MethodBoost is ugly, but a similar line is also in ReadWeightsFromFile --> needs to be fixed later
190 ((TMVA::MethodBoost*)this)->BookMethod( Types::Instance().GetMethodType( methodTypeName), methodName, optionString );
191 }
192 fMethods.push_back(
193 ClassifierFactory::Instance().Create(methodTypeName.Data(), jobName, methodName, DataInfo(), optionString));
194
195 fMethodWeight.push_back(methodWeight);
196 MethodBase* meth = dynamic_cast<MethodBase*>(fMethods.back());
197
198 if(meth==0)
199 Log() << kFATAL << "Could not read method from XML" << Endl;
200
201 void* methXML = gTools().GetChild(ch);
202
203 TString _fFileDir= meth->DataInfo().GetName();
204 _fFileDir+="/"+gConfig().GetIONames().fWeightFileDir;
205 meth->SetWeightFileDir(_fFileDir);
208 meth->SetupMethod();
209 meth->SetMsgType(kWARNING);
210 meth->ParseOptions();
211 meth->ProcessSetup();
212 meth->CheckSetup();
213 meth->ReadWeightsFromXML(methXML);
214 meth->SetSignalReferenceCut(methodSigCut);
215 meth->SetSignalReferenceCutOrientation(methodSigCutOrientation);
216
218
219 ch = gTools().GetNextChild(ch);
220 }
221 //Log() << kINFO << "Reading methods from XML done " << Endl;
222}
223
224////////////////////////////////////////////////////////////////////////////////
225/// text streamer
226
228{
229 TString var, dummy;
230 TString methodName, methodTitle = GetMethodName(),
231 jobName=GetJobName(),optionString = GetOptions();
232 UInt_t methodNum; Double_t methodWeight;
233 // and read the Weights (BDT coefficients)
234 // coverity[tainted_data_argument]
235 istr >> dummy >> methodNum;
236 Log() << kINFO << "Read " << methodNum << " Classifiers" << Endl;
237 for (UInt_t i=0;i<fMethods.size();i++) delete fMethods[i];
238 fMethods.clear();
239 fMethodWeight.clear();
240 for (UInt_t i=0; i<methodNum; i++) {
241 istr >> dummy >> methodName >> dummy >> fCurrentMethodIdx >> dummy >> methodWeight;
242 if ((UInt_t)fCurrentMethodIdx != i) {
243 Log() << kFATAL << "Error while reading weight file; mismatch MethodIndex="
244 << fCurrentMethodIdx << " i=" << i
245 << " MethodName " << methodName
246 << " dummy " << dummy
247 << " MethodWeight= " << methodWeight
248 << Endl;
249 }
250 if (GetMethodType() != Types::kBoost || i==0) {
251 istr >> dummy >> jobName;
252 istr >> dummy >> methodTitle;
253 istr >> dummy >> optionString;
255 ((TMVA::MethodBoost*)this)->BookMethod( Types::Instance().GetMethodType( methodName), methodTitle, optionString );
256 } else {
257 methodTitle = TString::Format("%s (%04i)",GetMethodName().Data(),fCurrentMethodIdx);
258 }
259 fMethods.push_back(
260 ClassifierFactory::Instance().Create(methodName.Data(), jobName, methodTitle, DataInfo(), optionString));
261 fMethodWeight.push_back( methodWeight );
262 if(MethodBase* m = dynamic_cast<MethodBase*>(fMethods.back()) )
263 m->ReadWeightsFromStream(istr);
264 }
265}
266
267////////////////////////////////////////////////////////////////////////////////
268/// return composite MVA response
269
271{
272 Double_t mvaValue = 0;
273 for (UInt_t i=0;i< fMethods.size(); i++) mvaValue+=fMethods[i]->GetMvaValue()*fMethodWeight[i];
274
275 // cannot determine error
276 NoErrorCalc(err, errUpper);
277
278 return mvaValue;
279}
int Int_t
Signed integer 4 bytes (int).
Definition RtypesCore.h:59
unsigned int UInt_t
Unsigned integer 4 bytes (unsigned int).
Definition RtypesCore.h:60
double Double_t
Double 8 bytes.
Definition RtypesCore.h:73
Double_t err
static ClassifierFactory & Instance()
access to the ClassifierFactory singleton creates the instance if needed
TString fWeightFileDir
Definition Config.h:124
IONames & GetIONames()
Definition Config.h:98
virtual void ParseOptions()
options parser
const TString & GetOptions() const
MsgLogger & Log() const
void SetMsgType(EMsgType t)
Class that contains all the data information.
Definition DataSetInfo.h:62
const char * GetName() const override
Returns name of object.
Definition DataSetInfo.h:71
Interface for all concrete MVA method implementations.
Definition IMethod.h:53
TransformationHandler * fTransformationPointer
pointer to the rest of transformations
Definition MethodBase.h:674
MethodBase(const TString &jobName, Types::EMVA methodType, const TString &methodTitle, DataSetInfo &dsi, const TString &theOption="")
standard constructor
void SetSilentFile(Bool_t status)
Definition MethodBase.h:381
void SetWeightFileDir(TString fileDir)
set directory of weight file
TString GetMethodTypeName() const
Definition MethodBase.h:335
virtual void ReadWeightsFromXML(void *wghtnode)=0
Bool_t IsModelPersistence() const
Definition MethodBase.h:386
const TString & GetJobName() const
Definition MethodBase.h:333
void SetupMethod()
setup of methods
friend class MethodCompositeBase
Definition MethodBase.h:273
const TString & GetMethodName() const
Definition MethodBase.h:334
void ProcessSetup()
process all options the "CheckForUnusedOptions" is done in an independent call, since it may be overr...
void RerouteTransformationHandler(TransformationHandler *fTargetTransformation)
Definition MethodBase.h:406
DataSetInfo & DataInfo() const
Definition MethodBase.h:413
virtual void AddWeightsXMLTo(void *parent) const =0
TransformationHandler & GetTransformationHandler(Bool_t takeReroutedIfAvailable=true)
Definition MethodBase.h:397
Bool_t IsSilentFile() const
Definition MethodBase.h:382
Types::EMVA GetMethodType() const
Definition MethodBase.h:336
Double_t GetSignalReferenceCutOrientation() const
Definition MethodBase.h:364
void SetSignalReferenceCut(Double_t cut)
Definition MethodBase.h:367
void NoErrorCalc(Double_t *const err, Double_t *const errUpper)
void SetSignalReferenceCutOrientation(Double_t cutOrientation)
Definition MethodBase.h:368
DataSet * Data() const
Definition MethodBase.h:412
void SetModelPersistence(Bool_t status)
Definition MethodBase.h:385
Double_t GetSignalReferenceCut() const
Definition MethodBase.h:363
virtual void CheckSetup()
check may be overridden by derived class (sometimes, eg, fitters are used which can only be implement...
Class for boosting a TMVA method.
Definition MethodBoost.h:58
void AddWeightsXMLTo(void *parent) const override
std::vector< Double_t > fMethodWeight
virtual ~MethodCompositeBase(void)
delete methods
std::vector< IMethod * > fMethods
vector of all classifiers
IMethod * GetMethod(const TString &title) const
accessor by name
void ReadWeightsFromXML(void *wghtnode) override
XML streamer.
Double_t GetMvaValue(Double_t *err=nullptr, Double_t *errUpper=nullptr) override
return composite MVA response
void ReadWeightsFromStream(std::istream &istr) override
text streamer
void ReadAttr(void *node, const char *, T &value)
read attribute from xml
Definition Tools.h:329
void * GetChild(void *parent, const char *childname=nullptr)
get child node
Definition Tools.cxx:1125
void AddAttr(void *node, const char *, const T &value, Int_t precision=16)
add attribute to xml
Definition Tools.h:347
void * AddChild(void *parent, const char *childname, const char *content=nullptr, bool isRootNode=false)
add child node
Definition Tools.cxx:1099
void * GetNextChild(void *prevchild, const char *childname=nullptr)
XML helpers.
Definition Tools.cxx:1137
static Types & Instance()
The single instance of "Types" if existing already, or create it (Singleton).
Definition Types.cxx:70
Basic string class.
Definition TString.h:138
void ToLower()
Change string to lower-case.
Definition TString.cxx:1189
const char * Data() const
Definition TString.h:384
TString & ReplaceAll(const TString &s1, const TString &s2)
Definition TString.h:713
static TString Format(const char *fmt,...)
Static method which formats a string using a printf style format descriptor and return a TString.
Definition TString.cxx:2385
create variable transformations
Config & gConfig()
Tools & gTools()
MsgLogger & Endl(MsgLogger &ml)
Definition MsgLogger.h:148
TMarker m
Definition textangle.C:8