This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 9a3b3dd1ce [TVMScript] Fix parse minimal i32 literal for tir script 
(#12772)
9a3b3dd1ce is described below

commit 9a3b3dd1ceac8f9b065636146756baead39b8ab6
Author: wrongtest <[email protected]>
AuthorDate: Fri Sep 16 07:40:55 2022 +0800

    [TVMScript] Fix parse minimal i32 literal for tir script (#12772)
    
    This change tries to fix an issue due to #12515.
    
    Previously the logic for `-2147483648` is  `parse(-literal)` = 
`-parse(literal)`, and all integer literals are converted to i32 (either the 
literal value actually overflow or not).
    
    Since after #12515, parse `2147483648` results in an i64 typed integer 
rather than i32, `-2147483648` then becomes an i64 integer too, which is not 
reasonable.
---
 python/tvm/script/parser.py                       |  7 +++++++
 tests/python/unittest/test_tvmscript_roundtrip.py | 10 ++++++++++
 2 files changed, 17 insertions(+)

diff --git a/python/tvm/script/parser.py b/python/tvm/script/parser.py
index e9b4286eda..c34aae2345 100644
--- a/python/tvm/script/parser.py
+++ b/python/tvm/script/parser.py
@@ -906,6 +906,13 @@ class TVMScriptParser(Transformer):
                 )
             if node.func_name.name in self._unaryop_maker:
                 rhs = self.transform(node.params[0])
+                if node.func_name.name == ast.BuiltinOp.USub and isinstance(
+                    node.params[0], ast.Constant
+                ):
+                    # '-literal' should be parsed together for proper literal 
type inference
+                    if not isinstance(rhs, (tvm.tir.IntImm, tvm.tir.FloatImm)):
+                        self.report_error("The literal is illegal after -", 
node.params[0].span)
+                    return tvm.tir.const(-rhs.value)
                 return self._unaryop_maker[node.func_name.name](
                     rhs, span=tvm_span_from_synr(node.span)
                 )
diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py 
b/tests/python/unittest/test_tvmscript_roundtrip.py
index 1762278955..1f5871b488 100644
--- a/tests/python/unittest/test_tvmscript_roundtrip.py
+++ b/tests/python/unittest/test_tvmscript_roundtrip.py
@@ -3381,6 +3381,15 @@ def float_infinity():
     return func
 
 
+def minimal_i32_literal():
+    @T.prim_func
+    def func() -> None:
+        T.evaluate(T.int32(-2147483648))
+        T.evaluate(-T.int64(2147483648))
+
+    return func
+
+
 ir_generator = tvm.testing.parameter(
     opt_gemm_normalize,
     opt_gemm_lower,
@@ -3423,6 +3432,7 @@ ir_generator = tvm.testing.parameter(
     decl_buffer,
     allocate_and_decl_buffer,
     float_infinity,
+    minimal_i32_literal,
 )
 
 

Reply via email to