Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
TMVA::CostComplexityPruneTool Class Reference

A class to prune a decision tree using the Cost Complexity method.

(see "Classification and Regression Trees" by Leo Breiman et al)

Some definitions:

  • \( T_{max} \) - the initial, usually highly overtrained tree, that is to be pruned back
  • \( R(T) \) - quality index (Gini, misclassification rate, or other) of a tree \( T \)
  • \( \sim T \) - set of terminal nodes in \( T \)
  • \( T' \) - the pruned subtree of \( T_max \) that has the best quality index \( R(T') \)
  • \( \alpha \) - the prune strength parameter in Cost Complexity pruning \( (R_{\alpha}(T) = R(T) + \alpha*|\sim T|) \)

There are two running modes in CCPruner: (i) one may select a prune strength and prune back the tree \( T_{max}\) until the criterion:

\[ \alpha < \frac{R(T) - R(t)}{|\sim T_t| - 1} \]

is true for all nodes t in \( T \), or (ii) the algorithm finds the sequence of critical points \( \alpha_k < \alpha_{k+1} ... < \alpha_K \) such that \( T_K = root(T_{max}) \) and then selects the optimally-pruned subtree, defined to be the subtree with the best quality index for the validation sample.

Definition at line 62 of file CostComplexityPruneTool.h.

Public Member Functions

 CostComplexityPruneTool (SeparationBase *qualityIndex=nullptr)
 the constructor for the cost complexity pruning
 
virtual ~CostComplexityPruneTool ()
 the destructor for the cost complexity pruning
 
PruningInfoCalculatePruningInfo (DecisionTree *dt, const IPruneTool::EventSample *testEvents=nullptr, Bool_t isAutomatic=kFALSE) override
 the routine that basically "steers" the pruning process.
 
- Public Member Functions inherited from TMVA::IPruneTool
 IPruneTool ()
 
virtual ~IPruneTool ()
 
Double_t GetPruneStrength () const
 
Bool_t IsAutomatic () const
 
void SetAutomatic ()
 
void SetPruneStrength (Double_t alpha)
 

Private Member Functions

void InitTreePruningMetaData (DecisionTreeNode *n)
 initialise "meta data" for the pruning, like the "costcomplexity", the critical alpha, the minimal alpha down the tree, etc... for each node!!
 
MsgLoggerLog () const
 
void Optimize (DecisionTree *dt, Double_t weights)
 after the critical \( \alpha \) values (at which the corresponding nodes would be pruned away) had been established in the "InitMetaData" we need now: automatic pruning:
 

Private Attributes

MsgLoggerfLogger
 ! output stream to save logging information
 
Int_t fOptimalK
 ! the optimal index of the prune sequence
 
std::vector< DecisionTreeNode * > fPruneSequence
 ! map of weakest links (i.e., branches to prune) -> pruning index
 
std::vector< Double_tfPruneStrengthList
 ! map of alpha -> pruning index
 
std::vector< Double_tfQualityIndexList
 ! map of R(T) -> pruning index
 
SeparationBasefQualityIndexTool
 ! the quality index used to calculate R(t), R(T) = sum[t in ~T]{ R(t) }
 

Additional Inherited Members

- Public Types inherited from TMVA::IPruneTool
typedef std::vector< const Event * > EventSample
 
- Protected Attributes inherited from TMVA::IPruneTool
Double_t B
 
Double_t fPruneStrength
 ! regularization parameter in pruning
 
Double_t S
 

#include <TMVA/CostComplexityPruneTool.h>

Inheritance diagram for TMVA::CostComplexityPruneTool:
[legend]

Constructor & Destructor Documentation

◆ CostComplexityPruneTool()

CostComplexityPruneTool::CostComplexityPruneTool ( SeparationBase * qualityIndex = nullptr)

the constructor for the cost complexity pruning

Definition at line 68 of file CostComplexityPruneTool.cxx.

◆ ~CostComplexityPruneTool()

CostComplexityPruneTool::~CostComplexityPruneTool ( )
virtual

the destructor for the cost complexity pruning

Definition at line 89 of file CostComplexityPruneTool.cxx.

Member Function Documentation

◆ CalculatePruningInfo()

PruningInfo * CostComplexityPruneTool::CalculatePruningInfo ( DecisionTree * dt,
const IPruneTool::EventSample * validationSample = nullptr,
Bool_t isAutomatic = kFALSE )
overridevirtual

the routine that basically "steers" the pruning process.

Call the calculation of the pruning sequence, the tree quality and alike..

Implements TMVA::IPruneTool.

Definition at line 98 of file CostComplexityPruneTool.cxx.

◆ InitTreePruningMetaData()

void CostComplexityPruneTool::InitTreePruningMetaData ( DecisionTreeNode * n)
private

initialise "meta data" for the pruning, like the "costcomplexity", the critical alpha, the minimal alpha down the tree, etc... for each node!!

Definition at line 181 of file CostComplexityPruneTool.cxx.

◆ Log()

MsgLogger & TMVA::CostComplexityPruneTool::Log ( ) const
inlineprivate

Definition at line 87 of file CostComplexityPruneTool.h.

◆ Optimize()

void CostComplexityPruneTool::Optimize ( DecisionTree * dt,
Double_t weights )
private

after the critical \( \alpha \) values (at which the corresponding nodes would be pruned away) had been established in the "InitMetaData" we need now: automatic pruning:

find the value of \( \alpha \) for which the test sample gives minimal error, on the tree with all nodes pruned that have \( \alpha_{critical} < \alpha \), fixed parameter pruning

Definition at line 236 of file CostComplexityPruneTool.cxx.

Member Data Documentation

◆ fLogger

MsgLogger* TMVA::CostComplexityPruneTool::fLogger
mutableprivate

! output stream to save logging information

Definition at line 86 of file CostComplexityPruneTool.h.

◆ fOptimalK

Int_t TMVA::CostComplexityPruneTool::fOptimalK
private

! the optimal index of the prune sequence

Definition at line 77 of file CostComplexityPruneTool.h.

◆ fPruneSequence

std::vector<DecisionTreeNode*> TMVA::CostComplexityPruneTool::fPruneSequence
private

! map of weakest links (i.e., branches to prune) -> pruning index

Definition at line 73 of file CostComplexityPruneTool.h.

◆ fPruneStrengthList

std::vector<Double_t> TMVA::CostComplexityPruneTool::fPruneStrengthList
private

! map of alpha -> pruning index

Definition at line 74 of file CostComplexityPruneTool.h.

◆ fQualityIndexList

std::vector<Double_t> TMVA::CostComplexityPruneTool::fQualityIndexList
private

! map of R(T) -> pruning index

Definition at line 75 of file CostComplexityPruneTool.h.

◆ fQualityIndexTool

SeparationBase* TMVA::CostComplexityPruneTool::fQualityIndexTool
private

! the quality index used to calculate R(t), R(T) = sum[t in ~T]{ R(t) }

Definition at line 71 of file CostComplexityPruneTool.h.

Libraries for TMVA::CostComplexityPruneTool:

The documentation for this class was generated from the following files: