Lunderberg commented on a change in pull request #9727:
URL: https://github.com/apache/tvm/pull/9727#discussion_r777575815
##########
File path: src/tir/transforms/inject_virtual_thread.cc
##########
@@ -354,46 +391,44 @@ class VTInjector : public StmtExprMutator {
}
// Allocate
Stmt VisitStmt_(const AllocateNode* op) final {
+ Allocate node = GetRef<Allocate>(op);
+
PrimExpr condition = this->VisitExpr(op->condition);
+
+ Array<PrimExpr> extents = op->extents;
+ extents.MutateByApply([this](const PrimExpr& extent) { return
this->VisitExpr(extent); });
+
if (visit_touched_var_ && !vt_loop_injected_) {
return InjectVTLoop(GetRef<Stmt>(op), true);
}
- bool changed = false;
- Array<PrimExpr> extents;
- for (size_t i = 0; i < op->extents.size(); i++) {
- PrimExpr new_ext = this->VisitExpr(op->extents[i]);
- if (visit_touched_var_ && !vt_loop_injected_) {
- return InjectVTLoop(GetRef<Stmt>(op), true);
- }
- if (!new_ext.same_as(op->extents[i])) changed = true;
- extents.push_back(new_ext);
- }
visit_touched_var_ = false;
- Stmt body;
- // always rewrite if not allow sharing.
+ // Rewrite the buffer if its shape or any value stored in it
+ // depends on the virtual thread var. If `allow_share_` is false,
+ // then the buffer is always rewritten, even if separate virtual
+ // threads only read from the buffer.
if (touched_var_.count(op->buffer_var.get()) || !allow_share_) {
// place v on highest dimension.
- PrimExpr stride = foldl([](PrimExpr a, PrimExpr b, Span span) { return
mul(a, b, span); },
- make_const(DataType::Int(32), 1), op->extents) *
- op->dtype.lanes();
- Array<PrimExpr> other;
- other.push_back(make_const(op->extents[0].dtype(), num_threads_));
- for (PrimExpr e : extents) {
- other.push_back(e);
- }
- extents = other;
- changed = true;
- // mark this buffer get touched.
+
+ // TODO(Lunderberg): Move pass to apply before
Review comment:
Can do. I've drafted the collected together these TODOs into a list of
passes that could be simplified by operating on pre-flattened buffers, along
with a high-level description of the changes that would be needed for each.
* InjectVirtualThread
Currently operates on a flattened buffer. If each virtual thread
needs its own buffer, the buffer size is increased by a factor of
`num_threads_`, and all accesses to that buffer are rewritten from
`i` to `i + vthread_id * original_size`.
Would be simpler to perform before buffer flattening. If each
virtual thread needs its own buffer, the buffer's shape is changed
from `shape` to `[num_threads_, *shape]`, and all accesses to that
buffer are rewritten from `indices` to `[vthread_id, *indices]`.
* VectorizeLoop
Currently operates on a flattened buffer. If each iteration of a
vectorized loop needs its own buffer, the buffer size is increased
by a factor of `var_lanes_`, and all accesses to that buffer are
rewritten from `i` to `i*var_lanes_ + var_`.
Would be simpler to perform before buffer flattening. If each
iteration of a vectorized loop needs its own buffer, the buffer's
shape is changed from `shape` to `[*shape, var_lanes_]`, and all
accesses to that buffer are rewritten from `indices` to
`[*indices, var_]`.
* InjectDoubleBuffer
Currently operates on a flattened buffer. The size of the buffer
is doubled, all reads to that buffer are rewritten from `i` to
`i + original_size*switch_read_var`, and all writes are rewritten
from `i` to `i + original_size*switch_write_var`.
Would be simpler to perform before buffer flattening. The
buffer's shape is changed from `shape` to `[2, *shape]`, and all
accesses to the buffer are rewritten from `indices` to
`[loop_iter%2, *indices]`.
* BoundChecker
Currently operates on a flattened buffer, using attributes
containing the pre-flattened buffer size. Would be simpler to
perform before buffer flattening, using the size of the buffer.
* InjectCopyIntrin
Currently operates on a flattened buffer. To detect if this is a
memcpy that can be replaced by an intrinsic, the index of the
store/load is broken up into linear combinations of the
surrounding loop iteration variables, then tested to see if those
strides are consistent with copying each data value sequentially.
Would be simpler to perform before buffer flattening. If the
expression a memcpy that can be replaced by an intrinsic, the
indices of the store/load should be exactly the loop iteration
variables, and the loop iteration extent should be the size of the
buffer.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]