This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 9815ae2 [Arith] Simplify cast (#7045)
9815ae2 is described below
commit 9815ae2d9e17eece1a1009eb6436c80f931c734e
Author: Haozheng Fan <[email protected]>
AuthorDate: Fri Jan 8 00:24:31 2021 +0800
[Arith] Simplify cast (#7045)
---
src/arith/canonical_simplify.cc | 161 +++++++++++++++++++++
.../unittest/test_arith_canonical_simplify.py | 41 ++++++
2 files changed, 202 insertions(+)
diff --git a/src/arith/canonical_simplify.cc b/src/arith/canonical_simplify.cc
index d0a0702..ba54995 100644
--- a/src/arith/canonical_simplify.cc
+++ b/src/arith/canonical_simplify.cc
@@ -78,6 +78,27 @@ inline PrimExpr DivImpl(PrimExpr a, PrimExpr b, DivMode
mode) {
}
/*!
+ * \brief check if value fits in dtype
+ * \param value The value to be analyzed
+ * \param dtype The target dtype
+ * \param analyzer The analyzer
+ * \return whether value fits in dtype
+ */
+bool CastIsSafe(DataType dtype, PrimExpr value, Analyzer* analyzer) {
+ if (!IsIndexType(dtype)) {
+ return false;
+ }
+ ConstIntBound bound = analyzer->const_int_bound(value);
+ int64_t ubound = Downcast<IntImm>(max_value(dtype))->value;
+ int64_t lbound = Downcast<IntImm>(min_value(dtype))->value;
+ if (value.dtype().bits() <= dtype.bits() || // upcast is safe
+ (bound->max_value <= ubound && bound->min_value >= lbound)) {
+ return true;
+ }
+ return false;
+}
+
+/*!
* \brief Internal "Split normal form" of expression.
*
* This is a special expression that represents
@@ -128,6 +149,58 @@ class SplitExprNode : public CanonicalExprNode {
void MulToSelf(int64_t scale) { this->scale *= scale; }
+ /*!
+ * \brief check if cast can be pushed to sub-expressions
+ * \param dtype The target datatype
+ * \param analyzer The analyzer
+ * \return whether the cast can be safely pushed to children
+ */
+ bool CanPushCastToChildren(DataType dtype, Analyzer* analyzer) const {
+ // cast(dtype, index % upper_factor / lower_factor * scale) ==
+ // cast(dtype, index) % upper_factor / lower_factor * scale
+ // iff it is an upcast (dtype.bits >= self.dtype.bits) or all of
+ // its intermediate results fit in the range of dtype
+ if (dtype.bits() >= this->dtype.bits()) {
+ return true; // upcast is safe
+ }
+ PrimExpr res = this->index;
+ if (this->scale == 0) {
+ return true;
+ }
+ if (!CastIsSafe(dtype, res, analyzer)) {
+ return false;
+ }
+ if (this->upper_factor != SplitExprNode::kPosInf) {
+ res = ModImpl(res, make_const(this->dtype, this->upper_factor),
div_mode);
+ if (!CastIsSafe(dtype, res, analyzer)) {
+ return false;
+ }
+ }
+ if (this->lower_factor != 1) {
+ res = DivImpl(res, make_const(this->dtype, this->lower_factor),
div_mode);
+ if (!CastIsSafe(dtype, res, analyzer)) {
+ return false;
+ }
+ }
+ if (this->scale != 1) {
+ ICHECK(!this->dtype.is_uint() || this->scale > 0);
+ res = res * make_const(this->dtype, this->scale);
+ if (!CastIsSafe(dtype, res, analyzer)) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ /*!
+ * \brief self = cast(dtype, self)
+ * \param dtype The target datatype
+ */
+ void PushCastToChildren(DataType dtype) {
+ this->index = cast(dtype, this->index);
+ this->dtype = dtype;
+ }
+
inline bool IndexEqual(const SplitExpr& other) const;
inline bool DivModeCompatibleTo(DivMode mode) const;
@@ -255,6 +328,69 @@ class SumExprNode : public CanonicalExprNode {
void AddToSelf(const SumExpr& other, int64_t scale);
+ /*!
+ * \brief check if cast can be pushed to sub-expressions
+ * \param dtype The target datatype
+ * \param analyzer The analyzer
+ * \return whether the cast can be safely pushed to children
+ */
+ bool CanPushCastToChildren(DataType dtype, Analyzer* analyzer) const {
+ // cast(dtype, arg_1 + arg_2 + ... arg_n) ==
+ // cast(dtype, arg_1) + ... + cast(dtype, arg_n)
+ // iff it is an upcast (dtype.bits >= self.dtype.bits) or all of
+ // its intermediate results fit in the range of dtype
+ if (dtype.bits() >= this->dtype.bits()) {
+ return true; // upcast is safe
+ }
+ PrimExpr res = make_const(dtype, 0);
+ for (size_t i = 0; i < args.size(); ++i) {
+ if (args[i]->scale > 0) {
+ res = res + args[i]->Normalize();
+ if (!CastIsSafe(dtype, res, analyzer)) {
+ return false;
+ }
+ }
+ }
+ if (base > 0) {
+ res = res + make_const(dtype, base);
+ if (!CastIsSafe(dtype, res, analyzer)) {
+ return false;
+ }
+ }
+ // negative scales follows using sub.
+ for (size_t i = 0; i < args.size(); ++i) {
+ if (args[i]->scale < 0) {
+ res = res - args[i]->NormalizeWithScale(-1);
+ if (!CastIsSafe(dtype, res, analyzer)) {
+ return false;
+ }
+ }
+ }
+ if (base < 0) {
+ res = res - make_const(dtype, -base);
+ if (!CastIsSafe(dtype, res, analyzer)) {
+ return false;
+ }
+ }
+ for (const auto& arg : args) {
+ if (!arg->CanPushCastToChildren(dtype, analyzer)) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ /*!
+ * \brief self = cast(dtype, self)
+ * \param dtype The target datatype
+ */
+ void PushCastToChildren(DataType dtype) {
+ for (auto& arg : args) {
+ arg.CopyOnWrite()->PushCastToChildren(dtype);
+ }
+ this->dtype = dtype;
+ }
+
static constexpr const char* _type_key = "arith.SumExpr";
TVM_DECLARE_FINAL_OBJECT_INFO(SumExprNode, CanonicalExprNode);
@@ -430,6 +566,7 @@ class CanonicalSimplifier::Impl : public
RewriteSimplifier::Impl {
PrimExpr VisitExpr_(const FloorDivNode* op) final;
PrimExpr VisitExpr_(const FloorModNode* op) final;
PrimExpr VisitExpr_(const ReduceNode* op) final;
+ PrimExpr VisitExpr_(const CastNode* op) final;
private:
/*!
@@ -1071,6 +1208,30 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const
ReduceNode* op) {
return ret;
}
+PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const CastNode* op) {
+ if (!IsIndexType(op->dtype)) {
+ return Rewriter::VisitExpr_(op);
+ }
+ // normalize
+ PrimExpr value = this->CanonicalMutate(op->value);
+ // PushCastToChildren
+ if (value.as<SumExprNode>()) {
+ SumExpr se = Downcast<SumExpr>(value);
+ if (se->CanPushCastToChildren(op->dtype, analyzer_)) {
+ se.CopyOnWrite()->PushCastToChildren(op->dtype);
+ return std::move(se);
+ }
+ }
+ if (value.as<SplitExprNode>()) {
+ SplitExpr se = Downcast<SplitExpr>(value);
+ if (se->CanPushCastToChildren(op->dtype, analyzer_)) {
+ se.CopyOnWrite()->PushCastToChildren(op->dtype);
+ return std::move(se);
+ }
+ }
+ return Rewriter::VisitExpr_(op);
+}
+
PrimExpr CanonicalSimplifier::operator()(const PrimExpr& expr) {
return impl_->CanonicalSimplify(expr);
}
diff --git a/tests/python/unittest/test_arith_canonical_simplify.py
b/tests/python/unittest/test_arith_canonical_simplify.py
index 65c8ec3..c241b81 100644
--- a/tests/python/unittest/test_arith_canonical_simplify.py
+++ b/tests/python/unittest/test_arith_canonical_simplify.py
@@ -310,6 +310,46 @@ def test_complex_cases():
ck.verify(res3, tdiv((x * 1024) + y, 256) - tdiv(y, 256) - (x * 4))
+def test_simplify_cast():
+ ck = CanonicalChecker()
+ tcast = tvm.tir.Cast
+ fld = tvm.te.floordiv
+ flm = tvm.te.floormod
+ # cast(i64, i + j + 1) - cast(i64, i)
+ i = te.var("i", dtype="int32")
+ j = te.var("j", dtype="int32")
+ res = tcast("int64", i + j + 1) - tcast("int64", i)
+ ck.verify(res, tcast("int64", j) + tvm.tir.const(1, "int64"))
+ # cast(i32, i + j + 1) - cast(i32, i)
+ i = te.var("i", dtype="int64")
+ j = te.var("j", dtype="int64")
+ ck.analyzer.update(i, tvm.arith.ConstIntBound(0, 10))
+ ck.analyzer.update(j, tvm.arith.ConstIntBound(0, 10))
+ res = tcast("int32", i + j + 1) - tcast("int32", i)
+ ck.verify(res, tcast("int32", j) + 1)
+ # cast(i32, i + j - 100)
+ i = te.var("i", dtype="int64")
+ j = te.var("j", dtype="int64")
+ ck.analyzer.update(i, tvm.arith.ConstIntBound(0, 2 ** 31 - 1))
+ ck.analyzer.update(j, tvm.arith.ConstIntBound(0, 10))
+ res = tcast("int32", i + j - 100)
+ ck.verify(res, res)
+ # cast(i32, flm(axis, 7i64) * 2i64 + 1i64) + 1i32
+ # - cast(i32, flm(axis, 7i64) * 2i64)
+ axis = te.var("axis", dtype="int64")
+ ck.analyzer.update(axis, tvm.arith.ConstIntBound(0, 42))
+ res = (
+ tcast(
+ "int32",
+ flm(axis, tvm.tir.const(7, "int64")) * tvm.tir.const(2, "int64")
+ + tvm.tir.const(1, "int64"),
+ )
+ + tvm.tir.const(1, "int32")
+ - tcast("int32", flm(axis, tvm.tir.const(7, "int64")) *
tvm.tir.const(2, "int64"))
+ )
+ ck.verify(res, 2)
+
+
if __name__ == "__main__":
test_floormod_simplify()
test_mul_sum_simplify()
@@ -321,3 +361,4 @@ if __name__ == "__main__":
test_split_index_simplify()
test_canonical_mixed()
test_complex_cases()
+ test_simplify_cast()