This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 925148e444 [TIR] Shuffle in PointerValueTypeRewrite for scalar reads  
(#15517)
925148e444 is described below

commit 925148e444103f044e9dbe111aacf0c5079abc3a
Author: Wuwei Lin <[email protected]>
AuthorDate: Fri Aug 18 12:11:06 2023 -0700

    [TIR] Shuffle in PointerValueTypeRewrite for scalar reads  (#15517)
---
 python/tvm/tir/transform/transform.py              |  14 +++
 src/target/spirv/codegen_spirv.cc                  |  11 ++
 src/target/spirv/codegen_spirv.h                   |   1 +
 src/tir/transforms/storage_rewrite.cc              | 112 +++++++++++++++------
 ...est_tir_transform_pointer_value_type_rewrite.py |  73 ++++++++++++++
 5 files changed, 180 insertions(+), 31 deletions(-)

diff --git a/python/tvm/tir/transform/transform.py 
b/python/tvm/tir/transform/transform.py
index 0cd54064a7..a46b2d1037 100644
--- a/python/tvm/tir/transform/transform.py
+++ b/python/tvm/tir/transform/transform.py
@@ -230,6 +230,20 @@ def StorageRewrite():
     return _ffi_api.StorageRewrite()  # type: ignore
 
 
+def PointerValueTypeRewrite():
+    """
+    Rewrite the pointer content type of arguments, as well as Alloc internal 
to the function to use
+    the most frequently accessed type for load/store to avoid pointer casting 
in backend when
+    possible.
+
+    Returns
+    -------
+    fpass : tvm.transform.Pass
+        The result pass
+    """
+    return _ffi_api.PointerValueTypeRewrite()  # type: ignore
+
+
 def UnrollLoop():
     """Unroll the constant loop marked by unroll.
 
diff --git a/src/target/spirv/codegen_spirv.cc 
b/src/target/spirv/codegen_spirv.cc
index ab9aec0775..5cc3f8f8dd 100644
--- a/src/target/spirv/codegen_spirv.cc
+++ b/src/target/spirv/codegen_spirv.cc
@@ -610,6 +610,17 @@ void CodeGenSPIRV::Scalarize(const PrimExpr& e, 
std::function<void(int i, spirv:
   }
 }
 
+spirv::Value CodeGenSPIRV::VisitExpr_(const ShuffleNode* op) {
+  ICHECK(op->vectors.size() == 1 && op->indices.size() == 1)
+      << "SPIR-V codegen only supports shuffle "
+      << "of one vector with one index";
+  spirv::Value vector = MakeValue(op->vectors[0]);
+  int index = Downcast<Integer>(op->indices[0])->value;
+  spirv::SType etype = builder_->GetSType(op->dtype);
+  spirv::Value element = builder_->MakeValue(spv::OpCompositeExtract, etype, 
vector, index);
+  return element;
+}
+
 void CodeGenSPIRV::VisitStmt_(const BufferStoreNode* op) {
   ICHECK_EQ(op->indices.size(), 1) << "SPIR-V codegen expects flat memory 
buffers";
   Var buffer_var = op->buffer->data;
diff --git a/src/target/spirv/codegen_spirv.h b/src/target/spirv/codegen_spirv.h
index 1e7b535585..8ea90a9c4b 100644
--- a/src/target/spirv/codegen_spirv.h
+++ b/src/target/spirv/codegen_spirv.h
@@ -102,6 +102,7 @@ class CodeGenSPIRV : public ExprFunctor<spirv::Value(const 
PrimExpr&)>,
   spirv::Value VisitExpr_(const RampNode* op) override;
   spirv::Value VisitExpr_(const BroadcastNode* op) override;
   spirv::Value VisitExpr_(const BufferLoadNode* op) override;
+  spirv::Value VisitExpr_(const ShuffleNode* op) override;
   // stmt
   void VisitStmt_(const BufferStoreNode* op) override;
   void VisitStmt_(const ForNode* op) override;
diff --git a/src/tir/transforms/storage_rewrite.cc 
b/src/tir/transforms/storage_rewrite.cc
index 3ecd0f64bb..f271769c80 100644
--- a/src/tir/transforms/storage_rewrite.cc
+++ b/src/tir/transforms/storage_rewrite.cc
@@ -36,6 +36,7 @@
 #include <unordered_map>
 #include <unordered_set>
 
+#include "../../arith/int_operator.h"
 #include "../../runtime/thread_storage_scope.h"
 #include "../ir/buffer_common.h"
 #include "ir_utils.h"
@@ -1066,12 +1067,18 @@ struct BufferVarInfo {
   // packing in StorageRewrite) or in number of lanes (e.g. float16*
   // cast to float16x4*).
   std::unordered_set<DataType> access_dtype;
+  // Data types used for scalar reads. This is used to record vectorized read 
dtypes that can be
+  // shuffled for scalar reads when rewrite_scalar_read_to_vector_shuffle is 
enabled.
+  std::unordered_set<DataType> scalar_read_dtype;
 
   DataType get_preferred_dtype() const {
     std::unordered_set<DataType> base_access_dtype;
     for (auto dtype : access_dtype) {
       base_access_dtype.insert(dtype.element_of());
     }
+    for (auto dtype : scalar_read_dtype) {
+      base_access_dtype.insert(dtype.element_of());
+    }
     // If the array is accessed as multiple base types within a
     // function, no point in changing the declared type.  CodeGenC can
     // handle this with a type-cast prior to indexing.  Vulkan will
@@ -1088,12 +1095,19 @@ struct BufferVarInfo {
     // size, then the buffer is vectorizable.  In the future, this
     // could be improved to allow vectorized buffer access of size
     // GCD(*lanes_used), if necessary.
+    // When there are scalar reads and no writes, access_dtype can be empty 
and we should avoid
+    // rewriting.
     int preferred_lanes = element_dtype.lanes();
-    if ((element_dtype.lanes() == 1) && (access_dtype.size() == 1)) {
+    if (element_dtype.lanes() == 1 && (access_dtype.size() == 1)) {
+      int lanes = access_dtype.begin()->lanes();
+      // Check the scalar read dtypes are compatible with the vectorized 
access dtype.
+      for (auto dtype : scalar_read_dtype) {
+        if (dtype.lanes() % lanes != 0) {
+          return element_dtype;
+        }
+      }
       arith::Analyzer analyzer_;
       arith::ModularSet me = analyzer_.modular_set(extent);
-
-      int lanes = access_dtype.begin()->lanes();
       if ((me->coeff % lanes == 0) && (me->base % lanes == 0)) {
         preferred_lanes = lanes;
       }
@@ -1120,8 +1134,10 @@ class VectorTypeAccessChecker : public StmtExprVisitor {
    * type as it is later accessed, with scalar element types.
    */
   VectorTypeAccessChecker(const Array<tir::Var>& params, const Map<Var, 
Buffer>& buffer_map,
-                          bool allow_untyped_pointers = false)
-      : allow_untyped_pointers_(allow_untyped_pointers) {
+                          bool allow_untyped_pointers = false,
+                          bool detect_scalar_read_patterns = true)
+      : allow_untyped_pointers_(allow_untyped_pointers),
+        detect_scalar_read_patterns_(detect_scalar_read_patterns) {
     // If a parameter is in the buffer map, we want to track the
     // version in the map.
     for (auto it : buffer_map) {
@@ -1145,12 +1161,12 @@ class VectorTypeAccessChecker : public StmtExprVisitor {
   }
 
   void VisitExpr_(const BufferLoadNode* op) final {
-    OnArrayAccess(op->dtype, op->buffer->data.get(), op->indices);
+    OnArrayAccess(op->dtype, op->buffer->data.get(), op->indices, 
/*is_buffer_load=*/true);
     StmtExprVisitor::VisitExpr_(op);
   }
 
   void VisitStmt_(const BufferStoreNode* op) final {
-    OnArrayAccess(op->value.dtype(), op->buffer->data.get(), op->indices);
+    OnArrayAccess(op->value.dtype(), op->buffer->data.get(), op->indices, 
/*is_buffer_load=*/false);
     StmtExprVisitor::VisitStmt_(op);
   }
 
@@ -1159,7 +1175,10 @@ class VectorTypeAccessChecker : public StmtExprVisitor {
       DataType dtype = op->args[0].dtype();
       const VarNode* buffer = op->args[1].as<VarNode>();
       PrimExpr index = op->args[2];
-      OnArrayAccess(dtype, buffer, {index});
+      OnArrayAccess(dtype, buffer, {index}, false);
+    } else if (op->op.same_as(builtin::address_of())) {
+      BufferLoad load = Downcast<BufferLoad>(op->args[0]);
+      OnArrayAccess(load->dtype, load->buffer->data.get(), load->indices, 
/*is_buffer_load=*/false);
     }
     StmtExprVisitor::VisitExpr_(op);
   }
@@ -1226,8 +1245,7 @@ class VectorTypeAccessChecker : public StmtExprVisitor {
     if (element_dtype == DataType::Bool()) {
       element_dtype = DataType::Int(8).with_lanes(element_dtype.lanes());
     }
-
-    info_map_[buffer.get()] = {buffer, element_dtype, extent, 
declaration_location};
+    info_map_[buffer.get()] = BufferVarInfo{buffer, element_dtype, extent, 
declaration_location};
   }
 
   /* Update the type map for a buffer based on its usage
@@ -1237,11 +1255,12 @@ class VectorTypeAccessChecker : public StmtExprVisitor {
    *
    * @param buffer The VarNode representing the buffer.
    *
-   * @param index The index at which the value is being stored/loaded.
+   * @param indices The index at which the value is being stored/loaded.
    *
-   * @param predicate The predicate used for the store/load.
+   * @param is_buffer_load Whether the access is BufferLoad
    */
-  void OnArrayAccess(DataType value_dtype, const VarNode* buffer, const 
Array<PrimExpr>& indices) {
+  void OnArrayAccess(DataType value_dtype, const VarNode* buffer, const 
Array<PrimExpr>& indices,
+                     bool is_buffer_load) {
     auto it = info_map_.find(buffer);
     ICHECK(it != info_map_.end()) << "Load/Store of buffer " << 
buffer->name_hint << " (" << buffer
                                   << ") occurred before its declaration.";
@@ -1304,6 +1323,14 @@ class VectorTypeAccessChecker : public StmtExprVisitor {
       }
     }
 
+    if (detect_scalar_read_patterns_ && is_buffer_load && indices.size()) {
+      const PrimExpr last_dim_index = indices[indices.size() - 1];
+      if (last_dim_index.dtype().lanes() == 1) {
+        arith::ModularSet me = analyzer_.modular_set(last_dim_index);
+        var_info.scalar_read_dtype.emplace(access_dtype.with_lanes(me->coeff));
+        return;
+      }
+    }
     var_info.access_dtype.insert(access_dtype.with_lanes(lanes_used));
   }
 
@@ -1312,6 +1339,8 @@ class VectorTypeAccessChecker : public StmtExprVisitor {
 
   //
   bool allow_untyped_pointers_{false};
+  // Whether to detect scalar read patterns for rewriting to vector shuffle
+  bool detect_scalar_read_patterns_{true};
 
   // internal analyzer
   arith::Analyzer analyzer_;
@@ -1366,7 +1395,8 @@ class VectorTypeRewriter : public StmtExprMutator {
   VectorTypeRewriter(const std::unordered_map<const VarNode*, BufferVarInfo>& 
info_map,
                      bool rewrite_params = true, bool rewrite_buffer_map = 
true,
                      bool rewrite_allocate_node = true, bool rewrite_indices = 
true,
-                     bool rewrite_let_node = true, bool 
rewrite_allocate_const_node = true)
+                     bool rewrite_let_node = true, bool 
rewrite_allocate_const_node = true,
+                     bool rewrite_scalar_read_to_vector_shuffle = true)
       : rewrite_indices_(rewrite_indices) {
     int rewrite_mask = 0;
     if (rewrite_params) {
@@ -1401,42 +1431,53 @@ class VectorTypeRewriter : public StmtExprMutator {
     }
   }
 
+  /*!
+   * \brief Mutator for BufferLoad or BufferStore.
+   * \return The rewritten node and the shuffle index. (Only for BufferLoad) 
When the shuffle index
+   * is non-negative, the caller should generate Shuffle to extract the 
element from the vector.
+   */
   template <typename Node>
-  Node VisitBufferAccess(Node node) {
+  std::pair<Node, int> VisitBufferAccess(Node node) {
+    int shuffle_index = -1;
     if (!rewrite_indices_) {
-      return node;
+      return {node, shuffle_index};
     }
 
     auto it = rewrite_map_.find(node->buffer->data.get());
     if (it == rewrite_map_.end()) {
-      return node;
+      return {node, shuffle_index};
     }
     const auto& info = it->second;
 
     Array<PrimExpr> indices = node->indices;
-
-    const RampNode* ramp_index = indices[indices.size() - 1].as<RampNode>();
-    if (ramp_index && is_one(ramp_index->stride)) {
+    const PrimExpr& last_dim_index = indices[indices.size() - 1];
+    if (const RampNode* ramp_index = last_dim_index.as<RampNode>();
+        ramp_index && is_one(ramp_index->stride)) {
       PrimExpr new_index =
           ramp_index->base / make_const(ramp_index->base.dtype(), 
ramp_index->lanes);
       if (ramp_index->lanes != info.factor()) {
-        new_index = Ramp(new_index, ramp_index->stride, ramp_index->lanes / 
info.factor(),
-                         ramp_index->span);
+        ICHECK(info.factor() && ramp_index->lanes % info.factor() == 0);
+        int new_lanes = ramp_index->lanes / info.factor();
+        new_index = Ramp(new_index * new_lanes, ramp_index->stride, new_lanes, 
ramp_index->span);
       }
-
+      indices.Set(indices.size() - 1, new_index);
+    } else if (last_dim_index.dtype().lanes() == 1 && info.factor() > 1) {
+      arith::ModularSet me = analyzer_.modular_set(last_dim_index);
+      ICHECK(me->coeff == 0 || info.factor() % me->coeff == 0);
+      PrimExpr new_index = last_dim_index / make_const(last_dim_index.dtype(), 
info.factor());
+      shuffle_index = me->base;
       indices.Set(indices.size() - 1, new_index);
     }
 
     auto writer = node.CopyOnWrite();
     writer->buffer = RemapBuffer(node->buffer);
     writer->indices = indices;
-
-    return node;
+    return {node, shuffle_index};
   }
 
   PrimExpr VisitExpr_(const BufferLoadNode* op) final {
     auto node = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
-    auto modified = VisitBufferAccess(node);
+    auto [modified, shuffle_index] = VisitBufferAccess(node);
 
     // Not needed for BufferStoreNode, so we can't just call
     // LegalizeDtype() in VisitBufferAccess.
@@ -1445,13 +1486,18 @@ class VectorTypeRewriter : public StmtExprMutator {
     } else {
       auto writer = modified.CopyOnWrite();
       writer->LegalizeDType();
+      if (shuffle_index >= 0) {
+        return Shuffle::ExtractElement(std::move(modified), shuffle_index);
+      }
       return std::move(modified);
     }
   }
 
   Stmt VisitStmt_(const BufferStoreNode* op) final {
     auto node = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
-    return VisitBufferAccess(std::move(node));
+    auto [modified, shuffle_index] = VisitBufferAccess(std::move(node));
+    ICHECK(shuffle_index < 0);
+    return std::move(modified);
   }
 
   Stmt VisitStmt_(const LetStmtNode* op) final {
@@ -1627,6 +1673,7 @@ class VectorTypeRewriter : public StmtExprMutator {
   bool rewrite_indices_{true};
   std::unordered_map<const VarNode*, RewriteInfo> rewrite_map_;
   std::unordered_map<const BufferNode*, Buffer> buffer_map_;
+  arith::Analyzer analyzer_;
 };
 
 // Rewrite allocates, pointer parameters, and buffer map into vectorized 
versions
@@ -1635,13 +1682,15 @@ PrimFunc PointerValueTypeRewrite(PrimFunc f, bool 
allow_untyped_pointers = false
                                  bool rewrite_params = true, bool 
rewrite_buffer_map = true,
                                  bool rewrite_allocate_node = true, bool 
rewrite_indices = true,
                                  bool rewrite_let_node = true,
-                                 bool rewrite_allocate_const_node = true) {
-  VectorTypeAccessChecker checker(f->params, f->buffer_map, 
allow_untyped_pointers);
+                                 bool rewrite_allocate_const_node = true,
+                                 bool rewrite_scalar_read_to_vector_shuffle = 
true) {
+  VectorTypeAccessChecker checker(f->params, f->buffer_map, 
allow_untyped_pointers,
+                                  rewrite_scalar_read_to_vector_shuffle);
   checker(f->body);
 
   VectorTypeRewriter rewriter(checker.info_map_, rewrite_params, 
rewrite_buffer_map,
                               rewrite_allocate_node, rewrite_indices, 
rewrite_let_node,
-                              rewrite_allocate_const_node);
+                              rewrite_allocate_const_node, 
rewrite_scalar_read_to_vector_shuffle);
   PrimFuncNode* n = f.CopyOnWrite();
   n->body = rewriter(std::move(n->body));
   rewriter.Finalize(&f);
@@ -1661,7 +1710,8 @@ Pass StorageRewrite() {
     // padded out to 32 bits) would require either rewriting
     // AllocateConst::data, or would require the code generators to
     // handle vectorized constants.
-    return PointerValueTypeRewrite(std::move(f), true, false, false, true, 
true, true, false);
+    return PointerValueTypeRewrite(std::move(f), true, false, false, true, 
true, true, false,
+                                   false);
   };
   return CreatePrimFuncPass(pass_func, 0, "tir.StorageRewrite", {});
 }
diff --git 
a/tests/python/unittest/test_tir_transform_pointer_value_type_rewrite.py 
b/tests/python/unittest/test_tir_transform_pointer_value_type_rewrite.py
new file mode 100644
index 0000000000..7baa96c1a1
--- /dev/null
+++ b/tests/python/unittest/test_tir_transform_pointer_value_type_rewrite.py
@@ -0,0 +1,73 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import tvm
+import tvm.testing
+from tvm import te
+from tvm.driver.build_module import schedule_to_module
+from tvm.script import tir as T
+
+
+class BaseCompare(tvm.testing.CompareBeforeAfter):
+    transform = tvm.tir.transform.PointerValueTypeRewrite()
+
+
+class TestRewriteToShuffle(BaseCompare):
+    @T.prim_func
+    def before(A: T.Buffer((16,), "float32"), B: T.Buffer((4,), "float32")):
+        A_local_data = T.allocate([16], "float32", scope="local")
+        A_local = T.Buffer((16,), "float32", data=A_local_data, scope="local")
+        for i in range(4):
+            A_local[i * 4 : i * 4 + 4] = A[i * 4 : i * 4 + 4]
+        for i in range(4):
+            B[i] = A_local[i * 4] + A_local[i * 4 + 1] + A_local[i * 4 + 2] + 
A_local[i * 4 + 3]
+
+    @T.prim_func
+    def expected(A: T.Buffer((4,), "float32x4"), B: T.Buffer((4,), "float32")):
+        A_local_data = T.allocate([4], "float32x4", scope="local")
+        A_local = T.Buffer((4,), "float32x4", data=A_local_data, scope="local")
+        for i in range(4):
+            A_local[T.Div(i * 4, 4)] = A[T.Div(i * 4, 4)]
+        for i in range(4):
+            B[i] = (
+                T.Shuffle([A_local[T.Div(i * 4, 4)]], [0])
+                + T.Shuffle([A_local[T.Div(i * 4 + 1, 4)]], [1])
+                + T.Shuffle([A_local[T.Div(i * 4 + 2, 4)]], [2])
+                + T.Shuffle([A_local[T.Div(i * 4 + 3, 4)]], [3])
+            )
+
+
+class TestAddressOf(BaseCompare):
+    @T.prim_func
+    def before(A: T.Buffer((16,), "float32"), B: T.Buffer((16,), "float32")):
+        for i in range(4):
+            T.evaluate(T.address_of(A[i * 4]))
+            B[i * 4 : i * 4 + 4] = A[i * 4 : i * 4 + 4]
+
+    @T.prim_func
+    def expected(A: T.Buffer((16,), "float32"), B: T.Buffer((4,), 
"float32x4")):
+        for i in range(4):
+            T.evaluate(T.address_of(A[i * 4]))
+            B[T.Div(i * 4, 4)] = A[i * 4 : i * 4 + 4]
+
+
+class TestScalarReadWithoutWrite(BaseCompare):
+    @T.prim_func
+    def before(A: T.Buffer((16,), "float32")):
+        for i in range(4):
+            T.evaluate(A[i * 4])
+
+    expected = before

Reply via email to