This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 70f223b [CodeGen][CUDA] use hrint for cuda half rounding (#10460)
70f223b is described below
commit 70f223ba08a1cb161d202ff9d668d00d325526e0
Author: wrongtest <[email protected]>
AuthorDate: Sat Mar 5 11:04:14 2022 +0800
[CodeGen][CUDA] use hrint for cuda half rounding (#10460)
When cuda c codegen generate `tir.round` for fp16, there is no function
named `hround`, but `hrint` for cuda half arithmetics.
https://docs.nvidia.com/cuda/cuda-math-api/group__CUDA__MATH____HALF__FUNCTIONS.html#group__CUDA__MATH____HALF__FUNCTIONS_1gbbf7a989130edcbdbfbb4730f61c79b1
Testcase to reproduce:
```python
import tvm
from tvm import relay
from tvm.ir.module import IRModule
x = relay.var("x", shape=[16], dtype="float16")
y = relay.round(x)
f = relay.Function([x], y)
m = IRModule.from_expr(f)
m = relay.transform.InferType()(m)
relay.build(m, target="cuda")
```
---
src/target/source/intrin_rule_cuda.cc | 2 ++
tests/python/relay/test_op_level1.py | 4 ++++
2 files changed, 6 insertions(+)
diff --git a/src/target/source/intrin_rule_cuda.cc
b/src/target/source/intrin_rule_cuda.cc
index d1f3b33..a450b44 100644
--- a/src/target/source/intrin_rule_cuda.cc
+++ b/src/target/source/intrin_rule_cuda.cc
@@ -43,6 +43,8 @@ struct CUDAMath {
case 16: {
if (name == "fabs") {
return "__habs";
+ } else if (name == "round") {
+ return "hrint";
} else {
return "h" + name;
}
diff --git a/tests/python/relay/test_op_level1.py
b/tests/python/relay/test_op_level1.py
index 4f7d810..c7aceb6 100644
--- a/tests/python/relay/test_op_level1.py
+++ b/tests/python/relay/test_op_level1.py
@@ -57,6 +57,10 @@ class TestUnaryOp:
"sin": (tvm.relay.sin, np.sin),
"tan": (tvm.relay.tan, np.tan),
"atan": (tvm.relay.atan, np.arctan),
+ "ceil": (tvm.relay.ceil, np.ceil),
+ "floor": (tvm.relay.floor, np.floor),
+ "trunc": (tvm.relay.trunc, np.trunc),
+ "round": (tvm.relay.round, np.round),
}
dtype = tvm.testing.parameter("float16", "float32")