anijain2305 commented on a change in pull request #5153: Adding support for QNN 
subtract op
URL: https://github.com/apache/incubator-tvm/pull/5153#discussion_r398788973
 
 

 ##########
 File path: src/relay/qnn/op/op_common.h
 ##########
 @@ -30,14 +30,155 @@
 #include <tvm/relay/qnn/attrs.h>
 #include <vector>
 #include "../../op/type_relations.h"
+#include "../../transforms/infer_layout_util.h"
 
 namespace tvm {
 namespace relay {
 namespace qnn {
 
-static inline bool QnnBroadcastRel(const Array<Type>& types, int num_inputs, 
const Attrs& attrs,
+/*
+ * Number of inputs for the Qnn binary operators.
+ * Refer the QNN_REGISTER_BINARY_OP macro to see
+ * what the operators are.
+ */
+static constexpr int numQnnBinaryOpInputs = 8;
+
+/*
+ * Number of expected arg types.
+ */
+static constexpr int numQnnBinaryOpArgTypes = 9;
+
+/*
+ * \brief Simple struct to organize the inputs to the Qnn
+ * binary operators. The main reason to have a struct
+ * is to be able to perform the common checks needed at a
+ * central location.
+ */
+struct QnnBinaryOpArguments {
+  Expr lhs;
+  Expr rhs;
+  Expr lhs_scale;
+  Expr lhs_zero_point;
+  Expr rhs_scale;
+  Expr rhs_zero_point;
+  Expr output_scale;
+  Expr output_zero_point;
+
+  explicit QnnBinaryOpArguments(const Array<Expr>& new_args) {
+    CHECK_EQ(new_args.size(), numQnnBinaryOpInputs);
+    int idx = 0;
+    lhs = new_args[idx++];
+    rhs = new_args[idx++];
+    lhs_scale = new_args[idx++];
+    lhs_zero_point = new_args[idx++];
+    rhs_scale = new_args[idx++];
+    rhs_zero_point = new_args[idx++];
+    output_scale = new_args[idx++];
+    output_zero_point = new_args[idx++];
+    CHECK_EQ(idx, numQnnBinaryOpInputs);
+  }
+};
+
+/*
+ * \brief Simple structure to hold the input tensor's dtype
+ * and shape. This structure allows a common point to do
+ * all the validation checks.
+ */
+struct QnnBinaryOpDtypeAndShape {
+  DataType input_dtype;
+  Array <PrimExpr> input_shape;
+
+  explicit QnnBinaryOpDtypeAndShape(const Array<tvm::relay::Type>& arg_types) {
+    CHECK_EQ(arg_types.size(), numQnnBinaryOpArgTypes);
+    auto tensor_type = arg_types[0].as<TensorTypeNode>();
+    CHECK(tensor_type != nullptr);
+    input_dtype = tensor_type->dtype;
+    input_shape = tensor_type->shape;
+  }
+};
+
+/*
+ * \brief Converts the expression from expression's dtype
+ * to target dtype. This is mainly used for converting
+ * computations done in Int32 to lower precision Int8 or
+ * UInt8.
+ * \param expr The expression to whose dtype needs conversion.
+ * \param target_dtype The dtype of the target expression
+ * \return New expression with target dtype and possibly lower
+ * precision.
+ */
+inline Expr lowerPrecision(const Expr& expr,
+                           const DataType& target_dtype) {
+  auto q_min = GetQmin(target_dtype);
+  auto q_max = GetQmax(target_dtype);
+  auto output = Clip(expr, q_min, q_max);
+  return Cast(output, target_dtype);
+}
+
+/*
+ * Full precision Int32 data type for explicitly casting
+ * Int8/UInt8 to Int32 and create Int32 constants.
+ */
+const auto fullPrecisionInt32 = DataType::Int(32);
+
+/*
+ * \brief Requantizes the given expression if expression's
+ * scale and zero point both do not match target scale and
+ * zero point. This is mainly needed for requantizing the
+ * input tensors with output tensor's scale and zero point
+ * to ease the computation of final quantized tensor.
+ * \param expr The expression on which the check needs to be performed.
+ * \param expr_scale The scale of the expression.
+ * \param expr_zero_point The zero point of the expression.
+ * \param target_scale The scale of the output tensor.
+ * \param target_zero_point The zero point of the output tensor.
+ * \param expr_shape The shape of the input expression.
+ * \return New expression that is requantized to target scale and zero
+ * point if the expression scale and zero points are different otherwise
+ * it simply casts the given expression to Int32 as no requantization is
+ * needed in this case.
+ */
+inline Expr requantizeIfNeeded(const Expr& expr,
+                               const Expr& expr_scale,
+                               const Expr& expr_zero_point,
+                               const Expr& target_scale,
+                               const Expr& target_zero_point,
+                               const Array <PrimExpr>& expr_shape) {
+  auto result = expr;
+  if (!IsEqualScalar(expr_scale, target_scale) ||
+     !IsEqualScalar(expr_zero_point, target_zero_point)) {
+    result = Requantize(expr, expr_shape, expr_scale, expr_zero_point,
+                        target_scale, target_zero_point, fullPrecisionInt32);
+  } else {
+    result = Cast(result, fullPrecisionInt32);
 
 Review comment:
   Maybe we need to come up with better name for the function. Here if not 
requantize, we are still upcasting, that is not clear from the name.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to