This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch feature/2022-11-09/printer-explicit-ir-node
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit 9cd4ad0077b15bdc03dfcc2db19f99ab224c9d79
Author: Junru Shao <[email protected]>
AuthorDate: Wed Nov 9 19:29:17 2022 -0800

    [TIR] Make syntax of AST nodes different than ops
    
    As part of effort of more formal TIR semantics, we want to more
    explicitly differentiate TIR AST nodes (defined in `tir/expr.h`)
    and TIR ops (defined in `tir/op.h`).
    
    A naming convention is that:
    - Lowercased methods, for example, `tvm.tir.mul`, means an TIR op, which
      will be eagerly constant-folded, i.e. `mul(1, 2)` returns `3`
      immediately rather than creating an AST node.
    - Capitalized callable, for example, `Mul`, means creating an AST node
      without constant folding.
    
    This PR makes this behavior more explictly by printing `T.Mul(a, b)`
    directly when `a` and `b` are both constants, rather than sugaring it
    into `mul(a. b)` or `a * b`, so that the difference between an op and
    an AST node is clarified.
    
    Co-authored-by: Yaxing Cai <[email protected]>
---
 python/tvm/script/tir/intrin.py                    | 80 +++++++++++++++++-
 src/printer/tvmscript_printer.cc                   | 97 +++++++++++++---------
 .../test_hexagon/test_async_dma_pipeline.py        | 23 +++--
 .../test_hexagon/test_parallel_hvx_load_vtcm.py    | 49 ++++-------
 .../unittest/test_aot_legalize_packed_call.py      | 12 +--
 .../unittest/test_meta_schedule_space_cuda.py      |  2 +-
 .../test_tir_transform_inject_software_pipeline.py | 16 ++--
 .../test_tir_transform_inject_virtual_thread.py    | 17 ++--
 .../unittest/test_tir_transform_thread_sync.py     |  2 +-
 9 files changed, 185 insertions(+), 113 deletions(-)

diff --git a/python/tvm/script/tir/intrin.py b/python/tvm/script/tir/intrin.py
index bd9aa1fdad..8e24f27325 100644
--- a/python/tvm/script/tir/intrin.py
+++ b/python/tvm/script/tir/intrin.py
@@ -17,12 +17,13 @@
 """TVM Script Parser Intrinsic Classes"""
 # pylint: disable=redefined-builtin, relative-beyond-top-level
 import builtins
-from typing import List, Any
+from typing import Any, List
 
 import tvm.tir
 from tvm.tir import FloatImm
-from ..registry import register
+
 from ...target import codegen
+from ..registry import register
 from ..utils import get_param_list, tvm_span_from_synr
 
 
@@ -229,3 +230,78 @@ def comm_reducer(lambda_io, identities, span):
 def llvm_lookup_intrinsic_id(name, span):
     # pylint: disable=unused-argument
     return codegen.llvm_lookup_intrinsic_id(name)
