masahi commented on code in PR #12171:
URL: https://github.com/apache/tvm/pull/12171#discussion_r930523406
##########
src/tir/transforms/inject_software_pipeline.cc:
##########
@@ -494,6 +512,267 @@ class PipelineRewriter : public StmtExprMutator {
return Buffer(new_buffer);
}
+ // Per-stage states that need to be tracked across pipeline prologue, body,
and epilogue.
+ struct AsyncStateGlobal {
+ // Buffers that this stage asynchronously writes.
+ std::unordered_set<const BufferNode*> dst_buffers;
+ // An imaginary index that the latest async operation associated with this
stage has written
+ // into. Only valid if all associated predicates are true, so that we can
count the number of
+ // async invocations exactly. When it is valid, it is the "sum of extents
of loops that have
+ // been executed" - 1, e.g. for epilogue it is prologue extent + body
extent - 1. This
+ // is only needed to compute wait count for epilogue without async
producers.
+ Optional<PrimExpr> producer_head{PrimExpr(-1)};
+
+ bool writes(Buffer buf) const { return dst_buffers.count(buf.get()) > 0; }
+ };
+
+ // Per-stage states that are local to each of pipeline prologue, body, and
epilogue.
+ struct AsyncStateLocal {
+ struct {
+ // The index into a list of blocks, where async_wait_queue should be
attached at the
+ // beginning.
+ int insert_before;
+ // in_flight_count would be a more precise name, but the implementation
uses wait_count for
+ // brevity.
+ PrimExpr wait_count{nullptr};
+
+ bool valid() const { return wait_count.defined(); }
+ } pending_wait;
+
+ // Destination buffers of async operations that have been encountered so
far in the loop
+ //
+ // for (size_t i = 0; i < new_blocks.size(); ++i) {
+ // ...
+ // }
+ //
+ // This is for tracking which async operations have been issued at the
"current" iteration, up
+ // until a point where we encounter a consumer of async result buffers.
This is used to decide
+ // if the producer_head of each buffer points to a copy written in the
current or previous
+ // iteration.
+ std::unordered_set<const BufferNode*> seen;
+
+ // A symbolic expression representing the index the latest async operation
associated with this
+ // stage has written into, at the "current" iteration.
+ Optional<PrimExpr> producer_head;
+ // The predicate of BlockRealize containing the async operation of this
stage.
+ Optional<PrimExpr> predicate;
+ // Indices into a list of blocks, where async_commit_queue scope should be
attached.
+ // If multiple async producers are interleaved with their consumer in
between, we need separate
+ // async_commit_queue for each producer. Thus, we need multiple sets of
indices.
+ std::vector<std::vector<size_t>> commit_groups;
+
+ // This is set to true when we reach a stage that consumes this async
stage.
+ bool consumed{false};
+ };
+
+ /*! Structure holding intermediate information for pipeline loop rewriting.
*/
+ struct RewrittenBlockInfo {
+ int stage;
+ PrimExpr predicate;
+ Block block;
+ PrimExpr access_index;
+ bool is_async;
+ };
+
+ // Determine where to insert async_wait and the corresponding wait count.
+ void PopulateWaitCounts(const std::vector<RewrittenBlockInfo>& new_blocks,
+ arith::Analyzer* ana_normalized,
+ const std::unordered_map<const BufferNode*, int>&
buffer_to_commit_group,
+ std::map<int, AsyncStateLocal>* async_states_local) {
+ for (size_t i = 0; i < new_blocks.size(); ++i) {
+ if (new_blocks[i].is_async) {
+ // Record the fact that we have encountered these write buffers.
+ for (auto write_region : new_blocks[i].block->writes) {
+
(*async_states_local)[new_blocks[i].stage].seen.insert(write_region->buffer.get());
+ }
+ }
+
+ int producer_stage_idx = -1;
+ for (auto read_region : new_blocks[i].block->reads) {
+ for (auto kv : async_states) {
+ if (kv.first <= new_blocks[i].stage &&
kv.second.writes(read_region->buffer)) {
+ // Found an earlier stage where read_region->buffer was
asynchronously written
+ ICHECK(producer_stage_idx == -1 || producer_stage_idx == kv.first)
+ << "A dependency on multiple async stages is not supported";
+ producer_stage_idx = kv.first;
+ }
+ }
+ }
+
+ if (producer_stage_idx == -1) continue;
+
+ // The following logic has become complicated to handle case like this:
+ //
+ // for i in range(13):
+ // # Stage 0
+ // async_commit_queue(0):
+ // async_scope:
+ // A_shared[(i + 3) % 4] = A[...]
+ //
+ //
+ // # Stage 1
+ // async_wait_queue(0, 5):
+ // compute(A_shared[i], B_shared[i])
+ //
+ // # Stage 0
+ // async_commit_queue(0)
+ // async_scope:
+ // B_shared[(i + 3) % 4] = B[...]
+ //
+ //
+ // Here, multiple async producers in the same stage are interleaved with
their consumer in
+ // between. Since each buffer is associated with different commit
groups, the wait_count
+ // before the consumer should be bigger than the simpler case:
+ //
+ // for i in range(13):
+ // # Stage 0
+ // async_commit_queue(0):
+ // async_scope:
+ // A_shared[(i + 3) % 4] = A[...]
+ // B_shared[(i + 3) % 4] = B[...]
+ //
+ // # Stage 1
+ // async_wait_queue(0, 3):
+ // compute(A_shared[i], B_shared[i])
+ //
+ // The correct wait_count can be determined by considering each commit
group separately, and
+ // summing "per-commit" wait_counts.
+ //
+ // From A_shared's perspective, it allows for (i + 3) - i async commit
groups to be in
+ // flight while from B_shared's perspective, the producer head at
compute points to the copy
+ // done by the previous iteration, so its wait_count is calculated as
((i - 1) + 3) - i. The
+ // sum of the two wait_counts gives 5.
+
+ auto& dep_local_state = (*async_states_local)[producer_stage_idx];
+ const auto num_commit_group = dep_local_state.commit_groups.size();
+ std::vector<Optional<PrimExpr>> producer_head_per_commit;
+
+ if (num_commit_group == 0) {
+ // Epilogue, no async producer. Since "local" producer_head is not
available, use
+ // "global" producer_head.
+ ICHECK(!dep_local_state.producer_head);
+
producer_head_per_commit.push_back(async_states[producer_stage_idx].producer_head);
+ } else {
+ ICHECK(dep_local_state.producer_head);
+ std::vector<bool> need_wait_count(num_commit_group, true);
+
+ for (auto read_region : new_blocks[i].block->reads) {
+ if (!async_states[producer_stage_idx].writes(read_region->buffer))
continue;
+ auto commit_group_id =
buffer_to_commit_group.at(read_region->buffer.get());
+ if (!need_wait_count[commit_group_id]) continue;
+
+ if (!dep_local_state.seen.count(read_region->buffer.get())) {
+ // Multiple async producers interleaved: The most recent async
write is from the
+ // previous iteration. This is the B_shared case above.
+
producer_head_per_commit.push_back(dep_local_state.producer_head.value() - 1);
+ } else {
+ // Normal case
+
producer_head_per_commit.push_back(dep_local_state.producer_head.value());
+ }
+
+ need_wait_count[commit_group_id] = false;
+ }
+ }
+
+ auto wait_count = [=, &ana_normalized]() {
+ auto sum = PrimExpr(0);
+ for (auto producer_head : producer_head_per_commit) {
+ if (producer_head && ana_normalized->CanProve(producer_head.value()
>= 0)) {
+ // Here, new_blocks[i].access_index corresponds to "consumer_head".
+ // The difference of producer_head and consumer_head is precisely
the number of
+ // async commit groups that can still be in flight after this wait.
+ sum += analyzer_.Simplify(producer_head.value() -
new_blocks[i].access_index);
+ } else {
+ // The precise count cannot be determined, give up.
+ return PrimExpr(0);
+ }
+ }
+ return sum;
+ }();
+
+ auto& pending_wait = dep_local_state.pending_wait;
+
+ if (!pending_wait.valid()) {
+ pending_wait = {static_cast<int>(i), wait_count};
+ } else if (analyzer_.CanProve(wait_count < pending_wait.wait_count)) {
+ // Coalesce multiple wait_queue if the later one allows fewer
in-flight ops.
+ pending_wait = {pending_wait.insert_before, wait_count};
+ }
+ }
+ }
+
+ // Given pipelined blocks and async-related information, generate final loop
statements with async
+ // scopes (if any).
+ Array<Stmt> CompletePipelineLoopStatements(
Review Comment:
I'm not entirely happy with the choice of this name, a suggestion for better
one welcome.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]