This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 5d3b525675 [TIR] Reject non-floating inputs for trig unary ops (#18879)
5d3b525675 is described below

commit 5d3b525675937be6c178b3f358e94e41ca080df8
Author: YinHanke <[email protected]>
AuthorDate: Mon Mar 9 06:35:14 2026 +0800

    [TIR] Reject non-floating inputs for trig unary ops (#18879)
    
    ## Summary
    
    Reject non-floating inputs for trig-style TIR unary ops.
    
    ## Changes
    
    - reject non-floating inputs for trig-style TIR unary ops such as `tan`,
    `sin`, and `cos`
    - add the same dtype check in the Python TIR wrapper so
    `topi.tan(int32)` fails early with a clear `TypeError`
    - add regression tests for `tvm.tir.tan(int32)` and `topi.tan(int32)`
    
    ## Validation
    
    - `tests/python/tir-base/test_tir_constructor.py -k
    'math_unary_constructor_requires_float_dtype or
    topi_tan_requires_float_dtype' -q`
    - local repro for the original `where -> tan(int32)` case now fails
    early with `TypeError`
    - verified `topi.tan(float32)` still builds with `target="llvm"`
    
    ## Issue
    
    Fixes #18769
---
 include/tvm/tir/op.h                          | 38 ++++++++++++++++++---------
 python/tvm/tir/op.py                          | 31 +++++++++++++---------
 tests/python/tir-base/test_tir_constructor.py | 27 +++++++++++++++++++
 3 files changed, 71 insertions(+), 25 deletions(-)

diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h
index 050063300b..59f04e76a3 100644
--- a/include/tvm/tir/op.h
+++ b/include/tvm/tir/op.h
@@ -720,10 +720,16 @@ TVM_DLL PrimExpr q_multiply_shift(PrimExpr x, PrimExpr y, 
PrimExpr q, PrimExpr s
  */
 TVM_DLL PrimExpr fast_erf_float_expr(PrimExpr arg, int bits);
 
+inline void CheckMathUnaryOpInputDType(const char* op_name, DataType dtype) {
+  TVM_FFI_CHECK(dtype.is_float() || dtype.is_bfloat16(), TypeError)
+      << "tir." << op_name << " only supports floating-point inputs, but got " 
<< dtype;
+}
+
 // Intrinsic operators
-#define TVM_DECLARE_INTRIN_UNARY(OpName)                                \
+#define TVM_DECLARE_INTRIN_UNARY_WITH_CHECK(OpName, CheckInputDType)    \
   inline PrimExpr OpName(PrimExpr x, Span span = Span()) {              \
     static const Op& op = Op::Get("tir." #OpName);                      \
+    CheckInputDType(#OpName, x.dtype());                                \
     if (x.dtype().is_bfloat16()) {                                      \
       DataType bf16_dtype = x.dtype();                                  \
       DataType fp32_dtype(kDLFloat, 32, bf16_dtype.lanes());            \
@@ -735,11 +741,17 @@ TVM_DLL PrimExpr fast_erf_float_expr(PrimExpr arg, int 
bits);
     }                                                                   \
   }
 
+#define TVM_DECLARE_INTRIN_UNARY(OpName) \
+  TVM_DECLARE_INTRIN_UNARY_WITH_CHECK(OpName, [](const char*, DataType) {})
+
+#define TVM_DECLARE_FLOAT_INTRIN_UNARY(OpName) \
+  TVM_DECLARE_INTRIN_UNARY_WITH_CHECK(OpName, CheckMathUnaryOpInputDType)
+
 TVM_DECLARE_INTRIN_UNARY(exp);
 TVM_DECLARE_INTRIN_UNARY(exp2);
 TVM_DECLARE_INTRIN_UNARY(exp10);
 TVM_DECLARE_INTRIN_UNARY(erf);
-TVM_DECLARE_INTRIN_UNARY(tanh);
+TVM_DECLARE_FLOAT_INTRIN_UNARY(tanh);
 TVM_DECLARE_INTRIN_UNARY(sigmoid);
 TVM_DECLARE_INTRIN_UNARY(sqrt);
 TVM_DECLARE_INTRIN_UNARY(rsqrt);
@@ -748,17 +760,17 @@ TVM_DECLARE_INTRIN_UNARY(log2);
 TVM_DECLARE_INTRIN_UNARY(log10);
 TVM_DECLARE_INTRIN_UNARY(log1p);
 TVM_DECLARE_INTRIN_UNARY(popcount);
-TVM_DECLARE_INTRIN_UNARY(tan);
-TVM_DECLARE_INTRIN_UNARY(cos);
-TVM_DECLARE_INTRIN_UNARY(cosh);
-TVM_DECLARE_INTRIN_UNARY(sin);
-TVM_DECLARE_INTRIN_UNARY(sinh);
-TVM_DECLARE_INTRIN_UNARY(asin);
-TVM_DECLARE_INTRIN_UNARY(acos);
-TVM_DECLARE_INTRIN_UNARY(atan);
-TVM_DECLARE_INTRIN_UNARY(acosh);
-TVM_DECLARE_INTRIN_UNARY(asinh);
-TVM_DECLARE_INTRIN_UNARY(atanh);
+TVM_DECLARE_FLOAT_INTRIN_UNARY(tan);
+TVM_DECLARE_FLOAT_INTRIN_UNARY(cos);
+TVM_DECLARE_FLOAT_INTRIN_UNARY(cosh);
+TVM_DECLARE_FLOAT_INTRIN_UNARY(sin);
+TVM_DECLARE_FLOAT_INTRIN_UNARY(sinh);
+TVM_DECLARE_FLOAT_INTRIN_UNARY(asin);
+TVM_DECLARE_FLOAT_INTRIN_UNARY(acos);
+TVM_DECLARE_FLOAT_INTRIN_UNARY(atan);
+TVM_DECLARE_FLOAT_INTRIN_UNARY(acosh);
+TVM_DECLARE_FLOAT_INTRIN_UNARY(asinh);
+TVM_DECLARE_FLOAT_INTRIN_UNARY(atanh);
 TVM_DECLARE_INTRIN_UNARY(clz);
 
 #define TVM_DECLARE_INTRIN_BINARY(OpName)                              \
diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py
index 616373c765..da9a9aecd1 100644
--- a/python/tvm/tir/op.py
+++ b/python/tvm/tir/op.py
@@ -238,6 +238,13 @@ def call_extern(dtype, func_name, *args, span=None):
     return Call(dtype, Op.get("tir.call_extern"), [func_name, *args], 
span=span)
 
 
+def _require_float_arg(op_name, x):
+    x = tir.convert(x)
+    if "float" not in x.dtype and "bfloat" not in x.dtype:
+        raise TypeError(f"tir.{op_name} only supports floating-point inputs, 
but got {x.dtype}")
+    return x
+
+
 def call_llvm_intrin(dtype, name, *args, span=None):
     """Build expression by calling a llvm intrinsic function
 
@@ -2186,7 +2193,7 @@ def tanh(x):
     y : PrimExpr
         The result.
     """
-    x = tir.convert(x)
+    x = _require_float_arg("tanh", x)
     return call_intrin(x.dtype, "tir.tanh", x)
 
 
@@ -2288,7 +2295,7 @@ def tan(x):
     y : PrimExpr
         The result.
     """
-    x = tir.convert(x)
+    x = _require_float_arg("tan", x)
     return call_intrin(x.dtype, "tir.tan", x)
 
 
@@ -2305,7 +2312,7 @@ def cos(x):
     y : PrimExpr
         The result.
     """
-    x = tir.convert(x)
+    x = _require_float_arg("cos", x)
     return call_intrin(x.dtype, "tir.cos", x)
 
 
@@ -2322,7 +2329,7 @@ def cosh(x):
     y : PrimExpr
         The result.
     """
-    x = tir.convert(x)
+    x = _require_float_arg("cosh", x)
     return call_intrin(x.dtype, "tir.cosh", x)
 
 
@@ -2339,7 +2346,7 @@ def acos(x):
     y : PrimExpr
         The result.
     """
-    x = tir.convert(x)
+    x = _require_float_arg("acos", x)
     return call_intrin(x.dtype, "tir.acos", x)
 
 
@@ -2356,7 +2363,7 @@ def acosh(x):
     y : PrimExpr
         The result.
     """
-    x = tir.convert(x)
+    x = _require_float_arg("acosh", x)
     return call_intrin(x.dtype, "tir.acosh", x)
 
 
@@ -2373,7 +2380,7 @@ def sin(x):
     y : PrimExpr
         The result.
     """
-    x = tir.convert(x)
+    x = _require_float_arg("sin", x)
     return call_intrin(x.dtype, "tir.sin", x)
 
 
@@ -2390,7 +2397,7 @@ def sinh(x):
     y : PrimExpr
         The result.
     """
-    x = tir.convert(x)
+    x = _require_float_arg("sinh", x)
     return call_intrin(x.dtype, "tir.sinh", x)
 
 
@@ -2407,7 +2414,7 @@ def asin(x):
     y : PrimExpr
         The result.
     """
-    x = tir.convert(x)
+    x = _require_float_arg("asin", x)
     return call_intrin(x.dtype, "tir.asin", x)
 
 
@@ -2424,7 +2431,7 @@ def asinh(x):
     y : PrimExpr
         The result.
     """
-    x = tir.convert(x)
+    x = _require_float_arg("asinh", x)
     return call_intrin(x.dtype, "tir.asinh", x)
 
 
@@ -2441,7 +2448,7 @@ def atan(x):
     y : PrimExpr
         The result.
     """
-    x = tir.convert(x)
+    x = _require_float_arg("atan", x)
     return call_intrin(x.dtype, "tir.atan", x)
 
 
@@ -2458,7 +2465,7 @@ def atanh(x):
     y : PrimExpr
         The result.
     """
-    x = tir.convert(x)
+    x = _require_float_arg("atanh", x)
     return call_intrin(x.dtype, "tir.atanh", x)
 
 
diff --git a/tests/python/tir-base/test_tir_constructor.py 
b/tests/python/tir-base/test_tir_constructor.py
index 7edc734fc7..654c22ab35 100644
--- a/tests/python/tir-base/test_tir_constructor.py
+++ b/tests/python/tir-base/test_tir_constructor.py
@@ -19,6 +19,7 @@
 import pytest
 
 import tvm
+from tvm import te, topi
 
 
 def test_expr_constructor():
@@ -187,5 +188,31 @@ def test_float_constructor_requires_float_dtype():
         tvm.tir.FloatImm("int32", 1.0)
 
 
+def test_math_unary_constructor_requires_float_dtype():
+    x = tvm.tir.Var("x", "int32")
+
+    with pytest.raises(TypeError, match=r"tir\.tan only supports 
floating-point inputs"):
+        tvm.tir.tan(x)
+
+    with pytest.raises(TypeError, match=r"tir\.sin only supports 
floating-point inputs"):
+        tvm.tir.sin(x)
+
+    y = tvm.tir.Var("y", "float32")
+    assert tvm.tir.tan(y).dtype == "float32"
+
+
+def test_topi_tan_requires_float_dtype():
+    x = te.placeholder((2, 2), dtype="int32", name="x")
+
+    with pytest.raises(TypeError, match=r"tir\.tan only supports 
floating-point inputs"):
+        topi.tan(x)
+
+
+def test_math_unary_constructor_preserves_bfloat16():
+    x = tvm.tir.Var("x", "bfloat16")
+    y = tvm.tir.exp(x)
+    assert y.dtype == "bfloat16"
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to