// Begin_Html
/*
This is the TMVA TMultiLayerPerceptron interface class. It provides the
training and testing the ROOT internal MLP class in the TMVA framework.<be>
Available learning methods:<br>
<ul>
<li>TMultiLayerPerceptron::kStochastic </li>
<li>TMultiLayerPerceptron::kBatch </li>
<li>TMultiLayerPerceptron::kSteepestDescent </li>
<li>TMultiLayerPerceptron::kRibierePolak </li>
<li>TMultiLayerPerceptron::kFletcherReeves </li>
<li>TMultiLayerPerceptron::kBFGS </li>
</ul>
See the
<a href="http://root.cern.ch/root/html/TMultiLayerPerceptron.html>TMultiLayerPerceptron class description</a>
for details on this ANN.
*/
// End_Html
#include "TMVA/MethodTMlpANN.h"
#include <stdlib.h>
#include "Riostream.h"
#include "TMultiLayerPerceptron.h"
#include "TLeaf.h"
#include "TEventList.h"
#include "TObjString.h"
#include "TMVA/Tools.h"
const Bool_t EnforceNormalization__=kTRUE;
const TMultiLayerPerceptron::ELearningMethod LearningMethod__= TMultiLayerPerceptron::kStochastic;
ClassImp(TMVA::MethodTMlpANN)
;
TMVA::MethodTMlpANN::MethodTMlpANN( TString jobName, TString methodTitle, DataSet& theData,
TString theOption, TDirectory* theTargetDir)
: TMVA::MethodBase(jobName, methodTitle, theData, theOption, theTargetDir )
, fMLP(0)
{
InitTMlpANN();
DeclareOptions();
ParseOptions();
ProcessOptions();
}
TMVA::MethodTMlpANN::MethodTMlpANN( DataSet& theData,
TString theWeightFile,
TDirectory* theTargetDir )
: TMVA::MethodBase( theData, theWeightFile, theTargetDir )
, fMLP(0)
{
InitTMlpANN();
DeclareOptions();
}
void TMVA::MethodTMlpANN::InitTMlpANN( void )
{
SetMethodName( "TMlpANN" );
SetMethodType( TMVA::Types::kTMlpANN );
SetTestvarName();
}
TMVA::MethodTMlpANN::~MethodTMlpANN( void )
{
if(fMLP!=0) delete fMLP;
}
void TMVA::MethodTMlpANN::CreateMLPOptions( TString layerSpec )
{
fHiddenLayer = ":";
while(layerSpec.Length()>0) {
TString sToAdd="";
if(layerSpec.First(',')<0) {
sToAdd = layerSpec;
layerSpec = "";
} else {
sToAdd = layerSpec(0,layerSpec.First(','));
layerSpec = layerSpec(layerSpec.First(',')+1,layerSpec.Length());
}
int nNodes = 0;
if(sToAdd.BeginsWith("N")) { sToAdd.Remove(0,1); nNodes = GetNvar(); }
nNodes += atoi(sToAdd);
fHiddenLayer = Form( "%s%i:", (const char*)fHiddenLayer, nNodes );
}
vector<TString>::iterator itrVar = (*fInputVars).begin();
vector<TString>::iterator itrVarEnd = (*fInputVars).end();
fMLPBuildOptions="";
for (; itrVar != itrVarEnd; itrVar++) {
if (EnforceNormalization__) fMLPBuildOptions += "@";
TString myVar = *itrVar; ;
fMLPBuildOptions += myVar;
fMLPBuildOptions += ",";
}
fMLPBuildOptions.Chop();
fMLPBuildOptions += fHiddenLayer;
fMLPBuildOptions += "type";
fLogger << kINFO << "use " << fNcycles << " training cycles" << Endl;
fLogger << kINFO << "use configuration (nodes per hidden layer): " << fHiddenLayer << Endl;
}
void TMVA::MethodTMlpANN::DeclareOptions()
{
DeclareOptionRef(fNcycles=3000,"NCycles","Number of training cycles");
DeclareOptionRef(fLayerSpec="N-1,N-2","HiddenLayers","Specification of the hidden layers");
}
void TMVA::MethodTMlpANN::ProcessOptions()
{
CreateMLPOptions(fLayerSpec);
static Double_t* d = new Double_t[Data().GetNVariables()] ;
static Int_t type;
gROOT->cd();
TTree * dummyTree = new TTree("dummy","Empty dummy tree", 1);
for(UInt_t ivar = 0; ivar<Data().GetNVariables(); ivar++) {
TString vn = Data().GetInternalVarName(ivar);
dummyTree->Branch(Form("%s",vn.Data()), d+ivar, Form("%s/D",vn.Data()));
}
dummyTree->Branch("type", &type, "type/I");
if(fMLP!=0) delete fMLP;
fMLP = new TMultiLayerPerceptron( fMLPBuildOptions.Data(), dummyTree );
}
Double_t TMVA::MethodTMlpANN::GetMvaValue()
{
static Double_t* d = new Double_t[Data().GetNVariables()];
for(UInt_t ivar = 0; ivar<Data().GetNVariables(); ivar++) {
d[ivar] = (Double_t)Data().Event().GetVal(ivar);
}
Double_t mvaVal = fMLP->Evaluate(0,d);
return mvaVal;
}
void TMVA::MethodTMlpANN::Train( void )
{
if (!CheckSanity()) fLogger << kFATAL << "<Train> sanity check failed" << Endl;
fLogger << kVERBOSE << "option string: " << GetOptions() << Endl;
TTree *localTrainingTree = Data().GetTrainingTree()->CloneTree();
localTrainingTree->CopyEntries(GetTestTree());
TString trainList = "Entry$<";
trainList += (Int_t)Data().GetNEvtTrain();
TString testList = "Entry$>=";
testList += (Int_t)Data().GetNEvtTrain();
if(fMLP!=0) delete fMLP;
fMLP = new TMultiLayerPerceptron( fMLPBuildOptions.Data(),
localTrainingTree,
trainList,
testList );
fMLP->SetLearningMethod( LearningMethod__ );
fMLP->Train(fNcycles, "text,update=200");
localTrainingTree->Delete();
}
void TMVA::MethodTMlpANN::WriteWeightsToStream( ostream & o ) const
{
fMLP->DumpWeights("weights/TMlp.nn.weights.temp");
ifstream inf("weights/TMlp.nn.weights.temp");
o << inf.rdbuf();
inf.close();
}
void TMVA::MethodTMlpANN::ReadWeightsFromStream( istream & istr )
{
ofstream fout("weights/TMlp.nn.weights.temp");
fout << istr.rdbuf();
fout.close();
fLogger << kINFO << "Load TMLP weights" << Endl;
fMLP->LoadWeights("weights/TMlp.nn.weights.temp");
}
void TMVA::MethodTMlpANN::WriteMonitoringHistosToFile( void ) const
{
fLogger << kINFO << "wrote monitoring histograms to file: " << BaseDir()->GetPath() << Endl;
}
ROOT page - Class index - Class Hierarchy - Top of the page
This page has been automatically generated. If you have any comments or suggestions about the page layout send a mail to ROOT support, or contact the developers with any questions or problems regarding ROOT.