This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new b44988ba9f [Relax][ONNX] Make ReduceMax/ReduceMin NaN propagation
order-independent(numpy semantics) (#19755)
b44988ba9f is described below
commit b44988ba9fd82c02a2be1b2014a12041d8ed5785
Author: Neo Chien <[email protected]>
AuthorDate: Tue Jun 16 04:19:50 2026 +0800
[Relax][ONNX] Make ReduceMax/ReduceMin NaN propagation
order-independent(numpy semantics) (#19755)
Hi Committers,
This PR addresses the `ReduceMax`/ `ReduceMin` part of issue
https://github.com/apache/tvm/issues/19572. Any suggestions would be
appreciated if you are available.
### Root cause:
The ONNX frontend ReduceMax / ReduceMin converters return relax.op.max /
relax.op.min. After legalization these map to topi.max / topi.min, which
fold with a commutative reducer whose combiner is Max(x, y) / Min(x, y).
In codegen, Max(a, b) lowers to select(a > b, a, b) using an **ordered**
float comparison (fcmp ogt), which is false for NaN. As a left-fold (acc
= Max(acc, elem)), NaN propagation becomes **position-dependent** - a
later non-NaN element silently overwrites an earlier NaN.
### Solution:
Adopt the well-defined, **order-independent numpy/IEEE convention**
(matching numpy.max/min and torch.amax/amin): the reduction yields NaN
whenever **any** reduced element is NaN. Minimal, ONNX-frontend-only
change:
- Add a shared helper _reduce_min_max_preserve_nan(reduce_op, data,
axes, keepdims).
- For floating-pint inputs, detect NaN along the reduced axes via
`sum(astype(isnan(data), dtype), axes, keepdims) > 0` and force those
outputs to `NaN` with `where(has_nan, nan, reduce(data))`. The mask
reduces over the **same axes/keepdims**, so it aligns in shape with the
reduced result.
- Keep non-floating(integer) inputs unchanged.
- Route all reduce paths(`_impl_v11`and both reduce branches of
`_impl_v18`) through the helper; the `noop_with_empty_axes` passthrough
is left untouched since it performs no reduction.
### Note on scope (re: #19589 ):
The underlying NaN behavior of Max/Min is the same family of ops
discussed in #19589. Per review comments there, enforcing NaN semantics
at the IR / LLVM-IR level is undesirable(backward-compat with older
LLVM, and portability to CUDA/OpenCL/Vulkan), and a dedicated portable
nanmin/nanmax TIRx intrinsic(like `nearbyint`) would be the preferred
long-term mechanism. This PR deliberately:
- does not touch the IR-level Max/Min lowering, and
- does not rely on the bool reduction of the NaN mask - it uses
`sum(isnan) > 0`, fully sidestepping Max/Min NaN behavior.
---------
Co-authored-by: cchung100m <[email protected]>
---
python/tvm/relax/frontend/onnx/onnx_frontend.py | 34 +++++++++++++++++----
tests/python/relax/test_frontend_onnx.py | 40 +++++++++++++++++++++++++
2 files changed, 68 insertions(+), 6 deletions(-)
diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index d64020bfc7..2d1cc47377 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -3887,6 +3887,28 @@ class RMSNormalization(OnnxOpConverter):
return output
+def _reduce_min_max_preserve_nan(reduce_op, data, axes, keepdims):
+ """Apply a min/max reduction with well-defined, order-independent NaN
propagation.
+
+ relax.op.max/min legalize to a max/min fold implemented as select(x > y,
x, y) with an
+ ordered float comparison, so NaN propagation depends on the fold position
(a later non-NaN
+ element silently overwrites an earlier NaN). ONNX Runtime is also
order-independent (it only
+ yields NaN when the first reduced element is NaN), which is an
implementation artifact rather
+ than a defined semantics and is impractical to replicate portably. We
instead adopt the
+ numpy/IEEE convention used by numpy.max/min and torch.amax/amin: for
floating pint inputs,
+ detect NaN along the reduced axes and force the output to NaN whenever any
reduced element is
+ NaN.
+ """
+ y = reduce_op(data, axes, keepdims)
+ dtype = data.struct_info.dtype if isinstance(data.struct_info,
relax.TensorStructInfo) else None
+ if dtype is None or not _relax_dtype_is_floating_point(dtype):
+ return y
+ nan_count = relax.op.sum(relax.op.astype(relax.op.isnan(data), dtype),
axes, keepdims)
+ has_nan = relax.op.greater(nan_count, relax.const(0, dtype))
+ nan_filled = relax.op.full_like(y, relax.const(float("nan"), dtype))
+ return relax.op.where(has_nan, nan_filled, y)
+
+
class ReduceMax(OnnxOpConverter):
"""Converts an onnx ReduceMax node into an equivalent Relax expression."""
@@ -3895,7 +3917,7 @@ class ReduceMax(OnnxOpConverter):
data = inputs[0]
axes = attr.get("axes", None)
keepdims = attr.get("keepdims", 1)
- return relax.op.max(data, axes, keepdims)
+ return _reduce_min_max_preserve_nan(relax.op.max, data, axes, keepdims)
@classmethod
def _impl_v18(cls, bb, inputs, attr, params):
@@ -3912,13 +3934,13 @@ class ReduceMax(OnnxOpConverter):
# If axes is empty and noop_with_empty_axes is False, reduce all dims
if not axes and not noop_with_empty_axes:
- return relax.op.max(data, None, keepdims)
+ return _reduce_min_max_preserve_nan(relax.op.max, data, None,
keepdims)
# If axes is empty and noop_with_empty_axes is True, return input
unchanged
elif not axes and noop_with_empty_axes:
return data
# Otherwise reduce over specified axes
else:
- return relax.op.max(data, axes, keepdims)
+ return _reduce_min_max_preserve_nan(relax.op.max, data, axes,
keepdims)
class ReduceMin(OnnxOpConverter):
@@ -3929,7 +3951,7 @@ class ReduceMin(OnnxOpConverter):
data = inputs[0]
axes = attr.get("axes", None)
keepdims = attr.get("keepdims", 1)
- return relax.op.min(data, axes, keepdims)
+ return _reduce_min_max_preserve_nan(relax.op.min, data, axes, keepdims)
@classmethod
def _impl_v18(cls, bb, inputs, attr, params):
@@ -3946,13 +3968,13 @@ class ReduceMin(OnnxOpConverter):
# If axes is empty and noop_with_empty_axes is False, reduce all dims
if not axes and not noop_with_empty_axes:
- return relax.op.min(data, None, keepdims)
+ return _reduce_min_max_preserve_nan(relax.op.min, data, None,
keepdims)
# If axes is empty and noop_with_empty_axes is True, return input
unchanged
elif not axes and noop_with_empty_axes:
return data
# Otherwise reduce over specified axes
else:
- return relax.op.min(data, axes, keepdims)
+ return _reduce_min_max_preserve_nan(relax.op.min, data, axes,
keepdims)
class ReduceSum(OnnxOpConverter):
diff --git a/tests/python/relax/test_frontend_onnx.py
b/tests/python/relax/test_frontend_onnx.py
index a83333e7d7..db8b977efc 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -804,6 +804,46 @@ def test_sign_nan_preserve():
)
[email protected]("op_name", ["ReduceMax", "ReduceMin"])
[email protected](
+ "x",
+ [
+ # NaN in different positions. TVM's max/min fold previously dropped
NaN depending on
+ # position, ONNX Runtime only propagates NaN when it is the first
reduced element, which
+ # is an order-dependent implementation artifact. We instead adopt the
well-defined,
+ # order-independent numpy/IEEE semantics: any NaN in the reduced range
yields NaN.
+ np.array([np.nan, 1.0, 2.0], dtype=np.float32),
+ np.array([2.0, 1.0, np.nan], dtype=np.float32),
+ np.array([1.0, np.nan, 2.0], dtype=np.float32),
+ np.array([1.0, 2.0, 3.0], dtype=np.float32),
+ ],
+)
+def test_reduce_min_max_nan_preserve(op_name, x):
+ reduce_node = helper.make_node(op_name, ["x"], ["y"], keepdims=0)
+ graph = helper.make_graph(
+ [reduce_node],
+ "reduce_nan_test",
+ inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT,
list(x.shape))],
+ outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [])],
+ )
+ model = helper.make_model(graph, producer_name="reduce_nan_test")
+ model.ir_version = 8
+ for opset_import in model.opset_import:
+ if opset_import.domain in ["", "ai.onnx"]:
+ opset_import.version = 18
+ break
+
+ # Reference is numpy (NaN propagates if any element is NaN), not ONNX
Runtime.
+ ref_out = (np.max if op_name == "ReduceMax" else np.min)(x)
+
+ tvm_out = run_in_tvm(model, inputs={"x": x}, opset=18)
+ out_np = (tvm_out[0] if isinstance(tvm_out, (list, tuple)) else
tvm_out).numpy()
+
+ np.testing.assert_array_equal(np.isnan(out_np), np.isnan(ref_out))
+ if not np.isnan(ref_out):
+ np.testing.assert_allclose(out_np, ref_out, rtol=1e-7, atol=1e-5)
+
+
@pytest.mark.parametrize("op_name", ["Softmax", "LogSoftmax", "Hardmax"])
def test_softmax_family_opset11_default_axis_semantics(op_name: str):
verify_unary(op_name, [2, 3, 4], opset=11)