This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 5d3b525675 [TIR] Reject non-floating inputs for trig unary ops (#18879)
5d3b525675 is described below
commit 5d3b525675937be6c178b3f358e94e41ca080df8
Author: YinHanke <[email protected]>
AuthorDate: Mon Mar 9 06:35:14 2026 +0800
[TIR] Reject non-floating inputs for trig unary ops (#18879)
## Summary
Reject non-floating inputs for trig-style TIR unary ops.
## Changes
- reject non-floating inputs for trig-style TIR unary ops such as `tan`,
`sin`, and `cos`
- add the same dtype check in the Python TIR wrapper so
`topi.tan(int32)` fails early with a clear `TypeError`
- add regression tests for `tvm.tir.tan(int32)` and `topi.tan(int32)`
## Validation
- `tests/python/tir-base/test_tir_constructor.py -k
'math_unary_constructor_requires_float_dtype or
topi_tan_requires_float_dtype' -q`
- local repro for the original `where -> tan(int32)` case now fails
early with `TypeError`
- verified `topi.tan(float32)` still builds with `target="llvm"`
## Issue
Fixes #18769
---
include/tvm/tir/op.h | 38 ++++++++++++++++++---------
python/tvm/tir/op.py | 31 +++++++++++++---------
tests/python/tir-base/test_tir_constructor.py | 27 +++++++++++++++++++
3 files changed, 71 insertions(+), 25 deletions(-)
diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h
index 050063300b..59f04e76a3 100644
--- a/include/tvm/tir/op.h
+++ b/include/tvm/tir/op.h
@@ -720,10 +720,16 @@ TVM_DLL PrimExpr q_multiply_shift(PrimExpr x, PrimExpr y,
PrimExpr q, PrimExpr s
*/
TVM_DLL PrimExpr fast_erf_float_expr(PrimExpr arg, int bits);
+inline void CheckMathUnaryOpInputDType(const char* op_name, DataType dtype) {
+ TVM_FFI_CHECK(dtype.is_float() || dtype.is_bfloat16(), TypeError)
+ << "tir." << op_name << " only supports floating-point inputs, but got "
<< dtype;
+}
+
// Intrinsic operators
-#define TVM_DECLARE_INTRIN_UNARY(OpName) \
+#define TVM_DECLARE_INTRIN_UNARY_WITH_CHECK(OpName, CheckInputDType) \
inline PrimExpr OpName(PrimExpr x, Span span = Span()) { \
static const Op& op = Op::Get("tir." #OpName); \
+ CheckInputDType(#OpName, x.dtype()); \
if (x.dtype().is_bfloat16()) { \
DataType bf16_dtype = x.dtype(); \
DataType fp32_dtype(kDLFloat, 32, bf16_dtype.lanes()); \
@@ -735,11 +741,17 @@ TVM_DLL PrimExpr fast_erf_float_expr(PrimExpr arg, int
bits);
} \
}
+#define TVM_DECLARE_INTRIN_UNARY(OpName) \
+ TVM_DECLARE_INTRIN_UNARY_WITH_CHECK(OpName, [](const char*, DataType) {})
+
+#define TVM_DECLARE_FLOAT_INTRIN_UNARY(OpName) \
+ TVM_DECLARE_INTRIN_UNARY_WITH_CHECK(OpName, CheckMathUnaryOpInputDType)
+
TVM_DECLARE_INTRIN_UNARY(exp);
TVM_DECLARE_INTRIN_UNARY(exp2);
TVM_DECLARE_INTRIN_UNARY(exp10);
TVM_DECLARE_INTRIN_UNARY(erf);
-TVM_DECLARE_INTRIN_UNARY(tanh);
+TVM_DECLARE_FLOAT_INTRIN_UNARY(tanh);
TVM_DECLARE_INTRIN_UNARY(sigmoid);
TVM_DECLARE_INTRIN_UNARY(sqrt);
TVM_DECLARE_INTRIN_UNARY(rsqrt);
@@ -748,17 +760,17 @@ TVM_DECLARE_INTRIN_UNARY(log2);
TVM_DECLARE_INTRIN_UNARY(log10);
TVM_DECLARE_INTRIN_UNARY(log1p);
TVM_DECLARE_INTRIN_UNARY(popcount);
-TVM_DECLARE_INTRIN_UNARY(tan);
-TVM_DECLARE_INTRIN_UNARY(cos);
-TVM_DECLARE_INTRIN_UNARY(cosh);
-TVM_DECLARE_INTRIN_UNARY(sin);
-TVM_DECLARE_INTRIN_UNARY(sinh);
-TVM_DECLARE_INTRIN_UNARY(asin);
-TVM_DECLARE_INTRIN_UNARY(acos);
-TVM_DECLARE_INTRIN_UNARY(atan);
-TVM_DECLARE_INTRIN_UNARY(acosh);
-TVM_DECLARE_INTRIN_UNARY(asinh);
-TVM_DECLARE_INTRIN_UNARY(atanh);
+TVM_DECLARE_FLOAT_INTRIN_UNARY(tan);
+TVM_DECLARE_FLOAT_INTRIN_UNARY(cos);
+TVM_DECLARE_FLOAT_INTRIN_UNARY(cosh);
+TVM_DECLARE_FLOAT_INTRIN_UNARY(sin);
+TVM_DECLARE_FLOAT_INTRIN_UNARY(sinh);
+TVM_DECLARE_FLOAT_INTRIN_UNARY(asin);
+TVM_DECLARE_FLOAT_INTRIN_UNARY(acos);
+TVM_DECLARE_FLOAT_INTRIN_UNARY(atan);
+TVM_DECLARE_FLOAT_INTRIN_UNARY(acosh);
+TVM_DECLARE_FLOAT_INTRIN_UNARY(asinh);
+TVM_DECLARE_FLOAT_INTRIN_UNARY(atanh);
TVM_DECLARE_INTRIN_UNARY(clz);
#define TVM_DECLARE_INTRIN_BINARY(OpName) \
diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py
index 616373c765..da9a9aecd1 100644
--- a/python/tvm/tir/op.py
+++ b/python/tvm/tir/op.py
@@ -238,6 +238,13 @@ def call_extern(dtype, func_name, *args, span=None):
return Call(dtype, Op.get("tir.call_extern"), [func_name, *args],
span=span)
+def _require_float_arg(op_name, x):
+ x = tir.convert(x)
+ if "float" not in x.dtype and "bfloat" not in x.dtype:
+ raise TypeError(f"tir.{op_name} only supports floating-point inputs,
but got {x.dtype}")
+ return x
+
+
def call_llvm_intrin(dtype, name, *args, span=None):
"""Build expression by calling a llvm intrinsic function
@@ -2186,7 +2193,7 @@ def tanh(x):
y : PrimExpr
The result.
"""
- x = tir.convert(x)
+ x = _require_float_arg("tanh", x)
return call_intrin(x.dtype, "tir.tanh", x)
@@ -2288,7 +2295,7 @@ def tan(x):
y : PrimExpr
The result.
"""
- x = tir.convert(x)
+ x = _require_float_arg("tan", x)
return call_intrin(x.dtype, "tir.tan", x)
@@ -2305,7 +2312,7 @@ def cos(x):
y : PrimExpr
The result.
"""
- x = tir.convert(x)
+ x = _require_float_arg("cos", x)
return call_intrin(x.dtype, "tir.cos", x)
@@ -2322,7 +2329,7 @@ def cosh(x):
y : PrimExpr
The result.
"""
- x = tir.convert(x)
+ x = _require_float_arg("cosh", x)
return call_intrin(x.dtype, "tir.cosh", x)
@@ -2339,7 +2346,7 @@ def acos(x):
y : PrimExpr
The result.
"""
- x = tir.convert(x)
+ x = _require_float_arg("acos", x)
return call_intrin(x.dtype, "tir.acos", x)
@@ -2356,7 +2363,7 @@ def acosh(x):
y : PrimExpr
The result.
"""
- x = tir.convert(x)
+ x = _require_float_arg("acosh", x)
return call_intrin(x.dtype, "tir.acosh", x)
@@ -2373,7 +2380,7 @@ def sin(x):
y : PrimExpr
The result.
"""
- x = tir.convert(x)
+ x = _require_float_arg("sin", x)
return call_intrin(x.dtype, "tir.sin", x)
@@ -2390,7 +2397,7 @@ def sinh(x):
y : PrimExpr
The result.
"""
- x = tir.convert(x)
+ x = _require_float_arg("sinh", x)
return call_intrin(x.dtype, "tir.sinh", x)
@@ -2407,7 +2414,7 @@ def asin(x):
y : PrimExpr
The result.
"""
- x = tir.convert(x)
+ x = _require_float_arg("asin", x)
return call_intrin(x.dtype, "tir.asin", x)
@@ -2424,7 +2431,7 @@ def asinh(x):
y : PrimExpr
The result.
"""
- x = tir.convert(x)
+ x = _require_float_arg("asinh", x)
return call_intrin(x.dtype, "tir.asinh", x)
@@ -2441,7 +2448,7 @@ def atan(x):
y : PrimExpr
The result.
"""
- x = tir.convert(x)
+ x = _require_float_arg("atan", x)
return call_intrin(x.dtype, "tir.atan", x)
@@ -2458,7 +2465,7 @@ def atanh(x):
y : PrimExpr
The result.
"""
- x = tir.convert(x)
+ x = _require_float_arg("atanh", x)
return call_intrin(x.dtype, "tir.atanh", x)
diff --git a/tests/python/tir-base/test_tir_constructor.py
b/tests/python/tir-base/test_tir_constructor.py
index 7edc734fc7..654c22ab35 100644
--- a/tests/python/tir-base/test_tir_constructor.py
+++ b/tests/python/tir-base/test_tir_constructor.py
@@ -19,6 +19,7 @@
import pytest
import tvm
+from tvm import te, topi
def test_expr_constructor():
@@ -187,5 +188,31 @@ def test_float_constructor_requires_float_dtype():
tvm.tir.FloatImm("int32", 1.0)
+def test_math_unary_constructor_requires_float_dtype():
+ x = tvm.tir.Var("x", "int32")
+
+ with pytest.raises(TypeError, match=r"tir\.tan only supports
floating-point inputs"):
+ tvm.tir.tan(x)
+
+ with pytest.raises(TypeError, match=r"tir\.sin only supports
floating-point inputs"):
+ tvm.tir.sin(x)
+
+ y = tvm.tir.Var("y", "float32")
+ assert tvm.tir.tan(y).dtype == "float32"
+
+
+def test_topi_tan_requires_float_dtype():
+ x = te.placeholder((2, 2), dtype="int32", name="x")
+
+ with pytest.raises(TypeError, match=r"tir\.tan only supports
floating-point inputs"):
+ topi.tan(x)
+
+
+def test_math_unary_constructor_preserves_bfloat16():
+ x = tvm.tir.Var("x", "bfloat16")
+ y = tvm.tir.exp(x)
+ assert y.dtype == "bfloat16"
+
+
if __name__ == "__main__":
tvm.testing.main()