This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 14bc5e4585 [FIX][TOPI] Clip with IntImm/FloatImm (#14027)
14bc5e4585 is described below

commit 14bc5e45855f5a80b7c57a53d98bb6016c9bbf53
Author: Ruihang Lai <[email protected]>
AuthorDate: Fri Feb 17 21:14:18 2023 -0500

    [FIX][TOPI] Clip with IntImm/FloatImm (#14027)
    
    Prior to this PR, TOPI clip op only accepts the min/max values with
    Python native float/int type, and rejects FloatImm and IntImm.
    
    This PR enhances the clip op to allow it accept FloatImm and IntImm.
    
    Co-authored-by: Siyuan Feng <[email protected]>
---
 python/tvm/topi/math.py                    | 18 ++++++++++++++----
 tests/python/topi/python/test_topi_clip.py | 16 +++++++++++++---
 2 files changed, 27 insertions(+), 7 deletions(-)

diff --git a/python/tvm/topi/math.py b/python/tvm/topi/math.py
index dd191c49be..4d305d1e2f 100644
--- a/python/tvm/topi/math.py
+++ b/python/tvm/topi/math.py
@@ -18,6 +18,8 @@
 # pylint: disable=redefined-builtin,unused-argument
 import tvm
 from tvm import te
+from tvm.tir import PrimExpr
+
 from . import tag
 from . import cpp
 from .utils import get_const_tuple
@@ -620,9 +622,9 @@ def clip(x, a_min, a_max):
     ----------
     x : tvm.te.Tensor
         Input argument.
-    a_min : int or float
+    a_min : tvm.tir.PrimExpr
         Minimum value.
-    a_max : int or float
+    a_max : tvm.tir.PrimExpr
         Maximum value.
 
     Returns
@@ -633,8 +635,16 @@ def clip(x, a_min, a_max):
 
     def _compute(*indices):
         value = x(*indices)
-        const_min = tvm.tir.const(a_min, value.dtype)
-        const_max = tvm.tir.const(a_max, value.dtype)
+        const_min = (
+            tvm.tir.Cast(value.dtype, a_min)
+            if isinstance(a_min, PrimExpr)
+            else tvm.tir.const(a_min, value.dtype)
+        )
+        const_max = (
+            tvm.tir.Cast(value.dtype, a_max)
+            if isinstance(a_max, PrimExpr)
+            else tvm.tir.const(a_max, value.dtype)
+        )
         return tvm.te.max(tvm.te.min(value, const_max), const_min)
 
     return te.compute(x.shape, _compute)
diff --git a/tests/python/topi/python/test_topi_clip.py 
b/tests/python/topi/python/test_topi_clip.py
index 21546e8b57..68bb45580f 100644
--- a/tests/python/topi/python/test_topi_clip.py
+++ b/tests/python/topi/python/test_topi_clip.py
@@ -17,7 +17,7 @@
 """Test code for clip operator"""
 import numpy as np
 import tvm
-from tvm import te
+from tvm import te, tir
 from tvm import topi
 import tvm.testing
 import tvm.topi.testing
@@ -32,12 +32,14 @@ def verify_clip(N, a_min, a_max, dtype):
 
     # use memoize to pickle the test data for next time use
     @memoize("topi.tests.test_topi_clip")
-    def get_ref_data():
+    def get_ref_data(a_min, a_max):
         a_np = np.random.uniform(a_min * 2, a_max * 2, size=(N, 
N)).astype(dtype)
         b_np = np.clip(a_np, a_min, a_max)
         return a_np, b_np
 
-    a_np, b_np = get_ref_data()
+    a_min = a_min.value if isinstance(a_min, (tir.FloatImm, tir.IntImm)) else 
a_min
+    a_max = a_max.value if isinstance(a_max, (tir.FloatImm, tir.IntImm)) else 
a_max
+    a_np, b_np = get_ref_data(a_min, a_max)
 
     def check_target(target, dev):
         print("Running on target: %s" % target)
@@ -61,5 +63,13 @@ def test_clip():
     verify_clip(1024, -127, 127, "int8")
 
 
[email protected]_gpu
+def test_clip_floaimm_intimm():
+    verify_clip(1024, tir.FloatImm("float32", -127), tir.FloatImm("float32", 
127), "float32")
+    verify_clip(1024, tir.IntImm("int32", -127), tir.IntImm("int32", 127), 
"int16")
+    verify_clip(1024, tir.IntImm("int32", -127), tir.IntImm("int32", 127), 
"int8")
+
+
 if __name__ == "__main__":
     test_clip()
+    test_clip_floaimm_intimm()

Reply via email to