Lunderberg commented on a change in pull request #8528:
URL: https://github.com/apache/tvm/pull/8528#discussion_r679154602
##########
File path: src/tir/transforms/storage_rewrite.cc
##########
@@ -885,108 +893,528 @@ class StoragePlanRewriter : public StmtExprMutator {
arith::Analyzer analyzer_;
};
-// Turn alloc into vector alloc
-// if all its access is the same vector type.
-class VectorAllocRewriter : public StmtExprMutator {
+/* Helper struct containing information on how a buffer is declared and used
+ *
+ */
+struct BufferVarInfo {
+ enum DeclarationLocation {
+ kPrimFuncParam = (1 << 0),
+ kPrimFuncBufferMap = (1 << 1),
+ kAllocateNode = (1 << 2),
+ kLetNode = (1 << 3),
+ };
+
+ // The tir::Var that represents this buffer.
+ Var var;
+
+ // The data type of an element of the buffer.
+ DataType element_dtype;
+
+ /* The extent of the buffer.
+ *
+ * If multidimensional, the extent of the last dimension of the buffer. If
the
+ * size is unknown (e.g. pointer arguments to PrimFunc with no corresponding
+ * entry in buffer_map), then extent is zero.
+ */
+ PrimExpr extent;
+
+ // Where the buffer was declared
+ DeclarationLocation declaration_location;
+
+ // When accessed, how many lanes of data are used.
+ std::set<int> lanes_used;
+
+ int get_preferred_lanes() const {
+ // If there is only one vectorizable size used to access the
+ // buffer, and if that access size is compatible with the array
+ // size, then the buffer is vectorizable. In the future, this
+ // could be improved to allow vectorized buffer access of size
+ // GCD(*lanes_used), if necessary.
+ if ((element_dtype.lanes() == 1) && (lanes_used.size() == 1)) {
+ arith::Analyzer analyzer_;
+ arith::ModularSet me = analyzer_.modular_set(extent);
+
+ int lanes = *lanes_used.begin();
+ if ((me->coeff % lanes == 0) && (me->base % lanes == 0)) {
+ return lanes;
+ }
+ }
+
+ return element_dtype.lanes();
+ }
+
+ DataType get_preferred_dtype() const { return
element_dtype.with_lanes(get_preferred_lanes()); }
+};
+
+/* Checks whether buffers are accessed as scalar or vector parameters in a
+ * function.
+ *
+ */
+class VectorTypeAccessChecker : public StmtExprVisitor {
public:
- PrimExpr VisitExpr_(const LoadNode* op) final {
- UpdateTypeMap(op->buffer_var.get(), op->dtype);
- return StmtExprMutator::VisitExpr_(op);
+ /* Constructor
+ *
+ * @param params The parameters passed to a PrimFunc
+ *
+ * @param buffer_map The buffer_map associated with a PrimFunc
+ *
+ * @param allow_untyped_handles If a buffer or pointer variable is
+ * missing a type annotation, assume that it has the same underlying
+ * type as it is later accessed, with scalar element types.
+ */
+ VectorTypeAccessChecker(const Array<tir::Var>& params, const Map<Var,
Buffer>& buffer_map,
+ bool allow_untyped_pointers = false)
+ : allow_untyped_pointers_(allow_untyped_pointers) {
+ // If a parameter is in the buffer map, we want to track the
+ // version in the map.
+ for (auto it : buffer_map) {
+ Buffer& buffer = it.second;
+ Var buffer_var = buffer->data;
+ DataType dtype = buffer->dtype;
+ PrimExpr extent = buffer->shape.size() ?
buffer->shape[buffer->shape.size() - 1] : 0;
+ OnArrayDeclaration(buffer_var, dtype, extent,
BufferVarInfo::kPrimFuncParam);
+ }
+
+ // If a pointer parameter isn't in the buffer map, then we want to
+ // track the parameter itself.
+ for (Var buffer_var : params) {
+ auto pointer_type = GetPointerType(buffer_var->type_annotation);
+ if (pointer_type.first && (buffer_map.count(buffer_var) == 0)) {
+ DataType dtype = pointer_type.second;
+ PrimExpr extent = 0;
+ OnArrayDeclaration(buffer_var, dtype, extent,
BufferVarInfo::kPrimFuncBufferMap);
+ }
+ }
}
- Stmt VisitStmt_(const StoreNode* op) final {
- UpdateTypeMap(op->buffer_var.get(), op->value.dtype());
- return StmtExprMutator::VisitStmt_(op);
+ void VisitExpr_(const LoadNode* op) final {
+ OnArrayAccess(op->dtype, op->buffer_var.get(), op->index, op->predicate);
+ StmtExprVisitor::VisitExpr_(op);
}
- PrimExpr VisitExpr_(const CallNode* op) final {
+
+ void VisitStmt_(const StoreNode* op) final {
+ OnArrayAccess(op->value.dtype(), op->buffer_var.get(), op->index,
op->predicate);
+ StmtExprVisitor::VisitStmt_(op);
+ }
+ void VisitExpr_(const CallNode* op) final {
if (op->op.same_as(builtin::tvm_access_ptr())) {
DataType dtype = op->args[0].dtype();
const VarNode* buffer = op->args[1].as<VarNode>();
- UpdateTypeMap(buffer, dtype);
+ PrimExpr index = op->args[2];
+ OnArrayAccess(dtype, buffer, index, const_true(dtype.lanes()));
}
- return StmtExprMutator::VisitExpr_(op);
+ StmtExprVisitor::VisitExpr_(op);
}
- Stmt VisitStmt_(const AllocateNode* op) final {
- Stmt stmt = StmtExprMutator::VisitStmt_(op);
- op = stmt.as<AllocateNode>();
- const auto& tvec = acc_map_[op->buffer_var.get()];
-
- if (tvec.size() == 1 && tvec[0].element_of() == op->dtype.element_of() &&
- tvec[0].lanes() % op->dtype.lanes() == 0 && tvec[0].lanes() !=
op->dtype.lanes()) {
- int factor = tvec[0].lanes() / op->dtype.lanes();
- Array<PrimExpr> extents = op->extents;
- arith::ModularSet me = analyzer_.modular_set(extents[extents.size() -
1]);
- if (me->base % factor == 0 && me->coeff % factor == 0) {
- extents.Set(extents.size() - 1,
- extents[extents.size() - 1] /
make_const(extents[0].dtype(), factor));
- // create a new buffer var
- DataType new_dtype = tvec[0];
- Var new_buffer_var(op->buffer_var->name_hint,
- PointerType(PrimType(new_dtype),
GetPtrStorageScope(op->buffer_var)));
- // update the remap req.
- var_remap_.Set(op->buffer_var, new_buffer_var);
- return Allocate(new_buffer_var, new_dtype, extents, op->condition,
op->body);
+ void VisitStmt_(const AllocateNode* op) final {
+ const Array<PrimExpr>& extents = op->extents;
+ PrimExpr extent = extents[extents.size() - 1];
+ OnArrayDeclaration(op->buffer_var, op->dtype, extent,
BufferVarInfo::kAllocateNode);
+
+ StmtExprVisitor::VisitStmt_(op);
+ }
+
+ void VisitExpr_(const LetNode* op) final {
+ HandleLetNode(op->var);
+ StmtExprVisitor::VisitExpr_(op);
+ }
+
+ void VisitStmt_(const LetStmtNode* op) final {
+ HandleLetNode(op->var);
+ StmtExprVisitor::VisitStmt_(op);
+ }
+
+ void HandleLetNode(Var let_var) {
+ if (let_var->dtype.is_handle()) {
+ auto pointer_type = GetPointerType(let_var->type_annotation);
+ if (pointer_type.first) {
+ OnArrayDeclaration(let_var, pointer_type.second, 0,
BufferVarInfo::kLetNode);
+ } else if (allow_untyped_pointers_) {
+ OnArrayDeclaration(let_var, let_var->dtype, 0,
BufferVarInfo::kLetNode);
+ } else {
+ LOG(FATAL) << "Let statement of variable " << let_var->name_hint
+ << " is missing a type annotation, "
+ << "or type annotation is not a pointer to primitive";
}
}
- return stmt;
}
- void UpdateTypeMap(const VarNode* buffer, DataType t) {
- auto& tvec = acc_map_[buffer];
- if (std::find(tvec.begin(), tvec.end(), t) == tvec.end()) {
- tvec.push_back(t);
+ /* Update the type map for a buffer based on its declaration
+ *
+ * @param buffer The VarNode representing the buffer.
+ *
+ * @param element_dtype The dtype of a single element of the buffer.
+ * If unknown, when used with the allow_untyped_handles option,
+ * should be a handle dtype.
+ *
+ * @param extent The extent of the buffer. Zero if size is unknown.
+ *
+ * @param declaration_location How the buffer was allocated, so that
+ * some locations can be rewritten without others.
+ */
+ void OnArrayDeclaration(Var buffer, DataType element_dtype, PrimExpr extent,
+ BufferVarInfo::DeclarationLocation
declaration_location) {
+ ICHECK(info_map_.find(buffer.get()) == info_map_.end())
+ << "Array declaration of " << buffer->name_hint << " occurred multiple
times.";
+
+ if (element_dtype == DataType::Bool()) {
+ element_dtype = DataType::Int(8).with_lanes(element_dtype.lanes());
+ }
+
+ info_map_[buffer.get()] = {buffer, element_dtype, extent,
declaration_location};
+ }
+
+ /* Update the type map for a buffer based on its usage
+ *
+ * @param value_dtype The dtype of the value being stored to or
+ * loaded from the buffer.
+ *
+ * @param buffer The VarNode representing the buffer.
+ *
+ * @param index The index at which the value is being stored/loaded.
+ *
+ * @param predicate The predicate used for the store/load.
+ */
+ void OnArrayAccess(DataType value_dtype, const VarNode* buffer, const
PrimExpr& index,
+ const PrimExpr& predicate) {
+ auto it = info_map_.find(buffer);
+ ICHECK(it != info_map_.end()) << "Load/Store of buffer " <<
buffer->name_hint << " (" << buffer
+ << ") occurred before its declaration.";
+ BufferVarInfo& var_info = it->second;
+
+ if (value_dtype.element_of() == DataType::Bool()) {
+ value_dtype = DataType::Int(8).with_lanes(value_dtype.lanes());
+ }
+
+ if (var_info.element_dtype.is_handle()) {
+ ICHECK(allow_untyped_pointers_) << "Variable " << buffer->name_hint
+ << " was missing a type annotation in
its declaration";
+ var_info.element_dtype = value_dtype.element_of();
+ }
+
+ // Currently cannot valid the element type being accessed. See comments in
+ // Load::Load for details.
+ //
+ // ICHECK_EQ(var_info.element_dtype.element_of(), value_dtype.element_of())
+ // << "Attempting to access buffer of type " << var_info.element_dtype
<< " as type "
+ // << value_dtype;
+
+ int lanes_used = var_info.element_dtype.lanes();
+
+ // This can happen due to a previous pass that had rewrite_store_load =
+ // false. This occurs from the StorageRewrite in tvm::lower, followed by
the
+ // PointerValueTypeRewrite in BuildSPIRV. The rewrite_store_load = false
is
+ // necessary because the C-based codegens do not yet support vectorized
+ // pointer types (e.g. float16x4*). Once they do, this if statement should
+ // instead be replaced by the below ICHECK_EQ.
+ if (index.dtype().lanes() * var_info.element_dtype.lanes() !=
value_dtype.lanes()) {
+ ICHECK_EQ(index.dtype().lanes(), value_dtype.lanes());
+ lanes_used = 1;
+ var_info.element_dtype = var_info.element_dtype.with_lanes(1);
}
+
+ // ICHECK_EQ(index.dtype().lanes() * var_info.element_dtype.lanes(),
value_dtype.lanes())
+ // << "Attempting to retrieve " << value_dtype.lanes() << " lanes of
data with "
+ // << index.dtype().lanes() << " indices into an array whose elements
have "
+ // << var_info.element_dtype.lanes() << " lanes. "
+ // << "Expected output with " << index.dtype().lanes() *
var_info.element_dtype.lanes()
+ // << " lanes.";
+
+ // If the index is a RampNode with stride of 1 and offset
+ // divisible by the number of number of lanes, and the predicate
+ // does not apply any masking, then this array access could be
+ // vectorized.
+ const RampNode* ramp_index = index.as<RampNode>();
+ if (ramp_index && is_one(ramp_index->stride) && is_one(predicate)) {
+ arith::ModularSet me = analyzer_.modular_set(ramp_index->base);
+ if ((me->coeff % ramp_index->lanes == 0) && (me->base %
ramp_index->lanes == 0)) {
+ lanes_used = ramp_index->lanes;
+ }
+ }
+
+ var_info.lanes_used.insert(lanes_used);
}
- // Internal access map
- std::unordered_map<const VarNode*, std::vector<DataType> > acc_map_;
- // Variables to remap
- Map<tir::Var, PrimExpr> var_remap_;
+ // Map of buffer variable information determined
+ std::unordered_map<const VarNode*, BufferVarInfo> info_map_;
+
+ //
+ bool allow_untyped_pointers_{false};
+
// internal analyzer
arith::Analyzer analyzer_;
};
-PrimFunc PointerValueTypeRewrite(PrimFunc f) {
- auto* n = f.CopyOnWrite();
- VectorAllocRewriter rewriter;
- n->body = rewriter(std::move(n->body));
+/* \brief Rewrites buffer/pointer variables from scalar types to vectorized
+ * types.
+ *
+ * Some runtimes do not allow casting between composite types and the
underlying
+ * base type (e.g. Vulkan, casting from 1-lane float16* to 4-lane float16x4*).
+ * In these cases, in order to have vectorized load/store on an array, the
+ * element type of that array must be vectorized. This is in contrast to
C-style
+ * runtimes, in which `float16x4* vec = *(float16x4*)(float_arr + offset)` is
+ * valid.
+ *
+ * By default, VectorTypeRewriter will attempt to rewrite all buffer variables
to
+ * vectorized access, if the load/store occurring in the PrimFunc are all
+ * vectorized. This includes adjusting the indices being used to access the
+ * array. (e.g. If `float16* scalar_arr` is being converted to `float16x4*
+ * vec_arr`, then `scalar_arr[Ramp(offset, 1, 4)]` will be converted to
+ * `vec_arr[offset/4]`.)
+ *
+ * Currently, several of the C-style runtimes do not support buffers whose
+ * elements are vectorized types, or rely on the presence of the Ramp nodes to
+ * identify vectorized loads. The boolean parameters in the constructor are to
+ * mimic the previous behavior of VectorTypeRewriter, to avoid breaking these
+ * runtimes. Once all runtimes support vectorized buffer elements, these
+ * parameters can be removed.
+ */
+class VectorTypeRewriter : public StmtExprMutator {
Review comment:
Ah, good point. Currently, the `tir.transform.PointerValueTypeRewrite`
pass is only used in the SPIR-V codegen step, while the
`tir.transform.StorageRewrite` pass is used as part of
[`tvm.lower`](https://github.com/apache/tvm/blob/main/src/driver/driver_api.cc#L242).
These two passes both share the `PointerValueTypeRewrite` function.
Currently, those two conflicting requirements are handled by having different
arguments to `PointerValueTypeRewrite` for those two cases.
Unfortunately, for the C-codegen, it isn't quite a nop, as that would have
been simpler to handle. Instead, the `AllocateNode`, any function parameters,
and references to those variables get rewritten, but the `StoreNode` and
`LoadNode` do not have their indices rewritten to account for the different
variable type. I'm [trying to
determine](https://discuss.tvm.apache.org/t/pre-rfc-vectorized-tir-buffers/10615/4)
the best way to handle that on the C codegens, but for now the previous
behavior can be maintained with the boolean options to
`PointerValueTypeRewrite`.
##########
File path: src/tir/ir/expr.cc
##########
@@ -618,8 +619,38 @@ Load::Load(DataType dtype, Var buffer_var, PrimExpr index,
PrimExpr predicate, S
ICHECK(buffer_var.defined());
ICHECK(predicate.defined());
ICHECK(index.defined());
- ICHECK_EQ(dtype.lanes(), index.dtype().lanes());
- ICHECK_EQ(dtype.lanes(), predicate.dtype().lanes());
+
+ // Assume that the array elements have 1 lane, unless a type
+ // annotation tells us otherwise.
+ int element_lanes = 1;
+ auto pointer_type = tir::GetPointerType(buffer_var->type_annotation);
+ if (pointer_type.first) {
+ // Cannot check element type of array, as it may be different than
+ // the loaded type in some cases.
+ //
+ // 1. Booleans use DataType::Int(8) while stored, and the codegens
+ // handle cast to boolean.
+ //
+ // 2. The StorageRewrite pass can merge multiple allocations at
+ // the same scope, regardless of element type. The codegen is
+ // then responsible for casting to the output type.
+
+ // ICHECK(dtype.element_of() == pointer_type.second.element_of())
Review comment:
Good call, and I will change them to an explicit TODO.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]