shoubhik commented on a change in pull request #5153: Adding support for QNN 
subtract op
URL: https://github.com/apache/incubator-tvm/pull/5153#discussion_r398857587
 
 

 ##########
 File path: src/relay/qnn/op/add.cc
 ##########
 @@ -97,39 +66,29 @@ Expr QnnAddCanonicalize(const Attrs& attrs, const 
Array<Expr>& new_args,
   //          Q_c = Q_a' + Q_b' - zp_c
   // The add op is done in int32 precision.
 
-  // Requantize LHS if necessary.
-  auto requantized_lhs = lhs;
-  if (!IsEqualScalar(lhs_scale, output_scale) ||
-      !IsEqualScalar(lhs_zero_point, output_zero_point)) {
-    requantized_lhs = Requantize(lhs, input_shape, lhs_scale, lhs_zero_point, 
output_scale,
-                                 output_zero_point, DataType::Int(32));
-  } else {
-    requantized_lhs = Cast(requantized_lhs, DataType::Int(32));
-  }
 
-  // Requantize RHS if necessary.
-  auto requantized_rhs = rhs;
-  if (!IsEqualScalar(rhs_scale, output_scale) ||
-      !IsEqualScalar(rhs_zero_point, output_zero_point)) {
-    requantized_rhs = Requantize(rhs, input_shape, rhs_scale, rhs_zero_point, 
output_scale,
-                                 output_zero_point, DataType::Int(32));
-  } else {
-    requantized_rhs = Cast(requantized_rhs, DataType::Int(32));
-  }
 
+  // Requantize LHS if necessary. Computes Q_a'
+  auto requantized_lhs = requantizeIfNeeded(args.lhs, args.lhs_scale,
+                                            args.lhs_zero_point,
+                                            args.output_scale, 
args.output_zero_point,
+                                            inputShapeAndDtype.input_shape);
+  // Requantize RHS if necessary. Computes Q_b'
+  auto requantized_rhs = requantizeIfNeeded(args.rhs, args.rhs_scale,
+                                            args.rhs_zero_point,
+                                            args.output_scale, 
args.output_zero_point,
+                                            inputShapeAndDtype.input_shape);
+  // Computes Q_a' + Q_b'
   auto output = Add(requantized_lhs, requantized_rhs);
 
-  // Subtract zero point.
+  // Subtract zero point. Computes (Q_a' + Q_b') - zp_c
   auto zero_scalar = MakeConstantScalar(DataType::Int(32), 0);
-  if (!IsEqualScalar(output_zero_point, zero_scalar)) {
-    output = Subtract(output, output_zero_point);
+  if (!IsEqualScalar(args.output_zero_point, zero_scalar)) {
+    output = Subtract(output, args.output_zero_point);
   }
 
   // Go back to lower precision.
-  auto q_min = GetQmin(input_dtype);
-  auto q_max = GetQmax(input_dtype);
-  output = Clip(output, q_min, q_max);
-  return Cast(output, input_dtype);
+  return lowerPrecision(output, inputShapeAndDtype.input_dtype);
 
 Review comment:
   Well `ShrinkBackToOutDtype` implies that the the output will be converted to 
output dtype but we always convet to input dtype (what goes in comes out). 
`ConvertDtype(fromExpression, toTargetDtype)` is the convention I am going 
with. What do you think?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to