abhikran-quic commented on code in PR #14854:
URL: https://github.com/apache/tvm/pull/14854#discussion_r1220062219
##########
tests/python/unittest/test_tir_schedule_compute_at.py:
##########
@@ -995,6 +995,41 @@ def
floordiv_and_floormod_indices_after_reverse_compute_at(a: T.handle, b: T.han
Y[v_i] = temp[v_i // 16, v_i % 16]
[email protected]_func
+def recursive_floordiv_and_floormod_indices(a: T.handle, b: T.handle) -> None:
+ X = T.match_buffer(a, [16, 16])
+ Y = T.match_buffer(b, [256])
+ temp = T.alloc_buffer([16, 4, 2, 2])
+ for i, j in T.grid(16, 16):
+ with T.block("A"):
+ v_i, v_j = T.axis.remap("SS", [i, j])
+ temp[v_i, v_j // 4, (v_j % 4) //2, v_j % 2] = X[v_j, v_i] + 1.0
+ for i, j in T.grid(16, 16):
+ with T.block("B"):
+ v_i, v_j = T.axis.remap("SS", [i, j])
+ Y[v_i*16 + v_j] = temp[v_i, v_j // 4, (v_j % 4) // 2, (v_j %2)]
+
+
[email protected]_func
+def recursive_floordiv_and_floormod_indices_after_reverse_compute_at(a:
T.handle, b: T.handle) -> None:
+ X = T.match_buffer(a, [16, 16])
+ Y = T.match_buffer(b, [256])
+ temp = T.alloc_buffer((16, 4, 2, 2))
+ for i in range(16):
+ for j in range(16):
+ with T.block("A"):
+ v_i, v_j = T.axis.remap("SS", [i, j])
+ T.reads(X[v_j, v_i])
+ T.writes(temp[v_i, v_j // 4, v_j % 4 // 2, v_j % 2])
+ temp[v_i, v_j // 4, v_j % 4 // 2, v_j % 2] = X[v_j, v_i] +
T.float32(1)
Review Comment:
Sure. I've added a test case to validate the LOC changed in this PR. It's
reduced from the issue that I faced with a model whose schedule I was trying to
optimize using `compute_at`.
##########
src/tir/schedule/primitive/compute_at.cc:
##########
@@ -422,19 +422,27 @@ std::pair<Var, BlockVarDomainInfo>
SolveBlockVarDomain(const arith::IntSet& prov
var_dom = arith::IntSet::Interval(required_min, required_max);
var_bound = arith::IntSet::Interval(0, dim_max);
} else {
- arith::PVar<PrimExpr> p_f;
- if ((floordiv(p_v, p_f)).Match(provided_min)) {
+ arith::PVar<PrimExpr> p_f1, p_f2;
+ if ((floordiv(p_v, p_f1)).Match(provided_min)) {
// a <= (x // factor) <= b, fac > 0 ==> (a * fac) <= x <= (b * fac +
fac - 1)
- PrimExpr fac = p_f.Eval();
+ PrimExpr fac = p_f1.Eval();
if (analyzer->CanProveGreaterEqual(fac, 1)) {
var = p_v.Eval();
var_dom = arith::IntSet::Interval(required_min * fac,
analyzer->Simplify(required_max *
fac + fac - 1));
var_bound = arith::IntSet::Interval(0, analyzer->Simplify(dim_max *
fac + fac - 1));
}
- } else if ((floormod(p_v, p_f).Match(provided_min))) {
+ } else if ((floormod(p_v, p_f1).Match(provided_min))) {
// generally domain of (x % fac) enforce no constraints to domain of x
return {p_v.Eval(), BlockVarDomainInfo()};
+ } else if ((floordiv(p_f1, p_f2).Match(provided_min))) {
Review Comment:
Thank you for the suggestion. I've merged the separate conditions for
`floordiv` and `floormod` respectively.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]