Logo ROOT   6.16/01
Reference Guide
RArrowDS.cxx
Go to the documentation of this file.
1// Author: Giulio Eulisse CERN 2/2018
2
3/*************************************************************************
4 * Copyright (C) 1995-2018, Rene Brun and Fons Rademakers. *
5 * All rights reserved. *
6 * *
7 * For the licensing terms see $ROOTSYS/LICENSE. *
8 * For the list of contributors see $ROOTSYS/README/CREDITS. *
9 *************************************************************************/
10
11// clang-format off
12/** \class ROOT::RDF::RArrowDS
13 \ingroup dataframe
14 \brief RDataFrame data source class to interface with Apache Arrow.
15
16The RArrowDS implements a proxy RDataSource to be able to use Apache Arrow
17tables with RDataFrame.
18
19A RDataFrame that adapts an arrow::Table class can be constructed using the factory method
20ROOT::RDF::MakeArrowDataFrame, which accepts one parameter:
211. An arrow::Table smart pointer.
22
23The types of the columns are derived from the types in the associated
24arrow::Schema.
25
26*/
27// clang-format on
28
29#include <ROOT/RDF/Utils.hxx>
30#include <ROOT/TSeq.hxx>
31#include <ROOT/RArrowDS.hxx>
32#include <ROOT/RMakeUnique.hxx>
33
34#include <algorithm>
35#include <sstream>
36#include <string>
37
38#if defined(__GNUC__)
39#pragma GCC diagnostic push
40#pragma GCC diagnostic ignored "-Wshadow"
41#endif
42#include <arrow/table.h>
43#if defined(__GNUC__)
44#pragma GCC diagnostic pop
45#endif
46
47namespace ROOT {
48namespace Internal {
49namespace RDF {
50// Per slot visitor of an Array.
51class ArrayPtrVisitor : public ::arrow::ArrayVisitor {
52private:
53 /// The pointer to update.
54 void **fResult;
55 bool fCachedBool{false}; // Booleans need to be unpacked, so we use a cached entry.
56 std::string fCachedString;
57 /// The entry in the array which should be looked up.
58 ULong64_t fCurrentEntry;
59
60public:
61 ArrayPtrVisitor(void **result) : fResult{result}, fCurrentEntry{0} {}
62
63 void SetEntry(ULong64_t entry) { fCurrentEntry = entry; }
64
65 /// Check if we are asking the same entry as before.
66 virtual arrow::Status Visit(arrow::Int32Array const &array) final
67 {
68 *fResult = (void *)(array.raw_values() + fCurrentEntry);
69 return arrow::Status::OK();
70 }
71
72 virtual arrow::Status Visit(arrow::Int64Array const &array) final
73 {
74 *fResult = (void *)(array.raw_values() + fCurrentEntry);
75 return arrow::Status::OK();
76 }
77
78 /// Check if we are asking the same entry as before.
79 virtual arrow::Status Visit(arrow::UInt32Array const &array) final
80 {
81 *fResult = (void *)(array.raw_values() + fCurrentEntry);
82 return arrow::Status::OK();
83 }
84
85 virtual arrow::Status Visit(arrow::UInt64Array const &array) final
86 {
87 *fResult = (void *)(array.raw_values() + fCurrentEntry);
88 return arrow::Status::OK();
89 }
90
91 virtual arrow::Status Visit(arrow::FloatArray const &array) final
92 {
93 *fResult = (void *)(array.raw_values() + fCurrentEntry);
94 return arrow::Status::OK();
95 }
96
97 virtual arrow::Status Visit(arrow::DoubleArray const &array) final
98 {
99 *fResult = (void *)(array.raw_values() + fCurrentEntry);
100 return arrow::Status::OK();
101 }
102
103 virtual arrow::Status Visit(arrow::BooleanArray const &array) final
104 {
105 fCachedBool = array.Value(fCurrentEntry);
106 *fResult = reinterpret_cast<void *>(&fCachedBool);
107 return arrow::Status::OK();
108 }
109
110 virtual arrow::Status Visit(arrow::StringArray const &array) final
111 {
112 fCachedString = array.GetString(fCurrentEntry);
113 *fResult = reinterpret_cast<void *>(&fCachedString);
114 return arrow::Status::OK();
115 }
116
117 using ::arrow::ArrayVisitor::Visit;
118};
119
120/// Helper class which keeps track for each slot where to get the entry.
121class TValueGetter {
122private:
123 std::vector<void *> fValuesPtrPerSlot;
124 std::vector<ULong64_t> fLastEntryPerSlot;
125 std::vector<ULong64_t> fLastChunkPerSlot;
126 std::vector<ULong64_t> fFirstEntryPerChunk;
127 std::vector<ArrayPtrVisitor> fArrayVisitorPerSlot;
128 /// Since data can be chunked in different arrays we need to construct an
129 /// index which contains the first element of each chunk, so that we can
130 /// quickly move to the correct chunk.
131 std::vector<ULong64_t> fChunkIndex;
132 arrow::ArrayVector fChunks;
133
134public:
135 TValueGetter(size_t slots, arrow::ArrayVector chunks)
136 : fValuesPtrPerSlot(slots, nullptr), fLastEntryPerSlot(slots, 0), fLastChunkPerSlot(slots, 0), fChunks{chunks}
137 {
138 fChunkIndex.reserve(fChunks.size());
139 size_t next = 0;
140 for (auto &chunk : chunks) {
141 fFirstEntryPerChunk.push_back(next);
142 next += chunk->length();
143 fChunkIndex.push_back(next);
144 }
145 for (size_t si = 0, se = fValuesPtrPerSlot.size(); si != se; ++si) {
146 fArrayVisitorPerSlot.push_back(ArrayPtrVisitor{fValuesPtrPerSlot.data() + si});
147 }
148 }
149
150 /// This returns the ptr to the ptr to actual data.
151 std::vector<void *> SlotPtrs()
152 {
153 std::vector<void *> result;
154 for (size_t i = 0; i < fValuesPtrPerSlot.size(); ++i) {
155 result.push_back(fValuesPtrPerSlot.data() + i);
156 }
157 return result;
158 }
159
160 // Convenience method to avoid code duplication between
161 // SetEntry and InitSlot
162 void UncachedSlotLookup(unsigned int slot, ULong64_t entry)
163 {
164 // If entry is greater than the previous one,
165 // we can skip all the chunks before the last one we
166 // queried.
167 size_t ci = 0;
168 assert(slot < fLastChunkPerSlot.size());
169 if (fLastEntryPerSlot[slot] < entry) {
170 ci = fLastChunkPerSlot.at(slot);
171 }
172
173 for (size_t ce = fChunkIndex.size(); ci != ce; ++ci) {
174 if (entry < fChunkIndex[ci]) {
175 assert(slot < fLastChunkPerSlot.size());
176 fLastChunkPerSlot[slot] = ci;
177 break;
178 }
179 }
180
181 // Update the pointer to the requested entry.
182 // Notice that we need to find the entry
183 auto chunk = fChunks.at(fLastChunkPerSlot[slot]);
184 assert(slot < fArrayVisitorPerSlot.size());
185 fArrayVisitorPerSlot[slot].SetEntry(entry - fFirstEntryPerChunk[fLastChunkPerSlot[slot]]);
186 auto status = chunk->Accept(fArrayVisitorPerSlot.data() + slot);
187 if (!status.ok()) {
188 std::string msg = "Could not get pointer for slot ";
189 msg += std::to_string(slot) + " looking at entry " + std::to_string(entry);
190 throw std::runtime_error(msg);
191 }
192 }
193
194 /// Set the current entry to be retrieved
195 void SetEntry(unsigned int slot, ULong64_t entry)
196 {
197 // Same entry as before
198 if (fLastEntryPerSlot[slot] == entry) {
199 return;
200 }
201 UncachedSlotLookup(slot, entry);
202 }
203};
204
205} // namespace RDF
206} // namespace Internal
207
208namespace RDF {
209
210/// Helper to get the contents of a given column
211
212/// Helper to get the human readable name of type
213class RDFTypeNameGetter : public ::arrow::TypeVisitor {
214private:
215 std::string fTypeName;
216
217public:
218 arrow::Status Visit(const arrow::Int64Type &) override
219 {
220 fTypeName = "Long64_t";
221 return arrow::Status::OK();
222 }
223 arrow::Status Visit(const arrow::Int32Type &) override
224 {
225 fTypeName = "Long_t";
226 return arrow::Status::OK();
227 }
228 arrow::Status Visit(const arrow::UInt64Type &) override
229 {
230 fTypeName = "ULong64_t";
231 return arrow::Status::OK();
232 }
233 arrow::Status Visit(const arrow::UInt32Type &) override
234 {
235 fTypeName = "ULong_t";
236 return arrow::Status::OK();
237 }
238 arrow::Status Visit(const arrow::FloatType &) override
239 {
240 fTypeName = "float";
241 return arrow::Status::OK();
242 }
243 arrow::Status Visit(const arrow::DoubleType &) override
244 {
245 fTypeName = "double";
246 return arrow::Status::OK();
247 }
248 arrow::Status Visit(const arrow::StringType &) override
249 {
250 fTypeName = "string";
251 return arrow::Status::OK();
252 }
253 arrow::Status Visit(const arrow::BooleanType &) override
254 {
255 fTypeName = "bool";
256 return arrow::Status::OK();
257 }
258 std::string result() { return fTypeName; }
259
260 using ::arrow::TypeVisitor::Visit;
261};
262
263/// Helper to determine if a given Column is a supported type.
264class VerifyValidColumnType : public ::arrow::TypeVisitor {
265private:
266public:
267 virtual arrow::Status Visit(const arrow::Int64Type &) override { return arrow::Status::OK(); }
268 virtual arrow::Status Visit(const arrow::UInt64Type &) override { return arrow::Status::OK(); }
269 virtual arrow::Status Visit(const arrow::Int32Type &) override { return arrow::Status::OK(); }
270 virtual arrow::Status Visit(const arrow::UInt32Type &) override { return arrow::Status::OK(); }
271 virtual arrow::Status Visit(const arrow::FloatType &) override { return arrow::Status::OK(); }
272 virtual arrow::Status Visit(const arrow::DoubleType &) override { return arrow::Status::OK(); }
273 virtual arrow::Status Visit(const arrow::StringType &) override { return arrow::Status::OK(); }
274 virtual arrow::Status Visit(const arrow::BooleanType &) override { return arrow::Status::OK(); }
275
276 using ::arrow::TypeVisitor::Visit;
277};
278
279////////////////////////////////////////////////////////////////////////
280/// Constructor to create an Arrow RDataSource for RDataFrame.
281/// \param[in] table the arrow Table to observe.
282/// \param[in] columns the name of the columns to use
283/// In case columns is empty, we use all the columns found in the table
284RArrowDS::RArrowDS(std::shared_ptr<arrow::Table> inTable, std::vector<std::string> const &inColumns)
285 : fTable{inTable}, fColumnNames{inColumns}
286{
287 auto &columnNames = fColumnNames;
288 auto &table = fTable;
289 auto &index = fGetterIndex;
290 // We want to allow people to specify which columns they
291 // need so that we can think of upfront IO optimizations.
292 auto filterWantedColumns = [&columnNames, &table]() {
293 if (columnNames.empty()) {
294 for (auto &field : table->schema()->fields()) {
295 columnNames.push_back(field->name());
296 }
297 }
298 };
299
300 // To support both arrow 0.14.0 and 0.16.0
301 using ColumnType = decltype(fTable->column(0));
302
303 auto getRecordsFirstColumn = [&columnNames, &table]() {
304 if (columnNames.empty()) {
305 throw std::runtime_error("At least one column required");
306 }
307 const auto name = columnNames.front();
308 const auto columnIdx = table->schema()->GetFieldIndex(name);
309 return table->column(columnIdx)->length();
310 };
311
312 // All columns are supposed to have the same number of entries.
313 auto verifyColumnSize = [&table](ColumnType column, int columnIdx, int nRecords) {
314 if (column->length() != nRecords) {
315 std::string msg = "Column ";
316 msg += table->schema()->field(columnIdx)->name() + " has a different number of entries.";
317 throw std::runtime_error(msg);
318 }
319 };
320
321 /// For the moment we support only a few native types.
322 auto verifyColumnType = [&table](ColumnType column, int columnIdx) {
323 auto verifyType = std::make_unique<VerifyValidColumnType>();
324 auto result = column->type()->Accept(verifyType.get());
325 if (result.ok() == false) {
326 std::string msg = "Column ";
327 msg += table->schema()->field(columnIdx)->name() + " contains an unsupported type.";
328 throw std::runtime_error(msg);
329 }
330 };
331
332 /// This is used to create an index between the columnId
333 /// and the associated getter.
334 auto addColumnToGetterIndex = [&index](int columnId) { index.push_back(std::make_pair(columnId, index.size())); };
335
336 /// Assuming we can get called more than once, we need to
337 /// reset the getter index each time.
338 auto resetGetterIndex = [&index]() { index.clear(); };
339
340 /// This is what initialization actually does
341 filterWantedColumns();
342 resetGetterIndex();
343 auto nRecords = getRecordsFirstColumn();
344 for (auto &columnName : fColumnNames) {
345 auto columnIdx = fTable->schema()->GetFieldIndex(columnName);
346 addColumnToGetterIndex(columnIdx);
347
348 auto column = fTable->column(columnIdx);
349 verifyColumnSize(column, columnIdx, nRecords);
350 verifyColumnType(column, columnIdx);
351 }
352}
353
354////////////////////////////////////////////////////////////////////////
355/// Destructor.
357{
358}
359
360const std::vector<std::string> &RArrowDS::GetColumnNames() const
361{
362 return fColumnNames;
363}
364
365std::vector<std::pair<ULong64_t, ULong64_t>> RArrowDS::GetEntryRanges()
366{
367 auto entryRanges(std::move(fEntryRanges)); // empty fEntryRanges
368 return entryRanges;
369}
370
371std::string RArrowDS::GetTypeName(std::string_view colName) const
372{
373 auto field = fTable->schema()->GetFieldByName(std::string(colName));
374 if (!field) {
375 std::string msg = "The dataset does not have column ";
376 msg += colName;
377 throw std::runtime_error(msg);
378 }
379 RDFTypeNameGetter typeGetter;
380 auto status = field->type()->Accept(&typeGetter);
381 if (status.ok() == false) {
382 std::string msg = "RArrowDS does not support a column of type ";
383 msg += field->type()->name();
384 throw std::runtime_error(msg);
385 }
386 return typeGetter.result();
387}
388
390{
391 auto field = fTable->schema()->GetFieldByName(std::string(colName));
392 if (!field) {
393 return false;
394 }
395 return true;
396}
397
398bool RArrowDS::SetEntry(unsigned int slot, ULong64_t entry)
399{
400 for (auto link : fGetterIndex) {
401 auto column = fTable->column(link.first);
402 auto &getter = fValueGetters[link.second];
403 getter->SetEntry(slot, entry);
404 }
405 return true;
406}
407
408void RArrowDS::InitSlot(unsigned int slot, ULong64_t entry)
409{
410 for (auto link : fGetterIndex) {
411 auto column = fTable->column(link.first);
412 auto &getter = fValueGetters[link.second];
413 getter->UncachedSlotLookup(slot, entry);
414 }
415}
416
417template <typename T>
418std::shared_ptr<arrow::ChunkedArray> getData(T p)
419{
420 return p->data();
421}
422
423template <>
424std::shared_ptr<arrow::ChunkedArray>
425getData<std::shared_ptr<arrow::ChunkedArray>>(std::shared_ptr<arrow::ChunkedArray> p)
426{
427 return p;
428}
429
430void RArrowDS::SetNSlots(unsigned int nSlots)
431{
432 assert(0U == fNSlots && "Setting the number of slots even if the number of slots is different from zero.");
433
434 // We dump all the previous getters structures and we rebuild it.
435 auto nColumns = fGetterIndex.size();
436 auto &outNSlots = fNSlots;
437 auto &ranges = fEntryRanges;
438 auto &table = fTable;
439 auto &columnNames = fColumnNames;
440
441 fValueGetters.clear();
442 for (size_t ci = 0; ci != nColumns; ++ci) {
443 auto chunkedArray = getData(fTable->column(fGetterIndex[ci].first));
444 fValueGetters.emplace_back(std::make_unique<ROOT::Internal::RDF::TValueGetter>(nSlots, chunkedArray->chunks()));
445 }
446
447 // We use the same logic as the ROOTDS.
448 auto splitInEqualRanges = [&outNSlots, &ranges](int nRecords, unsigned int newNSlots) {
449 ranges.clear();
450 outNSlots = newNSlots;
451 const auto chunkSize = nRecords / outNSlots;
452 const auto remainder = 1U == outNSlots ? 0 : nRecords % outNSlots;
453 auto start = 0UL;
454 auto end = 0UL;
455 for (auto i : ROOT::TSeqU(outNSlots)) {
456 start = end;
457 end += chunkSize;
458 ranges.emplace_back(start, end);
459 (void)i;
460 }
461 ranges.back().second += remainder;
462 };
463
464 auto getNRecords = [&table, &columnNames]() -> int {
465 auto index = table->schema()->GetFieldIndex(columnNames.front());
466 return table->column(index)->length();
467 };
468
469 auto nRecords = getNRecords();
470 splitInEqualRanges(nRecords, nSlots);
471}
472
473/// This needs to return a pointer to the pointer each value getter
474/// will point to.
475std::vector<void *> RArrowDS::GetColumnReadersImpl(std::string_view colName, const std::type_info &)
476{
477 auto &index = fGetterIndex;
478 auto findGetterIndex = [&index](unsigned int column) {
479 for (auto &entry : index) {
480 if (entry.first == column) {
481 return entry.second;
482 }
483 }
484 throw std::runtime_error("No column found at index " + std::to_string(column));
485 };
486
487 const int columnIdx = fTable->schema()->GetFieldIndex(std::string(colName));
488 const int getterIdx = findGetterIndex(columnIdx);
489 assert(getterIdx != -1);
490 assert((unsigned int)getterIdx < fValueGetters.size());
491 return fValueGetters[getterIdx]->SlotPtrs();
492}
493
495{
496}
497
499{
500 return "ArrowDS";
501}
502
503/// Creates a RDataFrame using an arrow::Table as input.
504/// \param[in] table the arrow Table to observe.
505/// \param[in] columnNames the name of the columns to use
506/// In case columnNames is empty, we use all the columns found in the table
507RDataFrame MakeArrowDataFrame(std::shared_ptr<arrow::Table> table, std::vector<std::string> const &columnNames)
508{
509 ROOT::RDataFrame tdf(std::make_unique<RArrowDS>(table, columnNames));
510 return tdf;
511}
512
513} // namespace RDF
514
515} // namespace ROOT
unsigned long long ULong64_t
Definition: RtypesCore.h:70
typedef void((*Func_t)())
bool HasColumn(std::string_view colName) const override
Checks if the dataset has a certain column.
Definition: RArrowDS.cxx:389
RArrowDS(std::shared_ptr< arrow::Table > table, std::vector< std::string > const &columns)
Constructor to create an Arrow RDataSource for RDataFrame.
Definition: RArrowDS.cxx:284
void Initialise() override
Convenience method called before starting an event-loop.
Definition: RArrowDS.cxx:494
std::string GetLabel() override
Return a string representation of the datasource type.
Definition: RArrowDS.cxx:498
void SetNSlots(unsigned int nSlots) override
Inform RDataSource of the number of processing slots (i.e.
Definition: RArrowDS.cxx:430
~RArrowDS()
Destructor.
Definition: RArrowDS.cxx:356
const std::vector< std::string > & GetColumnNames() const override
Returns a reference to the collection of the dataset's column names.
Definition: RArrowDS.cxx:360
void InitSlot(unsigned int slot, ULong64_t firstEntry) override
Convenience method called at the start of the data processing associated to a slot.
Definition: RArrowDS.cxx:408
std::vector< std::pair< ULong64_t, ULong64_t > > GetEntryRanges() override
Return ranges of entries to distribute to tasks.
Definition: RArrowDS.cxx:365
std::shared_ptr< arrow::Table > fTable
Definition: RArrowDS.hxx:24
std::vector< std::pair< size_t, size_t > > fGetterIndex
Definition: RArrowDS.hxx:29
std::vector< std::unique_ptr< ROOT::Internal::RDF::TValueGetter > > fValueGetters
Definition: RArrowDS.hxx:30
std::vector< void * > GetColumnReadersImpl(std::string_view name, const std::type_info &type) override
This needs to return a pointer to the pointer each value getter will point to.
Definition: RArrowDS.cxx:475
std::vector< std::string > fColumnNames
Definition: RArrowDS.hxx:26
std::string GetTypeName(std::string_view colName) const override
Type of a column as a string, e.g.
Definition: RArrowDS.cxx:371
std::vector< std::pair< ULong64_t, ULong64_t > > fEntryRanges
Definition: RArrowDS.hxx:25
bool SetEntry(unsigned int slot, ULong64_t entry) override
Advance the "cursors" returned by GetColumnReaders to the selected entry for a particular slot.
Definition: RArrowDS.cxx:398
ROOT's RDataFrame offers a high level interface for analyses of data stored in TTrees,...
Definition: RDataFrame.hxx:41
A pseudo container class which is a generator of indices.
Definition: TSeq.hxx:66
double T(double x)
Definition: ChebyshevPol.h:34
std::shared_ptr< arrow::ChunkedArray > getData(T p)
Definition: RArrowDS.cxx:418
RDataFrame MakeArrowDataFrame(std::shared_ptr< arrow::Table > table, std::vector< std::string > const &columns)
Factory method to create a Apache Arrow RDataFrame.
Definition: RArrowDS.cxx:507
Namespace for new ROOT classes and functions.
Definition: StringConv.hxx:21
@ array
array (ordered collection of values)
basic_string_view< char > string_view
Definition: RStringView.hxx:35
void table()
Definition: table.C:85