This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 6ccdb45844 [TIR] Refactor division simplification in RewriteSimplifier 
(#18319)
6ccdb45844 is described below

commit 6ccdb45844605a38a018c0aadb2807f1b765593c
Author: Lei Wang <[email protected]>
AuthorDate: Sun Oct 19 04:57:49 2025 +0800

    [TIR] Refactor division simplification in RewriteSimplifier (#18319)
    
    * Refactor division simplification in RewriteSimplifier and add 
corresponding test
    
    This commit removes the specific case for rewriting division by a constant 
float in the RewriteSimplifier. Additionally, a new test is introduced to 
verify the behavior of float division simplification, ensuring that the 
division is correctly handled without the previous rewrite logic.
    
    
    * test fix
    
    * test fix
    
    * cifix
    
    * fix
---
 src/arith/rewrite_simplify.cc                      |   7 -
 tests/python/arith/test_arith_simplify.py          |  12 +
 tests/python/relax/test_codegen_cudnn.py           |   4 +-
 tests/python/relax/test_op_create.py               |   2 +-
 .../python/relax/test_transform_legalize_ops_nn.py | 296 ++++++++++-----------
 .../relax/test_transform_legalize_ops_qdq.py       |   4 +-
 ...st_transform_legalize_ops_search_statistical.py |  14 +-
 7 files changed, 170 insertions(+), 169 deletions(-)

diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc
index e333f85a32..65b6e408e2 100644
--- a/src/arith/rewrite_simplify.cc
+++ b/src/arith/rewrite_simplify.cc
@@ -774,13 +774,6 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const 
DivNode* op) {
   // Pattern var for lanes in broadcast and ramp
   PVar<PrimExpr> lanes;
 
-  // x / 2.0 = x * 0.5
-  if (const FloatImmNode* ptr = op->b.as<FloatImmNode>()) {
-    ICHECK(op->dtype.is_float() || op->dtype.is_bfloat16() ||
-           datatype::Registry::Global()->GetTypeRegistered(op->dtype.code()));
-    return op->a * make_const(op->b.dtype(), 1.0 / ptr->value);
-  }
-
   // Vector rules
   if (op->dtype.is_scalable_or_fixed_length_vector()) {
     // NOTE: use div as the pattern also works for float.
diff --git a/tests/python/arith/test_arith_simplify.py 
b/tests/python/arith/test_arith_simplify.py
index 5a61cb8a52..161548a7a1 100644
--- a/tests/python/arith/test_arith_simplify.py
+++ b/tests/python/arith/test_arith_simplify.py
@@ -21,6 +21,7 @@ import tvm
 import tvm.testing
 from tvm import tir
 from tvm.script import tir as T
+import tvm.ir
 
 
 def test_simplify_reshape_flattened_index():
@@ -144,5 +145,16 @@ def test_simplify_floor_mod_with_linear_offset():
     assert ana.can_prove_equal(tvm.tir.floormod(expr1, divisor2), 0)
 
 
+def test_simplify_float_division():
+    # Test for the discussion:
+    # 
https://discuss.tvm.apache.org/t/discuss-is-constant-division-to-multiplication-rewrite-in-tvm-necessary/18615
+    ana = tvm.arith.Analyzer()
+    x = tir.Var("x", "float32")
+    ry = x / 27
+    # in old version, the division will be rewritten into x * T.float32(1 / 27)
+    sy = ana.rewrite_simplify(ry)
+    tvm.ir.assert_structural_equal(ry, sy)
+
+
 if __name__ == "__main__":
     tvm.testing.main()
diff --git a/tests/python/relax/test_codegen_cudnn.py 
b/tests/python/relax/test_codegen_cudnn.py
index 10ba775a6d..f066ad1a69 100644
--- a/tests/python/relax/test_codegen_cudnn.py
+++ b/tests/python/relax/test_codegen_cudnn.py
@@ -193,7 +193,9 @@ def test_conv2d_offload(data_shape, weight_shape, dtype, 
with_bias, activation):
     out = get_result_with_relax_cudnn_offload(mod, args)
     ref = build_and_run(mod, args, "llvm", legalize=True)
     if dtype == "float16":
-        tvm.testing.assert_allclose(out, ref, rtol=1e-1, atol=1e-1)
+        # FIXME(lei): currently raise into 3e-1 to prevent flaky test
+        # see https://github.com/apache/tvm/pull/18319
+        tvm.testing.assert_allclose(out, ref, rtol=3e-1, atol=3e-1)
     else:
         tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
 
diff --git a/tests/python/relax/test_op_create.py 
b/tests/python/relax/test_op_create.py
index d6e0a5e239..7269dfdbcf 100644
--- a/tests/python/relax/test_op_create.py
+++ b/tests/python/relax/test_op_create.py
@@ -661,7 +661,7 @@ def test_arange_infer_struct_info_shape_var():
     _check_inference(
         bb,
         relax.op.arange(start, stop, 2),
-        relax.TensorStructInfo((T.cast(T.ceil((stop - start) * 0.5), 
"int64"),), "float32"),
+        relax.TensorStructInfo((T.cast(T.ceil((stop - start) / 2), "int64"),), 
"float32"),
     )
     _check_inference(
         bb,
diff --git a/tests/python/relax/test_transform_legalize_ops_nn.py 
b/tests/python/relax/test_transform_legalize_ops_nn.py
index ff03ab4152..de2f183a10 100644
--- a/tests/python/relax/test_transform_legalize_ops_nn.py
+++ b/tests/python/relax/test_transform_legalize_ops_nn.py
@@ -949,7 +949,7 @@ def test_adaptive_avg_pool2d():
                     T.reads(adaptive_pool_sum[ax0, ax1, ax2, ax3, ax4])
                     T.writes(adaptive_pool_avg[ax0, ax1, ax2, ax3, ax4])
                     
T.block_attr({"schedule_rule":"meta_schedule.adaptive_pool_avg"})
-                    adaptive_pool_avg[ax0, ax1, ax2, ax3, ax4] = 
adaptive_pool_sum[ax0, ax1, ax2, ax3, ax4] * T.float32(0.020408163265306121)
+                    adaptive_pool_avg[ax0, ax1, ax2, ax3, ax4] = 
adaptive_pool_sum[ax0, ax1, ax2, ax3, ax4] / T.float32(49.0)
     # fmt: on
 
     mod = LegalizeOps()(AdaptiveAvgPool2D)
@@ -1104,15 +1104,14 @@ def test_leakyrelu():
             return gv
 
         @T.prim_func(private=True)
-        def leaky_relu(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), 
"float32"), compute: T.Buffer((T.int64(2), T.int64(3)), "float32")):
+        def leaky_relu(x: T.Buffer((T.int64(2), T.int64(3)), "float32"), 
compute: T.Buffer((T.int64(2), T.int64(3)), "float32")):
             T.func_attr({"tir.noalias": True})
             for i0, i1 in T.grid(T.int64(2), T.int64(3)):
                 with T.block("compute"):
-                    i0_1, i1_1 = T.axis.remap("SS", [i0, i1])
-                    T.reads(rxplaceholder[i0_1, i1_1])
-                    T.writes(compute[i0_1, i1_1])
-                    compute[i0_1, i1_1] = T.Select(T.float32(0) < 
rxplaceholder[i0_1, i1_1], rxplaceholder[i0_1, i1_1], \
-                                                   rxplaceholder[i0_1, i1_1] * 
T.float32(0.02))
+                    v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
+                    T.reads(x[v_i0, v_i1])
+                    T.writes(compute[v_i0, v_i1])
+                    compute[v_i0, v_i1] = T.Select(T.float32(0.0) < x[v_i0, 
v_i1], x[v_i0, v_i1], x[v_i0, v_i1] * T.float32(0.02))
     # fmt: on
 
     mod = LegalizeOps()(LeakyRelu)
@@ -1140,19 +1139,17 @@ def test_leakyrelu_symbolic():
             return gv
 
         @T.prim_func(private=True)
-        def leaky_relu(var_rxplaceholder: T.handle, var_compute: T.handle):
+        def leaky_relu(var_x: T.handle, var_compute: T.handle):
             T.func_attr({"tir.noalias": True})
-            m = T.int64()
-            n = T.int64()
-            rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n], 
dtype="float32")
-            compute = T.match_buffer(var_compute, [m, n], dtype="float32")
+            m, n = T.int64(), T.int64()
+            x = T.match_buffer(var_x, (m, n))
+            compute = T.match_buffer(var_compute, (m, n))
             for i0, i1 in T.grid(m, n):
                 with T.block("compute"):
