This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new a89b9f2880 [TOPI] Reject non-float inputs for inverse unary math ops 
(#18880)
a89b9f2880 is described below

commit a89b9f288014917339ba04de412a748c9cd82b58
Author: YinHanke <[email protected]>
AuthorDate: Mon Mar 9 06:39:21 2026 +0800

    [TOPI] Reject non-float inputs for inverse unary math ops (#18880)
    
    ## Summary
    
    Reject non-float inputs for inverse trigonometric and hyperbolic unary
    ops in TOPI.
    
    ## Changes
    
    - add a shared floating-point dtype check for inverse unary math ops in
    TOPI
    - apply the check to `topi.acos`, `topi.acosh`, `topi.asin`,
    `topi.asinh`, and `topi.atanh`
    - add TE tests covering integer-input rejection for these ops
    - add regression tests covering successful LLVM build for both `float32`
    and `bfloat16`
    
    ## Validation
    
    - `tests/python/te/test_te_create_primfunc.py -k 'topi_float_unary'`
    - local repro now fails early with a clear `TypeError` for integer
    inputs
    - local regression check confirms the valid `float32` and `bfloat16`
    paths still compile with LLVM
    
    ## Issue
    
    Fixes #18729
---
 python/tvm/topi/math.py                    | 13 ++++++++++++-
 tests/python/te/test_te_create_primfunc.py | 26 ++++++++++++++++++++++++++
 2 files changed, 38 insertions(+), 1 deletion(-)

diff --git a/python/tvm/topi/math.py b/python/tvm/topi/math.py
index 3664f61f0d..d086141371 100644
--- a/python/tvm/topi/math.py
+++ b/python/tvm/topi/math.py
@@ -18,13 +18,19 @@
 
 # pylint: disable=redefined-builtin,unused-argument
 import tvm
-from tvm import te
+from tvm import DataType, DataTypeCode, te
 from tvm.tir import PrimExpr
 
 from . import cpp, tag
 from .utils import get_const_tuple
 
 
+def _require_float_tensor(op_name, x):
+    if DataType(x.dtype).type_code not in (DataTypeCode.FLOAT, 
DataTypeCode.BFLOAT):
+        raise TypeError(f"topi.{op_name} only supports floating-point inputs, 
but got {x.dtype}")
+    return x
+
+
 @tvm.te.tag_scope(tag=tag.ELEMWISE)
 def identity(x):
     """Take identity of input x.
@@ -211,6 +217,7 @@ def acos(x):
     y : tvm.te.Tensor
         The result.
     """
+    x = _require_float_tensor("acos", x)
     return te.compute(x.shape, lambda *i: te.acos(x(*i)))
 
 
@@ -228,6 +235,7 @@ def acosh(x):
     y : tvm.te.Tensor
         The result.
     """
+    x = _require_float_tensor("acosh", x)
     return te.compute(x.shape, lambda *i: te.acosh(x(*i)))
 
 
@@ -245,6 +253,7 @@ def asin(x):
     y : tvm.te.Tensor
         The result.
     """
+    x = _require_float_tensor("asin", x)
     return te.compute(x.shape, lambda *i: te.asin(x(*i)))
 
 
@@ -262,6 +271,7 @@ def asinh(x):
     y : tvm.te.Tensor
         The result.
     """
+    x = _require_float_tensor("asinh", x)
     return te.compute(x.shape, lambda *i: te.asinh(x(*i)))
 
 
@@ -296,6 +306,7 @@ def atanh(x):
     y : tvm.te.Tensor
         The result.
     """
+    x = _require_float_tensor("atanh", x)
     return te.compute(x.shape, lambda *i: te.atanh(x(*i)))
 
 
diff --git a/tests/python/te/test_te_create_primfunc.py 
b/tests/python/te/test_te_create_primfunc.py
index 7b069fc8cd..3a9bcd3957 100644
--- a/tests/python/te/test_te_create_primfunc.py
+++ b/tests/python/te/test_te_create_primfunc.py
@@ -359,6 +359,32 @@ def test_constant():
     tvm.testing.assert_allclose(a_np + 2, c.numpy())
 
 
[email protected]("op_name", ["acos", "acosh", "asin", "asinh", 
"atanh"])
+def test_topi_float_unary_rejects_integer_input(op_name):
+    x = te.placeholder((1, 8), dtype="int16", name="x")
+    op = getattr(topi, op_name)
+
+    with pytest.raises(
+        TypeError,
+        match=rf"topi\.{op_name} only supports floating-point inputs, but got 
int16",
+    ):
+        op(x)
+
+
[email protected]("op_name", ["acos", "acosh", "asin", "asinh", 
"atanh"])
[email protected]("dtype", ["float32", "bfloat16"])
+def test_topi_float_unary_accepts_float_input(op_name, dtype):
+    x = te.placeholder((1, 8), dtype=dtype, name="x")
+    op = getattr(topi, op_name)
+    out = op(x)
+
+    func = te.create_prim_func([x, out]).with_attr("target", 
tvm.target.Target("llvm"))
+    mod = tvm.IRModule({"main": func})
+    compiled = tvm.build(mod, target="llvm")
+
+    assert compiled is not None
+
+
 def test_data_dependent_access():
     A = te.placeholder((10,), name="A")
     B = te.placeholder((10,), name="B", dtype="int32")

Reply via email to