This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 70f223b  [CodeGen][CUDA] use hrint for cuda half rounding (#10460)
70f223b is described below

commit 70f223ba08a1cb161d202ff9d668d00d325526e0
Author: wrongtest <[email protected]>
AuthorDate: Sat Mar 5 11:04:14 2022 +0800

    [CodeGen][CUDA] use hrint for cuda half rounding (#10460)
    
    When cuda c codegen generate `tir.round` for fp16, there is no function 
named `hround`, but `hrint` for cuda half arithmetics. 
https://docs.nvidia.com/cuda/cuda-math-api/group__CUDA__MATH____HALF__FUNCTIONS.html#group__CUDA__MATH____HALF__FUNCTIONS_1gbbf7a989130edcbdbfbb4730f61c79b1
    
    Testcase to reproduce:
    ```python
    import tvm
    from tvm import relay
    from tvm.ir.module import IRModule
    x = relay.var("x", shape=[16], dtype="float16")
    y = relay.round(x)
    f = relay.Function([x], y)
    m = IRModule.from_expr(f)
    m = relay.transform.InferType()(m)
    relay.build(m, target="cuda")
    ```
---
 src/target/source/intrin_rule_cuda.cc | 2 ++
 tests/python/relay/test_op_level1.py  | 4 ++++
 2 files changed, 6 insertions(+)

diff --git a/src/target/source/intrin_rule_cuda.cc 
b/src/target/source/intrin_rule_cuda.cc
index d1f3b33..a450b44 100644
--- a/src/target/source/intrin_rule_cuda.cc
+++ b/src/target/source/intrin_rule_cuda.cc
@@ -43,6 +43,8 @@ struct CUDAMath {
         case 16: {
           if (name == "fabs") {
             return "__habs";
+          } else if (name == "round") {
+            return "hrint";
           } else {
             return "h" + name;
           }
diff --git a/tests/python/relay/test_op_level1.py 
b/tests/python/relay/test_op_level1.py
index 4f7d810..c7aceb6 100644
--- a/tests/python/relay/test_op_level1.py
+++ b/tests/python/relay/test_op_level1.py
@@ -57,6 +57,10 @@ class TestUnaryOp:
         "sin": (tvm.relay.sin, np.sin),
         "tan": (tvm.relay.tan, np.tan),
         "atan": (tvm.relay.atan, np.arctan),
+        "ceil": (tvm.relay.ceil, np.ceil),
+        "floor": (tvm.relay.floor, np.floor),
+        "trunc": (tvm.relay.trunc, np.trunc),
+        "round": (tvm.relay.round, np.round),
     }
 
     dtype = tvm.testing.parameter("float16", "float32")

Reply via email to