-                    i0_1, i1_1 = T.axis.remap("SS", [i0, i1])
-                    T.reads(rxplaceholder[i0_1, i1_1])
-                    T.writes(compute[i0_1, i1_1])
-                    compute[i0_1, i1_1] = T.Select(T.float32(0) < 
rxplaceholder[i0_1, i1_1], rxplaceholder[i0_1, i1_1], \
-                                                    rxplaceholder[i0_1, i1_1] 
* T.float32(0.03))
+                    v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
+                    T.reads(x[v_i0, v_i1])
+                    T.writes(compute[v_i0, v_i1])
+                    compute[v_i0, v_i1] = T.Select(T.float32(0.0) < x[v_i0, 
v_i1], x[v_i0, v_i1], x[v_i0, v_i1] * T.float32(0.029999999999999999))
     # fmt: on
 
     mod = LegalizeOps()(LeakyRelu)
@@ -1259,42 +1256,42 @@ def test_gelu():
             return gv
 
         @T.prim_func(private=True)
-        def gelu(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), 
T_multiply: T.Buffer((T.int64(2), T.int64(3)), "float32")):
+        def gelu(x: T.Buffer((T.int64(2), T.int64(3)), "float32"), T_multiply: 
T.Buffer((T.int64(2), T.int64(3)), "float32")):
             T.func_attr({"tir.noalias": True})
-            T_multiply_1 = T.alloc_buffer([T.int64(2), T.int64(3)], 
dtype="float32")
-            compute = T.alloc_buffer([T.int64(2), T.int64(3)], dtype="float32")
-            T_multiply_2 = T.alloc_buffer([T.int64(2), T.int64(3)], 
dtype="float32")
-            T_divide = T.alloc_buffer([T.int64(2), T.int64(3)], 
dtype="float32")
-            for i0, i1 in T.grid(T.int64(2), T.int64(3)):
+            T_multiply_1 = T.alloc_buffer((T.int64(2), T.int64(3)))
+            compute = T.alloc_buffer((T.int64(2), T.int64(3)))
+            T_multiply_2 = T.alloc_buffer((T.int64(2), T.int64(3)))
+            T_add = T.alloc_buffer((T.int64(2), T.int64(3)))
+            for ax0, ax1 in T.grid(T.int64(2), T.int64(3)):
                 with T.block("T_multiply"):
-                    ax0, ax1 = T.axis.remap("SS", [i0, i1])
-                    T.reads(rxplaceholder[ax0, ax1])
-                    T.writes(T_multiply_1[ax0, ax1])
-                    T_multiply_1[ax0, ax1] = rxplaceholder[ax0, ax1] * 
T.float32(0.70710678118654757)
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(x[v_ax0, v_ax1])
+                    T.writes(T_multiply_1[v_ax0, v_ax1])
+                    T_multiply_1[v_ax0, v_ax1] = x[v_ax0, v_ax1] * 
T.float32(0.70710678118654757)
             for i0, i1 in T.grid(T.int64(2), T.int64(3)):
                 with T.block("compute"):
-                    i0_1, i1_1 = T.axis.remap("SS", [i0, i1])
-                    T.reads(T_multiply_1[i0_1, i1_1])
-                    T.writes(compute[i0_1, i1_1])
-                    compute[i0_1, i1_1] = T.erf(T_multiply_1[i0_1, i1_1], 
dtype="float32")
-            for i0, i1 in T.grid(T.int64(2), T.int64(3)):
+                    v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
+                    T.reads(T_multiply_1[v_i0, v_i1])
+                    T.writes(compute[v_i0, v_i1])
+                    compute[v_i0, v_i1] = T.erf(T_multiply_1[v_i0, v_i1])
+            for ax0, ax1 in T.grid(T.int64(2), T.int64(3)):
                 with T.block("T_multiply_1"):
-                    ax0, ax1 = T.axis.remap("SS", [i0, i1])
-                    T.reads(compute[ax0, ax1])
-                    T.writes(T_multiply_2[ax0, ax1])
-                    T_multiply_2[ax0, ax1] = compute[ax0, ax1] * T.float32(0.5)
-            for i0, i1 in T.grid(T.int64(2), T.int64(3)):
-                with T.block("T_divide"):
-                    ax0, ax1 = T.axis.remap("SS", [i0, i1])
-                    T.reads(T_multiply_2[ax0, ax1])
-                    T.writes(T_divide[ax0, ax1])
-                    T_divide[ax0, ax1] = T.float32(0.5) + T_multiply_2[ax0, 
ax1]
-            for i0, i1 in T.grid(T.int64(2), T.int64(3)):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(compute[v_ax0, v_ax1])
+                    T.writes(T_multiply_2[v_ax0, v_ax1])
+                    T_multiply_2[v_ax0, v_ax1] = compute[v_ax0, v_ax1] * 
T.float32(0.5)
+            for ax0, ax1 in T.grid(T.int64(2), T.int64(3)):
+                with T.block("T_add"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(T_multiply_2[v_ax0, v_ax1])
+                    T.writes(T_add[v_ax0, v_ax1])
+                    T_add[v_ax0, v_ax1] = T.float32(0.5) + T_multiply_2[v_ax0, 
v_ax1]
+            for ax0, ax1 in T.grid(T.int64(2), T.int64(3)):
                 with T.block("T_multiply_2"):
-                    ax0, ax1 = T.axis.remap("SS", [i0, i1])
-                    T.reads(rxplaceholder[ax0, ax1], T_divide[ax0, ax1])
-                    T.writes(T_multiply[ax0, ax1])
-                    T_multiply[ax0, ax1] = rxplaceholder[ax0, ax1] * 
T_divide[ax0, ax1]
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(x[v_ax0, v_ax1], T_add[v_ax0, v_ax1])
+                    T.writes(T_multiply[v_ax0, v_ax1])
+                    T_multiply[v_ax0, v_ax1] = x[v_ax0, v_ax1] * T_add[v_ax0, 
v_ax1]
     # fmt: on
 
     mod = LegalizeOps()(Gelu)
@@ -1322,46 +1319,45 @@ def test_gelu_symbolic():
             return gv
 
         @T.prim_func(private=True)
-        def gelu(var_rxplaceholder: T.handle, var_T_multiply: T.handle):
+        def gelu(var_x: T.handle, var_T_multiply: T.handle):
             T.func_attr({"tir.noalias": True})
-            m = T.int64()
-            n = T.int64()
-            rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n], 
dtype="float32")
-            T_multiply = T.match_buffer(var_T_multiply, [m, n], 
dtype="float32")
-            T_multiply_1 = T.alloc_buffer([m, n], dtype="float32")
-            compute = T.alloc_buffer([m, n], dtype="float32")
-            T_multiply_2 = T.alloc_buffer([m, n], dtype="float32")
-            T_add = T.alloc_buffer([m, n], dtype="float32")
-            for i0, i1 in T.grid(m, n):
+            m, n = T.int64(), T.int64()
+            x = T.match_buffer(var_x, (m, n))
+            T_multiply = T.match_buffer(var_T_multiply, (m, n))
+            T_multiply_1 = T.alloc_buffer((m, n))
+            compute = T.alloc_buffer((m, n))
+            T_multiply_2 = T.alloc_buffer((m, n))
+            T_add = T.alloc_buffer((m, n))
+            for ax0, ax1 in T.grid(m, n):
                 with T.block("T_multiply"):
