This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new 4fbfaca [TIR] Improve Let/LetStmt support. (#5949)
4fbfaca is described below
commit 4fbfaca5fc417b09330f7592b21efdc6fb1cf51b
Author: Tianqi Chen <[email protected]>
AuthorDate: Sun Jun 28 09:22:11 2020 -0700
[TIR] Improve Let/LetStmt support. (#5949)
Let/LetStmt are useful primitives to create variable bindings.
While let binding are harmful for simplification and integer analysis,
they are useful for other cases:
- C0: LetStmt is useful to represent a step that has side effect(e.g. call
a PRNG)
- C1: Let expression can be used to create deep nested expression for
complicated functions.
This PR improves the let support in the following ways:
- Enable vectorization support for let
- Change let simplification strategy to simplify the most trivial case
while ignore more complicated cases(to avoid deep nest explosion)
- Enhance arith module to handle const bound and modular set for let.
The overall recommendation is to only use Let in the cases when
necessary(C0, C1).
---
include/tvm/arith/analyzer.h | 32 +++---
include/tvm/tir/op.h | 26 ++++-
src/arith/analyzer.cc | 24 ++---
src/arith/const_int_bound.cc | 33 ++++--
src/arith/modular_set.cc | 21 +++-
src/arith/rewrite_simplify.cc | 16 ++-
src/arith/rewrite_simplify.h | 7 ++
src/contrib/hybrid/codegen_hybrid.cc | 4 +-
src/printer/tir_text_printer.cc | 88 +++++++++-------
src/target/source/codegen_c.cc | 2 +-
src/target/source/codegen_cuda.cc | 2 +-
src/target/stackvm/codegen_stackvm.cc | 2 +-
src/tir/analysis/verify_ssa.cc | 53 ++++++----
src/tir/op/op.cc | 2 +-
src/tir/transforms/loop_partition.cc | 4 +-
src/tir/transforms/simplify.cc | 11 +-
src/tir/transforms/split_host_device.cc | 6 +-
src/tir/transforms/vectorize_loop.cc | 112 +++++++++++++++++----
.../python/unittest/test_arith_const_int_bound.py | 9 ++
tests/python/unittest/test_arith_modular_set.py | 9 ++
.../unittest/test_tir_analysis_verify_ssa.py | 10 ++
tests/python/unittest/test_tir_nodes.py | 3 +-
.../unittest/test_tir_transform_vectorize.py | 15 +++
23 files changed, 354 insertions(+), 137 deletions(-)
diff --git a/include/tvm/arith/analyzer.h b/include/tvm/arith/analyzer.h
index 8033294..c4ee7b5 100644
--- a/include/tvm/arith/analyzer.h
+++ b/include/tvm/arith/analyzer.h
@@ -128,17 +128,17 @@ class ConstIntBoundAnalyzer {
*
* \param var The variable of interest.
* \param info The bound information.
- * \param override Whether do we allow override of existing information.
+ * \param allow_override Whether do we allow override of existing
information.
*/
- TVM_DLL void Update(const Var& var, const ConstIntBound& info, bool override
= false);
+ TVM_DLL void Update(const Var& var, const ConstIntBound& info, bool
allow_override = false);
/*!
* \brief Bind variable to a range.
*
* \param var The variable.
* \param range The range we bind to.
- * \param override Whether we allow overriding an existing var's range.
+ * \param allow_override Whether we allow overriding an existing var's range.
*/
- TVM_DLL void Bind(const Var& var, const Range& range, bool override = false);
+ TVM_DLL void Bind(const Var& var, const Range& range, bool allow_override =
false);
private:
friend class Analyzer;
@@ -217,9 +217,9 @@ class ModularSetAnalyzer {
*
* \param var The variable of interest.
* \param info The bound information.
- * \param override Whether do we allow override of existing information.
+ * \param allow_override Whether do we allow override of existing
information.
*/
- TVM_DLL void Update(const Var& var, const ModularSet& info, bool override =
false);
+ TVM_DLL void Update(const Var& var, const ModularSet& info, bool
allow_override = false);
private:
friend class Analyzer;
@@ -256,9 +256,9 @@ class RewriteSimplifier {
*
* \param var The variable of interest.
* \param new_expr
- * \param override Whether do we allow override of existing information.
+ * \param allow_override Whether do we allow override of existing
information.
*/
- TVM_DLL void Update(const Var& var, const PrimExpr& new_expr, bool override
= false);
+ TVM_DLL void Update(const Var& var, const PrimExpr& new_expr, bool
allow_override = false);
std::function<void()> EnterConstraint(const PrimExpr& constraint);
@@ -290,9 +290,9 @@ class CanonicalSimplifier {
*
* \param var The variable of interest.
* \param new_expr
- * \param override Whether do we allow override of existing information.
+ * \param allow_override Whether do we allow override of existing
information.
*/
- TVM_DLL void Update(const Var& var, const PrimExpr& new_expr, bool override
= false);
+ TVM_DLL void Update(const Var& var, const PrimExpr& new_expr, bool
allow_override = false);
private:
friend class Analyzer;
@@ -404,9 +404,9 @@ class TVM_DLL Analyzer {
*
* \param var The variable.
* \param expr The expression we bind to.
- * \param override Whether we allow overriding an existing var's expression.
+ * \param allow_override Whether we allow overriding an existing var's
expression.
*/
- void Bind(const Var& var, const PrimExpr& expr, bool override = false);
+ void Bind(const Var& var, const PrimExpr& expr, bool allow_override = false);
/*!
* \brief Notify all the sub-analyzers that var
* is created and binded to a range.
@@ -415,16 +415,16 @@ class TVM_DLL Analyzer {
*
* \param var The variable.
* \param range The range we bind to.
- * \param override Whether we allow overriding an existing var's expression.
+ * \param allow_override Whether we allow overriding an existing var's
expression.
*/
- void Bind(const Var& var, const Range& range, bool override = false);
+ void Bind(const Var& var, const Range& range, bool allow_override = false);
/*!
* \brief Bind all the vars in the Map
*
* \param variables The {variable -> range} map.
- * \param override Whether we allow overriding an existing var's expression.
+ * \param allow_override Whether we allow overriding an existing var's
expression.
*/
- void Bind(const Map<Var, Range>& variables, bool override = false);
+ void Bind(const Map<Var, Range>& variables, bool allow_override = false);
/*!
* \brief Whether can we prove expr >= val.
diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h
index 34cb52f..31ce13c 100644
--- a/include/tvm/tir/op.h
+++ b/include/tvm/tir/op.h
@@ -671,11 +671,18 @@ inline bool is_one(const PrimExpr& x) { return
is_const_int(x, 1); }
inline bool is_zero(const PrimExpr& x) { return is_const_int(x, 0); }
/*!
- * \brief Check whether x is a constant.
+ * \brief Check whether x is an integer constant.
* \note This only return true for integer types.
* \return whether x is constant
*/
-inline bool is_const(const PrimExpr& x);
+inline bool is_const_int(const PrimExpr& x);
+
+/*!
+ * \brief Check whether x is an integer/float constant.
+ * \note This only return true for integer types.
+ * \return whether x is constant
+ */
+inline bool is_const_number(const PrimExpr& x);
/*!
* \brief Left fold.
@@ -699,7 +706,7 @@ inline PrimExpr foldl(FReduce freduce, PrimExpr init_value,
const Array<PrimExpr
TVM_DLL bool is_const_power_of_two_integer(const PrimExpr& x, int* shift);
// Implementation details after this
-inline bool is_const(const PrimExpr& x) {
+inline bool is_const_int(const PrimExpr& x) {
if (x.as<tir::IntImmNode>()) {
return true;
} else if (const auto* op = x.as<tir::BroadcastNode>()) {
@@ -711,6 +718,17 @@ inline bool is_const(const PrimExpr& x) {
return false;
}
+inline bool is_const_number(const PrimExpr& x) {
+ if (x.as<tir::IntImmNode>()) {
+ return true;
+ } else if (x.as<tir::FloatImmNode>()) {
+ return true;
+ } else if (const auto* op = x.as<tir::BroadcastNode>()) {
+ return (op->value->IsInstance<tir::IntImmNode>() ||
op->value->IsInstance<tir::FloatImmNode>());
+ }
+ return false;
+}
+
inline bool is_positive_const(const PrimExpr& a) {
if (const tir::IntImmNode* op = a.as<tir::IntImmNode>()) {
return op->value > 0;
@@ -742,7 +760,7 @@ inline bool is_const_int(const PrimExpr& x, int64_t value) {
inline bool is_no_op(const tir::Stmt& stmt) {
if (!stmt.defined()) return true;
if (const auto* op = stmt.as<tir::EvaluateNode>()) {
- return is_const(op->value);
+ return is_const_int(op->value);
}
if (const auto* op = stmt.as<tir::SeqStmtNode>()) {
return op->seq.size() == 0;
diff --git a/src/arith/analyzer.cc b/src/arith/analyzer.cc
index 037c766..c7a8365 100644
--- a/src/arith/analyzer.cc
+++ b/src/arith/analyzer.cc
@@ -35,31 +35,31 @@ Analyzer::Analyzer()
canonical_simplify(this),
int_set(this) {}
-void Analyzer::Bind(const Var& var, const PrimExpr& expr, bool override) {
+void Analyzer::Bind(const Var& var, const PrimExpr& expr, bool allow_override)
{
PrimExpr new_expr = expr;
new_expr = this->canonical_simplify(new_expr);
new_expr = this->rewrite_simplify(new_expr);
- this->const_int_bound.Update(var, this->const_int_bound(new_expr), override);
- this->modular_set.Update(var, this->modular_set(new_expr), override);
- this->rewrite_simplify.Update(var, new_expr, override);
- this->canonical_simplify.Update(var, new_expr, override);
+ this->const_int_bound.Update(var, this->const_int_bound(new_expr),
allow_override);
+ this->modular_set.Update(var, this->modular_set(new_expr), allow_override);
+ this->rewrite_simplify.Update(var, new_expr, allow_override);
+ this->canonical_simplify.Update(var, new_expr, allow_override);
}
-void Analyzer::Bind(const Var& var, const Range& range, bool override) {
+void Analyzer::Bind(const Var& var, const Range& range, bool allow_override) {
CHECK(range.defined());
if (tir::is_one(range->extent)) {
- this->Bind(var, range->min, override);
+ this->Bind(var, range->min, allow_override);
} else {
- this->const_int_bound.Bind(var, range, override);
+ this->const_int_bound.Bind(var, range, allow_override);
}
// skip modular_set
// skip rewrite simplify
}
-void Analyzer::Bind(const Map<Var, Range>& variables, bool override) {
+void Analyzer::Bind(const Map<Var, Range>& variables, bool allow_override) {
for (const auto& iter : variables) {
- this->Bind(iter.first, iter.second, override);
+ this->Bind(iter.first, iter.second, allow_override);
}
}
@@ -116,9 +116,9 @@ bool Analyzer::CanProve(const PrimExpr& expr) {
}
PrimExpr Analyzer::Simplify(const PrimExpr& expr) {
- if (tir::is_const(expr)) return expr;
+ if (tir::is_const_int(expr)) return expr;
auto res = this->rewrite_simplify(expr);
- if (tir::is_const(res)) return res;
+ if (tir::is_const_int(res)) return res;
res = this->canonical_simplify(res);
return res;
}
diff --git a/src/arith/const_int_bound.cc b/src/arith/const_int_bound.cc
index 8c90249..be830d3 100644
--- a/src/arith/const_int_bound.cc
+++ b/src/arith/const_int_bound.cc
@@ -96,17 +96,17 @@ class ConstIntBoundAnalyzer::Impl
BoundInfo(PrimExpr expr, Entry bound) : expr(expr), bound(bound) {}
};
- void Bind(const Var& var, const Range& range, bool override) {
+ void Bind(const Var& var, const Range& range, bool allow_override) {
Entry a = VisitExpr(range->min);
Entry b = VisitExpr(range->extent);
Entry ret;
ret.min_value = a.min_value;
ret.max_value = InfAwareAdd(a.max_value, InfAwareAdd(b.max_value, -1));
- Update(var, ret, override);
+ Update(var, ret, allow_override);
}
- void Update(const Var& var, const Entry& info, bool override) {
- if (!override) {
+ void Update(const Var& var, const Entry& info, bool allow_override) {
+ if (!allow_override) {
auto it = var_map_.find(var);
if (it != var_map_.end()) {
CHECK(it->second == info) << "Trying to update var \'" << var << "\'"
@@ -119,8 +119,21 @@ class ConstIntBoundAnalyzer::Impl
var_map_[var] = info;
}
- void Update(const Var& var, const ConstIntBound& info, bool override) {
- Update(var, MakeBound(info->min_value, info->max_value), override);
+ Entry VisitExpr_(const LetNode* op) final {
+ auto it = var_map_.find(op->var);
+ // if the var has not been binded, update the info.
+ if (it == var_map_.end()) {
+ var_map_[op->var] = this->VisitExpr(op->value);
+ Entry ret = VisitExpr(op->body);
+ var_map_.erase(op->var);
+ return ret;
+ } else {
+ return VisitExpr(op->body);
+ }
+ }
+
+ void Update(const Var& var, const ConstIntBound& info, bool allow_override) {
+ Update(var, MakeBound(info->min_value, info->max_value), allow_override);
}
// Override visitor behaviors
@@ -558,12 +571,12 @@ ConstIntBound ConstIntBoundAnalyzer::operator()(const
PrimExpr& expr, BoundMapTy
return ConstIntBound(ret.min_value, ret.max_value);
}
-void ConstIntBoundAnalyzer::Update(const Var& var, const ConstIntBound& info,
bool override) {
- impl_->Update(var, info, override);
+void ConstIntBoundAnalyzer::Update(const Var& var, const ConstIntBound& info,
bool allow_override) {
+ impl_->Update(var, info, allow_override);
}
-void ConstIntBoundAnalyzer::Bind(const Var& var, const Range& range, bool
override) {
- impl_->Bind(var, range, override);
+void ConstIntBoundAnalyzer::Bind(const Var& var, const Range& range, bool
allow_override) {
+ impl_->Bind(var, range, allow_override);
}
std::function<void()> ConstIntBoundAnalyzer::EnterConstraint(const PrimExpr&
constraint) {
diff --git a/src/arith/modular_set.cc b/src/arith/modular_set.cc
index 108f08c..8c41760 100644
--- a/src/arith/modular_set.cc
+++ b/src/arith/modular_set.cc
@@ -89,8 +89,8 @@ class ModularSetAnalyzer::Impl : public
ExprFunctor<ModularSetAnalyzer::Entry(co
public:
explicit Impl(Analyzer* parent) : parent_(parent) {}
- void Update(const Var& var, const ModularSet& info, bool override) {
- if (!override) {
+ void Update(const Var& var, const ModularSet& info, bool allow_override) {
+ if (!allow_override) {
auto it = var_map_.find(var);
if (it != var_map_.end()) {
CHECK(it->second == info) << "Trying to update var \'" << var << "\'"
@@ -118,6 +118,19 @@ class ModularSetAnalyzer::Impl : public
ExprFunctor<ModularSetAnalyzer::Entry(co
// Override visitor behaviors
Entry VisitExprDefault_(const Object* op) final { return Everything(); }
+ Entry VisitExpr_(const LetNode* op) final {
+ auto it = var_map_.find(op->var);
+ // if the var has not been binded, update the info.
+ if (it == var_map_.end()) {
+ var_map_[op->var] = this->VisitExpr(op->value);
+ Entry ret = VisitExpr(op->body);
+ var_map_.erase(op->var);
+ return ret;
+ } else {
+ return VisitExpr(op->body);
+ }
+ }
+
Entry VisitExpr_(const CastNode* op) final { return VisitExpr(op->value); }
Entry VisitExpr_(const IntImmNode* op) final { return Entry(0, op->value); }
@@ -315,8 +328,8 @@ ModularSet ModularSetAnalyzer::operator()(const PrimExpr&
expr) {
return ModularSet(ret.coeff, ret.base);
}
-void ModularSetAnalyzer::Update(const Var& var, const ModularSet& info, bool
override) {
- impl_->Update(var, info, override);
+void ModularSetAnalyzer::Update(const Var& var, const ModularSet& info, bool
allow_override) {
+ impl_->Update(var, info, allow_override);
}
std::function<void()> ModularSetAnalyzer::EnterConstraint(const PrimExpr&
constraint) {
diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc
index 898eecc..e9d640a 100644
--- a/src/arith/rewrite_simplify.cc
+++ b/src/arith/rewrite_simplify.cc
@@ -1519,7 +1519,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const
CallNode* op) {
op = ret.as<CallNode>();
if (op == nullptr) return ret;
- if (op->op.same_as(tir::builtin::likely()) && is_const(op->args[0])) {
+ if (op->op.same_as(tir::builtin::likely()) && is_const_int(op->args[0])) {
return op->args[0];
} else if (op->op.same_as(tir::builtin::shift_right())) {
if (op->args[0].as<IntImmNode>() && op->args[1].as<IntImmNode>()) {
@@ -1559,9 +1559,17 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const
CastNode* op) {
return cast(op->dtype, op->value);
}
+bool RewriteSimplifier::Impl::CanInlineLet(const LetNode* op) {
+ // Only inline trivial bindings to avoid deep expression explosion
+ // when we need let to construct complicated expressions.
+ if (is_const_number(op->value)) return true;
+ if (op->value.as<VarNode>()) return true;
+ return false;
+}
+
PrimExpr RewriteSimplifier::Impl::VisitExpr_(const LetNode* op) {
PrimExpr value = this->VisitExpr(op->value);
- if (!tir::HasSideEffect(value)) {
+ if (CanInlineLet(op)) {
// it is fine to discard the let binding
// because the value will always be inlined in the simplifier.
analyzer_->Bind(op->var, value);
@@ -1587,8 +1595,8 @@ PrimExpr RewriteSimplifier::operator()(const PrimExpr&
expr) {
return res;
}
-void RewriteSimplifier::Update(const Var& var, const PrimExpr& info, bool
override) {
- impl_->Update(var, info, override);
+void RewriteSimplifier::Update(const Var& var, const PrimExpr& info, bool
allow_override) {
+ impl_->Update(var, info, allow_override);
}
std::function<void()> RewriteSimplifier::EnterConstraint(const PrimExpr&
constraint) {
diff --git a/src/arith/rewrite_simplify.h b/src/arith/rewrite_simplify.h
index 68c0dd2..258f833 100644
--- a/src/arith/rewrite_simplify.h
+++ b/src/arith/rewrite_simplify.h
@@ -98,6 +98,13 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer
{
*/
CompareResult TryCompare(const PrimExpr& x, int64_t val);
+ /*!
+ * \brief Internal function to check whether or not to inline let.
+ * \param op The let expr.
+ * \return The inline decision.
+ */
+ bool CanInlineLet(const LetNode* op);
+
private:
// Whether x >= val
bool CanProveGreaterEqual(const PrimExpr& x, int64_t val) {
diff --git a/src/contrib/hybrid/codegen_hybrid.cc
b/src/contrib/hybrid/codegen_hybrid.cc
index b65ae91..67765f0 100644
--- a/src/contrib/hybrid/codegen_hybrid.cc
+++ b/src/contrib/hybrid/codegen_hybrid.cc
@@ -381,7 +381,7 @@ void CodeGenHybrid::VisitStmt_(const ForNode* op) {
bool is_noop(const Stmt& stmt) {
if (!stmt.defined()) return true;
- if (auto eval = stmt.as<EvaluateNode>()) return is_const(eval->value);
+ if (auto eval = stmt.as<EvaluateNode>()) return is_const_int(eval->value);
return false;
}
@@ -409,7 +409,7 @@ void CodeGenHybrid::VisitStmt_(const SeqStmtNode* op) {
}
void CodeGenHybrid::VisitStmt_(const EvaluateNode* op) {
- if (is_const(op->value)) return;
+ if (is_const_int(op->value)) return;
std::string str = PrintExpr(op->value);
if (!str.empty()) stream << str << "\n";
}
diff --git a/src/printer/tir_text_printer.cc b/src/printer/tir_text_printer.cc
index 7ab26fa..ca038ab 100644
--- a/src/printer/tir_text_printer.cc
+++ b/src/printer/tir_text_printer.cc
@@ -71,12 +71,14 @@ Doc TIRTextPrinter::Print(const ObjectRef& node) {
}
}
-Doc TIRTextPrinter::PrintPrimFunc(const PrimFunc& primFunc) {
- const auto* op = primFunc.operator->();
+Doc TIRTextPrinter::PrintPrimFunc(const PrimFunc& prim_func) {
+ const auto* op = prim_func.operator->();
const auto& signature = op->func_type_annotation();
// collect Meta in DictAttr
- for (const auto& it : primFunc->attrs->dict) {
- meta_collector_.Collect(it.second);
+ if (prim_func->attrs.defined()) {
+ for (const auto& it : prim_func->attrs->dict) {
+ meta_collector_.Collect(it.second);
+ }
}
// collect buffers in buffer_map
memo_var_.clear();
@@ -100,46 +102,54 @@ Doc TIRTextPrinter::PrintPrimFunc(const PrimFunc&
primFunc) {
// print attr
Doc attr_doc;
std::vector<Doc> attr_docs;
- for (const auto& it : op->attrs->dict) {
- attr_docs.push_back(Doc::StrLiteral(it.first) << ": " << Print(it.second));
+ if (prim_func->attrs.defined()) {
+ for (const auto& it : op->attrs->dict) {
+ attr_docs.push_back(Doc::StrLiteral(it.first) << ": " <<
Print(it.second));
+ }
+ attr_doc << Doc::NewLine() << "attr = {" << PrintSep(attr_docs,
Doc::Text(", ")) << "}";
+ doc << Doc::Indent(2, attr_doc);
}
- attr_doc << Doc::NewLine() << "attr = {" << PrintSep(attr_docs, Doc::Text(",
")) << "}";
- doc << Doc::Indent(2, attr_doc);
+
// print all the buffers in the tree
- Doc buffer_doc;
- std::vector<Doc> buffer_docs;
- for (const auto& it : memo_buf_) {
- const auto& buf = it.first;
- buffer_docs.push_back(Print(buf) << Doc::Text(": Buffer(") <<
Print(buf->data) << ", "
- << PrintDType(buf->dtype) << ", " <<
Print(buf->shape) << ", "
- << Print(buf->strides));
- if (!is_zero(buf->elem_offset)) {
- buffer_docs.back() << ", elem_offset=" << Print(buf->elem_offset);
- }
- if (buf->scope != "global") {
- buffer_docs.back() << ", scope=" << Doc::StrLiteral(buf->scope);
- }
- if (buf->data_alignment != 128) {
- buffer_docs.back() << ", align=" << buf->data_alignment;
+ if (memo_buf_.size() != 0) {
+ Doc buffer_doc;
+ std::vector<Doc> buffer_docs;
+ for (const auto& it : memo_buf_) {
+ const auto& buf = it.first;
+ buffer_docs.push_back(Print(buf) << Doc::Text(": Buffer(") <<
Print(buf->data) << ", "
+ << PrintDType(buf->dtype) << ", " <<
Print(buf->shape)
+ << ", " << Print(buf->strides));
+ if (!is_zero(buf->elem_offset)) {
+ buffer_docs.back() << ", elem_offset=" << Print(buf->elem_offset);
+ }
+ if (buf->scope != "global") {
+ buffer_docs.back() << ", scope=" << Doc::StrLiteral(buf->scope);
+ }
+ if (buf->data_alignment != 128) {
+ buffer_docs.back() << ", align=" << buf->data_alignment;
+ }
+ if (buf->offset_factor != 1) {
+ buffer_docs.back() << ", offset_factor=" << buf->offset_factor;
+ }
+ if (buf->buffer_type != 1) {
+ buffer_docs.back() << ", type=" << Doc::StrLiteral("auto");
+ }
+ buffer_docs.back() << ")";
}
- if (buf->offset_factor != 1) {
- buffer_docs.back() << ", offset_factor=" << buf->offset_factor;
- }
- if (buf->buffer_type != 1) {
- buffer_docs.back() << ", type=" << Doc::StrLiteral("auto");
- }
- buffer_docs.back() << ")";
+ buffer_doc << Doc::NewLine() << "buffers = {";
+ buffer_doc << PrintSep(buffer_docs, Doc::Indent(11, Doc::Text(",") <<
Doc::NewLine()));
+ doc << Doc::Indent(2, buffer_doc) << "}";
}
- buffer_doc << Doc::NewLine() << "buffers = {";
- buffer_doc << PrintSep(buffer_docs, Doc::Indent(11, Doc::Text(",") <<
Doc::NewLine()));
- doc << Doc::Indent(2, buffer_doc) << "}";
- // print buffer_map
- std::vector<Doc> buffer_map_doc;
- for (const auto& it : op->buffer_map) {
- buffer_map_doc.push_back(Print(it.first) << ": " << Print(it.second));
+
+ if (op->buffer_map.size() != 0) {
+ // print buffer_map
+ std::vector<Doc> buffer_map_doc;
+ for (const auto& it : op->buffer_map) {
+ buffer_map_doc.push_back(Print(it.first) << ": " << Print(it.second));
+ }
+ doc << Doc::Indent(
+ 2, Doc::NewLine() << "buffer_map = {" << PrintSep(buffer_map_doc,
Doc::Text(", ")) << "}");
}
- doc << Doc::Indent(
- 2, Doc::NewLine() << "buffer_map = {" << PrintSep(buffer_map_doc,
Doc::Text(", ")) << "}");
doc << PrintBody(op->body);
return doc;
}
diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc
index 05582fb..7c3c830 100644
--- a/src/target/source/codegen_c.cc
+++ b/src/target/source/codegen_c.cc
@@ -914,7 +914,7 @@ void CodeGenC::VisitStmt_(const SeqStmtNode* op) {
}
void CodeGenC::VisitStmt_(const EvaluateNode* op) {
- if (is_const(op->value)) return;
+ if (is_const_int(op->value)) return;
const CallNode* call = op->value.as<CallNode>();
if (call) {
if (call->op.same_as(builtin::tvm_storage_sync())) {
diff --git a/src/target/source/codegen_cuda.cc
b/src/target/source/codegen_cuda.cc
index ae5e40a..7dc63d4 100644
--- a/src/target/source/codegen_cuda.cc
+++ b/src/target/source/codegen_cuda.cc
@@ -609,7 +609,7 @@ void CodeGenCUDA::VisitStmt_(const AllocateNode* op) {
}
void CodeGenCUDA::VisitStmt_(const EvaluateNode* op) {
- if (is_const(op->value)) return;
+ if (is_const_int(op->value)) return;
const CallNode* call = op->value.as<CallNode>();
if (call && call->op.same_as(builtin::tvm_global_barrier_kinit())) {
PrintIndent();
diff --git a/src/target/stackvm/codegen_stackvm.cc
b/src/target/stackvm/codegen_stackvm.cc
index 84b1492..9cad92d 100644
--- a/src/target/stackvm/codegen_stackvm.cc
+++ b/src/target/stackvm/codegen_stackvm.cc
@@ -429,7 +429,7 @@ void CodeGenStackVM::VisitStmt_(const SeqStmtNode* op) {
}
void CodeGenStackVM::VisitStmt_(const EvaluateNode* ev) {
- if (is_const(ev->value)) return;
+ if (is_const_int(ev->value)) return;
const CallNode* op = ev->value.as<CallNode>();
if (op && op->op.same_as(builtin::tvm_struct_set())) {
CHECK_EQ(op->args.size(), 4U);
diff --git a/src/tir/analysis/verify_ssa.cc b/src/tir/analysis/verify_ssa.cc
index c57cbf7..834ad09 100644
--- a/src/tir/analysis/verify_ssa.cc
+++ b/src/tir/analysis/verify_ssa.cc
@@ -35,44 +35,60 @@
namespace tvm {
namespace tir {
-class IRVerifySSA final : public StmtExprVisitor {
+class SSAVerifier final : public StmtExprVisitor {
public:
- bool is_ssa{true};
+ bool is_ssa_{true};
void VisitExpr(const PrimExpr& n) final {
- if (!is_ssa) return;
+ if (!is_ssa_) return;
StmtExprVisitor::VisitExpr(n);
}
void VisitStmt(const Stmt& n) final {
- if (!is_ssa) return;
+ if (!is_ssa_) return;
StmtExprVisitor::VisitStmt(n);
}
void VisitExpr_(const LetNode* op) final {
- MarkDef(op->var.get());
+ // Weaker SSA condition
+ // A single var can be binded in multiple lets
+ // but they have to bind to the same value.
+ // This is used to enable cases when we reuse a single let
+ // expression to cosntruct a nested expr.
+ // (let x = 1 in x + 1) * (let x = 1 in x + 1)
+ auto it = def_map_.find(op->var);
+ if (it != def_map_.end()) {
+ if (!deep_equal_(it->second, op->value)) {
+ is_ssa_ = false;
+ return;
+ }
+ } else {
+ MarkDef(op->var, op->value);
+ }
StmtExprVisitor::VisitExpr_(op);
}
+
void VisitStmt_(const LetStmtNode* op) final {
- MarkDef(op->var.get());
+ MarkDef(op->var, op->value);
StmtExprVisitor::VisitStmt_(op);
}
void VisitStmt_(const ForNode* op) final {
- MarkDef(op->loop_var.get());
+ MarkDef(op->loop_var, op->loop_var);
StmtExprVisitor::VisitStmt_(op);
}
void VisitStmt_(const AllocateNode* op) final {
- MarkDef(op->buffer_var.get());
+ MarkDef(op->buffer_var, op->buffer_var);
StmtExprVisitor::VisitStmt_(op);
}
void VisitExpr_(const VarNode* node) final {
+ auto var = GetRef<Var>(node);
if (match_scope_) {
- MarkDef(node, true);
+ MarkDef(var, var, true);
}
}
void Run(const PrimFunc& func) {
for (auto param : func->params) {
- MarkDef(param.get());
+ MarkDef(param, param);
}
for (auto kv : func->buffer_map) {
@@ -99,25 +115,28 @@ class IRVerifySSA final : public StmtExprVisitor {
}
private:
- void MarkDef(const VarNode* v, bool allow_dup = false) {
- if (defined_.count(v) != 0) {
+ void MarkDef(const Var& var, PrimExpr value, bool allow_dup = false) {
+ if (def_map_.count(var) != 0) {
if (!allow_dup) {
- is_ssa = false;
+ is_ssa_ = false;
return;
}
} else {
- defined_[v] = 1;
+ def_map_[var] = value;
}
}
// whether we are in match scope, where a var can occur multiple times.
bool match_scope_{false};
- std::unordered_map<const VarNode*, int> defined_;
+ // deep equal
+ ExprDeepEqual deep_equal_;
+ // def map, for let, maps to the bind value, for others maps to self.
+ std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> def_map_;
};
bool VerifySSA(const PrimFunc& func) {
- IRVerifySSA visitor;
+ SSAVerifier visitor;
visitor.Run(func);
- return visitor.is_ssa;
+ return visitor.is_ssa_;
}
TVM_REGISTER_GLOBAL("tir.analysis.verify_ssa").set_body_typed(VerifySSA);
diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc
index 0f67126..a0ba8d6 100644
--- a/src/tir/op/op.cc
+++ b/src/tir/op/op.cc
@@ -395,7 +395,7 @@ PrimExpr if_then_else(PrimExpr cond, PrimExpr true_value,
PrimExpr false_value)
// likely
PrimExpr likely(PrimExpr cond) {
- if (is_const(cond)) return cond;
+ if (is_const_int(cond)) return cond;
return tir::Call(cond.dtype(), tir::builtin::likely(), {cond});
}
diff --git a/src/tir/transforms/loop_partition.cc
b/src/tir/transforms/loop_partition.cc
index 2fb8003..1876dfe 100644
--- a/src/tir/transforms/loop_partition.cc
+++ b/src/tir/transforms/loop_partition.cc
@@ -96,7 +96,7 @@ class CandidateSelector final : public StmtExprVisitor {
void VisitStmt_(const ForNode* op) final {
// partition const loop when sets partition_const_loop_
- if (!is_const(op->min) || !is_const(op->extent) || partition_const_loop_) {
+ if (!is_const_int(op->min) || !is_const_int(op->extent) ||
partition_const_loop_) {
const VarNode* var = op->loop_var.get();
record_.insert({var, false});
StmtExprVisitor::VisitStmt_(op);
@@ -115,7 +115,7 @@ class CandidateSelector final : public StmtExprVisitor {
CHECK(iv);
Var var = iv->var;
runtime::ThreadScope scope =
runtime::ThreadScope::Create(iv->thread_tag);
- if ((scope.rank == 0) && (!is_const(op->value) ||
partition_const_loop_)) {
+ if ((scope.rank == 0) && (!is_const_int(op->value) ||
partition_const_loop_)) {
record_.insert({var.get(), false});
StmtExprVisitor::VisitStmt_(op);
if (record_.at(var.get()) && !no_split_) {
diff --git a/src/tir/transforms/simplify.cc b/src/tir/transforms/simplify.cc
index 3be2329..3c8a934 100644
--- a/src/tir/transforms/simplify.cc
+++ b/src/tir/transforms/simplify.cc
@@ -54,9 +54,18 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
return Parent::VisitStmt_(op);
}
+ bool CanInlineLetStmt(const LetStmtNode* op) {
+ if (is_const_number(op->value)) return true;
+ if (op->value.as<VarNode>()) return true;
+ // Won't face the deep expression explosion problem as in Let expression.
+ // attempt to inline as much as possible if the value integer type(can be
index).
+ if (!op->value.dtype().is_int()) return false;
+ return !tir::HasSideEffect(op->value);
+ }
+
Stmt VisitStmt_(const LetStmtNode* op) {
PrimExpr value = this->VisitExpr(op->value);
- if (!tir::HasSideEffect(value)) {
+ if (CanInlineLetStmt(op)) {
// it is fine to discard the let binding
// because the call to simplify will always inline the var.
analyzer_->Bind(op->var, value);
diff --git a/src/tir/transforms/split_host_device.cc
b/src/tir/transforms/split_host_device.cc
index f339c56..75ae743 100644
--- a/src/tir/transforms/split_host_device.cc
+++ b/src/tir/transforms/split_host_device.cc
@@ -70,7 +70,7 @@ class VarUseDefAnalysis : public StmtExprMutator {
this->HandleDef(op->var.get());
Stmt body = this->VisitStmt(op->body);
// eliminate unreferenced let
- if (use_count_.at(op->var.get()) == 0 && !HasSideEffect(op->value)) {
+ if (use_count_.at(op->var.get()) == 0 && !HasSideEffect(op->value) &&
simplify_let_) {
return body;
} else {
PrimExpr value = this->VisitExpr(op->value);
@@ -101,7 +101,7 @@ class VarUseDefAnalysis : public StmtExprMutator {
this->HandleDef(op->var.get());
PrimExpr body = this->VisitExpr(op->body);
// eliminate unreferenced let
- if (use_count_.at(op->var.get()) == 0 && !HasSideEffect(op->value)) {
+ if (use_count_.at(op->var.get()) == 0 && !HasSideEffect(op->value) &&
simplify_let_) {
return body;
} else {
PrimExpr value = this->VisitExpr(op->value);
@@ -149,6 +149,7 @@ class VarUseDefAnalysis : public StmtExprMutator {
// The fields are publically readible to
// be accessible to the users.
bool visit_thread_extent_{true};
+ bool simplify_let_{true};
Array<Var> undefined_;
Array<IterVar> thread_axis_;
Array<PrimExpr> thread_extent_;
@@ -158,6 +159,7 @@ class VarUseDefAnalysis : public StmtExprMutator {
Array<Var> UndefinedVars(const Stmt& stmt, const Array<Var>& args) {
VarUseDefAnalysis m;
+ m.simplify_let_ = false;
for (Var arg : args) {
m.use_count_[arg.get()] = 0;
}
diff --git a/src/tir/transforms/vectorize_loop.cc
b/src/tir/transforms/vectorize_loop.cc
index e015990..bf54ada 100644
--- a/src/tir/transforms/vectorize_loop.cc
+++ b/src/tir/transforms/vectorize_loop.cc
@@ -23,6 +23,7 @@
// Loop vectorizer as in Halide pipeline.
#include <tvm/arith/analyzer.h>
#include <tvm/runtime/registry.h>
+#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
@@ -91,15 +92,21 @@ class VecAllocAccess : public StmtExprMutator {
int var_lanes_;
};
-class Vectorizer : public StmtExprMutator {
+// We use ExprFunctor directly instead of StmtExprMutator
+// This is because the transformation can change the dtype of the Expr
+// The existing ExprMutator transformation rules may not be well defined.
+class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const
PrimExpr&)> {
public:
+ using ExprFunctor::VisitExpr;
+ using StmtMutator::operator();
+
Vectorizer(Var var, int var_lanes) : var_(var), var_lanes_(var_lanes) {
ramp_ = Ramp(0, 1, var_lanes);
}
Stmt VisitStmt(const Stmt& stmt) final {
CHECK(!need_scalarize_);
- Stmt ret = StmtExprMutator::VisitStmt(stmt);
+ Stmt ret = StmtMutator::VisitStmt(stmt);
if (need_scalarize_) {
need_scalarize_ = false;
return Scalarize(stmt);
@@ -108,6 +115,8 @@ class Vectorizer : public StmtExprMutator {
}
}
+ PrimExpr VisitExpr(const PrimExpr& e) final { return
ExprFunctor::VisitExpr(e); }
+
PrimExpr VisitExpr_(const AddNode* op) final {
return AddSubVec(op, [](PrimExpr a, PrimExpr b) { return a + b; });
}
@@ -151,6 +160,16 @@ class Vectorizer : public StmtExprMutator {
PrimExpr VisitExpr_(const GENode* op) final { return BinaryVec<GE>(op); }
PrimExpr VisitExpr_(const AndNode* op) final { return BinaryVec<And>(op); }
PrimExpr VisitExpr_(const OrNode* op) final { return BinaryVec<Or>(op); }
+
+ PrimExpr VisitExpr_(const NotNode* op) final {
+ PrimExpr a = this->VisitExpr(op->a);
+ if (a.same_as(op->a)) {
+ return GetRef<PrimExpr>(op);
+ } else {
+ return !(a);
+ }
+ }
+
PrimExpr VisitExpr_(const RampNode* op) final {
PrimExpr base = this->VisitExpr(op->base);
PrimExpr stride = this->VisitExpr(op->stride);
@@ -170,6 +189,20 @@ class Vectorizer : public StmtExprMutator {
}
return Shuffle::Concat(elems);
}
+
+ PrimExpr VisitExpr_(const BroadcastNode* op) final {
+ PrimExpr value = this->VisitExpr(op->value);
+ if (value.dtype().lanes() != 1) {
+ need_scalarize_ = true;
+ return GetRef<PrimExpr>(op);
+ }
+ if (value.same_as(op->value)) {
+ return GetRef<PrimExpr>(op);
+ } else {
+ return Broadcast(op->value, op->lanes);
+ }
+ }
+
PrimExpr VisitExpr_(const SelectNode* op) final {
PrimExpr cond = this->VisitExpr(op->condition);
PrimExpr t = this->VisitExpr(op->true_value);
@@ -189,14 +222,25 @@ class Vectorizer : public StmtExprMutator {
return Cast(op->dtype.with_lanes(value.dtype().lanes()), value);
}
}
+
+ PrimExpr VisitExpr_(const FloatImmNode* op) final { return
GetRef<PrimExpr>(op); }
+
+ PrimExpr VisitExpr_(const IntImmNode* op) final { return
GetRef<PrimExpr>(op); }
+
+ PrimExpr VisitExpr_(const StringImmNode* op) final { return
GetRef<PrimExpr>(op); }
+
// Variable
- PrimExpr VisitExpr_(const VarNode* v) final {
- if (v == var_.get()) {
+ PrimExpr VisitExpr_(const VarNode* op) final {
+ Var var = GetRef<Var>(op);
+
+ if (var.same_as(var_)) {
return ramp_;
- } else if (lets_.count(v)) {
- return lets_[v];
+ }
+ auto it = let_binding_.find(var);
+ if (it != let_binding_.end()) {
+ return it->second;
} else {
- return GetRef<PrimExpr>(v);
+ return std::move(var);
}
}
// IfThenElse expr
@@ -267,12 +311,23 @@ class Vectorizer : public StmtExprMutator {
// Let
PrimExpr VisitExpr_(const LetNode* op) final {
PrimExpr value = this->VisitExpr(op->value);
- CHECK(!lets_.count(op->var.get())) << "not SSA";
+ // Weaker SSA condition
+ // A single var can be binded in multiple lets
+ // but they have to bind to the same value.
+ // This is used to allow cases when we reuse a single let
+ // expression to cosntruct a nested expr.
+ // (let x = 1 in x + 1) * (let x = 1 in x + 1)
+ auto it = let_binding_.find(op->var);
+ if (it != let_binding_.end()) {
+ CHECK(deep_equal_(it->second, value))
+ << "Let cannot bind the same var to two different values";
+ }
if (value.dtype().lanes() != op->value.dtype().lanes()) {
- Var v(op->var->name_hint, value.dtype());
- lets_[op->var.get()] = v;
- return Let(v, value, this->VisitExpr(op->body));
+ Var new_var(op->var->name_hint, value.dtype());
+ let_binding_[op->var] = new_var;
+ return Let(new_var, value, this->VisitExpr(op->body));
} else {
+ let_binding_[op->var] = op->var;
PrimExpr body = this->VisitExpr(op->body);
if (value.same_as(op->value) && body.same_as(op->body)) {
return GetRef<PrimExpr>(op);
@@ -281,10 +336,6 @@ class Vectorizer : public StmtExprMutator {
}
}
}
- Stmt VisitStmt_(const ProducerStoreNode* op) final {
- LOG(FATAL) << "ProducerProvide is cannot appear in a TIR PrimFunc";
- return Stmt();
- }
// Store
Stmt VisitStmt_(const StoreNode* op) final {
PrimExpr value = this->VisitExpr(op->value);
@@ -338,8 +389,23 @@ class Vectorizer : public StmtExprMutator {
}
// LetStmt
Stmt VisitStmt_(const LetStmtNode* op) final {
- LOG(WARNING) << "Cannot vectorize with LetStmt, remove it with Simplify
Before Vectorize";
- return Scalarize(GetRef<Stmt>(op));
+ PrimExpr value = this->VisitExpr(op->value);
+ CHECK(!let_binding_.count(op->var)) << "SSA violation, a single var is
binded twice";
+ let_binding_[op->var] = value;
+
+ if (value.dtype().lanes() != op->value.dtype().lanes()) {
+ Var new_var(op->var->name_hint, value.dtype());
+ let_binding_[op->var] = new_var;
+ return LetStmt(new_var, value, this->VisitStmt(op->body));
+ } else {
+ let_binding_[op->var] = op->var;
+ Stmt body = this->VisitStmt(op->body);
+ if (value.same_as(op->value) && body.same_as(op->body)) {
+ return GetRef<Stmt>(op);
+ } else {
+ return LetStmt(op->var, value, body);
+ }
+ }
}
// Allocate
Stmt VisitStmt_(const AllocateNode* op) final {
@@ -364,6 +430,7 @@ class Vectorizer : public StmtExprMutator {
body = this->VisitStmt(body);
return Allocate(op->buffer_var, op->dtype, extents, condition, body);
}
+
// scalarize the statment
Stmt Scalarize(Stmt stmt) {
Var idx(var_->name_hint + ".s", var_->dtype);
@@ -371,10 +438,17 @@ class Vectorizer : public StmtExprMutator {
stmt = Substitute(stmt, values);
return For(idx, 0, var_lanes_, ForType::Serial, DeviceAPI::None, stmt);
}
+ // ProducerStore
+ Stmt VisitStmt_(const ProducerStoreNode* op) final {
+ LOG(FATAL) << "ProducerProvide is cannot appear in a TIR PrimFunc";
+ return Stmt();
+ }
private:
// analyzer
arith::Analyzer analyzer_;
+ // deep equal
+ ExprDeepEqual deep_equal_;
// variable to be replaced
Var var_;
// the lanes.
@@ -383,8 +457,8 @@ class Vectorizer : public StmtExprMutator {
PrimExpr ramp_;
// flag to mark requirment of scalarization.
bool need_scalarize_{false};
- // The lets
- std::unordered_map<const VarNode*, PrimExpr> lets_;
+ // Let binding
+ std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual>
let_binding_;
// vectorizable property
OpAttrMap<TVectorizable> op_vectorizable_ =
Op::GetAttrMap<TVectorizable>("TVectorizable");
diff --git a/tests/python/unittest/test_arith_const_int_bound.py
b/tests/python/unittest/test_arith_const_int_bound.py
index 4829b97..c5794cd 100644
--- a/tests/python/unittest/test_arith_const_int_bound.py
+++ b/tests/python/unittest/test_arith_const_int_bound.py
@@ -284,7 +284,16 @@ def test_size_var_bound():
assert bd.max_value == bd.POS_INF
+def test_let_bound():
+ analyzer = tvm.arith.Analyzer()
+ x = te.var("x")
+ bd = analyzer.const_int_bound(tvm.tir.Let(x, 1, x + 1))
+ assert bd.min_value == 2
+ assert bd.max_value == 2
+
+
if __name__ == "__main__":
+ test_let_bound()
test_dtype_bound()
test_cast_bound()
test_add_sub_bound()
diff --git a/tests/python/unittest/test_arith_modular_set.py
b/tests/python/unittest/test_arith_modular_set.py
index 01180d2..7d9f739 100644
--- a/tests/python/unittest/test_arith_modular_set.py
+++ b/tests/python/unittest/test_arith_modular_set.py
@@ -159,8 +159,17 @@ def test_intersect():
assert m.coeff == 105
assert m.base == 23
+def test_let():
+ analyzer = tvm.arith.Analyzer()
+ x = te.var("x")
+ y = te.var("y")
+ m = analyzer.modular_set(tvm.tir.Let(x, y * 10, x + 1))
+ m.coeff = 10
+ m.base = 1
+
if __name__ == "__main__":
+ test_let()
test_cast()
test_add_sub()
test_mul()
diff --git a/tests/python/unittest/test_tir_analysis_verify_ssa.py
b/tests/python/unittest/test_tir_analysis_verify_ssa.py
index 8a15c36..57dd826 100644
--- a/tests/python/unittest/test_tir_analysis_verify_ssa.py
+++ b/tests/python/unittest/test_tir_analysis_verify_ssa.py
@@ -27,6 +27,16 @@ def test_verify_ssa():
assert(not tvm.tir.analysis.verify_ssa(
tvm.tir.PrimFunc([x, y], tvm.tir.LetStmt(x, 1, z))))
+def test_verify_weak_let_ssa():
+ x = te.var('x')
+ z1 = tvm.tir.Let(x, 1, x + 1)
+ z2 = tvm.tir.Let(x, 2, x + 2)
+
+ assert(tvm.tir.analysis.verify_ssa(
+ tvm.tir.PrimFunc([], tvm.tir.Evaluate(z1 + z1))))
+ assert(not tvm.tir.analysis.verify_ssa(
+ tvm.tir.PrimFunc([], tvm.tir.Evaluate(z1 * z2))))
if __name__ == "__main__":
test_verify_ssa()
+ test_verify_weak_let_ssa()
diff --git a/tests/python/unittest/test_tir_nodes.py
b/tests/python/unittest/test_tir_nodes.py
index ab730cd..c182d9e 100644
--- a/tests/python/unittest/test_tir_nodes.py
+++ b/tests/python/unittest/test_tir_nodes.py
@@ -274,7 +274,8 @@ def test_prim_func():
func = tvm.tir.PrimFunc(
[x, y, b], stmt)
-
+ # make sure we can print
+ func.astext()
assert func.buffer_map[func.params[2]].same_as(b)
assert len(func.buffer_map) == 1
diff --git a/tests/python/unittest/test_tir_transform_vectorize.py
b/tests/python/unittest/test_tir_transform_vectorize.py
index a69c9d3..0516b4a 100644
--- a/tests/python/unittest/test_tir_transform_vectorize.py
+++ b/tests/python/unittest/test_tir_transform_vectorize.py
@@ -81,6 +81,20 @@ def test_vectorize_with_if():
assert isinstance(stmt.else_case, tvm.tir.For)
+def test_vectorize_let():
+ v = tvm.tir.Var("v", "float32")
+ ib = tvm.tir.ir_builder.create()
+ A = ib.pointer("float32", name="A")
+ with ib.for_range(0, 4, for_type="vectorize") as i:
+ ib.emit(lambda body: tvm.tir.LetStmt(v, A[i] + 1, body))
+ A[i] = v + 2
+
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A], ib.get()))
+ stmt = tvm.tir.transform.VectorizeLoop()(mod)["main"].body
+ assert isinstance(stmt, tvm.tir.LetStmt)
+ assert stmt.value.dtype == "float32x4"
+
+
def test_vectorize_with_le_cond():
n = te.var('n')
ib = tvm.tir.ir_builder.create()
@@ -153,3 +167,4 @@ if __name__ == "__main__":
test_vectorize_if_then_else()
test_vectorize_with_le_cond()
test_vectorize_with_ge_cond()
+ test_vectorize_let()