AndrewZhaoLuo commented on a change in pull request #10718:
URL: https://github.com/apache/tvm/pull/10718#discussion_r834598694
##########
File path: src/relay/qnn/op/mul.cc
##########
@@ -51,44 +53,108 @@ Expr QnnMulCanonicalize(const Attrs& attrs, const
Array<Expr>& new_args,
const auto int32_dtype = DataType::Int(32);
const auto float32_dtype = DataType::Float(32);
- /*
- A tensor multiplication c = a * b can be written in terms of respective
- quantized tensors, scales and zero points as
- S_c * (Q_c - zp_c) = S_a * (Q_a - zp_a) * S_b * (Q_b - zp_b).
-
- We can consider the product (Q_a - zp_a) * (Q_b - zp_b) as a different
- quantized tensor of c, Q', with corresponding scale S' = S_a * S_b and zp' =
- 0. The quantized multiplication then becomes
- Q_c = S'/S_c Q' + z_c,
- which is essentially a requantization of tensor Q' into tensor Q_c.
- */
-
- auto lhs_shifted = Cast(args.lhs, int32_dtype);
- auto rhs_shifted = Cast(args.rhs, int32_dtype);
-
- auto zero_scalar = MakeConstantScalar(int32_dtype, 0);
- if (!IsEqualScalar(args.lhs_zero_point, zero_scalar)) {
- lhs_shifted = Subtract(lhs_shifted, args.lhs_zero_point);
+ const auto* broadcast_attrs = attrs.as<BroadcastAttrs>();
+ ICHECK(broadcast_attrs != nullptr);
+
+ auto lhs_axis = broadcast_attrs->lhs_axis;
+ auto rhs_axis = broadcast_attrs->rhs_axis;
+
+ if (lhs_axis == -1 && rhs_axis == -1) {
+ /*
+ This is per-tensor quantized multiply.
+
+ A tensor multiplication c = a * b can be written in terms of respective
+ quantized tensors, scales and zero points as
+ S_c * (Q_c - zp_c) = S_a * (Q_a - zp_a) * S_b * (Q_b - zp_b).
+
+ We can consider the product (Q_a - zp_a) * (Q_b - zp_b) as a different
+ quantized tensor of c, Q', with corresponding scale S' = S_a * S_b and zp'
=
+ 0. The quantized multiplication then becomes
+ Q_c = S'/S_c Q' + z_c,
+ which is essentially a requantization of tensor Q' into tensor Q_c.
+ */
+
+ auto lhs_shifted = Cast(args.lhs, int32_dtype);
+ auto rhs_shifted = Cast(args.rhs, int32_dtype);
+
+ auto zero_scalar = MakeConstantScalar(int32_dtype, 0);
+ if (!IsEqualScalar(args.lhs_zero_point, zero_scalar)) {
+ lhs_shifted = Subtract(lhs_shifted, args.lhs_zero_point);
+ }
+
+ if (!IsEqualScalar(args.rhs_zero_point, zero_scalar)) {
+ rhs_shifted = Subtract(rhs_shifted, args.rhs_zero_point);
+ }
+
+ // Create a new tensor Q'
+ output = Multiply(lhs_shifted, rhs_shifted);
+
+ // Get the adjusted new scale and zero points.
+ float lhs_scale_float = GetScalarFromConstant<float>(args.lhs_scale);
+ float rhs_scale_float = GetScalarFromConstant<float>(args.rhs_scale);
+ float new_scale_float = lhs_scale_float * rhs_scale_float;
+ auto new_input_scale = MakeConstantScalar(float32_dtype, new_scale_float);
+ auto new_input_zero_point = zero_scalar;
+
+ // Requantize to get Q_c
+ output = Requantize(output, input_type.shape, new_input_scale,
new_input_zero_point,
+ args.output_scale, args.output_zero_point,
input_type.dtype);
+ } else if (lhs_axis == rhs_axis) {
Review comment:
this does not handle the negative axis case (which is fine, but if we
don't want to handle it we should throw an error).
##########
File path: tests/python/relay/test_pass_fake_quantization_to_integer.py
##########
@@ -600,13 +600,97 @@ def test_fake_quantize_binary(operator):
compare_fq_to_int(op, [x_np, y_np])
[email protected](
+ "operator",
+ [relay.op.add, relay.op.multiply, relay.op.subtract, relay.op.minimum,
relay.op.maximum],
+)
+def test_fake_quantize_binary_per_channel(operator):
+ def verify_binary_per_channel(lhs_scale, rhs_scale, lhs_zp, rhs_zp,
out_zp, lhs_axis, rhs_axis):
+ if operator == relay.op.multiply:
+ out_scale = relay.const(2.0)
+ rhs_axis = lhs_axis # TODO: Support different axes for
per-channel quantized multiply
+ else:
+ out_scale = relay.const(0.1)
+
+ x = relay.var("x", shape=[1, 3, 224, 224], dtype="int8")
+ x = relay.qnn.op.dequantize(x, relay.const(lhs_scale),
relay.const(lhs_zp), axis=lhs_axis)
+
+ y = relay.var("y", shape=[1, 3, 224, 224], dtype="int8")
+ y = relay.qnn.op.dequantize(y, relay.const(rhs_scale),
relay.const(rhs_zp), axis=rhs_axis)
+
+ op = operator(x, y)
+
+ op = relay.qnn.op.quantize(op, out_scale, relay.const(out_zp),
out_dtype="int8")
+ x_np = np.random.randint(-25, 25, size=[1, 3, 224, 224], dtype="int8")
+ y_np = np.random.randint(-25, 25, size=[1, 3, 224, 224], dtype="int8")
+
+ compare_fq_to_int(op, [x_np, y_np], allow_rounding_error=True)
+
+ # Same axis
+ verify_binary_per_channel(
+ lhs_scale=np.random.uniform(1.0, 5.0, 3),
+ rhs_scale=np.random.uniform(1.0, 5.0, 3),
+ lhs_zp=0,
+ rhs_zp=0,
+ out_zp=0,
+ lhs_axis=1,
+ rhs_axis=1,
+ )
+ verify_binary_per_channel(
+ lhs_scale=np.random.uniform(1.0, 5.0, 3),
+ rhs_scale=np.random.uniform(1.0, 5.0, 3),
+ lhs_zp=np.random.randint(1, 3),
+ rhs_zp=np.random.randint(1, 3),
+ out_zp=0,
+ lhs_axis=1,
+ rhs_axis=1,
+ )
+ verify_binary_per_channel(
+ lhs_scale=np.random.uniform(1.0, 5.0, 3),
+ rhs_scale=np.random.uniform(1.0, 5.0, 3),
+ lhs_zp=np.random.randint(1, 3),
+ rhs_zp=np.random.randint(1, 3),
+ out_zp=np.random.randint(1, 3),
+ lhs_axis=1,
+ rhs_axis=1,
+ )
+
+ # Different axes
+ verify_binary_per_channel(
+ lhs_scale=np.random.uniform(1.0, 5.0, 224),
+ rhs_scale=np.random.uniform(1.0, 5.0, 224),
+ lhs_zp=0,
+ rhs_zp=0,
+ out_zp=0,
+ lhs_axis=2,
+ rhs_axis=3,
+ )
+ verify_binary_per_channel(
+ lhs_scale=np.random.uniform(1.0, 5.0, 224),
+ rhs_scale=np.random.uniform(1.0, 5.0, 224),
+ lhs_zp=np.random.randint(1, 3),
+ rhs_zp=np.random.randint(1, 3),
+ out_zp=0,
+ lhs_axis=2,
+ rhs_axis=3,
+ )
+ verify_binary_per_channel(
+ lhs_scale=np.random.uniform(1.0, 5.0, 224),
+ rhs_scale=np.random.uniform(1.0, 5.0, 224),
+ lhs_zp=np.random.randint(1, 3),
+ rhs_zp=np.random.randint(1, 3),
+ out_zp=np.random.randint(1, 3),
+ lhs_axis=2,
+ rhs_axis=3,
+ )
+
+
Review comment:
Would be nice to have some examples with rank > 1
##########
File path: tests/python/relay/test_pass_fake_quantization_to_integer.py
##########
@@ -600,13 +600,97 @@ def test_fake_quantize_binary(operator):
compare_fq_to_int(op, [x_np, y_np])
[email protected](
+ "operator",
+ [relay.op.add, relay.op.multiply, relay.op.subtract, relay.op.minimum,
relay.op.maximum],
+)
+def test_fake_quantize_binary_per_channel(operator):
+ def verify_binary_per_channel(lhs_scale, rhs_scale, lhs_zp, rhs_zp,
out_zp, lhs_axis, rhs_axis):
+ if operator == relay.op.multiply:
+ out_scale = relay.const(2.0)
+ rhs_axis = lhs_axis # TODO: Support different axes for
per-channel quantized multiply
+ else:
+ out_scale = relay.const(0.1)
+
+ x = relay.var("x", shape=[1, 3, 224, 224], dtype="int8")
+ x = relay.qnn.op.dequantize(x, relay.const(lhs_scale),
relay.const(lhs_zp), axis=lhs_axis)
+
+ y = relay.var("y", shape=[1, 3, 224, 224], dtype="int8")
+ y = relay.qnn.op.dequantize(y, relay.const(rhs_scale),
relay.const(rhs_zp), axis=rhs_axis)
+
+ op = operator(x, y)
+
+ op = relay.qnn.op.quantize(op, out_scale, relay.const(out_zp),
out_dtype="int8")
+ x_np = np.random.randint(-25, 25, size=[1, 3, 224, 224], dtype="int8")
+ y_np = np.random.randint(-25, 25, size=[1, 3, 224, 224], dtype="int8")
+
+ compare_fq_to_int(op, [x_np, y_np], allow_rounding_error=True)
+
+ # Same axis
+ verify_binary_per_channel(
+ lhs_scale=np.random.uniform(1.0, 5.0, 3),
+ rhs_scale=np.random.uniform(1.0, 5.0, 3),
+ lhs_zp=0,
+ rhs_zp=0,
+ out_zp=0,
+ lhs_axis=1,
+ rhs_axis=1,
+ )
+ verify_binary_per_channel(
+ lhs_scale=np.random.uniform(1.0, 5.0, 3),
+ rhs_scale=np.random.uniform(1.0, 5.0, 3),
+ lhs_zp=np.random.randint(1, 3),
+ rhs_zp=np.random.randint(1, 3),
+ out_zp=0,
+ lhs_axis=1,
+ rhs_axis=1,
+ )
+ verify_binary_per_channel(
+ lhs_scale=np.random.uniform(1.0, 5.0, 3),
+ rhs_scale=np.random.uniform(1.0, 5.0, 3),
+ lhs_zp=np.random.randint(1, 3),
+ rhs_zp=np.random.randint(1, 3),
+ out_zp=np.random.randint(1, 3),
+ lhs_axis=1,
+ rhs_axis=1,
+ )
+
+ # Different axes
+ verify_binary_per_channel(
+ lhs_scale=np.random.uniform(1.0, 5.0, 224),
+ rhs_scale=np.random.uniform(1.0, 5.0, 224),
+ lhs_zp=0,
+ rhs_zp=0,
+ out_zp=0,
+ lhs_axis=2,
+ rhs_axis=3,
+ )
+ verify_binary_per_channel(
+ lhs_scale=np.random.uniform(1.0, 5.0, 224),
+ rhs_scale=np.random.uniform(1.0, 5.0, 224),
+ lhs_zp=np.random.randint(1, 3),
+ rhs_zp=np.random.randint(1, 3),
+ out_zp=0,
+ lhs_axis=2,
+ rhs_axis=3,
+ )
+ verify_binary_per_channel(
+ lhs_scale=np.random.uniform(1.0, 5.0, 224),
+ rhs_scale=np.random.uniform(1.0, 5.0, 224),
+ lhs_zp=np.random.randint(1, 3),
+ rhs_zp=np.random.randint(1, 3),
+ out_zp=np.random.randint(1, 3),
+ lhs_axis=2,
+ rhs_axis=3,
+ )
+
+
Review comment:
Also be nice to have examples which broadcast (if this is supported)
##########
File path: src/relay/qnn/op/mul.cc
##########
@@ -51,44 +53,108 @@ Expr QnnMulCanonicalize(const Attrs& attrs, const
Array<Expr>& new_args,
const auto int32_dtype = DataType::Int(32);
const auto float32_dtype = DataType::Float(32);
- /*
- A tensor multiplication c = a * b can be written in terms of respective
- quantized tensors, scales and zero points as
- S_c * (Q_c - zp_c) = S_a * (Q_a - zp_a) * S_b * (Q_b - zp_b).
-
- We can consider the product (Q_a - zp_a) * (Q_b - zp_b) as a different
- quantized tensor of c, Q', with corresponding scale S' = S_a * S_b and zp' =
- 0. The quantized multiplication then becomes
- Q_c = S'/S_c Q' + z_c,
- which is essentially a requantization of tensor Q' into tensor Q_c.
- */
-
- auto lhs_shifted = Cast(args.lhs, int32_dtype);
- auto rhs_shifted = Cast(args.rhs, int32_dtype);
-
- auto zero_scalar = MakeConstantScalar(int32_dtype, 0);
- if (!IsEqualScalar(args.lhs_zero_point, zero_scalar)) {
- lhs_shifted = Subtract(lhs_shifted, args.lhs_zero_point);
+ const auto* broadcast_attrs = attrs.as<BroadcastAttrs>();
+ ICHECK(broadcast_attrs != nullptr);
+
+ auto lhs_axis = broadcast_attrs->lhs_axis;
+ auto rhs_axis = broadcast_attrs->rhs_axis;
+
+ if (lhs_axis == -1 && rhs_axis == -1) {
Review comment:
So what does axis=-1 mean? I think it's an overloaded term.
This seems to be the default case where there is not any per-channel where
elsewhere it may mean quantize the last channel.
##########
File path: src/relay/qnn/op/op_common.h
##########
@@ -243,10 +246,62 @@ static inline bool QnnBroadcastRel(const Array<Type>&
types, int num_inputs, con
return false;
}
}
- ICHECK(IsScalarType(types[2], DataType::Float(32))); // lhs_scale
- ICHECK(IsScalarType(types[3], DataType::Int(32))); // lhs_zero_point
- ICHECK(IsScalarType(types[4], DataType::Float(32))); // rhs_scale
- ICHECK(IsScalarType(types[5], DataType::Int(32))); // rhs_zero_point
+
+ const auto* lhs_data = types[0].as<TensorTypeNode>();
+ const auto* rhs_data = types[1].as<TensorTypeNode>();
+
+ if (lhs_data == nullptr || rhs_data == nullptr) {
+ return false;
+ }
+
+ const BroadcastAttrs* broadcast_attrs = attrs.as<BroadcastAttrs>();
+ int lhs_axis = broadcast_attrs->lhs_axis;
+ int rhs_axis = broadcast_attrs->rhs_axis;
+
+ auto lhs_rank = static_cast<int>(lhs_data->shape.size());
+ auto rhs_rank = static_cast<int>(rhs_data->shape.size());
+
+ lhs_axis = (lhs_axis < 0) ? ((lhs_rank > 0) ? lhs_rank + lhs_axis : 0) :
lhs_axis;
Review comment:
Hmm there is a lot of lhs_rank > 0 checks going on which I think make it
feel confusing to read, can you just factor out the scalar case seperately?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]