This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 8597d21a8d [Frontend][ONNX] Add MatMulInteger support to Relax ONNX 
frontend (#18951)
8597d21a8d is described below

commit 8597d21a8d065e235f59e82ca178617107857c1e
Author: Kryptonite <[email protected]>
AuthorDate: Sun Mar 29 22:08:45 2026 +0300

    [Frontend][ONNX] Add MatMulInteger support to Relax ONNX frontend (#18951)
    
    ### Summary
    
    Implements the `MatMulInteger` operator (opset 10) in the Relax ONNX
    frontend — INT8 matrix multiplication. Required for quantized model
    inference (e.g. ONNX QDQ models).
    
    Closes #18945 (Tier 1 — MatMulInteger operator)
    
    ### Tests
    
    - All 4 `int8`/`uint8` dtype combinations, with and without scalar zero
    points
    - 3-D and 4-D batched matmul
    
    ---------
    
    Signed-off-by: OmarAzizi <[email protected]>
---
 python/tvm/relax/frontend/onnx/onnx_frontend.py |  56 +++++++-
 tests/python/relax/test_frontend_onnx.py        | 170 ++++++++++++++++++++++++
 2 files changed, 225 insertions(+), 1 deletion(-)

diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py 
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index fa9d6eb05d..4cc4e99b7b 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -4054,6 +4054,60 @@ class GridSample(OnnxOpConverter):
         )
 
 
+class MatMulInteger(OnnxOpConverter):
+    """
+    Converts ONNX MatMulInteger (INT8/UINT8 quantized matrix multiply).
+
+    Computes: output = (A - a_zero_point) * (B - b_zero_point)
+    in int32 accumulation, per ONNX spec v10.
+
+    Zero-point shapes per spec:
+      a_zero_point: scalar | [M] (per-row) | [D1, D2, M, 1] (N-D per-row)
+      b_zero_point: scalar | [N] (per-col) | [D1, D2, 1, N] (N-D per-col)
+    """
+
+    @classmethod
+    def _impl_v10(cls, bb, inputs, attr, params):
+        a = inputs[0]
+        b = inputs[1]
+
+        # Optional zero points with default of None (treated as 0)
+        a_zero_point = inputs[2] if len(inputs) > 2 and inputs[2] is not None 
else None
+        b_zero_point = inputs[3] if len(inputs) > 3 and inputs[3] is not None 
else None
+
+        # Widen to int32 before any arithmetic to prevent overflow
+        a = relax.op.astype(a, "int32")
+        b = relax.op.astype(b, "int32")
+
+        if a_zero_point is not None:
+            a_zp = relax.op.astype(
+                a_zero_point, "int32"
+            )  # Ensure zero point is int32 for subtraction
+            a_zp = bb.normalize(a_zp)  # Normalize the expr so struct_info 
gets populated
+            a_zp_ndim = len(a_zp.struct_info.shape)
+
+            # Per-row case: [M] -> [M, 1] so it broadcasts over [M, K] row-wise
+            # N-D case: spec says shape is [D1, D2, M, 1], which already 
broadcasts correctly (no need to reshape)
+            if a_zp_ndim == 1:
+                a_zp = relax.op.expand_dims(a_zp, axis=-1)
+
+            a = relax.op.subtract(a, a_zp)
+
+        if b_zero_point is not None:
+            b_zp = relax.op.astype(b_zero_point, "int32")
+            b_zp = bb.normalize(b_zp)
+            b_zp_ndim = len(b_zp.struct_info.shape)
+
+            # Per-col case: [N] -> [1, N] so it broadcasts over [K, N] 
column-wise
+            # N-D case: [D1, D2, 1, N] already broadcasts correctly
+            if b_zp_ndim == 1:
+                b_zp = relax.op.expand_dims(b_zp, axis=0)
+
+            b = relax.op.subtract(b, b_zp)
+
+        return relax.op.matmul(a, b, out_dtype="int32")  # Output is int32 per 
ONNX spec
+
+
 def _get_convert_map():
     return {
         # defs/experimental
@@ -4129,7 +4183,7 @@ def _get_convert_map():
         "Cast": Cast,
         "Gemm": Gemm,
         "MatMul": MatMul,
-        # "MatMulInteger": MatMulInteger,
+        "MatMulInteger": MatMulInteger,
         # "MatMulInteger16": MatMulInteger16,
         "Reshape": Reshape,
         "Sigmoid": Sigmoid,
diff --git a/tests/python/relax/test_frontend_onnx.py 
b/tests/python/relax/test_frontend_onnx.py
index 0fb5f3003a..d04b0c2f33 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -4423,6 +4423,176 @@ def test_if_nested():
     )
 
 
+# Helper that builds the ONNX graph for MatMulInteger so the tests don't 
repeat boilerplate code every time
+def _make_matmulinteger_model(A_shape, B_shape, A_dtype, B_dtype, 
a_zp_array=None, b_zp_array=None):
+    """Build a minimal single-node ONNX graph for MatMulInteger."""
+
+    def np_dtype_to_onnx(dt):
+        return {np.int8: TensorProto.INT8, np.uint8: TensorProto.UINT8}[dt]
+
+    A_info = helper.make_tensor_value_info("A", np_dtype_to_onnx(A_dtype), 
A_shape)
+    B_info = helper.make_tensor_value_info("B", np_dtype_to_onnx(B_dtype), 
B_shape)
+    graph_inputs = [A_info, B_info]
+    node_inputs = ["A", "B"]
+    initializers = []
+
+    def _add_zp(name, arr, dtype):
+        onnx_dtype = np_dtype_to_onnx(dtype)
+        shape = list(arr.shape)
+        initializers.append(helper.make_tensor(name, onnx_dtype, shape, 
arr.flatten().tolist()))
+        node_inputs.append(name)
+
+    if a_zp_array is not None:
+        _add_zp("a_zero_point", a_zp_array, A_dtype)
+    elif b_zp_array is not None:
+        node_inputs.append("")  # placeholder only needed if b_zp is present
+
+    if b_zp_array is not None:
+        _add_zp("b_zero_point", b_zp_array, B_dtype)
+
+    out_info = helper.make_tensor_value_info("output", TensorProto.INT32, None)
+    node = helper.make_node("MatMulInteger", inputs=node_inputs, 
outputs=["output"])
+    graph = helper.make_graph(
+        [node], "matmulinteger", graph_inputs, [out_info], 
initializer=initializers
+    )
+    model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 
10)])
+    model.ir_version = 8
+    return model
+
+
[email protected](
+    "A_dtype,B_dtype,a_zp,b_zp",
+    [
+        (np.int8, np.int8, None, None),
+        (np.uint8, np.uint8, None, None),
+        (np.uint8, np.int8, None, None),
+        pytest.param(
+            np.int8,
+            np.uint8,
+            None,
+            None,
+            marks=pytest.mark.xfail(
+                reason="Some older ORT versions doesn't support mixed 
int8/uint8 dtype combination for MatMulInteger",
+                strict=False,  # not strict - may pass on newer ORT versions
+            ),
+        ),
+        (np.uint8, np.uint8, np.uint8(128), np.uint8(128)),
+        (np.int8, np.int8, np.int8(1), np.int8(2)),
+    ],
+)
+def test_matmulinteger(A_dtype, B_dtype, a_zp, b_zp):
+    """2-D matmul across all dtype combos and zero-point configurations."""
+    np.random.seed(0)
+    A = np.random.randint(-5, 5, (4, 8)).astype(A_dtype)
+    B = np.random.randint(-5, 5, (8, 6)).astype(B_dtype)
+    model = _make_matmulinteger_model(
+        [4, 8],
+        [8, 6],
+        A_dtype,
+        B_dtype,
+        a_zp_array=np.array(a_zp, dtype=A_dtype) if a_zp is not None else None,
+        b_zp_array=np.array(b_zp, dtype=B_dtype) if b_zp is not None else None,
+    )
+    check_correctness(model, inputs={"A": A, "B": B}, opset=10)
+
+
[email protected](
+    "A_shape,B_shape,a_zp,b_zp",
+    [
+        ((2, 4, 8), (2, 8, 6), np.int8(1), np.int8(2)),  # 3-D batched
+        ((2, 3, 4, 8), (2, 3, 8, 6), np.int8(1), np.int8(2)),  # 4-D batched
+    ],
+)
+def test_matmulinteger_batched(A_shape, B_shape, a_zp, b_zp):
+    """Batched matmul — verifies the op generalizes beyond 2-D."""
+    np.random.seed(1)
+    A = np.random.randint(-5, 5, A_shape).astype(np.int8)
+    B = np.random.randint(-5, 5, B_shape).astype(np.int8)
+    model = _make_matmulinteger_model(
+        list(A_shape),
+        list(B_shape),
+        np.int8,
+        np.int8,
+        a_zp_array=np.array(a_zp, dtype=np.int8),
+        b_zp_array=np.array(b_zp, dtype=np.int8),
+    )
+    check_correctness(model, inputs={"A": A, "B": B}, opset=10)
+
+
+def test_matmulinteger_per_channel_zp():
+    """
+    1-D zero points: per-row for A ([M]) and per-col for B ([N]).
+    Exercises the expand_dims path in the converter.
+    Note: ORT CPU does not support per-row a_zero_point despite the ONNX spec
+    allowing it, so we verify TVM output against a NumPy reference instead.
+    """
+    np.random.seed(2)
+    A = np.random.randint(-5, 5, (4, 8)).astype(np.int8)
+    B = np.random.randint(-5, 5, (8, 6)).astype(np.int8)
+    a_zp = np.arange(4, dtype=np.int8)  # shape [M=4], per-row
+    b_zp = np.arange(6, dtype=np.int8)  # shape [N=6], per-col
+
+    # NumPy reference: mirrors the converter's expand_dims logic
+    expected = np.matmul(
+        A.astype(np.int32) - a_zp.astype(np.int32)[:, np.newaxis],
+        B.astype(np.int32) - b_zp.astype(np.int32)[np.newaxis, :],
+    ).astype(np.int32)
+
+    model = _make_matmulinteger_model(
+        [4, 8], [8, 6], np.int8, np.int8, a_zp_array=a_zp, b_zp_array=b_zp
+    )
+
+    # Run TVM only — ORT doesn't support per-row a_zero_point
+    tvm_model = from_onnx(model, opset=10, keep_params_in_input=True)
+    tvm_model = relax.transform.DecomposeOpsForInference()(tvm_model)
+    tvm_model = relax.transform.LegalizeOps()(tvm_model)
+    tvm_model, params = relax.frontend.detach_params(tvm_model)
+
+    with tvm.transform.PassContext(opt_level=3):
+        ex = tvm.compile(tvm_model, target="llvm")
+        vm = relax.VirtualMachine(ex, tvm.cpu())
+
+    input_list = [
+        {"A": A, "B": B}[k.name_hint] for k in tvm_model["main"].params if 
k.name_hint in {"A", "B"}
+    ]
+    if params:
+        input_list += params["main"]
+
+    vm.set_input("main", *input_list)
+    vm.invoke_stateful("main")
+    tvm_output = vm.get_outputs("main").numpy()
+
+    tvm.testing.assert_allclose(tvm_output, expected)
+
+
[email protected](
+    reason=(
+        "ORT doesn't support per-row a_zero_point of shape [M] "
+        "despite the ONNX spec explicitly allowing it. "
+        "See: matmul_integer.cc:63 IsScalarOr1ElementVector(a_zero_point)"
+    ),
+    strict=True,  # must fail, if ORT ever fixes this, the test will alert us
+)
+def test_matmulinteger_per_channel_zp_ort_limitation():
+    """
+    Documents that ORT CPU rejects per-row a_zero_point of shape [M].
+    Marked xfail because this is a valid ONNX spec case that ORT simply
+    hasn't implemented. If this test starts passing, ORT has fixed the
+    limitation and test_matmulinteger_per_channel_zp can be simplified
+    to use check_correctness instead of a manual TVM-only reference.
+    """
+    np.random.seed(2)
+    A = np.random.randint(-5, 5, (4, 8)).astype(np.int8)
+    B = np.random.randint(-5, 5, (8, 6)).astype(np.int8)
+    a_zp = np.arange(4, dtype=np.int8)  # shape [M=4], per-row
+    b_zp = np.arange(6, dtype=np.int8)  # shape [N=6], per-col
+
+    model = _make_matmulinteger_model(
+        [4, 8], [8, 6], np.int8, np.int8, a_zp_array=a_zp, b_zp_array=b_zp
+    )
+    check_correctness(model, inputs={"A": A, "B": B}, opset=10)
+
+
 @pytest.mark.parametrize(
     ("pooled_shape", "rois"),
     [

Reply via email to