ROOT logo
// @(#)root/tmva $Id$
// Author: Andreas Hoecker, Peter Speckmayer, Matt Jachowski, Jan Therhaag

/**********************************************************************************
 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis       *
 * Package: TMVA                                                                  *
 * Class  : MethodANNBase                                                         *
 * Web    : http://tmva.sourceforge.net                                           *
 *                                                                                *
 * Description:                                                                   *
 *      Artificial neural network base class for the discrimination of signal     *
 *      from background.                                                          *
 *                                                                                *
 * Authors (alphabetical):                                                        *
 *      Andreas Hoecker  <Andreas.Hocker@cern.ch> - CERN, Switzerland             *
 *      Matt Jachowski   <jachowski@stanford.edu> - Stanford University, USA      *
 *      Peter Speckmayer <Peter.Speckmayer@cern.ch>  - CERN, Switzerland          *
 *      Joerg Stelzer   <Joerg.Stelzer@cern.ch>   - CERN, Switzerland             *
 *      Jan Therhaag       <Jan.Therhaag@cern.ch>     - U of Bonn, Germany        *
 *                                                                                *
 * Small changes (regression):                                                    *
 *      Krzysztof Danielowski <danielow@cern.ch>  - IFJ PAN & AGH, Poland         *
 *      Kamil Kraszewski      <kalq@cern.ch>      - IFJ PAN & UJ , Poland         *
 *      Maciej Kruk           <mkruk@cern.ch>     - IFJ PAN & AGH, Poland         *
 *                                                                                *
 * Copyright (c) 2005-2011:                                                       *
 *      CERN, Switzerland                                                         *
 *                                                                                *
 * Redistribution and use in source and binary forms, with or without             *
 * modification, are permitted according to the terms listed in LICENSE           *
 * (http://tmva.sourceforge.net/LICENSE)                                          *
 **********************************************************************************/

#ifndef ROOT_TMVA_MethodANNBase
#define ROOT_TMVA_MethodANNBase

//////////////////////////////////////////////////////////////////////////
//                                                                      //
// MethodANNBase                                                        //
//                                                                      //
// Base class for all TMVA methods using artificial neural networks     //
//                                                                      //
//////////////////////////////////////////////////////////////////////////

#ifndef ROOT_TString
#include "TString.h"
#endif
#include <vector>
#ifndef ROOT_TTree
#include "TTree.h"
#endif
#ifndef ROOT_TObjArray
#include "TObjArray.h"
#endif
#ifndef ROOT_TRandom3
#include "TRandom3.h"
#endif
#ifndef ROOT_TMatrix
#include "TMatrix.h"
#endif

#ifndef ROOT_TMVA_MethodBase
#include "TMVA/MethodBase.h"
#endif
#ifndef ROOT_TMVA_TActivation
#include "TMVA/TActivation.h"
#endif
#ifndef ROOT_TMVA_TNeuron
#include "TMVA/TNeuron.h"
#endif
#ifndef ROOT_TMVA_TNeuronInput
#include "TMVA/TNeuronInput.h"
#endif

class TH1;
class TH1F;

namespace TMVA {

   class MethodANNBase : public MethodBase {
      
   public:
      
      // constructors dictated by subclassing off of MethodBase
      MethodANNBase( const TString& jobName,
                     Types::EMVA methodType,
                     const TString& methodTitle,
                     DataSetInfo& theData, 
                     const TString& theOption,
                     TDirectory* theTargetDir );
      
      MethodANNBase( Types::EMVA methodType,
                     DataSetInfo& theData,
                     const TString& theWeightFile, 
                     TDirectory* theTargetDir );
      
      virtual ~MethodANNBase();
      
      // this does the real initialization work
      void InitANNBase();
      
      // setters for subclasses
      void SetActivation(TActivation* activation) {
         if (fActivation != NULL) delete fActivation; fActivation = activation;
      }
      void SetNeuronInputCalculator(TNeuronInput* inputCalculator) {
         if (fInputCalculator != NULL) delete fInputCalculator;
         fInputCalculator = inputCalculator;
      }
      
      // this will have to be overridden by every subclass
      virtual void Train() = 0;
      
      // print network, for debugging  
      virtual void PrintNetwork() const;


      // call this function like that:
      // ...
      // MethodMLP* mlp = dynamic_cast<MethodMLP*>(method);
      // std::vector<float> layerValues;
      // mlp->GetLayerActivation (2, std::back_inserter(layerValues));
      // ... do now something with the layerValues
      // 
      template <typename WriteIterator>
      void GetLayerActivation (size_t layer, WriteIterator writeIterator);

      using MethodBase::ReadWeightsFromStream;

      // write weights to file
      void AddWeightsXMLTo( void* parent ) const;
      void ReadWeightsFromXML( void* wghtnode );