+
+
+@register
+def FloorMod(x, y, span):  # pylint: disable=invalid-name
+    return tvm.tir.FloorMod(x, y, span)
+
+
+@register
+def FloorDiv(x, y, span):  # pylint: disable=invalid-name
+    return tvm.tir.FloorDiv(x, y, span)
+
+
+@register
+def Mul(x, y, span):  # pylint: disable=invalid-name
+    return tvm.tir.Mul(x, y, span)
+
+
+@register
+def Div(x, y, span):  # pylint: disable=invalid-name
+    return tvm.tir.Div(x, y, span)
+
+
+@register
+def Add(x, y, span):  # pylint: disable=invalid-name
+    return tvm.tir.Add(x, y, span)
+
+
+@register
+def Sub(x, y, span):  # pylint: disable=invalid-name
+    return tvm.tir.Sub(x, y, span)
+
+
+@register
+def LT(x, y, span):  # pylint: disable=invalid-name
+    return tvm.tir.LT(x, y, span)
+
+
+@register
+def LE(x, y, span):  # pylint: disable=invalid-name
+    return tvm.tir.LE(x, y, span)
+
+
+@register
+def GT(x, y, span):  # pylint: disable=invalid-name
+    return tvm.tir.GT(x, y, span)
+
+
+@register
+def GE(x, y, span):  # pylint: disable=invalid-name
+    return tvm.tir.GE(x, y, span)
+
+
+@register
+def EQ(x, y, span):  # pylint: disable=invalid-name
+    return tvm.tir.EQ(x, y, span)
+
+
+@register
+def NE(x, y, span):  # pylint: disable=invalid-name
+    return tvm.tir.NE(x, y, span)
+
+
+@register
+def And(x, y, span):  # pylint: disable=invalid-name
+    return tvm.tir.And(x, y, span)
+
+
+@register
+def Or(x, y, span):  # pylint: disable=invalid-name
+    return tvm.tir.Or(x, y, span)
+
+
+@register
+def Cast(dtype, value, span):  # pylint: disable=invalid-name
+    return tvm.tir.Cast(dtype, value, span)
diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc
index 64a576ef52..d7a3a406e3 100644
--- a/src/printer/tvmscript_printer.cc
+++ b/src/printer/tvmscript_printer.cc
@@ -788,7 +788,7 @@ Doc TVMScriptPrinter::VisitExpr_(const StringImmNode* op, 
ExprPrecedence* out_pr
 Doc TVMScriptPrinter::VisitExpr_(const CastNode* op, ExprPrecedence* 
out_precedence) {
   *out_precedence = ExprPrecedence::kIdentity;
   Doc doc;
-  doc << tir_prefix_ << ".cast(" << Print(op->value) << ", " << 
PrintDType(op->dtype) << ")";
+  doc << tir_prefix_ << ".Cast(" << PrintDType(op->dtype) << ", " << 
Print(op->value) << ")";
   return doc;
 }
 
@@ -798,46 +798,61 @@ Doc TVMScriptPrinter::VisitExpr_(const VarNode* op, 
ExprPrecedence* out_preceden
   return meta_.InMeta(var) ? meta_.GetMetaNode(var) : 
AllocVar(GetRef<Var>(op));
 }
 
-#define TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(OpName, OpString, OpPrecedence)    
        \
-  Doc TVMScriptPrinter::VisitExpr_(const OpName* op, ExprPrecedence* 
out_precedence) { \
-    Doc doc;                                                                   
        \
-    ExprPrecedence lhs_precedence = ExprPrecedence::kUnknown;                  
        \
-    ExprPrecedence rhs_precedence = ExprPrecedence::kUnknown;                  
        \
-    /* Get children expr out_precedence */                                     
        \
-    Doc lhs_doc = VisitExpr(op->a, &lhs_precedence);                           
        \
-    Doc rhs_doc = VisitExpr(op->b, &rhs_precedence);                           
        \
-    ICHECK(lhs_precedence != ExprPrecedence::kUnknown);                        
        \
-    ICHECK(rhs_precedence != ExprPrecedence::kUnknown);                        
        \
-    /* Update out_precedence of current node. */                               
        \
-    *out_precedence = OpPrecedence;                                            
        \
-    if (lhs_precedence > OpPrecedence) {                                       
        \
-      doc << "(" << lhs_doc << ")";                                            
        \
-    } else {                                                                   
        \
-      doc << lhs_doc;                                                          
        \
-    }                                                                          
        \
-    doc << OpString;                                                           
        \
-    if (rhs_precedence >= OpPrecedence) {                                      
        \
-      doc << "(" << rhs_doc << ")";                                            
        \
-    } else {                                                                   
        \
-      doc << rhs_doc;                                                          
        \
-    }                                                                          
        \
-    return doc;                                                                
        \
-  }
-
-TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(MulNode, " * ", 
ExprPrecedence::kMultiplicationDivision)
-TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(DivNode, " / ", 
ExprPrecedence::kMultiplicationDivision)
-TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(FloorDivNode, " // ", 
ExprPrecedence::kMultiplicationDivision)
-TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(FloorModNode, " % ", 
ExprPrecedence::kMultiplicationDivision)
-TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(AddNode, " + ", 
ExprPrecedence::kAdditionSubtraction)
-TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(SubNode, " - ", 
ExprPrecedence::kAdditionSubtraction)
-TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(LTNode, " < ", ExprPrecedence::kRelational)
-TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(LENode, " <= ", 
ExprPrecedence::kRelational)
-TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(GTNode, " > ", ExprPrecedence::kRelational)
-TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(GENode, " >= ", 
ExprPrecedence::kRelational)
-TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(EQNode, " == ", ExprPrecedence::kEquality)
-TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(NENode, " != ", ExprPrecedence::kEquality)
-TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(AndNode, " and ", ExprPrecedence::kAnd)
-TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(OrNode, " or ", ExprPrecedence::kOr)
+bool WillPrintConstScalar(const PrimExpr& expr) {
+  if (const auto* imm = expr.as<IntImmNode>()) {
+    DataType dtype = imm->dtype;
+    return dtype == DataType::Int(32) || dtype == DataType::Bool();
+  }
+  return false;
+}
+
+#define TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(OpName, OpString, OpClass, 
OpPrecedence)              \
+  Doc TVMScriptPrinter::VisitExpr_(const OpName* op, ExprPrecedence* 
out_precedence) {            \
+    Doc doc;                                                                   
                   \
+    if (WillPrintConstScalar(op->a) && WillPrintConstScalar(op->b)) {          
                   \
+      *out_precedence = ExprPrecedence::kIdentity;                             
                   \
+      doc << tir_prefix_ << "." << OpClass << "(" << Print(op->a) << ", " << 
Print(op->b) << ")"; \
+      return doc;                                                              
                   \
+    }                                                                          
                   \
+    ExprPrecedence lhs_precedence = ExprPrecedence::kUnknown;                  
                   \
+    ExprPrecedence rhs_precedence = ExprPrecedence::kUnknown;                  
                   \
+    /* Get children expr out_precedence */                                     
                   \
+    Doc lhs_doc = VisitExpr(op->a, &lhs_precedence);                           
                   \
+    Doc rhs_doc = VisitExpr(op->b, &rhs_precedence);                           
                   \
+    ICHECK(lhs_precedence != ExprPrecedence::kUnknown);                        
                   \
+    ICHECK(rhs_precedence != ExprPrecedence::kUnknown);                        
                   \
+    /* Update out_precedence of current node. */                               
                   \
+    *out_precedence = OpPrecedence;                                            
                   \
+    if (lhs_precedence > OpPrecedence) {                                       
                   \
+      doc << "(" << lhs_doc << ")";                                            
                   \
+    } else {                                                                   
                   \
+      doc << lhs_doc;                                                          
                   \
+    }                                                                          
                   \
+    doc << OpString;                                                           
                   \
+    if (rhs_precedence >= OpPrecedence) {                                      
                   \
+      doc << "(" << rhs_doc << ")";                                            
                   \
+    } else {                                                                   
                   \
+      doc << rhs_doc;                                                          
                   \
+    }                                                                          
                   \
+    return doc;                                                                
                   \
+  }
+
+TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(MulNode, " * ", "Mul", 
ExprPrecedence::kMultiplicationDivision)
+TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(DivNode, " / ", "Div", 
ExprPrecedence::kMultiplicationDivision)
+TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(FloorDivNode, " // ", "FloorDiv",
+                                    ExprPrecedence::kMultiplicationDivision)
+TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(FloorModNode, " % ", "FloorMod",
+                                    ExprPrecedence::kMultiplicationDivision)
+TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(AddNode, " + ", "Add", 
ExprPrecedence::kAdditionSubtraction)
+TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(SubNode, " - ", "Sub", 
ExprPrecedence::kAdditionSubtraction)
+TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(LTNode, " < ", "LT", 
ExprPrecedence::kRelational)
+TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(LENode, " <= ", "LE", 
ExprPrecedence::kRelational)
+TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(GTNode, " > ", "GT", 
ExprPrecedence::kRelational)
+TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(GENode, " >= ", "GE", 
ExprPrecedence::kRelational)
+TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(EQNode, " == ", "EQ", 
ExprPrecedence::kEquality)
+TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(NENode, " != ", "NE", 
ExprPrecedence::kEquality)
+TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(AndNode, " and ", "And", 
ExprPrecedence::kAnd)
+TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(OrNode, " or ", "Or", ExprPrecedence::kOr)
 
 Doc TVMScriptPrinter::VisitExpr_(const ModNode* op, ExprPrecedence* 
out_precedence) {
   *out_precedence = ExprPrecedence::kIdentity;
diff --git a/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py 
b/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py
index a7a05c2aa3..19b380c1bd 100644
--- a/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py
+++ b/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py
@@ -18,11 +18,10 @@
 """ Test different strategies for loading data into vtcm before running HVX 
workloads. """
 
 import numpy as np
-import tvm
 import pytest
-
-from tvm.script import tir as T
+import tvm
 from numpy.random import default_rng
+from tvm.script import tir as T
 
 VRMPY_SIZE_B = 128
 VRMPY_SIZE_INT32 = 32
@@ -126,12 +125,12 @@ def get_single_dma_schedule(size_a, size_w):
     @T.prim_func
     def operator(a_input: T.handle, b_input: T.handle, c_output: T.handle) -> 
None:
         T.func_attr({"global_symbol": "main", "tir.noalias": True})
-        a_buffer = T.match_buffer(a_input, a_shape, dtype="uint8", 
mem_scope="global")
-        w_buffer = T.match_buffer(b_input, w_shape, dtype="uint8", 
mem_scope="global")
-        c_buffer = T.match_buffer(c_output, out_shape, dtype="int32", 
mem_scope="global")
-        a_global_vtcm = T.alloc_buffer(a_shape, dtype="uint8", 
mem_scope="global.vtcm")
-        w_global_vtcm = T.alloc_buffer(w_shape, dtype="uint8", 
mem_scope="global.vtcm")
-        c_global_vtcm = T.alloc_buffer(out_shape, dtype="int32", 
mem_scope="global.vtcm")
+        a_buffer = T.match_buffer(a_input, a_shape, dtype="uint8", 
scope="global")
+        w_buffer = T.match_buffer(b_input, w_shape, dtype="uint8", 
scope="global")
+        c_buffer = T.match_buffer(c_output, out_shape, dtype="int32", 
scope="global")
+        a_global_vtcm = T.alloc_buffer(a_shape, dtype="uint8", 
scope="global.vtcm")
+        w_global_vtcm = T.alloc_buffer(w_shape, dtype="uint8", 
scope="global.vtcm")
+        c_global_vtcm = T.alloc_buffer(out_shape, dtype="int32", 
scope="global.vtcm")
         T.evaluate(
             T.tvm_call_packed(
                 "device_api.hexagon.mem_copy_DLTensor",
@@ -153,7 +152,7 @@ def get_single_dma_schedule(size_a, size_w):
                     0,
                     dtype="handle",
                 ),
-                T.cast(a_bytes, dtype="int"),
+                T.Cast("int", a_bytes),
                 dtype="int32",
             )
         )
