Logo ROOT   6.16/01
Reference Guide
MethodPyAdaBoost.h
Go to the documentation of this file.
1// @(#)root/tmva/pymva $Id$
2// Authors: Omar Zapata, Lorenzo Moneta, Sergei Gleyzer 2015
3
4/**********************************************************************************
5 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6 * Package: TMVA *
7 * Class : MethodPyAdaBoost *
8 * Web : http://oproject.org *
9 * *
10 * Description: *
11 * scikit-learn package AdaBoostClassifier method based on python *
12 * *
13 **********************************************************************************/
14
15#ifndef ROOT_TMVA_MethodPyAdaBoost
16#define ROOT_TMVA_MethodPyAdaBoost
17
18//////////////////////////////////////////////////////////////////////////
19// //
20// MethodPyAdaBoost //
21// //
22//////////////////////////////////////////////////////////////////////////
23
24#include "TMVA/PyMethodBase.h"
25
26#include "TString.h"
27
28namespace TMVA {
29
30 class Factory;
31 class Reader;
32 class DataSetManager;
33 class Types;
35
36 public :
37 MethodPyAdaBoost(const TString &jobName,
38 const TString &methodTitle,
39 DataSetInfo &theData,
40 const TString &theOption = "");
41
43 const TString &theWeightFile);
44
46
47 void Train();
48
49 void Init();
50 void DeclareOptions();
51 void ProcessOptions();
52
53 // create ranking
54 const Ranking *CreateRanking();
55
56 Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets);
57
58 // performs classifier testing
59 virtual void TestClassification();
60
61 Double_t GetMvaValue(Double_t *errLower = 0, Double_t *errUpper = 0);
62 std::vector<Double_t> GetMvaValues(Long64_t firstEvt = 0, Long64_t lastEvt = -1, Bool_t logProgress = false);
63 std::vector<Float_t>& GetMulticlassValues();
64
65 virtual void ReadModelFromFile();
66
68 // the actual "weights"
69 virtual void AddWeightsXMLTo(void * /*parent */ ) const {} // = 0;
70 virtual void ReadWeightsFromXML(void * /*wghtnode*/ ) {} // = 0;
71 virtual void ReadWeightsFromStream(std::istream &) {} //= 0; backward compatibility
72
73 private :
75 friend class Factory;
76 friend class Reader;
77
78 protected:
79 std::vector<Double_t> mvaValues;
80 std::vector<Float_t> classValues;
81
82 UInt_t fNvars; // number of variables
83 UInt_t fNoutputs; // number of outputs
84 TString fFilenameClassifier; // Path to serialized classifier (default in `weights` folder)
85
86 //AdaBoost options
87
89 TString fBaseEstimator; //object, optional (default=DecisionTreeClassifier)
90 //The base estimator from which the boosted ensemble is built.
91 //Support for sample weighting is required, as well as proper `classes_`
92 //and `n_classes_` attributes.
93
95 Int_t fNestimators; //integer, optional (default=10)
96 //The number of trees in the forest.
97
99 Double_t fLearningRate; //loat, optional (default=1.)
100 //Learning rate shrinks the contribution of each classifier by
101 //``learning_rate``. There is a trade-off between ``learning_rate`` and ``n_estimators``.
102
104 TString fAlgorithm; //{'SAMME', 'SAMME.R'}, optional (default='SAMME.R')
105 //If 'SAMME.R' then use the SAMME.R real boosting algorithm.
106 //``base_estimator`` must support calculation of class probabilities.
107 //If 'SAMME' then use the SAMME discrete boosting algorithm.
108 //The SAMME.R algorithm typically converges faster than SAMME,
109 //achieving a lower test error with fewer boosting iterations.
110
112 TString fRandomState; //int, RandomState instance or None, optional (default=None)
113 //If int, random_state is the seed used by the random number generator;
114 //If RandomState instance, random_state is the random number generator;
115 //If None, the random number generator is the RandomState instance used by `np.random`.
116
117 // get help message text
118 void GetHelpMessage() const;
119
121 };
122
123} // namespace TMVA
124
125#endif // ROOT_TMVA_MethodPyAdaBoost
int Int_t
Definition: RtypesCore.h:41
unsigned int UInt_t
Definition: RtypesCore.h:42
bool Bool_t
Definition: RtypesCore.h:59
double Double_t
Definition: RtypesCore.h:55
long long Long64_t
Definition: RtypesCore.h:69
#define ClassDef(name, id)
Definition: Rtypes.h:324
int type
Definition: TGX11.cxx:120
_object PyObject
Definition: TPyArg.h:20
Class that contains all the data information.
Definition: DataSetInfo.h:60
Class that contains all the data information.
This is the main MVA steering class.
Definition: Factory.h:81
virtual void ReadWeightsFromStream(std::istream &)=0
DataSetManager * fDataSetManager
std::vector< Double_t > GetMvaValues(Long64_t firstEvt=0, Long64_t lastEvt=-1, Bool_t logProgress=false)
get all the MVA values for the events of the current Data type
Double_t GetMvaValue(Double_t *errLower=0, Double_t *errUpper=0)
std::vector< Double_t > mvaValues
const Ranking * CreateRanking()
std::vector< Float_t > classValues
virtual void AddWeightsXMLTo(void *) const
virtual void ReadWeightsFromStream(std::istream &)
MethodPyAdaBoost(const TString &jobName, const TString &methodTitle, DataSetInfo &theData, const TString &theOption="")
Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets)
virtual void TestClassification()
initialization
virtual void ReadModelFromFile()
std::vector< Float_t > & GetMulticlassValues()
virtual void ReadWeightsFromXML(void *)
Ranking for variables in method (implementation)
Definition: Ranking.h:48
The Reader class serves to use the MVAs in a specific analysis context.
Definition: Reader.h:63
EAnalysisType
Definition: Types.h:127
Basic string class.
Definition: TString.h:131
Abstract ClassifierFactory template that handles arbitrary types.