39 #pragma GCC diagnostic push 40 #pragma GCC diagnostic ignored "-Wshadow" 42 #include <arrow/table.h> 44 #pragma GCC diagnostic pop 52 class ArrayPtrVisitor :
public ::arrow::ArrayVisitor {
56 bool fCachedBool{
false};
57 std::string fCachedString;
62 ArrayPtrVisitor(
void **result) : fResult{result}, fCurrentEntry{0} {}
64 void SetEntry(
ULong64_t entry) { fCurrentEntry = entry; }
67 virtual arrow::Status Visit(arrow::Int32Array
const &array)
final 69 *fResult = (
void *)(
array.raw_values() + fCurrentEntry);
70 return arrow::Status::OK();
73 virtual arrow::Status Visit(arrow::Int64Array
const &array)
final 75 *fResult = (
void *)(
array.raw_values() + fCurrentEntry);
76 return arrow::Status::OK();
80 virtual arrow::Status Visit(arrow::UInt32Array
const &array)
final 82 *fResult = (
void *)(
array.raw_values() + fCurrentEntry);
83 return arrow::Status::OK();
86 virtual arrow::Status Visit(arrow::UInt64Array
const &array)
final 88 *fResult = (
void *)(
array.raw_values() + fCurrentEntry);
89 return arrow::Status::OK();
92 virtual arrow::Status Visit(arrow::FloatArray
const &array)
final 94 *fResult = (
void *)(
array.raw_values() + fCurrentEntry);
95 return arrow::Status::OK();
98 virtual arrow::Status Visit(arrow::DoubleArray
const &array)
final 100 *fResult = (
void *)(
array.raw_values() + fCurrentEntry);
101 return arrow::Status::OK();
104 virtual arrow::Status Visit(arrow::BooleanArray
const &array)
final 106 fCachedBool =
array.Value(fCurrentEntry);
107 *fResult =
reinterpret_cast<void *
>(&fCachedBool);
108 return arrow::Status::OK();
111 virtual arrow::Status Visit(arrow::StringArray
const &array)
final 113 fCachedString =
array.GetString(fCurrentEntry);
114 *fResult =
reinterpret_cast<void *
>(&fCachedString);
115 return arrow::Status::OK();
118 using ::arrow::ArrayVisitor::Visit;
124 std::vector<void *> fValuesPtrPerSlot;
125 std::vector<ULong64_t> fLastEntryPerSlot;
126 std::vector<ULong64_t> fLastChunkPerSlot;
127 std::vector<ULong64_t> fFirstEntryPerChunk;
128 std::vector<ArrayPtrVisitor> fArrayVisitorPerSlot;
132 std::vector<ULong64_t> fChunkIndex;
133 arrow::ArrayVector fChunks;
136 TValueGetter(
size_t slots, arrow::ArrayVector chunks)
137 : fValuesPtrPerSlot(slots, nullptr), fLastEntryPerSlot(slots, 0), fLastChunkPerSlot(slots, 0), fChunks{chunks}
139 fChunkIndex.reserve(fChunks.size());
141 for (
auto &chunk : chunks) {
142 fFirstEntryPerChunk.push_back(next);
143 next += chunk->length();
144 fChunkIndex.push_back(next);
146 for (
size_t si = 0, se = fValuesPtrPerSlot.size(); si != se; ++si) {
147 fArrayVisitorPerSlot.push_back(ArrayPtrVisitor{fValuesPtrPerSlot.data() + si});
152 std::vector<void *> SlotPtrs()
154 std::vector<void *> result;
155 for (
size_t i = 0; i < fValuesPtrPerSlot.size(); ++i) {
156 result.push_back(fValuesPtrPerSlot.data() + i);
163 void UncachedSlotLookup(
unsigned int slot,
ULong64_t entry)
169 assert(slot < fLastChunkPerSlot.size());
170 if (fLastEntryPerSlot[slot] < entry) {
171 ci = fLastChunkPerSlot.at(slot);
174 for (
size_t ce = fChunkIndex.size(); ci != ce; ++ci) {
175 if (entry < fChunkIndex[ci]) {
176 assert(slot < fLastChunkPerSlot.size());
177 fLastChunkPerSlot[slot] = ci;
184 auto chunk = fChunks.at(fLastChunkPerSlot[slot]);
185 assert(slot < fArrayVisitorPerSlot.size());
186 fArrayVisitorPerSlot[slot].SetEntry(entry - fFirstEntryPerChunk[fLastChunkPerSlot[slot]]);
187 auto status = chunk->Accept(fArrayVisitorPerSlot.data() + slot);
189 std::string msg =
"Could not get pointer for slot ";
190 msg += std::to_string(slot) +
" looking at entry " + std::to_string(entry);
191 throw std::runtime_error(msg);
196 void SetEntry(
unsigned int slot,
ULong64_t entry)
199 if (fLastEntryPerSlot[slot] == entry) {
202 UncachedSlotLookup(slot, entry);
215 class RDFTypeNameGetter :
public ::arrow::TypeVisitor {
217 std::string fTypeName;
220 arrow::Status Visit(
const arrow::Int64Type &)
override 222 fTypeName =
"Long64_t";
223 return arrow::Status::OK();
225 arrow::Status Visit(
const arrow::Int32Type &)
override 227 fTypeName =
"Long_t";
228 return arrow::Status::OK();
230 arrow::Status Visit(
const arrow::UInt64Type &)
override 232 fTypeName =
"ULong64_t";
233 return arrow::Status::OK();
235 arrow::Status Visit(
const arrow::UInt32Type &)
override 237 fTypeName =
"ULong_t";
238 return arrow::Status::OK();
240 arrow::Status Visit(
const arrow::FloatType &)
override 243 return arrow::Status::OK();
245 arrow::Status Visit(
const arrow::DoubleType &)
override 247 fTypeName =
"double";
248 return arrow::Status::OK();
250 arrow::Status Visit(
const arrow::StringType &)
override 252 fTypeName =
"string";
253 return arrow::Status::OK();
255 arrow::Status Visit(
const arrow::BooleanType &)
override 258 return arrow::Status::OK();
260 std::string result() {
return fTypeName; }
262 using ::arrow::TypeVisitor::Visit;
266 class VerifyValidColumnType :
public ::arrow::TypeVisitor {
269 virtual arrow::Status Visit(
const arrow::Int64Type &)
override {
return arrow::Status::OK(); }
270 virtual arrow::Status Visit(
const arrow::UInt64Type &)
override {
return arrow::Status::OK(); }
271 virtual arrow::Status Visit(
const arrow::Int32Type &)
override {
return arrow::Status::OK(); }
272 virtual arrow::Status Visit(
const arrow::UInt32Type &)
override {
return arrow::Status::OK(); }
273 virtual arrow::Status Visit(
const arrow::FloatType &)
override {
return arrow::Status::OK(); }
274 virtual arrow::Status Visit(
const arrow::DoubleType &)
override {
return arrow::Status::OK(); }
275 virtual arrow::Status Visit(
const arrow::StringType &)
override {
return arrow::Status::OK(); }
276 virtual arrow::Status Visit(
const arrow::BooleanType &)
override {
return arrow::Status::OK(); }
278 using ::arrow::TypeVisitor::Visit;
297 auto filterWantedColumns = [&columnNames, &table]()
299 if (columnNames.empty()) {
300 for (
auto &field : table->schema()->fields()) {
301 columnNames.push_back(field->name());
306 auto getRecordsFirstColumn = [&columnNames, &table]()
308 if (columnNames.empty()) {
309 throw std::runtime_error(
"At least one column required");
311 const auto name = columnNames.front();
312 const auto columnIdx = table->schema()->GetFieldIndex(
name);
313 return table->column(columnIdx)->length();
317 auto verifyColumnSize = [](std::shared_ptr<arrow::Column> column,
int nRecords)
319 if (column->length() != nRecords) {
320 std::string msg =
"Column ";
321 msg += column->name() +
" has a different number of entries.";
322 throw std::runtime_error(msg);
327 auto verifyColumnType = [](std::shared_ptr<arrow::Column> column) {
328 auto verifyType = std::make_unique<VerifyValidColumnType>();
329 auto result = column->type()->Accept(verifyType.get());
330 if (result.ok() ==
false) {
331 std::string msg =
"Column ";
332 msg += column->name() +
" contains an unsupported type.";
333 throw std::runtime_error(msg);
339 auto addColumnToGetterIndex = [&index](
int columnId)
341 index.push_back(std::make_pair(columnId, index.size()));
346 auto resetGetterIndex = [&index]() { index.clear(); };
349 filterWantedColumns();
351 auto nRecords = getRecordsFirstColumn();
353 auto columnIdx =
fTable->schema()->GetFieldIndex(columnName);
354 addColumnToGetterIndex(columnIdx);
356 auto column =
fTable->column(columnIdx);
357 verifyColumnSize(column, nRecords);
358 verifyColumnType(column);
381 auto field =
fTable->schema()->GetFieldByName(std::string(colName));
383 std::string msg =
"The dataset does not have column ";
385 throw std::runtime_error(msg);
387 RDFTypeNameGetter typeGetter;
388 auto status = field->type()->Accept(&typeGetter);
389 if (status.ok() ==
false) {
390 std::string msg =
"RArrowDS does not support a column of type ";
391 msg += field->type()->name();
392 throw std::runtime_error(msg);
394 return typeGetter.result();
399 auto field =
fTable->schema()->GetFieldByName(std::string(colName));
409 auto column =
fTable->column(link.first);
411 getter->SetEntry(slot, entry);
419 auto column =
fTable->column(link.first);
421 getter->UncachedSlotLookup(slot, entry);
427 assert(0U ==
fNSlots &&
"Setting the number of slots even if the number of slots is different from zero.");
437 for (
size_t ci = 0; ci != nColumns; ++ci) {
439 fValueGetters.emplace_back(std::make_unique<ROOT::Internal::RDF::TValueGetter>(nSlots, chunkedArray->chunks()));
443 auto splitInEqualRanges = [&outNSlots, &ranges](
int nRecords,
unsigned int newNSlots)
446 outNSlots = newNSlots;
447 const auto chunkSize = nRecords / outNSlots;
448 const auto remainder = 1U == outNSlots ? 0 : nRecords % outNSlots;
454 ranges.emplace_back(start, end);
457 ranges.back().second += remainder;
460 auto getNRecords = [&table, &columnNames]()->
int 462 auto index = table->schema()->GetFieldIndex(columnNames.front());
463 return table->column(index)->length();
466 auto nRecords = getNRecords();
467 splitInEqualRanges(nRecords, nSlots);
475 auto findGetterIndex = [&index](
unsigned int column)
477 for (
auto &entry : index) {
478 if (entry.first == column) {
482 throw std::runtime_error(
"No column found at index " + std::to_string(column));
485 const int columnIdx =
fTable->schema()->GetFieldIndex(std::string(colName));
486 const int getterIdx = findGetterIndex(columnIdx);
487 assert(getterIdx != -1);
Namespace for new ROOT classes and functions.
bool HasColumn(std::string_view colName) const override
Checks if the dataset has a certain column.
array (ordered collection of values)
RDataFrame MakeArrowDataFrame(std::shared_ptr< arrow::Table > table, std::vector< std::string > const &columns)
Factory method to create a Apache Arrow RDataFrame.
RArrowDS(std::shared_ptr< arrow::Table > table, std::vector< std::string > const &columns)
Constructor to create an Arrow RDataSource for RDataFrame.
std::string GetTypeName(std::string_view colName) const override
Type of a column as a string, e.g.
bool SetEntry(unsigned int slot, ULong64_t entry) override
Advance the "cursors" returned by GetColumnReaders to the selected entry for a particular slot...
std::vector< void * > GetColumnReadersImpl(std::string_view name, const std::type_info &type) override
This needs to return a pointer to the pointer each value getter will point to.
std::vector< std::pair< ULong64_t, ULong64_t > > fEntryRanges
const std::vector< std::string > & GetColumnNames() const override
Returns a reference to the collection of the dataset's column names.
std::vector< std::unique_ptr< ROOT::Internal::RDF::TValueGetter > > fValueGetters
void InitSlot(unsigned int slot, ULong64_t firstEntry) override
Convenience method called at the start of the data processing associated to a slot.
void SetNSlots(unsigned int nSlots) override
Inform RDataSource of the number of processing slots (i.e.
ROOT's RDataFrame offers a high level interface for analyses of data stored in TTrees, CSV's and other data formats.
A pseudo container class which is a generator of indices.
unsigned long long ULong64_t
basic_string_view< char > string_view
std::shared_ptr< arrow::Table > fTable
typedef void((*Func_t)())
void Initialise() override
Convenience method called before starting an event-loop.
std::vector< std::string > fColumnNames
std::vector< std::pair< ULong64_t, ULong64_t > > GetEntryRanges() override
Return ranges of entries to distribute to tasks.
std::vector< std::pair< size_t, size_t > > fGetterIndex