Lunderberg commented on a change in pull request #9727:
URL: https://github.com/apache/tvm/pull/9727#discussion_r777575815



##########
File path: src/tir/transforms/inject_virtual_thread.cc
##########
@@ -354,46 +391,44 @@ class VTInjector : public StmtExprMutator {
   }
   // Allocate
   Stmt VisitStmt_(const AllocateNode* op) final {
+    Allocate node = GetRef<Allocate>(op);
+
     PrimExpr condition = this->VisitExpr(op->condition);
+
+    Array<PrimExpr> extents = op->extents;
+    extents.MutateByApply([this](const PrimExpr& extent) { return 
this->VisitExpr(extent); });
+
     if (visit_touched_var_ && !vt_loop_injected_) {
       return InjectVTLoop(GetRef<Stmt>(op), true);
     }
 
-    bool changed = false;
-    Array<PrimExpr> extents;
-    for (size_t i = 0; i < op->extents.size(); i++) {
-      PrimExpr new_ext = this->VisitExpr(op->extents[i]);
-      if (visit_touched_var_ && !vt_loop_injected_) {
-        return InjectVTLoop(GetRef<Stmt>(op), true);
-      }
-      if (!new_ext.same_as(op->extents[i])) changed = true;
-      extents.push_back(new_ext);
-    }
     visit_touched_var_ = false;
 
-    Stmt body;
-    // always rewrite if not allow sharing.
+    // Rewrite the buffer if its shape or any value stored in it
+    // depends on the virtual thread var.  If `allow_share_` is false,
+    // then the buffer is always rewritten, even if separate virtual
+    // threads only read from the buffer.
     if (touched_var_.count(op->buffer_var.get()) || !allow_share_) {
       // place v on highest dimension.
-      PrimExpr stride = foldl([](PrimExpr a, PrimExpr b, Span span) { return 
mul(a, b, span); },
-                              make_const(DataType::Int(32), 1), op->extents) *
-                        op->dtype.lanes();
-      Array<PrimExpr> other;
-      other.push_back(make_const(op->extents[0].dtype(), num_threads_));
-      for (PrimExpr e : extents) {
-        other.push_back(e);
-      }
-      extents = other;
-      changed = true;
-      // mark this buffer get touched.
+
+      // TODO(Lunderberg): Move pass to apply before

Review comment:
       Can do.  I've drafted the collected together these TODOs into a list of 
passes that could be simplified by operating on pre-flattened buffers, along 
with a high-level description of the changes that would be needed for each.
   
   * InjectVirtualThread
   
       Currently operates on a flattened buffer.  If each virtual thread
       needs its own buffer, the buffer size is increased by a factor of
       `num_threads_`, and all accesses to that buffer are rewritten from
       `i` to `i + vthread_id * original_size`.
   
       Would be simpler to perform before buffer flattening.  If each
       virtual thread needs its own buffer, the buffer's shape is changed
       from `shape` to `[num_threads_, *shape]`, and all accesses to that
       buffer are rewritten from `indices` to `[vthread_id, *indices]`.
   
   * VectorizeLoop
   
       Currently operates on a flattened buffer.  If each iteration of a
       vectorized loop needs its own buffer, the buffer size is increased
       by a factor of `var_lanes_`, and all accesses to that buffer are
       rewritten from `i` to `i*var_lanes_ + var_`.
   
       Would be simpler to perform before buffer flattening.  If each
       iteration of a vectorized loop needs its own buffer, the buffer's
       shape is changed from `shape` to `[*shape, var_lanes_]`, and all
       accesses to that buffer are rewritten from `indices` to
       `[*indices, var_]`.
   
   * InjectDoubleBuffer
   
       Currently operates on a flattened buffer.  The size of the buffer
       is doubled, all reads to that buffer are rewritten from `i` to
       `i + original_size*switch_read_var`, and all writes are rewritten
       from `i` to `i + original_size*switch_write_var`.
   
       Would be simpler to perform before buffer flattening.  The
       buffer's shape is changed from `shape` to `[2, *shape]`, and all
       accesses to the buffer are rewritten from `indices` to
       `[loop_iter%2, *indices]`.
   
   * BoundChecker
   
       Currently operates on a flattened buffer, using attributes
       containing the pre-flattened buffer size.  Would be simpler to
       perform before buffer flattening, using the size of the buffer.
   
   * InjectCopyIntrin
   
       Currently operates on a flattened buffer.  To detect if this is a
       memcpy that can be replaced by an intrinsic, the index of the
       store/load is broken up into linear combinations of the
       surrounding loop iteration variables, then tested to see if those
       strides are consistent with copying each data value sequentially.
       
       Would be simpler to perform before buffer flattening.  If the
       expression a memcpy that can be replaced by an intrinsic, the
       indices of the store/load should be exactly the loop iteration
       variables, and the loop iteration extent should be the size of the
       buffer.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to