      // read weights from file
      virtual void ReadWeightsFromStream( std::istream& istr );
      
      // calculate the MVA value
      virtual Double_t GetMvaValue( Double_t* err = 0, Double_t* errUpper = 0 );

      virtual const std::vector<Float_t> &GetRegressionValues();

      virtual const std::vector<Float_t> &GetMulticlassValues();
      
      // write method specific histos to target file
      virtual void WriteMonitoringHistosToFile() const;
     
      // ranking of input variables
      const Ranking* CreateRanking();

      // the option handling methods
      virtual void DeclareOptions();
      virtual void ProcessOptions();
      
      Bool_t Debug() const;

      enum EEstimator      { kMSE=0,kCE};


   protected:

      virtual void MakeClassSpecific( std::ostream&, const TString& ) const;
      
      std::vector<Int_t>* ParseLayoutString( TString layerSpec );
      virtual void        BuildNetwork( std::vector<Int_t>* layout, std::vector<Double_t>* weights=NULL,
                                        Bool_t fromFile = kFALSE );
      void     ForceNetworkInputs( const Event* ev, Int_t ignoreIndex = -1 );
      Double_t GetNetworkOutput() { return GetOutputNeuron()->GetActivationValue(); }
      
      // debugging utilities
      void     PrintMessage( TString message, Bool_t force = kFALSE ) const;
      void     ForceNetworkCalculations();
      void     WaitForKeyboard();
      
      // accessors
      Int_t    NumCycles()  { return fNcycles;   }
      TNeuron* GetInputNeuron (Int_t index)       { return (TNeuron*)fInputLayer->At(index); }
      TNeuron* GetOutputNeuron(Int_t index = 0)   { return fOutputNeurons.at(index); }
      
      // protected variables
      TObjArray*    fNetwork;         // TObjArray of TObjArrays representing network
      TObjArray*    fSynapses;        // array of pointers to synapses, no structural data
      TActivation*  fActivation;      // activation function to be used for hidden layers
      TActivation*  fOutput;          // activation function to be used for output layers, depending on estimator
      TActivation*  fIdentity;        // activation for input and output layers
      TRandom3*     frgen;            // random number generator for various uses
      TNeuronInput* fInputCalculator; // input calculator for all neurons

      std::vector<Int_t>        fRegulatorIdx;  //index to different priors from every synapses
      std::vector<Double_t>     fRegulators;    //the priors as regulator
      EEstimator                fEstimator;
      TString                   fEstimatorS;

      // monitoring histograms
      TH1F* fEstimatorHistTrain; // monitors convergence of training sample
      TH1F* fEstimatorHistTest;  // monitors convergence of independent test sample
      
      // monitoring histograms (not available for regression)
      void CreateWeightMonitoringHists( const TString& bulkname, std::vector<TH1*>* hv = 0 ) const;
      std::vector<TH1*> fEpochMonHistS; // epoch monitoring hitograms for signal
      std::vector<TH1*> fEpochMonHistB; // epoch monitoring hitograms for background
      std::vector<TH1*> fEpochMonHistW; // epoch monitoring hitograms for weights

      
      // general
      TMatrixD           fInvHessian;           // zjh
      bool               fUseRegulator;         // zjh

   protected:
      Int_t                   fRandomSeed;      // random seed for initial synapse weights

      Int_t                   fNcycles;         // number of epochs to train

      TString                 fNeuronType;      // name of neuron activation function class
      TString                 fNeuronInputType; // name of neuron input calculator class


   private:
      
      // helper functions for building network
      void BuildLayers(std::vector<Int_t>* layout, Bool_t from_file = false);
      void BuildLayer(Int_t numNeurons, TObjArray* curLayer, TObjArray* prevLayer, 
                      Int_t layerIndex, Int_t numLayers, Bool_t from_file = false);
      void AddPreLinks(TNeuron* neuron, TObjArray* prevLayer);
     
      // helper functions for weight initialization
      void InitWeights();
      void ForceWeights(std::vector<Double_t>* weights);
      
      // helper functions for deleting network
      void DeleteNetwork();
      void DeleteNetworkLayer(TObjArray*& layer);
      
      // debugging utilities
      void PrintLayer(TObjArray* layer) const;
      void PrintNeuron(TNeuron* neuron) const;
      
      // private variables
      TObjArray*              fInputLayer;      // cache this for fast access
      std::vector<TNeuron*>   fOutputNeurons;   // cache this for fast access
      TString                 fLayerSpec;       // layout specification option

      // some static flags
      static const Bool_t fgDEBUG      = kTRUE;  // debug flag
    
      ClassDef(MethodANNBase,0) // Base class for TMVA ANNs
   };



