This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 8597d21a8d [Frontend][ONNX] Add MatMulInteger support to Relax ONNX
frontend (#18951)
8597d21a8d is described below
commit 8597d21a8d065e235f59e82ca178617107857c1e
Author: Kryptonite <[email protected]>
AuthorDate: Sun Mar 29 22:08:45 2026 +0300
[Frontend][ONNX] Add MatMulInteger support to Relax ONNX frontend (#18951)
### Summary
Implements the `MatMulInteger` operator (opset 10) in the Relax ONNX
frontend — INT8 matrix multiplication. Required for quantized model
inference (e.g. ONNX QDQ models).
Closes #18945 (Tier 1 — MatMulInteger operator)
### Tests
- All 4 `int8`/`uint8` dtype combinations, with and without scalar zero
points
- 3-D and 4-D batched matmul
---------
Signed-off-by: OmarAzizi <[email protected]>
---
python/tvm/relax/frontend/onnx/onnx_frontend.py | 56 +++++++-
tests/python/relax/test_frontend_onnx.py | 170 ++++++++++++++++++++++++
2 files changed, 225 insertions(+), 1 deletion(-)
diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index fa9d6eb05d..4cc4e99b7b 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -4054,6 +4054,60 @@ class GridSample(OnnxOpConverter):
)
+class MatMulInteger(OnnxOpConverter):
+ """
+ Converts ONNX MatMulInteger (INT8/UINT8 quantized matrix multiply).
+
+ Computes: output = (A - a_zero_point) * (B - b_zero_point)
+ in int32 accumulation, per ONNX spec v10.
+
+ Zero-point shapes per spec:
+ a_zero_point: scalar | [M] (per-row) | [D1, D2, M, 1] (N-D per-row)
+ b_zero_point: scalar | [N] (per-col) | [D1, D2, 1, N] (N-D per-col)
+ """
+
+ @classmethod
+ def _impl_v10(cls, bb, inputs, attr, params):
+ a = inputs[0]
+ b = inputs[1]
+
+ # Optional zero points with default of None (treated as 0)
+ a_zero_point = inputs[2] if len(inputs) > 2 and inputs[2] is not None
else None
+ b_zero_point = inputs[3] if len(inputs) > 3 and inputs[3] is not None
else None
+
+ # Widen to int32 before any arithmetic to prevent overflow
+ a = relax.op.astype(a, "int32")
+ b = relax.op.astype(b, "int32")
+
+ if a_zero_point is not None:
+ a_zp = relax.op.astype(
+ a_zero_point, "int32"
+ ) # Ensure zero point is int32 for subtraction
+ a_zp = bb.normalize(a_zp) # Normalize the expr so struct_info
gets populated
+ a_zp_ndim = len(a_zp.struct_info.shape)
+
+ # Per-row case: [M] -> [M, 1] so it broadcasts over [M, K] row-wise
+ # N-D case: spec says shape is [D1, D2, M, 1], which already
broadcasts correctly (no need to reshape)
+ if a_zp_ndim == 1:
+ a_zp = relax.op.expand_dims(a_zp, axis=-1)
+
+ a = relax.op.subtract(a, a_zp)
+
+ if b_zero_point is not None:
+ b_zp = relax.op.astype(b_zero_point, "int32")
+ b_zp = bb.normalize(b_zp)
+ b_zp_ndim = len(b_zp.struct_info.shape)
+
+ # Per-col case: [N] -> [1, N] so it broadcasts over [K, N]
column-wise
+ # N-D case: [D1, D2, 1, N] already broadcasts correctly
+ if b_zp_ndim == 1:
+ b_zp = relax.op.expand_dims(b_zp, axis=0)
+
+ b = relax.op.subtract(b, b_zp)
+
+ return relax.op.matmul(a, b, out_dtype="int32") # Output is int32 per
ONNX spec
+
+
def _get_convert_map():
return {
# defs/experimental
@@ -4129,7 +4183,7 @@ def _get_convert_map():
"Cast": Cast,
"Gemm": Gemm,
"MatMul": MatMul,
- # "MatMulInteger": MatMulInteger,
+ "MatMulInteger": MatMulInteger,
# "MatMulInteger16": MatMulInteger16,
"Reshape": Reshape,
"Sigmoid": Sigmoid,
diff --git a/tests/python/relax/test_frontend_onnx.py
b/tests/python/relax/test_frontend_onnx.py
index 0fb5f3003a..d04b0c2f33 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -4423,6 +4423,176 @@ def test_if_nested():
)
+# Helper that builds the ONNX graph for MatMulInteger so the tests don't
repeat boilerplate code every time
+def _make_matmulinteger_model(A_shape, B_shape, A_dtype, B_dtype,
a_zp_array=None, b_zp_array=None):
+ """Build a minimal single-node ONNX graph for MatMulInteger."""
+
+ def np_dtype_to_onnx(dt):
+ return {np.int8: TensorProto.INT8, np.uint8: TensorProto.UINT8}[dt]
+
+ A_info = helper.make_tensor_value_info("A", np_dtype_to_onnx(A_dtype),
A_shape)
+ B_info = helper.make_tensor_value_info("B", np_dtype_to_onnx(B_dtype),
B_shape)
+ graph_inputs = [A_info, B_info]
+ node_inputs = ["A", "B"]
+ initializers = []
+
+ def _add_zp(name, arr, dtype):
+ onnx_dtype = np_dtype_to_onnx(dtype)
+ shape = list(arr.shape)
+ initializers.append(helper.make_tensor(name, onnx_dtype, shape,
arr.flatten().tolist()))
+ node_inputs.append(name)
+
+ if a_zp_array is not None:
+ _add_zp("a_zero_point", a_zp_array, A_dtype)
+ elif b_zp_array is not None:
+ node_inputs.append("") # placeholder only needed if b_zp is present
+
+ if b_zp_array is not None:
+ _add_zp("b_zero_point", b_zp_array, B_dtype)
+
+ out_info = helper.make_tensor_value_info("output", TensorProto.INT32, None)
+ node = helper.make_node("MatMulInteger", inputs=node_inputs,
outputs=["output"])
+ graph = helper.make_graph(
+ [node], "matmulinteger", graph_inputs, [out_info],
initializer=initializers
+ )
+ model = helper.make_model(graph, opset_imports=[helper.make_opsetid("",
10)])
+ model.ir_version = 8
+ return model
+
+
[email protected](
+ "A_dtype,B_dtype,a_zp,b_zp",
+ [
+ (np.int8, np.int8, None, None),
+ (np.uint8, np.uint8, None, None),
+ (np.uint8, np.int8, None, None),
+ pytest.param(
+ np.int8,
+ np.uint8,
+ None,
+ None,
+ marks=pytest.mark.xfail(
+ reason="Some older ORT versions doesn't support mixed
int8/uint8 dtype combination for MatMulInteger",
+ strict=False, # not strict - may pass on newer ORT versions
+ ),
+ ),
+ (np.uint8, np.uint8, np.uint8(128), np.uint8(128)),
+ (np.int8, np.int8, np.int8(1), np.int8(2)),
+ ],
+)
+def test_matmulinteger(A_dtype, B_dtype, a_zp, b_zp):
+ """2-D matmul across all dtype combos and zero-point configurations."""
+ np.random.seed(0)
+ A = np.random.randint(-5, 5, (4, 8)).astype(A_dtype)
+ B = np.random.randint(-5, 5, (8, 6)).astype(B_dtype)
+ model = _make_matmulinteger_model(
+ [4, 8],
+ [8, 6],
+ A_dtype,
+ B_dtype,
+ a_zp_array=np.array(a_zp, dtype=A_dtype) if a_zp is not None else None,
+ b_zp_array=np.array(b_zp, dtype=B_dtype) if b_zp is not None else None,
+ )
+ check_correctness(model, inputs={"A": A, "B": B}, opset=10)
+
+
[email protected](
+ "A_shape,B_shape,a_zp,b_zp",
+ [
+ ((2, 4, 8), (2, 8, 6), np.int8(1), np.int8(2)), # 3-D batched
+ ((2, 3, 4, 8), (2, 3, 8, 6), np.int8(1), np.int8(2)), # 4-D batched
+ ],
+)
+def test_matmulinteger_batched(A_shape, B_shape, a_zp, b_zp):
+ """Batched matmul — verifies the op generalizes beyond 2-D."""
+ np.random.seed(1)
+ A = np.random.randint(-5, 5, A_shape).astype(np.int8)
+ B = np.random.randint(-5, 5, B_shape).astype(np.int8)
+ model = _make_matmulinteger_model(
+ list(A_shape),
+ list(B_shape),
+ np.int8,
+ np.int8,
+ a_zp_array=np.array(a_zp, dtype=np.int8),
+ b_zp_array=np.array(b_zp, dtype=np.int8),
+ )
+ check_correctness(model, inputs={"A": A, "B": B}, opset=10)
+
+
+def test_matmulinteger_per_channel_zp():
+ """
+ 1-D zero points: per-row for A ([M]) and per-col for B ([N]).
+ Exercises the expand_dims path in the converter.
+ Note: ORT CPU does not support per-row a_zero_point despite the ONNX spec
+ allowing it, so we verify TVM output against a NumPy reference instead.
+ """
+ np.random.seed(2)
+ A = np.random.randint(-5, 5, (4, 8)).astype(np.int8)
+ B = np.random.randint(-5, 5, (8, 6)).astype(np.int8)
+ a_zp = np.arange(4, dtype=np.int8) # shape [M=4], per-row
+ b_zp = np.arange(6, dtype=np.int8) # shape [N=6], per-col
+
+ # NumPy reference: mirrors the converter's expand_dims logic
+ expected = np.matmul(
+ A.astype(np.int32) - a_zp.astype(np.int32)[:, np.newaxis],
+ B.astype(np.int32) - b_zp.astype(np.int32)[np.newaxis, :],
+ ).astype(np.int32)
+
+ model = _make_matmulinteger_model(
+ [4, 8], [8, 6], np.int8, np.int8, a_zp_array=a_zp, b_zp_array=b_zp
+ )
+
+ # Run TVM only — ORT doesn't support per-row a_zero_point
+ tvm_model = from_onnx(model, opset=10, keep_params_in_input=True)
+ tvm_model = relax.transform.DecomposeOpsForInference()(tvm_model)
+ tvm_model = relax.transform.LegalizeOps()(tvm_model)
+ tvm_model, params = relax.frontend.detach_params(tvm_model)
+
+ with tvm.transform.PassContext(opt_level=3):
+ ex = tvm.compile(tvm_model, target="llvm")
+ vm = relax.VirtualMachine(ex, tvm.cpu())
+
+ input_list = [
+ {"A": A, "B": B}[k.name_hint] for k in tvm_model["main"].params if
k.name_hint in {"A", "B"}
+ ]
+ if params:
+ input_list += params["main"]
+
+ vm.set_input("main", *input_list)
+ vm.invoke_stateful("main")
+ tvm_output = vm.get_outputs("main").numpy()
+
+ tvm.testing.assert_allclose(tvm_output, expected)
+
+
[email protected](
+ reason=(
+ "ORT doesn't support per-row a_zero_point of shape [M] "
+ "despite the ONNX spec explicitly allowing it. "
+ "See: matmul_integer.cc:63 IsScalarOr1ElementVector(a_zero_point)"
+ ),
+ strict=True, # must fail, if ORT ever fixes this, the test will alert us
+)
+def test_matmulinteger_per_channel_zp_ort_limitation():
+ """
+ Documents that ORT CPU rejects per-row a_zero_point of shape [M].
+ Marked xfail because this is a valid ONNX spec case that ORT simply
+ hasn't implemented. If this test starts passing, ORT has fixed the
+ limitation and test_matmulinteger_per_channel_zp can be simplified
+ to use check_correctness instead of a manual TVM-only reference.
+ """
+ np.random.seed(2)
+ A = np.random.randint(-5, 5, (4, 8)).astype(np.int8)
+ B = np.random.randint(-5, 5, (8, 6)).astype(np.int8)
+ a_zp = np.arange(4, dtype=np.int8) # shape [M=4], per-row
+ b_zp = np.arange(6, dtype=np.int8) # shape [N=6], per-col
+
+ model = _make_matmulinteger_model(
+ [4, 8], [8, 6], np.int8, np.int8, a_zp_array=a_zp, b_zp_array=b_zp
+ )
+ check_correctness(model, inputs={"A": A, "B": B}, opset=10)
+
+
@pytest.mark.parametrize(
("pooled_shape", "rois"),
[