Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
MethodPyAdaBoost.h
Go to the documentation of this file.
1// @(#)root/tmva/pymva $Id$
2// Authors: Omar Zapata, Lorenzo Moneta, Sergei Gleyzer 2015
3
4/**********************************************************************************
5 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6 * Package: TMVA *
7 * Class : MethodPyAdaBoost *
8 * Web : http://oproject.org *
9 * *
10 * Description: *
11 * scikit-learn package AdaBoostClassifier method based on python *
12 * *
13 **********************************************************************************/
14
15#ifndef ROOT_TMVA_MethodPyAdaBoost
16#define ROOT_TMVA_MethodPyAdaBoost
17
18//////////////////////////////////////////////////////////////////////////
19// //
20// MethodPyAdaBoost //
21// //
22//////////////////////////////////////////////////////////////////////////
23
24#include "TMVA/PyMethodBase.h"
25
26#include "TString.h"
27#include <vector>
28
29namespace TMVA {
30
31 class Factory;
32 class Reader;
33 class DataSetManager;
34 class Types;
36
37 public :
39 const TString &methodTitle,
41 const TString &theOption = "");
42
44 const TString &theWeightFile);
45
47
48 void Train() override;
49
50 void Init() override;
51 void DeclareOptions() override;
52 void ProcessOptions() override;
53
54 // create ranking
55 const Ranking *CreateRanking() override;
56
58
59 // performs classifier testing
60 void TestClassification() override;
61
62 Double_t GetMvaValue(Double_t *errLower = nullptr, Double_t *errUpper = nullptr) override;
63 std::vector<Double_t> GetMvaValues(Long64_t firstEvt = 0, Long64_t lastEvt = -1, Bool_t logProgress = false) override;
64 std::vector<Float_t>& GetMulticlassValues() override;
65
66 void ReadModelFromFile() override;
67
69 // the actual "weights"
70 void AddWeightsXMLTo(void * /*parent */ ) const override {} // = 0;
71 void ReadWeightsFromXML(void * /*wghtnode*/ ) override {} // = 0;
72 void ReadWeightsFromStream(std::istream &) override {} //= 0; backward compatibility
73
74 private :
76 friend class Factory;
77 friend class Reader;
78
79 protected:
80 std::vector<Double_t> mvaValues;
81 std::vector<Float_t> classValues;
82
83 UInt_t fNvars; // number of variables
84 UInt_t fNoutputs; // number of outputs
85 TString fFilenameClassifier; // Path to serialized classifier (default in `weights` folder)
86
87 //AdaBoost options
88
90 TString fBaseEstimator; //object, optional (default=DecisionTreeClassifier)
91 //The base estimator from which the boosted ensemble is built.
92 //Support for sample weighting is required, as well as proper `classes_`
93 //and `n_classes_` attributes.
94
96 Int_t fNestimators; //integer, optional (default=10)
97 //The number of trees in the forest.
98
100 Double_t fLearningRate; //loat, optional (default=1.)
101 //Learning rate shrinks the contribution of each classifier by
102 //``learning_rate``. There is a trade-off between ``learning_rate`` and ``n_estimators``.
103
105 TString fAlgorithm; //{'SAMME', 'SAMME.R'}, optional (default='SAMME.R')
106 //If 'SAMME.R' then use the SAMME.R real boosting algorithm.
107 //``base_estimator`` must support calculation of class probabilities.
108 //If 'SAMME' then use the SAMME discrete boosting algorithm.
109 //The SAMME.R algorithm typically converges faster than SAMME,
110 //achieving a lower test error with fewer boosting iterations.
111
113 TString fRandomState; //int, RandomState instance or None, optional (default=None)
114 //If int, random_state is the seed used by the random number generator;
115 //If RandomState instance, random_state is the random number generator;
116 //If None, the random number generator is the RandomState instance used by `np.random`.
117
118 // get help message text
119 void GetHelpMessage() const override;
120
122 };
123
124} // namespace TMVA
125
126#endif // ROOT_TMVA_MethodPyAdaBoost
_object PyObject
long long Long64_t
Portable signed long integer 8 bytes.
Definition RtypesCore.h:83
#define ClassDefOverride(name, id)
Definition Rtypes.h:348
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t Atom_t Time_t type
Class that contains all the data information.
Definition DataSetInfo.h:62
Class that contains all the data information.
This is the main MVA steering class.
Definition Factory.h:80
void ReadWeightsFromStream(std::istream &) override=0
DataSetManager * fDataSetManager
std::vector< Double_t > GetMvaValues(Long64_t firstEvt=0, Long64_t lastEvt=-1, Bool_t logProgress=false) override
get all the MVA values for the events of the current Data type
void ReadWeightsFromXML(void *) override
std::vector< Double_t > mvaValues
void AddWeightsXMLTo(void *) const override
void ReadWeightsFromStream(std::istream &) override
std::vector< Float_t > classValues
const Ranking * CreateRanking() override
void GetHelpMessage() const override
Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets) override
std::vector< Float_t > & GetMulticlassValues() override
MethodPyAdaBoost(const TString &jobName, const TString &methodTitle, DataSetInfo &theData, const TString &theOption="")
Double_t GetMvaValue(Double_t *errLower=nullptr, Double_t *errUpper=nullptr) override
void TestClassification() override
initialization
void ReadModelFromFile() override
Virtual base class for all TMVA method based on Python.
Ranking for variables in method (implementation)
Definition Ranking.h:48
The Reader class serves to use the MVAs in a specific analysis context.
Definition Reader.h:64
Basic string class.
Definition TString.h:138
create variable transformations