Lunderberg commented on a change in pull request #8528:
URL: https://github.com/apache/tvm/pull/8528#discussion_r679137229
##########
File path: src/tir/transforms/storage_rewrite.cc
##########
@@ -885,108 +893,528 @@ class StoragePlanRewriter : public StmtExprMutator {
arith::Analyzer analyzer_;
};
-// Turn alloc into vector alloc
-// if all its access is the same vector type.
-class VectorAllocRewriter : public StmtExprMutator {
+/* Helper struct containing information on how a buffer is declared and used
+ *
+ */
+struct BufferVarInfo {
+ enum DeclarationLocation {
+ kPrimFuncParam = (1 << 0),
+ kPrimFuncBufferMap = (1 << 1),
+ kAllocateNode = (1 << 2),
+ kLetNode = (1 << 3),
+ };
+
+ // The tir::Var that represents this buffer.
+ Var var;
+
+ // The data type of an element of the buffer.
+ DataType element_dtype;
+
+ /* The extent of the buffer.
+ *
+ * If multidimensional, the extent of the last dimension of the buffer. If
the
+ * size is unknown (e.g. pointer arguments to PrimFunc with no corresponding
+ * entry in buffer_map), then extent is zero.
+ */
+ PrimExpr extent;
+
+ // Where the buffer was declared
+ DeclarationLocation declaration_location;
+
+ // When accessed, how many lanes of data are used.
+ std::set<int> lanes_used;
+
+ int get_preferred_lanes() const {
+ // If there is only one vectorizable size used to access the
+ // buffer, and if that access size is compatible with the array
+ // size, then the buffer is vectorizable. In the future, this
+ // could be improved to allow vectorized buffer access of size
+ // GCD(*lanes_used), if necessary.
+ if ((element_dtype.lanes() == 1) && (lanes_used.size() == 1)) {
+ arith::Analyzer analyzer_;
+ arith::ModularSet me = analyzer_.modular_set(extent);
+
+ int lanes = *lanes_used.begin();
+ if ((me->coeff % lanes == 0) && (me->base % lanes == 0)) {
+ return lanes;
+ }
+ }
+
+ return element_dtype.lanes();
+ }
+
+ DataType get_preferred_dtype() const { return
element_dtype.with_lanes(get_preferred_lanes()); }
+};
+
+/* Checks whether buffers are accessed as scalar or vector parameters in a
+ * function.
+ *
+ */
+class VectorTypeAccessChecker : public StmtExprVisitor {
public:
- PrimExpr VisitExpr_(const LoadNode* op) final {
- UpdateTypeMap(op->buffer_var.get(), op->dtype);
- return StmtExprMutator::VisitExpr_(op);
+ /* Constructor
+ *
+ * @param params The parameters passed to a PrimFunc
+ *
+ * @param buffer_map The buffer_map associated with a PrimFunc
+ *
+ * @param allow_untyped_handles If a buffer or pointer variable is
+ * missing a type annotation, assume that it has the same underlying
+ * type as it is later accessed, with scalar element types.
+ */
+ VectorTypeAccessChecker(const Array<tir::Var>& params, const Map<Var,
Buffer>& buffer_map,
+ bool allow_untyped_pointers = false)
+ : allow_untyped_pointers_(allow_untyped_pointers) {
+ // If a parameter is in the buffer map, we want to track the
+ // version in the map.
+ for (auto it : buffer_map) {
+ Buffer& buffer = it.second;
+ Var buffer_var = buffer->data;
+ DataType dtype = buffer->dtype;
+ PrimExpr extent = buffer->shape.size() ?
buffer->shape[buffer->shape.size() - 1] : 0;
+ OnArrayDeclaration(buffer_var, dtype, extent,
BufferVarInfo::kPrimFuncParam);
+ }
+
+ // If a pointer parameter isn't in the buffer map, then we want to
+ // track the parameter itself.
+ for (Var buffer_var : params) {
+ auto pointer_type = GetPointerType(buffer_var->type_annotation);
+ if (pointer_type.first && (buffer_map.count(buffer_var) == 0)) {
+ DataType dtype = pointer_type.second;
+ PrimExpr extent = 0;
+ OnArrayDeclaration(buffer_var, dtype, extent,
BufferVarInfo::kPrimFuncBufferMap);
+ }
+ }
}
- Stmt VisitStmt_(const StoreNode* op) final {
- UpdateTypeMap(op->buffer_var.get(), op->value.dtype());
- return StmtExprMutator::VisitStmt_(op);
+ void VisitExpr_(const LoadNode* op) final {
+ OnArrayAccess(op->dtype, op->buffer_var.get(), op->index, op->predicate);
+ StmtExprVisitor::VisitExpr_(op);
}
- PrimExpr VisitExpr_(const CallNode* op) final {
+
+ void VisitStmt_(const StoreNode* op) final {
+ OnArrayAccess(op->value.dtype(), op->buffer_var.get(), op->index,
op->predicate);
+ StmtExprVisitor::VisitStmt_(op);
+ }
+ void VisitExpr_(const CallNode* op) final {
if (op->op.same_as(builtin::tvm_access_ptr())) {
DataType dtype = op->args[0].dtype();
const VarNode* buffer = op->args[1].as<VarNode>();
- UpdateTypeMap(buffer, dtype);
+ PrimExpr index = op->args[2];
+ OnArrayAccess(dtype, buffer, index, const_true(dtype.lanes()));
}
- return StmtExprMutator::VisitExpr_(op);
+ StmtExprVisitor::VisitExpr_(op);
}
- Stmt VisitStmt_(const AllocateNode* op) final {
- Stmt stmt = StmtExprMutator::VisitStmt_(op);
- op = stmt.as<AllocateNode>();
- const auto& tvec = acc_map_[op->buffer_var.get()];
-
- if (tvec.size() == 1 && tvec[0].element_of() == op->dtype.element_of() &&
- tvec[0].lanes() % op->dtype.lanes() == 0 && tvec[0].lanes() !=
op->dtype.lanes()) {
- int factor = tvec[0].lanes() / op->dtype.lanes();
- Array<PrimExpr> extents = op->extents;
- arith::ModularSet me = analyzer_.modular_set(extents[extents.size() -
1]);
- if (me->base % factor == 0 && me->coeff % factor == 0) {
- extents.Set(extents.size() - 1,
- extents[extents.size() - 1] /
make_const(extents[0].dtype(), factor));
- // create a new buffer var
- DataType new_dtype = tvec[0];
- Var new_buffer_var(op->buffer_var->name_hint,
- PointerType(PrimType(new_dtype),
GetPtrStorageScope(op->buffer_var)));
- // update the remap req.
- var_remap_.Set(op->buffer_var, new_buffer_var);
- return Allocate(new_buffer_var, new_dtype, extents, op->condition,
op->body);
+ void VisitStmt_(const AllocateNode* op) final {
+ const Array<PrimExpr>& extents = op->extents;
+ PrimExpr extent = extents[extents.size() - 1];
+ OnArrayDeclaration(op->buffer_var, op->dtype, extent,
BufferVarInfo::kAllocateNode);
+
+ StmtExprVisitor::VisitStmt_(op);
+ }
+
+ void VisitExpr_(const LetNode* op) final {
+ HandleLetNode(op->var);
+ StmtExprVisitor::VisitExpr_(op);
+ }
+
+ void VisitStmt_(const LetStmtNode* op) final {
+ HandleLetNode(op->var);
+ StmtExprVisitor::VisitStmt_(op);
+ }
+
+ void HandleLetNode(Var let_var) {
+ if (let_var->dtype.is_handle()) {
+ auto pointer_type = GetPointerType(let_var->type_annotation);
+ if (pointer_type.first) {
+ OnArrayDeclaration(let_var, pointer_type.second, 0,
BufferVarInfo::kLetNode);
+ } else if (allow_untyped_pointers_) {
+ OnArrayDeclaration(let_var, let_var->dtype, 0,
BufferVarInfo::kLetNode);
+ } else {
+ LOG(FATAL) << "Let statement of variable " << let_var->name_hint
+ << " is missing a type annotation, "
+ << "or type annotation is not a pointer to primitive";
}
}
- return stmt;
}
- void UpdateTypeMap(const VarNode* buffer, DataType t) {
- auto& tvec = acc_map_[buffer];
- if (std::find(tvec.begin(), tvec.end(), t) == tvec.end()) {
- tvec.push_back(t);
+ /* Update the type map for a buffer based on its declaration
+ *
+ * @param buffer The VarNode representing the buffer.
+ *
+ * @param element_dtype The dtype of a single element of the buffer.
+ * If unknown, when used with the allow_untyped_handles option,
+ * should be a handle dtype.
+ *
+ * @param extent The extent of the buffer. Zero if size is unknown.
+ *
+ * @param declaration_location How the buffer was allocated, so that
+ * some locations can be rewritten without others.
+ */
+ void OnArrayDeclaration(Var buffer, DataType element_dtype, PrimExpr extent,
+ BufferVarInfo::DeclarationLocation
declaration_location) {
+ ICHECK(info_map_.find(buffer.get()) == info_map_.end())
+ << "Array declaration of " << buffer->name_hint << " occurred multiple
times.";
+
+ if (element_dtype == DataType::Bool()) {
+ element_dtype = DataType::Int(8).with_lanes(element_dtype.lanes());
+ }
+
+ info_map_[buffer.get()] = {buffer, element_dtype, extent,
declaration_location};
+ }
+
+ /* Update the type map for a buffer based on its usage
+ *
+ * @param value_dtype The dtype of the value being stored to or
+ * loaded from the buffer.
+ *
+ * @param buffer The VarNode representing the buffer.
+ *
+ * @param index The index at which the value is being stored/loaded.
+ *
+ * @param predicate The predicate used for the store/load.
+ */
+ void OnArrayAccess(DataType value_dtype, const VarNode* buffer, const
PrimExpr& index,
+ const PrimExpr& predicate) {
+ auto it = info_map_.find(buffer);
+ ICHECK(it != info_map_.end()) << "Load/Store of buffer " <<
buffer->name_hint << " (" << buffer
+ << ") occurred before its declaration.";
+ BufferVarInfo& var_info = it->second;
+
+ if (value_dtype.element_of() == DataType::Bool()) {
+ value_dtype = DataType::Int(8).with_lanes(value_dtype.lanes());
+ }
+
+ if (var_info.element_dtype.is_handle()) {
+ ICHECK(allow_untyped_pointers_) << "Variable " << buffer->name_hint
+ << " was missing a type annotation in
its declaration";
+ var_info.element_dtype = value_dtype.element_of();
+ }
+
+ // Currently cannot valid the element type being accessed. See comments in
+ // Load::Load for details.
+ //
+ // ICHECK_EQ(var_info.element_dtype.element_of(), value_dtype.element_of())
+ // << "Attempting to access buffer of type " << var_info.element_dtype
<< " as type "
+ // << value_dtype;
+
+ int lanes_used = var_info.element_dtype.lanes();
+
+ // This can happen due to a previous pass that had rewrite_store_load =
+ // false. This occurs from the StorageRewrite in tvm::lower, followed by
the
+ // PointerValueTypeRewrite in BuildSPIRV. The rewrite_store_load = false
is
+ // necessary because the C-based codegens do not yet support vectorized
+ // pointer types (e.g. float16x4*). Once they do, this if statement should
+ // instead be replaced by the below ICHECK_EQ.
+ if (index.dtype().lanes() * var_info.element_dtype.lanes() !=
value_dtype.lanes()) {
+ ICHECK_EQ(index.dtype().lanes(), value_dtype.lanes());
+ lanes_used = 1;
+ var_info.element_dtype = var_info.element_dtype.with_lanes(1);
}
+
+ // ICHECK_EQ(index.dtype().lanes() * var_info.element_dtype.lanes(),
value_dtype.lanes())
+ // << "Attempting to retrieve " << value_dtype.lanes() << " lanes of
data with "
+ // << index.dtype().lanes() << " indices into an array whose elements
have "
+ // << var_info.element_dtype.lanes() << " lanes. "
+ // << "Expected output with " << index.dtype().lanes() *
var_info.element_dtype.lanes()
+ // << " lanes.";
+
+ // If the index is a RampNode with stride of 1 and offset
+ // divisible by the number of number of lanes, and the predicate
+ // does not apply any masking, then this array access could be
+ // vectorized.
+ const RampNode* ramp_index = index.as<RampNode>();
+ if (ramp_index && is_one(ramp_index->stride) && is_one(predicate)) {
+ arith::ModularSet me = analyzer_.modular_set(ramp_index->base);
+ if ((me->coeff % ramp_index->lanes == 0) && (me->base %
ramp_index->lanes == 0)) {
+ lanes_used = ramp_index->lanes;
+ }
+ }
+
+ var_info.lanes_used.insert(lanes_used);
}
- // Internal access map
- std::unordered_map<const VarNode*, std::vector<DataType> > acc_map_;
- // Variables to remap
- Map<tir::Var, PrimExpr> var_remap_;
+ // Map of buffer variable information determined
+ std::unordered_map<const VarNode*, BufferVarInfo> info_map_;
+
+ //
+ bool allow_untyped_pointers_{false};
+
// internal analyzer
arith::Analyzer analyzer_;
};
-PrimFunc PointerValueTypeRewrite(PrimFunc f) {
- auto* n = f.CopyOnWrite();
- VectorAllocRewriter rewriter;
- n->body = rewriter(std::move(n->body));
+/* \brief Rewrites buffer/pointer variables from scalar types to vectorized
+ * types.
+ *
+ * Some runtimes do not allow casting between composite types and the
underlying
+ * base type (e.g. Vulkan, casting from 1-lane float16* to 4-lane float16x4*).
+ * In these cases, in order to have vectorized load/store on an array, the
+ * element type of that array must be vectorized. This is in contrast to
C-style
+ * runtimes, in which `float16x4* vec = *(float16x4*)(float_arr + offset)` is
+ * valid.
+ *
+ * By default, VectorTypeRewriter will attempt to rewrite all buffer variables
to
+ * vectorized access, if the load/store occurring in the PrimFunc are all
+ * vectorized. This includes adjusting the indices being used to access the
+ * array. (e.g. If `float16* scalar_arr` is being converted to `float16x4*
+ * vec_arr`, then `scalar_arr[Ramp(offset, 1, 4)]` will be converted to
+ * `vec_arr[offset/4]`.)
+ *
+ * Currently, several of the C-style runtimes do not support buffers whose
+ * elements are vectorized types, or rely on the presence of the Ramp nodes to
+ * identify vectorized loads. The boolean parameters in the constructor are to
+ * mimic the previous behavior of VectorTypeRewriter, to avoid breaking these
+ * runtimes. Once all runtimes support vectorized buffer elements, these
+ * parameters can be removed.
+ */
+class VectorTypeRewriter : public StmtExprMutator {
+ public:
+ /* Constructor
+ *
+ * @param checker The VectorTypeAccessChecker that has previously read out
+ * information from the PrimFunc
+ *
+ * @param rewrite_params Whether pointer-type parameters passed into the
+ * function should be rewritten from scalar types to vectorized types.
+ *
+ * @param rewrite_buffer_map Whether buffers present in the buffer_map should
+ * have their data variable be rewritten from scalar types to vectorized
types.
+ *
+ * @param rewrite_allocate_node Whether the buffer variable associated with
+ * AllocateNodes should be rewritten from scalar types to vectorized types.
+ *
+ * @param rewrite_indices Whether the indices to the Load and Store nodes
+ * should be rewritten to correspond to the new buffer_var type.
+ *
+ * @param rewrite_let_node Whether pointer declarations in let nodes
+ * should be re-written.
+ */
+ VectorTypeRewriter(const VectorTypeAccessChecker& checker, bool
rewrite_params = true,
Review comment:
Certainly, can do. I was going back and forth between passing `checker`
or `checker.info_map_` when I was writing it.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]