@@ -178,7 +177,7 @@ def get_single_dma_schedule(size_a, size_w):
                     0,
                     dtype="handle",
                 ),
-                T.cast(w_bytes, dtype="int"),
+                T.Cast("int", w_bytes),
                 dtype="int32",
             )
         )
@@ -222,7 +221,7 @@ def get_single_dma_schedule(size_a, size_w):
                     0,
                     dtype="handle",
                 ),
-                T.cast(a_bytes, dtype="int"),
+                T.Cast("int", a_bytes),
                 dtype="int32",
             )
         )
diff --git a/tests/python/contrib/test_hexagon/test_parallel_hvx_load_vtcm.py 
b/tests/python/contrib/test_hexagon/test_parallel_hvx_load_vtcm.py
index fb398f4397..e6fc0a3c20 100644
--- a/tests/python/contrib/test_hexagon/test_parallel_hvx_load_vtcm.py
+++ b/tests/python/contrib/test_hexagon/test_parallel_hvx_load_vtcm.py
@@ -18,9 +18,8 @@
 """ Test different strategies for loading data into vtcm before running HVX 
workloads. """
 
 import numpy as np
-from numpy.random import default_rng
-
 import tvm
+from numpy.random import default_rng
 from tvm.script import tir as T
 
 from .infrastructure import get_hexagon_target
