77#if ROOT_VERSION_CODE > ROOT_VERSION(5,13,06)
97 fLocalTrainingTree(0),
99 fValidationFraction(0.5),
100 fLearningMethod(
"" )
108 const TString& theWeightFile) :
111 fLocalTrainingTree(0),
113 fValidationFraction(0.5),
114 fLearningMethod(
"" )
141 if (fMLP)
delete fMLP;
151 while (layerSpec.
Length()>0) {
153 if (layerSpec.
First(
',')<0) {
158 sToAdd = layerSpec(0,layerSpec.
First(
','));
159 layerSpec = layerSpec(layerSpec.
First(
',')+1,layerSpec.
Length());
163 nNodes += atoi(sToAdd);
164 fHiddenLayer =
Form(
"%s%i:", (
const char*)fHiddenLayer, nNodes );
168 std::vector<TString>::iterator itrVar = (*fInputVars).begin();
169 std::vector<TString>::iterator itrVarEnd = (*fInputVars).end();
170 fMLPBuildOptions =
"";
171 for (; itrVar != itrVarEnd; ++itrVar) {
174 fMLPBuildOptions += myVar;
175 fMLPBuildOptions +=
",";
177 fMLPBuildOptions.
Chop();
180 fMLPBuildOptions += fHiddenLayer;
181 fMLPBuildOptions +=
"type";
183 Log() << kINFO <<
"Use " << fNcycles <<
" training cycles" <<
Endl;
184 Log() << kINFO <<
"Use configuration (nodes per hidden layer): " << fHiddenLayer <<
Endl;
204 DeclareOptionRef( fNcycles = 200,
"NCycles",
"Number of training cycles" );
205 DeclareOptionRef( fLayerSpec =
"N,N-1",
"HiddenLayers",
"Specification of hidden layer architecture (N stands for number of variables; any integers may also be used)" );
207 DeclareOptionRef( fValidationFraction = 0.5,
"ValidationFraction",
208 "Fraction of events in training tree used for cross validation" );
210 DeclareOptionRef( fLearningMethod =
"Stochastic",
"LearningMethod",
"Learning method" );
211 AddPreDefVal(
TString(
"Stochastic") );
212 AddPreDefVal(
TString(
"Batch") );
213 AddPreDefVal(
TString(
"SteepestDescent") );
214 AddPreDefVal(
TString(
"RibierePolak") );
215 AddPreDefVal(
TString(
"FletcherReeves") );
216 AddPreDefVal(
TString(
"BFGS") );
224 CreateMLPOptions(fLayerSpec);
226 if (IgnoreEventsWithNegWeightsInTraining()) {
227 Log() << kFATAL <<
"Mechanism to ignore events with negative weights in training not available for method"
228 << GetMethodTypeName()
229 <<
" --> please remove \"IgnoreNegWeightsInTraining\" option from booking string."
239 const Event* ev = GetEvent();
242 for (
UInt_t ivar = 0; ivar<Data()->GetNVariables(); ivar++) {
248 NoErrorCalc(err, errUpper);
271 const Long_t basketsize = 128000;
274 TTree *localTrainingTree =
new TTree(
"TMLPtrain",
"Local training tree for TMlpANN" );
275 localTrainingTree->
Branch(
"type", &
type,
"type/I", basketsize );
276 localTrainingTree->
Branch(
"weight", &weight,
"weight/F", basketsize );
278 for (
UInt_t ivar=0; ivar<GetNvar(); ivar++) {
279 const char* myVar = GetInternalVarName(ivar).Data();
280 localTrainingTree->
Branch( myVar, &vArr[ivar],
Form(
"Var%02i/F", ivar), basketsize );
283 for (
UInt_t ievt=0; ievt<Data()->GetNEvents(); ievt++) {
284 const Event *ev = GetEvent(ievt);
285 for (
UInt_t i=0; i<GetNvar(); i++) {
288 type = DataInfo().IsSignal( ev ) ? 1 : 0;
290 localTrainingTree->
Fill();
298 trainList += 1.0-fValidationFraction;
300 trainList += (
Int_t)Data()->GetNEvtSigTrain();
301 trainList +=
" || (Entry$>";
302 trainList += (
Int_t)Data()->GetNEvtSigTrain();
303 trainList +=
" && Entry$<";
304 trainList += (
Int_t)(Data()->GetNEvtSigTrain() + (1.0 - fValidationFraction)*Data()->GetNEvtBkgdTrain());
309 Log() << kHEADER <<
"Requirement for training events: \"" << trainList <<
"\"" <<
Endl;
310 Log() << kINFO <<
"Requirement for validation events: \"" << testList <<
"\"" <<
Endl;
315 if (fMLP != 0) {
delete fMLP; fMLP = 0; }
320 fMLP->SetEventWeight(
"weight" );
323#if ROOT_VERSION_CODE > ROOT_VERSION(5,13,06)
329 fLearningMethod.ToLower();
337 Log() << kFATAL <<
"Unknown Learning Method: \"" << fLearningMethod <<
"\"" <<
Endl;
339 fMLP->SetLearningMethod( learningMethod );
342 fMLP->Train(fNcycles,
"" );
346 delete localTrainingTree;
358 gTools().
AddAttr( arch,
"BuildOptions", fMLPBuildOptions.Data() );
361 const TString tmpfile=GetWeightFileDir()+
"/TMlp.nn.weights.temp";
362 fMLP->DumpWeights( tmpfile.
Data() );
363 std::ifstream inf( tmpfile.
Data() );
367 while (inf.getline(temp,256)) {
370 if (
dummy.BeginsWith(
'#')) {
378 data += (
dummy +
" ");
395 const TString fname = GetWeightFileDir()+
"/TMlp.nn.weights.temp";
396 std::ofstream fout( fname.
Data() );
397 double temp1=0,temp2=0;
400 std::stringstream content(nodecontent);
402 fout <<
"#input normalization" << std::endl;
403 while ((content >> temp1) &&(content >> temp2)) {
404 fout << temp1 <<
" " << temp2 << std::endl;
408 fout <<
"#output normalization" << std::endl;
409 while ((content >> temp1) &&(content >> temp2)) {
410 fout << temp1 <<
" " << temp2 << std::endl;
414 fout <<
"#neurons weights" << std::endl;
415 while (content >> temp1) {
416 fout << temp1 << std::endl;
420 fout <<
"#synapses weights" ;
421 while (content >> temp1) {
422 fout << std::endl << temp1 ;
435 TTree * dummyTree =
new TTree(
"dummy",
"Empty dummy tree", 1);
436 for (
UInt_t ivar = 0; ivar<Data()->GetNVariables(); ivar++) {
437 TString vn = DataInfo().GetVariableInfo(ivar).GetInternalName();
442 if (fMLP != 0) {
delete fMLP; fMLP = 0; }
444 fMLP->LoadWeights( fname );
454 std::ofstream fout(
"./TMlp.nn.weights.temp" );
455 fout << istr.rdbuf();
459 Log() << kINFO <<
"Load TMLP weights into " << fMLP <<
Endl;
464 TTree * dummyTree =
new TTree(
"dummy",
"Empty dummy tree", 1);
465 for (
UInt_t ivar = 0; ivar<Data()->GetNVariables(); ivar++) {
466 TString vn = DataInfo().GetVariableInfo(ivar).GetLabel();
471 if (fMLP != 0) {
delete fMLP; fMLP = 0; }
474 fMLP->LoadWeights(
"./TMlp.nn.weights.temp" );
488 if (theClassFileName ==
"")
489 classFileName = GetWeightFileDir() +
"/" + GetJobName() +
"_" +
GetMethodName() +
".class";
491 classFileName = theClassFileName;
494 Log() << kINFO <<
"Creating specific (TMultiLayerPerceptron) standalone response class: " << classFileName <<
Endl;
495 fMLP->Export( classFileName.
Data() );
517 Log() <<
"This feed-forward multilayer perceptron neural network is the " <<
Endl;
518 Log() <<
"standard implementation distributed with ROOT (class TMultiLayerPerceptron)." <<
Endl;
520 Log() <<
"Detailed information is available here:" <<
Endl;
521 if (
gConfig().WriteOptionsReference()) {
522 Log() <<
"<a href=\"http://root.cern.ch/root/html/TMultiLayerPerceptron.html\">";
523 Log() <<
"http://root.cern.ch/root/html/TMultiLayerPerceptron.html</a>" <<
Endl;
525 else Log() <<
"http://root.cern.ch/root/html/TMultiLayerPerceptron.html" <<
Endl;
#define REGISTER_METHOD(CLASS)
for example
const Bool_t EnforceNormalization__
static RooMathCoreReg dummy
char * Form(const char *fmt,...)
Class that contains all the data information.
Float_t GetValue(UInt_t ivar) const
return value of i'th variable
Double_t GetWeight() const
return the event weight - depending on whether the flag IgnoreNegWeightsInTraining is or not.
Virtual base Class for all MVA method.
This is the TMVA TMultiLayerPerceptron interface class.
void ReadWeightsFromStream(std::istream &istr)
read weights from stream since the MLP can not read from the stream, we 1st: write the weights to tem...
void Init(void)
default initialisations
virtual Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets)
TMlpANN can handle classification with 2 classes.
void Train(void)
performs TMlpANN training available learning methods:
Double_t GetMvaValue(Double_t *err=0, Double_t *errUpper=0)
calculate the value of the neural net for the current event
void DeclareOptions()
define the options (their key words) that can be set in the option string
void CreateMLPOptions(TString)
translates options from option string into TMlpANN language
void ReadWeightsFromXML(void *wghtnode)
rebuild temporary textfile from xml weightfile and load this file into MLP
MethodTMlpANN(const TString &jobName, const TString &methodTitle, DataSetInfo &theData, const TString &theOption="3000:N-1:N-2")
standard constructor
void ProcessOptions()
builds the neural network as specified by the user
void MakeClassSpecific(std::ostream &, const TString &) const
write specific classifier response nothing to do here - all taken care of by TMultiLayerPerceptron
void AddWeightsXMLTo(void *parent) const
write weights to xml file
void MakeClass(const TString &classFileName=TString("")) const
create reader class for classifier -> overwrites base class function create specific class for TMulti...
virtual ~MethodTMlpANN(void)
destructor
void GetHelpMessage() const
get help message text
Singleton class for Global types used by TMVA.
Ssiz_t First(char c) const
Find first occurrence of a character c.
const char * Data() const
TString & ReplaceAll(const TString &s1, const TString &s2)
void Resize(Ssiz_t n)
Resize the string. Truncate or add blanks as necessary.
Bool_t BeginsWith(const char *s, ECaseCompare cmp=kExact) const
TString & Remove(Ssiz_t pos)
A TTree represents a columnar dataset.
virtual Int_t Fill()
Fill all branches.
virtual Int_t Branch(TCollection *list, Int_t bufsize=32000, Int_t splitlevel=99, const char *name="")
Create one branch for each element in the collection.
std::string GetMethodName(TCppMethod_t)
std::string GetName(const std::string &scope_name)
create variable transformations
MsgLogger & Endl(MsgLogger &ml)