* *
**********************************************************************************/
// Begin_Html
/*
This is the TMVA TMultiLayerPerceptron interface class. It provides the
training and testing the ROOT internal MLP class in the TMVA framework.<be>
Available learning methods:<br>
<ul>
<li>TMultiLayerPerceptron::kStochastic </li>
<li>TMultiLayerPerceptron::kBatch </li>
<li>TMultiLayerPerceptron::kSteepestDescent </li>
<li>TMultiLayerPerceptron::kRibierePolak </li>
<li>TMultiLayerPerceptron::kFletcherReeves </li>
<li>TMultiLayerPerceptron::kBFGS </li>
</ul>
See the
<a href="http://root.cern.ch/root/html/TMultiLayerPerceptron.html>TMultiLayerPerceptron class description</a>
for details on this ANN.
*/
// End_Html
#include "TMVA/MethodTMlpANN.h"
#include <stdlib.h>
#include "Riostream.h"
#include "TMultiLayerPerceptron.h"
#include "TLeaf.h"
#include "TEventList.h"
#include "TObjString.h"
#include "TMVA/Tools.h"
const Bool_t EnforceNormalization__=kTRUE;
const TMultiLayerPerceptron::LearningMethod LearningMethod__=TMultiLayerPerceptron::kStochastic;
ClassImp(TMVA::MethodTMlpANN)
TMVA::MethodTMlpANN::MethodTMlpANN( TString jobName, std::vector<TString>* theVariables,
TTree* theTree, TString theOption, TDirectory* theTargetDir)
: TMVA::MethodBase(jobName, theVariables, theTree, theOption, theTargetDir )
{
InitTMlpANN();
if (fOptions.Sizeof()<2){
fOptions = "3000:N-1:N-2";
cout << "--- " << GetName() << ": using default options= "<< fOptions << endl;
}
CreateMLPOptions();
cout << "--- " << GetName() << ": use " << fNcycles << " training cycles" << endl;
cout << "--- " << GetName() << ": use configuration (nodes per hidden layer): "
<< fHiddenLayer << endl;
}
TMVA::MethodTMlpANN::MethodTMlpANN( vector<TString> *theVariables,
TString theWeightFile,
TDirectory* theTargetDir )
: TMVA::MethodBase( theVariables, theWeightFile, theTargetDir )
{
InitTMlpANN();
}
void TMVA::MethodTMlpANN::InitTMlpANN( void )
{
fMethodName = "TMlpANN";
fMethod = TMVA::Types::TMlpANN;
fTestvar = fTestvarPrefix+GetMethodName();
}
TMVA::MethodTMlpANN::~MethodTMlpANN( void )
{
}
void TMVA::MethodTMlpANN::CreateMLPOptions( void )
{
vector<Int_t>* nodes = ParseOptionString( fOptions, fNvar, new vector<Int_t>() );
fNcycles = (*nodes)[0];
fHiddenLayer = ":";
for (UInt_t i=1; i<nodes->size(); i++)
fHiddenLayer = Form( "%s%i:", (const char*)fHiddenLayer, (*nodes)[i] );
vector<TString>::iterator itrVar = (*fInputVars).begin();
vector<TString>::iterator itrVarEnd = (*fInputVars).end();
fOptions="";
for (; itrVar != itrVarEnd; itrVar++) {
if (EnforceNormalization__) fOptions += "@";
TString myVar = *itrVar; ;
fOptions += myVar;
fOptions += ",";
}
fOptions.Chop();
fOptions += fHiddenLayer;
fOptions += "type";
delete nodes;
}
void TMVA::MethodTMlpANN::Train( void )
{
if (!CheckSanity()) {
cout << "--- " << GetName() << ": Error: sanity check failed ==> exit(1)" << endl;
exit(1);
}
if (Verbose())
cout << "--- " << GetName() << " <verbose>: option string: " << fOptions << endl;
Double_t v[100];
Int_t type;
TTree *localTrainingTree = new TTree("localTrainingTree","Merged fTraining + fTestTree");
localTrainingTree->Branch("type",&type,"type/I");
for(Int_t ivar=0; ivar<fNvar; ivar++) {
if (!(*fInputVars)[ivar].Contains("type")) {
localTrainingTree->Branch( (*fInputVars)[ivar], &v[ivar], (*fInputVars)[ivar] + "/D" );
}
}
for (Int_t ievt=0;ievt<fTrainingTree->GetEntries(); ievt++) {
type = (Int_t)TMVA::Tools::GetValue( fTrainingTree, ievt, "type" );
for (Int_t ivar=0; ivar<fNvar; ivar++) {
if (!(*fInputVars)[ivar].Contains("type")) {
v[ivar] = TMVA::Tools::GetValue( fTrainingTree, ievt, (*fInputVars)[ivar] );
}
}
localTrainingTree->Fill();
}
for (Int_t ievt=0;ievt<fTestTree->GetEntries(); ievt++) {
type = (Int_t)TMVA::Tools::GetValue( fTestTree, ievt, "type" );
for(Int_t ivar=0; ivar<fNvar; ivar++) {
if (!(*fInputVars)[ivar].Contains("type")) {
v[ivar]= TMVA::Tools::GetValue( fTestTree, ievt, (*fInputVars)[ivar] );
}
}
localTrainingTree->Fill();
}
TString trainList = "Entry$<";
trainList += (Int_t)fTrainingTree->GetEntries();
TString testList = "Entry$>=";
testList += (Int_t)fTrainingTree->GetEntries();
TMultiLayerPerceptron *mlp = new TMultiLayerPerceptron( fOptions,
localTrainingTree,
trainList,
testList );
mlp->SetLearningMethod( LearningMethod__ );
mlp->Train(fNcycles, "text,update=200");
mlp->DumpWeights(GetWeightFileName());
WriteWeightsToFile();
localTrainingTree->Delete();
delete mlp;
}
void TMVA::MethodTMlpANN::WriteWeightsToFile( void )
{
TString fname = GetWeightFileName();
cout << "--- " << GetName() << ": creating weight file: " << fname << endl;
ofstream fout( fname , ios::out | ios::app);
if (!fout.good( )) {
cout << "--- " << GetName() << ": Error in ::WriteWeightsToFile: "
<< "unable to open output weight file: " << fname << endl;
exit(1);
}
fout << fOptions;
fout.close();
}
void TMVA::MethodTMlpANN::ReadWeightsFromFile( void )
{
TString fname = GetWeightFileName();
cout << "--- " << GetName() << ": reading weight file: " << fname << endl;
ifstream fin( fname );
if (!fin.good( )) {
cout << "--- " << GetName() << ": Error in ::ReadWeightsFromFile: "
<< "unable to open input file: " << fname << endl;
exit(1);
}
while (!fin.eof()) fin >> fOptions;
fin.close();
}
void TMVA::MethodTMlpANN::PrepareEvaluationTree( TTree* testTree )
{
if (Verbose()) cout << "--- " << GetName() << " <verbose>: begin testing" << endl;
Double_t v[100];
Int_t type;
TTree *localTestTree = new TTree("localTestTree","copy of testTree");
localTestTree->Branch("type",&type,"type/I",128000);
for (Int_t ivar=0; ivar<fNvar; ivar++) {
if (!(*fInputVars)[ivar].Contains("type")) {
localTestTree->Branch( (*fInputVars)[ivar], &v[ivar], (*fInputVars)[ivar] + "/D",128000 );
}
}
for (Int_t ievt=0;ievt<testTree->GetEntries(); ievt++) {
type = (Int_t)TMVA::Tools::GetValue( testTree, ievt, "type" );
for (Int_t ivar=0; ivar<fNvar; ivar++) {
if (!(*fInputVars)[ivar].Contains("type")) {
v[ivar] = TMVA::Tools::GetValue( testTree, ievt, (*fInputVars)[ivar] );
}
}
localTestTree->Fill();
}
ReadWeightsFromFile();
TMultiLayerPerceptron *mlp = new TMultiLayerPerceptron( fOptions, localTestTree );
mlp->LoadWeights(GetWeightFileName());
Double_t myMVA;
TBranch *newBranch = testTree->Branch( fTestvar, &myMVA, fTestvar + "/D" );
for (Int_t i=0; i< (testTree->GetEntries()) ; i++) {
myMVA=mlp->Result(i);
newBranch->Fill();
}
localTestTree->Delete();
delete mlp;
}
void TMVA::MethodTMlpANN::SetTestTree( TTree* testTree )
{
fTestTree = testTree;
}
void TMVA::MethodTMlpANN::WriteHistosToFile( void )
{
cout << "--- " << GetName() << ": write " << GetName()
<< " special histos to file: " << fBaseDir->GetPath() << endl;
}
ROOT page - Class index - Class Hierarchy - Top of the page
This page has been automatically generated. If you have any comments or suggestions about the page layout send a mail to ROOT support, or contact the developers with any questions or problems regarding ROOT.