anwang2009 commented on a change in pull request #10718:
URL: https://github.com/apache/tvm/pull/10718#discussion_r833480692
##########
File path: tests/python/relay/test_pass_fake_quantization_to_integer.py
##########
@@ -600,13 +600,100 @@ def test_fake_quantize_binary(operator):
compare_fq_to_int(op, [x_np, y_np])
[email protected](
+ "operator",
+ [relay.op.add, relay.op.multiply, relay.op.subtract, relay.op.minimum,
relay.op.maximum],
+)
+def test_fake_quantize_binary_per_channel(operator):
+ def verify_binary_per_channel(lhs_scale, rhs_scale, lhs_zp, rhs_zp,
out_zp, lhs_axis, rhs_axis):
+ if operator == relay.op.multiply:
+ out_scale = relay.const(2.0)
+ rhs_axis = lhs_axis # TODO: Support different axes for
per-channel quantized multiply
+ else:
+ out_scale = relay.const(0.1)
+ x = relay.var("x", shape=[1, 3, 224, 224], dtype="int8")
+ x = relay.qnn.op.dequantize(x, relay.const(lhs_scale),
relay.const(lhs_zp), axis=lhs_axis)
+
+ y = relay.var("y", shape=[1, 3, 224, 224], dtype="int8")
+ y = relay.qnn.op.dequantize(y, relay.const(rhs_scale),
relay.const(rhs_zp), axis=rhs_axis)
+
+ op = operator(x, y)
+ if operator == relay.op.multiply:
+ out_scale = relay.const(2.0)
+ else:
+ out_scale = relay.const(0.1)
Review comment:
nit: this block isn't needed, you've set out_scale above
##########
File path: src/relay/qnn/op/op_common.h
##########
@@ -286,7 +346,10 @@ static inline bool QnnBroadcastRel(const Array<Type>&
types, int num_inputs, con
.add_argument("rhs_zero_point", "Tensor", "The zero_point of the rhs
tensor.") \
.add_argument("output_scale", "Tensor", "The scale of the output
tensor.") \
.add_argument("output_zero_point", "Tensor", "The zero_point of the
output tensor.") \
+ .add_argument("lhs_axis", "Tensor", "The channel quantization of the lhs
tensor.") \
+ .add_argument("rhs_axis", "Tensor", "The channel quantization of the rhs
tensor.") \
.add_type_rel("QnnBroadcast", QnnBroadcastRel)
\
+ .set_attr<TOpPattern>("TOpPattern", kOpaque)
\
Review comment:
Just confirming my assumptions, this is to prevent qnn ops from being
fused?
##########
File path: src/relay/qnn/op/op_common.h
##########
@@ -243,10 +246,62 @@ static inline bool QnnBroadcastRel(const Array<Type>&
types, int num_inputs, con
return false;
}
}
- ICHECK(IsScalarType(types[2], DataType::Float(32))); // lhs_scale
- ICHECK(IsScalarType(types[3], DataType::Int(32))); // lhs_zero_point
- ICHECK(IsScalarType(types[4], DataType::Float(32))); // rhs_scale
- ICHECK(IsScalarType(types[5], DataType::Int(32))); // rhs_zero_point
+
+ const auto* lhs_data = types[0].as<TensorTypeNode>();
+ const auto* rhs_data = types[1].as<TensorTypeNode>();
+
+ if (lhs_data == nullptr || rhs_data == nullptr) {
+ return false;
+ }
+
+ const BroadcastAttrs* broadcast_attrs = attrs.as<BroadcastAttrs>();
+ int lhs_axis = broadcast_attrs->lhs_axis;
+ int rhs_axis = broadcast_attrs->rhs_axis;
+
+ auto lhs_rank = static_cast<int>(lhs_data->shape.size());
+ auto rhs_rank = static_cast<int>(rhs_data->shape.size());
+
+ lhs_axis = (lhs_axis < 0) ? ((lhs_rank > 0) ? lhs_data->shape.size() +
lhs_axis : 0) : lhs_axis;
+ rhs_axis = (rhs_axis < 0) ? ((rhs_rank > 0) ? rhs_data->shape.size() +
rhs_axis : 0) : rhs_axis;
Review comment:
nit
```suggestion
lhs_axis = (lhs_axis < 0) ? ((lhs_rank > 0) ? lhs_rank + lhs_axis : 0) :
lhs_axis;
rhs_axis = (rhs_axis < 0) ? ((rhs_rank > 0) ? rhs_rank + rhs_axis : 0) :
rhs_axis;
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]