This is an automated email from the ASF dual-hosted git repository.

yongwww pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 90391bb016 [Relax] Enable bfloat16 for softmax struct-info inference 
(#17781)
90391bb016 is described below

commit 90391bb016048ee22959dbe71da8a1eb301b027c
Author: Ruihang Lai <[email protected]>
AuthorDate: Thu Mar 27 00:26:28 2025 -0400

    [Relax] Enable bfloat16 for softmax struct-info inference (#17781)
    
    This PR supports inferring structure info for datatype bfloat16
    in softmax/log-softmax operators.
---
 src/relax/op/nn/nn.cc            | 3 ++-
 tests/python/relax/test_op_nn.py | 9 +++++++--
 2 files changed, 9 insertions(+), 3 deletions(-)

diff --git a/src/relax/op/nn/nn.cc b/src/relax/op/nn/nn.cc
index 826711538c..c768ea19af 100644
--- a/src/relax/op/nn/nn.cc
+++ b/src/relax/op/nn/nn.cc
@@ -74,7 +74,8 @@ StructInfo InferStructInfoSoftmax(const Call& call, const 
BlockBuilder& ctx) {
   if (data_sinfo->IsUnknownNdim()) {
     return data_sinfo;
   }
-  if (!data_sinfo->IsUnknownDtype() && !data_sinfo->dtype.is_float()) {
+  if (!data_sinfo->IsUnknownDtype() && !data_sinfo->dtype.is_float() &&
+      !data_sinfo->dtype.is_bfloat()) {
     ctx->ReportFatal(Diagnostic::Error(call) << "Softmax requires the input 
tensor to have float "
                                                 "dtype. However, the given 
input dtype is "
                                              << data_sinfo->dtype);
diff --git a/tests/python/relax/test_op_nn.py b/tests/python/relax/test_op_nn.py
index 7adfc84283..ec4551872f 100644
--- a/tests/python/relax/test_op_nn.py
+++ b/tests/python/relax/test_op_nn.py
@@ -15,10 +15,10 @@
 # specific language governing permissions and limitations
 # under the License.
 import pytest
+
 import tvm
 import tvm.testing
-from tvm import relax, tir
-from tvm import TVMError
+from tvm import TVMError, relax, tir
 from tvm.ir import Op, VDevice
 from tvm.script import relax as R
 
@@ -143,6 +143,7 @@ def test_softmax_log_softmax_infer_struct_info():
     x3 = relax.Var("x", R.Tensor((2, 3)))
     x4 = relax.Var("x", R.Tensor())
     x5 = relax.Var("x", R.Tensor((2, 3), "float32", vdev0))
+    x6 = relax.Var("x", R.Tensor((2, 3), "bfloat16"))
 
     _check_inference(bb, relax.op.nn.softmax(x0), relax.TensorStructInfo((2, 
3), "float32"))
     _check_inference(bb, relax.op.nn.softmax(x5), relax.TensorStructInfo((2, 
3), "float32", vdev0))
@@ -164,6 +165,10 @@ def test_softmax_log_softmax_infer_struct_info():
         bb, relax.op.nn.log_softmax(x3, axis=-1), relax.TensorStructInfo((2, 
3), dtype="")
     )
     _check_inference(bb, relax.op.nn.log_softmax(x4, axis=-2), 
relax.TensorStructInfo(dtype=""))
+    _check_inference(bb, relax.op.nn.softmax(x6), relax.TensorStructInfo((2, 
3), dtype="bfloat16"))
+    _check_inference(
+        bb, relax.op.nn.log_softmax(x6), relax.TensorStructInfo((2, 3), 
dtype="bfloat16")
+    )
 
 
 def test_softmax_log_softmax_infer_struct_info_shape_symbolic():

Reply via email to