-                    ax0, ax1 = T.axis.remap("SS", [i0, i1])
-                    T.reads(rxplaceholder[ax0, ax1])
-                    T.writes(T_multiply_1[ax0, ax1])
-                    T_multiply_1[ax0, ax1] = rxplaceholder[ax0, ax1] * 
T.float32(0.70710678118654757)
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(x[v_ax0, v_ax1])
+                    T.writes(T_multiply_1[v_ax0, v_ax1])
+                    T_multiply_1[v_ax0, v_ax1] = x[v_ax0, v_ax1] * 
T.float32(0.70710678118654757)
             for i0, i1 in T.grid(m, n):
                 with T.block("compute"):
-                    i0_1, i1_1 = T.axis.remap("SS", [i0, i1])
-                    T.reads(T_multiply_1[i0_1, i1_1])
-                    T.writes(compute[i0_1, i1_1])
-                    compute[i0_1, i1_1] = T.erf(T_multiply_1[i0_1, i1_1], 
dtype="float32")
-            for i0, i1 in T.grid(m, n):
+                    v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
+                    T.reads(T_multiply_1[v_i0, v_i1])
+                    T.writes(compute[v_i0, v_i1])
+                    compute[v_i0, v_i1] = T.erf(T_multiply_1[v_i0, v_i1])
+            for ax0, ax1 in T.grid(m, n):
                 with T.block("T_multiply_1"):
-                    ax0, ax1 = T.axis.remap("SS", [i0, i1])
-                    T.reads(compute[ax0, ax1])
-                    T.writes(T_multiply_2[ax0, ax1])
-                    T_multiply_2[ax0, ax1] = compute[ax0, ax1] * T.float32(0.5)
-            for i0, i1 in T.grid(m, n):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(compute[v_ax0, v_ax1])
+                    T.writes(T_multiply_2[v_ax0, v_ax1])
+                    T_multiply_2[v_ax0, v_ax1] = compute[v_ax0, v_ax1] * 
T.float32(0.5)
+            for ax0, ax1 in T.grid(m, n):
                 with T.block("T_add"):
-                    ax0, ax1 = T.axis.remap("SS", [i0, i1])
-                    T.reads(T_multiply_2[ax0, ax1])
-                    T.writes(T_add[ax0, ax1])
-                    T_add[ax0, ax1] = T.float32(0.5) + T_multiply_2[ax0, ax1]
-            for i0, i1 in T.grid(m, n):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(T_multiply_2[v_ax0, v_ax1])
+                    T.writes(T_add[v_ax0, v_ax1])
+                    T_add[v_ax0, v_ax1] = T.float32(0.5) + T_multiply_2[v_ax0, 
v_ax1]
+            for ax0, ax1 in T.grid(m, n):
                 with T.block("T_multiply_2"):
-                    ax0, ax1 = T.axis.remap("SS", [i0, i1])
-                    T.reads(rxplaceholder[ax0, ax1], T_add[ax0, ax1])
-                    T.writes(T_multiply[ax0, ax1])
-                    T_multiply[ax0, ax1] = rxplaceholder[ax0, ax1] * 
T_add[ax0, ax1]
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(x[v_ax0, v_ax1], T_add[v_ax0, v_ax1])
+                    T.writes(T_multiply[v_ax0, v_ax1])
+                    T_multiply[v_ax0, v_ax1] = x[v_ax0, v_ax1] * T_add[v_ax0, 
v_ax1]
     # fmt: on
 
     mod = LegalizeOps()(Gelu)
@@ -1887,29 +1883,29 @@ def test_cross_entropy_with_logits():
             return gv
 
         @T.prim_func(private=True)
-        def cross_entropy_with_logits(rxplaceholder: T.Buffer(T.int64(3), 
"float32"), rxplaceholder_1: T.Buffer(T.int64(3), "float32"), T_multiply: 
T.Buffer((), "float32")):
+        def cross_entropy_with_logits(x: T.Buffer((T.int64(3),), "float32"), 
y: T.Buffer((T.int64(3),), "float32"), T_multiply: T.Buffer((), "float32")):
             T.func_attr({"tir.noalias": True})
-            T_multiply_1 = T.alloc_buffer([T.int64(3)], dtype="float32")
-            T_multiply_red = T.alloc_buffer([], dtype="float32")
-            for i0 in T.serial(T.int64(3)):
+            T_multiply_1 = T.alloc_buffer((T.int64(3),))
+            T_multiply_red = T.alloc_buffer(())
+            for ax0 in range(T.int64(3)):
                 with T.block("T_multiply"):
-                    ax0 = T.axis.spatial(T.int64(3), i0)
-                    T.reads(rxplaceholder[ax0], rxplaceholder_1[ax0])
-                    T.writes(T_multiply_1[ax0])
-                    T_multiply_1[ax0] = rxplaceholder[ax0] * 
rxplaceholder_1[ax0]
-            for i0 in T.serial(T.int64(3)):
+                    v_ax0 = T.axis.spatial(T.int64(3), ax0)
+                    T.reads(x[v_ax0], y[v_ax0])
+                    T.writes(T_multiply_1[v_ax0])
+                    T_multiply_1[v_ax0] = x[v_ax0] * y[v_ax0]
+            for k0 in range(T.int64(3)):
                 with T.block("T_multiply_red"):
-                    k0 = T.axis.reduce(T.int64(3), i0)
-                    T.reads(T_multiply_1[k0])
+                    v_k0 = T.axis.reduce(T.int64(3), k0)
+                    T.reads(T_multiply_1[v_k0])
                     T.writes(T_multiply_red[()])
                     with T.init():
-                        T_multiply_red[()] = T.float32(0)
-                    T_multiply_red[()] = T_multiply_red[()] + T_multiply_1[k0]
+                        T_multiply_red[()] = T.float32(0.0)
+                    T_multiply_red[()] = T_multiply_red[()] + 
T_multiply_1[v_k0]
             with T.block("T_multiply_1"):
                 vi = T.axis.spatial(1, T.int64(0))
                 T.reads(T_multiply_red[()])
                 T.writes(T_multiply[()])
-                T_multiply[()] = T_multiply_red[()] * T.float32(-1)
+                T_multiply[()] = T_multiply_red[()] * T.float32(-1.0)
     # fmt: on
 
     mod = LegalizeOps()(CrossEntropyWithLogits)
@@ -1933,35 +1929,35 @@ def test_cross_entropy_with_logits_batch():
             return gv
 
         @T.prim_func(private=True)
