39#pragma GCC diagnostic push
40#pragma GCC diagnostic ignored "-Wshadow"
42#include <arrow/table.h>
44#pragma GCC diagnostic pop
51class ArrayPtrVisitor :
public ::arrow::ArrayVisitor {
55 bool fCachedBool{
false};
56 std::string fCachedString;
61 ArrayPtrVisitor(
void **result) : fResult{result}, fCurrentEntry{0} {}
63 void SetEntry(
ULong64_t entry) { fCurrentEntry = entry; }
66 virtual arrow::Status Visit(arrow::Int32Array
const &array)
final
68 *fResult = (
void *)(
array.raw_values() + fCurrentEntry);
69 return arrow::Status::OK();
72 virtual arrow::Status Visit(arrow::Int64Array
const &array)
final
74 *fResult = (
void *)(
array.raw_values() + fCurrentEntry);
75 return arrow::Status::OK();
79 virtual arrow::Status Visit(arrow::UInt32Array
const &array)
final
81 *fResult = (
void *)(
array.raw_values() + fCurrentEntry);
82 return arrow::Status::OK();
85 virtual arrow::Status Visit(arrow::UInt64Array
const &array)
final
87 *fResult = (
void *)(
array.raw_values() + fCurrentEntry);
88 return arrow::Status::OK();
91 virtual arrow::Status Visit(arrow::FloatArray
const &array)
final
93 *fResult = (
void *)(
array.raw_values() + fCurrentEntry);
94 return arrow::Status::OK();
97 virtual arrow::Status Visit(arrow::DoubleArray
const &array)
final
99 *fResult = (
void *)(
array.raw_values() + fCurrentEntry);
100 return arrow::Status::OK();
103 virtual arrow::Status Visit(arrow::BooleanArray
const &array)
final
105 fCachedBool =
array.Value(fCurrentEntry);
106 *fResult =
reinterpret_cast<void *
>(&fCachedBool);
107 return arrow::Status::OK();
110 virtual arrow::Status Visit(arrow::StringArray
const &array)
final
112 fCachedString =
array.GetString(fCurrentEntry);
113 *fResult =
reinterpret_cast<void *
>(&fCachedString);
114 return arrow::Status::OK();
117 using ::arrow::ArrayVisitor::Visit;
123 std::vector<void *> fValuesPtrPerSlot;
124 std::vector<ULong64_t> fLastEntryPerSlot;
125 std::vector<ULong64_t> fLastChunkPerSlot;
126 std::vector<ULong64_t> fFirstEntryPerChunk;
127 std::vector<ArrayPtrVisitor> fArrayVisitorPerSlot;
131 std::vector<ULong64_t> fChunkIndex;
132 arrow::ArrayVector fChunks;
135 TValueGetter(
size_t slots, arrow::ArrayVector chunks)
136 : fValuesPtrPerSlot(slots, nullptr), fLastEntryPerSlot(slots, 0), fLastChunkPerSlot(slots, 0), fChunks{chunks}
138 fChunkIndex.reserve(fChunks.size());
140 for (
auto &chunk : chunks) {
141 fFirstEntryPerChunk.push_back(next);
142 next += chunk->length();
143 fChunkIndex.push_back(next);
145 for (
size_t si = 0, se = fValuesPtrPerSlot.size(); si != se; ++si) {
146 fArrayVisitorPerSlot.push_back(ArrayPtrVisitor{fValuesPtrPerSlot.data() + si});
151 std::vector<void *> SlotPtrs()
153 std::vector<void *> result;
154 for (
size_t i = 0; i < fValuesPtrPerSlot.size(); ++i) {
155 result.push_back(fValuesPtrPerSlot.data() + i);
162 void UncachedSlotLookup(
unsigned int slot,
ULong64_t entry)
168 assert(slot < fLastChunkPerSlot.size());
169 if (fLastEntryPerSlot[slot] < entry) {
170 ci = fLastChunkPerSlot.at(slot);
173 for (
size_t ce = fChunkIndex.size(); ci != ce; ++ci) {
174 if (entry < fChunkIndex[ci]) {
175 assert(slot < fLastChunkPerSlot.size());
176 fLastChunkPerSlot[slot] = ci;
183 auto chunk = fChunks.at(fLastChunkPerSlot[slot]);
184 assert(slot < fArrayVisitorPerSlot.size());
185 fArrayVisitorPerSlot[slot].SetEntry(entry - fFirstEntryPerChunk[fLastChunkPerSlot[slot]]);
186 auto status = chunk->Accept(fArrayVisitorPerSlot.data() + slot);
188 std::string msg =
"Could not get pointer for slot ";
189 msg += std::to_string(slot) +
" looking at entry " + std::to_string(entry);
190 throw std::runtime_error(msg);
195 void SetEntry(
unsigned int slot,
ULong64_t entry)
198 if (fLastEntryPerSlot[slot] == entry) {
201 UncachedSlotLookup(slot, entry);
213class RDFTypeNameGetter :
public ::arrow::TypeVisitor {
215 std::string fTypeName;
218 arrow::Status Visit(
const arrow::Int64Type &)
override
220 fTypeName =
"Long64_t";
221 return arrow::Status::OK();
223 arrow::Status Visit(
const arrow::Int32Type &)
override
225 fTypeName =
"Long_t";
226 return arrow::Status::OK();
228 arrow::Status Visit(
const arrow::UInt64Type &)
override
230 fTypeName =
"ULong64_t";
231 return arrow::Status::OK();
233 arrow::Status Visit(
const arrow::UInt32Type &)
override
235 fTypeName =
"ULong_t";
236 return arrow::Status::OK();
238 arrow::Status Visit(
const arrow::FloatType &)
override
241 return arrow::Status::OK();
243 arrow::Status Visit(
const arrow::DoubleType &)
override
245 fTypeName =
"double";
246 return arrow::Status::OK();
248 arrow::Status Visit(
const arrow::StringType &)
override
250 fTypeName =
"string";
251 return arrow::Status::OK();
253 arrow::Status Visit(
const arrow::BooleanType &)
override
256 return arrow::Status::OK();
258 std::string result() {
return fTypeName; }
260 using ::arrow::TypeVisitor::Visit;
264class VerifyValidColumnType :
public ::arrow::TypeVisitor {
267 virtual arrow::Status Visit(
const arrow::Int64Type &)
override {
return arrow::Status::OK(); }
268 virtual arrow::Status Visit(
const arrow::UInt64Type &)
override {
return arrow::Status::OK(); }
269 virtual arrow::Status Visit(
const arrow::Int32Type &)
override {
return arrow::Status::OK(); }
270 virtual arrow::Status Visit(
const arrow::UInt32Type &)
override {
return arrow::Status::OK(); }
271 virtual arrow::Status Visit(
const arrow::FloatType &)
override {
return arrow::Status::OK(); }
272 virtual arrow::Status Visit(
const arrow::DoubleType &)
override {
return arrow::Status::OK(); }
273 virtual arrow::Status Visit(
const arrow::StringType &)
override {
return arrow::Status::OK(); }
274 virtual arrow::Status Visit(
const arrow::BooleanType &)
override {
return arrow::Status::OK(); }
276 using ::arrow::TypeVisitor::Visit;
285 : fTable{inTable}, fColumnNames{inColumns}
292 auto filterWantedColumns = [&columnNames, &
table]() {
293 if (columnNames.empty()) {
294 for (
auto &field :
table->schema()->fields()) {
295 columnNames.push_back(field->name());
301 using ColumnType =
decltype(
fTable->column(0));
303 auto getRecordsFirstColumn = [&columnNames, &
table]() {
304 if (columnNames.empty()) {
305 throw std::runtime_error(
"At least one column required");
307 const auto name = columnNames.front();
308 const auto columnIdx =
table->schema()->GetFieldIndex(
name);
309 return table->column(columnIdx)->length();
313 auto verifyColumnSize = [&
table](ColumnType column,
int columnIdx,
int nRecords) {
314 if (column->length() != nRecords) {
315 std::string msg =
"Column ";
316 msg +=
table->schema()->field(columnIdx)->name() +
" has a different number of entries.";
317 throw std::runtime_error(msg);
322 auto verifyColumnType = [&
table](ColumnType column,
int columnIdx) {
323 auto verifyType = std::make_unique<VerifyValidColumnType>();
324 auto result = column->type()->Accept(verifyType.get());
325 if (result.ok() ==
false) {
326 std::string msg =
"Column ";
327 msg +=
table->schema()->field(columnIdx)->name() +
" contains an unsupported type.";
328 throw std::runtime_error(msg);
334 auto addColumnToGetterIndex = [&index](
int columnId) { index.push_back(std::make_pair(columnId, index.size())); };
338 auto resetGetterIndex = [&index]() { index.clear(); };
341 filterWantedColumns();
343 auto nRecords = getRecordsFirstColumn();
345 auto columnIdx =
fTable->schema()->GetFieldIndex(columnName);
346 addColumnToGetterIndex(columnIdx);
348 auto column =
fTable->column(columnIdx);
349 verifyColumnSize(column, columnIdx, nRecords);
350 verifyColumnType(column, columnIdx);
373 auto field =
fTable->schema()->GetFieldByName(std::string(colName));
375 std::string msg =
"The dataset does not have column ";
377 throw std::runtime_error(msg);
379 RDFTypeNameGetter typeGetter;
380 auto status = field->type()->Accept(&typeGetter);
381 if (status.ok() ==
false) {
382 std::string msg =
"RArrowDS does not support a column of type ";
383 msg += field->type()->name();
384 throw std::runtime_error(msg);
386 return typeGetter.result();
391 auto field =
fTable->schema()->GetFieldByName(std::string(colName));
401 auto column =
fTable->column(link.first);
403 getter->SetEntry(slot, entry);
411 auto column =
fTable->column(link.first);
413 getter->UncachedSlotLookup(slot, entry);
424std::shared_ptr<arrow::ChunkedArray>
425getData<std::shared_ptr<arrow::ChunkedArray>>(std::shared_ptr<arrow::ChunkedArray> p)
432 assert(0U ==
fNSlots &&
"Setting the number of slots even if the number of slots is different from zero.");
442 for (
size_t ci = 0; ci != nColumns; ++ci) {
444 fValueGetters.emplace_back(std::make_unique<ROOT::Internal::RDF::TValueGetter>(nSlots, chunkedArray->chunks()));
448 auto splitInEqualRanges = [&outNSlots, &ranges](
int nRecords,
unsigned int newNSlots) {
450 outNSlots = newNSlots;
451 const auto chunkSize = nRecords / outNSlots;
452 const auto remainder = 1U == outNSlots ? 0 : nRecords % outNSlots;
458 ranges.emplace_back(start, end);
461 ranges.back().second += remainder;
464 auto getNRecords = [&
table, &columnNames]() ->
int {
465 auto index =
table->schema()->GetFieldIndex(columnNames.front());
466 return table->column(index)->length();
469 auto nRecords = getNRecords();
470 splitInEqualRanges(nRecords, nSlots);
478 auto findGetterIndex = [&index](
unsigned int column) {
479 for (
auto &entry : index) {
480 if (entry.first == column) {
484 throw std::runtime_error(
"No column found at index " + std::to_string(column));
487 const int columnIdx =
fTable->schema()->GetFieldIndex(std::string(colName));
488 const int getterIdx = findGetterIndex(columnIdx);
489 assert(getterIdx != -1);
unsigned long long ULong64_t
typedef void((*Func_t)())
bool HasColumn(std::string_view colName) const override
Checks if the dataset has a certain column.
RArrowDS(std::shared_ptr< arrow::Table > table, std::vector< std::string > const &columns)
Constructor to create an Arrow RDataSource for RDataFrame.
void Initialise() override
Convenience method called before starting an event-loop.
std::string GetLabel() override
Return a string representation of the datasource type.
void SetNSlots(unsigned int nSlots) override
Inform RDataSource of the number of processing slots (i.e.
const std::vector< std::string > & GetColumnNames() const override
Returns a reference to the collection of the dataset's column names.
void InitSlot(unsigned int slot, ULong64_t firstEntry) override
Convenience method called at the start of the data processing associated to a slot.
std::vector< std::pair< ULong64_t, ULong64_t > > GetEntryRanges() override
Return ranges of entries to distribute to tasks.
std::shared_ptr< arrow::Table > fTable
std::vector< std::pair< size_t, size_t > > fGetterIndex
std::vector< std::unique_ptr< ROOT::Internal::RDF::TValueGetter > > fValueGetters
std::vector< void * > GetColumnReadersImpl(std::string_view name, const std::type_info &type) override
This needs to return a pointer to the pointer each value getter will point to.
std::vector< std::string > fColumnNames
std::string GetTypeName(std::string_view colName) const override
Type of a column as a string, e.g.
std::vector< std::pair< ULong64_t, ULong64_t > > fEntryRanges
bool SetEntry(unsigned int slot, ULong64_t entry) override
Advance the "cursors" returned by GetColumnReaders to the selected entry for a particular slot.
ROOT's RDataFrame offers a high level interface for analyses of data stored in TTrees,...
A pseudo container class which is a generator of indices.
std::shared_ptr< arrow::ChunkedArray > getData(T p)
RDataFrame MakeArrowDataFrame(std::shared_ptr< arrow::Table > table, std::vector< std::string > const &columns)
Factory method to create a Apache Arrow RDataFrame.
Namespace for new ROOT classes and functions.
@ array
array (ordered collection of values)
basic_string_view< char > string_view