This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 645fcf9f0a [Relax][ONNX] Add frontend support for QuantizeLinear,
DequantizeLinear, and DynamicQuantizeLinear (#19391)
645fcf9f0a is described below
commit 645fcf9f0aac0c38fc9f46bcc159486cc19fb635
Author: WANG HUNG-HSIANG <[email protected]>
AuthorDate: Sun Apr 12 09:52:09 2026 +0800
[Relax][ONNX] Add frontend support for QuantizeLinear, DequantizeLinear,
and DynamicQuantizeLinear (#19391)
## Summary
This PR adds Relax ONNX frontend support for:
- `QuantizeLinear`
- `DequantizeLinear`
- `DynamicQuantizeLinear`
The implementation follows existing TVM ONNX frontend patterns and keeps
QDQ handling consistent for singleton quantization parameters and
optional zero-point inputs.
## Changes
- add ONNX frontend converters for `QuantizeLinear`,`DequantizeLinear`,
and `DynamicQuantizeLinear`
- register Q/DQ-related ops in the ONNX converter map
- handle optional zero-point inputs consistently during import
- preserve singleton quantization parameter semantics in the QDQ
legalization path
- improve QDQ legalization behavior for imported ONNX models
- add and update frontend tests for Q/DQ and `DynamicQuantizeLinear`
## Tests
Added or updated tests in `tests/python/relax/test_frontend_onnx.py` to
cover:
- singleton-qparam `QuantizeLinear` in opset 10
- singleton-qparam `DequantizeLinear` in opset 10
- optional-zero-point `QuantizeLinear` in opset 13
- `DynamicQuantizeLinear` in opset 11
## Validation
Validated with:
- `python -m pytest -n 1 tests/python/relax/test_frontend_onnx.py -k
"quantizelinear or dequantizelinear or dynamicquantizelinear" -v`
Result:
- `4 passed`
---
python/tvm/relax/frontend/onnx/onnx_frontend.py | 71 ++++++++++++++++
python/tvm/relax/transform/legalize_ops/qdq.py | 57 +++++++++++--
src/relax/op/tensor/qdq.cc | 54 +++++++++---
tests/python/relax/test_frontend_onnx.py | 108 ++++++++++++++++++++++++
4 files changed, 272 insertions(+), 18 deletions(-)
diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index 2707f6ff1c..5397f2c309 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -311,6 +311,73 @@ class OnnxOpConverter:
return getattr(cls, f"_impl_v{version}")
raise NotImplementedError(f"opset version {version} of {cls.__name__}
not implemented")
+class QuantizeLinear(OnnxOpConverter):
+ @classmethod
+ def _impl_v10(cls, bb, inputs, attr, params):
+ x, scale = inputs[0], inputs[1]
+ zp = inputs[2] if len(inputs) > 2 and inputs[2] is not None else None
+ axis = attr.get("axis", 1)
+ if hasattr(x.struct_info, "ndim") and x.struct_info.ndim <= 1 and axis
== 1:
+ axis = 0
+ out_dtype = "uint8" if zp is None else zp.struct_info.dtype
+ if zp is None:
+ zp = relax.const(0, out_dtype)
+ return relax.op.quantize(x, scale, zp, axis=axis, out_dtype=out_dtype)
+
+ @classmethod
+ def _impl_v13(cls, bb, inputs, attr, params):
+ x, scale = inputs[0], inputs[1]
+ zp = inputs[2] if len(inputs) > 2 and inputs[2] is not None else None
+ axis = attr.get("axis", 1)
+ if hasattr(x.struct_info, "ndim") and x.struct_info.ndim <= 1 and axis
== 1:
+ axis = 0
+ out_dtype = "uint8" if zp is None else zp.struct_info.dtype
+ if zp is None:
+ zp = relax.const(0, out_dtype)
+ return relax.op.quantize(x, scale, zp, axis=axis, out_dtype=out_dtype)
+
+
+class DequantizeLinear(OnnxOpConverter):
+ @classmethod
+ def _impl_v10(cls, bb, inputs, attr, params):
+ x, scale = inputs[0], inputs[1]
+ zp = inputs[2] if len(inputs) > 2 and inputs[2] is not None else None
+ axis = attr.get("axis", 1)
+ if hasattr(x.struct_info, "ndim") and x.struct_info.ndim <= 1 and axis
== 1:
+ axis = 0
+ if zp is None:
+ zp = relax.const(0, x.struct_info.dtype)
+ return relax.op.dequantize(x, scale, zp, axis=axis,
out_dtype="float32")
+
+ @classmethod
+ def _impl_v13(cls, bb, inputs, attr, params):
+ x, scale = inputs[0], inputs[1]
+ zp = inputs[2] if len(inputs) > 2 and inputs[2] is not None else None
+ axis = attr.get("axis", 1)
+ if hasattr(x.struct_info, "ndim") and x.struct_info.ndim <= 1 and axis
== 1:
+ axis = 0
+ if zp is None:
+ zp = relax.const(0, x.struct_info.dtype)
+ return relax.op.dequantize(x, scale, zp, axis=axis,
out_dtype="float32")
+
+
+class DynamicQuantizeLinear(OnnxOpConverter):
+ @classmethod
+ def _impl_v11(cls, bb, inputs, attr, params):
+ x = inputs[0]
+ x_dtype = x.struct_info.dtype
+ qmin = relax.const(0, x_dtype)
+ qmax = relax.const(255, x_dtype)
+
+ x_max = relax.op.maximum(qmin, relax.op.max(x))
+ x_min = relax.op.minimum(qmin, relax.op.min(x))
+ y_scale = relax.op.divide(relax.op.subtract(x_max, x_min), qmax)
+
+ zp_fp = relax.op.subtract(qmin, relax.op.divide(x_min, y_scale))
+ y_zero_point = relax.op.astype(relax.op.round(relax.op.clip(zp_fp, 0,
255)), "uint8")
+
+ y = relax.op.quantize(x, y_scale, y_zero_point, axis=0,
out_dtype="uint8")
+ return relax.Tuple([y, y_scale, y_zero_point])
class MatMul(OnnxOpConverter):
"""Converts an onnx MatMul node into an equivalent Relax expression."""
@@ -4812,6 +4879,10 @@ def _get_convert_map():
"ConcatFromSequence": ConcatFromSequence,
"SplitToSequence": SplitToSequence,
"SequenceAt": SequenceAt,
+ # Quantization
+ "QuantizeLinear": QuantizeLinear,
+ "DequantizeLinear": DequantizeLinear,
+ "DynamicQuantizeLinear": DynamicQuantizeLinear,
}
diff --git a/python/tvm/relax/transform/legalize_ops/qdq.py
b/python/tvm/relax/transform/legalize_ops/qdq.py
index caec63ffa8..5e28d1b291 100644
--- a/python/tvm/relax/transform/legalize_ops/qdq.py
+++ b/python/tvm/relax/transform/legalize_ops/qdq.py
@@ -17,6 +17,7 @@
# pylint: disable=invalid-name
"""Default legalization function for quantize/dequantize operators."""
+from typing import Union
import tvm
from tvm import te, tirx
@@ -35,6 +36,18 @@ def is_const_scalar(x):
return isinstance(x, tvm.tirx.IntImm | tvm.tirx.FloatImm)
+def _is_singleton_qparam(qparam: te.Tensor) -> bool:
+ """Return True if qparam is a tensor with all dimensions equal to 1."""
+ if not isinstance(qparam, te.Tensor):
+ return False
+ if len(qparam.shape) == 0:
+ return True
+ for dim in qparam.shape:
+ if not isinstance(dim, tirx.IntImm) or dim.value != 1:
+ return False
+ return True
+
+
@register_legalize("relax.quantize")
def _quantize(bb: BlockBuilder, call: Call) -> Expr:
"""
@@ -46,12 +59,26 @@ def _quantize(bb: BlockBuilder, call: Call) -> Expr:
def te_quantize(
data: te.Tensor,
- scale: te.Tensor | tirx.IntImm | tirx.FloatImm,
- zp: te.Tensor | tirx.IntImm | tirx.FloatImm,
+ scale: Union[te.Tensor, tirx.IntImm, tirx.FloatImm],
+ zp: Union[te.Tensor, tirx.IntImm, tirx.FloatImm],
):
+ scale_singleton = _is_singleton_qparam(scale) if isinstance(scale,
te.Tensor) else False
+ zp_singleton = _is_singleton_qparam(zp) if isinstance(zp, te.Tensor)
else False
+
def quantize_compute(*indices):
- scale_value = scale if is_const_scalar(scale) else
scale[indices[axis]]
- zp_value = zp if is_const_scalar(zp) else zp[indices[axis]]
+ if is_const_scalar(scale):
+ scale_value = scale
+ elif scale_singleton:
+ scale_value = scale[(0,) * len(scale.shape)]
+ else:
+ scale_value = scale[indices[axis]]
+
+ if is_const_scalar(zp):
+ zp_value = zp
+ elif zp_singleton:
+ zp_value = zp[(0,) * len(zp.shape)]
+ else:
+ zp_value = zp[indices[axis]]
scaled = data[indices] / scale_value
round_val = (te.round(scaled) if "int" in out_dtype else scaled) +
zp_value
return clip_cast(round_val, out_dtype)
@@ -94,12 +121,26 @@ def _dequantize(bb: BlockBuilder, call: Call) -> Expr:
def te_dequantize(
data: te.Tensor,
- scale: te.Tensor | tirx.IntImm | tirx.FloatImm,
- zp: te.Tensor | tirx.IntImm | tirx.FloatImm,
+ scale: Union[te.Tensor, tirx.IntImm, tirx.FloatImm],
+ zp: Union[te.Tensor, tirx.IntImm, tirx.FloatImm],
):
+ scale_singleton = _is_singleton_qparam(scale) if isinstance(scale,
te.Tensor) else False
+ zp_singleton = _is_singleton_qparam(zp) if isinstance(zp, te.Tensor)
else False
+
def dequantize_compute(*indices):
- scale_value = scale if is_const_scalar(scale) else
scale[indices[axis]]
- zp_value = zp if is_const_scalar(zp) else zp[indices[axis]]
+ if is_const_scalar(scale):
+ scale_value = scale
+ elif scale_singleton:
+ scale_value = scale[(0,) * len(scale.shape)]
+ else:
+ scale_value = scale[indices[axis]]
+
+ if is_const_scalar(zp):
+ zp_value = zp
+ elif zp_singleton:
+ zp_value = zp[(0,) * len(zp.shape)]
+ else:
+ zp_value = zp[indices[axis]]
dtype = "float32" if "float" in data.dtype else "int32"
sub = te.subtract(data[indices].astype(dtype), zp_value)
out = te.multiply(sub, scale_value.astype("float32"))
diff --git a/src/relax/op/tensor/qdq.cc b/src/relax/op/tensor/qdq.cc
index 406868ab4b..3a7a9f164a 100644
--- a/src/relax/op/tensor/qdq.cc
+++ b/src/relax/op/tensor/qdq.cc
@@ -79,10 +79,14 @@ StructInfo InferStructInfoQuantize(const Call& call, const
BlockBuilder& ctx) {
}
// Check datatype of zero_point param:
- if (zp_sinfo->dtype != DataType::Int(8) && zp_sinfo->dtype !=
DataType::Float(16)) {
+ if (zp_sinfo->dtype != DataType::Int(8) && zp_sinfo->dtype !=
DataType::UInt(8) &&
+ zp_sinfo->dtype != DataType::Int(16) && zp_sinfo->dtype !=
DataType::UInt(16) &&
+ zp_sinfo->dtype != DataType::Int(32) && zp_sinfo->dtype !=
DataType::UInt(32) &&
+ zp_sinfo->dtype != DataType::Float(16)) {
ctx->ReportFatal(Diagnostic::Error(call)
- << "zero_point param datatype should be 'int8' or
'float16', but got "
- << zp_sinfo->dtype);
+ << "zero_point param datatype should be one of "
+ << "['int8', 'uint8', 'int16', 'uint16', 'int32',
'uint32', 'float16'], "
+ << "but got " << zp_sinfo->dtype);
}
// Check that "axis" attribute is not out of range:
@@ -104,9 +108,22 @@ StructInfo InferStructInfoQuantize(const Call& call, const
BlockBuilder& ctx) {
}
};
+ auto is_scalar_or_singleton_vector = [&](const TensorStructInfo&
param_sinfo) {
+ if (IsScalarTensor(param_sinfo)) return true;
+ if (param_sinfo->shape.defined() &&
param_sinfo->shape->IsInstance<ShapeExprNode>()) {
+ const auto& values = param_sinfo->shape.as<ShapeExprNode>()->values;
+ if (!values.empty()) {
+ return std::all_of(values.begin(), values.end(), [&](const PrimExpr&
dim) {
+ return ctx->GetAnalyzer()->CanProveEqual(dim, 1);
+ });
+ }
+ }
+ return false;
+ };
+
// Check size matching of scale/zp params with input shape at dim =
attrs->axis.
- if (!IsScalarTensor(scale_sinfo)) check_param_size(scale_sinfo, input_sinfo,
"scale");
- if (!IsScalarTensor(zp_sinfo)) check_param_size(zp_sinfo, input_sinfo,
"zero_point");
+ if (!is_scalar_or_singleton_vector(scale_sinfo))
check_param_size(scale_sinfo, input_sinfo, "scale");
+ if (!is_scalar_or_singleton_vector(zp_sinfo)) check_param_size(zp_sinfo,
input_sinfo, "zero_point");
auto output_sinfo =
ffi::make_object<TensorStructInfoNode>(*input_sinfo.get());
output_sinfo->dtype = attrs->out_dtype;
@@ -167,10 +184,14 @@ StructInfo InferStructInfoDequantize(const Call& call,
const BlockBuilder& ctx)
}
// Check datatype of zero_point param:
- if (zp_sinfo->dtype != DataType::Int(8) && zp_sinfo->dtype !=
DataType::Float(16)) {
+ if (zp_sinfo->dtype != DataType::Int(8) && zp_sinfo->dtype !=
DataType::UInt(8) &&
+ zp_sinfo->dtype != DataType::Int(16) && zp_sinfo->dtype !=
DataType::UInt(16) &&
+ zp_sinfo->dtype != DataType::Int(32) && zp_sinfo->dtype !=
DataType::UInt(32) &&
+ zp_sinfo->dtype != DataType::Float(16)) {
ctx->ReportFatal(Diagnostic::Error(call)
- << "zero_point param datatype should be 'int8' or
'float16', but got "
- << zp_sinfo->dtype);
+ << "zero_point param datatype should be one of "
+ << "['int8', 'uint8', 'int16', 'uint16', 'int32',
'uint32', 'float16'], "
+ << "but got " << zp_sinfo->dtype);
}
// Check that "axis" attribute is not out of range:
@@ -192,9 +213,22 @@ StructInfo InferStructInfoDequantize(const Call& call,
const BlockBuilder& ctx)
}
};
+ auto is_scalar_or_singleton_vector = [&](const TensorStructInfo&
param_sinfo) {
+ if (IsScalarTensor(param_sinfo)) return true;
+ if (param_sinfo->shape.defined() &&
param_sinfo->shape->IsInstance<ShapeExprNode>()) {
+ const auto& values = param_sinfo->shape.as<ShapeExprNode>()->values;
+ if (!values.empty()) {
+ return std::all_of(values.begin(), values.end(), [&](const PrimExpr&
dim) {
+ return ctx->GetAnalyzer()->CanProveEqual(dim, 1);
+ });
+ }
+ }
+ return false;
+ };
+
// Check size matching of scale/zp params with input shape at dim =
attrs->axis.
- if (!IsScalarTensor(scale_sinfo)) check_param_size(scale_sinfo, input_sinfo,
"scale");
- if (!IsScalarTensor(zp_sinfo)) check_param_size(zp_sinfo, input_sinfo,
"zero_point");
+ if (!is_scalar_or_singleton_vector(scale_sinfo))
check_param_size(scale_sinfo, input_sinfo, "scale");
+ if (!is_scalar_or_singleton_vector(zp_sinfo)) check_param_size(zp_sinfo,
input_sinfo, "zero_point");
auto output_sinfo =
ffi::make_object<TensorStructInfoNode>(*input_sinfo.get());
output_sinfo->dtype = attrs->out_dtype;
diff --git a/tests/python/relax/test_frontend_onnx.py
b/tests/python/relax/test_frontend_onnx.py
index db7c3da25a..7e434d2659 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -5599,6 +5599,114 @@ def test_split_to_sequence_uneven_last_chunk(axis: int):
model = helper.make_model(graph,
producer_name="test_split_to_sequence_uneven")
check_correctness(model)
+def test_quantizelinear_singleton_qparams_opset10():
+ """QuantizeLinear must treat shape-[1] scale/zp as scalar in opset10."""
+ node = helper.make_node("QuantizeLinear", ["x", "scale", "zero_point"],
["y"])
+ graph = helper.make_graph(
+ [node],
+ "quantizelinear_singleton_qparams_opset10",
+ [helper.make_tensor_value_info("x", TensorProto.FLOAT, [4, 3, 2, 2])],
+ [helper.make_tensor_value_info("y", TensorProto.UINT8, [4, 3, 2, 2])],
+ initializer=[
+ helper.make_tensor("scale", TensorProto.FLOAT, [1], [0.03125]),
+ helper.make_tensor("zero_point", TensorProto.UINT8, [1], [127]),
+ ],
+ )
+ model = helper.make_model(graph, opset_imports=[helper.make_opsetid("",
10)])
+
+ x = rg.standard_normal((4, 3, 2, 2)).astype("float32")
+ check_correctness(model, inputs={"x": x}, opset=10, check_dtypes=True)
+
+
+def test_dequantizelinear_singleton_qparams_opset10():
+ """DequantizeLinear must treat shape-[1] scale/zp as scalar in opset10."""
+ node = helper.make_node("DequantizeLinear", ["x", "scale", "zero_point"],
["y"])
+ graph = helper.make_graph(
+ [node],
+ "dequantizelinear_singleton_qparams_opset10",
+ [helper.make_tensor_value_info("x", TensorProto.UINT8, [64])],
+ [helper.make_tensor_value_info("y", TensorProto.FLOAT, [64])],
+ initializer=[
+ helper.make_tensor("scale", TensorProto.FLOAT, [1], [0.125]),
+ helper.make_tensor("zero_point", TensorProto.UINT8, [1], [1]),
+ ],
+ )
+ model = helper.make_model(graph, opset_imports=[helper.make_opsetid("",
10)])
+
+ x = rg.integers(low=0, high=255, size=(64,), dtype=np.uint8)
+ check_correctness(model, inputs={"x": x}, opset=10, check_dtypes=True)
+
+
+def test_quantizelinear_optional_zero_point_opset13():
+ """ONNX allows missing zero_point input; importer should default it to 0
(uint8)."""
+ node = helper.make_node("QuantizeLinear", ["x", "scale"], ["y"])
+ graph = helper.make_graph(
+ [node],
+ "quantizelinear_optional_zero_point_opset13",
+ [helper.make_tensor_value_info("x", TensorProto.FLOAT, [2, 5])],
+ [helper.make_tensor_value_info("y", TensorProto.UINT8, [2, 5])],
+ initializer=[helper.make_tensor("scale", TensorProto.FLOAT, [],
[0.2])],
+ )
+ model = helper.make_model(graph, opset_imports=[helper.make_opsetid("",
13)])
+
+ x = rg.standard_normal((2, 5)).astype("float32")
+ check_correctness(model, inputs={"x": x}, opset=13, check_dtypes=True)
+
+
+def test_dynamicquantizelinear_opset11():
+ """DynamicQuantizeLinear returns (y, y_scale, y_zero_point) with ORT
parity."""
+ node = helper.make_node("DynamicQuantizeLinear", ["x"], ["y", "y_scale",
"y_zero_point"])
+ graph = helper.make_graph(
+ [node],
+ "dynamicquantizelinear_opset11",
+ [helper.make_tensor_value_info("x", TensorProto.FLOAT, [2, 3, 4])],
+ [
+ helper.make_tensor_value_info("y", TensorProto.UINT8, [2, 3, 4]),
+ helper.make_tensor_value_info("y_scale", TensorProto.FLOAT, []),
+ helper.make_tensor_value_info("y_zero_point", TensorProto.UINT8,
[]),
+ ],
+ )
+ model = helper.make_model(graph, opset_imports=[helper.make_opsetid("",
11)])
+
+ x = rg.standard_normal((2, 3, 4)).astype("float32")
+ check_correctness(model, inputs={"x": x}, opset=11, atol=1e-5, rtol=1e-5,
check_dtypes=True)
+
+def test_quantizelinear_default_axis_opset10():
+ """opset10 QuantizeLinear should honor default axis=1 (not hardcode
axis=0)."""
+ node = helper.make_node("QuantizeLinear", ["x", "scale", "zero_point"],
["y"])
+ graph = helper.make_graph(
+ [node],
+ "quantizelinear_axis_opset10",
+ [helper.make_tensor_value_info("x", TensorProto.FLOAT, [2, 3, 4])],
+ [helper.make_tensor_value_info("y", TensorProto.UINT8, [2, 3, 4])],
+ initializer=[
+ helper.make_tensor("scale", TensorProto.FLOAT, [3], [0.05, 0.1,
0.2]),
+ helper.make_tensor("zero_point", TensorProto.UINT8, [3], [1, 127,
250]),
+ ],
+ )
+ model = helper.make_model(graph, opset_imports=[helper.make_opsetid("",
10)])
+
+ x = rg.standard_normal((2, 3, 4)).astype("float32")
+ check_correctness(model, inputs={"x": x}, opset=10, check_dtypes=True)
+
+
+def test_dequantizelinear_default_axis_opset10():
+ """opset10 DequantizeLinear should honor default axis=1 (not hardcode
axis=0)."""
+ node = helper.make_node("DequantizeLinear", ["x", "scale", "zero_point"],
["y"])
+ graph = helper.make_graph(
+ [node],
+ "dequantizelinear_axis_opset10",
+ [helper.make_tensor_value_info("x", TensorProto.UINT8, [2, 3, 4])],
+ [helper.make_tensor_value_info("y", TensorProto.FLOAT, [2, 3, 4])],
+ initializer=[
+ helper.make_tensor("scale", TensorProto.FLOAT, [3], [0.05, 0.1,
0.2]),
+ helper.make_tensor("zero_point", TensorProto.UINT8, [3], [1, 127,
250]),
+ ],
+ )
+ model = helper.make_model(graph, opset_imports=[helper.make_opsetid("",
10)])
+
+ x = rg.integers(low=0, high=255, size=(2, 3, 4), dtype=np.uint8)
+ check_correctness(model, inputs={"x": x}, opset=10, check_dtypes=True)
if __name__ == "__main__":
tvm.testing.main()