-        def cross_entropy_with_logits(rxplaceholder: T.Buffer((T.int64(2), 
T.int64(3)), "float32"), rxplaceholder_1: T.Buffer((T.int64(2), T.int64(3)), 
"float32"), T_divide: T.Buffer((), "float32")):
+        def cross_entropy_with_logits(x: T.Buffer((T.int64(2), T.int64(3)), 
"float32"), y: T.Buffer((T.int64(2), T.int64(3)), "float32"), T_divide: 
T.Buffer((), "float32")):
             T.func_attr({"tir.noalias": True})
-            T_multiply = T.alloc_buffer([T.int64(2), T.int64(3)], 
dtype="float32")
-            T_multiply_red = T.alloc_buffer([], dtype="float32")
-            T_multiply_1 = T.alloc_buffer([], dtype="float32")
-            for i0, i1 in T.grid(T.int64(2), T.int64(3)):
+            T_multiply = T.alloc_buffer((T.int64(2), T.int64(3)))
+            T_multiply_red = T.alloc_buffer(())
+            T_multiply_1 = T.alloc_buffer(())
+            for ax0, ax1 in T.grid(T.int64(2), T.int64(3)):
                 with T.block("T_multiply"):
-                    ax0, ax1 = T.axis.remap("SS", [i0, i1])
-                    T.reads(rxplaceholder[ax0, ax1], rxplaceholder_1[ax0, ax1])
-                    T.writes(T_multiply[ax0, ax1])
-                    T_multiply[ax0, ax1] = rxplaceholder[ax0, ax1] * 
rxplaceholder_1[ax0, ax1]
-            for i0, i1 in T.grid(T.int64(2), T.int64(3)):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(x[v_ax0, v_ax1], y[v_ax0, v_ax1])
+                    T.writes(T_multiply[v_ax0, v_ax1])
+                    T_multiply[v_ax0, v_ax1] = x[v_ax0, v_ax1] * y[v_ax0, 
v_ax1]
+            for k0, k1 in T.grid(T.int64(2), T.int64(3)):
                 with T.block("T_multiply_red"):
-                    k0, k1 = T.axis.remap("RR", [i0, i1])
-                    T.reads(T_multiply[k0, k1])
+                    v_k0, v_k1 = T.axis.remap("RR", [k0, k1])
+                    T.reads(T_multiply[v_k0, v_k1])
                     T.writes(T_multiply_red[()])
                     with T.init():
-                        T_multiply_red[()] = T.float32(0)
-                    T_multiply_red[()] = T_multiply_red[()] + T_multiply[k0, 
k1]
+                        T_multiply_red[()] = T.float32(0.0)
+                    T_multiply_red[()] = T_multiply_red[()] + T_multiply[v_k0, 
v_k1]
             with T.block("T_multiply_1"):
                 vi = T.axis.spatial(1, T.int64(0))
                 T.reads(T_multiply_red[()])
                 T.writes(T_multiply_1[()])
-                T_multiply_1[()] = T_multiply_red[()] * T.float32(-1)
+                T_multiply_1[()] = T_multiply_red[()] * T.float32(-1.0)
             with T.block("T_divide"):
                 vi = T.axis.spatial(1, T.int64(0))
                 T.reads(T_multiply_1[()])
                 T.writes(T_divide[()])
-                T_divide[()] = T_multiply_1[()] * T.float32(0.5)
+                T_divide[()] = T_multiply_1[()] / T.float32(2)
     # fmt: on
 
     mod = LegalizeOps()(CrossEntropyWithLogits)
@@ -1987,34 +1983,33 @@ def test_cross_entropy_with_logits_batch_symbolic():
             return gv
 
         @T.prim_func(private=True)
-        def cross_entropy_with_logits(var_rxplaceholder: T.handle, 
var_rxplaceholder_1: T.handle, T_divide: T.Buffer((), "float32")):
+        def cross_entropy_with_logits(var_x: T.handle, var_y: T.handle, 
T_divide: T.Buffer((), "float32")):
             T.func_attr({"tir.noalias": True})
-            m = T.int64()
-            n = T.int64()
-            rxplaceholder = T.match_buffer(var_rxplaceholder, [n, m], 
dtype="float32")
-            rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [n, m], 
dtype="float32")
-            T_multiply = T.alloc_buffer([n, m], dtype="float32")
-            T_multiply_red = T.alloc_buffer([], dtype="float32")
-            T_multiply_1 = T.alloc_buffer([], dtype="float32")
+            m, n = T.int64(), T.int64()
+            x = T.match_buffer(var_x, (n, m))
+            y = T.match_buffer(var_y, (n, m))
+            T_multiply = T.alloc_buffer((n, m))
+            T_multiply_red = T.alloc_buffer(())
+            T_multiply_1 = T.alloc_buffer(())
             for ax0, ax1 in T.grid(n, m):
                 with T.block("T_multiply"):
                     v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
-                    T.reads(rxplaceholder[v_ax0, v_ax1], 
rxplaceholder_1[v_ax0, v_ax1])
+                    T.reads(x[v_ax0, v_ax1], y[v_ax0, v_ax1])
                     T.writes(T_multiply[v_ax0, v_ax1])
-                    T_multiply[v_ax0, v_ax1] = rxplaceholder[v_ax0, v_ax1] * 
rxplaceholder_1[v_ax0, v_ax1]
+                    T_multiply[v_ax0, v_ax1] = x[v_ax0, v_ax1] * y[v_ax0, 
v_ax1]
             for k0, k1 in T.grid(n, m):
                 with T.block("T_multiply_red"):
                     v_k0, v_k1 = T.axis.remap("RR", [k0, k1])
                     T.reads(T_multiply[v_k0, v_k1])
                     T.writes(T_multiply_red[()])
                     with T.init():
-                        T_multiply_red[()] = T.float32(0)
+                        T_multiply_red[()] = T.float32(0.0)
                     T_multiply_red[()] = T_multiply_red[()] + T_multiply[v_k0, 
v_k1]
             with T.block("T_multiply_1"):
                 vi = T.axis.spatial(1, T.int64(0))
                 T.reads(T_multiply_red[()])
                 T.writes(T_multiply_1[()])
-                T_multiply_1[()] = T_multiply_red[()] * T.float32(-1)
+                T_multiply_1[()] = T_multiply_red[()] * T.float32(-1.0)
             with T.block("T_divide"):
                 vi = T.axis.spatial(1, T.int64(0))
                 T.reads(T_multiply_1[()])
@@ -2217,7 +2212,7 @@ def test_batch_norm():
                         v_ax0 = T.axis.spatial(T.int64(3), ax0)
                         T.reads(x_red[v_ax0])
                         T.writes(T_divide_1[v_ax0])
-                        T_divide_1[v_ax0] = x_red[v_ax0] * 
T.float32(0.00063775510204081628)
+                        T_divide_1[v_ax0] = x_red[v_ax0] / T.float32(1568)
                 for ax0 in range(T.int64(3)):
                     with T.block("T_multiply_2"):
                         v_ax0 = T.axis.spatial(T.int64(3), ax0)
@@ -2303,7 +2298,7 @@ def test_batch_norm():
                         v_ax0 = T.axis.spatial(T.int64(3), ax0)
                         T.reads(T_multiply_red[v_ax0])
                         T.writes(T_divide_2[v_ax0])
-                        T_divide_2[v_ax0] = T_multiply_red[v_ax0] * 
T.float32(0.00063775510204081628)
+                        T_divide_2[v_ax0] = T_multiply_red[v_ax0] / 
T.float32(1568)
                 for ax0 in range(T.int64(3)):
                     with T.block("T_multiply_5"):
                         v_ax0 = T.axis.spatial(T.int64(3), ax0)
