This is an automated email from the ASF dual-hosted git repository.
wrongtest pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 397cf8781e [Arith][Refactor] Return Optional<PrimExpr> from
TryConstFold (#12784)
397cf8781e is described below
commit 397cf8781eba7a2bcc35e832130801c1d1419c43
Author: Eric Lunderberg <[email protected]>
AuthorDate: Thu Sep 15 06:39:20 2022 -0500
[Arith][Refactor] Return Optional<PrimExpr> from TryConstFold (#12784)
Prior to this commit, the templated `TryConstFold` utility returned an
undefined `PrimExpr` to represent a failure to perform constant
folding. This commit makes this explicit by returning
`Optional<PrimExpr>` instead.
---
src/arith/canonical_simplify.cc | 21 ++++------
src/arith/const_fold.h | 91 +++++++++++++++++++++--------------------
src/arith/int_set.cc | 10 +++--
src/arith/iter_affine_map.cc | 15 +++----
src/arith/pattern_match.h | 3 +-
src/arith/rewrite_simplify.cc | 42 +++++++------------
src/tir/op/op.cc | 57 +++++++++-----------------
7 files changed, 99 insertions(+), 140 deletions(-)
diff --git a/src/arith/canonical_simplify.cc b/src/arith/canonical_simplify.cc
index 9f45317cba..f5d2667aa6 100644
--- a/src/arith/canonical_simplify.cc
+++ b/src/arith/canonical_simplify.cc
@@ -716,8 +716,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const
AddNode* op) {
PrimExpr b = this->CanonicalMutate(op->b);
// const folding
- PrimExpr const_res = TryConstFold<Add>(a, b);
- if (const_res.defined()) return const_res;
+ if (auto const_res = TryConstFold<Add>(a, b)) return const_res.value();
// canonical form simplification.
SumExpr ret = ToSumExpr(std::move(a));
@@ -741,8 +740,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const
SubNode* op) {
PrimExpr b = this->CanonicalMutate(op->b);
// const folding
- PrimExpr const_res = TryConstFold<Sub>(a, b);
- if (const_res.defined()) return const_res;
+ if (auto const_res = TryConstFold<Sub>(a, b)) return const_res.value();
// canonical form simplification.
SumExpr ret = ToSumExpr(std::move(a));
@@ -766,8 +764,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const
MulNode* op) {
PrimExpr b = this->CanonicalMutate(op->b);
// const folding
- PrimExpr const_res = TryConstFold<Mul>(a, b);
- if (const_res.defined()) return const_res;
+ if (auto const_res = TryConstFold<Mul>(a, b)) return const_res.value();
// x * c
if (a.as<IntImmNode>()) {
@@ -870,8 +867,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const
DivNode* op) {
PrimExpr b = this->CanonicalMutate(op->b);
// const folding
- PrimExpr const_res = TryConstFold<Div>(a, b);
- if (const_res.defined()) return const_res;
+ if (auto const_res = TryConstFold<Div>(a, b)) return const_res.value();
PVar<IntImm> c1;
// x / c1
if (c1.Match(b) && c1.Eval()->value > 0) {
@@ -928,8 +924,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const
FloorDivNode* op) {
PrimExpr b = this->CanonicalMutate(op->b);
// const folding
- PrimExpr const_res = TryConstFold<FloorDiv>(a, b);
- if (const_res.defined()) return const_res;
+ if (auto const_res = TryConstFold<FloorDiv>(a, b)) return const_res.value();
PVar<IntImm> c1;
// x / c1
if (c1.Match(b) && c1.Eval()->value > 0) {
@@ -1037,8 +1032,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const
ModNode* op) {
PrimExpr b = this->CanonicalMutate(op->b);
// const folding
- PrimExpr const_res = TryConstFold<Mod>(a, b);
- if (const_res.defined()) return const_res;
+ if (auto const_res = TryConstFold<Mod>(a, b)) return const_res.value();
PVar<IntImm> c1;
// x % c1
@@ -1105,8 +1099,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const
FloorModNode* op) {
PrimExpr b = this->CanonicalMutate(op->b);
// const folding
- PrimExpr const_res = TryConstFold<FloorMod>(a, b);
- if (const_res.defined()) return const_res;
+ if (auto const_res = TryConstFold<FloorMod>(a, b)) return const_res.value();
PVar<IntImm> c1;
// x % c1
diff --git a/src/arith/const_fold.h b/src/arith/const_fold.h
index d0e09a1a74..a7466cf38c 100644
--- a/src/arith/const_fold.h
+++ b/src/arith/const_fold.h
@@ -24,6 +24,7 @@
#ifndef TVM_ARITH_CONST_FOLD_H_
#define TVM_ARITH_CONST_FOLD_H_
+#include <tvm/runtime/container/optional.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
@@ -44,10 +45,10 @@ namespace arith {
* \tparam Op The operator type.
*
* \note a and b Must already matched data types with each other.
- * \return nullptr if constant fold fails, otherwise return folded result.
+ * \return NullOpt if constant fold fails, otherwise return folded result.
*/
template <typename Op>
-inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b);
+inline Optional<PrimExpr> TryConstFold(PrimExpr a, PrimExpr b);
/*!
* \brief Try to run unary compute with constant folding.
@@ -56,10 +57,10 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b);
* \tparam Op The operator type.
*
* \note a and b Must already matched data types with each other.
- * \return nullptr if constant fold fails, otherwise return folded result.
+ * \return NullOpt if constant fold fails, otherwise return folded result.
*/
template <typename Op>
-inline PrimExpr TryConstFold(PrimExpr a);
+inline Optional<PrimExpr> TryConstFold(PrimExpr a);
/*!
* \brief Check whether type is used to represent index.
@@ -126,7 +127,7 @@ inline double GetFoldResultDoubleRepr(float x) {
// specialization of constant folders.
template <>
-inline PrimExpr TryConstFold<tir::Add>(PrimExpr a, PrimExpr b) {
+inline Optional<PrimExpr> TryConstFold<tir::Add>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
if (pa && pb) {
@@ -142,17 +143,17 @@ inline PrimExpr TryConstFold<tir::Add>(PrimExpr a,
PrimExpr b) {
} else if (rtype.bits() == 64) {
return FloatImm(rtype, fa->value + fb->value);
} else {
- return PrimExpr();
+ return NullOpt;
}
}
if (fa && fa->value == 0) return b;
if (fb && fb->value == 0) return a;
});
- return PrimExpr();
+ return NullOpt;
}
template <>
-inline PrimExpr TryConstFold<tir::Sub>(PrimExpr a, PrimExpr b) {
+inline Optional<PrimExpr> TryConstFold<tir::Sub>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({
ICHECK(!((pa && pa->dtype.is_uint() && pa->value == 0U) &&
(pb && pb->dtype.is_uint() && pb->value > 0U)))
@@ -171,16 +172,16 @@ inline PrimExpr TryConstFold<tir::Sub>(PrimExpr a,
PrimExpr b) {
} else if (rtype.bits() == 64) {
return FloatImm(rtype, fa->value - fb->value);
} else {
- return PrimExpr();
+ return NullOpt;
}
}
if (fb && fb->value == 0) return a;
});
- return PrimExpr();
+ return NullOpt;
}
template <>
-inline PrimExpr TryConstFold<tir::Mul>(PrimExpr a, PrimExpr b) {
+inline Optional<PrimExpr> TryConstFold<tir::Mul>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
if (pa && pb) {
@@ -202,7 +203,7 @@ inline PrimExpr TryConstFold<tir::Mul>(PrimExpr a, PrimExpr
b) {
} else if (rtype.bits() == 64) {
return FloatImm(rtype, fa->value * fb->value);
} else {
- return PrimExpr();
+ return NullOpt;
}
}
if (fa) {
@@ -214,11 +215,11 @@ inline PrimExpr TryConstFold<tir::Mul>(PrimExpr a,
PrimExpr b) {
if (fb->value == 0) return b;
}
});
- return PrimExpr();
+ return NullOpt;
}
template <>
-inline PrimExpr TryConstFold<tir::Div>(PrimExpr a, PrimExpr b) {
+inline Optional<PrimExpr> TryConstFold<tir::Div>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
if (pa && pb) {
@@ -242,7 +243,7 @@ inline PrimExpr TryConstFold<tir::Div>(PrimExpr a, PrimExpr
b) {
} else if (rtype.bits() == 64) {
return FloatImm(rtype, fa->value / fb->value);
} else {
- return PrimExpr();
+ return NullOpt;
}
}
if (fa && fa->value == 0) return a;
@@ -251,11 +252,11 @@ inline PrimExpr TryConstFold<tir::Div>(PrimExpr a,
PrimExpr b) {
ICHECK_NE(fb->value, 0) << "Divide by zero";
}
});
- return PrimExpr();
+ return NullOpt;
}
template <>
-inline PrimExpr TryConstFold<tir::Mod>(PrimExpr a, PrimExpr b) {
+inline Optional<PrimExpr> TryConstFold<tir::Mod>(PrimExpr a, PrimExpr b) {
TVM_INDEX_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
if (pa && pb) {
@@ -271,11 +272,11 @@ inline PrimExpr TryConstFold<tir::Mod>(PrimExpr a,
PrimExpr b) {
ICHECK_NE(pb->value, 0) << "Divide by zero";
}
});
- return PrimExpr();
+ return NullOpt;
}
template <>
-inline PrimExpr TryConstFold<tir::FloorDiv>(PrimExpr a, PrimExpr b) {
+inline Optional<PrimExpr> TryConstFold<tir::FloorDiv>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
if (pa && pb) {
@@ -297,7 +298,7 @@ inline PrimExpr TryConstFold<tir::FloorDiv>(PrimExpr a,
PrimExpr b) {
} else if (rtype.bits() == 64) {
return FloatImm(rtype, std::floor(fa->value / fb->value));
} else {
- return PrimExpr();
+ return NullOpt;
}
}
if (fa && fa->value == 0) return a;
@@ -306,11 +307,11 @@ inline PrimExpr TryConstFold<tir::FloorDiv>(PrimExpr a,
PrimExpr b) {
ICHECK_NE(fb->value, 0) << "Divide by zero";
}
});
- return PrimExpr();
+ return NullOpt;
}
template <>
-inline PrimExpr TryConstFold<tir::FloorMod>(PrimExpr a, PrimExpr b) {
+inline Optional<PrimExpr> TryConstFold<tir::FloorMod>(PrimExpr a, PrimExpr b) {
TVM_INDEX_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
if (pa && pb) {
@@ -326,114 +327,114 @@ inline PrimExpr TryConstFold<tir::FloorMod>(PrimExpr a,
PrimExpr b) {
ICHECK_NE(pb->value, 0) << "Divide by zero";
}
});
- return PrimExpr();
+ return NullOpt;
}
template <>
-inline PrimExpr TryConstFold<tir::Min>(PrimExpr a, PrimExpr b) {
+inline Optional<PrimExpr> TryConstFold<tir::Min>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
if (pa && pb) return IntImm(rtype, std::min(pa->value, pb->value));
if (fa && fb) return FloatImm(rtype, std::min(fa->value, fb->value));
});
if (a.same_as(b)) return a;
- return PrimExpr();
+ return NullOpt;
}
template <>
-inline PrimExpr TryConstFold<tir::Max>(PrimExpr a, PrimExpr b) {
+inline Optional<PrimExpr> TryConstFold<tir::Max>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
if (pa && pb) return IntImm(rtype, std::max(pa->value, pb->value));
if (fa && fb) return FloatImm(rtype, std::max(fa->value, fb->value));
});
if (a.same_as(b)) return a;
- return PrimExpr();
+ return NullOpt;
}
template <>
-inline PrimExpr TryConstFold<tir::GT>(PrimExpr a, PrimExpr b) {
+inline Optional<PrimExpr> TryConstFold<tir::GT>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({
if (pa && pb) return IntImm(DataType::UInt(1), pa->value > pb->value);
if (fa && fb) return IntImm(DataType::UInt(1), fa->value > fb->value);
});
- return PrimExpr();
+ return NullOpt;
}
template <>
-inline PrimExpr TryConstFold<tir::GE>(PrimExpr a, PrimExpr b) {
+inline Optional<PrimExpr> TryConstFold<tir::GE>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({
if (pa && pb) return IntImm(DataType::UInt(1), pa->value >= pb->value);
if (fa && fb) return IntImm(DataType::UInt(1), fa->value >= fb->value);
});
- return PrimExpr();
+ return NullOpt;
}
template <>
-inline PrimExpr TryConstFold<tir::LT>(PrimExpr a, PrimExpr b) {
+inline Optional<PrimExpr> TryConstFold<tir::LT>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({
if (pa && pb) return IntImm(DataType::UInt(1), pa->value < pb->value);
if (fa && fb) return IntImm(DataType::UInt(1), fa->value < fb->value);
});
- return PrimExpr();
+ return NullOpt;
}
template <>
-inline PrimExpr TryConstFold<tir::LE>(PrimExpr a, PrimExpr b) {
+inline Optional<PrimExpr> TryConstFold<tir::LE>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({
if (pa && pb) return IntImm(DataType::UInt(1), pa->value <= pb->value);
if (fa && fb) return IntImm(DataType::UInt(1), fa->value <= fb->value);
});
- return PrimExpr();
+ return NullOpt;
}
template <>
-inline PrimExpr TryConstFold<tir::EQ>(PrimExpr a, PrimExpr b) {
+inline Optional<PrimExpr> TryConstFold<tir::EQ>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({
if (pa && pb) return IntImm(DataType::UInt(1), pa->value == pb->value);
if (fa && fb) return IntImm(DataType::UInt(1), fa->value == fb->value);
});
- return PrimExpr();
+ return NullOpt;
}
template <>
-inline PrimExpr TryConstFold<tir::NE>(PrimExpr a, PrimExpr b) {
+inline Optional<PrimExpr> TryConstFold<tir::NE>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({
if (pa && pb) return IntImm(DataType::UInt(1), pa->value != pb->value);
if (fa && fb) return IntImm(DataType::UInt(1), fa->value != fb->value);
});
- return PrimExpr();
+ return NullOpt;
}
template <>
-inline PrimExpr TryConstFold<tir::And>(PrimExpr a, PrimExpr b) {
+inline Optional<PrimExpr> TryConstFold<tir::And>(PrimExpr a, PrimExpr b) {
const IntImmNode* pa = a.as<IntImmNode>();
const IntImmNode* pb = b.as<IntImmNode>();
if (pa && pa->value) return b;
if (pa && !pa->value) return a;
if (pb && pb->value) return a;
if (pb && !pb->value) return b;
- return PrimExpr();
+ return NullOpt;
}
template <>
-inline PrimExpr TryConstFold<tir::Or>(PrimExpr a, PrimExpr b) {
+inline Optional<PrimExpr> TryConstFold<tir::Or>(PrimExpr a, PrimExpr b) {
const IntImmNode* pa = a.as<IntImmNode>();
const IntImmNode* pb = b.as<IntImmNode>();
if (pa && pa->value) return a;
if (pa && !pa->value) return b;
if (pb && pb->value) return b;
if (pb && !pb->value) return a;
- return PrimExpr();
+ return NullOpt;
}
template <>
-inline PrimExpr TryConstFold<tir::Not>(PrimExpr a) {
+inline Optional<PrimExpr> TryConstFold<tir::Not>(PrimExpr a) {
const IntImmNode* pa = a.as<IntImmNode>();
if (pa) {
return IntImm(DataType::UInt(1), !(pa->value));
}
- return PrimExpr();
+ return NullOpt;
}
/*! \brief Helper namespace for symbolic value limits */
diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc
index e8e223ceca..35b12bb352 100644
--- a/src/arith/int_set.cc
+++ b/src/arith/int_set.cc
@@ -108,9 +108,13 @@ TVM_DECLARE_LOGICAL_OP(Not);
template <typename Op>
inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b,
DataType dtype) {
if (a->IsSinglePoint() && b->IsSinglePoint()) {
- PrimExpr res = TryConstFold<Op>(a->min_value, b->min_value);
- if (!res.defined()) res = Op(a->min_value, b->min_value);
- return IntervalSet::SinglePoint(res);
+ PrimExpr expr;
+ if (auto res = TryConstFold<Op>(a->min_value, b->min_value)) {
+ expr = res.value();
+ } else {
+ expr = Op(a->min_value, b->min_value);
+ }
+ return IntervalSet::SinglePoint(expr);
}
if (is_logical_op<Op>::value) {
return IntervalSet(make_const(dtype, 0), make_const(dtype, 1));
diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc
index 83e2821c98..182eada24d 100644
--- a/src/arith/iter_affine_map.cc
+++ b/src/arith/iter_affine_map.cc
@@ -1205,8 +1205,7 @@ PrimExpr IterMapRewriter::VisitExpr_(const AddNode* op) {
PrimExpr b = this->DirectMutate(op->b);
// const folding
- PrimExpr const_res = TryConstFold<Add>(a, b);
- if (const_res.defined()) return const_res;
+ if (auto const_res = TryConstFold<Add>(a, b)) return const_res.value();
// does not contain iter map.
if (!a->IsInstance<IterMapExprNode>() && !b->IsInstance<IterMapExprNode>()) {
if (op->a.same_as(a) && op->b.same_as(b)) {
@@ -1240,8 +1239,7 @@ PrimExpr IterMapRewriter::VisitExpr_(const SubNode* op) {
PrimExpr b = this->DirectMutate(op->b);
// const folding
- PrimExpr const_res = TryConstFold<Sub>(a, b);
- if (const_res.defined()) return const_res;
+ if (auto const_res = TryConstFold<Sub>(a, b)) return const_res.value();
// does not contain iter map.
if (!a->IsInstance<IterMapExprNode>() && !b->IsInstance<IterMapExprNode>()) {
@@ -1276,8 +1274,7 @@ PrimExpr IterMapRewriter::VisitExpr_(const MulNode* op) {
PrimExpr b = this->DirectMutate(op->b);
// const folding
- PrimExpr const_res = TryConstFold<Mul>(a, b);
- if (const_res.defined()) return const_res;
+ if (auto const_res = TryConstFold<Mul>(a, b)) return const_res.value();
// does not contain iter map.
if (!a->IsInstance<IterMapExprNode>() && !b->IsInstance<IterMapExprNode>()) {
@@ -1572,8 +1569,7 @@ PrimExpr IterMapRewriter::VisitExpr_(const FloorDivNode*
op) {
PrimExpr b = this->DirectMutate(op->b);
// const folding
- PrimExpr const_res = TryConstFold<FloorDiv>(a, b);
- if (const_res.defined()) return const_res;
+ if (auto const_res = TryConstFold<FloorDiv>(a, b)) return const_res.value();
// does not contain iter map.
if (!a->IsInstance<IterMapExprNode>() && !b->IsInstance<IterMapExprNode>()) {
@@ -1657,8 +1653,7 @@ PrimExpr IterMapRewriter::VisitExpr_(const FloorModNode*
op) {
PrimExpr b = this->DirectMutate(op->b);
// const folding
- PrimExpr const_res = TryConstFold<FloorMod>(a, b);
- if (const_res.defined()) return const_res;
+ if (auto const_res = TryConstFold<FloorMod>(a, b)) return const_res.value();
// does not contain iter map.
if (!a->IsInstance<IterMapExprNode>() && !b->IsInstance<IterMapExprNode>()) {
diff --git a/src/arith/pattern_match.h b/src/arith/pattern_match.h
index 6abcc728fc..69f064e119 100644
--- a/src/arith/pattern_match.h
+++ b/src/arith/pattern_match.h
@@ -330,8 +330,7 @@ class PBinaryExpr : public Pattern<PBinaryExpr<OpType, TA,
TB>> {
PrimExpr Eval() const {
PrimExpr lhs = a_.Eval();
PrimExpr rhs = b_.Eval();
- PrimExpr ret = TryConstFold<OpType>(lhs, rhs);
- if (ret.defined()) return ret;
+ if (auto ret = TryConstFold<OpType>(lhs, rhs)) return ret.value();
return OpType(lhs, rhs);
}
diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc
index d7866fc130..e3e9db62d0 100644
--- a/src/arith/rewrite_simplify.cc
+++ b/src/arith/rewrite_simplify.cc
@@ -124,8 +124,7 @@ void RewriteSimplifier::Impl::Update(const Var& var, const
PrimExpr& info, bool
PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AddNode* op) {
PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<AddNode>();
- PrimExpr const_res = TryConstFold<Add>(op->a, op->b);
- if (const_res.defined()) return const_res;
+ if (auto const_res = TryConstFold<Add>(op->a, op->b)) return
const_res.value();
// Pattern var to match any expression
PVar<PrimExpr> x, y, z, b1, b2, s1, s2;
// Pattern var match IntImm
@@ -258,8 +257,7 @@ std::function<void()>
RewriteSimplifier::Impl::EnterConstraint(const PrimExpr& c
PrimExpr RewriteSimplifier::Impl::VisitExpr_(const SubNode* op) {
PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<SubNode>();
- PrimExpr const_res = TryConstFold<Sub>(op->a, op->b);
- if (const_res.defined()) return const_res;
+ if (auto const_res = TryConstFold<Sub>(op->a, op->b)) return
const_res.value();
// Pattern var to match any expression
PVar<PrimExpr> x, y, z, b1, b2, s1, s2;
// Pattern var match IntImm
@@ -450,8 +448,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const SubNode*
op) {
PrimExpr RewriteSimplifier::Impl::VisitExpr_(const MulNode* op) {
PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<MulNode>();
- PrimExpr const_res = TryConstFold<Mul>(op->a, op->b);
- if (const_res.defined()) return const_res;
+ if (auto const_res = TryConstFold<Mul>(op->a, op->b)) return
const_res.value();
// Pattern var to match any expression
PVar<PrimExpr> x, y, z, b1, b2, s1, s2;
// Pattern var match IntImm
@@ -490,8 +487,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const MulNode*
op) {
PrimExpr RewriteSimplifier::Impl::VisitExpr_(const DivNode* op) {
PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<DivNode>();
- PrimExpr const_res = TryConstFold<Div>(op->a, op->b);
- if (const_res.defined()) return const_res;
+ if (auto const_res = TryConstFold<Div>(op->a, op->b)) return
const_res.value();
// Pattern var to match any expression
PVar<PrimExpr> x, y, z, b1;
// Pattern var match IntImm
@@ -666,8 +662,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const DivNode*
op) {
PrimExpr RewriteSimplifier::Impl::VisitExpr_(const ModNode* op) {
PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<ModNode>();
- PrimExpr const_res = TryConstFold<Mod>(op->a, op->b);
- if (const_res.defined()) return const_res;
+ if (auto const_res = TryConstFold<Mod>(op->a, op->b)) return
const_res.value();
// Pattern var to match any expression
PVar<PrimExpr> x, y, z, b1;
@@ -748,8 +743,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const ModNode*
op) {
PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) {
PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<FloorDivNode>();
- PrimExpr const_res = TryConstFold<FloorDiv>(op->a, op->b);
- if (const_res.defined()) return const_res;
+ if (auto const_res = TryConstFold<FloorDiv>(op->a, op->b)) return
const_res.value();
// Pattern var to match any expression
PVar<PrimExpr> x, y, z, b1;
// Pattern var match IntImm
@@ -895,8 +889,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const
FloorDivNode* op) {
PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorModNode* op) {
PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<FloorModNode>();
- PrimExpr const_res = TryConstFold<FloorMod>(op->a, op->b);
- if (const_res.defined()) return const_res;
+ if (auto const_res = TryConstFold<FloorMod>(op->a, op->b)) return
const_res.value();
// Pattern var to match any expression
PVar<PrimExpr> x, y, z, b1;
@@ -977,8 +970,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const
FloorModNode* op) {
PrimExpr RewriteSimplifier::Impl::VisitExpr_(const MinNode* op) {
PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<MinNode>();
- PrimExpr const_res = TryConstFold<Min>(op->a, op->b);
- if (const_res.defined()) return const_res;
+ if (auto const_res = TryConstFold<Min>(op->a, op->b)) return
const_res.value();
// Pattern var to match any expression
PVar<PrimExpr> x, y, z, s1, s2;
@@ -1149,8 +1141,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const
MinNode* op) {
PrimExpr RewriteSimplifier::Impl::VisitExpr_(const MaxNode* op) {
PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<MaxNode>();
- PrimExpr const_res = TryConstFold<Max>(op->a, op->b);
- if (const_res.defined()) return const_res;
+ if (auto const_res = TryConstFold<Max>(op->a, op->b)) return
const_res.value();
// Pattern var to match any expression
PVar<PrimExpr> x, y, z, s1, s2;
@@ -1327,8 +1318,7 @@ Optional<PrimExpr>
RewriteSimplifier::Impl::TryMatchLiteralConstraint(const Prim
PrimExpr RewriteSimplifier::Impl::VisitExpr_(const EQNode* op) {
PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<EQNode>();
- PrimExpr const_res = TryConstFold<EQ>(op->a, op->b);
- if (const_res.defined()) return const_res;
+ if (auto const_res = TryConstFold<EQ>(op->a, op->b)) return
const_res.value();
if (auto match = TryMatchLiteralConstraint(ret)) return match.value();
// Pattern var to match any expression
@@ -1376,8 +1366,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const
GENode* op) {
PrimExpr RewriteSimplifier::Impl::VisitExpr_(const LTNode* op) {
PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<LTNode>();
- PrimExpr const_res = TryConstFold<LT>(op->a, op->b);
- if (const_res.defined()) return const_res;
+ if (auto const_res = TryConstFold<LT>(op->a, op->b)) return
const_res.value();
if (auto match = TryMatchLiteralConstraint(ret)) return match.value();
// Pattern var to match any expression
@@ -1508,8 +1497,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const
LTNode* op) {
PrimExpr RewriteSimplifier::Impl::VisitExpr_(const NotNode* op) {
PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<NotNode>();
- PrimExpr const_res = TryConstFold<Not>(op->a);
- if (const_res.defined()) return const_res;
+ if (auto const_res = TryConstFold<Not>(op->a)) return const_res.value();
if (auto match = TryMatchLiteralConstraint(ret)) return match.value();
// Pattern var to match any expression
@@ -1534,8 +1522,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const
NotNode* op) {
PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AndNode* op) {
PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<AndNode>();
- PrimExpr const_res = TryConstFold<And>(op->a, op->b);
- if (const_res.defined()) return const_res;
+ if (auto const_res = TryConstFold<And>(op->a, op->b)) return
const_res.value();
if (auto match = TryMatchLiteralConstraint(ret)) return match.value();
// Pattern var to match any expression
@@ -1574,8 +1561,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const
AndNode* op) {
PrimExpr RewriteSimplifier::Impl::VisitExpr_(const OrNode* op) {
PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<OrNode>();
- PrimExpr const_res = TryConstFold<Or>(op->a, op->b);
- if (const_res.defined()) return const_res;
+ if (auto const_res = TryConstFold<Or>(op->a, op->b)) return
const_res.value();
if (auto match = TryMatchLiteralConstraint(ret)) return match.value();
// Pattern var to match any expression
diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc
index b9e0c3c370..509badbebb 100644
--- a/src/tir/op/op.cc
+++ b/src/tir/op/op.cc
@@ -327,8 +327,7 @@ PrimExpr operator+(PrimExpr a, PrimExpr b) { return add(a,
b); }
PrimExpr add(PrimExpr a, PrimExpr b, Span span) {
BinaryOpMatchTypes(a, b, span);
- PrimExpr ret = arith::TryConstFold<tir::Add>(a, b);
- if (ret.defined()) return ret;
+ if (auto ret = arith::TryConstFold<tir::Add>(a, b)) return ret.value();
return tir::Add(a, b, span);
}
@@ -349,23 +348,20 @@ PrimExpr operator-(PrimExpr a, PrimExpr b) { return
sub(a, b); }
PrimExpr sub(PrimExpr a, PrimExpr b, Span span) {
BinaryOpMatchTypes(a, b, span);
- PrimExpr ret = arith::TryConstFold<tir::Sub>(a, b);
- if (ret.defined()) return ret;
+ if (auto ret = arith::TryConstFold<tir::Sub>(a, b)) return ret.value();
return tir::Sub(a, b, span);
}
PrimExpr operator*(PrimExpr a, PrimExpr b) { return mul(a, b); }
PrimExpr mul(PrimExpr a, PrimExpr b, Span span) {
BinaryOpMatchTypes(a, b, span);
- PrimExpr ret = arith::TryConstFold<tir::Mul>(a, b);
- if (ret.defined()) return ret;
+ if (auto ret = arith::TryConstFold<tir::Mul>(a, b)) return ret.value();
return tir::Mul(a, b, span);
}
PrimExpr div(PrimExpr a, PrimExpr b, Span span) {
BinaryOpMatchTypes(a, b, span);
- PrimExpr ret = arith::TryConstFold<tir::Div>(a, b);
- if (ret.defined()) return ret;
+ if (auto ret = arith::TryConstFold<tir::Div>(a, b)) return ret.value();
return tir::Div(a, b, span);
}
@@ -377,8 +373,7 @@ PrimExpr truncdiv(PrimExpr a, PrimExpr b, Span span) {
PrimExpr truncmod(PrimExpr a, PrimExpr b, Span span) {
BinaryOpMatchTypes(a, b, span);
- PrimExpr ret = arith::TryConstFold<tir::Mod>(a, b);
- if (ret.defined()) return ret;
+ if (auto ret = arith::TryConstFold<tir::Mod>(a, b)) return ret.value();
return tir::Mod(a, b, span);
}
@@ -397,8 +392,7 @@ PrimExpr floordiv(PrimExpr a, PrimExpr b, Span span) {
ICHECK(a.dtype().is_int() || a.dtype().is_uint()) << a;
ICHECK(b.dtype().is_int() || b.dtype().is_uint()) << b;
BinaryOpMatchTypes(a, b, span);
- PrimExpr ret = arith::TryConstFold<tir::FloorDiv>(a, b);
- if (ret.defined()) return ret;
+ if (auto ret = arith::TryConstFold<tir::FloorDiv>(a, b)) return ret.value();
return tir::FloorDiv(a, b, span);
}
@@ -406,8 +400,7 @@ PrimExpr ceildiv(PrimExpr a, PrimExpr b, Span span) {
ICHECK(a.dtype().is_int() || a.dtype().is_uint()) << a;
ICHECK(b.dtype().is_int() || b.dtype().is_uint()) << b;
BinaryOpMatchTypes(a, b, span);
- PrimExpr ret = arith::TryConstFold<tir::FloorDiv>(a + b - 1, b);
- if (ret.defined()) return ret;
+ if (auto ret = arith::TryConstFold<tir::FloorDiv>(a + b - 1, b)) return
ret.value();
return tir::FloorDiv(a + b - 1, b, span);
}
@@ -415,8 +408,7 @@ PrimExpr floormod(PrimExpr a, PrimExpr b, Span span) {
ICHECK(a.dtype().is_int() || a.dtype().is_uint()) << a;
ICHECK(b.dtype().is_int() || b.dtype().is_uint()) << b;
BinaryOpMatchTypes(a, b, span);
- PrimExpr ret = arith::TryConstFold<tir::FloorMod>(a, b);
- if (ret.defined()) return ret;
+ if (auto ret = arith::TryConstFold<tir::FloorMod>(a, b)) return ret.value();
return tir::FloorMod(a, b, span);
}
@@ -429,8 +421,7 @@ PrimExpr min(PrimExpr a, PrimExpr b, Span span) {
if (is_pos_inf(b)) return a;
if (is_neg_inf(b)) return b;
BinaryOpMatchTypes(a, b, span);
- PrimExpr ret = arith::TryConstFold<tir::Min>(a, b);
- if (ret.defined()) return ret;
+ if (auto ret = arith::TryConstFold<tir::Min>(a, b)) return ret.value();
return tir::Min(a, b, span);
}
@@ -443,8 +434,7 @@ PrimExpr max(PrimExpr a, PrimExpr b, Span span) {
if (is_pos_inf(b)) return b;
if (is_neg_inf(b)) return a;
BinaryOpMatchTypes(a, b, span);
- PrimExpr ret = arith::TryConstFold<tir::Max>(a, b);
- if (ret.defined()) return ret;
+ if (auto ret = arith::TryConstFold<tir::Max>(a, b)) return ret.value();
return tir::Max(a, b, span);
}
@@ -475,48 +465,42 @@ PrimExpr likely(PrimExpr cond, Span span) {
PrimExpr operator>(PrimExpr a, PrimExpr b) { return greater(a, b); }
PrimExpr greater(PrimExpr a, PrimExpr b, Span span) {
BinaryOpMatchTypes(a, b, span);
- PrimExpr ret = arith::TryConstFold<tir::GT>(a, b);
- if (ret.defined()) return ret;
+ if (auto ret = arith::TryConstFold<tir::GT>(a, b)) return ret.value();
return tir::GT(a, b, span);
}
PrimExpr operator>=(PrimExpr a, PrimExpr b) { return greater_equal(a, b); }
PrimExpr greater_equal(PrimExpr a, PrimExpr b, Span span) {
BinaryOpMatchTypes(a, b, span);
- PrimExpr ret = arith::TryConstFold<tir::GE>(a, b);
- if (ret.defined()) return ret;
+ if (auto ret = arith::TryConstFold<tir::GE>(a, b)) return ret.value();
return tir::GE(a, b, span);
}
PrimExpr operator<(PrimExpr a, PrimExpr b) { return less(a, b); }
PrimExpr less(PrimExpr a, PrimExpr b, Span span) {
BinaryOpMatchTypes(a, b, span);
- PrimExpr ret = arith::TryConstFold<tir::LT>(a, b);
- if (ret.defined()) return ret;
+ if (auto ret = arith::TryConstFold<tir::LT>(a, b)) return ret.value();
return tir::LT(a, b, span);
}
PrimExpr operator<=(PrimExpr a, PrimExpr b) { return less_equal(a, b); }
PrimExpr less_equal(PrimExpr a, PrimExpr b, Span span) {
BinaryOpMatchTypes(a, b, span);
- PrimExpr ret = arith::TryConstFold<tir::LE>(a, b);
- if (ret.defined()) return ret;
+ if (auto ret = arith::TryConstFold<tir::LE>(a, b)) return ret.value();
return tir::LE(a, b, span);
}
PrimExpr operator==(PrimExpr a, PrimExpr b) { return equal(a, b); }
PrimExpr equal(PrimExpr a, PrimExpr b, Span span) {
BinaryOpMatchTypes(a, b, span);
- PrimExpr ret = arith::TryConstFold<tir::EQ>(a, b);
- if (ret.defined()) return ret;
+ if (auto ret = arith::TryConstFold<tir::EQ>(a, b)) return ret.value();
return tir::EQ(a, b, span);
}
PrimExpr operator!=(PrimExpr a, PrimExpr b) { return not_equal(a, b); }
PrimExpr not_equal(PrimExpr a, PrimExpr b, Span span) {
BinaryOpMatchTypes(a, b, span);
- PrimExpr ret = arith::TryConstFold<tir::NE>(a, b);
- if (ret.defined()) return ret;
+ if (auto ret = arith::TryConstFold<tir::NE>(a, b)) return ret.value();
return tir::NE(a, b, span);
}
@@ -551,24 +535,21 @@ void type_check_integer_args(const PrimExpr& lhs, const
PrimExpr& rhs, const cha
PrimExpr operator&&(PrimExpr a, PrimExpr b) { return logical_and(a, b); }
PrimExpr logical_and(PrimExpr a, PrimExpr b, Span span) {
type_check_boolean_args(a, b, "&& operator (logical AND)");
- PrimExpr ret = arith::TryConstFold<tir::And>(a, b);
- if (ret.defined()) return ret;
+ if (auto ret = arith::TryConstFold<tir::And>(a, b)) return ret.value();
return tir::And(a, b, span);
}
PrimExpr operator||(PrimExpr a, PrimExpr b) { return logical_or(a, b); }
PrimExpr logical_or(PrimExpr a, PrimExpr b, Span span) {
type_check_boolean_args(a, b, "|| operator (logical OR)");
- PrimExpr ret = arith::TryConstFold<tir::Or>(a, b);
- if (ret.defined()) return ret;
+ if (auto ret = arith::TryConstFold<tir::Or>(a, b)) return ret.value();
return tir::Or(a, b, span);
}
PrimExpr operator!(PrimExpr a) { return logical_not(a); }
PrimExpr logical_not(PrimExpr a, Span span) {
type_check_boolean_args(a, "! operator (logical NOT)");
- PrimExpr ret = arith::TryConstFold<tir::Not>(a);
- if (ret.defined()) return ret;
+ if (auto ret = arith::TryConstFold<tir::Not>(a)) return ret.value();
return tir::Not(a, span);
}