ROOT logo
// @(#)root/tmva $Id: MethodMLP.h 29122 2009-06-22 06:51:30Z brun $
// Author: Andreas Hoecker, Matt Jachowski

/**********************************************************************************
 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis       *
 * Package: TMVA                                                                  *
 * Class  : MethodMLP                                                             *
 * Web    : http://tmva.sourceforge.net                                           *
 *                                                                                *
 * Description:                                                                   *
 *      ANN Multilayer Perceptron  class for the discrimination of signal         *
 *      from background.  BFGS implementation based on TMultiLayerPerceptron      *
 *      class from ROOT (http://root.cern.ch).                                    *
 *                                                                                *
 * Authors (alphabetical):                                                        *
 *      Krzysztof Danielowski <danielow@cern.ch>       - IFJ & AGH, Poland        *
 *      Andreas Hoecker       <Andreas.Hocker@cern.ch> - CERN, Switzerland        *
 *      Peter Speckmayer      <peter.speckmayer@cern.ch> - CERN, Switzerland      *
 *      Matt Jachowski        <jachowski@stanford.edu> - Stanford University, USA *
 *      Kamil Kraszewski      <kalq@cern.ch>           - IFJ & UJ, Poland         *
 *      Maciej Kruk           <mkruk@cern.ch>          - IFJ & AGH, Poland        *
 *                                                                                *
 * Copyright (c) 2005:                                                            *
 *      CERN, Switzerland                                                         *
 *                                                                                *
 * Redistribution and use in source and binary forms, with or without             *
 * modification, are permitted according to the terms listed in LICENSE           *
 * (http://tmva.sourceforge.net/LICENSE)                                          *
 **********************************************************************************/

#ifndef ROOT_TMVA_MethodMLP
#define ROOT_TMVA_MethodMLP

//////////////////////////////////////////////////////////////////////////
//                                                                      //
// MethodMLP                                                            //
//                                                                      //
// Multilayer Perceptron built off of MethodANNBase                     //
//                                                                      //
//////////////////////////////////////////////////////////////////////////

#include <vector>
#ifndef ROOT_TString
#include "TString.h"
#endif
#ifndef ROOT_TTree
#include "TTree.h"
#endif
#ifndef ROOT_TObjArray
#include "TObjArray.h"
#endif
#ifndef ROOT_TRandom3
#include "TRandom3.h"
#endif
#ifndef ROOT_TH1F
#include "TH1F.h"
#endif
#ifndef ROOT_TMatrixDfwd
#include "TMatrixDfwd.h"
#endif

#ifndef ROOT_TMVA_IFitterTarget
#include "TMVA/IFitterTarget.h"
#endif
#ifndef ROOT_TMVA_MethodBase
#include "TMVA/MethodBase.h"
#endif
#ifndef ROOT_TMVA_MethodANNBase
#include "TMVA/MethodANNBase.h"
#endif
#ifndef ROOT_TMVA_TNeuron
#include "TMVA/TNeuron.h"
#endif
#ifndef ROOT_TMVA_TActivation
#include "TMVA/TActivation.h"
#endif
#ifndef ROOT_TMVA_ConvergenceTest
#include "TMVA/ConvergenceTest.h"
#endif

#define MethodMLP_UseMinuit__
#undef  MethodMLP_UseMinuit__

namespace TMVA {

   class MethodMLP : public MethodANNBase, public IFitterTarget, public ConvergenceTest {

   public:

      // standard constructors
      MethodMLP( const TString& jobName, 
                 const TString&  methodTitle,
                 DataSetInfo& theData,
                 const TString& theOption, 
                 TDirectory* theTargetDir = 0 );

      MethodMLP( DataSetInfo& theData, 
                 const TString& theWeightFile, 
                 TDirectory* theTargetDir = 0 );

      virtual ~MethodMLP();

      virtual Bool_t HasAnalysisType( Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets );

      void Train() { Train(NumCycles()); }

      // for GA
      Double_t ComputeEstimator ( std::vector<Double_t>& parameters );
      Double_t EstimatorFunction( std::vector<Double_t>& parameters );

