shoubhik commented on a change in pull request #5153: Adding support for QNN
subtract op
URL: https://github.com/apache/incubator-tvm/pull/5153#discussion_r398857587
##########
File path: src/relay/qnn/op/add.cc
##########
@@ -97,39 +66,29 @@ Expr QnnAddCanonicalize(const Attrs& attrs, const
Array<Expr>& new_args,
// Q_c = Q_a' + Q_b' - zp_c
// The add op is done in int32 precision.
- // Requantize LHS if necessary.
- auto requantized_lhs = lhs;
- if (!IsEqualScalar(lhs_scale, output_scale) ||
- !IsEqualScalar(lhs_zero_point, output_zero_point)) {
- requantized_lhs = Requantize(lhs, input_shape, lhs_scale, lhs_zero_point,
output_scale,
- output_zero_point, DataType::Int(32));
- } else {
- requantized_lhs = Cast(requantized_lhs, DataType::Int(32));
- }
- // Requantize RHS if necessary.
- auto requantized_rhs = rhs;
- if (!IsEqualScalar(rhs_scale, output_scale) ||
- !IsEqualScalar(rhs_zero_point, output_zero_point)) {
- requantized_rhs = Requantize(rhs, input_shape, rhs_scale, rhs_zero_point,
output_scale,
- output_zero_point, DataType::Int(32));
- } else {
- requantized_rhs = Cast(requantized_rhs, DataType::Int(32));
- }
+ // Requantize LHS if necessary. Computes Q_a'
+ auto requantized_lhs = requantizeIfNeeded(args.lhs, args.lhs_scale,
+ args.lhs_zero_point,
+ args.output_scale,
args.output_zero_point,
+ inputShapeAndDtype.input_shape);
+ // Requantize RHS if necessary. Computes Q_b'
+ auto requantized_rhs = requantizeIfNeeded(args.rhs, args.rhs_scale,
+ args.rhs_zero_point,
+ args.output_scale,
args.output_zero_point,
+ inputShapeAndDtype.input_shape);
+ // Computes Q_a' + Q_b'
auto output = Add(requantized_lhs, requantized_rhs);
- // Subtract zero point.
+ // Subtract zero point. Computes (Q_a' + Q_b') - zp_c
auto zero_scalar = MakeConstantScalar(DataType::Int(32), 0);
- if (!IsEqualScalar(output_zero_point, zero_scalar)) {
- output = Subtract(output, output_zero_point);
+ if (!IsEqualScalar(args.output_zero_point, zero_scalar)) {
+ output = Subtract(output, args.output_zero_point);
}
// Go back to lower precision.
- auto q_min = GetQmin(input_dtype);
- auto q_max = GetQmax(input_dtype);
- output = Clip(output, q_min, q_max);
- return Cast(output, input_dtype);
+ return lowerPrecision(output, inputShapeAndDtype.input_dtype);
Review comment:
Well `ShrinkBackToOutDtype` implies that the the output will be converted to
output dtype but we always convet to input dtype (what goes in comes out).
`ConvertDtype(fromExpression, toTargetDtype)` is the convention I am going
with. What do you think?
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services