@@ -109,17 +108,17 @@ def preloaded_vrmpy(operations):
             [T.cast(operations, "int32") * 128],
             dtype="uint8",
             align=128,
-            mem_scope="global.vtcm",
+            scope="global.vtcm",
         )
         b_buffer = T.match_buffer(
             b,
             [T.cast(operations, "int32") * 128],
             dtype="uint8",
             align=128,
-            mem_scope="global.vtcm",
+            scope="global.vtcm",
         )
         c_buffer = T.match_buffer(
-            c, [T.cast(operations, "int32") * 32], dtype="int32", align=128, 
mem_scope="global.vtcm"
+            c, [T.cast(operations, "int32") * 32], dtype="int32", align=128, 
scope="global.vtcm"
         )
         for n in T.grid(operations):
             with T.block("c_buffer"):
@@ -149,21 +148,13 @@ def preallocated_vrmpy(operations):
         a: T.handle, b: T.handle, c: T.handle, a_v: T.handle, b_v: T.handle, 
c_v: T.handle
     ) -> None:
         T.func_attr({"global_symbol": "main", "tir.noalias": True})
-        a_buffer = T.match_buffer(
-            a, [operations, 128], dtype="uint8", align=128, mem_scope="global"
-        )
-        b_buffer = T.match_buffer(
-            b, [operations, 128], dtype="uint8", align=128, mem_scope="global"
-        )
-        c_buffer = T.match_buffer(c, [operations, 32], dtype="int32", 
align=128, mem_scope="global")
-        a_global_vtcm = T.match_buffer(
-            a_v, [size], dtype="uint8", align=128, mem_scope="global.vtcm"
-        )
-        b_global_vtcm = T.match_buffer(
-            b_v, [size], dtype="uint8", align=128, mem_scope="global.vtcm"
-        )
+        a_buffer = T.match_buffer(a, [operations, 128], dtype="uint8", 
align=128, scope="global")
+        b_buffer = T.match_buffer(b, [operations, 128], dtype="uint8", 
align=128, scope="global")
+        c_buffer = T.match_buffer(c, [operations, 32], dtype="int32", 
align=128, scope="global")
+        a_global_vtcm = T.match_buffer(a_v, [size], dtype="uint8", align=128, 
scope="global.vtcm")
+        b_global_vtcm = T.match_buffer(b_v, [size], dtype="uint8", align=128, 
scope="global.vtcm")
         c_global_vtcm = T.match_buffer(
-            c_v, [out_size], dtype="int32", align=128, mem_scope="global.vtcm"
+            c_v, [out_size], dtype="int32", align=128, scope="global.vtcm"
         )
         for n, i in T.grid(operations, 128):
             with T.block("a_buffer_global.vtcm"):
