This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 1c209e27b7 [Relax] Clean up scatter_elements unknown dtype handling 
(#18577)
1c209e27b7 is described below

commit 1c209e27b7b0c62fcb37968382ffcd1612319eab
Author: Guan-Ming (Wesley) Chiu <[email protected]>
AuthorDate: Sat Dec 20 19:47:33 2025 +0800

    [Relax] Clean up scatter_elements unknown dtype handling (#18577)
    
    ## Why
    
    - LOG(WARNING) is the standard and correct approach throughout the TVM
    codebase
    - The existing pattern is used consistently in all relax ops (see
    test_op_manipulate.py, index.cc, etc.)
    - Added test coverage for previously untested scenarios
---
 src/relax/op/tensor/manipulate.cc        |  2 --
 tests/python/relax/test_op_manipulate.py | 14 ++++++++++++++
 2 files changed, 14 insertions(+), 2 deletions(-)

diff --git a/src/relax/op/tensor/manipulate.cc 
b/src/relax/op/tensor/manipulate.cc
index 493198fbd0..1aab52ac56 100644
--- a/src/relax/op/tensor/manipulate.cc
+++ b/src/relax/op/tensor/manipulate.cc
@@ -2456,7 +2456,6 @@ StructInfo InferStructInfoScatterElements(const Call& 
call, const BlockBuilder&
   if (data_sinfo->IsUnknownDtype() || updates_sinfo->IsUnknownDtype()) {
     auto diag_dtype = [&](const TensorStructInfoNode* sinfo, ffi::String name) 
{
       if (sinfo->IsUnknownDtype()) {
-        // TODO(tvm-team): Do we have an equivalent of `ctx->ReportFatal` for 
warning?
         LOG(WARNING) << "Data type of " << name
                      << " has not been specified. Assume it has an integer 
type.";
       }
@@ -2473,7 +2472,6 @@ StructInfo InferStructInfoScatterElements(const Call& 
call, const BlockBuilder&
   }
 
   if (indices_sinfo->IsUnknownDtype()) {
-    // TODO(tvm-team): Do we have an equivalent of `ctx->ReportFatal` for 
warning?
     LOG(WARNING) << "Data type of indice has not been specified. Assume it has 
an integer type.";
   } else if (!(indices_sinfo->dtype.is_int() || 
indices_sinfo->dtype.is_uint())) {
     ctx->ReportFatal(
diff --git a/tests/python/relax/test_op_manipulate.py 
b/tests/python/relax/test_op_manipulate.py
index d39584e06b..6a73a84fd8 100644
--- a/tests/python/relax/test_op_manipulate.py
+++ b/tests/python/relax/test_op_manipulate.py
@@ -3417,6 +3417,20 @@ def test_scatter_elements_infer_struct_info():
         relax.op.scatter_elements(d2, i3, u0, 0, "updates"),
         relax.TensorStructInfo(dtype="float32", ndim=-1),
     )
+    # Test with unknown dtype for data
+    d_unknown = relax.Var("data", R.Tensor((4, 4)))
+    _check_inference(
+        bb,
+        relax.op.scatter_elements(d_unknown, i0, u0, 0, "updates"),
+        relax.TensorStructInfo((4, 4), dtype=""),
+    )
+    # Test with unknown dtype for updates
+    u_unknown = relax.Var("updates", R.Tensor((2, 2)))
+    _check_inference(
+        bb,
+        relax.op.scatter_elements(d0, i0, u_unknown, 0, "updates"),
+        relax.TensorStructInfo((4, 4), dtype="float32"),
+    )
 
 
 def test_scatter_elements_infer_struct_info_symbolic_shape():

Reply via email to