shoubhik commented on a change in pull request #5153: Adding support for QNN
subtract op
URL: https://github.com/apache/incubator-tvm/pull/5153#discussion_r398806276
##########
File path: src/relay/qnn/op/add.cc
##########
@@ -97,39 +66,29 @@ Expr QnnAddCanonicalize(const Attrs& attrs, const
Array<Expr>& new_args,
// Q_c = Q_a' + Q_b' - zp_c
// The add op is done in int32 precision.
- // Requantize LHS if necessary.
- auto requantized_lhs = lhs;
- if (!IsEqualScalar(lhs_scale, output_scale) ||
- !IsEqualScalar(lhs_zero_point, output_zero_point)) {
- requantized_lhs = Requantize(lhs, input_shape, lhs_scale, lhs_zero_point,
output_scale,
- output_zero_point, DataType::Int(32));
- } else {
- requantized_lhs = Cast(requantized_lhs, DataType::Int(32));
- }
- // Requantize RHS if necessary.
- auto requantized_rhs = rhs;
- if (!IsEqualScalar(rhs_scale, output_scale) ||
- !IsEqualScalar(rhs_zero_point, output_zero_point)) {
- requantized_rhs = Requantize(rhs, input_shape, rhs_scale, rhs_zero_point,
output_scale,
- output_zero_point, DataType::Int(32));
- } else {
- requantized_rhs = Cast(requantized_rhs, DataType::Int(32));
- }
+ // Requantize LHS if necessary. Computes Q_a'
+ auto requantized_lhs = requantizeIfNeeded(args.lhs, args.lhs_scale,
+ args.lhs_zero_point,
+ args.output_scale,
args.output_zero_point,
+ inputShapeAndDtype.input_shape);
+ // Requantize RHS if necessary. Computes Q_b'
+ auto requantized_rhs = requantizeIfNeeded(args.rhs, args.rhs_scale,
+ args.rhs_zero_point,
+ args.output_scale,
args.output_zero_point,
+ inputShapeAndDtype.input_shape);
+ // Computes Q_a' + Q_b'
auto output = Add(requantized_lhs, requantized_rhs);
- // Subtract zero point.
+ // Subtract zero point. Computes (Q_a' + Q_b') - zp_c
auto zero_scalar = MakeConstantScalar(DataType::Int(32), 0);
- if (!IsEqualScalar(output_zero_point, zero_scalar)) {
- output = Subtract(output, output_zero_point);
+ if (!IsEqualScalar(args.output_zero_point, zero_scalar)) {
+ output = Subtract(output, args.output_zero_point);
}
// Go back to lower precision.
- auto q_min = GetQmin(input_dtype);
- auto q_max = GetQmax(input_dtype);
- output = Clip(output, q_min, q_max);
- return Cast(output, input_dtype);
+ return lowerPrecision(output, inputShapeAndDtype.input_dtype);
Review comment:
i think `ConvertDtype` makse sense?
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services