Logo ROOT   6.16/01
Reference Guide
RMethodBase.cxx
Go to the documentation of this file.
1// @(#)root/tmva/rmva $Id$
2// Author: Omar Zapata,Lorenzo Moneta, Sergei Gleyzer 2015
3
4
5/**********************************************************************************
6 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
7 * Package: TMVA *
8 * Class : RMethodBase *
9 * *
10 * Description: *
11 * Virtual base class for all MVA method based on ROOTR *
12 * *
13 **********************************************************************************/
14
15#include<TMVA/RMethodBase.h>
16#include <TMVA/DataSetInfo.h>
17#include<TApplication.h>
18using namespace TMVA;
19
21
22//_______________________________________________________________________
23RMethodBase::RMethodBase(const TString &jobName,
24 Types::EMVA methodType,
25 const TString &methodTitle,
26 DataSetInfo &dsi,
27 const TString &theOption , ROOT::R::TRInterface &_r): MethodBase(jobName, methodType, methodTitle, dsi, theOption),
28 r(_r)
29{
30 LoadData();
31}
32
33//_______________________________________________________________________
35 DataSetInfo &dsi,
36 const TString &weightFile,ROOT::R::TRInterface &_r): MethodBase(methodType, dsi, weightFile),
37 r(_r)
38{
39 LoadData();
40}
41
42//_______________________________________________________________________
44{
45 ///////////////////////////
46 //Loading Training Data //
47 ///////////////////////////
48 const UInt_t nvar = DataInfo().GetNVariables();
49
50 const UInt_t ntrains = Data()->GetNTrainingEvents();
51
52 //array of columns for every var to create a dataframe for training
53 std::vector<std::vector<Float_t> > fArrayTrain(nvar);
54// Data()->SetCurrentEvent(1);
55// Data()->SetCurrentType(Types::ETreeType::kTraining);
56
57 fWeightTrain.ResizeTo(ntrains);
58 for (UInt_t j = 0; j < ntrains; j++) {
59 const Event *ev = Data()->GetEvent(j, Types::ETreeType::kTraining);
60// const Event *ev=Data()->GetEvent( j );
61 //creating array with class type(signal or background) for factor required
62 if (ev->GetClass() == Types::kSignal) fFactorTrain.push_back("signal");
63 else fFactorTrain.push_back("background");
64
65 fWeightTrain[j] = ev->GetWeight();
66
67 //filling vector of columns for training
68 for (UInt_t i = 0; i < nvar; i++) {
69 fArrayTrain[i].push_back(ev->GetValue(i));
70 }
71
72 }
73 for (UInt_t i = 0; i < nvar; i++) {
74 fDfTrain[DataInfo().GetListOfVariables()[i].Data()] = fArrayTrain[i];
75 }
76 ////////////////////////
77 //Loading Test Data //
78 ////////////////////////
79
80 const UInt_t ntests = Data()->GetNTestEvents();
81 const UInt_t nspectators = DataInfo().GetNSpectators(kTRUE);
82
83 //array of columns for every var to create a dataframe for testing
84 std::vector<std::vector<Float_t> > fArrayTest(nvar);
85 //array of columns for every spectator to create a dataframe for testing
86 std::vector<std::vector<Float_t> > fArraySpectators(nvar);
87 fWeightTest.ResizeTo(ntests);
88// Data()->SetCurrentType(Types::ETreeType::kTesting);
89 for (UInt_t j = 0; j < ntests; j++) {
90 const Event *ev = Data()->GetEvent(j, Types::ETreeType::kTesting);
91// const Event *ev=Data()->GetEvent(j);
92 //creating array with class type(signal or background) for factor required
93 if (ev->GetClass() == Types::kSignal) fFactorTest.push_back("signal");
94 else fFactorTest.push_back("background");
95
96 fWeightTest[j] = ev->GetWeight();
97
98 for (UInt_t i = 0; i < nvar; i++) {
99 fArrayTest[i].push_back(ev->GetValue(i));
100 }
101 for (UInt_t i = 0; i < nspectators; i++) {
102 fArraySpectators[i].push_back(ev->GetSpectator(i));
103 }
104 }
105 for (UInt_t i = 0; i < nvar; i++) {
106 fDfTest[DataInfo().GetListOfVariables()[i].Data()] = fArrayTest[i];
107 }
108 for (UInt_t i = 0; i < nspectators; i++) {
109 fDfSpectators[DataInfo().GetSpectatorInfo(i).GetLabel().Data()] = fArraySpectators[i];
110 }
111
112}
ROOT::R::TRInterface & r
Definition: Object.C:4
unsigned int UInt_t
Definition: RtypesCore.h:42
const Bool_t kTRUE
Definition: RtypesCore.h:87
#define ClassImp(name)
Definition: Rtypes.h:363
ROOT R was implemented using the R Project library and the modules Rcpp and RInside
Definition: TRInterface.h:137
Class that contains all the data information.
Definition: DataSetInfo.h:60
UInt_t GetNVariables() const
Definition: DataSetInfo.h:110
UInt_t GetNSpectators(bool all=kTRUE) const
std::vector< TString > GetListOfVariables() const
returns list of variables
VariableInfo & GetSpectatorInfo(Int_t i)
Definition: DataSetInfo.h:106
Long64_t GetNTestEvents() const
Definition: DataSet.h:80
const Event * GetEvent() const
Definition: DataSet.cxx:202
Long64_t GetNTrainingEvents() const
Definition: DataSet.h:79
Float_t GetValue(UInt_t ivar) const
return value of i'th variable
Definition: Event.cxx:237
Double_t GetWeight() const
return the event weight - depending on whether the flag IgnoreNegWeightsInTraining is or not.
Definition: Event.cxx:382
Float_t GetSpectator(UInt_t ivar) const
return spectator content
Definition: Event.cxx:262
UInt_t GetClass() const
Definition: Event.h:81
Virtual base Class for all MVA method.
Definition: MethodBase.h:109
DataSetInfo & DataInfo() const
Definition: MethodBase.h:401
DataSet * Data() const
Definition: MethodBase.h:400
std::vector< std::string > fFactorTrain
Definition: RMethodBase.h:92
ROOT::R::TRDataFrame fDfTrain
Definition: RMethodBase.h:88
RMethodBase(const TString &jobName, Types::EMVA methodType, const TString &methodTitle, DataSetInfo &dsi, const TString &theOption="", ROOT::R::TRInterface &_r=ROOT::R::TRInterface::Instance())
Definition: RMethodBase.cxx:23
TVectorD fWeightTrain
Definition: RMethodBase.h:90
ROOT::R::TRDataFrame fDfTest
Definition: RMethodBase.h:89
TVectorD fWeightTest
Definition: RMethodBase.h:91
std::vector< std::string > fFactorTest
Definition: RMethodBase.h:93
ROOT::R::TRDataFrame fDfSpectators
Definition: RMethodBase.h:94
@ kSignal
Definition: Types.h:136
const TString & GetLabel() const
Definition: VariableInfo.h:59
TVectorT< Element > & ResizeTo(Int_t lwb, Int_t upb)
Resize the vector to [lwb:upb] .
Definition: TVectorT.cxx:292
Abstract ClassifierFactory template that handles arbitrary types.