mbrookhart commented on a change in pull request #8883:
URL: https://github.com/apache/tvm/pull/8883#discussion_r703793875



##########
File path: python/tvm/relay/transform/fake_quantization_to_integer.py
##########
@@ -198,19 +219,51 @@ def clip(expr, type_map):
     amax = expr.attrs.a_max
     scale = fold_constant(t.scale)
     z_p = fold_constant(t.zero_point)
-    if isinstance(scale, relay.expr.Constant) and isinstance(z_p, 
relay.expr.Constant):
+    if (
+        isinstance(scale, relay.expr.Constant)
+        and scale.data.numpy().size == 1
+        and isinstance(z_p, relay.expr.Constant)
+        and z_p.data.numpy().size == 1
+    ):
         scale = scale.data.numpy().item()
         z_p = z_p.data.numpy().item()
         new_min = int(amin / scale + z_p)
         new_max = int(amax / scale + z_p)
         out = relay.op.clip(arg, new_min, new_max)
     else:
-        amin = relay.op.round(relay.op.const(amin) / scale + z_p)
-        amax = relay.op.round(relay.op.const(amax) / scale + z_p)
-        out = relay.op.minimum(relay.op.maximum(arg, amin), amax)
+        if not isinstance(amin, relay.expr.Constant):
+            amin = relay.op.const(amin)
+        if not isinstance(amax, relay.expr.Constant):
+            amax = relay.op.const(amax)
+
+        scale_shape = infer_shape(scale)
+        if len(scale_shape) > 0 and scale_shape[0] > 1:
+            b_shape = [1] * len(infer_shape(arg))
+            b_shape[t.axis] = -1
+            amin = relay.op.reshape(relay.op.broadcast_to(amin, scale_shape), 
b_shape)
+            amax = relay.op.reshape(relay.op.broadcast_to(amax, scale_shape), 
b_shape)
+        amin = relay.qnn.op.quantize(amin, scale, z_p, t.axis, t.dtype)
+        amax = relay.qnn.op.quantize(amax, scale, z_p, t.axis, t.dtype)
+        out = relay.op.minimum(relay.op.maximum(arg, fold_constant(amin)), 
fold_constant(amax))
+
     return [out, t]
 
 
+@register_fake_quantization_to_integer("nn.relu")
+def relu(expr, type_map):
+    """Rewrite a relu op"""
+    arg = expr.args[0]
+    t = type_map[arg]
+    scale_shape = infer_shape(t.scale)
+    z_p = t.zero_point

Review comment:
       on the first Q, no, the broadcasted z_p will not be updated in place, 
it's only used in the computation.
   
   On the second Q, that's an interesting point. I guess I haven't every seen 
it, but it's feasible. Qnn currently supports scalar scale and scalar zp OR 
vector scale and scalar zp OR vector scale and vector zp, which matches all of 
the combinations I've ever seen in the wild. What do you think, should we try 
to support that in QNN?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to