      enum ETrainingMethod { kBP=0, kBFGS, kGA };
      enum EBPTrainingMode { kSequential=0, kBatch };

   protected:

      // make ROOT-independent C++ class for classifier response (classifier-specific implementation)
      void MakeClassSpecific( std::ostream&, const TString& ) const;

      // get help message text
      void GetHelpMessage() const;


   private:

      // the option handling methods
      void DeclareOptions();
      void ProcessOptions();

      // general helper functions
      void     Train( Int_t nEpochs );
      void     Init();
      void     InitializeLearningRates(); // although this is only needed by backprop

      // used as a measure of success in all minimization techniques
      Double_t CalculateEstimator( Types::ETreeType treeType = Types::kTraining );

      // BFGS functions
      void     BFGSMinimize( Int_t nEpochs );
      void     SetGammaDelta( TMatrixD &Gamma, TMatrixD &Delta, std::vector<Double_t> &Buffer );
      void     SteepestDir( TMatrixD &Dir );
      Bool_t   GetHessian( TMatrixD &Hessian, TMatrixD &Gamma, TMatrixD &Delta );
      void     SetDir( TMatrixD &Hessian, TMatrixD &Dir );
      Double_t DerivDir( TMatrixD &Dir );
      Bool_t   LineSearch( TMatrixD &Dir, std::vector<Double_t> &Buffer );
      void     ComputeDEDw();
      void     SimulateEvent( const Event* ev );
      void     SetDirWeights( std::vector<Double_t> &Origin, TMatrixD &Dir, Double_t alpha );
      Double_t GetError();
      Double_t GetSqrErr( const Event* ev, UInt_t index = 0 );

      // backpropagation functions
      void     BackPropagationMinimize( Int_t nEpochs );
      void     TrainOneEpoch();
      void     Shuffle( Int_t* index, Int_t n );
      void     DecaySynapseWeights(Bool_t lateEpoch );
      void     TrainOneEvent( Int_t ievt);
      Double_t GetDesiredOutput( const Event* ev );
      void     UpdateNetwork( Double_t desired, Double_t eventWeight=1.0 );
      void     UpdateNetwork(std::vector<Float_t>& desired, Double_t eventWeight=1.0);
      void     CalculateNeuronDeltas();
      void     UpdateSynapses();
      void     AdjustSynapseWeights();

      // faster backpropagation
      void     TrainOneEventFast( Int_t ievt, Float_t*& branchVar, Int_t& type );

      // genetic algorithm functions
      void GeneticMinimize();

      // the neuronal network can be initialized after the analysis type has been set.
      void   SetAnalysisType( Types::EAnalysisType type );
      

#ifdef MethodMLP_UseMinuit__
      // minuit functions -- commented out because they rely on a static pointer
      void MinuitMinimize();
      static MethodMLP* GetThisPtr() { return fgThis; }
      static void IFCN( Int_t& npars, Double_t* grad, Double_t &f, Double_t* fitPars, Int_t ifl );
      void FCN( Int_t& npars, Double_t* grad, Double_t &f, Double_t* fitPars, Int_t ifl );
#endif

      // general
      ETrainingMethod fTrainingMethod; // method of training, BP or GA
      TString         fTrainMethodS;   // training method option param

      Float_t         fSamplingFraction;  // fraction of events which is sampled for training
      Float_t         fSamplingEpoch;     // fraction of epochs where sampling is used
      Float_t         fSamplingWeight;    // changing factor for event weights when sampling is turned on
      Bool_t          fSamplingTraining;  // The training sample is sampled
      Bool_t          fSamplingTesting;   // The testing sample is sampled

      // BFGS variables
      Double_t        fLastAlpha;      // line search variable
      Double_t        fTau;            // line search variable
      Int_t           fResetStep;      // reset time (how often we clear hessian matrix)

      // backpropagation variable
      Double_t        fLearnRate;      // learning rate for synapse weight adjustments
      Double_t        fDecayRate;      // decay rate for above learning rate
      EBPTrainingMode fBPMode;         // backprop learning mode (sequential or batch)
      TString         fBpModeS;        // backprop learning mode option string (sequential or batch)
      Int_t           fBatchSize;      // batch size, only matters if in batch learning mode
      Int_t           fTestRate;       // test for overtraining performed at each #th epochs
      
