ROOT logo

/**********************************************************************************
 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis       *
 * Package: TMVA                                                                  *
 * Class  : CCTreeWrapper                                                         *
 * Web    : http://tmva.sourceforge.net                                           *
 *                                                                                *
 * Description: a light wrapper of a decision tree, used to perform cost          *
 *              complexity pruning "in-place" Cost Complexity Pruning             *
 *                                                                                *  
 * Author: Doug Schouten (dschoute@sfu.ca)                                        *
 *                                                                                *
 *                                                                                *
 * Copyright (c) 2007:                                                            *
 *      CERN, Switzerland                                                         *
 *      MPI-K Heidelberg, Germany                                                 *
 *      U. of Texas at Austin, USA                                                *
 *                                                                                *
 * Redistribution and use in source and binary forms, with or without             *
 * modification, are permitted according to the terms listed in LICENSE           *
 * (http://tmva.sourceforge.net/LICENSE)                                          *
 **********************************************************************************/

#ifndef ROOT_TMVA_CCTreeWrapper
#define ROOT_TMVA_CCTreeWrapper

#ifndef ROOT_TMVA_Event
#include "TMVA/Event.h"
#endif
#ifndef ROOT_TMVA_SeparationBase
#include "TMVA/SeparationBase.h"
#endif
#ifndef ROOT_TMVA_DecisionTree
#include "TMVA/DecisionTree.h"
#endif
#ifndef ROOT_TMVA_DataSet
#include "TMVA/DataSet.h"
#endif


namespace TMVA {

   class CCTreeWrapper {

   public:

      typedef std::vector<Event*> EventList;

      /////////////////////////////////////////////////////////////
      // CCTreeNode - a light wrapper of a decision tree node    //
      //                                                         //
      /////////////////////////////////////////////////////////////

      class CCTreeNode : virtual public Node {

      public:

         CCTreeNode( DecisionTreeNode* n = NULL );
         virtual ~CCTreeNode( );
      
         virtual Node* CreateNode() const { return new CCTreeNode(); }

         // set |~T_t|, the number of terminal descendants of node t 
         inline void SetNLeafDaughters( Int_t N ) { fNLeafDaughters = (N > 0 ? N : 0); }

         // return |~T_t|
         inline Int_t GetNLeafDaughters() const { return fNLeafDaughters; }

         // set R(t), the node resubstitution estimate (Gini, misclassification, etc.) for the node t
         inline void SetNodeResubstitutionEstimate( Double_t R ) { fNodeResubstitutionEstimate = (R >= 0 ? R : 0.0); }
      
         // return R(t) for node t
         inline Double_t GetNodeResubstitutionEstimate( ) const { return fNodeResubstitutionEstimate; }

         // set R(T_t) = sum[t' in ~T_t]{ R(t) }, the resubstitution estimate for the branch rooted at
         // node t (it is an estimate because it is calculated from the training dataset, i.e., the original tree)
         inline void SetResubstitutionEstimate( Double_t R ) { fResubstitutionEstimate = (R >= 0 ?  R : 0.0); }
      
         // return R(T_t) for node t
         inline Double_t GetResubstitutionEstimate( ) const { return fResubstitutionEstimate; }
      
         // set the critical point of alpha
         //             R(t) - R(T_t)
         //  alpha_c <  ------------- := g(t)
         //              |~T_t| - 1
         // which is the value of alpha such that the branch rooted at node t is pruned
         inline void SetAlphaC( Double_t alpha ) { fAlphaC = alpha; }

         // get the critical alpha value for this node
         inline Double_t GetAlphaC( ) const { return fAlphaC; }

         // set the minimum critical alpha value for descendants of node t ( G(t) = min(alpha_c, g(t_l), g(t_r)) )
         inline void SetMinAlphaC( Double_t alpha ) { fMinAlphaC = alpha; }

         // get the minimum critical alpha value 
         inline Double_t GetMinAlphaC( ) const { return fMinAlphaC; }

         // get the pointer to the wrapped DT node
         inline DecisionTreeNode* GetDTNode( ) const { return fDTNode; }

         // get pointers to children, mother in the CC tree
         inline CCTreeNode* GetLeftDaughter( ) { return dynamic_cast<CCTreeNode*>(GetLeft()); }
         inline CCTreeNode* GetRightDaughter( ) { return dynamic_cast<CCTreeNode*>(GetRight()); }
         inline CCTreeNode* GetMother( ) { return dynamic_cast<CCTreeNode*>(GetParent()); }

