Logo ROOT  
Reference Guide
MethodPyAdaBoost.h
Go to the documentation of this file.
1 // @(#)root/tmva/pymva $Id$
2 // Authors: Omar Zapata, Lorenzo Moneta, Sergei Gleyzer 2015
3 
4 /**********************************************************************************
5  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6  * Package: TMVA *
7  * Class : MethodPyAdaBoost *
8  * Web : http://oproject.org *
9  * *
10  * Description: *
11  * scikit-learn package AdaBoostClassifier method based on python *
12  * *
13  **********************************************************************************/
14 
15 #ifndef ROOT_TMVA_MethodPyAdaBoost
16 #define ROOT_TMVA_MethodPyAdaBoost
17 
18 //////////////////////////////////////////////////////////////////////////
19 // //
20 // MethodPyAdaBoost //
21 // //
22 //////////////////////////////////////////////////////////////////////////
23 
24 #include "TMVA/PyMethodBase.h"
25 
26 #include "TString.h"
27 #include <vector>
28 
29 namespace TMVA {
30 
31  class Factory;
32  class Reader;
33  class DataSetManager;
34  class Types;
35  class MethodPyAdaBoost : public PyMethodBase {
36 
37  public :
38  MethodPyAdaBoost(const TString &jobName,
39  const TString &methodTitle,
40  DataSetInfo &theData,
41  const TString &theOption = "");
42 
44  const TString &theWeightFile);
45 
47 
48  void Train();
49 
50  void Init();
51  void DeclareOptions();
52  void ProcessOptions();
53 
54  // create ranking
55  const Ranking *CreateRanking();
56 
57  Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets);
58 
59  // performs classifier testing
60  virtual void TestClassification();
61 
62  Double_t GetMvaValue(Double_t *errLower = 0, Double_t *errUpper = 0);
63  std::vector<Double_t> GetMvaValues(Long64_t firstEvt = 0, Long64_t lastEvt = -1, Bool_t logProgress = false);
64  std::vector<Float_t>& GetMulticlassValues();
65 
66  virtual void ReadModelFromFile();
67 
69  // the actual "weights"
70  virtual void AddWeightsXMLTo(void * /*parent */ ) const {} // = 0;
71  virtual void ReadWeightsFromXML(void * /*wghtnode*/ ) {} // = 0;
72  virtual void ReadWeightsFromStream(std::istream &) {} //= 0; backward compatibility
73 
74  private :
76  friend class Factory;
77  friend class Reader;
78 
79  protected:
80  std::vector<Double_t> mvaValues;
81  std::vector<Float_t> classValues;
82 
83  UInt_t fNvars; // number of variables
84  UInt_t fNoutputs; // number of outputs
85  TString fFilenameClassifier; // Path to serialized classifier (default in `weights` folder)
86 
87  //AdaBoost options
88 
90  TString fBaseEstimator; //object, optional (default=DecisionTreeClassifier)
91  //The base estimator from which the boosted ensemble is built.
92  //Support for sample weighting is required, as well as proper `classes_`
93  //and `n_classes_` attributes.
94 
96  Int_t fNestimators; //integer, optional (default=10)
97  //The number of trees in the forest.
98 
100  Double_t fLearningRate; //loat, optional (default=1.)
101  //Learning rate shrinks the contribution of each classifier by
102  //``learning_rate``. There is a trade-off between ``learning_rate`` and ``n_estimators``.
103 
105  TString fAlgorithm; //{'SAMME', 'SAMME.R'}, optional (default='SAMME.R')
106  //If 'SAMME.R' then use the SAMME.R real boosting algorithm.
107  //``base_estimator`` must support calculation of class probabilities.
108  //If 'SAMME' then use the SAMME discrete boosting algorithm.
109  //The SAMME.R algorithm typically converges faster than SAMME,
110  //achieving a lower test error with fewer boosting iterations.
111 
113  TString fRandomState; //int, RandomState instance or None, optional (default=None)
114  //If int, random_state is the seed used by the random number generator;
115  //If RandomState instance, random_state is the random number generator;
116  //If None, the random number generator is the RandomState instance used by `np.random`.
117 
118  // get help message text
119  void GetHelpMessage() const;
120 
122  };
123 
124 } // namespace TMVA
125 
126 #endif // ROOT_TMVA_MethodPyAdaBoost
TMVA::MethodPyAdaBoost::GetMulticlassValues
std::vector< Float_t > & GetMulticlassValues()
Definition: MethodPyAdaBoost.cxx:353
TMVA::MethodPyAdaBoost::fRandomState
TString fRandomState
Definition: MethodPyAdaBoost.h:113
TMVA::MethodPyAdaBoost::MethodPyAdaBoost
MethodPyAdaBoost(const TString &jobName, const TString &methodTitle, DataSetInfo &theData, const TString &theOption="")
Definition: MethodPyAdaBoost.cxx:62
TMVA::MethodPyAdaBoost::fNestimators
Int_t fNestimators
Definition: MethodPyAdaBoost.h:96
TMVA::PyMethodBase
Definition: PyMethodBase.h:56
PyObject
_object PyObject
Definition: PyMethodBase.h:42
TMVA::MethodBase::ReadWeightsFromStream
virtual void ReadWeightsFromStream(std::istream &)=0
TMVA::MethodPyAdaBoost::TestClassification
virtual void TestClassification()
initialization
Definition: MethodPyAdaBoost.cxx:257
TMVA::Ranking
Ranking for variables in method (implementation)
Definition: Ranking.h:48
TMVA::MethodPyAdaBoost::GetHelpMessage
void GetHelpMessage() const
Definition: MethodPyAdaBoost.cxx:428
TMVA::MethodPyAdaBoost::pAlgorithm
PyObject * pAlgorithm
Definition: MethodPyAdaBoost.h:104
Long64_t
long long Long64_t
Definition: RtypesCore.h:80
TMVA::MethodPyAdaBoost
Definition: MethodPyAdaBoost.h:35
TMVA::MethodPyAdaBoost::GetMvaValue
Double_t GetMvaValue(Double_t *errLower=0, Double_t *errUpper=0)
Definition: MethodPyAdaBoost.cxx:321
TMVA::MethodPyAdaBoost::Train
void Train()
Definition: MethodPyAdaBoost.cxx:199
TMVA::MethodPyAdaBoost::ReadModelFromFile
virtual void ReadModelFromFile()
Definition: MethodPyAdaBoost.cxx:379
TString
Basic string class.
Definition: TString.h:136
TMVA::MethodPyAdaBoost::~MethodPyAdaBoost
~MethodPyAdaBoost()
Definition: MethodPyAdaBoost.cxx:88
TMVA::MethodPyAdaBoost::pLearningRate
PyObject * pLearningRate
Definition: MethodPyAdaBoost.h:99
TString.h
TMVA::MethodPyAdaBoost::ProcessOptions
void ProcessOptions()
Definition: MethodPyAdaBoost.cxx:138
bool
TMVA::MethodPyAdaBoost::pBaseEstimator
PyObject * pBaseEstimator
Definition: MethodPyAdaBoost.h:89
TMVA::DataSetInfo
Class that contains all the data information.
Definition: DataSetInfo.h:62
TMVA::MethodPyAdaBoost::fNoutputs
UInt_t fNoutputs
Definition: MethodPyAdaBoost.h:84
TMVA::Types::EAnalysisType
EAnalysisType
Definition: Types.h:128
TMVA::MethodPyAdaBoost::fNvars
UInt_t fNvars
Definition: MethodPyAdaBoost.h:83
TMVA::MethodPyAdaBoost::pRandomState
PyObject * pRandomState
Definition: MethodPyAdaBoost.h:112
TMVA::MethodPyAdaBoost::fBaseEstimator
TString fBaseEstimator
Definition: MethodPyAdaBoost.h:90
TMVA::MethodPyAdaBoost::GetMvaValues
std::vector< Double_t > GetMvaValues(Long64_t firstEvt=0, Long64_t lastEvt=-1, Bool_t logProgress=false)
get all the MVA values for the events of the current Data type
Definition: MethodPyAdaBoost.cxx:263
TMVA::MethodPyAdaBoost::classValues
std::vector< Float_t > classValues
Definition: MethodPyAdaBoost.h:81
TMVA::Factory
This is the main MVA steering class.
Definition: Factory.h:80
TMVA::MethodPyAdaBoost::DeclareOptions
void DeclareOptions()
Definition: MethodPyAdaBoost.cxx:101
unsigned int
TMVA::DataSetManager
Class that contains all the data information.
Definition: DataSetManager.h:51
TMVA::MethodPyAdaBoost::HasAnalysisType
Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets)
Definition: MethodPyAdaBoost.cxx:93
TMVA::MethodPyAdaBoost::fLearningRate
Double_t fLearningRate
Definition: MethodPyAdaBoost.h:100
Double_t
double Double_t
Definition: RtypesCore.h:59
TMVA::MethodPyAdaBoost::ReadWeightsFromXML
virtual void ReadWeightsFromXML(void *)
Definition: MethodPyAdaBoost.h:71
TMVA::MethodPyAdaBoost::fFilenameClassifier
TString fFilenameClassifier
Definition: MethodPyAdaBoost.h:85
ClassDef
#define ClassDef(name, id)
Definition: Rtypes.h:325
TMVA::MethodPyAdaBoost::CreateRanking
const Ranking * CreateRanking()
Definition: MethodPyAdaBoost.cxx:406
TMVA::MethodPyAdaBoost::ReadWeightsFromStream
virtual void ReadWeightsFromStream(std::istream &)
Definition: MethodPyAdaBoost.h:72
type
int type
Definition: TGX11.cxx:121
PyMethodBase.h
TMVA::MethodPyAdaBoost::pNestimators
PyObject * pNestimators
Definition: MethodPyAdaBoost.h:95
TMVA::Reader
The Reader class serves to use the MVAs in a specific analysis context.
Definition: Reader.h:64
TMVA::MethodPyAdaBoost::fAlgorithm
TString fAlgorithm
Definition: MethodPyAdaBoost.h:105
TMVA::MethodPyAdaBoost::AddWeightsXMLTo
virtual void AddWeightsXMLTo(void *) const
Definition: MethodPyAdaBoost.h:70
TMVA::MethodPyAdaBoost::mvaValues
std::vector< Double_t > mvaValues
Definition: MethodPyAdaBoost.h:80
TMVA::MethodPyAdaBoost::Init
void Init()
Definition: MethodPyAdaBoost.cxx:182
TMVA
create variable transformations
Definition: GeneticMinimizer.h:22
int
TMVA::MethodPyAdaBoost::fDataSetManager
DataSetManager * fDataSetManager
Definition: MethodPyAdaBoost.h:75