Lunderberg commented on a change in pull request #8528:
URL: https://github.com/apache/tvm/pull/8528#discussion_r679485071



##########
File path: src/tir/transforms/storage_rewrite.cc
##########
@@ -885,108 +893,528 @@ class StoragePlanRewriter : public StmtExprMutator {
   arith::Analyzer analyzer_;
 };
 
-// Turn alloc into vector alloc
-// if all its access is the same vector type.
-class VectorAllocRewriter : public StmtExprMutator {
+/* Helper struct containing information on how a buffer is declared and used
+ *
+ */
+struct BufferVarInfo {
+  enum DeclarationLocation {
+    kPrimFuncParam = (1 << 0),
+    kPrimFuncBufferMap = (1 << 1),
+    kAllocateNode = (1 << 2),
+    kLetNode = (1 << 3),
+  };
+
+  // The tir::Var that represents this buffer.
+  Var var;
+
+  // The data type of an element of the buffer.
+  DataType element_dtype;
+
+  /* The extent of the buffer.
+   *
+   * If multidimensional, the extent of the last dimension of the buffer.  If 
the
+   * size is unknown (e.g. pointer arguments to PrimFunc with no corresponding
+   * entry in buffer_map), then extent is zero.
+   */
+  PrimExpr extent;
+
+  // Where the buffer was declared
+  DeclarationLocation declaration_location;
+
+  // When accessed, how many lanes of data are used.
+  std::set<int> lanes_used;
+
+  int get_preferred_lanes() const {
+    // If there is only one vectorizable size used to access the
+    // buffer, and if that access size is compatible with the array
+    // size, then the buffer is vectorizable.  In the future, this
+    // could be improved to allow vectorized buffer access of size
+    // GCD(*lanes_used), if necessary.
+    if ((element_dtype.lanes() == 1) && (lanes_used.size() == 1)) {
+      arith::Analyzer analyzer_;
+      arith::ModularSet me = analyzer_.modular_set(extent);
+
+      int lanes = *lanes_used.begin();
+      if ((me->coeff % lanes == 0) && (me->base % lanes == 0)) {
+        return lanes;
+      }
+    }
+
+    return element_dtype.lanes();
+  }
+
+  DataType get_preferred_dtype() const { return 
element_dtype.with_lanes(get_preferred_lanes()); }
+};
+
+/* Checks whether buffers are accessed as scalar or vector parameters in a
+ * function.
+ *
+ */
+class VectorTypeAccessChecker : public StmtExprVisitor {
  public:
-  PrimExpr VisitExpr_(const LoadNode* op) final {
-    UpdateTypeMap(op->buffer_var.get(), op->dtype);
-    return StmtExprMutator::VisitExpr_(op);
+  /* Constructor
+   *
+   * @param params The parameters passed to a PrimFunc
+   *
+   * @param buffer_map The buffer_map associated with a PrimFunc
+   *
+   * @param allow_untyped_handles If a buffer or pointer variable is
+   * missing a type annotation, assume that it has the same underlying
+   * type as it is later accessed, with scalar element types.
+   */
+  VectorTypeAccessChecker(const Array<tir::Var>& params, const Map<Var, 
Buffer>& buffer_map,
+                          bool allow_untyped_pointers = false)
+      : allow_untyped_pointers_(allow_untyped_pointers) {
+    // If a parameter is in the buffer map, we want to track the
+    // version in the map.
+    for (auto it : buffer_map) {
+      Buffer& buffer = it.second;
+      Var buffer_var = buffer->data;
+      DataType dtype = buffer->dtype;
+      PrimExpr extent = buffer->shape.size() ? 
buffer->shape[buffer->shape.size() - 1] : 0;
+      OnArrayDeclaration(buffer_var, dtype, extent, 
BufferVarInfo::kPrimFuncParam);
+    }
+
+    // If a pointer parameter isn't in the buffer map, then we want to
+    // track the parameter itself.
+    for (Var buffer_var : params) {
+      auto pointer_type = GetPointerType(buffer_var->type_annotation);
+      if (pointer_type.first && (buffer_map.count(buffer_var) == 0)) {
+        DataType dtype = pointer_type.second;
+        PrimExpr extent = 0;
+        OnArrayDeclaration(buffer_var, dtype, extent, 
BufferVarInfo::kPrimFuncBufferMap);
+      }
+    }
   }
 
-  Stmt VisitStmt_(const StoreNode* op) final {
-    UpdateTypeMap(op->buffer_var.get(), op->value.dtype());
-    return StmtExprMutator::VisitStmt_(op);
+  void VisitExpr_(const LoadNode* op) final {
+    OnArrayAccess(op->dtype, op->buffer_var.get(), op->index, op->predicate);
+    StmtExprVisitor::VisitExpr_(op);
   }
-  PrimExpr VisitExpr_(const CallNode* op) final {
+
+  void VisitStmt_(const StoreNode* op) final {
+    OnArrayAccess(op->value.dtype(), op->buffer_var.get(), op->index, 
op->predicate);
+    StmtExprVisitor::VisitStmt_(op);
+  }
+  void VisitExpr_(const CallNode* op) final {
     if (op->op.same_as(builtin::tvm_access_ptr())) {
       DataType dtype = op->args[0].dtype();
       const VarNode* buffer = op->args[1].as<VarNode>();
-      UpdateTypeMap(buffer, dtype);
+      PrimExpr index = op->args[2];
+      OnArrayAccess(dtype, buffer, index, const_true(dtype.lanes()));
     }
-    return StmtExprMutator::VisitExpr_(op);
+    StmtExprVisitor::VisitExpr_(op);
   }
 
-  Stmt VisitStmt_(const AllocateNode* op) final {
-    Stmt stmt = StmtExprMutator::VisitStmt_(op);
-    op = stmt.as<AllocateNode>();
-    const auto& tvec = acc_map_[op->buffer_var.get()];
-
-    if (tvec.size() == 1 && tvec[0].element_of() == op->dtype.element_of() &&
-        tvec[0].lanes() % op->dtype.lanes() == 0 && tvec[0].lanes() != 
op->dtype.lanes()) {
-      int factor = tvec[0].lanes() / op->dtype.lanes();
-      Array<PrimExpr> extents = op->extents;
-      arith::ModularSet me = analyzer_.modular_set(extents[extents.size() - 
1]);
-      if (me->base % factor == 0 && me->coeff % factor == 0) {
-        extents.Set(extents.size() - 1,
-                    extents[extents.size() - 1] / 
make_const(extents[0].dtype(), factor));
-        // create a new buffer var
-        DataType new_dtype = tvec[0];
-        Var new_buffer_var(op->buffer_var->name_hint,
-                           PointerType(PrimType(new_dtype), 
GetPtrStorageScope(op->buffer_var)));
-        // update the remap req.
-        var_remap_.Set(op->buffer_var, new_buffer_var);
-        return Allocate(new_buffer_var, new_dtype, extents, op->condition, 
op->body);
+  void VisitStmt_(const AllocateNode* op) final {
+    const Array<PrimExpr>& extents = op->extents;
+    PrimExpr extent = extents[extents.size() - 1];
+    OnArrayDeclaration(op->buffer_var, op->dtype, extent, 
BufferVarInfo::kAllocateNode);
+
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitExpr_(const LetNode* op) final {
+    HandleLetNode(op->var);
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  void VisitStmt_(const LetStmtNode* op) final {
+    HandleLetNode(op->var);
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void HandleLetNode(Var let_var) {
+    if (let_var->dtype.is_handle()) {
+      auto pointer_type = GetPointerType(let_var->type_annotation);
+      if (pointer_type.first) {
+        OnArrayDeclaration(let_var, pointer_type.second, 0, 
BufferVarInfo::kLetNode);
+      } else if (allow_untyped_pointers_) {
+        OnArrayDeclaration(let_var, let_var->dtype, 0, 
BufferVarInfo::kLetNode);
+      } else {
+        LOG(FATAL) << "Let statement of variable " << let_var->name_hint
+                   << " is missing a type annotation, "
+                   << "or type annotation is not a pointer to primitive";
       }
     }
-    return stmt;
   }
 
-  void UpdateTypeMap(const VarNode* buffer, DataType t) {
-    auto& tvec = acc_map_[buffer];
-    if (std::find(tvec.begin(), tvec.end(), t) == tvec.end()) {
-      tvec.push_back(t);
+  /* Update the type map for a buffer based on its declaration
+   *
+   * @param buffer The VarNode representing the buffer.
+   *
+   * @param element_dtype The dtype of a single element of the buffer.
+   * If unknown, when used with the allow_untyped_handles option,
+   * should be a handle dtype.
+   *
+   * @param extent The extent of the buffer.  Zero if size is unknown.
+   *
+   * @param declaration_location How the buffer was allocated, so that
+   * some locations can be rewritten without others.
+   */
+  void OnArrayDeclaration(Var buffer, DataType element_dtype, PrimExpr extent,
+                          BufferVarInfo::DeclarationLocation 
declaration_location) {
+    ICHECK(info_map_.find(buffer.get()) == info_map_.end())
+        << "Array declaration of " << buffer->name_hint << " occurred multiple 
times.";
+
+    if (element_dtype == DataType::Bool()) {
+      element_dtype = DataType::Int(8).with_lanes(element_dtype.lanes());
+    }
+
+    info_map_[buffer.get()] = {buffer, element_dtype, extent, 
declaration_location};
+  }
+
+  /* Update the type map for a buffer based on its usage
+   *
+   * @param value_dtype The dtype of the value being stored to or
+   * loaded from the buffer.
+   *
+   * @param buffer The VarNode representing the buffer.
+   *
+   * @param index The index at which the value is being stored/loaded.
+   *
+   * @param predicate The predicate used for the store/load.
+   */
+  void OnArrayAccess(DataType value_dtype, const VarNode* buffer, const 
PrimExpr& index,
+                     const PrimExpr& predicate) {
+    auto it = info_map_.find(buffer);
+    ICHECK(it != info_map_.end()) << "Load/Store of buffer " << 
buffer->name_hint << " (" << buffer
+                                  << ") occurred before its declaration.";
+    BufferVarInfo& var_info = it->second;
+
+    if (value_dtype.element_of() == DataType::Bool()) {
+      value_dtype = DataType::Int(8).with_lanes(value_dtype.lanes());
+    }
+
+    if (var_info.element_dtype.is_handle()) {
+      ICHECK(allow_untyped_pointers_) << "Variable " << buffer->name_hint
+                                      << " was missing a type annotation in 
its declaration";
+      var_info.element_dtype = value_dtype.element_of();
+    }
+
+    // Currently cannot valid the element type being accessed.  See comments in
+    // Load::Load for details.
+    //
+    // ICHECK_EQ(var_info.element_dtype.element_of(), value_dtype.element_of())
+    //     << "Attempting to access buffer of type " << var_info.element_dtype 
<< " as type "
+    //     << value_dtype;
+
+    int lanes_used = var_info.element_dtype.lanes();
+
+    // This can happen due to a previous pass that had rewrite_store_load =
+    // false.  This occurs from the StorageRewrite in tvm::lower, followed by 
the
+    // PointerValueTypeRewrite in BuildSPIRV.  The rewrite_store_load = false 
is
+    // necessary because the C-based codegens do not yet support vectorized
+    // pointer types (e.g. float16x4*).  Once they do, this if statement should
+    // instead be replaced by the below ICHECK_EQ.
+    if (index.dtype().lanes() * var_info.element_dtype.lanes() != 
value_dtype.lanes()) {
+      ICHECK_EQ(index.dtype().lanes(), value_dtype.lanes());
+      lanes_used = 1;
+      var_info.element_dtype = var_info.element_dtype.with_lanes(1);
     }
+
+    // ICHECK_EQ(index.dtype().lanes() * var_info.element_dtype.lanes(), 
value_dtype.lanes())
+    //     << "Attempting to retrieve " << value_dtype.lanes() << " lanes of 
data with "
+    //     << index.dtype().lanes() << " indices into an array whose elements 
have "
+    //     << var_info.element_dtype.lanes() << " lanes.  "
+    //     << "Expected output with " << index.dtype().lanes() * 
var_info.element_dtype.lanes()
+    //     << " lanes.";
+
+    // If the index is a RampNode with stride of 1 and offset
+    // divisible by the number of number of lanes, and the predicate
+    // does not apply any masking, then this array access could be
+    // vectorized.
+    const RampNode* ramp_index = index.as<RampNode>();
+    if (ramp_index && is_one(ramp_index->stride) && is_one(predicate)) {
+      arith::ModularSet me = analyzer_.modular_set(ramp_index->base);
+      if ((me->coeff % ramp_index->lanes == 0) && (me->base % 
ramp_index->lanes == 0)) {
+        lanes_used = ramp_index->lanes;
+      }
+    }
+
+    var_info.lanes_used.insert(lanes_used);
   }
 
-  // Internal access map
-  std::unordered_map<const VarNode*, std::vector<DataType> > acc_map_;
-  // Variables to remap
-  Map<tir::Var, PrimExpr> var_remap_;
+  // Map of buffer variable information determined
+  std::unordered_map<const VarNode*, BufferVarInfo> info_map_;
+
+  //
+  bool allow_untyped_pointers_{false};
+
   // internal analyzer
   arith::Analyzer analyzer_;
 };
 
-PrimFunc PointerValueTypeRewrite(PrimFunc f) {
-  auto* n = f.CopyOnWrite();
-  VectorAllocRewriter rewriter;
-  n->body = rewriter(std::move(n->body));
+/* \brief Rewrites buffer/pointer variables from scalar types to vectorized
+ * types.
+ *
+ * Some runtimes do not allow casting between composite types and the 
underlying
+ * base type (e.g. Vulkan, casting from 1-lane float16* to 4-lane float16x4*).
+ * In these cases, in order to have vectorized load/store on an array, the
+ * element type of that array must be vectorized.  This is in contrast to 
C-style
+ * runtimes, in which `float16x4* vec = *(float16x4*)(float_arr + offset)` is
+ * valid.
+ *
+ * By default, VectorTypeRewriter will attempt to rewrite all buffer variables 
to
+ * vectorized access, if the load/store occurring in the PrimFunc are all
+ * vectorized.  This includes adjusting the indices being used to access the
+ * array.  (e.g. If `float16* scalar_arr` is being converted to `float16x4*
+ * vec_arr`, then `scalar_arr[Ramp(offset, 1, 4)]` will be converted to
+ * `vec_arr[offset/4]`.)
+ *
+ * Currently, several of the C-style runtimes do not support buffers whose
+ * elements are vectorized types, or rely on the presence of the Ramp nodes to
+ * identify vectorized loads.  The boolean parameters in the constructor are to
+ * mimic the previous behavior of VectorTypeRewriter, to avoid breaking these
+ * runtimes.  Once all runtimes support vectorized buffer elements, these
+ * parameters can be removed.
+ */
+class VectorTypeRewriter : public StmtExprMutator {

Review comment:
       I was surprised as well, as my mental model was that all changes to the 
TIR graph occur before being passed to the codegen.  This extra call is fairly 
vulkan-specific, and from talking with @tqchen exists to adjust the 
pointer-type through which Vulkan accesses an array.  For cuda, it's 
unnecessary because the pointer can be cast to the desired output type, but 
Vulkan doesn't allow those pointer casts.  Instead, we need to choose one 
specific type for each pointer that is passed in, and stick with it through the 
entire `PrimFunc`.
   
   My preference would be to pull it out into a target-aware optimization pass, 
which could be added into either `tvm.lower` or `tvm.build`.  I've put together 
some comments over [on 
discuss](https://discuss.tvm.apache.org/t/pre-rfc-vectorized-tir-buffers/10615/4),
 and will tag you there as well.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to