Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
MethodCompositeBase.cxx
Go to the documentation of this file.
1// @(#)root/tmva $Id$
2// Author: Andreas Hoecker, Joerg Stelzer, Helge Voss, Kai Voss,Or Cohen
3
4/*****************************************************************************
5 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6 * Package: TMVA *
7 * Class : MethodCompositeBase *
8 * *
9 * *
10 * Description: *
11 * Virtual base class for all MVA method *
12 * *
13 * Authors (alphabetical): *
14 * Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland *
15 * Joerg Stelzer <Joerg.Stelzer@cern.ch> - MSU, USA *
16 * Helge Voss <Helge.Voss@cern.ch> - MPI-K Heidelberg, Germany *
17 * Kai Voss <Kai.Voss@cern.ch> - U. of Victoria, Canada *
18 * Or Cohen <orcohenor@gmail.com> - Weizmann Inst., Israel *
19 * *
20 * Copyright (c) 2005: *
21 * CERN, Switzerland *
22 * U. of Victoria, Canada *
23 * MPI-K Heidelberg, Germany *
24 * LAPP, Annecy, France *
25 * *
26 * Redistribution and use in source and binary forms, with or without *
27 * modification, are permitted according to the terms listed in LICENSE *
28 * (see tmva/doc/LICENSE) *
29 *****************************************************************************/
30
31/*! \class TMVA::MethodCompositeBase
32\ingroup TMVA
33
34Virtual base class for combining several TMVA method.
35
36This class is virtual class meant to combine more than one classifier
37together. The training of the classifiers is done by classes that are
38derived from this one, while the saving and loading of weights file
39and the evaluation is done here.
40*/
41
43
45#include "TMVA/DataSetInfo.h"
46#include "TMVA/Factory.h"
47#include "TMVA/IMethod.h"
48#include "TMVA/MethodBase.h"
49#include "TMVA/MethodBoost.h"
50#include "TMVA/MsgLogger.h"
51#include "TMVA/Tools.h"
52#include "TMVA/Types.h"
53#include "TMVA/Config.h"
54
55#include "TRandom3.h"
56
57#include <iostream>
58#include <algorithm>
59#include <vector>
60
61
62using std::vector;
63
65
66////////////////////////////////////////////////////////////////////////////////
67
69 Types::EMVA methodType,
70 const TString& methodTitle,
71 DataSetInfo& theData,
72 const TString& theOption )
73: TMVA::MethodBase( jobName, methodType, methodTitle, theData, theOption),
74 fCurrentMethodIdx(0), fCurrentMethod(0)
75{}
76
77////////////////////////////////////////////////////////////////////////////////
78
80 DataSetInfo& dsi,
81 const TString& weightFile)
82 : TMVA::MethodBase( methodType, dsi, weightFile),
83 fCurrentMethodIdx(0), fCurrentMethod(0)
84{}
85
86////////////////////////////////////////////////////////////////////////////////
87/// returns pointer to MVA that corresponds to given method title
88
90{
91 std::vector<IMethod*>::const_iterator itrMethod = fMethods.begin();
92 std::vector<IMethod*>::const_iterator itrMethodEnd = fMethods.end();
93
94 for (; itrMethod != itrMethodEnd; ++itrMethod) {
95 MethodBase* mva = dynamic_cast<MethodBase*>(*itrMethod);
96 if ( (mva->GetMethodName())==methodTitle ) return mva;
97 }
98 return 0;
99}
100
101////////////////////////////////////////////////////////////////////////////////
102/// returns pointer to MVA that corresponds to given method index
103
105{
106 std::vector<IMethod*>::const_iterator itrMethod = fMethods.begin()+index;
107 if (itrMethod<fMethods.end()) return *itrMethod;
108 else return 0;
109}
110
111
112////////////////////////////////////////////////////////////////////////////////
113
115{
116 void* wght = gTools().AddChild(parent, "Weights");
117 gTools().AddAttr( wght, "NMethods", fMethods.size() );
118 for (UInt_t i=0; i< fMethods.size(); i++)
119 {
120 void* methxml = gTools().AddChild( wght, "Method" );
121 MethodBase* method = dynamic_cast<MethodBase*>(fMethods[i]);
122 gTools().AddAttr(methxml,"Index", i );
123 gTools().AddAttr(methxml,"Weight", fMethodWeight[i]);
124 gTools().AddAttr(methxml,"MethodSigCut", method->GetSignalReferenceCut());
125 gTools().AddAttr(methxml,"MethodSigCutOrientation", method->GetSignalReferenceCutOrientation());
126 gTools().AddAttr(methxml,"MethodTypeName", method->GetMethodTypeName());
127 gTools().AddAttr(methxml,"MethodName", method->GetMethodName() );
128 gTools().AddAttr(methxml,"JobName", method->GetJobName());
129 gTools().AddAttr(methxml,"Options", method->GetOptions());
130 if (method->fTransformationPointer)
131 gTools().AddAttr(methxml,"UseMainMethodTransformation", TString("true"));
132 else
133 gTools().AddAttr(methxml,"UseMainMethodTransformation", TString("false"));
134 method->AddWeightsXMLTo(methxml);
135 }
136}
137
138////////////////////////////////////////////////////////////////////////////////
139/// delete methods
140
142{
143 std::vector<IMethod*>::iterator itrMethod = fMethods.begin();
144 for (; itrMethod != fMethods.end(); ++itrMethod) {
145 Log() << kVERBOSE << "Delete method: " << (*itrMethod)->GetName() << Endl;
146 delete (*itrMethod);
147 }
148 fMethods.clear();
149}
150
151////////////////////////////////////////////////////////////////////////////////
152/// XML streamer
153
155{
156 UInt_t nMethods;
157 TString methodName, methodTypeName, jobName, optionString;
158
159 for (UInt_t i=0;i<fMethods.size();i++) delete fMethods[i];
160 fMethods.clear();
161 fMethodWeight.clear();
162 gTools().ReadAttr( wghtnode, "NMethods", nMethods );
163 void* ch = gTools().GetChild(wghtnode);
164 for (UInt_t i=0; i< nMethods; i++) {
165 Double_t methodWeight, methodSigCut, methodSigCutOrientation;
166 gTools().ReadAttr( ch, "Weight", methodWeight );
167 gTools().ReadAttr( ch, "MethodSigCut", methodSigCut);
168 gTools().ReadAttr( ch, "MethodSigCutOrientation", methodSigCutOrientation);
169 gTools().ReadAttr( ch, "MethodTypeName", methodTypeName );
170 gTools().ReadAttr( ch, "MethodName", methodName );
171 gTools().ReadAttr( ch, "JobName", jobName );
172 gTools().ReadAttr( ch, "Options", optionString );
173
174 // Bool_t rerouteTransformation = kFALSE;
175 if (gTools().HasAttr( ch, "UseMainMethodTransformation")) {
176 TString rerouteString("");
177 gTools().ReadAttr( ch, "UseMainMethodTransformation", rerouteString );
178 rerouteString.ToLower();
179 // if (rerouteString=="true")
180 // rerouteTransformation=kTRUE;
181 }
182
183 //remove trailing "~" to signal that options have to be reused
184 optionString.ReplaceAll("~","");
185 //ignore meta-options for method Boost
186 optionString.ReplaceAll("Boost_","~Boost_");
187 optionString.ReplaceAll("!~","~!");
188
189 if (i==0){
190 // the cast on MethodBoost is ugly, but a similar line is also in ReadWeightsFromFile --> needs to be fixed later
191 ((TMVA::MethodBoost*)this)->BookMethod( Types::Instance().GetMethodType( methodTypeName), methodName, optionString );
192 }
193 fMethods.push_back(
194 ClassifierFactory::Instance().Create(methodTypeName.Data(), jobName, methodName, DataInfo(), optionString));
195
196 fMethodWeight.push_back(methodWeight);
197 MethodBase* meth = dynamic_cast<MethodBase*>(fMethods.back());
198
199 if(meth==0)
200 Log() << kFATAL << "Could not read method from XML" << Endl;
201
202 void* methXML = gTools().GetChild(ch);
203
204 TString _fFileDir= meth->DataInfo().GetName();
205 _fFileDir+="/"+gConfig().GetIONames().fWeightFileDir;
206 meth->SetWeightFileDir(_fFileDir);
207 meth->SetModelPersistence(IsModelPersistence());
208 meth->SetSilentFile(IsSilentFile());
209 meth->SetupMethod();
210 meth->SetMsgType(kWARNING);
211 meth->ParseOptions();
212 meth->ProcessSetup();
213 meth->CheckSetup();
214 meth->ReadWeightsFromXML(methXML);
215 meth->SetSignalReferenceCut(methodSigCut);
216 meth->SetSignalReferenceCutOrientation(methodSigCutOrientation);
217
218 meth->RerouteTransformationHandler (&(this->GetTransformationHandler()));
219
220 ch = gTools().GetNextChild(ch);
221 }
222 //Log() << kINFO << "Reading methods from XML done " << Endl;
223}
224
225////////////////////////////////////////////////////////////////////////////////
226/// text streamer
227
229{
230 TString var, dummy;
231 TString methodName, methodTitle = GetMethodName(),
232 jobName=GetJobName(),optionString = GetOptions();
233 UInt_t methodNum; Double_t methodWeight;
234 // and read the Weights (BDT coefficients)
235 // coverity[tainted_data_argument]
236 istr >> dummy >> methodNum;
237 Log() << kINFO << "Read " << methodNum << " Classifiers" << Endl;
238 for (UInt_t i=0;i<fMethods.size();i++) delete fMethods[i];
239 fMethods.clear();
240 fMethodWeight.clear();
241 for (UInt_t i=0; i<methodNum; i++) {
242 istr >> dummy >> methodName >> dummy >> fCurrentMethodIdx >> dummy >> methodWeight;
243 if ((UInt_t)fCurrentMethodIdx != i) {
244 Log() << kFATAL << "Error while reading weight file; mismatch MethodIndex="
245 << fCurrentMethodIdx << " i=" << i
246 << " MethodName " << methodName
247 << " dummy " << dummy
248 << " MethodWeight= " << methodWeight
249 << Endl;
250 }
251 if (GetMethodType() != Types::kBoost || i==0) {
252 istr >> dummy >> jobName;
253 istr >> dummy >> methodTitle;
254 istr >> dummy >> optionString;
255 if (GetMethodType() == Types::kBoost)
256 ((TMVA::MethodBoost*)this)->BookMethod( Types::Instance().GetMethodType( methodName), methodTitle, optionString );
257 } else {
258 methodTitle = TString::Format("%s (%04i)",GetMethodName().Data(),fCurrentMethodIdx);
259 }
260 fMethods.push_back(
261 ClassifierFactory::Instance().Create(methodName.Data(), jobName, methodTitle, DataInfo(), optionString));
262 fMethodWeight.push_back( methodWeight );
263 if(MethodBase* m = dynamic_cast<MethodBase*>(fMethods.back()) )
264 m->ReadWeightsFromStream(istr);
265 }
266}
267
268////////////////////////////////////////////////////////////////////////////////
269/// return composite MVA response
270
272{
273 Double_t mvaValue = 0;
274 for (UInt_t i=0;i< fMethods.size(); i++) mvaValue+=fMethods[i]->GetMvaValue()*fMethodWeight[i];
275
276 // cannot determine error
277 NoErrorCalc(err, errUpper);
278
279 return mvaValue;
280}
#define ClassImp(name)
Definition Rtypes.h:382
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t index
IMethod * Create(const std::string &name, const TString &job, const TString &title, DataSetInfo &dsi, const TString &option)
creates the method if needed based on the method name using the creator function the factory has stor...
static ClassifierFactory & Instance()
access to the ClassifierFactory singleton creates the instance if needed
TString fWeightFileDir
Definition Config.h:124
IONames & GetIONames()
Definition Config.h:98
virtual void ParseOptions()
options parser
const TString & GetOptions() const
void SetMsgType(EMsgType t)
Class that contains all the data information.
Definition DataSetInfo.h:62
virtual const char * GetName() const
Returns name of object.
Definition DataSetInfo.h:71
Interface for all concrete MVA method implementations.
Definition IMethod.h:53
Virtual base Class for all MVA method.
Definition MethodBase.h:111
TransformationHandler * fTransformationPointer
pointer to the rest of transformations
Definition MethodBase.h:671
void SetSilentFile(Bool_t status)
Definition MethodBase.h:378
void SetWeightFileDir(TString fileDir)
set directory of weight file
TString GetMethodTypeName() const
Definition MethodBase.h:332
virtual void ReadWeightsFromXML(void *wghtnode)=0
const TString & GetJobName() const
Definition MethodBase.h:330
void SetupMethod()
setup of methods
friend class MethodCompositeBase
Definition MethodBase.h:270
const TString & GetMethodName() const
Definition MethodBase.h:331
void ProcessSetup()
process all options the "CheckForUnusedOptions" is done in an independent call, since it may be overr...
void RerouteTransformationHandler(TransformationHandler *fTargetTransformation)
Definition MethodBase.h:403
DataSetInfo & DataInfo() const
Definition MethodBase.h:410
virtual void AddWeightsXMLTo(void *parent) const =0
Double_t GetSignalReferenceCutOrientation() const
Definition MethodBase.h:361
void SetSignalReferenceCut(Double_t cut)
Definition MethodBase.h:364
void SetSignalReferenceCutOrientation(Double_t cutOrientation)
Definition MethodBase.h:365
void SetModelPersistence(Bool_t status)
Definition MethodBase.h:382
Double_t GetSignalReferenceCut() const
Definition MethodBase.h:360
virtual void CheckSetup()
check may be overridden by derived class (sometimes, eg, fitters are used which can only be implement...
Class for boosting a TMVA method.
Definition MethodBoost.h:58
Virtual base class for combining several TMVA method.
void ReadWeightsFromStream(std::istream &istr)
text streamer
virtual ~MethodCompositeBase(void)
delete methods
Double_t GetMvaValue(Double_t *err=nullptr, Double_t *errUpper=nullptr)
return composite MVA response
IMethod * GetMethod(const TString &title) const
accessor by name
void ReadWeightsFromXML(void *wghtnode)
XML streamer.
void AddWeightsXMLTo(void *parent) const
void ReadAttr(void *node, const char *, T &value)
read attribute from xml
Definition Tools.h:329
void * GetChild(void *parent, const char *childname=nullptr)
get child node
Definition Tools.cxx:1150
void AddAttr(void *node, const char *, const T &value, Int_t precision=16)
add attribute to xml
Definition Tools.h:347
void * AddChild(void *parent, const char *childname, const char *content=nullptr, bool isRootNode=false)
add child node
Definition Tools.cxx:1124
void * GetNextChild(void *prevchild, const char *childname=nullptr)
XML helpers.
Definition Tools.cxx:1162
static Types & Instance()
The single instance of "Types" if existing already, or create it (Singleton)
Definition Types.cxx:70
Types::EMVA GetMethodType(const TString &method) const
returns the method type (enum) for a given method (string)
Definition Types.cxx:121
Basic string class.
Definition TString.h:139
void ToLower()
Change string to lower-case.
Definition TString.cxx:1182
const char * Data() const
Definition TString.h:376
TString & ReplaceAll(const TString &s1, const TString &s2)
Definition TString.h:704
static TString Format(const char *fmt,...)
Static method which formats a string using a printf style format descriptor and return a TString.
Definition TString.cxx:2378
create variable transformations
Config & gConfig()
Tools & gTools()
MsgLogger & Endl(MsgLogger &ml)
Definition MsgLogger.h:148
TMarker m
Definition textangle.C:8