@@ -2676,7 +2671,7 @@ def test_layer_norm():
                     ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
                     T.reads(rxplaceholder[ax0, ax1, ax2, ax3], 
rxplaceholder_red_temp_v0[ax0, ax1], rxplaceholder_red_temp_v1[ax0, ax1], 
rxplaceholder_1[ax2, ax3], rxplaceholder_2[ax2, ax3])
                     T.writes(T_layer_norm[ax0, ax1, ax2, ax3])
-                    T_layer_norm[ax0, ax1, ax2, ax3] = (rxplaceholder[ax0, 
ax1, ax2, ax3] - rxplaceholder_red_temp_v0[ax0, ax1] * T.float32(0.05)) * 
T.rsqrt(rxplaceholder_red_temp_v1[ax0, ax1] * T.float32(0.05) - 
rxplaceholder_red_temp_v0[ax0, ax1] * T.float32(0.05) * 
(rxplaceholder_red_temp_v0[ax0, ax1] * T.float32(0.05)) + T.float32(1e-05), 
dtype="float32") * rxplaceholder_1[ax2, ax3] + rxplaceholder_2[ax2, ax3]
+                    T_layer_norm[ax0, ax1, ax2, ax3] = (rxplaceholder[ax0, 
ax1, ax2, ax3] - rxplaceholder_red_temp_v0[ax0, ax1] / T.float32(20)) * 
T.rsqrt(rxplaceholder_red_temp_v1[ax0, ax1] / T.float32(20) - 
rxplaceholder_red_temp_v0[ax0, ax1] / T.float32(20) * 
(rxplaceholder_red_temp_v0[ax0, ax1] / T.float32(20)) + T.float32(1e-05), 
dtype="float32") * rxplaceholder_1[ax2, ax3] + rxplaceholder_2[ax2, ax3]
     # fmt: on
     mod = LegalizeOps()(LayerNorm)
     tvm.ir.assert_structural_equal(mod, Expected)
@@ -2720,7 +2715,7 @@ def test_layer_norm_1d():
                     v_ax0 = T.axis.spatial(T.int64(3), ax0)
                     T.reads(x[v_ax0], x_red_temp_v0[()], x_red_temp_v1[()], 
layer_norm_weight[v_ax0], layer_norm_bias[v_ax0])
                     T.writes(T_layer_norm[v_ax0])
-                    T_layer_norm[v_ax0] = (x[v_ax0] - x_red_temp_v0[()] * 
T.float32(0.33333333333333331)) * T.rsqrt(x_red_temp_v1[()] * 
T.float32(0.33333333333333331) - x_red_temp_v0[()] * 
T.float32(0.33333333333333331) * (x_red_temp_v0[()] * 
T.float32(0.33333333333333331)) + T.float32(1.0000000000000001e-05)) * 
layer_norm_weight[v_ax0] + layer_norm_bias[v_ax0]
+                    T_layer_norm[v_ax0] = (x[v_ax0] - x_red_temp_v0[()] / 
T.float32(3)) * T.rsqrt(x_red_temp_v1[()] / T.float32(3) - x_red_temp_v0[()] / 
T.float32(3) * (x_red_temp_v0[()] / T.float32(3)) + 
T.float32(1.0000000000000001e-05)) * layer_norm_weight[v_ax0] + 
layer_norm_bias[v_ax0]
 
         @R.function
         def forward(x: R.Tensor((3,), dtype="float32"), layer_norm_weight: 
R.Tensor((3,), dtype="float32"), layer_norm_bias: R.Tensor((3,), 
dtype="float32")) -> R.Tensor((3,), dtype="float32"):
@@ -2911,7 +2906,7 @@ def test_group_norm():
                     v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap("SSSSS", 
[ax0, ax1, ax2, ax3, ax4])
                     T.reads(T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4], 
