This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 9815ae2  [Arith] Simplify cast (#7045)
9815ae2 is described below

commit 9815ae2d9e17eece1a1009eb6436c80f931c734e
Author: Haozheng Fan <[email protected]>
AuthorDate: Fri Jan 8 00:24:31 2021 +0800

    [Arith] Simplify cast (#7045)
---
 src/arith/canonical_simplify.cc                    | 161 +++++++++++++++++++++
 .../unittest/test_arith_canonical_simplify.py      |  41 ++++++
 2 files changed, 202 insertions(+)

diff --git a/src/arith/canonical_simplify.cc b/src/arith/canonical_simplify.cc
index d0a0702..ba54995 100644
--- a/src/arith/canonical_simplify.cc
+++ b/src/arith/canonical_simplify.cc
@@ -78,6 +78,27 @@ inline PrimExpr DivImpl(PrimExpr a, PrimExpr b, DivMode 
mode) {
 }
 
 /*!
+ * \brief check if value fits in dtype
+ * \param value The value to be analyzed
+ * \param dtype The target dtype
+ * \param analyzer The analyzer
+ * \return whether value fits in dtype
+ */
+bool CastIsSafe(DataType dtype, PrimExpr value, Analyzer* analyzer) {
+  if (!IsIndexType(dtype)) {
+    return false;
+  }
+  ConstIntBound bound = analyzer->const_int_bound(value);
+  int64_t ubound = Downcast<IntImm>(max_value(dtype))->value;
+  int64_t lbound = Downcast<IntImm>(min_value(dtype))->value;
+  if (value.dtype().bits() <= dtype.bits() ||  // upcast is safe
+      (bound->max_value <= ubound && bound->min_value >= lbound)) {
+    return true;
+  }
+  return false;
+}
+
+/*!
  * \brief Internal "Split normal form" of expression.
  *
  * This is a special expression that represents
@@ -128,6 +149,58 @@ class SplitExprNode : public CanonicalExprNode {
 
   void MulToSelf(int64_t scale) { this->scale *= scale; }
 
+  /*!
+   * \brief check if cast can be pushed to sub-expressions
+   * \param dtype The target datatype
+   * \param analyzer The analyzer
+   * \return whether the cast can be safely pushed to children
+   */
+  bool CanPushCastToChildren(DataType dtype, Analyzer* analyzer) const {
+    // cast(dtype, index % upper_factor / lower_factor * scale) ==
+    // cast(dtype, index) % upper_factor / lower_factor * scale
+    // iff it is an upcast (dtype.bits >= self.dtype.bits) or all of
+    // its intermediate results fit in the range of dtype
+    if (dtype.bits() >= this->dtype.bits()) {
+      return true;  // upcast is safe
+    }
+    PrimExpr res = this->index;
+    if (this->scale == 0) {
+      return true;
+    }
+    if (!CastIsSafe(dtype, res, analyzer)) {
+      return false;
+    }
+    if (this->upper_factor != SplitExprNode::kPosInf) {
+      res = ModImpl(res, make_const(this->dtype, this->upper_factor), 
div_mode);
+      if (!CastIsSafe(dtype, res, analyzer)) {
+        return false;
+      }
+    }
+    if (this->lower_factor != 1) {
+      res = DivImpl(res, make_const(this->dtype, this->lower_factor), 
div_mode);
+      if (!CastIsSafe(dtype, res, analyzer)) {
+        return false;
+      }
+    }
+    if (this->scale != 1) {
+      ICHECK(!this->dtype.is_uint() || this->scale > 0);
+      res = res * make_const(this->dtype, this->scale);
+      if (!CastIsSafe(dtype, res, analyzer)) {
+        return false;
+      }
+    }
+    return true;
+  }
+
+  /*!
+   * \brief self = cast(dtype, self)
+   * \param dtype The target datatype
+   */
+  void PushCastToChildren(DataType dtype) {
+    this->index = cast(dtype, this->index);
+    this->dtype = dtype;
+  }
+
   inline bool IndexEqual(const SplitExpr& other) const;
   inline bool DivModeCompatibleTo(DivMode mode) const;
 
@@ -255,6 +328,69 @@ class SumExprNode : public CanonicalExprNode {
 
   void AddToSelf(const SumExpr& other, int64_t scale);
 
+  /*!
+   * \brief check if cast can be pushed to sub-expressions
+   * \param dtype The target datatype
+   * \param analyzer The analyzer
+   * \return whether the cast can be safely pushed to children
+   */
+  bool CanPushCastToChildren(DataType dtype, Analyzer* analyzer) const {
+    // cast(dtype, arg_1 + arg_2 + ... arg_n) ==
+    // cast(dtype, arg_1) + ... + cast(dtype, arg_n)
+    // iff it is an upcast (dtype.bits >= self.dtype.bits) or all of
+    // its intermediate results fit in the range of dtype
+    if (dtype.bits() >= this->dtype.bits()) {
+      return true;  // upcast is safe
+    }
+    PrimExpr res = make_const(dtype, 0);
+    for (size_t i = 0; i < args.size(); ++i) {
+      if (args[i]->scale > 0) {
+        res = res + args[i]->Normalize();
+        if (!CastIsSafe(dtype, res, analyzer)) {
+          return false;
+        }
+      }
+    }
+    if (base > 0) {
+      res = res + make_const(dtype, base);
+      if (!CastIsSafe(dtype, res, analyzer)) {
+        return false;
+      }
+    }
+    // negative scales follows using sub.
+    for (size_t i = 0; i < args.size(); ++i) {
+      if (args[i]->scale < 0) {
+        res = res - args[i]->NormalizeWithScale(-1);
+        if (!CastIsSafe(dtype, res, analyzer)) {
+          return false;
+        }
+      }
+    }
+    if (base < 0) {
+      res = res - make_const(dtype, -base);
+      if (!CastIsSafe(dtype, res, analyzer)) {
+        return false;
+      }
+    }
+    for (const auto& arg : args) {
+      if (!arg->CanPushCastToChildren(dtype, analyzer)) {
+        return false;
+      }
+    }
+    return true;
+  }
+
+  /*!
+   * \brief self = cast(dtype, self)
+   * \param dtype The target datatype
+   */
+  void PushCastToChildren(DataType dtype) {
+    for (auto& arg : args) {
+      arg.CopyOnWrite()->PushCastToChildren(dtype);
+    }
+    this->dtype = dtype;
+  }
+
   static constexpr const char* _type_key = "arith.SumExpr";
   TVM_DECLARE_FINAL_OBJECT_INFO(SumExprNode, CanonicalExprNode);
 
@@ -430,6 +566,7 @@ class CanonicalSimplifier::Impl : public 
RewriteSimplifier::Impl {
   PrimExpr VisitExpr_(const FloorDivNode* op) final;
   PrimExpr VisitExpr_(const FloorModNode* op) final;
   PrimExpr VisitExpr_(const ReduceNode* op) final;
+  PrimExpr VisitExpr_(const CastNode* op) final;
 
  private:
   /*!
@@ -1071,6 +1208,30 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const 
ReduceNode* op) {
   return ret;
 }
 
+PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const CastNode* op) {
+  if (!IsIndexType(op->dtype)) {
+    return Rewriter::VisitExpr_(op);
+  }
+  // normalize
+  PrimExpr value = this->CanonicalMutate(op->value);
+  // PushCastToChildren
+  if (value.as<SumExprNode>()) {
+    SumExpr se = Downcast<SumExpr>(value);
+    if (se->CanPushCastToChildren(op->dtype, analyzer_)) {
+      se.CopyOnWrite()->PushCastToChildren(op->dtype);
+      return std::move(se);
+    }
+  }
+  if (value.as<SplitExprNode>()) {
+    SplitExpr se = Downcast<SplitExpr>(value);
+    if (se->CanPushCastToChildren(op->dtype, analyzer_)) {
+      se.CopyOnWrite()->PushCastToChildren(op->dtype);
+      return std::move(se);
+    }
+  }
+  return Rewriter::VisitExpr_(op);
+}
+
 PrimExpr CanonicalSimplifier::operator()(const PrimExpr& expr) {
   return impl_->CanonicalSimplify(expr);
 }
diff --git a/tests/python/unittest/test_arith_canonical_simplify.py 
b/tests/python/unittest/test_arith_canonical_simplify.py
index 65c8ec3..c241b81 100644
--- a/tests/python/unittest/test_arith_canonical_simplify.py
+++ b/tests/python/unittest/test_arith_canonical_simplify.py
@@ -310,6 +310,46 @@ def test_complex_cases():
     ck.verify(res3, tdiv((x * 1024) + y, 256) - tdiv(y, 256) - (x * 4))
 
 
+def test_simplify_cast():
+    ck = CanonicalChecker()
+    tcast = tvm.tir.Cast
+    fld = tvm.te.floordiv
+    flm = tvm.te.floormod
+    # cast(i64, i + j + 1) - cast(i64, i)
+    i = te.var("i", dtype="int32")
+    j = te.var("j", dtype="int32")
+    res = tcast("int64", i + j + 1) - tcast("int64", i)
+    ck.verify(res, tcast("int64", j) + tvm.tir.const(1, "int64"))
+    # cast(i32, i + j + 1) - cast(i32, i)
+    i = te.var("i", dtype="int64")
+    j = te.var("j", dtype="int64")
+    ck.analyzer.update(i, tvm.arith.ConstIntBound(0, 10))
+    ck.analyzer.update(j, tvm.arith.ConstIntBound(0, 10))
+    res = tcast("int32", i + j + 1) - tcast("int32", i)
+    ck.verify(res, tcast("int32", j) + 1)
+    # cast(i32, i + j - 100)
+    i = te.var("i", dtype="int64")
+    j = te.var("j", dtype="int64")
+    ck.analyzer.update(i, tvm.arith.ConstIntBound(0, 2 ** 31 - 1))
+    ck.analyzer.update(j, tvm.arith.ConstIntBound(0, 10))
+    res = tcast("int32", i + j - 100)
+    ck.verify(res, res)
+    # cast(i32, flm(axis, 7i64) * 2i64 + 1i64) + 1i32
+    # - cast(i32, flm(axis, 7i64) * 2i64)
+    axis = te.var("axis", dtype="int64")
+    ck.analyzer.update(axis, tvm.arith.ConstIntBound(0, 42))
+    res = (
+        tcast(
+            "int32",
+            flm(axis, tvm.tir.const(7, "int64")) * tvm.tir.const(2, "int64")
+            + tvm.tir.const(1, "int64"),
+        )
+        + tvm.tir.const(1, "int32")
+        - tcast("int32", flm(axis, tvm.tir.const(7, "int64")) * 
tvm.tir.const(2, "int64"))
+    )
+    ck.verify(res, 2)
+
+
 if __name__ == "__main__":
     test_floormod_simplify()
     test_mul_sum_simplify()
@@ -321,3 +361,4 @@ if __name__ == "__main__":
     test_split_index_simplify()
     test_canonical_mixed()
     test_complex_cases()
+    test_simplify_cast()

Reply via email to