sfvaroglu commented on a change in pull request #10718:
URL: https://github.com/apache/tvm/pull/10718#discussion_r836722294



##########
File path: src/relay/qnn/op/op_common.h
##########
@@ -243,10 +246,62 @@ static inline bool QnnBroadcastRel(const Array<Type>& 
types, int num_inputs, con
       return false;
     }
   }
-  ICHECK(IsScalarType(types[2], DataType::Float(32)));  // lhs_scale
-  ICHECK(IsScalarType(types[3], DataType::Int(32)));    // lhs_zero_point
-  ICHECK(IsScalarType(types[4], DataType::Float(32)));  // rhs_scale
-  ICHECK(IsScalarType(types[5], DataType::Int(32)));    // rhs_zero_point
+
+  const auto* lhs_data = types[0].as<TensorTypeNode>();
+  const auto* rhs_data = types[1].as<TensorTypeNode>();
+
+  if (lhs_data == nullptr || rhs_data == nullptr) {
+    return false;
+  }
+
+  const BroadcastAttrs* broadcast_attrs = attrs.as<BroadcastAttrs>();
+  int lhs_axis = broadcast_attrs->lhs_axis;
+  int rhs_axis = broadcast_attrs->rhs_axis;
+
+  auto lhs_rank = static_cast<int>(lhs_data->shape.size());
+  auto rhs_rank = static_cast<int>(rhs_data->shape.size());
+
+  lhs_axis = (lhs_axis < 0) ? ((lhs_rank > 0) ? lhs_rank + lhs_axis : 0) : 
lhs_axis;

Review comment:
       I couldn't see a less confusing/ugly way to refactor this out 

##########
File path: src/relay/qnn/op/mul.cc
##########
@@ -51,44 +53,108 @@ Expr QnnMulCanonicalize(const Attrs& attrs, const 
Array<Expr>& new_args,
   const auto int32_dtype = DataType::Int(32);
   const auto float32_dtype = DataType::Float(32);
 
-  /*
-  A tensor multiplication c = a * b can be written in terms of respective
-  quantized tensors, scales and zero points as
-  S_c * (Q_c - zp_c) = S_a * (Q_a - zp_a) * S_b * (Q_b - zp_b).
-
-  We can consider the product (Q_a - zp_a) * (Q_b - zp_b) as a different
-  quantized tensor of c, Q', with corresponding scale S' = S_a * S_b and zp' =
-  0. The quantized multiplication then becomes
-  Q_c = S'/S_c Q' + z_c,
-  which is essentially a requantization of tensor Q' into tensor Q_c.
-  */
-
-  auto lhs_shifted = Cast(args.lhs, int32_dtype);
-  auto rhs_shifted = Cast(args.rhs, int32_dtype);
-
-  auto zero_scalar = MakeConstantScalar(int32_dtype, 0);
-  if (!IsEqualScalar(args.lhs_zero_point, zero_scalar)) {
-    lhs_shifted = Subtract(lhs_shifted, args.lhs_zero_point);
+  const auto* broadcast_attrs = attrs.as<BroadcastAttrs>();
+  ICHECK(broadcast_attrs != nullptr);
+
+  auto lhs_axis = broadcast_attrs->lhs_axis;
+  auto rhs_axis = broadcast_attrs->rhs_axis;
+
+  if (lhs_axis == -1 && rhs_axis == -1) {

Review comment:
       Fixed




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to