masahi commented on code in PR #13080:
URL: https://github.com/apache/tvm/pull/13080#discussion_r1005269639
##########
src/target/intrin_rule.cc:
##########
@@ -194,40 +234,34 @@ TVM_REGISTER_OP("tir.q_multiply_shift")
}
} else {
// Only int32 types are supported (any number of lanes is allowed)
- ICHECK(y.dtype().code() == DLDataTypeCode::kDLInt && y.dtype().bits()
== 32);
ICHECK(s.dtype().code() == DLDataTypeCode::kDLInt && s.dtype().bits()
== 32);
- DataType hp_dtype = DataType::Int(64, x.dtype().lanes());
- DataType lp_dtype = DataType::Int(32, x.dtype().lanes());
-
- // 1) Calculating the integer multiplier and integer shift
+ // Calculating integer shifts
PrimExpr zero = make_const(s.dtype(), 0);
PrimExpr left_shift = tir::Select(s > zero, s, zero);
PrimExpr right_shift = tir::Select(s > zero, zero, -s);
+ PrimExpr is_left_shift_required = (left_shift != zero);
- // 2) Cast and Multiply the integer multiplier
- PrimExpr one = make_const(hp_dtype, 1);
- x = cast(hp_dtype, x);
- y = cast(hp_dtype, y);
- x = tir::Select(left_shift != zero, x << left_shift, x);
-
- // 3) Perform the multiplication in higher precision.
- x = x * y;
-
- // 4) Find the rounding scalar
- PrimExpr total_right_shift = right_shift + q;
- PrimExpr pos_rounding_value = (one << (total_right_shift - 1));
- x = x + pos_rounding_value;
-
- // 5) Simply right shift the result to get the final output.
- x = x >> total_right_shift;
-
- // 6) The fixed point multiplication keeps the value in int32 range.
Casting back to
- // int32.
- return cast(lp_dtype, x);
+ return QMultiplyShift(x, y, q, left_shift, right_shift,
is_left_shift_required);
}
});
+TVM_REGISTER_OP("tir.q_multiply_shift_per_axis")
+ .set_attr<FLegalize>("default.FLegalize", [](const PrimExpr& e) ->
PrimExpr {
+ const tir::CallNode* call = e.as<tir::CallNode>();
+ ICHECK(call != nullptr);
+
+ PrimExpr x = call->args[0];
+ PrimExpr y = call->args[1];
+ PrimExpr left_shift = call->args[2];
+ PrimExpr right_shift = call->args[3];
+ PrimExpr q = call->args[4];
+ PrimExpr is_lshift_required = call->args[5];
+ // Note, 7th argument is "is_rshift_required" flag, but we do need that
here.
Review Comment:
You mean "don't need"?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]