@@ -212,21 +203,13 @@ def preallocated_single_dma_vrmpy(operations):
         c_v: T.handle,
     ) -> None:
         T.func_attr({"global_symbol": "main", "tir.noalias": True})
-        a_buffer = T.match_buffer(
-            a, [operations, 128], dtype="uint8", align=128, mem_scope="global"
-        )
-        b_buffer = T.match_buffer(
-            b, [operations, 128], dtype="uint8", align=128, mem_scope="global"
-        )
-        c_buffer = T.match_buffer(c, [operations, 32], dtype="int32", 
align=128, mem_scope="global")
-        a_global_vtcm = T.match_buffer(
-            a_v, [size], dtype="uint8", align=128, mem_scope="global.vtcm"
-        )
-        b_global_vtcm = T.match_buffer(
-            b_v, [size], dtype="uint8", align=128, mem_scope="global.vtcm"
-        )
+        a_buffer = T.match_buffer(a, [operations, 128], dtype="uint8", 
align=128, scope="global")
+        b_buffer = T.match_buffer(b, [operations, 128], dtype="uint8", 
align=128, scope="global")
+        c_buffer = T.match_buffer(c, [operations, 32], dtype="int32", 
align=128, scope="global")
+        a_global_vtcm = T.match_buffer(a_v, [size], dtype="uint8", align=128, 
scope="global.vtcm")
+        b_global_vtcm = T.match_buffer(b_v, [size], dtype="uint8", align=128, 
scope="global.vtcm")
         c_global_vtcm = T.match_buffer(
-            c_v, [out_size], dtype="int32", align=128, mem_scope="global.vtcm"
+            c_v, [out_size], dtype="int32", align=128, scope="global.vtcm"
         )
         T.evaluate(
             T.tvm_call_packed(
diff --git a/tests/python/unittest/test_aot_legalize_packed_call.py 
b/tests/python/unittest/test_aot_legalize_packed_call.py
index 9c597a55e5..cd0114d464 100644
--- a/tests/python/unittest/test_aot_legalize_packed_call.py
+++ b/tests/python/unittest/test_aot_legalize_packed_call.py
@@ -15,11 +15,11 @@
 # specific language governing permissions and limitations
 # under the License.
 # pylint: disable=missing-function-docstring,missing-module-docstring
+import pytest
 import tvm
-from tvm.script import tir as T
-from tvm import tir
 import tvm.testing
-import pytest
+from tvm import tir
+from tvm.script import tir as T
 
 
 @tvm.script.ir_module
@@ -85,7 +85,7 @@ class Expected:
                     T.tvm_stack_make_shape(1, dtype="handle"),
                     T.reinterpret(T.uint64(0), dtype="handle"),
                     T.uint32(1),
-                    T.cast(0, dtype="float32"),
+                    T.Cast("float32", 0),
                     0,
                     dtype="handle",
                 ),
@@ -94,7 +94,7 @@ class Expected:
                     T.tvm_stack_make_shape(1, dtype="handle"),
                     T.reinterpret(T.uint64(0), dtype="handle"),
                     T.uint32(1),
-                    T.cast(0, dtype="float32"),
+                    T.Cast("float32", 0),
                     0,
                     dtype="handle",
                 ),
