This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 925148e444 [TIR] Shuffle in PointerValueTypeRewrite for scalar reads
(#15517)
925148e444 is described below
commit 925148e444103f044e9dbe111aacf0c5079abc3a
Author: Wuwei Lin <[email protected]>
AuthorDate: Fri Aug 18 12:11:06 2023 -0700
[TIR] Shuffle in PointerValueTypeRewrite for scalar reads (#15517)
---
python/tvm/tir/transform/transform.py | 14 +++
src/target/spirv/codegen_spirv.cc | 11 ++
src/target/spirv/codegen_spirv.h | 1 +
src/tir/transforms/storage_rewrite.cc | 112 +++++++++++++++------
...est_tir_transform_pointer_value_type_rewrite.py | 73 ++++++++++++++
5 files changed, 180 insertions(+), 31 deletions(-)
diff --git a/python/tvm/tir/transform/transform.py
b/python/tvm/tir/transform/transform.py
index 0cd54064a7..a46b2d1037 100644
--- a/python/tvm/tir/transform/transform.py
+++ b/python/tvm/tir/transform/transform.py
@@ -230,6 +230,20 @@ def StorageRewrite():
return _ffi_api.StorageRewrite() # type: ignore
+def PointerValueTypeRewrite():
+ """
+ Rewrite the pointer content type of arguments, as well as Alloc internal
to the function to use
+ the most frequently accessed type for load/store to avoid pointer casting
in backend when
+ possible.
+
+ Returns
+ -------
+ fpass : tvm.transform.Pass
+ The result pass
+ """
+ return _ffi_api.PointerValueTypeRewrite() # type: ignore
+
+
def UnrollLoop():
"""Unroll the constant loop marked by unroll.
diff --git a/src/target/spirv/codegen_spirv.cc
b/src/target/spirv/codegen_spirv.cc
index ab9aec0775..5cc3f8f8dd 100644
--- a/src/target/spirv/codegen_spirv.cc
+++ b/src/target/spirv/codegen_spirv.cc
@@ -610,6 +610,17 @@ void CodeGenSPIRV::Scalarize(const PrimExpr& e,
std::function<void(int i, spirv:
}
}
+spirv::Value CodeGenSPIRV::VisitExpr_(const ShuffleNode* op) {
+ ICHECK(op->vectors.size() == 1 && op->indices.size() == 1)
+ << "SPIR-V codegen only supports shuffle "
+ << "of one vector with one index";
+ spirv::Value vector = MakeValue(op->vectors[0]);
+ int index = Downcast<Integer>(op->indices[0])->value;
+ spirv::SType etype = builder_->GetSType(op->dtype);
+ spirv::Value element = builder_->MakeValue(spv::OpCompositeExtract, etype,
vector, index);
+ return element;
+}
+
void CodeGenSPIRV::VisitStmt_(const BufferStoreNode* op) {
ICHECK_EQ(op->indices.size(), 1) << "SPIR-V codegen expects flat memory
buffers";
Var buffer_var = op->buffer->data;
diff --git a/src/target/spirv/codegen_spirv.h b/src/target/spirv/codegen_spirv.h
index 1e7b535585..8ea90a9c4b 100644
--- a/src/target/spirv/codegen_spirv.h
+++ b/src/target/spirv/codegen_spirv.h
@@ -102,6 +102,7 @@ class CodeGenSPIRV : public ExprFunctor<spirv::Value(const
PrimExpr&)>,
spirv::Value VisitExpr_(const RampNode* op) override;
spirv::Value VisitExpr_(const BroadcastNode* op) override;
spirv::Value VisitExpr_(const BufferLoadNode* op) override;
+ spirv::Value VisitExpr_(const ShuffleNode* op) override;
// stmt
void VisitStmt_(const BufferStoreNode* op) override;
void VisitStmt_(const ForNode* op) override;
diff --git a/src/tir/transforms/storage_rewrite.cc
b/src/tir/transforms/storage_rewrite.cc
index 3ecd0f64bb..f271769c80 100644
--- a/src/tir/transforms/storage_rewrite.cc
+++ b/src/tir/transforms/storage_rewrite.cc
@@ -36,6 +36,7 @@
#include <unordered_map>
#include <unordered_set>
+#include "../../arith/int_operator.h"
#include "../../runtime/thread_storage_scope.h"
#include "../ir/buffer_common.h"
#include "ir_utils.h"
@@ -1066,12 +1067,18 @@ struct BufferVarInfo {
// packing in StorageRewrite) or in number of lanes (e.g. float16*
// cast to float16x4*).
std::unordered_set<DataType> access_dtype;
+ // Data types used for scalar reads. This is used to record vectorized read
dtypes that can be
+ // shuffled for scalar reads when rewrite_scalar_read_to_vector_shuffle is
enabled.
+ std::unordered_set<DataType> scalar_read_dtype;
DataType get_preferred_dtype() const {
std::unordered_set<DataType> base_access_dtype;
for (auto dtype : access_dtype) {
base_access_dtype.insert(dtype.element_of());
}
+ for (auto dtype : scalar_read_dtype) {
+ base_access_dtype.insert(dtype.element_of());
+ }
// If the array is accessed as multiple base types within a
// function, no point in changing the declared type. CodeGenC can
// handle this with a type-cast prior to indexing. Vulkan will
@@ -1088,12 +1095,19 @@ struct BufferVarInfo {
// size, then the buffer is vectorizable. In the future, this
// could be improved to allow vectorized buffer access of size
// GCD(*lanes_used), if necessary.
+ // When there are scalar reads and no writes, access_dtype can be empty
and we should avoid
+ // rewriting.
int preferred_lanes = element_dtype.lanes();
- if ((element_dtype.lanes() == 1) && (access_dtype.size() == 1)) {
+ if (element_dtype.lanes() == 1 && (access_dtype.size() == 1)) {
+ int lanes = access_dtype.begin()->lanes();
+ // Check the scalar read dtypes are compatible with the vectorized
access dtype.
+ for (auto dtype : scalar_read_dtype) {
+ if (dtype.lanes() % lanes != 0) {
+ return element_dtype;
+ }
+ }
arith::Analyzer analyzer_;
arith::ModularSet me = analyzer_.modular_set(extent);
-
- int lanes = access_dtype.begin()->lanes();
if ((me->coeff % lanes == 0) && (me->base % lanes == 0)) {
preferred_lanes = lanes;
}
@@ -1120,8 +1134,10 @@ class VectorTypeAccessChecker : public StmtExprVisitor {
* type as it is later accessed, with scalar element types.
*/
VectorTypeAccessChecker(const Array<tir::Var>& params, const Map<Var,
Buffer>& buffer_map,
- bool allow_untyped_pointers = false)
- : allow_untyped_pointers_(allow_untyped_pointers) {
+ bool allow_untyped_pointers = false,
+ bool detect_scalar_read_patterns = true)
+ : allow_untyped_pointers_(allow_untyped_pointers),
+ detect_scalar_read_patterns_(detect_scalar_read_patterns) {
// If a parameter is in the buffer map, we want to track the
// version in the map.
for (auto it : buffer_map) {
@@ -1145,12 +1161,12 @@ class VectorTypeAccessChecker : public StmtExprVisitor {
}
void VisitExpr_(const BufferLoadNode* op) final {
- OnArrayAccess(op->dtype, op->buffer->data.get(), op->indices);
+ OnArrayAccess(op->dtype, op->buffer->data.get(), op->indices,
/*is_buffer_load=*/true);
StmtExprVisitor::VisitExpr_(op);
}
void VisitStmt_(const BufferStoreNode* op) final {
- OnArrayAccess(op->value.dtype(), op->buffer->data.get(), op->indices);
+ OnArrayAccess(op->value.dtype(), op->buffer->data.get(), op->indices,
/*is_buffer_load=*/false);
StmtExprVisitor::VisitStmt_(op);
}
@@ -1159,7 +1175,10 @@ class VectorTypeAccessChecker : public StmtExprVisitor {
DataType dtype = op->args[0].dtype();
const VarNode* buffer = op->args[1].as<VarNode>();
PrimExpr index = op->args[2];
- OnArrayAccess(dtype, buffer, {index});
+ OnArrayAccess(dtype, buffer, {index}, false);
+ } else if (op->op.same_as(builtin::address_of())) {
+ BufferLoad load = Downcast<BufferLoad>(op->args[0]);
+ OnArrayAccess(load->dtype, load->buffer->data.get(), load->indices,
/*is_buffer_load=*/false);
}
StmtExprVisitor::VisitExpr_(op);
}
@@ -1226,8 +1245,7 @@ class VectorTypeAccessChecker : public StmtExprVisitor {
if (element_dtype == DataType::Bool()) {
element_dtype = DataType::Int(8).with_lanes(element_dtype.lanes());
}
-
- info_map_[buffer.get()] = {buffer, element_dtype, extent,
declaration_location};
+ info_map_[buffer.get()] = BufferVarInfo{buffer, element_dtype, extent,
declaration_location};
}
/* Update the type map for a buffer based on its usage
@@ -1237,11 +1255,12 @@ class VectorTypeAccessChecker : public StmtExprVisitor {
*
* @param buffer The VarNode representing the buffer.
*
- * @param index The index at which the value is being stored/loaded.
+ * @param indices The index at which the value is being stored/loaded.
*
- * @param predicate The predicate used for the store/load.
+ * @param is_buffer_load Whether the access is BufferLoad
*/
- void OnArrayAccess(DataType value_dtype, const VarNode* buffer, const
Array<PrimExpr>& indices) {
+ void OnArrayAccess(DataType value_dtype, const VarNode* buffer, const
Array<PrimExpr>& indices,
+ bool is_buffer_load) {
auto it = info_map_.find(buffer);
ICHECK(it != info_map_.end()) << "Load/Store of buffer " <<
buffer->name_hint << " (" << buffer
<< ") occurred before its declaration.";
@@ -1304,6 +1323,14 @@ class VectorTypeAccessChecker : public StmtExprVisitor {
}
}
+ if (detect_scalar_read_patterns_ && is_buffer_load && indices.size()) {
+ const PrimExpr last_dim_index = indices[indices.size() - 1];
+ if (last_dim_index.dtype().lanes() == 1) {
+ arith::ModularSet me = analyzer_.modular_set(last_dim_index);
+ var_info.scalar_read_dtype.emplace(access_dtype.with_lanes(me->coeff));
+ return;
+ }
+ }
var_info.access_dtype.insert(access_dtype.with_lanes(lanes_used));
}
@@ -1312,6 +1339,8 @@ class VectorTypeAccessChecker : public StmtExprVisitor {
//
bool allow_untyped_pointers_{false};
+ // Whether to detect scalar read patterns for rewriting to vector shuffle
+ bool detect_scalar_read_patterns_{true};
// internal analyzer
arith::Analyzer analyzer_;
@@ -1366,7 +1395,8 @@ class VectorTypeRewriter : public StmtExprMutator {
VectorTypeRewriter(const std::unordered_map<const VarNode*, BufferVarInfo>&
info_map,
bool rewrite_params = true, bool rewrite_buffer_map =
true,
bool rewrite_allocate_node = true, bool rewrite_indices =
true,
- bool rewrite_let_node = true, bool
rewrite_allocate_const_node = true)
+ bool rewrite_let_node = true, bool
rewrite_allocate_const_node = true,
+ bool rewrite_scalar_read_to_vector_shuffle = true)
: rewrite_indices_(rewrite_indices) {
int rewrite_mask = 0;
if (rewrite_params) {
@@ -1401,42 +1431,53 @@ class VectorTypeRewriter : public StmtExprMutator {
}
}
+ /*!
+ * \brief Mutator for BufferLoad or BufferStore.
+ * \return The rewritten node and the shuffle index. (Only for BufferLoad)
When the shuffle index
+ * is non-negative, the caller should generate Shuffle to extract the
element from the vector.
+ */
template <typename Node>
- Node VisitBufferAccess(Node node) {
+ std::pair<Node, int> VisitBufferAccess(Node node) {
+ int shuffle_index = -1;
if (!rewrite_indices_) {
- return node;
+ return {node, shuffle_index};
}
auto it = rewrite_map_.find(node->buffer->data.get());
if (it == rewrite_map_.end()) {
- return node;
+ return {node, shuffle_index};
}
const auto& info = it->second;
Array<PrimExpr> indices = node->indices;
-
- const RampNode* ramp_index = indices[indices.size() - 1].as<RampNode>();
- if (ramp_index && is_one(ramp_index->stride)) {
+ const PrimExpr& last_dim_index = indices[indices.size() - 1];
+ if (const RampNode* ramp_index = last_dim_index.as<RampNode>();
+ ramp_index && is_one(ramp_index->stride)) {
PrimExpr new_index =
ramp_index->base / make_const(ramp_index->base.dtype(),
ramp_index->lanes);
if (ramp_index->lanes != info.factor()) {
- new_index = Ramp(new_index, ramp_index->stride, ramp_index->lanes /
info.factor(),
- ramp_index->span);
+ ICHECK(info.factor() && ramp_index->lanes % info.factor() == 0);
+ int new_lanes = ramp_index->lanes / info.factor();
+ new_index = Ramp(new_index * new_lanes, ramp_index->stride, new_lanes,
ramp_index->span);
}
-
+ indices.Set(indices.size() - 1, new_index);
+ } else if (last_dim_index.dtype().lanes() == 1 && info.factor() > 1) {
+ arith::ModularSet me = analyzer_.modular_set(last_dim_index);
+ ICHECK(me->coeff == 0 || info.factor() % me->coeff == 0);
+ PrimExpr new_index = last_dim_index / make_const(last_dim_index.dtype(),
info.factor());
+ shuffle_index = me->base;
indices.Set(indices.size() - 1, new_index);
}
auto writer = node.CopyOnWrite();
writer->buffer = RemapBuffer(node->buffer);
writer->indices = indices;
-
- return node;
+ return {node, shuffle_index};
}
PrimExpr VisitExpr_(const BufferLoadNode* op) final {
auto node = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
- auto modified = VisitBufferAccess(node);
+ auto [modified, shuffle_index] = VisitBufferAccess(node);
// Not needed for BufferStoreNode, so we can't just call
// LegalizeDtype() in VisitBufferAccess.
@@ -1445,13 +1486,18 @@ class VectorTypeRewriter : public StmtExprMutator {
} else {
auto writer = modified.CopyOnWrite();
writer->LegalizeDType();
+ if (shuffle_index >= 0) {
+ return Shuffle::ExtractElement(std::move(modified), shuffle_index);
+ }
return std::move(modified);
}
}
Stmt VisitStmt_(const BufferStoreNode* op) final {
auto node = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
- return VisitBufferAccess(std::move(node));
+ auto [modified, shuffle_index] = VisitBufferAccess(std::move(node));
+ ICHECK(shuffle_index < 0);
+ return std::move(modified);
}
Stmt VisitStmt_(const LetStmtNode* op) final {
@@ -1627,6 +1673,7 @@ class VectorTypeRewriter : public StmtExprMutator {
bool rewrite_indices_{true};
std::unordered_map<const VarNode*, RewriteInfo> rewrite_map_;
std::unordered_map<const BufferNode*, Buffer> buffer_map_;
+ arith::Analyzer analyzer_;
};
// Rewrite allocates, pointer parameters, and buffer map into vectorized
versions
@@ -1635,13 +1682,15 @@ PrimFunc PointerValueTypeRewrite(PrimFunc f, bool
allow_untyped_pointers = false
bool rewrite_params = true, bool
rewrite_buffer_map = true,
bool rewrite_allocate_node = true, bool
rewrite_indices = true,
bool rewrite_let_node = true,
- bool rewrite_allocate_const_node = true) {
- VectorTypeAccessChecker checker(f->params, f->buffer_map,
allow_untyped_pointers);
+ bool rewrite_allocate_const_node = true,
+ bool rewrite_scalar_read_to_vector_shuffle =
true) {
+ VectorTypeAccessChecker checker(f->params, f->buffer_map,
allow_untyped_pointers,
+ rewrite_scalar_read_to_vector_shuffle);
checker(f->body);
VectorTypeRewriter rewriter(checker.info_map_, rewrite_params,
rewrite_buffer_map,
rewrite_allocate_node, rewrite_indices,
rewrite_let_node,
- rewrite_allocate_const_node);
+ rewrite_allocate_const_node,
rewrite_scalar_read_to_vector_shuffle);
PrimFuncNode* n = f.CopyOnWrite();
n->body = rewriter(std::move(n->body));
rewriter.Finalize(&f);
@@ -1661,7 +1710,8 @@ Pass StorageRewrite() {
// padded out to 32 bits) would require either rewriting
// AllocateConst::data, or would require the code generators to
// handle vectorized constants.
- return PointerValueTypeRewrite(std::move(f), true, false, false, true,
true, true, false);
+ return PointerValueTypeRewrite(std::move(f), true, false, false, true,
true, true, false,
+ false);
};
return CreatePrimFuncPass(pass_func, 0, "tir.StorageRewrite", {});
}
diff --git
a/tests/python/unittest/test_tir_transform_pointer_value_type_rewrite.py
b/tests/python/unittest/test_tir_transform_pointer_value_type_rewrite.py
new file mode 100644
index 0000000000..7baa96c1a1
--- /dev/null
+++ b/tests/python/unittest/test_tir_transform_pointer_value_type_rewrite.py
@@ -0,0 +1,73 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import tvm
+import tvm.testing
+from tvm import te
+from tvm.driver.build_module import schedule_to_module
+from tvm.script import tir as T
+
+
+class BaseCompare(tvm.testing.CompareBeforeAfter):
+ transform = tvm.tir.transform.PointerValueTypeRewrite()
+
+
+class TestRewriteToShuffle(BaseCompare):
+ @T.prim_func
+ def before(A: T.Buffer((16,), "float32"), B: T.Buffer((4,), "float32")):
+ A_local_data = T.allocate([16], "float32", scope="local")
+ A_local = T.Buffer((16,), "float32", data=A_local_data, scope="local")
+ for i in range(4):
+ A_local[i * 4 : i * 4 + 4] = A[i * 4 : i * 4 + 4]
+ for i in range(4):
+ B[i] = A_local[i * 4] + A_local[i * 4 + 1] + A_local[i * 4 + 2] +
A_local[i * 4 + 3]
+
+ @T.prim_func
+ def expected(A: T.Buffer((4,), "float32x4"), B: T.Buffer((4,), "float32")):
+ A_local_data = T.allocate([4], "float32x4", scope="local")
+ A_local = T.Buffer((4,), "float32x4", data=A_local_data, scope="local")
+ for i in range(4):
+ A_local[T.Div(i * 4, 4)] = A[T.Div(i * 4, 4)]
+ for i in range(4):
+ B[i] = (
+ T.Shuffle([A_local[T.Div(i * 4, 4)]], [0])
+ + T.Shuffle([A_local[T.Div(i * 4 + 1, 4)]], [1])
+ + T.Shuffle([A_local[T.Div(i * 4 + 2, 4)]], [2])
+ + T.Shuffle([A_local[T.Div(i * 4 + 3, 4)]], [3])
+ )
+
+
+class TestAddressOf(BaseCompare):
+ @T.prim_func
+ def before(A: T.Buffer((16,), "float32"), B: T.Buffer((16,), "float32")):
+ for i in range(4):
+ T.evaluate(T.address_of(A[i * 4]))
+ B[i * 4 : i * 4 + 4] = A[i * 4 : i * 4 + 4]
+
+ @T.prim_func
+ def expected(A: T.Buffer((16,), "float32"), B: T.Buffer((4,),
"float32x4")):
+ for i in range(4):
+ T.evaluate(T.address_of(A[i * 4]))
+ B[T.Div(i * 4, 4)] = A[i * 4 : i * 4 + 4]
+
+
+class TestScalarReadWithoutWrite(BaseCompare):
+ @T.prim_func
+ def before(A: T.Buffer((16,), "float32")):
+ for i in range(4):
+ T.evaluate(A[i * 4])
+
+ expected = before