This is an automated email from the ASF dual-hosted git repository. wuwei pushed a commit to branch fix/iter_map in repository https://gitbox.apache.org/repos/asf/tvm.git
commit 693aa52c733c94f82b8c3f75a1bac245aca164ce Author: Wuwei Lin <[email protected]> AuthorDate: Wed Feb 10 12:41:17 2021 -0800 [Arith] Fix iter_affine_map with non-const extent --- src/arith/iter_affine_map.cc | 34 +++++++++++----------- .../python/unittest/test_arith_iter_affine_map.py | 3 ++ 2 files changed, 20 insertions(+), 17 deletions(-) diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc index 7896db7..170e825 100644 --- a/src/arith/iter_affine_map.cc +++ b/src/arith/iter_affine_map.cc @@ -412,8 +412,8 @@ class IterMapRewriter : public ExprMutator { return analyzer_->CanProve(floormod(lhs, rhs) == 0); } - PrimExpr SplitFloorDivConst(IterSplitExpr lhs, PrimExpr rhs); - PrimExpr SplitFloorModConst(IterSplitExpr lhs, PrimExpr rhs); + PrimExpr SplitFloorDivConst(IterSplitExpr lhs, PrimExpr rhs, const PrimExpr &orig); + PrimExpr SplitFloorModConst(IterSplitExpr lhs, PrimExpr rhs, const PrimExpr &orig); static void AddToLhs(IterSumExprNode* lhs, IterSplitExpr rhs, int sign) { tir::ExprDeepEqual equal; @@ -577,14 +577,14 @@ PrimExpr IterMapRewriter::VisitExpr_(const MulNode* op) { if (op->a.same_as(a) && op->b.same_as(b)) { return GetRef<PrimExpr>(op); } else { - return Mul(a, b); + return GetRef<PrimExpr>(op); } } if (a->IsInstance<IterMapExprNode>() && b->IsInstance<IterMapExprNode>()) { // cannot multiply two iterators, mark as unresolved. ++unresolved_count_; - return Mul(a, b); + return GetRef<PrimExpr>(op); } if (!a->IsInstance<IterMapExprNode>()) { @@ -603,7 +603,7 @@ PrimExpr IterMapRewriter::VisitExpr_(const MulNode* op) { } } -PrimExpr IterMapRewriter::SplitFloorDivConst(IterSplitExpr lhs, PrimExpr rhs) { +PrimExpr IterMapRewriter::SplitFloorDivConst(IterSplitExpr lhs, PrimExpr rhs, const PrimExpr &orig) { // floordiv(x*scale, rhs) if (is_one(rhs)) return std::move(lhs); if (!is_one(lhs->scale)) { @@ -619,7 +619,7 @@ PrimExpr IterMapRewriter::SplitFloorDivConst(IterSplitExpr lhs, PrimExpr rhs) { } else { // mark as unresolved. ++unresolved_count_; - return floordiv(lhs, rhs); + return orig; } } } @@ -641,7 +641,7 @@ PrimExpr IterMapRewriter::SplitFloorDivConst(IterSplitExpr lhs, PrimExpr rhs) { } else { // mark as unresolved. ++unresolved_count_; - return floordiv(lhs, rhs); + return orig; } } @@ -669,7 +669,7 @@ PrimExpr IterMapRewriter::VisitExpr_(const FloorDivNode* op) { if (b->IsInstance<IterMapExprNode>()) { // cannot divide an iterator, mark as unresolved. ++unresolved_count_; - return FloorDiv(a, b); + return GetRef<PrimExpr>(op); } if (a->IsInstance<IterSumExprNode>()) { @@ -678,16 +678,16 @@ PrimExpr IterMapRewriter::VisitExpr_(const FloorDivNode* op) { return SplitFloorDivConst(opt.value(), b); } else { ++unresolved_count_; - return FloorDiv(a, b); + return GetRef<PrimExpr>(op); } } else { ICHECK(a->IsInstance<IterSplitExprNode>()); IterSplitExpr ret = Downcast<IterSplitExpr>(std::move(a)); - return SplitFloorDivConst(ret, b); + return SplitFloorDivConst(ret, b, GetRef<PrimExpr>(op)); } } -PrimExpr IterMapRewriter::SplitFloorModConst(IterSplitExpr lhs, PrimExpr rhs) { +PrimExpr IterMapRewriter::SplitFloorModConst(IterSplitExpr lhs, PrimExpr rhs, const PrimExpr &orig) { // floormod(x*scale, rhs) if (is_one(rhs)) return make_zero(lhs->dtype); if (!is_one(lhs->scale)) { @@ -701,7 +701,7 @@ PrimExpr IterMapRewriter::SplitFloorModConst(IterSplitExpr lhs, PrimExpr rhs) { } else { // mark as unresolved. ++unresolved_count_; - return floormod(lhs, rhs); + return orig; } } } @@ -715,7 +715,7 @@ PrimExpr IterMapRewriter::SplitFloorModConst(IterSplitExpr lhs, PrimExpr rhs) { } else { // mark as unresolved. ++unresolved_count_; - return floormod(lhs, rhs); + return orig; } } @@ -743,21 +743,21 @@ PrimExpr IterMapRewriter::VisitExpr_(const FloorModNode* op) { if (b->IsInstance<IterMapExprNode>()) { // cannot mod an iterator, mark as unresolved. ++unresolved_count_; - return FloorMod(a, b); + return GetRef<PrimExpr>(op); } if (a->IsInstance<IterSumExprNode>()) { IterSumExpr ret = Downcast<IterSumExpr>(a); if (auto opt = TryFuseIters(ret)) { - return SplitFloorModConst(opt.value(), b); + return SplitFloorModConst(opt.value(), b, GetRef<PrimExpr>(op)); } else { ++unresolved_count_; - return FloorMod(a, b); + return GetRef<PrimExpr>(op); } } else { ICHECK(a->IsInstance<IterSplitExprNode>()); IterSplitExpr ret = Downcast<IterSplitExpr>(std::move(a)); - return SplitFloorModConst(ret, b); + return SplitFloorModConst(ret, b, GetRef<PrimExpr>(op)) } } diff --git a/tests/python/unittest/test_arith_iter_affine_map.py b/tests/python/unittest/test_arith_iter_affine_map.py index 620540c..6ab61fd 100644 --- a/tests/python/unittest/test_arith_iter_affine_map.py +++ b/tests/python/unittest/test_arith_iter_affine_map.py @@ -161,6 +161,9 @@ def test_split(): assert len(res) == 1 assert_iter_sum_pattern(res[0], 8, 0, scale=2) + res = tvm.arith.detect_iter_map([fld(x, flm(flm(y, 8), 6))], var_dom([(x, 24), (y, 8)])) + assert len(res) == 0 + def test_compound(): x = tvm.tir.Var("x", "int32"), 10