@@ -103,7 +103,7 @@ class Expected:
                     T.tvm_stack_make_shape(1, dtype="handle"),
                     T.reinterpret(T.uint64(0), dtype="handle"),
                     T.uint32(1),
-                    T.cast(0, dtype="float32"),
+                    T.Cast("float32", 0),
                     0,
                     dtype="handle",
                 ),
diff --git a/tests/python/unittest/test_meta_schedule_space_cuda.py 
b/tests/python/unittest/test_meta_schedule_space_cuda.py
index 324d8a9ec4..0a518c840d 100644
--- a/tests/python/unittest/test_meta_schedule_space_cuda.py
+++ b/tests/python/unittest/test_meta_schedule_space_cuda.py
@@ -856,7 +856,7 @@ def test_cuda_nrm():
                 for i0_1 in T.thread_binding(128, thread="threadIdx.x"):
                     with T.block("D"):
                         b = T.axis.spatial(1, i0_1)
-                        T.where(0 * 128 + i0_1 < 1)
+                        T.where(T.Mul(0, 128) + i0_1 < 1)
                         T.reads(C_shared[b])
                         T.writes(D[b])
                         D[b] = T.sqrt(C_shared[b], dtype="float32")
diff --git 
a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py 
b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py
index 2a4cabc541..c70525b057 100644
--- a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py
+++ b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py
@@ -14,16 +14,16 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-import pytest
 import sys
-import numpy as np
 
+import numpy as np
+import pytest
 import tvm
 import tvm.testing
 import tvm.tir.tensor_intrin.cuda
-from tvm import tir, te, TVMError
-from tvm.script import tir as T
+from tvm import TVMError, te, tir
 from tvm.meta_schedule.testing import te_workload
+from tvm.script import tir as T
 from tvm.testing.tir import mma_schedule
 from tvm.tir.tensor_intrin.cuda import (
     LDMATRIX_16x16_A_DYN_INTRIN,
@@ -1116,7 +1116,7 @@ def test_simple_compute_async():
     mod = tvm.tir.transform.InjectSoftwarePipeline()(sch.mod)
 
     @T.prim_func
-    def ref(A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), 
"float32"]) -> None:
+    def ref(A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), 
"float32"]):
         for tx in T.thread_binding(16, thread="threadIdx.x"):
             with T.block():
                 T.reads(A[tx, 0:16])
@@ -1127,7 +1127,7 @@ def test_simple_compute_async():
                     T.writes(B[0, tx, 0])
                     with T.attr(0, "async_commit_queue_scope", 0):
                         with T.attr(0, "async_scope", 1):
-                            B[0 % 2, tx, 0] = A[tx, 0] * T.float32(2)
+                            B[T.FloorMod(0, 2), tx, 0] = A[tx, 0] * 
T.float32(2)
                 with T.block():
                     T.reads(A[tx, 1:16], B[0:2, tx, 0])
                     T.writes(B[0:2, tx, 0], C[tx, 0:15])
@@ -1147,11 +1147,11 @@ def test_simple_compute_async():
                                 with T.attr(0, "async_wait_inflight_count", 1):
                                     C[tx, i - 1 + 1] = B[(i - 1 + 1) % 2, tx, 
0] + T.float32(1)
                 with T.block():
-                    T.reads(B[15 % 2, tx, 0])
+                    T.reads(B[T.FloorMod(15, 2), tx, 0])
                     T.writes(C[tx, 15])
                     with T.attr(0, "async_wait_queue_scope", 0):
                         with T.attr(0, "async_wait_inflight_count", 0):
