This is an automated email from the ASF dual-hosted git repository.
yongwww pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 90391bb016 [Relax] Enable bfloat16 for softmax struct-info inference
(#17781)
90391bb016 is described below
commit 90391bb016048ee22959dbe71da8a1eb301b027c
Author: Ruihang Lai <[email protected]>
AuthorDate: Thu Mar 27 00:26:28 2025 -0400
[Relax] Enable bfloat16 for softmax struct-info inference (#17781)
This PR supports inferring structure info for datatype bfloat16
in softmax/log-softmax operators.
---
src/relax/op/nn/nn.cc | 3 ++-
tests/python/relax/test_op_nn.py | 9 +++++++--
2 files changed, 9 insertions(+), 3 deletions(-)
diff --git a/src/relax/op/nn/nn.cc b/src/relax/op/nn/nn.cc
index 826711538c..c768ea19af 100644
--- a/src/relax/op/nn/nn.cc
+++ b/src/relax/op/nn/nn.cc
@@ -74,7 +74,8 @@ StructInfo InferStructInfoSoftmax(const Call& call, const
BlockBuilder& ctx) {
if (data_sinfo->IsUnknownNdim()) {
return data_sinfo;
}
- if (!data_sinfo->IsUnknownDtype() && !data_sinfo->dtype.is_float()) {
+ if (!data_sinfo->IsUnknownDtype() && !data_sinfo->dtype.is_float() &&
+ !data_sinfo->dtype.is_bfloat()) {
ctx->ReportFatal(Diagnostic::Error(call) << "Softmax requires the input
tensor to have float "
"dtype. However, the given
input dtype is "
<< data_sinfo->dtype);
diff --git a/tests/python/relax/test_op_nn.py b/tests/python/relax/test_op_nn.py
index 7adfc84283..ec4551872f 100644
--- a/tests/python/relax/test_op_nn.py
+++ b/tests/python/relax/test_op_nn.py
@@ -15,10 +15,10 @@
# specific language governing permissions and limitations
# under the License.
import pytest
+
import tvm
import tvm.testing
-from tvm import relax, tir
-from tvm import TVMError
+from tvm import TVMError, relax, tir
from tvm.ir import Op, VDevice
from tvm.script import relax as R
@@ -143,6 +143,7 @@ def test_softmax_log_softmax_infer_struct_info():
x3 = relax.Var("x", R.Tensor((2, 3)))
x4 = relax.Var("x", R.Tensor())
x5 = relax.Var("x", R.Tensor((2, 3), "float32", vdev0))
+ x6 = relax.Var("x", R.Tensor((2, 3), "bfloat16"))
_check_inference(bb, relax.op.nn.softmax(x0), relax.TensorStructInfo((2,
3), "float32"))
_check_inference(bb, relax.op.nn.softmax(x5), relax.TensorStructInfo((2,
3), "float32", vdev0))
@@ -164,6 +165,10 @@ def test_softmax_log_softmax_infer_struct_info():
bb, relax.op.nn.log_softmax(x3, axis=-1), relax.TensorStructInfo((2,
3), dtype="")
)
_check_inference(bb, relax.op.nn.log_softmax(x4, axis=-2),
relax.TensorStructInfo(dtype=""))
+ _check_inference(bb, relax.op.nn.softmax(x6), relax.TensorStructInfo((2,
3), dtype="bfloat16"))
+ _check_inference(
+ bb, relax.op.nn.log_softmax(x6), relax.TensorStructInfo((2, 3),
dtype="bfloat16")
+ )
def test_softmax_log_softmax_infer_struct_info_shape_symbolic():