anwang2009 commented on a change in pull request #8952:
URL: https://github.com/apache/tvm/pull/8952#discussion_r715835129



##########
File path: python/tvm/relay/frontend/onnx.py
##########
@@ -3462,6 +3462,66 @@ def _impl_v10(cls, inputs, attr, params):
         return _qnn.op.quantize(out, c_scale, c_zero_point, out_dtype=dtype)
 
 
+class QLinearMatMul(OnnxOpConverter):
+    """Operator converter for QLinearMatMul from Microsoft onnxruntime contrib 
opset."""
+
+    @classmethod
+    def _impl_v10(cls, inputs, attr, params):
+        def get_scalar(x, dtype="float32"):
+            if isinstance(x, _expr.Var) and x.name_hint in params:
+                return _op.const(params[x.name_hint].numpy(), dtype)
+            rank = len(infer_shape(x))
+            assert rank <= 1, "QLinearMul scale and zero_point input must be 
scalars"
+            if rank == 1:
+                x = _op.squeeze(x, [0])
+            return _op.cast(x, dtype)
+
+        import pdb
+        pdb.set_trace()

Review comment:
       remove

##########
File path: python/tvm/relay/frontend/onnx.py
##########
@@ -3462,6 +3462,66 @@ def _impl_v10(cls, inputs, attr, params):
         return _qnn.op.quantize(out, c_scale, c_zero_point, out_dtype=dtype)
 
 
+class QLinearMatMul(OnnxOpConverter):
+    """Operator converter for QLinearMatMul from Microsoft onnxruntime contrib 
opset."""
+
+    @classmethod
+    def _impl_v10(cls, inputs, attr, params):
+        def get_scalar(x, dtype="float32"):
+            if isinstance(x, _expr.Var) and x.name_hint in params:
+                return _op.const(params[x.name_hint].numpy(), dtype)
+            rank = len(infer_shape(x))
+            assert rank <= 1, "QLinearMul scale and zero_point input must be 
scalars"
+            if rank == 1:
+                x = _op.squeeze(x, [0])
+            return _op.cast(x, dtype)
+
+        import pdb
+        pdb.set_trace()
+
+        a = inputs[0]
+        a_scale = get_scalar(inputs[1])
+        a_zero_point = get_scalar(inputs[2], "int32")
+
+        b = inputs[3]
+        b_scale = get_scalar(inputs[4])
+        b_zero_point = get_scalar(inputs[5], "int32")
+
+        y_scale = fold_constant(get_scalar(inputs[6]))
+        y_zero_point = get_scalar(inputs[7], "int32")
+
+        dtype = infer_type(a).checked_type.dtype
+
+        a_rank = len(infer_shape(a))
+        b_rank = len(infer_shape(b))
+
+        assert (a_rank == 2) and (
+            b_rank == 2
+        ), "QLinearMatMul importer currently requires both 'a' and 'b' tensors 
to be 2D, but rank(a)={}, rank(b)={}".format(
+            a_rank, b_rank
+        )
+
+        ## Note: The ONNX documentation for this op is fairly clear about 
acceptable overflow
+        ## behavior during the matmul operation:
+        ##   - The scalar multiplication ops MAY NOT overflow.
+        ##   - The scalar addition ops, which sum the results of the scalar 
multiplication,
+        ##     MAY overflow, but if they do so, it must behave as one would 
expect during
+        ##     32-bit integer-addition overflow.
+
+        ## As of this writing, Relay's nn.matmul operator doesn't expose a way 
for us to
+        ## express these constraints. So to ensure correct behavior, we'll 
play it safe by
+        ## converting the input tensors to int32 prior before performing 
matmul.
+
+        a_int32 = _op.cast(a, "int32")
+        b_int32 = _op.cast(b, "int32")
+        matmul_int32 = _op.nn.matmul(a_int32, b_int32)
+
+        a = _qnn.op.dequantize(inputs[0], a_scale, a_zero_point)
+        b = _qnn.op.dequantize(inputs[3], b_scale, b_zero_point)
+        out =

Review comment:
       finish thought here. Also, a and b assignments above are never read




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to