This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new a89b9f2880 [TOPI] Reject non-float inputs for inverse unary math ops
(#18880)
a89b9f2880 is described below
commit a89b9f288014917339ba04de412a748c9cd82b58
Author: YinHanke <[email protected]>
AuthorDate: Mon Mar 9 06:39:21 2026 +0800
[TOPI] Reject non-float inputs for inverse unary math ops (#18880)
## Summary
Reject non-float inputs for inverse trigonometric and hyperbolic unary
ops in TOPI.
## Changes
- add a shared floating-point dtype check for inverse unary math ops in
TOPI
- apply the check to `topi.acos`, `topi.acosh`, `topi.asin`,
`topi.asinh`, and `topi.atanh`
- add TE tests covering integer-input rejection for these ops
- add regression tests covering successful LLVM build for both `float32`
and `bfloat16`
## Validation
- `tests/python/te/test_te_create_primfunc.py -k 'topi_float_unary'`
- local repro now fails early with a clear `TypeError` for integer
inputs
- local regression check confirms the valid `float32` and `bfloat16`
paths still compile with LLVM
## Issue
Fixes #18729
---
python/tvm/topi/math.py | 13 ++++++++++++-
tests/python/te/test_te_create_primfunc.py | 26 ++++++++++++++++++++++++++
2 files changed, 38 insertions(+), 1 deletion(-)
diff --git a/python/tvm/topi/math.py b/python/tvm/topi/math.py
index 3664f61f0d..d086141371 100644
--- a/python/tvm/topi/math.py
+++ b/python/tvm/topi/math.py
@@ -18,13 +18,19 @@
# pylint: disable=redefined-builtin,unused-argument
import tvm
-from tvm import te
+from tvm import DataType, DataTypeCode, te
from tvm.tir import PrimExpr
from . import cpp, tag
from .utils import get_const_tuple
+def _require_float_tensor(op_name, x):
+ if DataType(x.dtype).type_code not in (DataTypeCode.FLOAT,
DataTypeCode.BFLOAT):
+ raise TypeError(f"topi.{op_name} only supports floating-point inputs,
but got {x.dtype}")
+ return x
+
+
@tvm.te.tag_scope(tag=tag.ELEMWISE)
def identity(x):
"""Take identity of input x.
@@ -211,6 +217,7 @@ def acos(x):
y : tvm.te.Tensor
The result.
"""
+ x = _require_float_tensor("acos", x)
return te.compute(x.shape, lambda *i: te.acos(x(*i)))
@@ -228,6 +235,7 @@ def acosh(x):
y : tvm.te.Tensor
The result.
"""
+ x = _require_float_tensor("acosh", x)
return te.compute(x.shape, lambda *i: te.acosh(x(*i)))
@@ -245,6 +253,7 @@ def asin(x):
y : tvm.te.Tensor
The result.
"""
+ x = _require_float_tensor("asin", x)
return te.compute(x.shape, lambda *i: te.asin(x(*i)))
@@ -262,6 +271,7 @@ def asinh(x):
y : tvm.te.Tensor
The result.
"""
+ x = _require_float_tensor("asinh", x)
return te.compute(x.shape, lambda *i: te.asinh(x(*i)))
@@ -296,6 +306,7 @@ def atanh(x):
y : tvm.te.Tensor
The result.
"""
+ x = _require_float_tensor("atanh", x)
return te.compute(x.shape, lambda *i: te.atanh(x(*i)))
diff --git a/tests/python/te/test_te_create_primfunc.py
b/tests/python/te/test_te_create_primfunc.py
index 7b069fc8cd..3a9bcd3957 100644
--- a/tests/python/te/test_te_create_primfunc.py
+++ b/tests/python/te/test_te_create_primfunc.py
@@ -359,6 +359,32 @@ def test_constant():
tvm.testing.assert_allclose(a_np + 2, c.numpy())
[email protected]("op_name", ["acos", "acosh", "asin", "asinh",
"atanh"])
+def test_topi_float_unary_rejects_integer_input(op_name):
+ x = te.placeholder((1, 8), dtype="int16", name="x")
+ op = getattr(topi, op_name)
+
+ with pytest.raises(
+ TypeError,
+ match=rf"topi\.{op_name} only supports floating-point inputs, but got
int16",
+ ):
+ op(x)
+
+
[email protected]("op_name", ["acos", "acosh", "asin", "asinh",
"atanh"])
[email protected]("dtype", ["float32", "bfloat16"])
+def test_topi_float_unary_accepts_float_input(op_name, dtype):
+ x = te.placeholder((1, 8), dtype=dtype, name="x")
+ op = getattr(topi, op_name)
+ out = op(x)
+
+ func = te.create_prim_func([x, out]).with_attr("target",
tvm.target.Target("llvm"))
+ mod = tvm.IRModule({"main": func})
+ compiled = tvm.build(mod, target="llvm")
+
+ assert compiled is not None
+
+
def test_data_dependent_access():
A = te.placeholder((10,), name="A")
B = te.placeholder((10,), name="B", dtype="int32")