rxplaceholder_red_temp_v0[v_ax0, v_ax1], rxplaceholder_red_temp_v1[v_ax0, 
v_ax1], T_reshape_2[v_ax1, v_ax2], T_reshape_3[v_ax1, v_ax2])
                     T.writes(T_group_norm[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
-                    T_group_norm[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = 
(T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] - 
rxplaceholder_red_temp_v0[v_ax0, v_ax1] * T.float32(0.025000000000000001)) * 
T.rsqrt(rxplaceholder_red_temp_v1[v_ax0, v_ax1] * 
T.float32(0.025000000000000001) - rxplaceholder_red_temp_v0[v_ax0, v_ax1] * 
T.float32(0.025000000000000001) * (rxplaceholder_red_temp_v0[v_ax0, v_ax1] * 
T.float32(0.025000000000000001)) + T.float32(1.0000000000000001e-05)) * 
T_reshape_2[v_ax1, v [...]
+                    T_group_norm[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = 
(T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] - 
rxplaceholder_red_temp_v0[v_ax0, v_ax1] / T.float32(40)) * 
T.rsqrt(rxplaceholder_red_temp_v1[v_ax0, v_ax1] / T.float32(40) - 
rxplaceholder_red_temp_v0[v_ax0, v_ax1] / T.float32(40) * 
(rxplaceholder_red_temp_v0[v_ax0, v_ax1] / T.float32(40)) + 
T.float32(1.0000000000000001e-05)) * T_reshape_2[v_ax1, v_ax2] + 
T_reshape_3[v_ax1, v_ax2]
             for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(4), 
T.int64(4), T.int64(5)):
                 with T.block("T_reshape_3"):
                     v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, 
ax1, ax2, ax3])
@@ -2996,7 +2991,7 @@ def test_group_norm_fp16():
                     v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap("SSSSS", 
[ax0, ax1, ax2, ax3, ax4])
                     T.reads(T_cast[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4], 
rxplaceholder_red_temp_v0[v_ax0, v_ax1], rxplaceholder_red_temp_v1[v_ax0, 
v_ax1], T_reshape_2[v_ax1, v_ax2], T_reshape_3[v_ax1, v_ax2])
                     T.writes(T_group_norm[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
-                    T_group_norm[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = 
T.Cast("float16", (T_cast[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] - 
rxplaceholder_red_temp_v0[v_ax0, v_ax1] * T.float32(0.025000000000000001)) * 
T.rsqrt(rxplaceholder_red_temp_v1[v_ax0, v_ax1] * 
T.float32(0.025000000000000001) - rxplaceholder_red_temp_v0[v_ax0, v_ax1] * 
T.float32(0.025000000000000001) * (rxplaceholder_red_temp_v0[v_ax0, v_ax1] * 
T.float32(0.025000000000000001)) + T.float32(1.0000000000000001e-05))) * T_resh 
[...]
+                    T_group_norm[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = 
T.Cast("float16", (T_cast[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] - 
rxplaceholder_red_temp_v0[v_ax0, v_ax1] / T.float32(40)) * 
T.rsqrt(rxplaceholder_red_temp_v1[v_ax0, v_ax1] / T.float32(40) - 
rxplaceholder_red_temp_v0[v_ax0, v_ax1] / T.float32(40) * 
(rxplaceholder_red_temp_v0[v_ax0, v_ax1] / T.float32(40)) + 
T.float32(1.0000000000000001e-05))) * T_reshape_2[v_ax1, v_ax2] + 
T_reshape_3[v_ax1, v_ax2]
             for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(4), 
T.int64(4), T.int64(5)):
                 with T.block("T_reshape_3"):
                     v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, 
ax1, ax2, ax3])
@@ -3143,7 +3138,7 @@ def test_rms_norm():
                     v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                     T.reads(T_multiply_red[v_ax0, v_ax1])
                     T.writes(rsqrt[v_ax0, v_ax1])
-                    rsqrt[v_ax0, v_ax1] = T.rsqrt(T_multiply_red[v_ax0, v_ax1] 
* T.float32(0.050000000000000003) + T.float32(1.0000000000000001e-05))
+                    rsqrt[v_ax0, v_ax1] = T.rsqrt(T_multiply_red[v_ax0, v_ax1] 
/ T.float32(20) + T.float32(1.0000000000000001e-05))
             for ax0, ax1 in T.grid(T.int64(4), T.int64(5)):
                 with T.block("T_cast_1"):
                     v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
@@ -3219,7 +3214,7 @@ def test_rms_norm_fp16():
                     v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                     T.reads(T_multiply_red[v_ax0, v_ax1])
                     T.writes(rsqrt[v_ax0, v_ax1])
-                    rsqrt[v_ax0, v_ax1] = T.rsqrt(T_multiply_red[v_ax0, v_ax1] 
* T.float32(0.050000000000000003) + T.float32(1.0000000000000001e-05))
+                    rsqrt[v_ax0, v_ax1] = T.rsqrt(T_multiply_red[v_ax0, v_ax1] 
/ T.float32(20) + T.float32(1.0000000000000001e-05))
             for ax0, ax1 in T.grid(T.int64(4), T.int64(5)):
                 with T.block("T_cast_1"):
                     v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
@@ -3381,7 +3376,7 @@ def test_rms_norm_no_bias():
                     v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                     T.reads(T_multiply_red[v_ax0, v_ax1])
                     T.writes(rsqrt[v_ax0, v_ax1])
-                    rsqrt[v_ax0, v_ax1] = T.rsqrt(T_multiply_red[v_ax0, v_ax1] 
* T.float32(0.050000000000000003) + T.float32(1.0000000000000001e-05))
+                    rsqrt[v_ax0, v_ax1] = T.rsqrt(T_multiply_red[v_ax0, v_ax1] 
/ T.float32(20) + T.float32(1.0000000000000001e-05))
             for ax0, ax1 in T.grid(T.int64(4), T.int64(5)):
                 with T.block("T_cast_1"):
                     v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
@@ -3424,7 +3419,7 @@ def test_attention():
     @tvm.script.ir_module
     class Expected:
         @T.prim_func(private=True)
-        def attention_bias(A: T.Buffer((T.int64(4), T.int64(16), T.int64(32), 
T.int64(8)), "float32"), B: T.Buffer((T.int64(4), T.int64(8), T.int64(32), 
T.int64(8)), "float32"), C: T.Buffer((T.int64(4), T.int64(8), T.int64(32), 
T.int64(16)), "float32"), D: T.Buffer((T.int64(4), T.int64(32), T.int64(16), 
T.int64(8)), "float32"), T_transpose: T.Buffer((T.int64(4), T.int64(16), 
T.int64(32), T.int64(16)), "float32")):
+        def attention_bias(q: T.Buffer((T.int64(4), T.int64(16), T.int64(32), 
T.int64(8)), "float32"), k: T.Buffer((T.int64(4), T.int64(8), T.int64(32), 
T.int64(8)), "float32"), v: T.Buffer((T.int64(4), T.int64(8), T.int64(32), 
T.int64(16)), "float32"), bias: T.Buffer((T.int64(4), T.int64(32), T.int64(16), 
T.int64(8)), "float32"), T_transpose: T.Buffer((T.int64(4), T.int64(16), 
T.int64(32), T.int64(16)), "float32")):
             T.func_attr({"tir.noalias": True})
             # with T.block("root"):
             T_transpose_1 = T.alloc_buffer((T.int64(4), T.int64(32), 
T.int64(16), T.int64(8)))
@@ -3450,9 +3445,9 @@ def test_attention():
             for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(32), 
T.int64(16), T.int64(8)):
                 with T.block("T_transpose"):
                     v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, 
ax1, ax2, ax3])
-                    T.reads(A[v_ax0, v_ax2, v_ax1, v_ax3])
+                    T.reads(q[v_ax0, v_ax2, v_ax1, v_ax3])
                     T.writes(T_transpose_1[v_ax0, v_ax1, v_ax2, v_ax3])
-                    T_transpose_1[v_ax0, v_ax1, v_ax2, v_ax3] = A[v_ax0, 
v_ax2, v_ax1, v_ax3]
+                    T_transpose_1[v_ax0, v_ax1, v_ax2, v_ax3] = q[v_ax0, 
v_ax2, v_ax1, v_ax3]
             for ax0, ax1, ax2 in T.grid(T.int64(128), T.int64(16), T.int64(8)):
                 with T.block("T_reshape"):
                     v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
@@ -3462,23 +3457,23 @@ def test_attention():
             for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(32), 
T.int64(8), T.int64(8)):
                 with T.block("T_transpose_1"):
                     v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, 
ax1, ax2, ax3])
-                    T.reads(B[v_ax0, v_ax2, v_ax1, v_ax3])
+                    T.reads(k[v_ax0, v_ax2, v_ax1, v_ax3])
                     T.writes(T_transpose_2[v_ax0, v_ax1, v_ax2, v_ax3])
-                    T_transpose_2[v_ax0, v_ax1, v_ax2, v_ax3] = B[v_ax0, 
v_ax2, v_ax1, v_ax3]
+                    T_transpose_2[v_ax0, v_ax1, v_ax2, v_ax3] = k[v_ax0, 
v_ax2, v_ax1, v_ax3]
             for ax0, ax1, ax2 in T.grid(T.int64(128), T.int64(8), T.int64(8)):
                 with T.block("T_reshape_1"):
                     v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                     T.reads(T_transpose_2[((v_ax2 // T.int64(8) + v_ax1) // 
T.int64(8) + v_ax0) % T.int64(128) // T.int64(32), ((v_ax2 // T.int64(8) + 
v_ax1) // T.int64(8) + v_ax0) % T.int64(32), (v_ax2 // T.int64(8) + v_ax1) % 
T.int64(8), v_ax2 % T.int64(8)])
                     T.writes(T_reshape_1[v_ax0, v_ax1, v_ax2])
                     T_reshape_1[v_ax0, v_ax1, v_ax2] = T_transpose_2[((v_ax2 
// T.int64(8) + v_ax1) // T.int64(8) + v_ax0) % T.int64(128) // T.int64(32), 
((v_ax2 // T.int64(8) + v_ax1) // T.int64(8) + v_ax0) % T.int64(32), (v_ax2 // 
T.int64(8) + v_ax1) % T.int64(8), v_ax2 % T.int64(8)]
-            for b, i, j, k in T.grid(T.int64(128), T.int64(16), T.int64(8), 
T.int64(8)):
+            for b, i, j, k_1 in T.grid(T.int64(128), T.int64(16), T.int64(8), 
T.int64(8)):
                 with T.block("T_batch_matmul_NT"):
-                    v_b, v_i, v_j, v_k = T.axis.remap("SSSR", [b, i, j, k])
+                    v_b, v_i, v_j, v_k = T.axis.remap("SSSR", [b, i, j, k_1])
                     T.reads(T_reshape[v_b, v_i, v_k], T_reshape_1[v_b, v_j, 
v_k])
                     T.writes(T_batch_matmul_NT[v_b, v_i, v_j])
                     T.block_attr({"layout_free_placeholders": [T_reshape_1]})
                     with T.init():
-                        T_batch_matmul_NT[v_b, v_i, v_j] = T.float32(0)
+                        T_batch_matmul_NT[v_b, v_i, v_j] = T.float32(0.0)
                     T_batch_matmul_NT[v_b, v_i, v_j] = T_batch_matmul_NT[v_b, 
v_i, v_j] + T_reshape[v_b, v_i, v_k] * T_reshape_1[v_b, v_j, v_k]
             for ax0, ax1, ax2 in T.grid(T.int64(128), T.int64(16), T.int64(8)):
                 with T.block("T_multiply"):
@@ -3495,9 +3490,9 @@ def test_attention():
             for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(32), 
T.int64(16), T.int64(8)):
                 with T.block("T_add"):
                     v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, 
ax1, ax2, ax3])
-                    T.reads(T_reshape_2[v_ax0, v_ax1, v_ax2, v_ax3], D[v_ax0, 
v_ax1, v_ax2, v_ax3])
+                    T.reads(T_reshape_2[v_ax0, v_ax1, v_ax2, v_ax3], 
bias[v_ax0, v_ax1, v_ax2, v_ax3])
                     T.writes(T_add[v_ax0, v_ax1, v_ax2, v_ax3])