         // printout of the node (can be read in with ReadDataRecord)
         virtual void Print( ostream& os ) const;

         // recursive printout of the node and its daughters 
         virtual void PrintRec ( ostream& os ) const;

         virtual void AddAttributesToNode(void* node) const;
         virtual void AddContentToNode(std::stringstream& s) const;
         

         // test event if it decends the tree at this node to the right  
         inline virtual Bool_t GoesRight( const Event& e ) const { return (GetDTNode() != NULL ? 
                                                                           GetDTNode()->GoesRight(e) : false); }
      
         // test event if it decends the tree at this node to the left 
         inline virtual Bool_t GoesLeft ( const Event& e ) const { return (GetDTNode() != NULL ? 
                                                                           GetDTNode()->GoesLeft(e) : false); }
      
      private:

         // initialize a node from a data record
         virtual void ReadAttributes(void* node);
         virtual Bool_t ReadDataRecord( std::istream& in );
         virtual void ReadContent(std::stringstream& s);
         
         Int_t fNLeafDaughters; //! number of terminal descendants
         Double_t fNodeResubstitutionEstimate; //! R(t) = misclassification rate for node t
         Double_t fResubstitutionEstimate; //! R(T_t) = sum[t' in ~T_t]{ R(t) }
         Double_t fAlphaC; //! critical point, g(t) = alpha_c(t)
         Double_t fMinAlphaC; //! G(t), minimum critical point of t and its descendants
         DecisionTreeNode* fDTNode; //! pointer to wrapped node in the decision tree
      };

      CCTreeWrapper( DecisionTree* T,  SeparationBase* qualityIndex );
      ~CCTreeWrapper( );

      // return the decision tree output for an event 
      Double_t CheckEvent( const TMVA::Event & e, Bool_t useYesNoLeaf = false );
      // return the misclassification rate of a pruned tree for a validation event sample
      Double_t TestTreeQuality( const EventList* validationSample );
      Double_t TestTreeQuality( const DataSet* validationSample );

      // remove the branch rooted at node t
      void PruneNode( CCTreeNode* t );
      // initialize the node t and all its descendants
      void InitTree( CCTreeNode* t );

      // return the root node for this tree
      CCTreeNode* GetRoot() { return fRoot; }
   private:
      SeparationBase* fQualityIndex;  //! pointer to the used quality index calculator
      DecisionTree* fDTParent;        //! pointer to underlying DecisionTree
      CCTreeNode* fRoot;              //! the root node of the (wrapped) decision Tree
   };

}

