Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
RSampler.hxx
Go to the documentation of this file.
1// Author: Martin Føll, University of Oslo (UiO) & CERN 01/2026
2
3/*************************************************************************
4 * Copyright (C) 1995-2026, Rene Brun and Fons Rademakers. *
5 * All rights reserved. *
6 * *
7 * For the licensing terms see $ROOTSYS/LICENSE. *
8 * For the list of contributors see $ROOTSYS/README/CREDITS. *
9 *************************************************************************/
10
11#ifndef ROOT_INTERNAL_ML_RSAMPLER
12#define ROOT_INTERNAL_ML_RSAMPLER
13
14#include <vector>
15#include <random>
16#include <algorithm>
17
18#include "ROOT/RDataFrame.hxx"
19#include "ROOT/RDF/Utils.hxx"
20#include "ROOT/RVec.hxx"
22#include "ROOT/RLogger.hxx"
23
25/**
26\class ROOT::Experimental::Internal::ML::RSampler
27
28\brief Implementation of different sampling strategies.
29*/
30
31class RSampler {
32private:
33 std::vector<RFlat2DMatrix> &fDatasets;
34 std::string fSampleType;
38 std::size_t fSetSeed;
39 std::size_t fNumEntries;
40
41 std::size_t fMajor;
42 std::size_t fMinor;
43 std::size_t fNumMajor;
44 std::size_t fNumMinor;
45 std::size_t fNumResampledMajor;
46 std::size_t fNumResampledMinor;
47
48 std::vector<std::size_t> fSamples;
49
50 std::unique_ptr<RFlat2DMatrixOperators> fTensorOperators;
51
52public:
53 RSampler(std::vector<RFlat2DMatrix> &datasets, const std::string &sampleType, float sampleRatio,
54 bool replacement = false, bool shuffle = true, std::size_t setSeed = 0)
55 : fDatasets(datasets),
61 {
62 fTensorOperators = std::make_unique<RFlat2DMatrixOperators>(fShuffle, fSetSeed);
63
64 // setup the sampler for the datasets
66 }
67
68 //////////////////////////////////////////////////////////////////////////
69 /// \brief Calculate fNumEntries and major/minor variables
71 {
72 if (fSampleType == "undersampling") {
74 } else if (fSampleType == "oversampling") {
76 }
77 }
78
79 //////////////////////////////////////////////////////////////////////////
80 /// \brief Collection of sampling types
81 /// \param[in] SampledTensor Tensor with all the sampled entries
83 {
84 if (fSampleType == "undersampling") {
86 } else if (fSampleType == "oversampling") {
88 }
89 }
90
91 //////////////////////////////////////////////////////////////////////////
92 /// \brief Calculate fNumEntries and major/minor variables for the random undersampler
94 {
95 if (fDatasets[0].GetRows() > fDatasets[1].GetRows()) {
96 fMajor = 0;
97 fMinor = 1;
98 } else {
99 fMajor = 1;
100 fMinor = 0;
101 }
102
103 fNumMajor = fDatasets[fMajor].GetRows();
104 fNumMinor = fDatasets[fMinor].GetRows();
105 fNumResampledMajor = static_cast<std::size_t>(fNumMinor / fSampleRatio);
107 auto minRatio = std::to_string(std::round(double(fNumMinor) / double(fNumMajor) * 100.0) / 100.0);
108 minRatio.erase(minRatio.find('.') + 3);
109 throw std::invalid_argument(
110 "The sampling_ratio is too low: not enough entries in the majority class to sample from.\n"
111 "Choose sampling_ratio > " +
112 minRatio + " or set replacement to True.");
113 }
115 }
116
117 //////////////////////////////////////////////////////////////////////////
118 /// \brief Calculate fNumEntries and major/minor variables for the random oversampler
120 {
121 if (fDatasets[0].GetRows() > fDatasets[1].GetRows()) {
122 fMajor = 0;
123 fMinor = 1;
124 } else {
125 fMajor = 1;
126 fMinor = 0;
127 }
128
129 fNumMajor = fDatasets[fMajor].GetRows();
130 fNumMinor = fDatasets[fMinor].GetRows();
131 fNumResampledMinor = static_cast<std::size_t>(fSampleRatio * fNumMajor);
133 }
134
135 //////////////////////////////////////////////////////////////////////////
136 /// \brief Undersample entries randomly from the majority dataset
137 /// \param[in] SampledTensor Tensor with all the sampled entries
139 {
140 if (fReplacement) {
142 }
143
144 else {
146 }
147
148 std::size_t cols = fDatasets[0].GetCols();
152
153 std::size_t index = 0;
154 for (std::size_t i = 0; i < fNumResampledMajor; i++) {
155 std::copy(fDatasets[fMajor].GetData() + fSamples[i] * cols,
156 fDatasets[fMajor].GetData() + (fSamples[i] + 1) * cols,
157 UndersampledMajorTensor.GetData() + index * cols);
158 index++;
159 }
160
163 }
164
165 //////////////////////////////////////////////////////////////////////////
166 /// \brief Oversample entries randomly from the minority dataset
167 /// \param[in] SampledTensor Tensor with all the sampled entries
169 {
171
172 std::size_t cols = fDatasets[0].GetCols();
176
177 std::size_t index = 0;
178 for (std::size_t i = 0; i < fNumResampledMinor; i++) {
179 std::copy(fDatasets[fMinor].GetData() + fSamples[i] * cols,
180 fDatasets[fMinor].GetData() + (fSamples[i] + 1) * cols,
181 OversampledMinorTensor.GetData() + index * cols);
182 index++;
183 }
184
187 }
188
189 //////////////////////////////////////////////////////////////////////////
190 /// \brief Add indices with replacement to fSamples
191 /// \param[in] n_samples Number of indices to sample
192 /// \param[in] max Max index of the sample distribution
193 void SampleWithReplacement(std::size_t n_samples, std::size_t max)
194 {
195 std::uniform_int_distribution<> dist(0, max - 1);
196 fSamples.clear();
197 fSamples.reserve(n_samples);
198 for (std::size_t i = 0; i < n_samples; ++i) {
199 std::size_t sample;
200 if (fShuffle) {
201 std::random_device rd;
202 std::mt19937 g;
203
204 if (fSetSeed == 0) {
205 g.seed(rd());
206 } else {
207 g.seed(fSetSeed);
208 }
209
210 sample = dist(g);
211 }
212
213 else {
214 sample = i % max;
215 }
216 fSamples.push_back(sample);
217 }
218 }
219
220 //////////////////////////////////////////////////////////////////////////
221 /// \brief Add indices without replacement to fSamples
222 /// \param[in] n_samples Number of indices to sample
223 /// \param[in] max Max index of the sample distribution
224 void SampleWithoutReplacement(std::size_t n_samples, std::size_t max)
225 {
226 std::vector<std::size_t> UniqueSamples;
227 UniqueSamples.reserve(max);
228 fSamples.clear();
229 fSamples.reserve(n_samples);
230
231 for (std::size_t i = 0; i < max; ++i)
232 UniqueSamples.push_back(i);
233
234 if (fShuffle) {
235 std::random_device rd;
236 std::mt19937 g;
237
238 if (fSetSeed == 0) {
239 g.seed(rd());
240 } else {
241 g.seed(fSetSeed);
242 }
243 std::shuffle(UniqueSamples.begin(), UniqueSamples.end(), g);
244 }
245
246 for (std::size_t i = 0; i < n_samples; ++i) {
247 fSamples.push_back(UniqueSamples[i]);
248 }
249 }
250
251 std::size_t GetNumEntries() { return fNumEntries; }
252};
253
254} // namespace ROOT::Experimental::Internal::ML
255#endif // ROOT_INTERNAL_ML_RSAMPLER
#define g(i)
Definition RSha256.hxx:105
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t index
Implementation of different sampling strategies.
Definition RSampler.hxx:31
void SampleWithoutReplacement(std::size_t n_samples, std::size_t max)
Add indices without replacement to fSamples.
Definition RSampler.hxx:224
void SetupRandomUndersampler()
Calculate fNumEntries and major/minor variables for the random undersampler.
Definition RSampler.hxx:93
void RandomOversampler(RFlat2DMatrix &ShuffledTensor)
Oversample entries randomly from the minority dataset.
Definition RSampler.hxx:168
void SampleWithReplacement(std::size_t n_samples, std::size_t max)
Add indices with replacement to fSamples.
Definition RSampler.hxx:193
void SetupRandomOversampler()
Calculate fNumEntries and major/minor variables for the random oversampler.
Definition RSampler.hxx:119
void SetupSampler()
Calculate fNumEntries and major/minor variables.
Definition RSampler.hxx:70
std::vector< std::size_t > fSamples
Definition RSampler.hxx:48
std::unique_ptr< RFlat2DMatrixOperators > fTensorOperators
Definition RSampler.hxx:50
RSampler(std::vector< RFlat2DMatrix > &datasets, const std::string &sampleType, float sampleRatio, bool replacement=false, bool shuffle=true, std::size_t setSeed=0)
Definition RSampler.hxx:53
void RandomUndersampler(RFlat2DMatrix &ShuffledTensor)
Undersample entries randomly from the majority dataset.
Definition RSampler.hxx:138
std::vector< RFlat2DMatrix > & fDatasets
Definition RSampler.hxx:33
void Sampler(RFlat2DMatrix &SampledTensor)
Collection of sampling types.
Definition RSampler.hxx:82
const_iterator begin() const
const_iterator end() const
Wrapper around ROOT::RVec<float> representing a 2D matrix.