-                    T_add[v_ax0, v_ax1, v_ax2, v_ax3] = T_reshape_2[v_ax0, 
v_ax1, v_ax2, v_ax3] + D[v_ax0, v_ax1, v_ax2, v_ax3]
+                    T_add[v_ax0, v_ax1, v_ax2, v_ax3] = T_reshape_2[v_ax0, 
v_ax1, v_ax2, v_ax3] + bias[v_ax0, v_ax1, v_ax2, v_ax3]
             for ax0, ax1, ax2 in T.grid(T.int64(128), T.int64(16), T.int64(8)):
                 with T.block("T_reshape_3"):
                     v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
@@ -3509,14 +3504,14 @@ def test_attention():
                     v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
                     T.reads(T_reshape_3[v_i0, v_i1, v_i2])
                     T.writes(trilu[v_i0, v_i1, v_i2])
-                    trilu[v_i0, v_i1, v_i2] = T.Select(v_i2 <= v_i1, 
T_reshape_3[v_i0, v_i1, v_i2], T.float32(0))
+                    trilu[v_i0, v_i1, v_i2] = T.Select(v_i2 <= v_i1, 
T_reshape_3[v_i0, v_i1, v_i2], T.float32(0.0))
             for ax0, ax1, ax2, k2 in T.grid(T.int64(128), T.int64(16), 
T.int64(1), T.int64(8)):
                 with T.block("trilu_red"):
                     v_ax0, v_ax1, v_ax2, v_k2 = T.axis.remap("SSSR", [ax0, 
ax1, ax2, k2])
                     T.reads(trilu[v_ax0, v_ax1, v_k2])
                     T.writes(trilu_red[v_ax0, v_ax1, v_ax2])
                     with T.init():
-                        trilu_red[v_ax0, v_ax1, v_ax2] = 
T.float32(-3.4028234663852886e+38)
+                        trilu_red[v_ax0, v_ax1, v_ax2] = 
T.float32(-340282346638528859811704183484516925440.0)
                     trilu_red[v_ax0, v_ax1, v_ax2] = T.max(trilu_red[v_ax0, 
v_ax1, v_ax2], trilu[v_ax0, v_ax1, v_k2])
             for ax0, ax1, ax2 in T.grid(T.int64(128), T.int64(16), T.int64(8)):
                 with T.block("T_subtract"):
@@ -3535,14 +3530,14 @@ def test_attention():
                     v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
                     T.reads(compute[v_i0, v_i1, v_i2])
                     T.writes(trilu_1[v_i0, v_i1, v_i2])
-                    trilu_1[v_i0, v_i1, v_i2] = T.Select(v_i2 <= v_i1, 
compute[v_i0, v_i1, v_i2], T.float32(0))
+                    trilu_1[v_i0, v_i1, v_i2] = T.Select(v_i2 <= v_i1, 
compute[v_i0, v_i1, v_i2], T.float32(0.0))
             for ax0, ax1, ax2, k2 in T.grid(T.int64(128), T.int64(16), 
T.int64(1), T.int64(8)):
                 with T.block("trilu_red_1"):
                     v_ax0, v_ax1, v_ax2, v_k2 = T.axis.remap("SSSR", [ax0, 
ax1, ax2, k2])
                     T.reads(trilu_1[v_ax0, v_ax1, v_k2])
                     T.writes(trilu_red_1[v_ax0, v_ax1, v_ax2])
                     with T.init():
-                        trilu_red_1[v_ax0, v_ax1, v_ax2] = T.float32(0)
+                        trilu_red_1[v_ax0, v_ax1, v_ax2] = T.float32(0.0)
                     trilu_red_1[v_ax0, v_ax1, v_ax2] = trilu_red_1[v_ax0, 
v_ax1, v_ax2] + trilu_1[v_ax0, v_ax1, v_k2]
             for ax0, ax1, ax2 in T.grid(T.int64(128), T.int64(16), T.int64(8)):
                 with T.block("T_divide"):
@@ -3553,23 +3548,23 @@ def test_attention():
             for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(32), 
T.int64(8), T.int64(16)):
                 with T.block("T_transpose_2"):
                     v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, 
ax1, ax2, ax3])
-                    T.reads(C[v_ax0, v_ax2, v_ax1, v_ax3])
+                    T.reads(v[v_ax0, v_ax2, v_ax1, v_ax3])
                     T.writes(T_transpose_3[v_ax0, v_ax1, v_ax2, v_ax3])
