This is an automated email from the ASF dual-hosted git repository.

wrongtest pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 397cf8781e [Arith][Refactor] Return Optional<PrimExpr> from 
TryConstFold (#12784)
397cf8781e is described below

commit 397cf8781eba7a2bcc35e832130801c1d1419c43
Author: Eric Lunderberg <[email protected]>
AuthorDate: Thu Sep 15 06:39:20 2022 -0500

    [Arith][Refactor] Return Optional<PrimExpr> from TryConstFold (#12784)
    
    Prior to this commit, the templated `TryConstFold` utility returned an
    undefined `PrimExpr` to represent a failure to perform constant
    folding.  This commit makes this explicit by returning
    `Optional<PrimExpr>` instead.
---
 src/arith/canonical_simplify.cc | 21 ++++------
 src/arith/const_fold.h          | 91 +++++++++++++++++++++--------------------
 src/arith/int_set.cc            | 10 +++--
 src/arith/iter_affine_map.cc    | 15 +++----
 src/arith/pattern_match.h       |  3 +-
 src/arith/rewrite_simplify.cc   | 42 +++++++------------
 src/tir/op/op.cc                | 57 +++++++++-----------------
 7 files changed, 99 insertions(+), 140 deletions(-)

diff --git a/src/arith/canonical_simplify.cc b/src/arith/canonical_simplify.cc
index 9f45317cba..f5d2667aa6 100644
--- a/src/arith/canonical_simplify.cc
+++ b/src/arith/canonical_simplify.cc
@@ -716,8 +716,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const 
AddNode* op) {
   PrimExpr b = this->CanonicalMutate(op->b);
 
   // const folding
-  PrimExpr const_res = TryConstFold<Add>(a, b);
-  if (const_res.defined()) return const_res;
+  if (auto const_res = TryConstFold<Add>(a, b)) return const_res.value();
 
   // canonical form simplification.
   SumExpr ret = ToSumExpr(std::move(a));
@@ -741,8 +740,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const 
SubNode* op) {
   PrimExpr b = this->CanonicalMutate(op->b);
 
   // const folding
-  PrimExpr const_res = TryConstFold<Sub>(a, b);
-  if (const_res.defined()) return const_res;
+  if (auto const_res = TryConstFold<Sub>(a, b)) return const_res.value();
 
   // canonical form simplification.
   SumExpr ret = ToSumExpr(std::move(a));
@@ -766,8 +764,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const 
MulNode* op) {
   PrimExpr b = this->CanonicalMutate(op->b);
 
   // const folding
-  PrimExpr const_res = TryConstFold<Mul>(a, b);
-  if (const_res.defined()) return const_res;
+  if (auto const_res = TryConstFold<Mul>(a, b)) return const_res.value();
 
   // x * c
   if (a.as<IntImmNode>()) {
@@ -870,8 +867,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const 
DivNode* op) {
   PrimExpr b = this->CanonicalMutate(op->b);
 
   // const folding
-  PrimExpr const_res = TryConstFold<Div>(a, b);
-  if (const_res.defined()) return const_res;
+  if (auto const_res = TryConstFold<Div>(a, b)) return const_res.value();
   PVar<IntImm> c1;
   // x / c1
   if (c1.Match(b) && c1.Eval()->value > 0) {
@@ -928,8 +924,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const 
FloorDivNode* op) {
   PrimExpr b = this->CanonicalMutate(op->b);
 
   // const folding
-  PrimExpr const_res = TryConstFold<FloorDiv>(a, b);
-  if (const_res.defined()) return const_res;
+  if (auto const_res = TryConstFold<FloorDiv>(a, b)) return const_res.value();
   PVar<IntImm> c1;
   // x / c1
   if (c1.Match(b) && c1.Eval()->value > 0) {
@@ -1037,8 +1032,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const 
ModNode* op) {
   PrimExpr b = this->CanonicalMutate(op->b);
 
   // const folding
-  PrimExpr const_res = TryConstFold<Mod>(a, b);
-  if (const_res.defined()) return const_res;
+  if (auto const_res = TryConstFold<Mod>(a, b)) return const_res.value();
 
   PVar<IntImm> c1;
   // x % c1
@@ -1105,8 +1099,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const 
FloorModNode* op) {
   PrimExpr b = this->CanonicalMutate(op->b);
 
   // const folding
-  PrimExpr const_res = TryConstFold<FloorMod>(a, b);
-  if (const_res.defined()) return const_res;
+  if (auto const_res = TryConstFold<FloorMod>(a, b)) return const_res.value();
 
   PVar<IntImm> c1;
   // x % c1
diff --git a/src/arith/const_fold.h b/src/arith/const_fold.h
index d0e09a1a74..a7466cf38c 100644
--- a/src/arith/const_fold.h
+++ b/src/arith/const_fold.h
@@ -24,6 +24,7 @@
 #ifndef TVM_ARITH_CONST_FOLD_H_
 #define TVM_ARITH_CONST_FOLD_H_
 
+#include <tvm/runtime/container/optional.h>
 #include <tvm/tir/expr.h>
 #include <tvm/tir/op.h>
 
@@ -44,10 +45,10 @@ namespace arith {
  * \tparam Op The operator type.
  *
  * \note a and b Must already matched data types with each other.
- * \return nullptr if constant fold fails, otherwise return folded result.
+ * \return NullOpt if constant fold fails, otherwise return folded result.
  */
 template <typename Op>
-inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b);
+inline Optional<PrimExpr> TryConstFold(PrimExpr a, PrimExpr b);
 
 /*!
  * \brief Try to run unary compute with constant folding.
@@ -56,10 +57,10 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b);
  * \tparam Op The operator type.
  *
  * \note a and b Must already matched data types with each other.
- * \return nullptr if constant fold fails, otherwise return folded result.
+ * \return NullOpt if constant fold fails, otherwise return folded result.
  */
 template <typename Op>
-inline PrimExpr TryConstFold(PrimExpr a);
+inline Optional<PrimExpr> TryConstFold(PrimExpr a);
 
 /*!
  * \brief Check whether type is used to represent index.
@@ -126,7 +127,7 @@ inline double GetFoldResultDoubleRepr(float x) {
 
 // specialization of constant folders.
 template <>
-inline PrimExpr TryConstFold<tir::Add>(PrimExpr a, PrimExpr b) {
+inline Optional<PrimExpr> TryConstFold<tir::Add>(PrimExpr a, PrimExpr b) {
   TVM_ARITH_CONST_PROPAGATION({
     const DataType& rtype = a.dtype();
     if (pa && pb) {
@@ -142,17 +143,17 @@ inline PrimExpr TryConstFold<tir::Add>(PrimExpr a, 
PrimExpr b) {
       } else if (rtype.bits() == 64) {
         return FloatImm(rtype, fa->value + fb->value);
       } else {
-        return PrimExpr();
+        return NullOpt;
       }
     }
     if (fa && fa->value == 0) return b;
     if (fb && fb->value == 0) return a;
   });
-  return PrimExpr();
+  return NullOpt;
 }
 
 template <>
-inline PrimExpr TryConstFold<tir::Sub>(PrimExpr a, PrimExpr b) {
+inline Optional<PrimExpr> TryConstFold<tir::Sub>(PrimExpr a, PrimExpr b) {
   TVM_ARITH_CONST_PROPAGATION({
     ICHECK(!((pa && pa->dtype.is_uint() && pa->value == 0U) &&
              (pb && pb->dtype.is_uint() && pb->value > 0U)))
@@ -171,16 +172,16 @@ inline PrimExpr TryConstFold<tir::Sub>(PrimExpr a, 
PrimExpr b) {
       } else if (rtype.bits() == 64) {
         return FloatImm(rtype, fa->value - fb->value);
       } else {
-        return PrimExpr();
+        return NullOpt;
       }
     }
     if (fb && fb->value == 0) return a;
   });
-  return PrimExpr();
+  return NullOpt;
 }
 
 template <>
-inline PrimExpr TryConstFold<tir::Mul>(PrimExpr a, PrimExpr b) {
+inline Optional<PrimExpr> TryConstFold<tir::Mul>(PrimExpr a, PrimExpr b) {
   TVM_ARITH_CONST_PROPAGATION({
     const DataType& rtype = a.dtype();
     if (pa && pb) {
@@ -202,7 +203,7 @@ inline PrimExpr TryConstFold<tir::Mul>(PrimExpr a, PrimExpr 
b) {
       } else if (rtype.bits() == 64) {
         return FloatImm(rtype, fa->value * fb->value);
       } else {
-        return PrimExpr();
+        return NullOpt;
       }
     }
     if (fa) {
@@ -214,11 +215,11 @@ inline PrimExpr TryConstFold<tir::Mul>(PrimExpr a, 
PrimExpr b) {
       if (fb->value == 0) return b;
     }
   });
-  return PrimExpr();
+  return NullOpt;
 }
 
 template <>
-inline PrimExpr TryConstFold<tir::Div>(PrimExpr a, PrimExpr b) {
+inline Optional<PrimExpr> TryConstFold<tir::Div>(PrimExpr a, PrimExpr b) {
   TVM_ARITH_CONST_PROPAGATION({
     const DataType& rtype = a.dtype();
     if (pa && pb) {
@@ -242,7 +243,7 @@ inline PrimExpr TryConstFold<tir::Div>(PrimExpr a, PrimExpr 
b) {
       } else if (rtype.bits() == 64) {
         return FloatImm(rtype, fa->value / fb->value);
       } else {
-        return PrimExpr();
+        return NullOpt;
       }
     }
     if (fa && fa->value == 0) return a;
@@ -251,11 +252,11 @@ inline PrimExpr TryConstFold<tir::Div>(PrimExpr a, 
PrimExpr b) {
       ICHECK_NE(fb->value, 0) << "Divide by zero";
     }
   });
-  return PrimExpr();
+  return NullOpt;
 }
 
 template <>
-inline PrimExpr TryConstFold<tir::Mod>(PrimExpr a, PrimExpr b) {
+inline Optional<PrimExpr> TryConstFold<tir::Mod>(PrimExpr a, PrimExpr b) {
   TVM_INDEX_CONST_PROPAGATION({
     const DataType& rtype = a.dtype();
     if (pa && pb) {
@@ -271,11 +272,11 @@ inline PrimExpr TryConstFold<tir::Mod>(PrimExpr a, 
PrimExpr b) {
       ICHECK_NE(pb->value, 0) << "Divide by zero";
     }
   });
-  return PrimExpr();
+  return NullOpt;
 }
 
 template <>
-inline PrimExpr TryConstFold<tir::FloorDiv>(PrimExpr a, PrimExpr b) {
+inline Optional<PrimExpr> TryConstFold<tir::FloorDiv>(PrimExpr a, PrimExpr b) {
   TVM_ARITH_CONST_PROPAGATION({
     const DataType& rtype = a.dtype();
     if (pa && pb) {
@@ -297,7 +298,7 @@ inline PrimExpr TryConstFold<tir::FloorDiv>(PrimExpr a, 
PrimExpr b) {
       } else if (rtype.bits() == 64) {
         return FloatImm(rtype, std::floor(fa->value / fb->value));
       } else {
-        return PrimExpr();
+        return NullOpt;
       }
     }
     if (fa && fa->value == 0) return a;
@@ -306,11 +307,11 @@ inline PrimExpr TryConstFold<tir::FloorDiv>(PrimExpr a, 
PrimExpr b) {
       ICHECK_NE(fb->value, 0) << "Divide by zero";
     }
   });
-  return PrimExpr();
+  return NullOpt;
 }
 
 template <>
-inline PrimExpr TryConstFold<tir::FloorMod>(PrimExpr a, PrimExpr b) {
+inline Optional<PrimExpr> TryConstFold<tir::FloorMod>(PrimExpr a, PrimExpr b) {
   TVM_INDEX_CONST_PROPAGATION({
     const DataType& rtype = a.dtype();
     if (pa && pb) {
@@ -326,114 +327,114 @@ inline PrimExpr TryConstFold<tir::FloorMod>(PrimExpr a, 
PrimExpr b) {
       ICHECK_NE(pb->value, 0) << "Divide by zero";
     }
   });
-  return PrimExpr();
+  return NullOpt;
 }
 
 template <>
-inline PrimExpr TryConstFold<tir::Min>(PrimExpr a, PrimExpr b) {
+inline Optional<PrimExpr> TryConstFold<tir::Min>(PrimExpr a, PrimExpr b) {
   TVM_ARITH_CONST_PROPAGATION({
     const DataType& rtype = a.dtype();
     if (pa && pb) return IntImm(rtype, std::min(pa->value, pb->value));
     if (fa && fb) return FloatImm(rtype, std::min(fa->value, fb->value));
   });
   if (a.same_as(b)) return a;
-  return PrimExpr();
+  return NullOpt;
 }
 
 template <>
-inline PrimExpr TryConstFold<tir::Max>(PrimExpr a, PrimExpr b) {
+inline Optional<PrimExpr> TryConstFold<tir::Max>(PrimExpr a, PrimExpr b) {
   TVM_ARITH_CONST_PROPAGATION({
     const DataType& rtype = a.dtype();
     if (pa && pb) return IntImm(rtype, std::max(pa->value, pb->value));
     if (fa && fb) return FloatImm(rtype, std::max(fa->value, fb->value));
   });
   if (a.same_as(b)) return a;
-  return PrimExpr();
+  return NullOpt;
 }
 
 template <>
-inline PrimExpr TryConstFold<tir::GT>(PrimExpr a, PrimExpr b) {
+inline Optional<PrimExpr> TryConstFold<tir::GT>(PrimExpr a, PrimExpr b) {
   TVM_ARITH_CONST_PROPAGATION({
     if (pa && pb) return IntImm(DataType::UInt(1), pa->value > pb->value);
     if (fa && fb) return IntImm(DataType::UInt(1), fa->value > fb->value);
   });
-  return PrimExpr();
+  return NullOpt;
 }
 
 template <>
-inline PrimExpr TryConstFold<tir::GE>(PrimExpr a, PrimExpr b) {
+inline Optional<PrimExpr> TryConstFold<tir::GE>(PrimExpr a, PrimExpr b) {
   TVM_ARITH_CONST_PROPAGATION({
     if (pa && pb) return IntImm(DataType::UInt(1), pa->value >= pb->value);
     if (fa && fb) return IntImm(DataType::UInt(1), fa->value >= fb->value);
   });
-  return PrimExpr();
+  return NullOpt;
 }
 
 template <>
-inline PrimExpr TryConstFold<tir::LT>(PrimExpr a, PrimExpr b) {
+inline Optional<PrimExpr> TryConstFold<tir::LT>(PrimExpr a, PrimExpr b) {
   TVM_ARITH_CONST_PROPAGATION({
     if (pa && pb) return IntImm(DataType::UInt(1), pa->value < pb->value);
     if (fa && fb) return IntImm(DataType::UInt(1), fa->value < fb->value);
   });
-  return PrimExpr();
+  return NullOpt;
 }
 
 template <>
-inline PrimExpr TryConstFold<tir::LE>(PrimExpr a, PrimExpr b) {
+inline Optional<PrimExpr> TryConstFold<tir::LE>(PrimExpr a, PrimExpr b) {
   TVM_ARITH_CONST_PROPAGATION({
     if (pa && pb) return IntImm(DataType::UInt(1), pa->value <= pb->value);
     if (fa && fb) return IntImm(DataType::UInt(1), fa->value <= fb->value);
   });
-  return PrimExpr();
+  return NullOpt;
 }
 
 template <>
-inline PrimExpr TryConstFold<tir::EQ>(PrimExpr a, PrimExpr b) {
+inline Optional<PrimExpr> TryConstFold<tir::EQ>(PrimExpr a, PrimExpr b) {
   TVM_ARITH_CONST_PROPAGATION({
     if (pa && pb) return IntImm(DataType::UInt(1), pa->value == pb->value);
     if (fa && fb) return IntImm(DataType::UInt(1), fa->value == fb->value);
   });
-  return PrimExpr();
+  return NullOpt;
 }
 
 template <>
-inline PrimExpr TryConstFold<tir::NE>(PrimExpr a, PrimExpr b) {
+inline Optional<PrimExpr> TryConstFold<tir::NE>(PrimExpr a, PrimExpr b) {
   TVM_ARITH_CONST_PROPAGATION({
     if (pa && pb) return IntImm(DataType::UInt(1), pa->value != pb->value);
     if (fa && fb) return IntImm(DataType::UInt(1), fa->value != fb->value);
   });
-  return PrimExpr();
+  return NullOpt;
 }
 
 template <>
-inline PrimExpr TryConstFold<tir::And>(PrimExpr a, PrimExpr b) {
+inline Optional<PrimExpr> TryConstFold<tir::And>(PrimExpr a, PrimExpr b) {
   const IntImmNode* pa = a.as<IntImmNode>();
   const IntImmNode* pb = b.as<IntImmNode>();
   if (pa && pa->value) return b;
   if (pa && !pa->value) return a;
   if (pb && pb->value) return a;
   if (pb && !pb->value) return b;
-  return PrimExpr();
+  return NullOpt;
 }
 
 template <>
-inline PrimExpr TryConstFold<tir::Or>(PrimExpr a, PrimExpr b) {
+inline Optional<PrimExpr> TryConstFold<tir::Or>(PrimExpr a, PrimExpr b) {
   const IntImmNode* pa = a.as<IntImmNode>();
   const IntImmNode* pb = b.as<IntImmNode>();
   if (pa && pa->value) return a;
   if (pa && !pa->value) return b;
   if (pb && pb->value) return b;
   if (pb && !pb->value) return a;
-  return PrimExpr();
+  return NullOpt;
 }
 
 template <>
-inline PrimExpr TryConstFold<tir::Not>(PrimExpr a) {
+inline Optional<PrimExpr> TryConstFold<tir::Not>(PrimExpr a) {
   const IntImmNode* pa = a.as<IntImmNode>();
   if (pa) {
     return IntImm(DataType::UInt(1), !(pa->value));
   }
-  return PrimExpr();
+  return NullOpt;
 }
 
 /*! \brief Helper namespace for symbolic value limits */
diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc
index e8e223ceca..35b12bb352 100644
--- a/src/arith/int_set.cc
+++ b/src/arith/int_set.cc
@@ -108,9 +108,13 @@ TVM_DECLARE_LOGICAL_OP(Not);
 template <typename Op>
 inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b, 
DataType dtype) {
   if (a->IsSinglePoint() && b->IsSinglePoint()) {
-    PrimExpr res = TryConstFold<Op>(a->min_value, b->min_value);
-    if (!res.defined()) res = Op(a->min_value, b->min_value);
-    return IntervalSet::SinglePoint(res);
+    PrimExpr expr;
+    if (auto res = TryConstFold<Op>(a->min_value, b->min_value)) {
+      expr = res.value();
+    } else {
+      expr = Op(a->min_value, b->min_value);
+    }
+    return IntervalSet::SinglePoint(expr);
   }
   if (is_logical_op<Op>::value) {
     return IntervalSet(make_const(dtype, 0), make_const(dtype, 1));
diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc
index 83e2821c98..182eada24d 100644
--- a/src/arith/iter_affine_map.cc
+++ b/src/arith/iter_affine_map.cc
@@ -1205,8 +1205,7 @@ PrimExpr IterMapRewriter::VisitExpr_(const AddNode* op) {
   PrimExpr b = this->DirectMutate(op->b);
 
   // const folding
-  PrimExpr const_res = TryConstFold<Add>(a, b);
-  if (const_res.defined()) return const_res;
+  if (auto const_res = TryConstFold<Add>(a, b)) return const_res.value();
   // does not contain iter map.
   if (!a->IsInstance<IterMapExprNode>() && !b->IsInstance<IterMapExprNode>()) {
     if (op->a.same_as(a) && op->b.same_as(b)) {
@@ -1240,8 +1239,7 @@ PrimExpr IterMapRewriter::VisitExpr_(const SubNode* op) {
   PrimExpr b = this->DirectMutate(op->b);
 
   // const folding
-  PrimExpr const_res = TryConstFold<Sub>(a, b);
-  if (const_res.defined()) return const_res;
+  if (auto const_res = TryConstFold<Sub>(a, b)) return const_res.value();
 
   // does not contain iter map.
   if (!a->IsInstance<IterMapExprNode>() && !b->IsInstance<IterMapExprNode>()) {
@@ -1276,8 +1274,7 @@ PrimExpr IterMapRewriter::VisitExpr_(const MulNode* op) {
   PrimExpr b = this->DirectMutate(op->b);
 
   // const folding
-  PrimExpr const_res = TryConstFold<Mul>(a, b);
-  if (const_res.defined()) return const_res;
+  if (auto const_res = TryConstFold<Mul>(a, b)) return const_res.value();
 
   // does not contain iter map.
   if (!a->IsInstance<IterMapExprNode>() && !b->IsInstance<IterMapExprNode>()) {
@@ -1572,8 +1569,7 @@ PrimExpr IterMapRewriter::VisitExpr_(const FloorDivNode* 
op) {
   PrimExpr b = this->DirectMutate(op->b);
 
   // const folding
-  PrimExpr const_res = TryConstFold<FloorDiv>(a, b);
-  if (const_res.defined()) return const_res;
+  if (auto const_res = TryConstFold<FloorDiv>(a, b)) return const_res.value();
 
   // does not contain iter map.
   if (!a->IsInstance<IterMapExprNode>() && !b->IsInstance<IterMapExprNode>()) {
@@ -1657,8 +1653,7 @@ PrimExpr IterMapRewriter::VisitExpr_(const FloorModNode* 
op) {
   PrimExpr b = this->DirectMutate(op->b);
 
   // const folding
-  PrimExpr const_res = TryConstFold<FloorMod>(a, b);
-  if (const_res.defined()) return const_res;
+  if (auto const_res = TryConstFold<FloorMod>(a, b)) return const_res.value();
 
   // does not contain iter map.
   if (!a->IsInstance<IterMapExprNode>() && !b->IsInstance<IterMapExprNode>()) {
diff --git a/src/arith/pattern_match.h b/src/arith/pattern_match.h
index 6abcc728fc..69f064e119 100644
--- a/src/arith/pattern_match.h
+++ b/src/arith/pattern_match.h
@@ -330,8 +330,7 @@ class PBinaryExpr : public Pattern<PBinaryExpr<OpType, TA, 
TB>> {
   PrimExpr Eval() const {
     PrimExpr lhs = a_.Eval();
     PrimExpr rhs = b_.Eval();
-    PrimExpr ret = TryConstFold<OpType>(lhs, rhs);
-    if (ret.defined()) return ret;
+    if (auto ret = TryConstFold<OpType>(lhs, rhs)) return ret.value();
     return OpType(lhs, rhs);
   }
 
diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc
index d7866fc130..e3e9db62d0 100644
--- a/src/arith/rewrite_simplify.cc
+++ b/src/arith/rewrite_simplify.cc
@@ -124,8 +124,7 @@ void RewriteSimplifier::Impl::Update(const Var& var, const 
PrimExpr& info, bool
 PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AddNode* op) {
   PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
   op = ret.as<AddNode>();
-  PrimExpr const_res = TryConstFold<Add>(op->a, op->b);
-  if (const_res.defined()) return const_res;
+  if (auto const_res = TryConstFold<Add>(op->a, op->b)) return 
const_res.value();
   // Pattern var to match any expression
   PVar<PrimExpr> x, y, z, b1, b2, s1, s2;
   // Pattern var match IntImm
@@ -258,8 +257,7 @@ std::function<void()> 
RewriteSimplifier::Impl::EnterConstraint(const PrimExpr& c
 PrimExpr RewriteSimplifier::Impl::VisitExpr_(const SubNode* op) {
   PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
   op = ret.as<SubNode>();
-  PrimExpr const_res = TryConstFold<Sub>(op->a, op->b);
-  if (const_res.defined()) return const_res;
+  if (auto const_res = TryConstFold<Sub>(op->a, op->b)) return 
const_res.value();
   // Pattern var to match any expression
   PVar<PrimExpr> x, y, z, b1, b2, s1, s2;
   // Pattern var match IntImm
@@ -450,8 +448,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const SubNode* 
op) {
 PrimExpr RewriteSimplifier::Impl::VisitExpr_(const MulNode* op) {
   PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
   op = ret.as<MulNode>();
-  PrimExpr const_res = TryConstFold<Mul>(op->a, op->b);
-  if (const_res.defined()) return const_res;
+  if (auto const_res = TryConstFold<Mul>(op->a, op->b)) return 
const_res.value();
   // Pattern var to match any expression
   PVar<PrimExpr> x, y, z, b1, b2, s1, s2;
   // Pattern var match IntImm
@@ -490,8 +487,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const MulNode* 
op) {
 PrimExpr RewriteSimplifier::Impl::VisitExpr_(const DivNode* op) {
   PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
   op = ret.as<DivNode>();
-  PrimExpr const_res = TryConstFold<Div>(op->a, op->b);
-  if (const_res.defined()) return const_res;
+  if (auto const_res = TryConstFold<Div>(op->a, op->b)) return 
const_res.value();
   // Pattern var to match any expression
   PVar<PrimExpr> x, y, z, b1;
   // Pattern var match IntImm
@@ -666,8 +662,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const DivNode* 
op) {
 PrimExpr RewriteSimplifier::Impl::VisitExpr_(const ModNode* op) {
   PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
   op = ret.as<ModNode>();
-  PrimExpr const_res = TryConstFold<Mod>(op->a, op->b);
-  if (const_res.defined()) return const_res;
+  if (auto const_res = TryConstFold<Mod>(op->a, op->b)) return 
const_res.value();
 
   // Pattern var to match any expression
   PVar<PrimExpr> x, y, z, b1;
@@ -748,8 +743,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const ModNode* 
op) {
 PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) {
   PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
   op = ret.as<FloorDivNode>();
-  PrimExpr const_res = TryConstFold<FloorDiv>(op->a, op->b);
-  if (const_res.defined()) return const_res;
+  if (auto const_res = TryConstFold<FloorDiv>(op->a, op->b)) return 
const_res.value();
   // Pattern var to match any expression
   PVar<PrimExpr> x, y, z, b1;
   // Pattern var match IntImm
@@ -895,8 +889,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const 
FloorDivNode* op) {
 PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorModNode* op) {
   PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
   op = ret.as<FloorModNode>();
-  PrimExpr const_res = TryConstFold<FloorMod>(op->a, op->b);
-  if (const_res.defined()) return const_res;
+  if (auto const_res = TryConstFold<FloorMod>(op->a, op->b)) return 
const_res.value();
 
   // Pattern var to match any expression
   PVar<PrimExpr> x, y, z, b1;
@@ -977,8 +970,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const 
FloorModNode* op) {
 PrimExpr RewriteSimplifier::Impl::VisitExpr_(const MinNode* op) {
   PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
   op = ret.as<MinNode>();
-  PrimExpr const_res = TryConstFold<Min>(op->a, op->b);
-  if (const_res.defined()) return const_res;
+  if (auto const_res = TryConstFold<Min>(op->a, op->b)) return 
const_res.value();
 
   // Pattern var to match any expression
   PVar<PrimExpr> x, y, z, s1, s2;
@@ -1149,8 +1141,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const 
MinNode* op) {
 PrimExpr RewriteSimplifier::Impl::VisitExpr_(const MaxNode* op) {
   PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
   op = ret.as<MaxNode>();
-  PrimExpr const_res = TryConstFold<Max>(op->a, op->b);
-  if (const_res.defined()) return const_res;
+  if (auto const_res = TryConstFold<Max>(op->a, op->b)) return 
const_res.value();
 
   // Pattern var to match any expression
   PVar<PrimExpr> x, y, z, s1, s2;
@@ -1327,8 +1318,7 @@ Optional<PrimExpr> 
RewriteSimplifier::Impl::TryMatchLiteralConstraint(const Prim
 PrimExpr RewriteSimplifier::Impl::VisitExpr_(const EQNode* op) {
   PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
   op = ret.as<EQNode>();
-  PrimExpr const_res = TryConstFold<EQ>(op->a, op->b);
-  if (const_res.defined()) return const_res;
+  if (auto const_res = TryConstFold<EQ>(op->a, op->b)) return 
const_res.value();
   if (auto match = TryMatchLiteralConstraint(ret)) return match.value();
 
   // Pattern var to match any expression
@@ -1376,8 +1366,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const 
GENode* op) {
 PrimExpr RewriteSimplifier::Impl::VisitExpr_(const LTNode* op) {
   PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
   op = ret.as<LTNode>();
-  PrimExpr const_res = TryConstFold<LT>(op->a, op->b);
-  if (const_res.defined()) return const_res;
+  if (auto const_res = TryConstFold<LT>(op->a, op->b)) return 
const_res.value();
   if (auto match = TryMatchLiteralConstraint(ret)) return match.value();
 
   // Pattern var to match any expression
@@ -1508,8 +1497,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const 
LTNode* op) {
 PrimExpr RewriteSimplifier::Impl::VisitExpr_(const NotNode* op) {
   PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
   op = ret.as<NotNode>();
-  PrimExpr const_res = TryConstFold<Not>(op->a);
-  if (const_res.defined()) return const_res;
+  if (auto const_res = TryConstFold<Not>(op->a)) return const_res.value();
   if (auto match = TryMatchLiteralConstraint(ret)) return match.value();
 
   // Pattern var to match any expression
@@ -1534,8 +1522,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const 
NotNode* op) {
 PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AndNode* op) {
   PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
   op = ret.as<AndNode>();
-  PrimExpr const_res = TryConstFold<And>(op->a, op->b);
-  if (const_res.defined()) return const_res;
+  if (auto const_res = TryConstFold<And>(op->a, op->b)) return 
const_res.value();
   if (auto match = TryMatchLiteralConstraint(ret)) return match.value();
 
   // Pattern var to match any expression
@@ -1574,8 +1561,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const 
AndNode* op) {
 PrimExpr RewriteSimplifier::Impl::VisitExpr_(const OrNode* op) {
   PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
   op = ret.as<OrNode>();
-  PrimExpr const_res = TryConstFold<Or>(op->a, op->b);
-  if (const_res.defined()) return const_res;
+  if (auto const_res = TryConstFold<Or>(op->a, op->b)) return 
const_res.value();
   if (auto match = TryMatchLiteralConstraint(ret)) return match.value();
 
   // Pattern var to match any expression
diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc
index b9e0c3c370..509badbebb 100644
--- a/src/tir/op/op.cc
+++ b/src/tir/op/op.cc
@@ -327,8 +327,7 @@ PrimExpr operator+(PrimExpr a, PrimExpr b) { return add(a, 
b); }
 
 PrimExpr add(PrimExpr a, PrimExpr b, Span span) {
   BinaryOpMatchTypes(a, b, span);
-  PrimExpr ret = arith::TryConstFold<tir::Add>(a, b);
-  if (ret.defined()) return ret;
+  if (auto ret = arith::TryConstFold<tir::Add>(a, b)) return ret.value();
   return tir::Add(a, b, span);
 }
 
@@ -349,23 +348,20 @@ PrimExpr operator-(PrimExpr a, PrimExpr b) { return 
sub(a, b); }
 
 PrimExpr sub(PrimExpr a, PrimExpr b, Span span) {
   BinaryOpMatchTypes(a, b, span);
-  PrimExpr ret = arith::TryConstFold<tir::Sub>(a, b);
-  if (ret.defined()) return ret;
+  if (auto ret = arith::TryConstFold<tir::Sub>(a, b)) return ret.value();
   return tir::Sub(a, b, span);
 }
 
 PrimExpr operator*(PrimExpr a, PrimExpr b) { return mul(a, b); }
 PrimExpr mul(PrimExpr a, PrimExpr b, Span span) {
   BinaryOpMatchTypes(a, b, span);
-  PrimExpr ret = arith::TryConstFold<tir::Mul>(a, b);
-  if (ret.defined()) return ret;
+  if (auto ret = arith::TryConstFold<tir::Mul>(a, b)) return ret.value();
   return tir::Mul(a, b, span);
 }
 
 PrimExpr div(PrimExpr a, PrimExpr b, Span span) {
   BinaryOpMatchTypes(a, b, span);
-  PrimExpr ret = arith::TryConstFold<tir::Div>(a, b);
-  if (ret.defined()) return ret;
+  if (auto ret = arith::TryConstFold<tir::Div>(a, b)) return ret.value();
   return tir::Div(a, b, span);
 }
 
@@ -377,8 +373,7 @@ PrimExpr truncdiv(PrimExpr a, PrimExpr b, Span span) {
 
 PrimExpr truncmod(PrimExpr a, PrimExpr b, Span span) {
   BinaryOpMatchTypes(a, b, span);
-  PrimExpr ret = arith::TryConstFold<tir::Mod>(a, b);
-  if (ret.defined()) return ret;
+  if (auto ret = arith::TryConstFold<tir::Mod>(a, b)) return ret.value();
   return tir::Mod(a, b, span);
 }
 
@@ -397,8 +392,7 @@ PrimExpr floordiv(PrimExpr a, PrimExpr b, Span span) {
   ICHECK(a.dtype().is_int() || a.dtype().is_uint()) << a;
   ICHECK(b.dtype().is_int() || b.dtype().is_uint()) << b;
   BinaryOpMatchTypes(a, b, span);
-  PrimExpr ret = arith::TryConstFold<tir::FloorDiv>(a, b);
-  if (ret.defined()) return ret;
+  if (auto ret = arith::TryConstFold<tir::FloorDiv>(a, b)) return ret.value();
   return tir::FloorDiv(a, b, span);
 }
 
@@ -406,8 +400,7 @@ PrimExpr ceildiv(PrimExpr a, PrimExpr b, Span span) {
   ICHECK(a.dtype().is_int() || a.dtype().is_uint()) << a;
   ICHECK(b.dtype().is_int() || b.dtype().is_uint()) << b;
   BinaryOpMatchTypes(a, b, span);
-  PrimExpr ret = arith::TryConstFold<tir::FloorDiv>(a + b - 1, b);
-  if (ret.defined()) return ret;
+  if (auto ret = arith::TryConstFold<tir::FloorDiv>(a + b - 1, b)) return 
ret.value();
   return tir::FloorDiv(a + b - 1, b, span);
 }
 
@@ -415,8 +408,7 @@ PrimExpr floormod(PrimExpr a, PrimExpr b, Span span) {
   ICHECK(a.dtype().is_int() || a.dtype().is_uint()) << a;
   ICHECK(b.dtype().is_int() || b.dtype().is_uint()) << b;
   BinaryOpMatchTypes(a, b, span);
-  PrimExpr ret = arith::TryConstFold<tir::FloorMod>(a, b);
-  if (ret.defined()) return ret;
+  if (auto ret = arith::TryConstFold<tir::FloorMod>(a, b)) return ret.value();
   return tir::FloorMod(a, b, span);
 }
 
@@ -429,8 +421,7 @@ PrimExpr min(PrimExpr a, PrimExpr b, Span span) {
   if (is_pos_inf(b)) return a;
   if (is_neg_inf(b)) return b;
   BinaryOpMatchTypes(a, b, span);
-  PrimExpr ret = arith::TryConstFold<tir::Min>(a, b);
-  if (ret.defined()) return ret;
+  if (auto ret = arith::TryConstFold<tir::Min>(a, b)) return ret.value();
   return tir::Min(a, b, span);
 }
 
@@ -443,8 +434,7 @@ PrimExpr max(PrimExpr a, PrimExpr b, Span span) {
   if (is_pos_inf(b)) return b;
   if (is_neg_inf(b)) return a;
   BinaryOpMatchTypes(a, b, span);
-  PrimExpr ret = arith::TryConstFold<tir::Max>(a, b);
-  if (ret.defined()) return ret;
+  if (auto ret = arith::TryConstFold<tir::Max>(a, b)) return ret.value();
   return tir::Max(a, b, span);
 }
 
@@ -475,48 +465,42 @@ PrimExpr likely(PrimExpr cond, Span span) {
 PrimExpr operator>(PrimExpr a, PrimExpr b) { return greater(a, b); }
 PrimExpr greater(PrimExpr a, PrimExpr b, Span span) {
   BinaryOpMatchTypes(a, b, span);
-  PrimExpr ret = arith::TryConstFold<tir::GT>(a, b);
-  if (ret.defined()) return ret;
+  if (auto ret = arith::TryConstFold<tir::GT>(a, b)) return ret.value();
   return tir::GT(a, b, span);
 }
 
 PrimExpr operator>=(PrimExpr a, PrimExpr b) { return greater_equal(a, b); }
 PrimExpr greater_equal(PrimExpr a, PrimExpr b, Span span) {
   BinaryOpMatchTypes(a, b, span);
-  PrimExpr ret = arith::TryConstFold<tir::GE>(a, b);
-  if (ret.defined()) return ret;
+  if (auto ret = arith::TryConstFold<tir::GE>(a, b)) return ret.value();
   return tir::GE(a, b, span);
 }
 
 PrimExpr operator<(PrimExpr a, PrimExpr b) { return less(a, b); }
 PrimExpr less(PrimExpr a, PrimExpr b, Span span) {
   BinaryOpMatchTypes(a, b, span);
-  PrimExpr ret = arith::TryConstFold<tir::LT>(a, b);
-  if (ret.defined()) return ret;
+  if (auto ret = arith::TryConstFold<tir::LT>(a, b)) return ret.value();
   return tir::LT(a, b, span);
 }
 
 PrimExpr operator<=(PrimExpr a, PrimExpr b) { return less_equal(a, b); }
 PrimExpr less_equal(PrimExpr a, PrimExpr b, Span span) {
   BinaryOpMatchTypes(a, b, span);
-  PrimExpr ret = arith::TryConstFold<tir::LE>(a, b);
-  if (ret.defined()) return ret;
+  if (auto ret = arith::TryConstFold<tir::LE>(a, b)) return ret.value();
   return tir::LE(a, b, span);
 }
 
 PrimExpr operator==(PrimExpr a, PrimExpr b) { return equal(a, b); }
 PrimExpr equal(PrimExpr a, PrimExpr b, Span span) {
   BinaryOpMatchTypes(a, b, span);
-  PrimExpr ret = arith::TryConstFold<tir::EQ>(a, b);
-  if (ret.defined()) return ret;
+  if (auto ret = arith::TryConstFold<tir::EQ>(a, b)) return ret.value();
   return tir::EQ(a, b, span);
 }
 
 PrimExpr operator!=(PrimExpr a, PrimExpr b) { return not_equal(a, b); }
 PrimExpr not_equal(PrimExpr a, PrimExpr b, Span span) {
   BinaryOpMatchTypes(a, b, span);
-  PrimExpr ret = arith::TryConstFold<tir::NE>(a, b);
-  if (ret.defined()) return ret;
+  if (auto ret = arith::TryConstFold<tir::NE>(a, b)) return ret.value();
   return tir::NE(a, b, span);
 }
 
@@ -551,24 +535,21 @@ void type_check_integer_args(const PrimExpr& lhs, const 
PrimExpr& rhs, const cha
 PrimExpr operator&&(PrimExpr a, PrimExpr b) { return logical_and(a, b); }
 PrimExpr logical_and(PrimExpr a, PrimExpr b, Span span) {
   type_check_boolean_args(a, b, "&& operator (logical AND)");
-  PrimExpr ret = arith::TryConstFold<tir::And>(a, b);
-  if (ret.defined()) return ret;
+  if (auto ret = arith::TryConstFold<tir::And>(a, b)) return ret.value();
   return tir::And(a, b, span);
 }
 
 PrimExpr operator||(PrimExpr a, PrimExpr b) { return logical_or(a, b); }
 PrimExpr logical_or(PrimExpr a, PrimExpr b, Span span) {
   type_check_boolean_args(a, b, "|| operator (logical OR)");
-  PrimExpr ret = arith::TryConstFold<tir::Or>(a, b);
-  if (ret.defined()) return ret;
+  if (auto ret = arith::TryConstFold<tir::Or>(a, b)) return ret.value();
   return tir::Or(a, b, span);
 }
 
 PrimExpr operator!(PrimExpr a) { return logical_not(a); }
 PrimExpr logical_not(PrimExpr a, Span span) {
   type_check_boolean_args(a, "! operator (logical NOT)");
-  PrimExpr ret = arith::TryConstFold<tir::Not>(a);
-  if (ret.defined()) return ret;
+  if (auto ret = arith::TryConstFold<tir::Not>(a)) return ret.value();
   return tir::Not(a, span);
 }
 

Reply via email to