-                            C[tx, 15] = B[15 % 2, tx, 0] + T.float32(1)
+                            C[tx, 15] = B[T.FloorMod(15, 2), tx, 0] + 
T.float32(1)
 
     tvm.ir.assert_structural_equal(mod["main"], ref, True)
 
diff --git a/tests/python/unittest/test_tir_transform_inject_virtual_thread.py 
b/tests/python/unittest/test_tir_transform_inject_virtual_thread.py
index 548f3bc8d1..b4ea4e712d 100644
--- a/tests/python/unittest/test_tir_transform_inject_virtual_thread.py
+++ b/tests/python/unittest/test_tir_transform_inject_virtual_thread.py
@@ -16,7 +16,6 @@
 # under the License.
 import tvm
 from tvm import te
-
 from tvm.script import tir as T
 
 vthread_name = tvm.testing.parameter("vthread", "cthread")
@@ -155,10 +154,10 @@ def test_vthread_simplified():
         B = T.buffer_decl([16], "int32", data=B_data, scope="shared")
         # The indices for B should each be a single Ramp node, and
         # should not be the sum of a Ramp and Broadcast node.
-        B[0 * 4 : 0 * 4 + 4] = T.broadcast(0, 4)
-        B[1 * 4 : 1 * 4 + 4] = T.broadcast(1, 4)
-        B[2 * 4 : 2 * 4 + 4] = T.broadcast(2, 4)
-        B[3 * 4 : 3 * 4 + 4] = T.broadcast(3, 4)
+        B[T.Mul(0, 4) : T.Mul(0, 4) + 4] = T.broadcast(0, 4)
+        B[T.Mul(1, 4) : T.Mul(1, 4) + 4] = T.broadcast(1, 4)
+        B[T.Mul(2, 4) : T.Mul(2, 4) + 4] = T.broadcast(2, 4)
+        B[T.Mul(3, 4) : T.Mul(3, 4) + 4] = T.broadcast(3, 4)
 
     before_mod = tvm.IRModule.from_expr(before_func)
     after_mod = tvm.tir.transform.InjectVirtualThread()(before_mod)
@@ -182,10 +181,10 @@ def test_vthread_vectorized():
     def expected_func():
         B_data = T.allocate([4], "int32x4", "shared")
         B = T.buffer_decl([4], "int32x4", data=B_data, scope="shared")
-        B[0 * 4 / 4] = T.broadcast(0, 4)
-        B[1 * 4 / 4] = T.broadcast(1, 4)
-        B[2 * 4 / 4] = T.broadcast(2, 4)
-        B[3 * 4 / 4] = T.broadcast(3, 4)
+        B[T.Mul(0, 4) / 4] = T.broadcast(0, 4)
+        B[T.Mul(1, 4) / 4] = T.broadcast(1, 4)
+        B[T.Mul(2, 4) / 4] = T.broadcast(2, 4)
+        B[T.Mul(3, 4) / 4] = T.broadcast(3, 4)
 
     before_mod = tvm.IRModule.from_expr(before_func)
     intermediate_mod = tvm.tir.transform.InjectVirtualThread()(before_mod)
diff --git a/tests/python/unittest/test_tir_transform_thread_sync.py 
b/tests/python/unittest/test_tir_transform_thread_sync.py
index 18607ca1a0..c80cd55ea2 100644
--- a/tests/python/unittest/test_tir_transform_thread_sync.py
+++ b/tests/python/unittest/test_tir_transform_thread_sync.py
@@ -102,9 +102,9 @@ def test_sync_read_thread_id_independent_location():
         threadIdx_x = T.env_thread("threadIdx.x")
         blockIdx_x = T.env_thread("blockIdx.x")
         T.preflattened_buffer(p0, [1, 2, 1, 1], dtype="float32", data=p0.data)
-        T.launch_thread(blockIdx_x, 8)
         result_local = T.alloc_buffer([1], dtype="float32", scope="local")
         temp_shared = T.alloc_buffer([1], dtype="float32", scope="shared")
+        T.launch_thread(blockIdx_x, 8)
         T.launch_thread(threadIdx_x, 4)
         result_local[0] = T.float32(0)
         if threadIdx_x < 1:

Reply via email to