    template <typename WriteIterator>
    inline void MethodANNBase::GetLayerActivation (size_t layerNumber, WriteIterator writeIterator)
    {
	// get the activation values of the nodes in layer "layer"
	// write the node activation values into the writeIterator
        // assumes, that the network has been computed already (by calling
	// "GetRegressionValues")

	if (layerNumber >= (size_t)fNetwork->GetEntriesFast())
	    return;

	TObjArray* layer = (TObjArray*)fNetwork->At(layerNumber);
	UInt_t nNodes    = layer->GetEntriesFast();
	for (UInt_t iNode = 0; iNode < nNodes; iNode++) 
	{
	    (*writeIterator) = ((TNeuron*)layer->At(iNode))->GetActivationValue();
	    ++writeIterator;
	}
    }

   
} // namespace TMVA

#endif
 MethodANNBase.h:1
 MethodANNBase.h:2
 MethodANNBase.h:3
 MethodANNBase.h:4
 MethodANNBase.h:5
 MethodANNBase.h:6
 MethodANNBase.h:7
 MethodANNBase.h:8
 MethodANNBase.h:9
 MethodANNBase.h:10
 MethodANNBase.h:11
 MethodANNBase.h:12
 MethodANNBase.h:13
 MethodANNBase.h:14
 MethodANNBase.h:15
 MethodANNBase.h:16
 MethodANNBase.h:17
 MethodANNBase.h:18
 MethodANNBase.h:19
 MethodANNBase.h:20
 MethodANNBase.h:21
 MethodANNBase.h:22
 MethodANNBase.h:23
 MethodANNBase.h:24
 MethodANNBase.h:25
 MethodANNBase.h:26
 MethodANNBase.h:27
 MethodANNBase.h:28
 MethodANNBase.h:29
 MethodANNBase.h:30
 MethodANNBase.h:31
 MethodANNBase.h:32
 MethodANNBase.h:33
 MethodANNBase.h:34
 MethodANNBase.h:35
 MethodANNBase.h:36
 MethodANNBase.h:37
 MethodANNBase.h:38
 MethodANNBase.h:39
 MethodANNBase.h:40
 MethodANNBase.h:41
 MethodANNBase.h:42
 MethodANNBase.h:43
 MethodANNBase.h:44
 MethodANNBase.h:45
 MethodANNBase.h:46
 MethodANNBase.h:47
 MethodANNBase.h:48
 MethodANNBase.h:49
 MethodANNBase.h:50
 MethodANNBase.h:51
 MethodANNBase.h:52
 MethodANNBase.h:53
 MethodANNBase.h:54
 MethodANNBase.h:55
 MethodANNBase.h:56
 MethodANNBase.h:57
 MethodANNBase.h:58
 MethodANNBase.h:59
 MethodANNBase.h:60
 MethodANNBase.h:61
 MethodANNBase.h:62
 MethodANNBase.h:63
 MethodANNBase.h:64
 MethodANNBase.h:65
 MethodANNBase.h:66
 MethodANNBase.h:67
 MethodANNBase.h:68
 MethodANNBase.h:69
 MethodANNBase.h:70
 MethodANNBase.h:71
 MethodANNBase.h:72
 MethodANNBase.h:73
 MethodANNBase.h:74
 MethodANNBase.h:75
 MethodANNBase.h:76
 MethodANNBase.h:77
 MethodANNBase.h:78
 MethodANNBase.h:79
 MethodANNBase.h:80
 MethodANNBase.h:81
 MethodANNBase.h:82
 MethodANNBase.h:83
 MethodANNBase.h:84
 MethodANNBase.h:85
 MethodANNBase.h:86
 MethodANNBase.h:87
 MethodANNBase.h:88
 MethodANNBase.h:89
 MethodANNBase.h:90
 MethodANNBase.h:91
 MethodANNBase.h:92
 MethodANNBase.h:93
 MethodANNBase.h:94
 MethodANNBase.h:95
 MethodANNBase.h:96
 MethodANNBase.h:97
 MethodANNBase.h:98
 MethodANNBase.h:99
 MethodANNBase.h:100
 MethodANNBase.h:101
 MethodANNBase.h:102
 MethodANNBase.h:103
 MethodANNBase.h:104
 MethodANNBase.h:105
 MethodANNBase.h:106
 MethodANNBase.h:107
 MethodANNBase.h:108
 MethodANNBase.h:109
 MethodANNBase.h:110
 MethodANNBase.h:111
 MethodANNBase.h:112
 MethodANNBase.h:113
 MethodANNBase.h:114
 MethodANNBase.h:115
 MethodANNBase.h:116
 MethodANNBase.h:117
 MethodANNBase.h:118
 MethodANNBase.h:119
 MethodANNBase.h:120
 MethodANNBase.h:121
 MethodANNBase.h:122
 MethodANNBase.h:123
 MethodANNBase.h:124
 MethodANNBase.h:125
 MethodANNBase.h:126
 MethodANNBase.h:127
 MethodANNBase.h:128
 MethodANNBase.h:129
 MethodANNBase.h:130
 MethodANNBase.h:131
 MethodANNBase.h:132
 MethodANNBase.h:133
 MethodANNBase.h:134
 MethodANNBase.h:135
 MethodANNBase.h:136
 MethodANNBase.h:137
 MethodANNBase.h:138
 MethodANNBase.h:139
 MethodANNBase.h:140
 MethodANNBase.h:141
 MethodANNBase.h:142
 MethodANNBase.h:143
 MethodANNBase.h:144
 MethodANNBase.h:145
 MethodANNBase.h:146
 MethodANNBase.h:147
 MethodANNBase.h:148
 MethodANNBase.h:149
 MethodANNBase.h:150
 MethodANNBase.h:151
 MethodANNBase.h:152
 MethodANNBase.h:153
 MethodANNBase.h:154
 MethodANNBase.h:155
 MethodANNBase.h:156
 MethodANNBase.h:157
 MethodANNBase.h:158
 MethodANNBase.h:159
 MethodANNBase.h:160
 MethodANNBase.h:161
 MethodANNBase.h:162
 MethodANNBase.h:163
 MethodANNBase.h:164
 MethodANNBase.h:165
 MethodANNBase.h:166
 MethodANNBase.h:167
 MethodANNBase.h:168
 MethodANNBase.h:169
 MethodANNBase.h:170
 MethodANNBase.h:171
 MethodANNBase.h:172
 MethodANNBase.h:173
 MethodANNBase.h:174
 MethodANNBase.h:175
 MethodANNBase.h:176
 MethodANNBase.h:177
 MethodANNBase.h:178
 MethodANNBase.h:179
 MethodANNBase.h:180
 MethodANNBase.h:181
 MethodANNBase.h:182
 MethodANNBase.h:183
 MethodANNBase.h:184
 MethodANNBase.h:185
 MethodANNBase.h:186
 MethodANNBase.h:187
 MethodANNBase.h:188
 MethodANNBase.h:189
 MethodANNBase.h:190
 MethodANNBase.h:191
 MethodANNBase.h:192
 MethodANNBase.h:193
 MethodANNBase.h:194
 MethodANNBase.h:195
 MethodANNBase.h:196
 MethodANNBase.h:197
 MethodANNBase.h:198
 MethodANNBase.h:199
 MethodANNBase.h:200
 MethodANNBase.h:201
 MethodANNBase.h:202
 MethodANNBase.h:203
 MethodANNBase.h:204
 MethodANNBase.h:205
 MethodANNBase.h:206
 MethodANNBase.h:207
 MethodANNBase.h:208
 MethodANNBase.h:209
 MethodANNBase.h:210
 MethodANNBase.h:211
 MethodANNBase.h:212
 MethodANNBase.h:213
 MethodANNBase.h:214
 MethodANNBase.h:215
 MethodANNBase.h:216
 MethodANNBase.h:217
 MethodANNBase.h:218
 MethodANNBase.h:219
 MethodANNBase.h:220
 MethodANNBase.h:221
 MethodANNBase.h:222
 MethodANNBase.h:223
 MethodANNBase.h:224
 MethodANNBase.h:225
 MethodANNBase.h:226
 MethodANNBase.h:227
 MethodANNBase.h:228
 MethodANNBase.h:229
 MethodANNBase.h:230
 MethodANNBase.h:231
 MethodANNBase.h:232
 MethodANNBase.h:233
 MethodANNBase.h:234
 MethodANNBase.h:235
 MethodANNBase.h:236
 MethodANNBase.h:237
 MethodANNBase.h:238
 MethodANNBase.h:239
 MethodANNBase.h:240
 MethodANNBase.h:241
 MethodANNBase.h:242
 MethodANNBase.h:243
 MethodANNBase.h:244
 MethodANNBase.h:245
 MethodANNBase.h:246
 MethodANNBase.h:247
 MethodANNBase.h:248
 MethodANNBase.h:249
 MethodANNBase.h:250
 MethodANNBase.h:251
 MethodANNBase.h:252
 MethodANNBase.h:253
 MethodANNBase.h:254
 MethodANNBase.h:255
 MethodANNBase.h:256
 MethodANNBase.h:257
 MethodANNBase.h:258
 MethodANNBase.h:259
 MethodANNBase.h:260
 MethodANNBase.h:261
 MethodANNBase.h:262
 MethodANNBase.h:263
 MethodANNBase.h:264
 MethodANNBase.h:265
 MethodANNBase.h:266
 MethodANNBase.h:267
 MethodANNBase.h:268
 MethodANNBase.h:269
 MethodANNBase.h:270
 MethodANNBase.h:271
 MethodANNBase.h:272
 MethodANNBase.h:273