LiangHao151941 commented on a change in pull request #4828: [QNN][TFLite]
TFLite rounding mode support
URL: https://github.com/apache/incubator-tvm/pull/4828#discussion_r376854118
##########
File path: src/relay/qnn/util.cc
##########
@@ -100,35 +134,54 @@ Expr FixedPointMultiply(Expr tensor, double multiplier,
const Array<IndexExpr>&
// (from the right, rightmost bit is bit 0). The computation is performed in
// higher precision to avoid overflow in multiplying two int32 values.
Expr scalar = MakeConstantScalar(hp_dtype, fixed_point_multiplier);
- tensor = Multiply(tensor, scalar);
+ Expr scaled_tensor = Multiply(tensor, scalar);
// 4) Find the rounding scalar. This depends on where the final decimal
// point sits. As we will be right shifting the multiplied_t, we need to
// first calculate the total_right_shift.
int total_right_shift = right_shift + 31;
int64_t pos_rounding_value = (1ll << (total_right_shift - 1));
+ auto nearest_rounding_scalar =
+ [&](const Expr& input_tensor, int right_shift) -> Expr {
+ int64_t pos_rounding_value = (1ll << (right_shift - 1));
+ auto pos_rounder = MakeConstantScalar(hp_dtype, pos_rounding_value);
+ auto neg_rounder = MakeConstantScalar(hp_dtype, pos_rounding_value - 1);
+ auto pos_rounder_t = Full(pos_rounder, input_shape, hp_dtype);
+ auto neg_rounder_t = Full(neg_rounder, input_shape, hp_dtype);
+
+ auto zero_t = Zeros(input_shape, hp_dtype);
+ return Where(
+ GreaterEqual(input_tensor, zero_t), pos_rounder_t, neg_rounder_t);
Review comment:
I'll add some comments for this lambda function, if it is the suggestion
here that I understand.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services