40#pragma GCC diagnostic push
41#pragma GCC diagnostic ignored "-Wshadow"
42#pragma GCC diagnostic ignored "-Wunused-parameter"
44#include <arrow/table.h>
47#pragma GCC diagnostic pop
60struct RootConversionTraits {};
62#define ROOT_ARROW_STL_CONVERSION(c_type, ArrowType_) \
64 struct RootConversionTraits<c_type> { \
65 using ArrowType = ::arrow::ArrowType_; \
82class ArrayPtrVisitor :
public ::arrow::ArrayVisitor {
86 bool fCachedBool{
false};
88 RVec<float> fCachedRVecFloat;
89 RVec<double> fCachedRVecDouble;
90 RVec<ULong64_t> fCachedRVecULong64;
91 RVec<UInt_t> fCachedRVecUInt;
92 RVec<Long64_t> fCachedRVecLong64;
93 RVec<Int_t> fCachedRVecInt;
94 std::string fCachedString;
99 void *getTypeErasedPtrFrom(arrow::ListArray
const &array, int32_t entry, RVec<T> &cache)
101 using ArrowType =
typename RootConversionTraits<T>::ArrowType;
102 using ArrayType =
typename arrow::TypeTraits<ArrowType>::ArrayType;
103 auto values =
reinterpret_cast<ArrayType *
>(array.values().get());
104 auto offset = array.value_offset(entry);
107 RVec<T> tmp(
reinterpret_cast<T *
>((
void *)values->raw_values()) + offset, array.value_length(entry));
108 std::swap(cache, tmp);
109 return (
void *)(&cache);
113 ArrayPtrVisitor(
void **result) : fResult{result}, fCurrentEntry{0} {}
115 void SetEntry(
ULong64_t entry) { fCurrentEntry = entry; }
118 virtual arrow::Status Visit(arrow::Int32Array
const &array)
final
120 *fResult = (
void *)(array.raw_values() + fCurrentEntry);
121 return arrow::Status::OK();
124 virtual arrow::Status Visit(arrow::Int64Array
const &array)
final
126 *fResult = (
void *)(array.raw_values() + fCurrentEntry);
127 return arrow::Status::OK();
131 virtual arrow::Status Visit(arrow::UInt32Array
const &array)
final
133 *fResult = (
void *)(array.raw_values() + fCurrentEntry);
134 return arrow::Status::OK();
137 virtual arrow::Status Visit(arrow::UInt64Array
const &array)
final
139 *fResult = (
void *)(array.raw_values() + fCurrentEntry);
140 return arrow::Status::OK();
143 virtual arrow::Status Visit(arrow::FloatArray
const &array)
final
145 *fResult = (
void *)(array.raw_values() + fCurrentEntry);
146 return arrow::Status::OK();
149 virtual arrow::Status Visit(arrow::DoubleArray
const &array)
final
151 *fResult = (
void *)(array.raw_values() + fCurrentEntry);
152 return arrow::Status::OK();
155 virtual arrow::Status Visit(arrow::BooleanArray
const &array)
final
157 fCachedBool = array.Value(fCurrentEntry);
158 *fResult =
reinterpret_cast<void *
>(&fCachedBool);
159 return arrow::Status::OK();
162 virtual arrow::Status Visit(arrow::StringArray
const &array)
final
164 fCachedString = array.GetString(fCurrentEntry);
165 *fResult =
reinterpret_cast<void *
>(&fCachedString);
166 return arrow::Status::OK();
169 virtual arrow::Status Visit(arrow::ListArray
const &array)
final
171 switch (array.value_type()->id()) {
172 case arrow::Type::FLOAT: {
173 *fResult = getTypeErasedPtrFrom(array, fCurrentEntry, fCachedRVecFloat);
174 return arrow::Status::OK();
176 case arrow::Type::DOUBLE: {
177 *fResult = getTypeErasedPtrFrom(array, fCurrentEntry, fCachedRVecDouble);
178 return arrow::Status::OK();
180 case arrow::Type::UINT32: {
181 *fResult = getTypeErasedPtrFrom(array, fCurrentEntry, fCachedRVecUInt);
182 return arrow::Status::OK();
184 case arrow::Type::UINT64: {
185 *fResult = getTypeErasedPtrFrom(array, fCurrentEntry, fCachedRVecULong64);
186 return arrow::Status::OK();
188 case arrow::Type::INT32: {
189 *fResult = getTypeErasedPtrFrom(array, fCurrentEntry, fCachedRVecInt);
190 return arrow::Status::OK();
192 case arrow::Type::INT64: {
193 *fResult = getTypeErasedPtrFrom(array, fCurrentEntry, fCachedRVecLong64);
194 return arrow::Status::OK();
196 default:
return arrow::Status::TypeError(
"Type not supported");
200 using ::arrow::ArrayVisitor::Visit;
206 std::vector<void *> fValuesPtrPerSlot;
207 std::vector<ULong64_t> fLastEntryPerSlot;
208 std::vector<ULong64_t> fLastChunkPerSlot;
209 std::vector<ULong64_t> fFirstEntryPerChunk;
210 std::vector<ArrayPtrVisitor> fArrayVisitorPerSlot;
214 std::vector<ULong64_t> fChunkIndex;
215 arrow::ArrayVector fChunks;
218 TValueGetter(
size_t slots, arrow::ArrayVector chunks)
219 : fValuesPtrPerSlot(slots, nullptr), fLastEntryPerSlot(slots, 0), fLastChunkPerSlot(slots, 0), fChunks{chunks}
221 fChunkIndex.reserve(fChunks.size());
223 for (
auto &chunk : chunks) {
224 fFirstEntryPerChunk.push_back(next);
225 next += chunk->length();
226 fChunkIndex.push_back(next);
228 for (
size_t si = 0, se = fValuesPtrPerSlot.size(); si != se; ++si) {
229 fArrayVisitorPerSlot.push_back(ArrayPtrVisitor{fValuesPtrPerSlot.data() + si});
234 std::vector<void *> SlotPtrs()
236 std::vector<void *> result;
237 for (
size_t i = 0; i < fValuesPtrPerSlot.size(); ++i) {
238 result.push_back(fValuesPtrPerSlot.data() + i);
245 void UncachedSlotLookup(
unsigned int slot,
ULong64_t entry)
251 assert(slot < fLastChunkPerSlot.size());
252 if (fLastEntryPerSlot[slot] < entry) {
253 ci = fLastChunkPerSlot.at(slot);
256 for (
size_t ce = fChunkIndex.size(); ci != ce; ++ci) {
257 if (entry < fChunkIndex[ci]) {
258 assert(slot < fLastChunkPerSlot.size());
259 fLastChunkPerSlot[slot] = ci;
266 auto chunk = fChunks.at(fLastChunkPerSlot[slot]);
267 assert(slot < fArrayVisitorPerSlot.size());
268 fArrayVisitorPerSlot[slot].SetEntry(entry - fFirstEntryPerChunk[fLastChunkPerSlot[slot]]);
269 fLastEntryPerSlot[slot] = entry;
270 auto status = chunk->Accept(fArrayVisitorPerSlot.data() + slot);
272 std::string msg =
"Could not get pointer for slot ";
273 msg += std::to_string(slot) +
" looking at entry " + std::to_string(entry);
274 throw std::runtime_error(msg);
279 void SetEntry(
unsigned int slot,
ULong64_t entry)
282 if (fLastEntryPerSlot[slot] == entry) {
285 UncachedSlotLookup(slot, entry);
297class RDFTypeNameGetter :
public ::arrow::TypeVisitor {
302 arrow::Status Visit(
const arrow::Int64Type &)
override
305 return arrow::Status::OK();
307 arrow::Status Visit(
const arrow::Int32Type &)
override
310 return arrow::Status::OK();
312 arrow::Status Visit(
const arrow::UInt64Type &)
override
315 return arrow::Status::OK();
317 arrow::Status Visit(
const arrow::UInt32Type &)
override
320 return arrow::Status::OK();
322 arrow::Status Visit(
const arrow::FloatType &)
override
325 return arrow::Status::OK();
327 arrow::Status Visit(
const arrow::DoubleType &)
override
330 return arrow::Status::OK();
332 arrow::Status Visit(
const arrow::StringType &)
override
335 return arrow::Status::OK();
337 arrow::Status Visit(
const arrow::BooleanType &)
override
340 return arrow::Status::OK();
342 arrow::Status Visit(
const arrow::ListType &
l)
override
348 fTypeName.push_back(
"ROOT::VecOps::RVec<%s>");
349 return l.value_type()->Accept(
this);
354 std::string result =
"%s";
356 for (
size_t i = 0; i <
fTypeName.size(); ++i) {
363 using ::arrow::TypeVisitor::Visit;
367class VerifyValidColumnType :
public ::arrow::TypeVisitor {
370 virtual arrow::Status Visit(
const arrow::Int64Type &)
override {
return arrow::Status::OK(); }
371 virtual arrow::Status Visit(
const arrow::UInt64Type &)
override {
return arrow::Status::OK(); }
372 virtual arrow::Status Visit(
const arrow::Int32Type &)
override {
return arrow::Status::OK(); }
373 virtual arrow::Status Visit(
const arrow::UInt32Type &)
override {
return arrow::Status::OK(); }
374 virtual arrow::Status Visit(
const arrow::FloatType &)
override {
return arrow::Status::OK(); }
375 virtual arrow::Status Visit(
const arrow::DoubleType &)
override {
return arrow::Status::OK(); }
376 virtual arrow::Status Visit(
const arrow::StringType &)
override {
return arrow::Status::OK(); }
377 virtual arrow::Status Visit(
const arrow::BooleanType &)
override {
return arrow::Status::OK(); }
378 virtual arrow::Status Visit(
const arrow::ListType &)
override {
return arrow::Status::OK(); }
380 using ::arrow::TypeVisitor::Visit;
388RArrowDS::RArrowDS(std::shared_ptr<arrow::Table> inTable, std::vector<std::string>
const &inColumns)
389 : fTable{inTable}, fColumnNames{inColumns}
396 auto filterWantedColumns = [&columnNames, &table]() {
397 if (columnNames.empty()) {
398 for (
auto &field : table->schema()->fields()) {
399 columnNames.push_back(field->name());
405 using ColumnType =
decltype(
fTable->column(0));
407 auto getRecordsFirstColumn = [&columnNames, &table]() {
408 if (columnNames.empty()) {
409 throw std::runtime_error(
"At least one column required");
411 const auto name = columnNames.front();
412 const auto columnIdx = table->schema()->GetFieldIndex(
name);
413 return table->column(columnIdx)->length();
417 auto verifyColumnSize = [&table](ColumnType column,
int columnIdx,
int nRecords) {
418 if (column->length() != nRecords) {
419 std::string msg =
"Column ";
420 msg += table->schema()->field(columnIdx)->name() +
" has a different number of entries.";
421 throw std::runtime_error(msg);
426 auto verifyColumnType = [&table](ColumnType column,
int columnIdx) {
427 auto verifyType = std::make_unique<VerifyValidColumnType>();
428 auto result = column->type()->Accept(verifyType.get());
429 if (result.ok() ==
false) {
430 std::string msg =
"Column ";
431 msg += table->schema()->field(columnIdx)->name() +
" contains an unsupported type.";
432 throw std::runtime_error(msg);
438 auto addColumnToGetterIndex = [&index](
int columnId) { index.push_back(std::make_pair(columnId, index.size())); };
442 auto resetGetterIndex = [&index]() { index.clear(); };
445 filterWantedColumns();
447 auto nRecords = getRecordsFirstColumn();
449 auto columnIdx =
fTable->schema()->GetFieldIndex(columnName);
450 addColumnToGetterIndex(columnIdx);
452 auto column =
fTable->column(columnIdx);
453 verifyColumnSize(column, columnIdx, nRecords);
454 verifyColumnType(column, columnIdx);
464const std::vector<std::string> &RArrowDS::GetColumnNames()
const
469std::vector<std::pair<ULong64_t, ULong64_t>> RArrowDS::GetEntryRanges()
475std::string RArrowDS::GetTypeName(std::string_view colName)
const
477 auto field =
fTable->schema()->GetFieldByName(std::string(colName));
479 std::string msg =
"The dataset does not have column ";
481 throw std::runtime_error(msg);
483 RDFTypeNameGetter typeGetter;
484 auto status = field->type()->Accept(&typeGetter);
485 if (status.ok() ==
false) {
486 std::string msg =
"RArrowDS does not support a column of type ";
487 msg += field->type()->name();
488 throw std::runtime_error(msg);
490 return typeGetter.result();
493bool RArrowDS::HasColumn(std::string_view colName)
const
495 auto field =
fTable->schema()->GetFieldByName(std::string(colName));
502bool RArrowDS::SetEntry(
unsigned int slot,
ULong64_t entry)
506 getter->SetEntry(slot, entry);
511void RArrowDS::InitSlot(
unsigned int slot,
ULong64_t entry)
515 getter->UncachedSlotLookup(slot, entry);
519void splitInEqualRanges(std::vector<std::pair<ULong64_t, ULong64_t>> &ranges,
int nRecords,
unsigned int nSlots)
522 const auto chunkSize = nRecords / nSlots;
523 const auto remainder = 1U == nSlots ? 0 : nRecords % nSlots;
529 ranges.emplace_back(start, end);
532 ranges.back().second += remainder;
535int getNRecords(std::shared_ptr<arrow::Table> &table, std::vector<std::string> &columnNames)
537 auto index = table->schema()->GetFieldIndex(columnNames.front());
538 return table->column(index)->length();
542std::shared_ptr<arrow::ChunkedArray>
getData(T p)
548std::shared_ptr<arrow::ChunkedArray>
549getData<std::shared_ptr<arrow::ChunkedArray>>(std::shared_ptr<arrow::ChunkedArray> p)
554void RArrowDS::SetNSlots(
unsigned int nSlots)
556 assert(0U ==
fNSlots &&
"Setting the number of slots even if the number of slots is different from zero.");
562 for (
size_t ci = 0; ci != nColumns; ++ci) {
564 fValueGetters.emplace_back(std::make_unique<ROOT::Internal::RDF::TValueGetter>(nSlots, chunkedArray->chunks()));
570std::vector<void *> RArrowDS::GetColumnReadersImpl(std::string_view colName,
const std::type_info &)
573 auto findGetterIndex = [&index](
unsigned int column) {
574 for (
auto &entry : index) {
575 if (entry.first == column) {
579 throw std::runtime_error(
"No column found at index " + std::to_string(column));
582 const int columnIdx =
fTable->schema()->GetFieldIndex(std::string(colName));
583 const int getterIdx = findGetterIndex(columnIdx);
584 assert(getterIdx != -1);
589void RArrowDS::Initialise()
595std::string RArrowDS::GetLabel()
#define ROOT_ARROW_STL_CONVERSION(c_type, ArrowType_)
unsigned long long ULong64_t
typedef void((*Func_t)())
std::shared_ptr< arrow::Table > fTable
std::vector< std::pair< size_t, size_t > > fGetterIndex
std::vector< std::unique_ptr< ROOT::Internal::RDF::TValueGetter > > fValueGetters
std::vector< std::string > fColumnNames
std::vector< std::pair< ULong64_t, ULong64_t > > fEntryRanges
ROOT's RDataFrame offers a high level interface for analyses of data stored in TTrees,...
void splitInEqualRanges(std::vector< std::pair< ULong64_t, ULong64_t > > &ranges, int nRecords, unsigned int nSlots)
int getNRecords(std::shared_ptr< arrow::Table > &table, std::vector< std::string > &columnNames)
std::shared_ptr< arrow::ChunkedArray > getData(T p)
RDataFrame MakeArrowDataFrame(std::shared_ptr< arrow::Table > table, std::vector< std::string > const &columns)
Factory method to create a Apache Arrow RDataFrame.
tbb::task_arena is an alias of tbb::interface7::task_arena, which doesn't allow to forward declare tb...
TSeq< unsigned int > TSeqU