This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new b44988ba9f [Relax][ONNX] Make ReduceMax/ReduceMin NaN propagation 
order-independent(numpy semantics) (#19755)
b44988ba9f is described below

commit b44988ba9fd82c02a2be1b2014a12041d8ed5785
Author: Neo Chien <[email protected]>
AuthorDate: Tue Jun 16 04:19:50 2026 +0800

    [Relax][ONNX] Make ReduceMax/ReduceMin NaN propagation 
order-independent(numpy semantics) (#19755)
    
    Hi Committers,
    
    This PR addresses the `ReduceMax`/ `ReduceMin` part of issue
    https://github.com/apache/tvm/issues/19572. Any suggestions would be
    appreciated if you are available.
    
    ### Root cause:
    The ONNX frontend ReduceMax / ReduceMin converters return relax.op.max /
    relax.op.min. After legalization these map to topi.max / topi.min, which
    fold with a commutative reducer whose combiner is Max(x, y) / Min(x, y).
    In codegen, Max(a, b) lowers to select(a > b, a, b) using an **ordered**
    float comparison (fcmp ogt), which is false for NaN. As a left-fold (acc
    = Max(acc, elem)), NaN propagation becomes **position-dependent** - a
    later non-NaN element silently overwrites an earlier NaN.
    
    ### Solution:
    Adopt the well-defined, **order-independent numpy/IEEE convention**
    (matching numpy.max/min and torch.amax/amin): the reduction yields NaN
    whenever **any** reduced element is NaN. Minimal, ONNX-frontend-only
    change:
    
    - Add a shared helper _reduce_min_max_preserve_nan(reduce_op, data,
    axes, keepdims).
    - For floating-pint inputs, detect NaN along the reduced axes via
    `sum(astype(isnan(data), dtype), axes, keepdims) > 0` and force those
    outputs to `NaN` with `where(has_nan, nan, reduce(data))`. The mask
    reduces over the **same axes/keepdims**, so it aligns in shape with the
    reduced result.
    - Keep non-floating(integer) inputs unchanged.
    - Route all reduce paths(`_impl_v11`and both reduce branches of
    `_impl_v18`) through the helper; the `noop_with_empty_axes` passthrough
    is left untouched since it performs no reduction.
    
    ### Note on scope (re: #19589 ):
    The underlying NaN behavior of Max/Min is the same family of ops
    discussed in #19589. Per review comments there, enforcing NaN semantics
    at the IR / LLVM-IR level is undesirable(backward-compat with older
    LLVM, and portability to CUDA/OpenCL/Vulkan), and a dedicated portable
    nanmin/nanmax TIRx intrinsic(like `nearbyint`) would be the preferred
    long-term mechanism. This PR deliberately:
    
    - does not touch the IR-level Max/Min lowering, and
    - does not rely on the bool reduction of the NaN mask - it uses
    `sum(isnan) > 0`, fully sidestepping Max/Min NaN behavior.
    
    ---------
    
    Co-authored-by: cchung100m <[email protected]>
---
 python/tvm/relax/frontend/onnx/onnx_frontend.py | 34 +++++++++++++++++----
 tests/python/relax/test_frontend_onnx.py        | 40 +++++++++++++++++++++++++
 2 files changed, 68 insertions(+), 6 deletions(-)

diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py 
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index d64020bfc7..2d1cc47377 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -3887,6 +3887,28 @@ class RMSNormalization(OnnxOpConverter):
         return output
 
 
+def _reduce_min_max_preserve_nan(reduce_op, data, axes, keepdims):
+    """Apply a min/max reduction with well-defined, order-independent NaN 
propagation.
+
+    relax.op.max/min legalize to a max/min fold implemented as select(x > y, 
x, y) with an
+    ordered float comparison, so NaN propagation depends on the fold position 
(a later non-NaN
+    element silently overwrites an earlier NaN). ONNX Runtime is also 
order-independent (it only
+    yields NaN when the first reduced element is NaN), which is an 
implementation artifact rather
+    than a defined semantics and is impractical to replicate portably. We 
instead adopt the
+    numpy/IEEE convention used by numpy.max/min and torch.amax/amin: for 
floating pint inputs,
+    detect NaN along the reduced axes and force the output to NaN whenever any 
reduced element is
+    NaN.
+    """
+    y = reduce_op(data, axes, keepdims)
+    dtype = data.struct_info.dtype if isinstance(data.struct_info, 
relax.TensorStructInfo) else None
+    if dtype is None or not _relax_dtype_is_floating_point(dtype):
+        return y
+    nan_count = relax.op.sum(relax.op.astype(relax.op.isnan(data), dtype), 
axes, keepdims)
+    has_nan = relax.op.greater(nan_count, relax.const(0, dtype))
+    nan_filled = relax.op.full_like(y, relax.const(float("nan"), dtype))
+    return relax.op.where(has_nan, nan_filled, y)
+
+
 class ReduceMax(OnnxOpConverter):
     """Converts an onnx ReduceMax node into an equivalent Relax expression."""
 
@@ -3895,7 +3917,7 @@ class ReduceMax(OnnxOpConverter):
         data = inputs[0]
         axes = attr.get("axes", None)
         keepdims = attr.get("keepdims", 1)
-        return relax.op.max(data, axes, keepdims)
+        return _reduce_min_max_preserve_nan(relax.op.max, data, axes, keepdims)
 
     @classmethod
     def _impl_v18(cls, bb, inputs, attr, params):
@@ -3912,13 +3934,13 @@ class ReduceMax(OnnxOpConverter):
 
         # If axes is empty and noop_with_empty_axes is False, reduce all dims
         if not axes and not noop_with_empty_axes:
-            return relax.op.max(data, None, keepdims)
+            return _reduce_min_max_preserve_nan(relax.op.max, data, None, 
keepdims)
         # If axes is empty and noop_with_empty_axes is True, return input 
unchanged
         elif not axes and noop_with_empty_axes:
             return data
         # Otherwise reduce over specified axes
         else:
-            return relax.op.max(data, axes, keepdims)
+            return _reduce_min_max_preserve_nan(relax.op.max, data, axes, 
keepdims)
 
 
 class ReduceMin(OnnxOpConverter):
@@ -3929,7 +3951,7 @@ class ReduceMin(OnnxOpConverter):
         data = inputs[0]
         axes = attr.get("axes", None)
         keepdims = attr.get("keepdims", 1)
-        return relax.op.min(data, axes, keepdims)
+        return _reduce_min_max_preserve_nan(relax.op.min, data, axes, keepdims)
 
     @classmethod
     def _impl_v18(cls, bb, inputs, attr, params):
@@ -3946,13 +3968,13 @@ class ReduceMin(OnnxOpConverter):
 
         # If axes is empty and noop_with_empty_axes is False, reduce all dims
         if not axes and not noop_with_empty_axes:
-            return relax.op.min(data, None, keepdims)
+            return _reduce_min_max_preserve_nan(relax.op.min, data, None, 
keepdims)
         # If axes is empty and noop_with_empty_axes is True, return input 
unchanged
         elif not axes and noop_with_empty_axes:
             return data
         # Otherwise reduce over specified axes
         else:
-            return relax.op.min(data, axes, keepdims)
+            return _reduce_min_max_preserve_nan(relax.op.min, data, axes, 
keepdims)
 
 
 class ReduceSum(OnnxOpConverter):
diff --git a/tests/python/relax/test_frontend_onnx.py 
b/tests/python/relax/test_frontend_onnx.py
index a83333e7d7..db8b977efc 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -804,6 +804,46 @@ def test_sign_nan_preserve():
     )
 
 
[email protected]("op_name", ["ReduceMax", "ReduceMin"])
[email protected](
+    "x",
+    [
+        # NaN in different positions. TVM's max/min fold previously dropped 
NaN depending on
+        # position, ONNX Runtime only propagates NaN when it is the first 
reduced element, which
+        # is an order-dependent implementation artifact. We instead adopt the 
well-defined,
+        # order-independent numpy/IEEE semantics: any NaN in the reduced range 
yields NaN.
+        np.array([np.nan, 1.0, 2.0], dtype=np.float32),
+        np.array([2.0, 1.0, np.nan], dtype=np.float32),
+        np.array([1.0, np.nan, 2.0], dtype=np.float32),
+        np.array([1.0, 2.0, 3.0], dtype=np.float32),
+    ],
+)
+def test_reduce_min_max_nan_preserve(op_name, x):
+    reduce_node = helper.make_node(op_name, ["x"], ["y"], keepdims=0)
+    graph = helper.make_graph(
+        [reduce_node],
+        "reduce_nan_test",
+        inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, 
list(x.shape))],
+        outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [])],
+    )
+    model = helper.make_model(graph, producer_name="reduce_nan_test")
+    model.ir_version = 8
+    for opset_import in model.opset_import:
+        if opset_import.domain in ["", "ai.onnx"]:
+            opset_import.version = 18
+            break
+
+    # Reference is numpy (NaN propagates if any element is NaN), not ONNX 
Runtime.
+    ref_out = (np.max if op_name == "ReduceMax" else np.min)(x)
+
+    tvm_out = run_in_tvm(model, inputs={"x": x}, opset=18)
+    out_np = (tvm_out[0] if isinstance(tvm_out, (list, tuple)) else 
tvm_out).numpy()
+
+    np.testing.assert_array_equal(np.isnan(out_np), np.isnan(ref_out))
+    if not np.isnan(ref_out):
+        np.testing.assert_allclose(out_np, ref_out, rtol=1e-7, atol=1e-5)
+
+
 @pytest.mark.parametrize("op_name", ["Softmax", "LogSoftmax", "Hardmax"])
 def test_softmax_family_opset11_default_axis_semantics(op_name: str):
     verify_unary(op_name, [2, 3, 4], opset=11)

Reply via email to