-                    T_transpose_3[v_ax0, v_ax1, v_ax2, v_ax3] = C[v_ax0, 
v_ax2, v_ax1, v_ax3]
+                    T_transpose_3[v_ax0, v_ax1, v_ax2, v_ax3] = v[v_ax0, 
v_ax2, v_ax1, v_ax3]
             for ax0, ax1, ax2 in T.grid(T.int64(128), T.int64(8), T.int64(16)):
                 with T.block("T_reshape_4"):
                     v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                     T.reads(T_transpose_3[((v_ax2 // T.int64(16) + v_ax1) // 
T.int64(8) + v_ax0) % T.int64(128) // T.int64(32), ((v_ax2 // T.int64(16) + 
v_ax1) // T.int64(8) + v_ax0) % T.int64(32), (v_ax2 // T.int64(16) + v_ax1) % 
T.int64(8), v_ax2 % T.int64(16)])
                     T.writes(T_reshape_4[v_ax0, v_ax1, v_ax2])
                     T_reshape_4[v_ax0, v_ax1, v_ax2] = T_transpose_3[((v_ax2 
// T.int64(16) + v_ax1) // T.int64(8) + v_ax0) % T.int64(128) // T.int64(32), 
((v_ax2 // T.int64(16) + v_ax1) // T.int64(8) + v_ax0) % T.int64(32), (v_ax2 // 
T.int64(16) + v_ax1) % T.int64(8), v_ax2 % T.int64(16)]
-            for b, i, j, k in T.grid(T.int64(128), T.int64(16), T.int64(16), 
T.int64(8)):
+            for b, i, j, k_1 in T.grid(T.int64(128), T.int64(16), T.int64(16), 
T.int64(8)):
                 with T.block("T_batch_matmul_NN"):
-                    v_b, v_i, v_j, v_k = T.axis.remap("SSSR", [b, i, j, k])
+                    v_b, v_i, v_j, v_k = T.axis.remap("SSSR", [b, i, j, k_1])
                     T.reads(T_divide[v_b, v_i, v_k], T_reshape_4[v_b, v_k, 
v_j])
                     T.writes(T_batch_matmul_NN[v_b, v_i, v_j])
                     T.block_attr({"layout_free_placeholders": [T_reshape_4]})
                     with T.init():
-                        T_batch_matmul_NN[v_b, v_i, v_j] = T.float32(0)
+                        T_batch_matmul_NN[v_b, v_i, v_j] = T.float32(0.0)
                     T_batch_matmul_NN[v_b, v_i, v_j] = T_batch_matmul_NN[v_b, 
v_i, v_j] + T_divide[v_b, v_i, v_k] * T_reshape_4[v_b, v_k, v_j]
             for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(32), 
T.int64(16), T.int64(16)):
                 with T.block("T_reshape_5"):
@@ -3589,7 +3584,6 @@ def test_attention():
             cls = Expected
             gv = R.call_tir(cls.attention_bias, (q, k, v, bias), 
out_sinfo=R.Tensor((4, 16, 32, 16), dtype="float32"))
             return gv
-
     # fmt: on
     mod = LegalizeOps()(Attention)
     tvm.ir.assert_structural_equal(mod, Expected)
diff --git a/tests/python/relax/test_transform_legalize_ops_qdq.py 
b/tests/python/relax/test_transform_legalize_ops_qdq.py
index 55f1acadb1..09706c637e 100644
--- a/tests/python/relax/test_transform_legalize_ops_qdq.py
+++ b/tests/python/relax/test_transform_legalize_ops_qdq.py
@@ -212,7 +212,7 @@ def test_quantize_fp32_to_int8_scalar_param():
                         "int8",
                         T.max(
                             T.min(
-                                T.round(A[v_i0, v_i1] * T.float32(0.5)) + 
T.float32(1),
+                                T.round(A[v_i0, v_i1] / T.float32(2)) + 
T.float32(1),
                                 T.float32(127),
                             ),
                             T.float32(-128),
@@ -311,7 +311,7 @@ def test_quantize_fp16_to_int8_scalar_param():
                         "int8",
                         T.max(
                             T.min(
-                                T.round(A[v_i0, v_i1] * T.float16(0.5)) + 
T.float16(1),
+                                T.round(A[v_i0, v_i1] / T.float16(2)) + 
T.float16(1),
                                 T.float16(127),
                             ),
                             T.float16(-128),
diff --git 
a/tests/python/relax/test_transform_legalize_ops_search_statistical.py 
b/tests/python/relax/test_transform_legalize_ops_search_statistical.py
index f8dab89815..7edfff3dfc 100644
--- a/tests/python/relax/test_transform_legalize_ops_search_statistical.py
+++ b/tests/python/relax/test_transform_legalize_ops_search_statistical.py
@@ -627,7 +627,7 @@ def test_mean():
                     ax0, ax1 = T.axis.remap("SS", [i0, i1])
                     T.reads(rxplaceholder_red[ax0, ax1])
                     T.writes(T_divide[ax0, ax1])
-                    T_divide[ax0, ax1] = rxplaceholder_red[ax0, ax1] * 
T.float32(0.1)
+                    T_divide[ax0, ax1] = rxplaceholder_red[ax0, ax1] / 
T.float32(10)
     # fmt: on
 
     mod = LegalizeOps()(Mean)
@@ -718,7 +718,7 @@ def test_std():
                     v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, 
ax1, ax2, ax3])
                     T.reads(rxplaceholder_red[v_ax0, v_ax1, v_ax2, v_ax3])
                     T.writes(T_divide[v_ax0, v_ax1, v_ax2, v_ax3])
-                    T_divide[v_ax0, v_ax1, v_ax2, v_ax3] = 
rxplaceholder_red[v_ax0, v_ax1, v_ax2, v_ax3] * T.float32(0.0083333333333333332)
+                    T_divide[v_ax0, v_ax1, v_ax2, v_ax3] = 
rxplaceholder_red[v_ax0, v_ax1, v_ax2, v_ax3] / T.float32(120.0)
             for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), 
T.int64(4), T.int64(5)):
                 with T.block("T_subtract"):
                     v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, 
ax1, ax2, ax3])
@@ -743,7 +743,7 @@ def test_std():
                 vi = T.axis.spatial(1, T.int64(0))
                 T.reads(T_multiply_red[()])
                 T.writes(T_divide_1[()])
-                T_divide_1[()] = T_multiply_red[()] * 
T.float32(0.0083333333333333332)
+                T_divide_1[()] = T_multiply_red[()] / T.float32(120.0)
             with T.block("compute"):
                 vi = T.axis.spatial(1, T.int64(0))
                 T.reads(T_divide_1[()])
@@ -881,7 +881,7 @@ def test_variance():
                     ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
                     T.reads(rxplaceholder_red[ax0, ax1, ax2, ax3])
                     T.writes(T_divide_1[ax0, ax1, ax2, ax3])
-                    T_divide_1[ax0, ax1, ax2, ax3] = rxplaceholder_red[ax0, 
ax1, ax2, ax3] * T.float32(0.10000000000000001)
+                    T_divide_1[ax0, ax1, ax2, ax3] = rxplaceholder_red[ax0, 
ax1, ax2, ax3] / T.float32(10.0)
             for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(4), 
T.int64(5)):
                 with T.block("T_subtract"):
                     ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
@@ -907,7 +907,7 @@ def test_variance():
                     ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
                     T.reads(T_multiply_red[ax0, ax1, ax2, ax3])
                     T.writes(T_divide[ax0, ax1, ax2, ax3])
-                    T_divide[ax0, ax1, ax2, ax3] = T_multiply_red[ax0, ax1, 
ax2, ax3] * T.float32(0.10000000000000001)
+                    T_divide[ax0, ax1, ax2, ax3] = T_multiply_red[ax0, ax1, 
ax2, ax3] / T.float32(10)
     # fmt: on
 
     mod = LegalizeOps()(Variance)
@@ -1027,7 +1027,7 @@ def test_variance_no_keepdims():
                     v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, 
ax1, ax2, ax3])
                     T.reads(rxplaceholder_red[v_ax0, v_ax1, v_ax2, v_ax3])
                     T.writes(T_divide_1[v_ax0, v_ax1, v_ax2, v_ax3])
-                    T_divide_1[v_ax0, v_ax1, v_ax2, v_ax3] = 
rxplaceholder_red[v_ax0, v_ax1, v_ax2, v_ax3] * T.float32(0.10000000000000001)
+                    T_divide_1[v_ax0, v_ax1, v_ax2, v_ax3] = 
rxplaceholder_red[v_ax0, v_ax1, v_ax2, v_ax3] / T.float32(10)
             for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), 
T.int64(4), T.int64(5)):
                 with T.block("T_subtract"):
                     v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, 
ax1, ax2, ax3])
@@ -1053,7 +1053,7 @@ def test_variance_no_keepdims():
                     v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                     T.reads(T_multiply_red[v_ax0, v_ax1])
                     T.writes(T_divide[v_ax0, v_ax1])
-                    T_divide[v_ax0, v_ax1] = T_multiply_red[v_ax0, v_ax1] * 
T.float32(0.10000000000000001)
+                    T_divide[v_ax0, v_ax1] = T_multiply_red[v_ax0, v_ax1] / 
T.float32(10)
 
         @R.function
         def main(x: R.Tensor((2, 3, 4, 5), dtype="float32")) -> R.Tensor((3, 
4), dtype="float32"):

Reply via email to