#endif



 CCTreeWrapper.h:1
 CCTreeWrapper.h:2
 CCTreeWrapper.h:3
 CCTreeWrapper.h:4
 CCTreeWrapper.h:5
 CCTreeWrapper.h:6
 CCTreeWrapper.h:7
 CCTreeWrapper.h:8
 CCTreeWrapper.h:9
 CCTreeWrapper.h:10
 CCTreeWrapper.h:11
 CCTreeWrapper.h:12
 CCTreeWrapper.h:13
 CCTreeWrapper.h:14
 CCTreeWrapper.h:15
 CCTreeWrapper.h:16
 CCTreeWrapper.h:17
 CCTreeWrapper.h:18
 CCTreeWrapper.h:19
 CCTreeWrapper.h:20
 CCTreeWrapper.h:21
 CCTreeWrapper.h:22
 CCTreeWrapper.h:23
 CCTreeWrapper.h:24
 CCTreeWrapper.h:25
 CCTreeWrapper.h:26
 CCTreeWrapper.h:27
 CCTreeWrapper.h:28
 CCTreeWrapper.h:29
 CCTreeWrapper.h:30
 CCTreeWrapper.h:31
 CCTreeWrapper.h:32
 CCTreeWrapper.h:33
 CCTreeWrapper.h:34
 CCTreeWrapper.h:35
 CCTreeWrapper.h:36
 CCTreeWrapper.h:37
 CCTreeWrapper.h:38
 CCTreeWrapper.h:39
 CCTreeWrapper.h:40
 CCTreeWrapper.h:41
 CCTreeWrapper.h:42
 CCTreeWrapper.h:43
 CCTreeWrapper.h:44
 CCTreeWrapper.h:45
 CCTreeWrapper.h:46
 CCTreeWrapper.h:47
 CCTreeWrapper.h:48
 CCTreeWrapper.h:49
 CCTreeWrapper.h:50
 CCTreeWrapper.h:51
 CCTreeWrapper.h:52
 CCTreeWrapper.h:53
 CCTreeWrapper.h:54
 CCTreeWrapper.h:55
 CCTreeWrapper.h:56
 CCTreeWrapper.h:57
 CCTreeWrapper.h:58
 CCTreeWrapper.h:59
 CCTreeWrapper.h:60
 CCTreeWrapper.h:61
 CCTreeWrapper.h:62
 CCTreeWrapper.h:63
 CCTreeWrapper.h:64
 CCTreeWrapper.h:65
 CCTreeWrapper.h:66
 CCTreeWrapper.h:67
 CCTreeWrapper.h:68
 CCTreeWrapper.h:69
 CCTreeWrapper.h:70
 CCTreeWrapper.h:71
 CCTreeWrapper.h:72
 CCTreeWrapper.h:73
 CCTreeWrapper.h:74
 CCTreeWrapper.h:75
 CCTreeWrapper.h:76
 CCTreeWrapper.h:77
 CCTreeWrapper.h:78
 CCTreeWrapper.h:79
 CCTreeWrapper.h:80
 CCTreeWrapper.h:81
 CCTreeWrapper.h:82
 CCTreeWrapper.h:83
 CCTreeWrapper.h:84
 CCTreeWrapper.h:85
 CCTreeWrapper.h:86
 CCTreeWrapper.h:87
 CCTreeWrapper.h:88
 CCTreeWrapper.h:89
 CCTreeWrapper.h:90
 CCTreeWrapper.h:91
 CCTreeWrapper.h:92
 CCTreeWrapper.h:93
 CCTreeWrapper.h:94
 CCTreeWrapper.h:95
 CCTreeWrapper.h:96
 CCTreeWrapper.h:97
 CCTreeWrapper.h:98
 CCTreeWrapper.h:99
 CCTreeWrapper.h:100
 CCTreeWrapper.h:101
 CCTreeWrapper.h:102
 CCTreeWrapper.h:103
 CCTreeWrapper.h:104
 CCTreeWrapper.h:105
 CCTreeWrapper.h:106
 CCTreeWrapper.h:107
 CCTreeWrapper.h:108
 CCTreeWrapper.h:109
 CCTreeWrapper.h:110
 CCTreeWrapper.h:111
 CCTreeWrapper.h:112
 CCTreeWrapper.h:113
 CCTreeWrapper.h:114
 CCTreeWrapper.h:115
 CCTreeWrapper.h:116
 CCTreeWrapper.h:117
 CCTreeWrapper.h:118
 CCTreeWrapper.h:119
 CCTreeWrapper.h:120
 CCTreeWrapper.h:121
 CCTreeWrapper.h:122
 CCTreeWrapper.h:123
 CCTreeWrapper.h:124
 CCTreeWrapper.h:125
 CCTreeWrapper.h:126
 CCTreeWrapper.h:127
 CCTreeWrapper.h:128
 CCTreeWrapper.h:129
 CCTreeWrapper.h:130
 CCTreeWrapper.h:131
 CCTreeWrapper.h:132
 CCTreeWrapper.h:133
 CCTreeWrapper.h:134
 CCTreeWrapper.h:135
 CCTreeWrapper.h:136
 CCTreeWrapper.h:137
 CCTreeWrapper.h:138
 CCTreeWrapper.h:139
 CCTreeWrapper.h:140
 CCTreeWrapper.h:141
 CCTreeWrapper.h:142
 CCTreeWrapper.h:143
 CCTreeWrapper.h:144
 CCTreeWrapper.h:145
 CCTreeWrapper.h:146
 CCTreeWrapper.h:147
 CCTreeWrapper.h:148
 CCTreeWrapper.h:149
 CCTreeWrapper.h:150
 CCTreeWrapper.h:151
 CCTreeWrapper.h:152
 CCTreeWrapper.h:153
 CCTreeWrapper.h:154
 CCTreeWrapper.h:155
 CCTreeWrapper.h:156
 CCTreeWrapper.h:157
 CCTreeWrapper.h:158
 CCTreeWrapper.h:159
 CCTreeWrapper.h:160
 CCTreeWrapper.h:161
 CCTreeWrapper.h:162
 CCTreeWrapper.h:163
 CCTreeWrapper.h:164
 CCTreeWrapper.h:165
 CCTreeWrapper.h:166