      // genetic algorithm variables
      Int_t           fGA_nsteps;      // GA settings: number of steps
      Int_t           fGA_preCalc;     // GA settings: number of pre-calc steps
      Int_t           fGA_SC_steps;    // GA settings: SC_steps
      Int_t           fGA_SC_rate; // GA settings: SC_rate
      Double_t        fGA_SC_factor;   // GA settings: SC_factor

#ifdef MethodMLP_UseMinuit__
      // minuit variables -- commented out because they rely on a static pointer
      Int_t          fNumberOfWeights; // Minuit: number of weights
      static MethodMLP* fgThis;        // Minuit: this pointer
#endif

      // debugging flags
      static const Int_t  fgPRINT_ESTIMATOR_INC = 10;     // debug flags
      static const Bool_t fgPRINT_SEQ           = kFALSE; // debug flags
      static const Bool_t fgPRINT_BATCH         = kFALSE; // debug flags

      ClassDef(MethodMLP,0) // Multi-layer perceptron implemented specifically for TMVA
   };

} // namespace TMVA

#endif
 MethodMLP.h:1
 MethodMLP.h:2
 MethodMLP.h:3
 MethodMLP.h:4
 MethodMLP.h:5
 MethodMLP.h:6
 MethodMLP.h:7
 MethodMLP.h:8
 MethodMLP.h:9
 MethodMLP.h:10
 MethodMLP.h:11
 MethodMLP.h:12
 MethodMLP.h:13
 MethodMLP.h:14
 MethodMLP.h:15
 MethodMLP.h:16
 MethodMLP.h:17
 MethodMLP.h:18
 MethodMLP.h:19
 MethodMLP.h:20
 MethodMLP.h:21
 MethodMLP.h:22
 MethodMLP.h:23
 MethodMLP.h:24
 MethodMLP.h:25
 MethodMLP.h:26
 MethodMLP.h:27
 MethodMLP.h:28
 MethodMLP.h:29
 MethodMLP.h:30
 MethodMLP.h:31
 MethodMLP.h:32
 MethodMLP.h:33
 MethodMLP.h:34
 MethodMLP.h:35
 MethodMLP.h:36
 MethodMLP.h:37
 MethodMLP.h:38
 MethodMLP.h:39
 MethodMLP.h:40
 MethodMLP.h:41
 MethodMLP.h:42
 MethodMLP.h:43
 MethodMLP.h:44
 MethodMLP.h:45
 MethodMLP.h:46
 MethodMLP.h:47
 MethodMLP.h:48
 MethodMLP.h:49
 MethodMLP.h:50
 MethodMLP.h:51
 MethodMLP.h:52
 MethodMLP.h:53
 MethodMLP.h:54
 MethodMLP.h:55
 MethodMLP.h:56
 MethodMLP.h:57
 MethodMLP.h:58
 MethodMLP.h:59
 MethodMLP.h:60
 MethodMLP.h:61
 MethodMLP.h:62
 MethodMLP.h:63
 MethodMLP.h:64
 MethodMLP.h:65
 MethodMLP.h:66
 MethodMLP.h:67
 MethodMLP.h:68
 MethodMLP.h:69
 MethodMLP.h:70
 MethodMLP.h:71
 MethodMLP.h:72
 MethodMLP.h:73
 MethodMLP.h:74
 MethodMLP.h:75
 MethodMLP.h:76
 MethodMLP.h:77
 MethodMLP.h:78
 MethodMLP.h:79
 MethodMLP.h:80
 MethodMLP.h:81
 MethodMLP.h:82
 MethodMLP.h:83
 MethodMLP.h:84
 MethodMLP.h:85
 MethodMLP.h:86
 MethodMLP.h:87
 MethodMLP.h:88
 MethodMLP.h:89
 MethodMLP.h:90
 MethodMLP.h:91
 MethodMLP.h:92
 MethodMLP.h:93
 MethodMLP.h:94
 MethodMLP.h:95
 MethodMLP.h:96
 MethodMLP.h:97
 MethodMLP.h:98
 MethodMLP.h:99
 MethodMLP.h:100
 MethodMLP.h:101
 MethodMLP.h:102
 MethodMLP.h:103
 MethodMLP.h:104
 MethodMLP.h:105
 MethodMLP.h:106
 MethodMLP.h:107
 MethodMLP.h:108
 MethodMLP.h:109
 MethodMLP.h:110
 MethodMLP.h:111
 MethodMLP.h:112
 MethodMLP.h:113
 MethodMLP.h:114
 MethodMLP.h:115
 MethodMLP.h:116
 MethodMLP.h:117
 MethodMLP.h:118
 MethodMLP.h:119
 MethodMLP.h:120
 MethodMLP.h:121
 MethodMLP.h:122
 MethodMLP.h:123
 MethodMLP.h:124
 MethodMLP.h:125
 MethodMLP.h:126
 MethodMLP.h:127
 MethodMLP.h:128
 MethodMLP.h:129
 MethodMLP.h:130
 MethodMLP.h:131
 MethodMLP.h:132
 MethodMLP.h:133
 MethodMLP.h:134
 MethodMLP.h:135
 MethodMLP.h:136
 MethodMLP.h:137
 MethodMLP.h:138
 MethodMLP.h:139
 MethodMLP.h:140
 MethodMLP.h:141
 MethodMLP.h:142
 MethodMLP.h:143
 MethodMLP.h:144
 MethodMLP.h:145
 MethodMLP.h:146
 MethodMLP.h:147
 MethodMLP.h:148
 MethodMLP.h:149
 MethodMLP.h:150
 MethodMLP.h:151
 MethodMLP.h:152
 MethodMLP.h:153
 MethodMLP.h:154
 MethodMLP.h:155
 MethodMLP.h:156
 MethodMLP.h:157
 MethodMLP.h:158
 MethodMLP.h:159
 MethodMLP.h:160
 MethodMLP.h:161
 MethodMLP.h:162
 MethodMLP.h:163
 MethodMLP.h:164
 MethodMLP.h:165
 MethodMLP.h:166
 MethodMLP.h:167
 MethodMLP.h:168
 MethodMLP.h:169
 MethodMLP.h:170
 MethodMLP.h:171
 MethodMLP.h:172
 MethodMLP.h:173
 MethodMLP.h:174
 MethodMLP.h:175
 MethodMLP.h:176
 MethodMLP.h:177
 MethodMLP.h:178
 MethodMLP.h:179
 MethodMLP.h:180
 MethodMLP.h:181
 MethodMLP.h:182
 MethodMLP.h:183
 MethodMLP.h:184
 MethodMLP.h:185
 MethodMLP.h:186
 MethodMLP.h:187
 MethodMLP.h:188
 MethodMLP.h:189
 MethodMLP.h:190
 MethodMLP.h:191
 MethodMLP.h:192
 MethodMLP.h:193
 MethodMLP.h:194
 MethodMLP.h:195
 MethodMLP.h:196
 MethodMLP.h:197
 MethodMLP.h:198
 MethodMLP.h:199
 MethodMLP.h:200
 MethodMLP.h:201
 MethodMLP.h:202
 MethodMLP.h:203
 MethodMLP.h:204
 MethodMLP.h:205
 MethodMLP.h:206
 MethodMLP.h:207
 MethodMLP.h:208
 MethodMLP.h:209
 MethodMLP.h:210
 MethodMLP.h:211
 MethodMLP.h:212
 MethodMLP.h:213
 MethodMLP.h:214
 MethodMLP.h:215
 MethodMLP.h:216
 MethodMLP.h:217
 MethodMLP.h:218
 MethodMLP.h:219
 MethodMLP.h:220
 MethodMLP.h:221
 MethodMLP.h:222
 MethodMLP.h:223
 MethodMLP.h:224
 MethodMLP.h:225
 MethodMLP.h:226
 MethodMLP.h:227
 MethodMLP.h:228