Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
RVariedAction.hxx
Go to the documentation of this file.
1// Author: Enrico Guiraud, CERN 11/2021
2
3/*************************************************************************
4 * Copyright (C) 1995-2022, Rene Brun and Fons Rademakers. *
5 * All rights reserved. *
6 * *
7 * For the licensing terms see $ROOTSYS/LICENSE. *
8 * For the list of contributors see $ROOTSYS/README/CREDITS. *
9 *************************************************************************/
10
11#ifndef ROOT_RVARIEDACTION
12#define ROOT_RVARIEDACTION
13
14#include "ColumnReaderUtils.hxx"
15#include "GraphNode.hxx"
16#include "RActionBase.hxx"
17#include "RColumnReaderBase.hxx"
18#include "RLoopManager.hxx"
19#include "RJittedFilter.hxx"
22
23#include <Rtypes.h> // R__CLING_PTRCHECK
24#include <ROOT/TypeTraits.hxx>
25
26#include <algorithm>
27#include <array>
28#include <memory>
29#include <utility> // make_index_sequence
30#include <vector>
31
32namespace ROOT {
33namespace Internal {
34namespace RDF {
35
37
38/// Just like an RAction, but it has N action helpers and N previous nodes (N is the number of variations).
39template <typename Helper, typename PrevNode, typename ColumnTypes_t>
40class R__CLING_PTRCHECK(off) RVariedAction final : public RActionBase {
41 using TypeInd_t = std::make_index_sequence<ColumnTypes_t::list_size>;
42 // If the PrevNode is a RJittedFilter, our collection of previous nodes will have to use the RNodeBase type:
43 // we'll have a RJittedFilter for the nominal case, but the others will be concrete filters.
44 using PrevNodeType = std::conditional_t<std::is_same<PrevNode, RJittedFilter>::value, RFilterBase, PrevNode>;
45
46 std::vector<Helper> fHelpers; ///< Action helpers per variation.
47 /// Owning pointers to upstream nodes for each systematic variation.
48 std::vector<std::shared_ptr<PrevNodeType>> fPrevNodes;
49
50 /// Column readers per slot (outer dimension), per variation and per input column (inner dimension, std::array).
51 std::vector<std::vector<std::array<RColumnReaderBase *, ColumnTypes_t::list_size>>> fInputValues;
52
53 /// The nth flag signals whether the nth input column is a custom column or not.
54 std::array<bool, ColumnTypes_t::list_size> fIsDefine;
55
56 /// \brief Creates new filter nodes, one per variation, from the upstream nominal one.
57 /// \param nominal The nominal filter
58 /// \return The varied filters
59 ///
60 /// The nominal filter is not included in the return value.
61 std::vector<std::shared_ptr<PrevNodeType>> MakePrevFilters(std::shared_ptr<PrevNode> nominal) const
62 {
63 const auto &variations = GetVariations();
64 std::vector<std::shared_ptr<PrevNodeType>> prevFilters;
65 prevFilters.reserve(variations.size());
66 if (static_cast<RNodeBase *>(nominal.get()) == fLoopManager) {
67 // just fill this with the RLoopManager N times
68 prevFilters.resize(variations.size(), nominal);
69 } else {
70 // create varied versions of the previous filter node
71 const auto &prevVariations = nominal->GetVariations();
72 for (const auto &variation : variations) {
73 if (IsStrInVec(variation, prevVariations)) {
74 prevFilters.emplace_back(std::static_pointer_cast<PrevNodeType>(nominal->GetVariedFilter(variation)));
75 } else {
76 prevFilters.emplace_back(nominal);
77 }
78 }
79 }
80
81 return prevFilters;
82 }
83
85 {
86 // The column register and names are private members of RActionBase
87 const auto &colRegister = GetColRegister();
88 const auto &columnNames = GetColumnNames();
89
90 fLoopManager->Register(this);
91
92 for (auto i = 0u; i < columnNames.size(); ++i) {
93 auto *define = colRegister.GetDefine(columnNames[i]);
94 fIsDefine[i] = define != nullptr;
95 if (fIsDefine[i])
96 define->MakeVariations(GetVariations());
97 }
98 }
99
100 /// This constructor takes in input a vector of previous nodes, motivated by the CloneAction logic.
101 RVariedAction(std::vector<Helper> &&helpers, const ColumnNames_t &columns,
102 const std::vector<std::shared_ptr<PrevNodeType>> &prevNodes, const RColumnRegister &colRegister)
103 : RActionBase(prevNodes[0]->GetLoopManagerUnchecked(), columns, colRegister, prevNodes[0]->GetVariations()),
104 fHelpers(std::move(helpers)),
105 fPrevNodes(prevNodes),
106 fInputValues(GetNSlots())
107 {
108 SetupClass();
109 }
110
111public:
112 RVariedAction(std::vector<Helper> &&helpers, const ColumnNames_t &columns, std::shared_ptr<PrevNode> prevNode,
113 const RColumnRegister &colRegister)
114 : RActionBase(prevNode->GetLoopManagerUnchecked(), columns, colRegister, prevNode->GetVariations()),
115 fHelpers(std::move(helpers)),
116 fPrevNodes(MakePrevFilters(prevNode)),
117 fInputValues(GetNSlots())
118 {
119 SetupClass();
120 }
121
122 RVariedAction(const RVariedAction &) = delete;
124
125 ~RVariedAction() { fLoopManager->Deregister(this); }
126
127 void Initialize() final
128 {
129 std::for_each(fHelpers.begin(), fHelpers.end(), [](Helper &h) { h.Initialize(); });
130 }
131
132 void InitSlot(TTreeReader *r, unsigned int slot) final
133 {
134 RColumnReadersInfo info{GetColumnNames(), GetColRegister(), fIsDefine.data(), *fLoopManager};
135
136 // get readers for each systematic variation
137 for (const auto &variation : GetVariations())
138 fInputValues[slot].emplace_back(GetColumnReaders(slot, r, ColumnTypes_t{}, info, variation));
139
140 std::for_each(fHelpers.begin(), fHelpers.end(), [=](Helper &h) { h.InitTask(r, slot); });
141 }
142
143 template <typename ColType>
144 auto GetValueChecked(unsigned int slot, unsigned int varIdx, std::size_t readerIdx, Long64_t entry) -> ColType &
145 {
146 if (auto *val = fInputValues[slot][varIdx][readerIdx]->template TryGet<ColType>(entry))
147 return *val;
148
149 throw std::out_of_range{"RDataFrame: Varied action (" + fHelpers[0].GetActionName() +
150 ") could not retrieve value for column '" + fColumnNames[readerIdx] + "' for entry " +
151 std::to_string(entry) +
152 ". You can use the DefaultValueFor operation to provide a default value, or "
153 "FilterAvailable/FilterMissing to discard/keep entries with missing values instead."};
154 }
155
156 template <typename... ColTypes, std::size_t... ReaderIdxs>
157 void CallExec(unsigned int slot, unsigned int varIdx, Long64_t entry, TypeList<ColTypes...>,
158 std::index_sequence<ReaderIdxs...>)
159 {
160 fHelpers[varIdx].Exec(slot, GetValueChecked<ColTypes>(slot, varIdx, ReaderIdxs, entry)...);
161 (void)entry;
162 }
163
164 void Run(unsigned int slot, Long64_t entry) final
165 {
166 for (auto varIdx = 0u; varIdx < GetVariations().size(); ++varIdx) {
167 if (fPrevNodes[varIdx]->CheckFilters(slot, entry))
168 CallExec(slot, varIdx, entry, ColumnTypes_t{}, TypeInd_t{});
169 }
170 }
171
173 {
174 std::for_each(fPrevNodes.begin(), fPrevNodes.end(), [](auto &f) { f->IncrChildrenCount(); });
175 }
176
177 /// Clean-up operations to be performed at the end of a task.
178 void FinalizeSlot(unsigned int slot) final
179 {
180 fInputValues[slot].clear();
181 std::for_each(fHelpers.begin(), fHelpers.end(), [=](Helper &h) { h.CallFinalizeTask(slot); });
182 }
183
184 /// Clean-up and finalize the action result (e.g. merging slot-local results).
185 /// It invokes the helper's Finalize method.
186 void Finalize() final
187 {
188 std::for_each(fHelpers.begin(), fHelpers.end(), [](Helper &h) { h.Finalize(); });
189 SetHasRun();
190 }
191
192 /// Return the partially-updated value connected to the first variation.
193 void *PartialUpdate(unsigned int slot) final { return PartialUpdateImpl(slot); }
194
195 /// Return a callback that in turn runs the callbacks of each variation's helper.
197 {
198 if (fHelpers[0].GetSampleCallback()) {
199 std::vector<ROOT::RDF::SampleCallback_t> callbacks;
200 for (auto &h : fHelpers)
201 callbacks.push_back(h.GetSampleCallback());
202
203 auto callEachCallback = [cs = std::move(callbacks)](unsigned int slot, const RSampleInfo &info) {
204 for (auto &c : cs)
205 c(slot, info);
206 };
207
208 return callEachCallback;
209 }
210
211 return {};
212 }
213
214 std::shared_ptr<RDFGraphDrawing::GraphNode>
215 GetGraph(std::unordered_map<void *, std::shared_ptr<RDFGraphDrawing::GraphNode>> &visitedMap) final
216 {
217 auto prevNode = fPrevNodes[0]->GetGraph(visitedMap);
218 const auto &prevColumns = prevNode->GetDefinedColumns();
219
220 // Action nodes do not need to go through CreateFilterNode: they are never common nodes between multiple branches
221 const auto nodeType = HasRun() ? RDFGraphDrawing::ENodeType::kUsedAction : RDFGraphDrawing::ENodeType::kAction;
222 auto thisNode = std::make_shared<RDFGraphDrawing::GraphNode>("Varied " + fHelpers[0].GetActionName(),
223 visitedMap.size(), nodeType);
224 visitedMap[(void *)this] = thisNode;
225
226 auto upmostNode = AddDefinesToGraph(thisNode, GetColRegister(), prevColumns, visitedMap);
227
228 thisNode->AddDefinedColumns(GetColRegister().GenerateColumnNames());
229 upmostNode->SetPrevNode(prevNode);
230 return thisNode;
231 }
232
233 /**
234 Retrieve a container holding the names and values of the variations. It
235 knows how to merge with others of the same type.
236 */
237 std::unique_ptr<RMergeableValueBase> GetMergeableValue() const final
238 {
239 std::vector<std::string> keys{GetVariations()};
240
241 std::vector<std::unique_ptr<RDFDetail::RMergeableValueBase>> values;
242 values.reserve(fHelpers.size());
243 for (auto &&h : fHelpers)
244 values.emplace_back(h.GetMergeableValue());
245
246 return std::make_unique<RDFDetail::RMergeableVariationsBase>(std::move(keys), std::move(values));
247 }
248
249 [[noreturn]] std::unique_ptr<RActionBase> MakeVariedAction(std::vector<void *> &&) final
250 {
251 throw std::logic_error("Cannot produce a varied action from a varied action.");
252 }
253
254 std::unique_ptr<RActionBase> CloneAction(void *typeErasedResults) final
255 {
256 const auto &vectorOfTypeErasedResults = *reinterpret_cast<const std::vector<void *> *>(typeErasedResults);
257 assert(vectorOfTypeErasedResults.size() == fHelpers.size() &&
258 "The number of results and the number of helpers are not the same!");
259
260 std::vector<Helper> clonedHelpers;
261 clonedHelpers.reserve(fHelpers.size());
262 for (std::size_t i = 0; i < fHelpers.size(); i++) {
263 clonedHelpers.emplace_back(fHelpers[i].CallMakeNew(vectorOfTypeErasedResults[i]));
264 }
265
266 return std::unique_ptr<RVariedAction>(
267 new RVariedAction(std::move(clonedHelpers), GetColumnNames(), fPrevNodes, GetColRegister()));
268 }
269
270private:
271 // this overload is SFINAE'd out if Helper does not implement `PartialUpdate`
272 // the template parameter is required to defer instantiation of the method to SFINAE time
273 template <typename H = Helper>
274 auto PartialUpdateImpl(unsigned int slot) -> decltype(std::declval<H>().PartialUpdate(slot), (void *)(nullptr))
275 {
276 return &fHelpers[0].PartialUpdate(slot);
277 }
278
279 // this one is always available but has lower precedence thanks to `...`
280 void *PartialUpdateImpl(...) { throw std::runtime_error("This action does not support callbacks!"); }
281};
282
283} // namespace RDF
284} // namespace Internal
285} // namespace ROOT
286
287#endif // ROOT_RVARIEDACTION
#define f(i)
Definition RSha256.hxx:104
#define c(i)
Definition RSha256.hxx:101
#define h(i)
Definition RSha256.hxx:106
long long Long64_t
Definition RtypesCore.h:69
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t r
Base class for non-leaf nodes of the computational graph.
Definition RNodeBase.hxx:43
A binder for user-defined columns, variations and aliases.
Just like an RAction, but it has N action helpers and N previous nodes (N is the number of variations...
void Finalize() final
Clean-up and finalize the action result (e.g.
auto GetValueChecked(unsigned int slot, unsigned int varIdx, std::size_t readerIdx, Long64_t entry) -> ColType &
void Run(unsigned int slot, Long64_t entry) final
std::vector< std::shared_ptr< PrevNodeType > > fPrevNodes
Owning pointers to upstream nodes for each systematic variation.
std::unique_ptr< RMergeableValueBase > GetMergeableValue() const final
Retrieve a container holding the names and values of the variations.
RVariedAction(std::vector< Helper > &&helpers, const ColumnNames_t &columns, std::shared_ptr< PrevNode > prevNode, const RColumnRegister &colRegister)
auto PartialUpdateImpl(unsigned int slot) -> decltype(std::declval< H >().PartialUpdate(slot),(void *)(nullptr))
std::unique_ptr< RActionBase > CloneAction(void *typeErasedResults) final
ROOT::RDF::SampleCallback_t GetSampleCallback() final
Return a callback that in turn runs the callbacks of each variation's helper.
void FinalizeSlot(unsigned int slot) final
Clean-up operations to be performed at the end of a task.
std::vector< std::shared_ptr< PrevNodeType > > MakePrevFilters(std::shared_ptr< PrevNode > nominal) const
Creates new filter nodes, one per variation, from the upstream nominal one.
std::make_index_sequence< ColumnTypes_t::list_size > TypeInd_t
std::vector< std::vector< std::array< RColumnReaderBase *, ColumnTypes_t::list_size > > > fInputValues
Column readers per slot (outer dimension), per variation and per input column (inner dimension,...
std::vector< Helper > fHelpers
Action helpers per variation.
RVariedAction & operator=(const RVariedAction &)=delete
void InitSlot(TTreeReader *r, unsigned int slot) final
std::conditional_t< std::is_same< PrevNode, RJittedFilter >::value, RFilterBase, PrevNode > PrevNodeType
std::shared_ptr< RDFGraphDrawing::GraphNode > GetGraph(std::unordered_map< void *, std::shared_ptr< RDFGraphDrawing::GraphNode > > &visitedMap) final
std::array< bool, ColumnTypes_t::list_size > fIsDefine
The nth flag signals whether the nth input column is a custom column or not.
void * PartialUpdate(unsigned int slot) final
Return the partially-updated value connected to the first variation.
std::unique_ptr< RActionBase > MakeVariedAction(std::vector< void * > &&) final
RVariedAction(const RVariedAction &)=delete
void CallExec(unsigned int slot, unsigned int varIdx, Long64_t entry, TypeList< ColTypes... >, std::index_sequence< ReaderIdxs... >)
RVariedAction(std::vector< Helper > &&helpers, const ColumnNames_t &columns, const std::vector< std::shared_ptr< PrevNodeType > > &prevNodes, const RColumnRegister &colRegister)
This constructor takes in input a vector of previous nodes, motivated by the CloneAction logic.
This type represents a sample identifier, to be used in conjunction with RDataFrame features such as ...
A simple, robust and fast interface to read values from ROOT columnar datasets such as TTree,...
Definition TTreeReader.h:46
unsigned int GetNSlots()
Definition RDFUtils.cxx:301
bool IsStrInVec(const std::string &str, const std::vector< std::string > &vec)
Definition RDFUtils.cxx:439
std::array< RDFDetail::RColumnReaderBase *, sizeof...(ColTypes)> GetColumnReaders(unsigned int slot, TTreeReader *r, TypeList< ColTypes... >, const RColumnReadersInfo &colInfo, const std::string &variationName="nominal")
Create a group of column readers, one per type in the parameter pack.
std::vector< std::string > ColumnNames_t
std::function< void(unsigned int, const ROOT::RDF::RSampleInfo &)> SampleCallback_t
The type of a data-block callback, registered with an RDataFrame computation graph via e....
tbb::task_arena is an alias of tbb::interface7::task_arena, which doesn't allow to forward declare tb...
This type aggregates some of the arguments passed to GetColumnReaders.
Lightweight storage for a collection of types.