ROOT  6.06/09
Reference Guide
MethodPyAdaBoost.h
Go to the documentation of this file.
1 // @(#)root/tmva/pymva $Id$
2 // Authors: Omar Zapata, Lorenzo Moneta, Sergei Gleyzer 2015
3 
4 /**********************************************************************************
5  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6  * Package: TMVA *
7  * Class : MethodPyAdaBoost *
8  * Web : http://oproject.org *
9  * *
10  * Description: *
11  * scikit-learn Package AdaBoostClassifier method based on python *
12  * *
13  **********************************************************************************/
14 
15 #ifndef ROOT_TMVA_MethodPyAdaBoost
16 #define ROOT_TMVA_MethodPyAdaBoost
17 
18 //////////////////////////////////////////////////////////////////////////
19 // //
20 // MethodPyAdaBoost //
21 // //
22 // //
23 //////////////////////////////////////////////////////////////////////////
24 
25 #ifndef ROOT_TMVA_PyMethodBase
26 #include "TMVA/PyMethodBase.h"
27 #endif
28 
29 namespace TMVA {
30 
31  class Factory; // DSMTEST
32  class Reader; // DSMTEST
33  class DataSetManager; // DSMTEST
34  class Types;
35  class MethodPyAdaBoost : public PyMethodBase {
36 
37  public :
38 
39  // constructors
40  MethodPyAdaBoost(const TString &jobName,
41  const TString &methodTitle,
42  DataSetInfo &theData,
43  const TString &theOption = "",
44  TDirectory *theTargetDir = NULL);
45 
47  const TString &theWeightFile,
48  TDirectory *theTargetDir = NULL);
49 
50 
51  ~MethodPyAdaBoost(void);
52  void Train();
53  // options treatment
54  void Init();
55  void DeclareOptions();
56  void ProcessOptions();
57  // create ranking
59  {
60  return NULL; // = 0;
61  }
62 
63 
64  Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets);
65 
66  // performs classifier testing
67  virtual void TestClassification();
68 
69 
70  Double_t GetMvaValue(Double_t *errLower = 0, Double_t *errUpper = 0);
71 
73  // the actual "weights"
74  virtual void AddWeightsXMLTo(void * /*parent */ ) const {} // = 0;
75  virtual void ReadWeightsFromXML(void * /*wghtnode*/ ) {} // = 0;
76  virtual void ReadWeightsFromStream(std::istream &) {} //= 0; // backward compatibility
77  void ReadStateFromFile();
78  private :
80  friend class Factory; // DSMTEST
81  friend class Reader; // DSMTEST
82  protected:
83  //AdaBoost options
84  TString base_estimator;//object, optional (default=DecisionTreeClassifier)
85  //The base estimator from which the boosted ensemble is built.
86  //Support for sample weighting is required, as well as proper `classes_`
87  //and `n_classes_` attributes.
88  Int_t n_estimators;//integer, optional (default=10)
89  //The number of trees in the forest.
90  Double_t learning_rate;//loat, optional (default=1.)
91  //Learning rate shrinks the contribution of each classifier by
92  //``learning_rate``. There is a trade-off between ``learning_rate`` and ``n_estimators``.
93  TString algorithm;//{'SAMME', 'SAMME.R'}, optional (default='SAMME.R')
94  //If 'SAMME.R' then use the SAMME.R real boosting algorithm.
95  //``base_estimator`` must support calculation of class probabilities.
96  //If 'SAMME' then use the SAMME discrete boosting algorithm.
97  //The SAMME.R algorithm typically converges faster than SAMME,
98  //achieving a lower test error with fewer boosting iterations.
99  TString random_state;//int, RandomState instance or None, optional (default=None)
100  //If int, random_state is the seed used by the random number generator;
101  //If RandomState instance, random_state is the random number generator;
102  //If None, the random number generator is the RandomState instance used by `np.random`.
103  // get help message text
104  void GetHelpMessage() const;
105 
106 
108  };
109 } // namespace TMVA
110 #endif
EAnalysisType
Definition: Types.h:124
Basic string class.
Definition: TString.h:137
int Int_t
Definition: RtypesCore.h:41
bool Bool_t
Definition: RtypesCore.h:59
virtual void TestClassification()
initialization
#define ClassDef(name, id)
Definition: Rtypes.h:254
DataSetManager * fDataSetManager
Double_t GetMvaValue(Double_t *errLower=0, Double_t *errUpper=0)
virtual void ReadWeightsFromXML(void *)
virtual void ReadWeightsFromStream(std::istream &)
unsigned int UInt_t
Definition: RtypesCore.h:42
virtual void AddWeightsXMLTo(void *) const
double Double_t
Definition: RtypesCore.h:55
Describe directory structure in memory.
Definition: TDirectory.h:41
int type
Definition: TGX11.cxx:120
const Ranking * CreateRanking()
Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets)
Abstract ClassifierFactory template that handles arbitrary types.
MethodPyAdaBoost(const TString &jobName, const TString &methodTitle, DataSetInfo &theData, const TString &theOption="", TDirectory *theTargetDir=NULL)
#define NULL
Definition: Rtypes.h:82
virtual void ReadWeightsFromStream(std::istream &)=0