This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 8f7c34393a [Unity][TIR][Pass] ForceNarrowIndexToInt32 (#14203)
8f7c34393a is described below

commit 8f7c34393aef847814625cbb323159cdd39518c2
Author: Ruihang Lai <[email protected]>
AuthorDate: Sun Mar 5 19:49:50 2023 -0500

    [Unity][TIR][Pass] ForceNarrowIndexToInt32 (#14203)
    
    [TIR][Pass] ForceNarrowIndexToInt32
    
    This PR introduces a pass which forces every index expression in a
    PrimFunc to have dtype int32. Meanwhile, it also checks if all integer
    buffers in the PrimFunc have int32 dtype, and report error if some
    integer buffer has dtype other than int32.
    
    In terms of implementation, this pass leverages the
    IndexDataTypeNormalizer, with the target dtype being int32.
    
    This PR contains a few basic tests that come from
    `test_tir_transform_narrow_datatype.py`, and contains some negative
    tests as well.
---
 include/tvm/tir/transform.h                        |   8 +
 python/tvm/tir/transform/transform.py              |  15 ++
 src/tir/transforms/force_narrow_index_to_i32.cc    |  84 ++++++++
 ...test_tir_transform_force_narrow_index_to_i32.py | 220 +++++++++++++++++++++
 4 files changed, 327 insertions(+)

diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h
index fee5db0875..d212578b8d 100644
--- a/include/tvm/tir/transform.h
+++ b/include/tvm/tir/transform.h
@@ -336,6 +336,14 @@ TVM_DLL Pass CombineContextCall();
  */
 TVM_DLL Pass NarrowDataType(int target_bits);
 
+/*!
+ * \brief Force to narrow down indexing expressions and integer buffers to 
int32 dtype.
+ *
+ * \return The pass.
+ * \note This pass should not be used in default cases.
+ */
+TVM_DLL Pass ForceNarrowIndexToInt32();
+
 /*!
  * \brief Legalize bf16 typed Ops. Add a cast to fp32
  *   before Ops, then add a cast back to bf16.
diff --git a/python/tvm/tir/transform/transform.py 
b/python/tvm/tir/transform/transform.py
index bc3ec5b2ad..a6e0cf06cb 100644
--- a/python/tvm/tir/transform/transform.py
+++ b/python/tvm/tir/transform/transform.py
@@ -611,6 +611,21 @@ def NarrowDataType(target_bits: int):
     return _ffi_api.NarrowDataType(target_bits)  # type: ignore
 
 
+def ForceNarrowIndexToInt32():
+    """Force narrow down indexing expressions and integer buffers to int32 
dtype.
+
+    Returns
+    -------
+    fpass : tvm.transform.Pass
+        The result pass
+
+    Note
+    ----
+    This pass should not be used in default cases.
+    """
+    return _ffi_api.ForceNarrowIndexToInt32()  # type: ignore
+
+
 def VerifyMemory():
     """Verify if func contains illegal host side direct memory access.
 
diff --git a/src/tir/transforms/force_narrow_index_to_i32.cc 
b/src/tir/transforms/force_narrow_index_to_i32.cc
new file mode 100644
index 0000000000..70dc554e12
--- /dev/null
+++ b/src/tir/transforms/force_narrow_index_to_i32.cc
@@ -0,0 +1,84 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file force_narrow_index_to_i32.cc
+ * \brief Force narrow down indexing expressions and integer buffers to int32 
dtype.
+ * \note This pass is not used in default cases.
+ */
+
+#include <tvm/tir/data_type_rewriter.h>
+#include <tvm/tir/transform.h>
+
+namespace tvm {
+namespace tir {
+
+class Int32DTypeNarrower : public IndexDataTypeNormalizer {
+ public:
+  static PrimFunc RewriteDataType(PrimFunc func) {
+    // Check if the integer parameter buffers have dtype other than int32.
+    for (auto it : func->buffer_map) {
+      if (it.second->dtype.is_int() && it.second->dtype.bits() != 32) {
+        LOG(FATAL) << "The buffer " << it.second << " in the function buffer 
map has dtype "
+                   << it.second->dtype << ". The function is " << func;
+      }
+    }
+
+    Int32DTypeNarrower narrower(func);
+    return narrower.Rewrite(func);
+  }
+
+ private:
+  explicit Int32DTypeNarrower(PrimFunc func)
+      : IndexDataTypeNormalizer(DataType::Int(32)), func_(std::move(func)) {}
+
+  Stmt VisitStmt_(const BlockNode* block) final {
+    Block block_ = Downcast<Block>(IndexDataTypeNormalizer::VisitStmt_(block));
+    // Check if the allocated integer buffers have dtype other than int32.
+    for (const Buffer& buf : block_->alloc_buffers) {
+      if (buf->dtype.is_int() && buf->dtype.bits() != 32) {
+        LOG(FATAL) << "The buffer " << buf << " allocated in the function has 
dtype " << buf->dtype
+                   << ". The function is " << func_;
+      }
+    }
+    return block_;
+  }
+
+  PrimFunc func_;
+};
+
+PrimFunc ForceNarrowIndexToInt32(PrimFunc func) {
+  return Int32DTypeNarrower::RewriteDataType(func);
+}
+
+namespace transform {
+
+Pass ForceNarrowIndexToInt32() {
+  auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
+    return ForceNarrowIndexToInt32(f);
+  };
+  return CreatePrimFuncPass(pass_func, 0, "tir.NarrowDataType", {});
+}
+
+TVM_REGISTER_GLOBAL("tir.transform.ForceNarrowIndexToInt32")
+    .set_body_typed(ForceNarrowIndexToInt32);
+
+}  // namespace transform
+}  // namespace tir
+}  // namespace tvm
diff --git 
a/tests/python/unittest/test_tir_transform_force_narrow_index_to_i32.py 
b/tests/python/unittest/test_tir_transform_force_narrow_index_to_i32.py
new file mode 100644
index 0000000000..f275d438a7
--- /dev/null
+++ b/tests/python/unittest/test_tir_transform_force_narrow_index_to_i32.py
@@ -0,0 +1,220 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import pytest
+import tvm
+from tvm import TVMError
+from tvm.script import tir as T
+import tvm.testing
+
+
+def test_thread_axis1():
+    @T.prim_func
+    def before(A: T.Buffer((T.int64(64),), "float32"), B: 
T.Buffer((T.int64(64),), "float32")):
+        blockIdx_x = T.env_thread("blockIdx.x")
+        T.launch_thread(blockIdx_x, T.int64(2))
+        threadIdx_x = T.env_thread("threadIdx.x")
+        T.launch_thread(threadIdx_x, T.int64(32))
+        B[T.Cast("int64", blockIdx_x) * T.int64(32) + T.Cast("int64", 
threadIdx_x)] = A[
+            T.Cast("int64", blockIdx_x) * T.int64(32) + T.Cast("int64", 
threadIdx_x)
+        ] + T.float32(1)
+
+    @T.prim_func
+    def expected(A: T.Buffer((64,), "float32"), B: T.Buffer((64,), "float32")):
+        blockIdx_x = T.env_thread("blockIdx.x")
+        T.launch_thread(blockIdx_x, 2)
+        threadIdx_x = T.env_thread("threadIdx.x")
+        T.launch_thread(threadIdx_x, 32)
+        B[blockIdx_x * 32 + threadIdx_x] = A[blockIdx_x * 32 + threadIdx_x] + 
T.float32(1)
+
+    mod = tvm.IRModule.from_expr(before)
+    func = tvm.tir.transform.ForceNarrowIndexToInt32()(mod)["main"]
+    tvm.ir.assert_structural_equal(func, expected)
+
+
+def test_thread_axis2():
+    @T.prim_func
+    def before(
+        T_reshape: T.Buffer((1, 12, 384, 384), "float32"),
+        placeholder_1: T.Buffer((T.int64(1), T.int64(12), T.int64(384), 384), 
"bool"),
+        T_where: T.Buffer((T.int64(1), T.int64(12), T.int64(384), 384), 
"float32"),
+    ) -> None:
+        T.func_attr({"global_symbol": "main", "tir.noalias": True})
+        for i0_i1_i2_i3_fused_1 in T.thread_binding(T.int64(256), 
thread="blockIdx.x"):
+            for i0_i1_i2_i3_fused_2 in T.thread_binding(T.int64(1024), 
thread="threadIdx.x"):
+                for i0_i1_i2_i3_fused_0 in T.serial(T.int64(7)):
+                    with T.block("T_where"):
+                        ax0 = T.axis.spatial(T.int64(1), T.int64(0))
+                        ax1 = T.axis.spatial(
+                            T.int64(12),
+                            (
+                                (i0_i1_i2_i3_fused_0 * T.int64(256) + 
i0_i1_i2_i3_fused_1)
+                                * T.int64(1024)
+                                + i0_i1_i2_i3_fused_2
+                            )
+                            % T.int64(1769472)
+                            // T.int64(147456),
+                        )
+                        ax2 = T.axis.spatial(
+                            T.int64(384),
+                            (
+                                (i0_i1_i2_i3_fused_0 * T.int64(256) + 
i0_i1_i2_i3_fused_1)
+                                * T.int64(1024)
+                                + i0_i1_i2_i3_fused_2
+                            )
+                            % T.int64(147456)
+                            // T.int64(384),
+                        )
+                        ax3 = T.axis.spatial(
+                            384,
+                            T.cast(
+                                (
+                                    (i0_i1_i2_i3_fused_0 * T.int64(256) + 
i0_i1_i2_i3_fused_1)
+                                    * T.int64(1024)
+                                    + i0_i1_i2_i3_fused_2
+                                )
+                                % T.int64(384),
+                                "int32",
+                            ),
+                        )
+                        T.where(
+                            (i0_i1_i2_i3_fused_0 * T.int64(256) + 
i0_i1_i2_i3_fused_1)
+                            * T.int64(1024)
+                            + i0_i1_i2_i3_fused_2
+                            < T.int64(1769472)
+                        )
+                        T.reads(placeholder_1[ax0, ax1, ax2, ax3], 
T_reshape[ax0, ax1, ax2, ax3])
+                        T.writes(T_where[ax0, ax1, ax2, ax3])
+                        T_where[ax0, ax1, ax2, ax3] = T.Select(
+                            T.cast(placeholder_1[ax0, ax1, ax2, ax3], "int32") 
!= 0,
+                            T.float32(-1000000000),
+                            T_reshape[ax0, ax1, ax2, ax3],
+                        )
+
+    @T.prim_func
+    def expected(
+        T_reshape: T.Buffer((1, 12, 384, 384), "float32"),
+        placeholder_1: T.Buffer((1, 12, 384, 384), "bool"),
+        T_where: T.Buffer((1, 12, 384, 384), "float32"),
+    ):
+        T.func_attr({"global_symbol": "main", "tir.noalias": True})
+        for i0_i1_i2_i3_fused_1 in T.thread_binding(256, thread="blockIdx.x"):
+            for i0_i1_i2_i3_fused_2 in T.thread_binding(1024, 
thread="threadIdx.x"):
+                for i0_i1_i2_i3_fused_0 in range(7):
+                    with T.block("T_where"):
+                        ax0 = T.axis.spatial(1, 0)
+                        ax1 = T.axis.spatial(
+                            12,
+                            (
+                                (i0_i1_i2_i3_fused_0 * 256 + 
i0_i1_i2_i3_fused_1) * 1024
+                                + i0_i1_i2_i3_fused_2
+                            )
+                            % 1769472
+                            // 147456,
+                        )
+                        ax2 = T.axis.spatial(
+                            384,
+                            (
+                                (i0_i1_i2_i3_fused_0 * 256 + 
i0_i1_i2_i3_fused_1) * 1024
+                                + i0_i1_i2_i3_fused_2
+                            )
+                            % 147456
+                            // 384,
+                        )
+                        ax3 = T.axis.spatial(
+                            384,
+                            (
+                                (i0_i1_i2_i3_fused_0 * 256 + 
i0_i1_i2_i3_fused_1) * 1024
+                                + i0_i1_i2_i3_fused_2
+                            )
+                            % 384,
+                        )
+                        T.where(
+                            (i0_i1_i2_i3_fused_0 * 256 + i0_i1_i2_i3_fused_1) 
* 1024
+                            + i0_i1_i2_i3_fused_2
+                            < 1769472
+                        )
+                        T.reads(placeholder_1[ax0, ax1, ax2, ax3], 
T_reshape[ax0, ax1, ax2, ax3])
+                        T.writes(T_where[ax0, ax1, ax2, ax3])
+                        T_where[ax0, ax1, ax2, ax3] = T.Select(
+                            T.Cast("int32", placeholder_1[ax0, ax1, ax2, ax3]) 
!= 0,
+                            T.float32(-1000000000),
+                            T_reshape[ax0, ax1, ax2, ax3],
+                        )
+
+    mod = tvm.IRModule.from_expr(before)
+    func = tvm.tir.transform.ForceNarrowIndexToInt32()(mod)["main"]
+    tvm.ir.assert_structural_equal(func, expected)
+
+
+def test_block():
+    @T.prim_func
+    def before(A: T.Buffer((128,), "float32"), B: T.Buffer((128,), "float32")):
+        for i in T.serial(0, T.int64(16)):
+            for j in T.serial(0, T.int64(8)):
+                with T.block():
+                    vi = T.axis.spatial(T.int64(128), i * T.int64(8) + j)
+                    B[vi] = A[vi] + T.float32(1)
+
+    @T.prim_func
+    def expected(A: T.Buffer((128,), "float32"), B: T.Buffer((128,), 
"float32")):
+        for i in T.serial(0, T.int32(16)):
+            for j in T.serial(0, T.int32(8)):
+                with T.block():
+                    vi = T.axis.spatial(T.int32(128), i * T.int32(8) + j)
+                    B[vi] = A[vi] + T.float32(1)
+
+    mod = tvm.IRModule.from_expr(before)
+    func = tvm.tir.transform.ForceNarrowIndexToInt32()(mod)["main"]
+    tvm.ir.assert_structural_equal(func, expected)
+
+
+def test_fail_on_buffer_map():
+    @T.prim_func
+    def func(A: T.Buffer((128,), "int64"), B: T.Buffer((128,), "int64")):
+        for i in T.serial(0, 16):
+            for j in T.serial(0, 8):
+                with T.block():
+                    vi = T.axis.spatial(128, i * 8 + j)
+                    B[vi] = A[vi] + T.int64(1)
+
+    mod = tvm.IRModule.from_expr(func)
+    with pytest.raises(TVMError):
+        tvm.tir.transform.ForceNarrowIndexToInt32()(mod)["main"]
+
+
+def test_fail_on_buffer_map():
+    @T.prim_func
+    def func(A: T.Buffer((128,), "int32"), B: T.Buffer((128,), "int32")):
+        C = T.alloc_buffer((128,), "int64")
+        for i in T.serial(0, 16):
+            for j in T.serial(0, 8):
+                with T.block():
+                    vi = T.axis.spatial(128, i * 8 + j)
+                    C[vi] = T.cast(A[vi], "int64") + T.int64(1)
+        for i in T.serial(0, 16):
+            for j in T.serial(0, 8):
+                with T.block():
+                    vi = T.axis.spatial(128, i * 8 + j)
+                    B[vi] = T.cast(C[vi] + T.int64(1), "int32")
+
+    mod = tvm.IRModule.from_expr(func)
+    with pytest.raises(TVMError):
+        tvm.tir.transform.ForceNarrowIndexToInt32()(mod)["main"]
+
+
+if __name__ == "__main__":
+    tvm.testing.main()

Reply via email to