This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new fb64be3f78 [ARITH] Allow Analyzer to MarkGlobalNonNegValue (#15193)
fb64be3f78 is described below
commit fb64be3f7807df18c2df6ebf5e68178e564ab0b4
Author: Tianqi Chen <[email protected]>
AuthorDate: Sun Jul 2 23:02:11 2023 -0400
[ARITH] Allow Analyzer to MarkGlobalNonNegValue (#15193)
This PR introduces an utility function MarkGlobalNonNegValue.
This function allows analyzer to mark buffer shapes in function arguments
as positive globally and opens doors for more symbolic simplification.
---
include/tvm/arith/analyzer.h | 16 +++++++++
src/arith/analyzer.cc | 33 +++++++++++++++++++
src/arith/const_int_bound.cc | 38 ++++++++++++++++++----
src/arith/ir_mutator_with_analyzer.cc | 9 +++++
src/arith/ir_mutator_with_analyzer.h | 8 +++++
src/tir/transforms/flatten_buffer.cc | 1 +
src/tir/transforms/simplify.cc | 21 +++++-------
.../python/unittest/test_arith_const_int_bound.py | 17 ++++++++++
.../python/unittest/test_tir_transform_simplify.py | 16 +++++++++
9 files changed, 140 insertions(+), 19 deletions(-)
diff --git a/include/tvm/arith/analyzer.h b/include/tvm/arith/analyzer.h
index efa26a31d0..fb837ca3d0 100644
--- a/include/tvm/arith/analyzer.h
+++ b/include/tvm/arith/analyzer.h
@@ -618,6 +618,22 @@ class TVM_DLL Analyzer {
TransitiveComparisonAnalyzer transitive_comparisons;
/*! \brief constructor */
Analyzer();
+ /*!
+ * \brief Mark the value as non-negative value globally in analyzer.
+ *
+ * Only call this function if the non-neg condition is global and
+ * not context-dependent.
+ *
+ * This function does best-effort propagations to the sub-analyzers
+ *
+ * \note We expose this function because non-negative global values,
+ * such as symbolic buffer shapes in function arguments are really
+ * important to ensure the best simplification, and usually they
+ * can be handled in a simpler way than the generic constraints.
+ *
+ * This function may call into the Update function of the sub-analyzers.
+ */
+ void MarkGlobalNonNegValue(const PrimExpr& value);
/*!
* \brief Notify all the sub-analyzers that var
* is created and binded to expr.
diff --git a/src/arith/analyzer.cc b/src/arith/analyzer.cc
index 722a2cd00e..9e5b1414ed 100644
--- a/src/arith/analyzer.cc
+++ b/src/arith/analyzer.cc
@@ -25,6 +25,7 @@
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
+#include "const_fold.h"
#include "product_normal_form.h"
namespace tvm {
@@ -63,6 +64,38 @@ void Analyzer::Bind(const Var& var, const Range& range, bool
allow_override) {
// skip rewrite simplify
}
+void Analyzer::MarkGlobalNonNegValue(const PrimExpr& value) {
+ // split out the symbolic and non-symbolic part
+ int64_t cscale = 1;
+ PrimExpr symbolic = tir::make_const(value.dtype(), 1);
+ auto fcollect = [&](PrimExpr val) {
+ if (const auto* intimm = val.as<IntImmNode>()) {
+ cscale *= intimm->value;
+ } else {
+ symbolic = symbolic * val;
+ }
+ };
+ UnpackReduction<tir::MulNode>(value, fcollect);
+ if (cscale <= 0) return;
+ // override the constant int bound by marking it as non-negative
+ // NOTE: there might be future opportunities of more bound hint
+ // this is a simple step and covers all the current needs
+ //
+ // We may consider enhance the sub analyzer to directly take
+ // MarkPositiveVar so their bounds do not overlap
+ if (const auto* var_ptr = symbolic.as<VarNode>()) {
+ Var var = GetRef<Var>(var_ptr);
+ // skip non-index type, keep it to be compatible
+ // with any_dim that do not represent any value
+ if (!IsIndexType(var.dtype())) return;
+ bool allow_override = true;
+ // mark the constant bound is sufficient
+ // we cannot mark interval set as that will cause relaxation of the var
+ // during bound proof which is not our intention
+ this->const_int_bound.Update(var, ConstIntBound(0,
ConstIntBound::kPosInf), allow_override);
+ }
+}
+
void Analyzer::Bind(const Map<Var, Range>& variables, bool allow_override) {
for (const auto& iter : variables) {
this->Bind(iter.first, iter.second, allow_override);
diff --git a/src/arith/const_int_bound.cc b/src/arith/const_int_bound.cc
index 68ade3bb54..8ce5025231 100644
--- a/src/arith/const_int_bound.cc
+++ b/src/arith/const_int_bound.cc
@@ -195,6 +195,31 @@ class ConstIntBoundAnalyzer::Impl
return Intersect(a, b);
}
+ /*!
+ * \brief Process the divisor by making assumption that divide by zero
+ * won't happen in a valid program.
+ *
+ * This is important for us to get a lot of symbolic shape bound right
+ * now that the shape n >= 0, but in cases
+ * when mod or divide of n occur, the intention is actually n > 0
+ *
+ * \param divisor The input divsor entry
+ * \return The processed entry
+ */
+ Entry AssumeNoZeroDivisor(Entry divisor) {
+ ICHECK(!divisor.is_const(0)) << "Find divide by zero";
+ // NOTE: here we make the assumption that
+ // divide by zero won't happen in a valid program
+ // this is important for us to get a lot of symbolic shape bound right
+ // where most conditions know that the shape n >= 0, but in cases
+ // when mod or divide of n occur, the intention is actually n > 0
+ if (divisor.min_value == 0) {
+ divisor.min_value = 1;
+ ICHECK_GE(divisor.max_value, 1);
+ }
+ return divisor;
+ }
+
Entry VisitExpr_(const IntImmNode* op) final { return MakeBound(op->value,
op->value); }
Entry VisitExpr_(const AddNode* op) final {
@@ -223,14 +248,14 @@ class ConstIntBoundAnalyzer::Impl
Entry VisitExpr_(const DivNode* op) final {
Entry a = VisitExpr(op->a);
- Entry b = VisitExpr(op->b);
- ICHECK(!b.is_const(0)) << "divide by zero";
+ Entry b = AssumeNoZeroDivisor(VisitExpr(op->b));
return HandleDivision(a, b, op->dtype, InfAwareDiv);
}
Entry VisitExpr_(const ModNode* op) final {
Entry a = VisitExpr(op->a);
- Entry b = VisitExpr(op->b);
+ Entry b = AssumeNoZeroDivisor(VisitExpr(op->b));
+
if (b.min_value > 0) {
int64_t b_max_cap = InfAwareAdd(b.max_value, -1);
if (a.min_value >= 0) {
@@ -252,8 +277,7 @@ class ConstIntBoundAnalyzer::Impl
Entry VisitExpr_(const FloorDivNode* op) final {
Entry a = VisitExpr(op->a);
- Entry b = VisitExpr(op->b);
- ICHECK(!b.is_const(0)) << "floordiv by zero";
+ Entry b = AssumeNoZeroDivisor(VisitExpr(op->b));
return HandleDivision(a, b, op->dtype, InfAwareFloorDiv);
}
@@ -276,7 +300,8 @@ class ConstIntBoundAnalyzer::Impl
* That is, min(0, b_min + 1) <= floormod(a, b) <= max(0, b_max - 1)
*/
Entry a = VisitExpr(op->a);
- Entry b = VisitExpr(op->b);
+ Entry b = AssumeNoZeroDivisor(VisitExpr(op->b));
+
if (b.min_value > 0) {
int64_t b_max_cap = InfAwareAdd(b.max_value, -1);
if (a.min_value >= 0) {
@@ -457,7 +482,6 @@ class ConstIntBoundAnalyzer::Impl
// at a negative value and ends at a positive one, narrow it down to
// be closer to 0, because BinaryOpBoundary only checks end-points of
// the domain ranges.
-
// If the range of b contains 0, then some infinity will be involved
if (b.min_value <= 0 && 0 <= b.max_value && dt.is_int()) {
Entry b_neg = b.min_value < 0 ? MakeBound(b.min_value, -1) :
Everything(dt);
diff --git a/src/arith/ir_mutator_with_analyzer.cc
b/src/arith/ir_mutator_with_analyzer.cc
index c201a245e1..1f087d9934 100644
--- a/src/arith/ir_mutator_with_analyzer.cc
+++ b/src/arith/ir_mutator_with_analyzer.cc
@@ -30,6 +30,15 @@ namespace arith {
using namespace tir;
+void IRMutatorWithAnalyzer::MarkBufferMapShapes(const tir::PrimFunc& func) {
+ // Mark the all the symbolic buffer shape values in the buffer map as
positive value.
+ for (auto kv : func->buffer_map) {
+ for (PrimExpr shape : kv.second->shape) {
+ analyzer_->MarkGlobalNonNegValue(shape);
+ }
+ }
+}
+
Stmt IRMutatorWithAnalyzer::VisitStmt_(const ForNode* op) {
// record the loop variable as iterators
Range dom = Range::FromMinExtent(op->min, op->extent);
diff --git a/src/arith/ir_mutator_with_analyzer.h
b/src/arith/ir_mutator_with_analyzer.h
index ed62c91df9..f04b40e7ae 100644
--- a/src/arith/ir_mutator_with_analyzer.h
+++ b/src/arith/ir_mutator_with_analyzer.h
@@ -62,6 +62,14 @@ class IRMutatorWithAnalyzer : public tir::StmtExprMutator {
PrimExpr VisitExpr_(const tir::ReduceNode* op) override;
protected:
+ /*!
+ * \brief Mark the all the buffer shape values in the buffer map as positive
value.
+ *
+ * \note call this function before Visit function's body to maximize
+ * simplification efficiency
+ */
+ void MarkBufferMapShapes(const tir::PrimFunc& func);
+
/*! \brief internal analyzer field. */
Analyzer* analyzer_;
// the following two fields are useful in case we want
diff --git a/src/tir/transforms/flatten_buffer.cc
b/src/tir/transforms/flatten_buffer.cc
index ffdff45a7d..f37c21593f 100644
--- a/src/tir/transforms/flatten_buffer.cc
+++ b/src/tir/transforms/flatten_buffer.cc
@@ -42,6 +42,7 @@ class BufferFlattener : public arith::IRMutatorWithAnalyzer {
arith::Analyzer ana;
auto pass = BufferFlattener(&ana);
auto writer = func.CopyOnWrite();
+ pass.MarkBufferMapShapes(func);
writer->body = pass.VisitStmt(func->body);
// The buffers in func->buffer_map are deliberately left
// unflattened, as they are used for validation of user-provided
diff --git a/src/tir/transforms/simplify.cc b/src/tir/transforms/simplify.cc
index 130cbe37c1..44d64df63d 100644
--- a/src/tir/transforms/simplify.cc
+++ b/src/tir/transforms/simplify.cc
@@ -142,20 +142,24 @@ TVM_REGISTER_PASS_CONFIG_OPTION("tir.Simplify",
SimplifyConfig);
class StmtSimplifier : public IRMutatorWithAnalyzer {
public:
- static Stmt Apply(Stmt stmt, Analyzer* analyzer, Optional<SimplifyConfig>
config_opt = NullOpt) {
+ static PrimFunc Apply(PrimFunc func, Analyzer* analyzer,
+ Optional<SimplifyConfig> config_opt = NullOpt) {
auto config =
config_opt.value_or(AttrsWithDefaultValues<arith::SimplifyConfig>());
analyzer->rewrite_simplify.SetEnabledExtensions(config->GetEnabledExtensions());
std::optional<ControlFlowGraph> touch_pattern = std::nullopt;
if (config->propagate_knowns_to_prove_conditional ||
config->propagate_knowns_to_simplify_expressions) {
- touch_pattern = ControlFlowGraph(stmt);
+ touch_pattern = ControlFlowGraph(func->body);
}
- std::unordered_set<const VarNode*> used_in_buffer_def =
CollectVarsUsedInBufferDefinition(stmt);
+ std::unordered_set<const VarNode*> used_in_buffer_def =
+ CollectVarsUsedInBufferDefinition(func->body);
StmtSimplifier simplifier(analyzer, config, std::move(touch_pattern),
std::move(used_in_buffer_def));
- return simplifier(std::move(stmt));
+ simplifier.MarkBufferMapShapes(func);
+ func.CopyOnWrite()->body = simplifier(func->body);
+ return func;
}
private:
@@ -335,11 +339,6 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
} // namespace arith
namespace tir {
-
-Stmt Simplify(Stmt stmt, arith::Analyzer* analyzer) {
- return arith::StmtSimplifier::Apply(stmt, analyzer);
-}
-
namespace transform {
Pass Simplify() {
@@ -347,9 +346,7 @@ Pass Simplify() {
arith::Analyzer analyzer;
auto cfg = ctx->GetConfig<arith::SimplifyConfig>("tir.Simplify");
- auto* n = f.CopyOnWrite();
- n->body = arith::StmtSimplifier::Apply(std::move(n->body), &analyzer, cfg);
- return f;
+ return arith::StmtSimplifier::Apply(f, &analyzer, cfg);
};
return CreatePrimFuncPass(pass_func, 0, "tir.Simplify", {});
}
diff --git a/tests/python/unittest/test_arith_const_int_bound.py
b/tests/python/unittest/test_arith_const_int_bound.py
index d9ea36206b..5667c79aac 100644
--- a/tests/python/unittest/test_arith_const_int_bound.py
+++ b/tests/python/unittest/test_arith_const_int_bound.py
@@ -339,6 +339,23 @@ def test_floormod_negative_divisor():
assert bd.max_value == 6
+def test_divmod_assume_no_zero_divsor():
+ # Divmod non negative expression makes assumption that divide by zero
won't occur
+ # this assumption is important to get best result from symbolic shape
programs
+ analyzer = tvm.arith.Analyzer()
+ flm, fld = tvm.te.floormod, tvm.te.floordiv
+ a, b = te.var("a"), te.var("b")
+ analyzer.update(a, tvm.arith.ConstIntBound(0, 6))
+ analyzer.update(b, tvm.arith.ConstIntBound(0,
tvm.arith.ConstIntBound.POS_INF))
+ bd = analyzer.const_int_bound(fld(a, b))
+ assert bd.min_value == 0
+ assert bd.max_value == 6
+
+ bd = analyzer.const_int_bound(flm(a, b))
+ assert bd.min_value == 0
+ assert bd.max_value == 6
+
+
def test_multiple_condition():
analyzer = tvm.arith.Analyzer()
flm, fld = tvm.te.floormod, tvm.te.floordiv
diff --git a/tests/python/unittest/test_tir_transform_simplify.py
b/tests/python/unittest/test_tir_transform_simplify.py
index 1f25405ec9..79fd5e1434 100644
--- a/tests/python/unittest/test_tir_transform_simplify.py
+++ b/tests/python/unittest/test_tir_transform_simplify.py
@@ -1733,5 +1733,21 @@ class TestSimplifyTrivialLetStride(BaseBeforeAfter):
expected = before
+class TestBufferShapeConstraint(BaseBeforeAfter):
+ """If enabled, rewrite boolean expressions into AND of OR"""
+
+ convert_boolean_to_and_of_ors = True
+
+ def before(a: T.handle):
+ n = T.int64()
+ A = T.match_buffer(a, (n * 32,), "float32")
+ A[T.min(T.int64(0), n)] = T.float32(0)
+
+ def expected(a: T.handle):
+ n = T.int64()
+ A = T.match_buffer(a, (n * 32,), "float32")
+ A[T.int64(0)] = T.float32(0)
+
+
if __name__ == "__main__":
tvm.testing.main()