gemini-code-assist[bot] commented on code in PR #18694:
URL: https://github.com/apache/tvm/pull/18694#discussion_r2746185371
##########
src/tir/transforms/lower_intrin.cc:
##########
@@ -118,10 +118,20 @@ class IntrinInjecter : public
tvm::arith::IRMutatorWithAnalyzer {
// If the numerator's lower bound is known, express the floordiv
// in terms of truncdiv using only positive operands.
- arith::ConstIntBound const_int_bound = analyzer_->const_int_bound(op->a);
- if (const_int_bound->min_value < 0 &&
- const_int_bound->min_value >
-
-(Downcast<IntImm>(tvm::max_value(op->a->dtype.element_of()))->value)) {
+
+ // The optimization below rewrites expressions involving `-a_min + (b -
1)`.
+ // Without proper bounds checking, this expression may overflow the dtype
+ // maximum, leading to non-equivalent transformations.
+ // To ensure safety, we require:
+ // b_max - a_min <= max_value_of_dtype + 1
+ // This provides a conservative upper bound that prevents overflow and
+ // preserves the original semantics.
+ arith::ConstIntBound const_int_bound_a =
analyzer_->const_int_bound(op->a);
+ arith::ConstIntBound const_int_bound_b =
analyzer_->const_int_bound(op->b);
+ const int64_t max_value_of_dtype =
+ Downcast<IntImm>(tvm::max_value(op->a->dtype.element_of()))->value;
+ if (const_int_bound_a->min_value < 0 &&
+ const_int_bound_b->max_value - const_int_bound_a->min_value <=
max_value_of_dtype + 1) {
Review Comment:

This block of logic for checking the safety of the optimization is
duplicated in `VisitExpr_(const FloorModNode* op)` at lines 228-240. To improve
maintainability and reduce code duplication, consider extracting this logic
into a private helper function within the `IntrinInjecter` class.
For example, you could create a helper function like this:
```cpp
private:
bool IsFloorDivModOptSafe(const PrimExpr& a, const PrimExpr& b) {
arith::ConstIntBound const_int_bound_a = analyzer_->const_int_bound(a);
if (const_int_bound_a->min_value >= 0) {
return false;
}
// The optimization below rewrites expressions involving `-a_min + (b -
1)`.
// Without proper bounds checking, this expression may overflow the dtype
// maximum, leading to non-equivalent transformations.
// To ensure safety, we require:
// b_max - a_min <= max_value_of_dtype + 1
// This provides a conservative upper bound that prevents overflow and
// preserves the original semantics.
arith::ConstIntBound const_int_bound_b = analyzer_->const_int_bound(b);
const int64_t max_value_of_dtype =
Downcast<IntImm>(tvm::max_value(a->dtype.element_of()))->value;
return const_int_bound_b->max_value - const_int_bound_a->min_value <=
max_value_of_dtype + 1;
}
```
Then, you could simplify the call sites:
```cpp
// in VisitExpr_(const FloorDivNode* op)
if (IsFloorDivModOptSafe(op->a, op->b)) {
arith::ConstIntBound const_int_bound_a = analyzer_->const_int_bound(op->a);
// ... rest of the logic
}
// in VisitExpr_(const FloorModNode* op)
if (IsFloorDivModOptSafe(op->a, op->b)) {
arith::ConstIntBound const_int_bound_a = analyzer_->const_int_bound(op->a);
// ... rest of the logic
}
```
This would centralize the safety check and make the code cleaner.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]