This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new 4fbfaca  [TIR] Improve Let/LetStmt support. (#5949)
4fbfaca is described below

commit 4fbfaca5fc417b09330f7592b21efdc6fb1cf51b
Author: Tianqi Chen <[email protected]>
AuthorDate: Sun Jun 28 09:22:11 2020 -0700

    [TIR] Improve Let/LetStmt support. (#5949)
    
    Let/LetStmt are useful primitives to create variable bindings.
    While let binding are harmful for simplification and integer analysis,
    they are useful for other cases:
    
    - C0: LetStmt is useful to represent a step that has side effect(e.g. call 
a PRNG)
    - C1: Let expression can be used to create deep nested expression for 
complicated functions.
    
    This PR improves the let support in the following ways:
    - Enable vectorization support for let
    - Change let simplification strategy to simplify the most trivial case
      while ignore more complicated cases(to avoid deep nest explosion)
    - Enhance arith module to handle const bound and modular set for let.
    
    The overall recommendation is to only use Let in the cases when 
necessary(C0, C1).
---
 include/tvm/arith/analyzer.h                       |  32 +++---
 include/tvm/tir/op.h                               |  26 ++++-
 src/arith/analyzer.cc                              |  24 ++---
 src/arith/const_int_bound.cc                       |  33 ++++--
 src/arith/modular_set.cc                           |  21 +++-
 src/arith/rewrite_simplify.cc                      |  16 ++-
 src/arith/rewrite_simplify.h                       |   7 ++
 src/contrib/hybrid/codegen_hybrid.cc               |   4 +-
 src/printer/tir_text_printer.cc                    |  88 +++++++++-------
 src/target/source/codegen_c.cc                     |   2 +-
 src/target/source/codegen_cuda.cc                  |   2 +-
 src/target/stackvm/codegen_stackvm.cc              |   2 +-
 src/tir/analysis/verify_ssa.cc                     |  53 ++++++----
 src/tir/op/op.cc                                   |   2 +-
 src/tir/transforms/loop_partition.cc               |   4 +-
 src/tir/transforms/simplify.cc                     |  11 +-
 src/tir/transforms/split_host_device.cc            |   6 +-
 src/tir/transforms/vectorize_loop.cc               | 112 +++++++++++++++++----
 .../python/unittest/test_arith_const_int_bound.py  |   9 ++
 tests/python/unittest/test_arith_modular_set.py    |   9 ++
 .../unittest/test_tir_analysis_verify_ssa.py       |  10 ++
 tests/python/unittest/test_tir_nodes.py            |   3 +-
 .../unittest/test_tir_transform_vectorize.py       |  15 +++
 23 files changed, 354 insertions(+), 137 deletions(-)

diff --git a/include/tvm/arith/analyzer.h b/include/tvm/arith/analyzer.h
index 8033294..c4ee7b5 100644
--- a/include/tvm/arith/analyzer.h
+++ b/include/tvm/arith/analyzer.h
@@ -128,17 +128,17 @@ class ConstIntBoundAnalyzer {
    *
    * \param var The variable of interest.
    * \param info The bound information.
-   * \param override Whether do we allow override of existing information.
+   * \param allow_override Whether do we allow override of existing 
information.
    */
-  TVM_DLL void Update(const Var& var, const ConstIntBound& info, bool override 
= false);
+  TVM_DLL void Update(const Var& var, const ConstIntBound& info, bool 
allow_override = false);
   /*!
    * \brief Bind variable to a range.
    *
    * \param var The variable.
    * \param range The range we bind to.
-   * \param override Whether we allow overriding an existing var's range.
+   * \param allow_override Whether we allow overriding an existing var's range.
    */
-  TVM_DLL void Bind(const Var& var, const Range& range, bool override = false);
+  TVM_DLL void Bind(const Var& var, const Range& range, bool allow_override = 
false);
 
  private:
   friend class Analyzer;
@@ -217,9 +217,9 @@ class ModularSetAnalyzer {
    *
    * \param var The variable of interest.
    * \param info The bound information.
-   * \param override Whether do we allow override of existing information.
+   * \param allow_override Whether do we allow override of existing 
information.
    */
-  TVM_DLL void Update(const Var& var, const ModularSet& info, bool override = 
false);
+  TVM_DLL void Update(const Var& var, const ModularSet& info, bool 
allow_override = false);
 
  private:
   friend class Analyzer;
@@ -256,9 +256,9 @@ class RewriteSimplifier {
    *
    * \param var The variable of interest.
    * \param new_expr
-   * \param override Whether do we allow override of existing information.
+   * \param allow_override Whether do we allow override of existing 
information.
    */
-  TVM_DLL void Update(const Var& var, const PrimExpr& new_expr, bool override 
= false);
+  TVM_DLL void Update(const Var& var, const PrimExpr& new_expr, bool 
allow_override = false);
 
   std::function<void()> EnterConstraint(const PrimExpr& constraint);
 
@@ -290,9 +290,9 @@ class CanonicalSimplifier {
    *
    * \param var The variable of interest.
    * \param new_expr
-   * \param override Whether do we allow override of existing information.
+   * \param allow_override Whether do we allow override of existing 
information.
    */
-  TVM_DLL void Update(const Var& var, const PrimExpr& new_expr, bool override 
= false);
+  TVM_DLL void Update(const Var& var, const PrimExpr& new_expr, bool 
allow_override = false);
 
  private:
   friend class Analyzer;
@@ -404,9 +404,9 @@ class TVM_DLL Analyzer {
    *
    * \param var The variable.
    * \param expr The expression we bind to.
-   * \param override Whether we allow overriding an existing var's expression.
+   * \param allow_override Whether we allow overriding an existing var's 
expression.
    */
-  void Bind(const Var& var, const PrimExpr& expr, bool override = false);
+  void Bind(const Var& var, const PrimExpr& expr, bool allow_override = false);
   /*!
    * \brief Notify all the sub-analyzers that var
    *        is created and binded to a range.
@@ -415,16 +415,16 @@ class TVM_DLL Analyzer {
    *
    * \param var The variable.
    * \param range The range we bind to.
-   * \param override Whether we allow overriding an existing var's expression.
+   * \param allow_override Whether we allow overriding an existing var's 
expression.
    */
-  void Bind(const Var& var, const Range& range, bool override = false);
+  void Bind(const Var& var, const Range& range, bool allow_override = false);
   /*!
    * \brief Bind all the vars in the Map
    *
    * \param variables The {variable -> range} map.
-   * \param override Whether we allow overriding an existing var's expression.
+   * \param allow_override Whether we allow overriding an existing var's 
expression.
    */
-  void Bind(const Map<Var, Range>& variables, bool override = false);
+  void Bind(const Map<Var, Range>& variables, bool allow_override = false);
   /*!
    * \brief Whether can we prove expr >= val.
 
diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h
index 34cb52f..31ce13c 100644
--- a/include/tvm/tir/op.h
+++ b/include/tvm/tir/op.h
@@ -671,11 +671,18 @@ inline bool is_one(const PrimExpr& x) { return 
is_const_int(x, 1); }
 inline bool is_zero(const PrimExpr& x) { return is_const_int(x, 0); }
 
 /*!
- * \brief Check whether x is a constant.
+ * \brief Check whether x is an integer constant.
  * \note This only return true for integer types.
  * \return whether x is constant
  */
-inline bool is_const(const PrimExpr& x);
+inline bool is_const_int(const PrimExpr& x);
+
+/*!
+ * \brief Check whether x is an integer/float constant.
+ * \note This only return true for integer types.
+ * \return whether x is constant
+ */
+inline bool is_const_number(const PrimExpr& x);
 
 /*!
  * \brief Left fold.
@@ -699,7 +706,7 @@ inline PrimExpr foldl(FReduce freduce, PrimExpr init_value, 
const Array<PrimExpr
 TVM_DLL bool is_const_power_of_two_integer(const PrimExpr& x, int* shift);
 
 // Implementation details after this
-inline bool is_const(const PrimExpr& x) {
+inline bool is_const_int(const PrimExpr& x) {
   if (x.as<tir::IntImmNode>()) {
     return true;
   } else if (const auto* op = x.as<tir::BroadcastNode>()) {
@@ -711,6 +718,17 @@ inline bool is_const(const PrimExpr& x) {
   return false;
 }
 
+inline bool is_const_number(const PrimExpr& x) {
+  if (x.as<tir::IntImmNode>()) {
+    return true;
+  } else if (x.as<tir::FloatImmNode>()) {
+    return true;
+  } else if (const auto* op = x.as<tir::BroadcastNode>()) {
+    return (op->value->IsInstance<tir::IntImmNode>() || 
op->value->IsInstance<tir::FloatImmNode>());
+  }
+  return false;
+}
+
 inline bool is_positive_const(const PrimExpr& a) {
   if (const tir::IntImmNode* op = a.as<tir::IntImmNode>()) {
     return op->value > 0;
@@ -742,7 +760,7 @@ inline bool is_const_int(const PrimExpr& x, int64_t value) {
 inline bool is_no_op(const tir::Stmt& stmt) {
   if (!stmt.defined()) return true;
   if (const auto* op = stmt.as<tir::EvaluateNode>()) {
-    return is_const(op->value);
+    return is_const_int(op->value);
   }
   if (const auto* op = stmt.as<tir::SeqStmtNode>()) {
     return op->seq.size() == 0;
diff --git a/src/arith/analyzer.cc b/src/arith/analyzer.cc
index 037c766..c7a8365 100644
--- a/src/arith/analyzer.cc
+++ b/src/arith/analyzer.cc
@@ -35,31 +35,31 @@ Analyzer::Analyzer()
       canonical_simplify(this),
       int_set(this) {}
 
-void Analyzer::Bind(const Var& var, const PrimExpr& expr, bool override) {
+void Analyzer::Bind(const Var& var, const PrimExpr& expr, bool allow_override) 
{
   PrimExpr new_expr = expr;
   new_expr = this->canonical_simplify(new_expr);
   new_expr = this->rewrite_simplify(new_expr);
 
-  this->const_int_bound.Update(var, this->const_int_bound(new_expr), override);
-  this->modular_set.Update(var, this->modular_set(new_expr), override);
-  this->rewrite_simplify.Update(var, new_expr, override);
-  this->canonical_simplify.Update(var, new_expr, override);
+  this->const_int_bound.Update(var, this->const_int_bound(new_expr), 
allow_override);
+  this->modular_set.Update(var, this->modular_set(new_expr), allow_override);
+  this->rewrite_simplify.Update(var, new_expr, allow_override);
+  this->canonical_simplify.Update(var, new_expr, allow_override);
 }
 
-void Analyzer::Bind(const Var& var, const Range& range, bool override) {
+void Analyzer::Bind(const Var& var, const Range& range, bool allow_override) {
   CHECK(range.defined());
   if (tir::is_one(range->extent)) {
-    this->Bind(var, range->min, override);
+    this->Bind(var, range->min, allow_override);
   } else {
-    this->const_int_bound.Bind(var, range, override);
+    this->const_int_bound.Bind(var, range, allow_override);
   }
   // skip modular_set
   // skip rewrite simplify
 }
 
-void Analyzer::Bind(const Map<Var, Range>& variables, bool override) {
+void Analyzer::Bind(const Map<Var, Range>& variables, bool allow_override) {
   for (const auto& iter : variables) {
-    this->Bind(iter.first, iter.second, override);
+    this->Bind(iter.first, iter.second, allow_override);
   }
 }
 
@@ -116,9 +116,9 @@ bool Analyzer::CanProve(const PrimExpr& expr) {
 }
 
 PrimExpr Analyzer::Simplify(const PrimExpr& expr) {
-  if (tir::is_const(expr)) return expr;
+  if (tir::is_const_int(expr)) return expr;
   auto res = this->rewrite_simplify(expr);
-  if (tir::is_const(res)) return res;
+  if (tir::is_const_int(res)) return res;
   res = this->canonical_simplify(res);
   return res;
 }
diff --git a/src/arith/const_int_bound.cc b/src/arith/const_int_bound.cc
index 8c90249..be830d3 100644
--- a/src/arith/const_int_bound.cc
+++ b/src/arith/const_int_bound.cc
@@ -96,17 +96,17 @@ class ConstIntBoundAnalyzer::Impl
     BoundInfo(PrimExpr expr, Entry bound) : expr(expr), bound(bound) {}
   };
 
-  void Bind(const Var& var, const Range& range, bool override) {
+  void Bind(const Var& var, const Range& range, bool allow_override) {
     Entry a = VisitExpr(range->min);
     Entry b = VisitExpr(range->extent);
     Entry ret;
     ret.min_value = a.min_value;
     ret.max_value = InfAwareAdd(a.max_value, InfAwareAdd(b.max_value, -1));
-    Update(var, ret, override);
+    Update(var, ret, allow_override);
   }
 
-  void Update(const Var& var, const Entry& info, bool override) {
-    if (!override) {
+  void Update(const Var& var, const Entry& info, bool allow_override) {
+    if (!allow_override) {
       auto it = var_map_.find(var);
       if (it != var_map_.end()) {
         CHECK(it->second == info) << "Trying to update var \'" << var << "\'"
@@ -119,8 +119,21 @@ class ConstIntBoundAnalyzer::Impl
     var_map_[var] = info;
   }
 
-  void Update(const Var& var, const ConstIntBound& info, bool override) {
-    Update(var, MakeBound(info->min_value, info->max_value), override);
+  Entry VisitExpr_(const LetNode* op) final {
+    auto it = var_map_.find(op->var);
+    // if the var has not been binded, update the info.
+    if (it == var_map_.end()) {
+      var_map_[op->var] = this->VisitExpr(op->value);
+      Entry ret = VisitExpr(op->body);
+      var_map_.erase(op->var);
+      return ret;
+    } else {
+      return VisitExpr(op->body);
+    }
+  }
+
+  void Update(const Var& var, const ConstIntBound& info, bool allow_override) {
+    Update(var, MakeBound(info->min_value, info->max_value), allow_override);
   }
 
   // Override visitor behaviors
@@ -558,12 +571,12 @@ ConstIntBound ConstIntBoundAnalyzer::operator()(const 
PrimExpr& expr, BoundMapTy
   return ConstIntBound(ret.min_value, ret.max_value);
 }
 
-void ConstIntBoundAnalyzer::Update(const Var& var, const ConstIntBound& info, 
bool override) {
-  impl_->Update(var, info, override);
+void ConstIntBoundAnalyzer::Update(const Var& var, const ConstIntBound& info, 
bool allow_override) {
+  impl_->Update(var, info, allow_override);
 }
 
-void ConstIntBoundAnalyzer::Bind(const Var& var, const Range& range, bool 
override) {
-  impl_->Bind(var, range, override);
+void ConstIntBoundAnalyzer::Bind(const Var& var, const Range& range, bool 
allow_override) {
+  impl_->Bind(var, range, allow_override);
 }
 
 std::function<void()> ConstIntBoundAnalyzer::EnterConstraint(const PrimExpr& 
constraint) {
diff --git a/src/arith/modular_set.cc b/src/arith/modular_set.cc
index 108f08c..8c41760 100644
--- a/src/arith/modular_set.cc
+++ b/src/arith/modular_set.cc
@@ -89,8 +89,8 @@ class ModularSetAnalyzer::Impl : public 
ExprFunctor<ModularSetAnalyzer::Entry(co
  public:
   explicit Impl(Analyzer* parent) : parent_(parent) {}
 
-  void Update(const Var& var, const ModularSet& info, bool override) {
-    if (!override) {
+  void Update(const Var& var, const ModularSet& info, bool allow_override) {
+    if (!allow_override) {
       auto it = var_map_.find(var);
       if (it != var_map_.end()) {
         CHECK(it->second == info) << "Trying to update var \'" << var << "\'"
@@ -118,6 +118,19 @@ class ModularSetAnalyzer::Impl : public 
ExprFunctor<ModularSetAnalyzer::Entry(co
   // Override visitor behaviors
   Entry VisitExprDefault_(const Object* op) final { return Everything(); }
 
+  Entry VisitExpr_(const LetNode* op) final {
+    auto it = var_map_.find(op->var);
+    // if the var has not been binded, update the info.
+    if (it == var_map_.end()) {
+      var_map_[op->var] = this->VisitExpr(op->value);
+      Entry ret = VisitExpr(op->body);
+      var_map_.erase(op->var);
+      return ret;
+    } else {
+      return VisitExpr(op->body);
+    }
+  }
+
   Entry VisitExpr_(const CastNode* op) final { return VisitExpr(op->value); }
 
   Entry VisitExpr_(const IntImmNode* op) final { return Entry(0, op->value); }
@@ -315,8 +328,8 @@ ModularSet ModularSetAnalyzer::operator()(const PrimExpr& 
expr) {
   return ModularSet(ret.coeff, ret.base);
 }
 
-void ModularSetAnalyzer::Update(const Var& var, const ModularSet& info, bool 
override) {
-  impl_->Update(var, info, override);
+void ModularSetAnalyzer::Update(const Var& var, const ModularSet& info, bool 
allow_override) {
+  impl_->Update(var, info, allow_override);
 }
 
 std::function<void()> ModularSetAnalyzer::EnterConstraint(const PrimExpr& 
constraint) {
diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc
index 898eecc..e9d640a 100644
--- a/src/arith/rewrite_simplify.cc
+++ b/src/arith/rewrite_simplify.cc
@@ -1519,7 +1519,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const 
CallNode* op) {
   op = ret.as<CallNode>();
   if (op == nullptr) return ret;
 
-  if (op->op.same_as(tir::builtin::likely()) && is_const(op->args[0])) {
+  if (op->op.same_as(tir::builtin::likely()) && is_const_int(op->args[0])) {
     return op->args[0];
   } else if (op->op.same_as(tir::builtin::shift_right())) {
     if (op->args[0].as<IntImmNode>() && op->args[1].as<IntImmNode>()) {
@@ -1559,9 +1559,17 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const 
CastNode* op) {
   return cast(op->dtype, op->value);
 }
 
+bool RewriteSimplifier::Impl::CanInlineLet(const LetNode* op) {
+  // Only inline trivial bindings to avoid deep expression explosion
+  // when we need let to construct complicated expressions.
+  if (is_const_number(op->value)) return true;
+  if (op->value.as<VarNode>()) return true;
+  return false;
+}
+
 PrimExpr RewriteSimplifier::Impl::VisitExpr_(const LetNode* op) {
   PrimExpr value = this->VisitExpr(op->value);
-  if (!tir::HasSideEffect(value)) {
+  if (CanInlineLet(op)) {
     // it is fine to discard the let binding
     // because the value will always be inlined in the simplifier.
     analyzer_->Bind(op->var, value);
@@ -1587,8 +1595,8 @@ PrimExpr RewriteSimplifier::operator()(const PrimExpr& 
expr) {
   return res;
 }
 
-void RewriteSimplifier::Update(const Var& var, const PrimExpr& info, bool 
override) {
-  impl_->Update(var, info, override);
+void RewriteSimplifier::Update(const Var& var, const PrimExpr& info, bool 
allow_override) {
+  impl_->Update(var, info, allow_override);
 }
 
 std::function<void()> RewriteSimplifier::EnterConstraint(const PrimExpr& 
constraint) {
diff --git a/src/arith/rewrite_simplify.h b/src/arith/rewrite_simplify.h
index 68c0dd2..258f833 100644
--- a/src/arith/rewrite_simplify.h
+++ b/src/arith/rewrite_simplify.h
@@ -98,6 +98,13 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer 
{
    */
   CompareResult TryCompare(const PrimExpr& x, int64_t val);
 
+  /*!
+   * \brief Internal function to check whether or not to inline let.
+   * \param op The let expr.
+   * \return The inline decision.
+   */
+  bool CanInlineLet(const LetNode* op);
+
  private:
   // Whether x >= val
   bool CanProveGreaterEqual(const PrimExpr& x, int64_t val) {
diff --git a/src/contrib/hybrid/codegen_hybrid.cc 
b/src/contrib/hybrid/codegen_hybrid.cc
index b65ae91..67765f0 100644
--- a/src/contrib/hybrid/codegen_hybrid.cc
+++ b/src/contrib/hybrid/codegen_hybrid.cc
@@ -381,7 +381,7 @@ void CodeGenHybrid::VisitStmt_(const ForNode* op) {
 
 bool is_noop(const Stmt& stmt) {
   if (!stmt.defined()) return true;
-  if (auto eval = stmt.as<EvaluateNode>()) return is_const(eval->value);
+  if (auto eval = stmt.as<EvaluateNode>()) return is_const_int(eval->value);
   return false;
 }
 
@@ -409,7 +409,7 @@ void CodeGenHybrid::VisitStmt_(const SeqStmtNode* op) {
 }
 
 void CodeGenHybrid::VisitStmt_(const EvaluateNode* op) {
-  if (is_const(op->value)) return;
+  if (is_const_int(op->value)) return;
   std::string str = PrintExpr(op->value);
   if (!str.empty()) stream << str << "\n";
 }
diff --git a/src/printer/tir_text_printer.cc b/src/printer/tir_text_printer.cc
index 7ab26fa..ca038ab 100644
--- a/src/printer/tir_text_printer.cc
+++ b/src/printer/tir_text_printer.cc
@@ -71,12 +71,14 @@ Doc TIRTextPrinter::Print(const ObjectRef& node) {
   }
 }
 
-Doc TIRTextPrinter::PrintPrimFunc(const PrimFunc& primFunc) {
-  const auto* op = primFunc.operator->();
+Doc TIRTextPrinter::PrintPrimFunc(const PrimFunc& prim_func) {
+  const auto* op = prim_func.operator->();
   const auto& signature = op->func_type_annotation();
   // collect Meta in DictAttr
-  for (const auto& it : primFunc->attrs->dict) {
-    meta_collector_.Collect(it.second);
+  if (prim_func->attrs.defined()) {
+    for (const auto& it : prim_func->attrs->dict) {
+      meta_collector_.Collect(it.second);
+    }
   }
   // collect buffers in buffer_map
   memo_var_.clear();
@@ -100,46 +102,54 @@ Doc TIRTextPrinter::PrintPrimFunc(const PrimFunc& 
primFunc) {
   // print attr
   Doc attr_doc;
   std::vector<Doc> attr_docs;
-  for (const auto& it : op->attrs->dict) {
-    attr_docs.push_back(Doc::StrLiteral(it.first) << ": " << Print(it.second));
+  if (prim_func->attrs.defined()) {
+    for (const auto& it : op->attrs->dict) {
+      attr_docs.push_back(Doc::StrLiteral(it.first) << ": " << 
Print(it.second));
+    }
+    attr_doc << Doc::NewLine() << "attr = {" << PrintSep(attr_docs, 
Doc::Text(", ")) << "}";
+    doc << Doc::Indent(2, attr_doc);
   }
-  attr_doc << Doc::NewLine() << "attr = {" << PrintSep(attr_docs, Doc::Text(", 
")) << "}";
-  doc << Doc::Indent(2, attr_doc);
+
   // print all the buffers in the tree
-  Doc buffer_doc;
-  std::vector<Doc> buffer_docs;
-  for (const auto& it : memo_buf_) {
-    const auto& buf = it.first;
-    buffer_docs.push_back(Print(buf) << Doc::Text(": Buffer(") << 
Print(buf->data) << ", "
-                                     << PrintDType(buf->dtype) << ", " << 
Print(buf->shape) << ", "
-                                     << Print(buf->strides));
-    if (!is_zero(buf->elem_offset)) {
-      buffer_docs.back() << ", elem_offset=" << Print(buf->elem_offset);
-    }
-    if (buf->scope != "global") {
-      buffer_docs.back() << ", scope=" << Doc::StrLiteral(buf->scope);
-    }
-    if (buf->data_alignment != 128) {
-      buffer_docs.back() << ", align=" << buf->data_alignment;
+  if (memo_buf_.size() != 0) {
+    Doc buffer_doc;
+    std::vector<Doc> buffer_docs;
+    for (const auto& it : memo_buf_) {
+      const auto& buf = it.first;
+      buffer_docs.push_back(Print(buf) << Doc::Text(": Buffer(") << 
Print(buf->data) << ", "
+                                       << PrintDType(buf->dtype) << ", " << 
Print(buf->shape)
+                                       << ", " << Print(buf->strides));
+      if (!is_zero(buf->elem_offset)) {
+        buffer_docs.back() << ", elem_offset=" << Print(buf->elem_offset);
+      }
+      if (buf->scope != "global") {
+        buffer_docs.back() << ", scope=" << Doc::StrLiteral(buf->scope);
+      }
+      if (buf->data_alignment != 128) {
+        buffer_docs.back() << ", align=" << buf->data_alignment;
+      }
+      if (buf->offset_factor != 1) {
+        buffer_docs.back() << ", offset_factor=" << buf->offset_factor;
+      }
+      if (buf->buffer_type != 1) {
+        buffer_docs.back() << ", type=" << Doc::StrLiteral("auto");
+      }
+      buffer_docs.back() << ")";
     }
-    if (buf->offset_factor != 1) {
-      buffer_docs.back() << ", offset_factor=" << buf->offset_factor;
-    }
-    if (buf->buffer_type != 1) {
-      buffer_docs.back() << ", type=" << Doc::StrLiteral("auto");
-    }
-    buffer_docs.back() << ")";
+    buffer_doc << Doc::NewLine() << "buffers = {";
+    buffer_doc << PrintSep(buffer_docs, Doc::Indent(11, Doc::Text(",") << 
Doc::NewLine()));
+    doc << Doc::Indent(2, buffer_doc) << "}";
   }
-  buffer_doc << Doc::NewLine() << "buffers = {";
-  buffer_doc << PrintSep(buffer_docs, Doc::Indent(11, Doc::Text(",") << 
Doc::NewLine()));
-  doc << Doc::Indent(2, buffer_doc) << "}";
-  // print buffer_map
-  std::vector<Doc> buffer_map_doc;
-  for (const auto& it : op->buffer_map) {
-    buffer_map_doc.push_back(Print(it.first) << ": " << Print(it.second));
+
+  if (op->buffer_map.size() != 0) {
+    // print buffer_map
+    std::vector<Doc> buffer_map_doc;
+    for (const auto& it : op->buffer_map) {
+      buffer_map_doc.push_back(Print(it.first) << ": " << Print(it.second));
+    }
+    doc << Doc::Indent(
+        2, Doc::NewLine() << "buffer_map = {" << PrintSep(buffer_map_doc, 
Doc::Text(", ")) << "}");
   }
-  doc << Doc::Indent(
-      2, Doc::NewLine() << "buffer_map = {" << PrintSep(buffer_map_doc, 
Doc::Text(", ")) << "}");
   doc << PrintBody(op->body);
   return doc;
 }
diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc
index 05582fb..7c3c830 100644
--- a/src/target/source/codegen_c.cc
+++ b/src/target/source/codegen_c.cc
@@ -914,7 +914,7 @@ void CodeGenC::VisitStmt_(const SeqStmtNode* op) {
 }
 
 void CodeGenC::VisitStmt_(const EvaluateNode* op) {
-  if (is_const(op->value)) return;
+  if (is_const_int(op->value)) return;
   const CallNode* call = op->value.as<CallNode>();
   if (call) {
     if (call->op.same_as(builtin::tvm_storage_sync())) {
diff --git a/src/target/source/codegen_cuda.cc 
b/src/target/source/codegen_cuda.cc
index ae5e40a..7dc63d4 100644
--- a/src/target/source/codegen_cuda.cc
+++ b/src/target/source/codegen_cuda.cc
@@ -609,7 +609,7 @@ void CodeGenCUDA::VisitStmt_(const AllocateNode* op) {
 }
 
 void CodeGenCUDA::VisitStmt_(const EvaluateNode* op) {
-  if (is_const(op->value)) return;
+  if (is_const_int(op->value)) return;
   const CallNode* call = op->value.as<CallNode>();
   if (call && call->op.same_as(builtin::tvm_global_barrier_kinit())) {
     PrintIndent();
diff --git a/src/target/stackvm/codegen_stackvm.cc 
b/src/target/stackvm/codegen_stackvm.cc
index 84b1492..9cad92d 100644
--- a/src/target/stackvm/codegen_stackvm.cc
+++ b/src/target/stackvm/codegen_stackvm.cc
@@ -429,7 +429,7 @@ void CodeGenStackVM::VisitStmt_(const SeqStmtNode* op) {
 }
 
 void CodeGenStackVM::VisitStmt_(const EvaluateNode* ev) {
-  if (is_const(ev->value)) return;
+  if (is_const_int(ev->value)) return;
   const CallNode* op = ev->value.as<CallNode>();
   if (op && op->op.same_as(builtin::tvm_struct_set())) {
     CHECK_EQ(op->args.size(), 4U);
diff --git a/src/tir/analysis/verify_ssa.cc b/src/tir/analysis/verify_ssa.cc
index c57cbf7..834ad09 100644
--- a/src/tir/analysis/verify_ssa.cc
+++ b/src/tir/analysis/verify_ssa.cc
@@ -35,44 +35,60 @@
 namespace tvm {
 namespace tir {
 
-class IRVerifySSA final : public StmtExprVisitor {
+class SSAVerifier final : public StmtExprVisitor {
  public:
-  bool is_ssa{true};
+  bool is_ssa_{true};
 
   void VisitExpr(const PrimExpr& n) final {
-    if (!is_ssa) return;
+    if (!is_ssa_) return;
     StmtExprVisitor::VisitExpr(n);
   }
   void VisitStmt(const Stmt& n) final {
-    if (!is_ssa) return;
+    if (!is_ssa_) return;
     StmtExprVisitor::VisitStmt(n);
   }
   void VisitExpr_(const LetNode* op) final {
-    MarkDef(op->var.get());
+    // Weaker SSA condition
+    // A single var can be binded in multiple lets
+    // but they have to bind to the same value.
+    // This is used to enable cases when we reuse a single let
+    // expression to cosntruct a nested expr.
+    // (let x = 1 in x + 1) * (let x = 1 in x + 1)
+    auto it = def_map_.find(op->var);
+    if (it != def_map_.end()) {
+      if (!deep_equal_(it->second, op->value)) {
+        is_ssa_ = false;
+        return;
+      }
+    } else {
+      MarkDef(op->var, op->value);
+    }
     StmtExprVisitor::VisitExpr_(op);
   }
+
   void VisitStmt_(const LetStmtNode* op) final {
-    MarkDef(op->var.get());
+    MarkDef(op->var, op->value);
     StmtExprVisitor::VisitStmt_(op);
   }
   void VisitStmt_(const ForNode* op) final {
-    MarkDef(op->loop_var.get());
+    MarkDef(op->loop_var, op->loop_var);
     StmtExprVisitor::VisitStmt_(op);
   }
   void VisitStmt_(const AllocateNode* op) final {
-    MarkDef(op->buffer_var.get());
+    MarkDef(op->buffer_var, op->buffer_var);
     StmtExprVisitor::VisitStmt_(op);
   }
 
   void VisitExpr_(const VarNode* node) final {
+    auto var = GetRef<Var>(node);
     if (match_scope_) {
-      MarkDef(node, true);
+      MarkDef(var, var, true);
     }
   }
 
   void Run(const PrimFunc& func) {
     for (auto param : func->params) {
-      MarkDef(param.get());
+      MarkDef(param, param);
     }
 
     for (auto kv : func->buffer_map) {
@@ -99,25 +115,28 @@ class IRVerifySSA final : public StmtExprVisitor {
   }
 
  private:
-  void MarkDef(const VarNode* v, bool allow_dup = false) {
-    if (defined_.count(v) != 0) {
+  void MarkDef(const Var& var, PrimExpr value, bool allow_dup = false) {
+    if (def_map_.count(var) != 0) {
       if (!allow_dup) {
-        is_ssa = false;
+        is_ssa_ = false;
         return;
       }
     } else {
-      defined_[v] = 1;
+      def_map_[var] = value;
     }
   }
   // whether we are in match scope, where a var can occur multiple times.
   bool match_scope_{false};
-  std::unordered_map<const VarNode*, int> defined_;
+  // deep equal
+  ExprDeepEqual deep_equal_;
+  // def map, for let, maps to the bind value, for others maps to self.
+  std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> def_map_;
 };
 
 bool VerifySSA(const PrimFunc& func) {
-  IRVerifySSA visitor;
+  SSAVerifier visitor;
   visitor.Run(func);
-  return visitor.is_ssa;
+  return visitor.is_ssa_;
 }
 
 TVM_REGISTER_GLOBAL("tir.analysis.verify_ssa").set_body_typed(VerifySSA);
diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc
index 0f67126..a0ba8d6 100644
--- a/src/tir/op/op.cc
+++ b/src/tir/op/op.cc
@@ -395,7 +395,7 @@ PrimExpr if_then_else(PrimExpr cond, PrimExpr true_value, 
PrimExpr false_value)
 
 // likely
 PrimExpr likely(PrimExpr cond) {
-  if (is_const(cond)) return cond;
+  if (is_const_int(cond)) return cond;
   return tir::Call(cond.dtype(), tir::builtin::likely(), {cond});
 }
 
diff --git a/src/tir/transforms/loop_partition.cc 
b/src/tir/transforms/loop_partition.cc
index 2fb8003..1876dfe 100644
--- a/src/tir/transforms/loop_partition.cc
+++ b/src/tir/transforms/loop_partition.cc
@@ -96,7 +96,7 @@ class CandidateSelector final : public StmtExprVisitor {
 
   void VisitStmt_(const ForNode* op) final {
     // partition const loop when sets partition_const_loop_
-    if (!is_const(op->min) || !is_const(op->extent) || partition_const_loop_) {
+    if (!is_const_int(op->min) || !is_const_int(op->extent) || 
partition_const_loop_) {
       const VarNode* var = op->loop_var.get();
       record_.insert({var, false});
       StmtExprVisitor::VisitStmt_(op);
@@ -115,7 +115,7 @@ class CandidateSelector final : public StmtExprVisitor {
       CHECK(iv);
       Var var = iv->var;
       runtime::ThreadScope scope = 
runtime::ThreadScope::Create(iv->thread_tag);
-      if ((scope.rank == 0) && (!is_const(op->value) || 
partition_const_loop_)) {
+      if ((scope.rank == 0) && (!is_const_int(op->value) || 
partition_const_loop_)) {
         record_.insert({var.get(), false});
         StmtExprVisitor::VisitStmt_(op);
         if (record_.at(var.get()) && !no_split_) {
diff --git a/src/tir/transforms/simplify.cc b/src/tir/transforms/simplify.cc
index 3be2329..3c8a934 100644
--- a/src/tir/transforms/simplify.cc
+++ b/src/tir/transforms/simplify.cc
@@ -54,9 +54,18 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
     return Parent::VisitStmt_(op);
   }
 
+  bool CanInlineLetStmt(const LetStmtNode* op) {
+    if (is_const_number(op->value)) return true;
+    if (op->value.as<VarNode>()) return true;
+    // Won't face the deep expression explosion problem as in Let expression.
+    // attempt to inline as much as possible if the value integer type(can be 
index).
+    if (!op->value.dtype().is_int()) return false;
+    return !tir::HasSideEffect(op->value);
+  }
+
   Stmt VisitStmt_(const LetStmtNode* op) {
     PrimExpr value = this->VisitExpr(op->value);
-    if (!tir::HasSideEffect(value)) {
+    if (CanInlineLetStmt(op)) {
       // it is fine to discard the let binding
       // because the call to simplify will always inline the var.
       analyzer_->Bind(op->var, value);
diff --git a/src/tir/transforms/split_host_device.cc 
b/src/tir/transforms/split_host_device.cc
index f339c56..75ae743 100644
--- a/src/tir/transforms/split_host_device.cc
+++ b/src/tir/transforms/split_host_device.cc
@@ -70,7 +70,7 @@ class VarUseDefAnalysis : public StmtExprMutator {
     this->HandleDef(op->var.get());
     Stmt body = this->VisitStmt(op->body);
     // eliminate unreferenced let
-    if (use_count_.at(op->var.get()) == 0 && !HasSideEffect(op->value)) {
+    if (use_count_.at(op->var.get()) == 0 && !HasSideEffect(op->value) && 
simplify_let_) {
       return body;
     } else {
       PrimExpr value = this->VisitExpr(op->value);
@@ -101,7 +101,7 @@ class VarUseDefAnalysis : public StmtExprMutator {
     this->HandleDef(op->var.get());
     PrimExpr body = this->VisitExpr(op->body);
     // eliminate unreferenced let
-    if (use_count_.at(op->var.get()) == 0 && !HasSideEffect(op->value)) {
+    if (use_count_.at(op->var.get()) == 0 && !HasSideEffect(op->value) && 
simplify_let_) {
       return body;
     } else {
       PrimExpr value = this->VisitExpr(op->value);
@@ -149,6 +149,7 @@ class VarUseDefAnalysis : public StmtExprMutator {
   // The fields are publically readible to
   // be accessible to the users.
   bool visit_thread_extent_{true};
+  bool simplify_let_{true};
   Array<Var> undefined_;
   Array<IterVar> thread_axis_;
   Array<PrimExpr> thread_extent_;
@@ -158,6 +159,7 @@ class VarUseDefAnalysis : public StmtExprMutator {
 
 Array<Var> UndefinedVars(const Stmt& stmt, const Array<Var>& args) {
   VarUseDefAnalysis m;
+  m.simplify_let_ = false;
   for (Var arg : args) {
     m.use_count_[arg.get()] = 0;
   }
diff --git a/src/tir/transforms/vectorize_loop.cc 
b/src/tir/transforms/vectorize_loop.cc
index e015990..bf54ada 100644
--- a/src/tir/transforms/vectorize_loop.cc
+++ b/src/tir/transforms/vectorize_loop.cc
@@ -23,6 +23,7 @@
 // Loop vectorizer as in Halide pipeline.
 #include <tvm/arith/analyzer.h>
 #include <tvm/runtime/registry.h>
+#include <tvm/tir/analysis.h>
 #include <tvm/tir/builtin.h>
 #include <tvm/tir/expr.h>
 #include <tvm/tir/op.h>
@@ -91,15 +92,21 @@ class VecAllocAccess : public StmtExprMutator {
   int var_lanes_;
 };
 
-class Vectorizer : public StmtExprMutator {
+// We use ExprFunctor directly instead of StmtExprMutator
+// This is because the transformation can change the dtype of the Expr
+// The existing ExprMutator transformation rules may not be well defined.
+class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const 
PrimExpr&)> {
  public:
+  using ExprFunctor::VisitExpr;
+  using StmtMutator::operator();
+
   Vectorizer(Var var, int var_lanes) : var_(var), var_lanes_(var_lanes) {
     ramp_ = Ramp(0, 1, var_lanes);
   }
 
   Stmt VisitStmt(const Stmt& stmt) final {
     CHECK(!need_scalarize_);
-    Stmt ret = StmtExprMutator::VisitStmt(stmt);
+    Stmt ret = StmtMutator::VisitStmt(stmt);
     if (need_scalarize_) {
       need_scalarize_ = false;
       return Scalarize(stmt);
@@ -108,6 +115,8 @@ class Vectorizer : public StmtExprMutator {
     }
   }
 
+  PrimExpr VisitExpr(const PrimExpr& e) final { return 
ExprFunctor::VisitExpr(e); }
+
   PrimExpr VisitExpr_(const AddNode* op) final {
     return AddSubVec(op, [](PrimExpr a, PrimExpr b) { return a + b; });
   }
@@ -151,6 +160,16 @@ class Vectorizer : public StmtExprMutator {
   PrimExpr VisitExpr_(const GENode* op) final { return BinaryVec<GE>(op); }
   PrimExpr VisitExpr_(const AndNode* op) final { return BinaryVec<And>(op); }
   PrimExpr VisitExpr_(const OrNode* op) final { return BinaryVec<Or>(op); }
+
+  PrimExpr VisitExpr_(const NotNode* op) final {
+    PrimExpr a = this->VisitExpr(op->a);
+    if (a.same_as(op->a)) {
+      return GetRef<PrimExpr>(op);
+    } else {
+      return !(a);
+    }
+  }
+
   PrimExpr VisitExpr_(const RampNode* op) final {
     PrimExpr base = this->VisitExpr(op->base);
     PrimExpr stride = this->VisitExpr(op->stride);
@@ -170,6 +189,20 @@ class Vectorizer : public StmtExprMutator {
     }
     return Shuffle::Concat(elems);
   }
+
+  PrimExpr VisitExpr_(const BroadcastNode* op) final {
+    PrimExpr value = this->VisitExpr(op->value);
+    if (value.dtype().lanes() != 1) {
+      need_scalarize_ = true;
+      return GetRef<PrimExpr>(op);
+    }
+    if (value.same_as(op->value)) {
+      return GetRef<PrimExpr>(op);
+    } else {
+      return Broadcast(op->value, op->lanes);
+    }
+  }
+
   PrimExpr VisitExpr_(const SelectNode* op) final {
     PrimExpr cond = this->VisitExpr(op->condition);
     PrimExpr t = this->VisitExpr(op->true_value);
@@ -189,14 +222,25 @@ class Vectorizer : public StmtExprMutator {
       return Cast(op->dtype.with_lanes(value.dtype().lanes()), value);
     }
   }
+
+  PrimExpr VisitExpr_(const FloatImmNode* op) final { return 
GetRef<PrimExpr>(op); }
+
+  PrimExpr VisitExpr_(const IntImmNode* op) final { return 
GetRef<PrimExpr>(op); }
+
+  PrimExpr VisitExpr_(const StringImmNode* op) final { return 
GetRef<PrimExpr>(op); }
+
   // Variable
-  PrimExpr VisitExpr_(const VarNode* v) final {
-    if (v == var_.get()) {
+  PrimExpr VisitExpr_(const VarNode* op) final {
+    Var var = GetRef<Var>(op);
+
+    if (var.same_as(var_)) {
       return ramp_;
-    } else if (lets_.count(v)) {
-      return lets_[v];
+    }
+    auto it = let_binding_.find(var);
+    if (it != let_binding_.end()) {
+      return it->second;
     } else {
-      return GetRef<PrimExpr>(v);
+      return std::move(var);
     }
   }
   // IfThenElse expr
@@ -267,12 +311,23 @@ class Vectorizer : public StmtExprMutator {
   // Let
   PrimExpr VisitExpr_(const LetNode* op) final {
     PrimExpr value = this->VisitExpr(op->value);
-    CHECK(!lets_.count(op->var.get())) << "not SSA";
+    // Weaker SSA condition
+    // A single var can be binded in multiple lets
+    // but they have to bind to the same value.
+    // This is used to allow cases when we reuse a single let
+    // expression to cosntruct a nested expr.
+    // (let x = 1 in x + 1) * (let x = 1 in x + 1)
+    auto it = let_binding_.find(op->var);
+    if (it != let_binding_.end()) {
+      CHECK(deep_equal_(it->second, value))
+          << "Let cannot bind the same var to two different values";
+    }
     if (value.dtype().lanes() != op->value.dtype().lanes()) {
-      Var v(op->var->name_hint, value.dtype());
-      lets_[op->var.get()] = v;
-      return Let(v, value, this->VisitExpr(op->body));
+      Var new_var(op->var->name_hint, value.dtype());
+      let_binding_[op->var] = new_var;
+      return Let(new_var, value, this->VisitExpr(op->body));
     } else {
+      let_binding_[op->var] = op->var;
       PrimExpr body = this->VisitExpr(op->body);
       if (value.same_as(op->value) && body.same_as(op->body)) {
         return GetRef<PrimExpr>(op);
@@ -281,10 +336,6 @@ class Vectorizer : public StmtExprMutator {
       }
     }
   }
-  Stmt VisitStmt_(const ProducerStoreNode* op) final {
-    LOG(FATAL) << "ProducerProvide is cannot appear in a TIR PrimFunc";
-    return Stmt();
-  }
   // Store
   Stmt VisitStmt_(const StoreNode* op) final {
     PrimExpr value = this->VisitExpr(op->value);
@@ -338,8 +389,23 @@ class Vectorizer : public StmtExprMutator {
   }
   // LetStmt
   Stmt VisitStmt_(const LetStmtNode* op) final {
-    LOG(WARNING) << "Cannot vectorize with LetStmt, remove it with Simplify 
Before Vectorize";
-    return Scalarize(GetRef<Stmt>(op));
+    PrimExpr value = this->VisitExpr(op->value);
+    CHECK(!let_binding_.count(op->var)) << "SSA violation, a single var is 
binded twice";
+    let_binding_[op->var] = value;
+
+    if (value.dtype().lanes() != op->value.dtype().lanes()) {
+      Var new_var(op->var->name_hint, value.dtype());
+      let_binding_[op->var] = new_var;
+      return LetStmt(new_var, value, this->VisitStmt(op->body));
+    } else {
+      let_binding_[op->var] = op->var;
+      Stmt body = this->VisitStmt(op->body);
+      if (value.same_as(op->value) && body.same_as(op->body)) {
+        return GetRef<Stmt>(op);
+      } else {
+        return LetStmt(op->var, value, body);
+      }
+    }
   }
   // Allocate
   Stmt VisitStmt_(const AllocateNode* op) final {
@@ -364,6 +430,7 @@ class Vectorizer : public StmtExprMutator {
     body = this->VisitStmt(body);
     return Allocate(op->buffer_var, op->dtype, extents, condition, body);
   }
+
   // scalarize the statment
   Stmt Scalarize(Stmt stmt) {
     Var idx(var_->name_hint + ".s", var_->dtype);
@@ -371,10 +438,17 @@ class Vectorizer : public StmtExprMutator {
     stmt = Substitute(stmt, values);
     return For(idx, 0, var_lanes_, ForType::Serial, DeviceAPI::None, stmt);
   }
+  // ProducerStore
+  Stmt VisitStmt_(const ProducerStoreNode* op) final {
+    LOG(FATAL) << "ProducerProvide is cannot appear in a TIR PrimFunc";
+    return Stmt();
+  }
 
  private:
   // analyzer
   arith::Analyzer analyzer_;
+  // deep equal
+  ExprDeepEqual deep_equal_;
   // variable to be replaced
   Var var_;
   // the lanes.
@@ -383,8 +457,8 @@ class Vectorizer : public StmtExprMutator {
   PrimExpr ramp_;
   // flag to mark requirment of scalarization.
   bool need_scalarize_{false};
-  // The lets
-  std::unordered_map<const VarNode*, PrimExpr> lets_;
+  // Let binding
+  std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> 
let_binding_;
   // vectorizable property
   OpAttrMap<TVectorizable> op_vectorizable_ = 
Op::GetAttrMap<TVectorizable>("TVectorizable");
 
diff --git a/tests/python/unittest/test_arith_const_int_bound.py 
b/tests/python/unittest/test_arith_const_int_bound.py
index 4829b97..c5794cd 100644
--- a/tests/python/unittest/test_arith_const_int_bound.py
+++ b/tests/python/unittest/test_arith_const_int_bound.py
@@ -284,7 +284,16 @@ def test_size_var_bound():
     assert bd.max_value == bd.POS_INF
 
 
+def test_let_bound():
+    analyzer = tvm.arith.Analyzer()
+    x = te.var("x")
+    bd = analyzer.const_int_bound(tvm.tir.Let(x, 1, x + 1))
+    assert bd.min_value == 2
+    assert bd.max_value == 2
+
+
 if __name__ == "__main__":
+    test_let_bound()
     test_dtype_bound()
     test_cast_bound()
     test_add_sub_bound()
diff --git a/tests/python/unittest/test_arith_modular_set.py 
b/tests/python/unittest/test_arith_modular_set.py
index 01180d2..7d9f739 100644
--- a/tests/python/unittest/test_arith_modular_set.py
+++ b/tests/python/unittest/test_arith_modular_set.py
@@ -159,8 +159,17 @@ def test_intersect():
                 assert m.coeff == 105
                 assert m.base == 23
 
+def test_let():
+    analyzer = tvm.arith.Analyzer()
+    x = te.var("x")
+    y = te.var("y")
+    m = analyzer.modular_set(tvm.tir.Let(x, y * 10, x + 1))
+    m.coeff = 10
+    m.base = 1
+
 
 if __name__ == "__main__":
+    test_let()
     test_cast()
     test_add_sub()
     test_mul()
diff --git a/tests/python/unittest/test_tir_analysis_verify_ssa.py 
b/tests/python/unittest/test_tir_analysis_verify_ssa.py
index 8a15c36..57dd826 100644
--- a/tests/python/unittest/test_tir_analysis_verify_ssa.py
+++ b/tests/python/unittest/test_tir_analysis_verify_ssa.py
@@ -27,6 +27,16 @@ def test_verify_ssa():
     assert(not tvm.tir.analysis.verify_ssa(
         tvm.tir.PrimFunc([x, y], tvm.tir.LetStmt(x, 1, z))))
 
+def test_verify_weak_let_ssa():
+    x = te.var('x')
+    z1 = tvm.tir.Let(x, 1, x + 1)
+    z2 = tvm.tir.Let(x, 2, x + 2)
+
+    assert(tvm.tir.analysis.verify_ssa(
+        tvm.tir.PrimFunc([], tvm.tir.Evaluate(z1 + z1))))
+    assert(not tvm.tir.analysis.verify_ssa(
+        tvm.tir.PrimFunc([], tvm.tir.Evaluate(z1 * z2))))
 
 if __name__ == "__main__":
     test_verify_ssa()
+    test_verify_weak_let_ssa()
diff --git a/tests/python/unittest/test_tir_nodes.py 
b/tests/python/unittest/test_tir_nodes.py
index ab730cd..c182d9e 100644
--- a/tests/python/unittest/test_tir_nodes.py
+++ b/tests/python/unittest/test_tir_nodes.py
@@ -274,7 +274,8 @@ def test_prim_func():
 
     func = tvm.tir.PrimFunc(
         [x, y, b], stmt)
-
+    # make sure we can print
+    func.astext()
     assert func.buffer_map[func.params[2]].same_as(b)
 
     assert len(func.buffer_map) == 1
diff --git a/tests/python/unittest/test_tir_transform_vectorize.py 
b/tests/python/unittest/test_tir_transform_vectorize.py
index a69c9d3..0516b4a 100644
--- a/tests/python/unittest/test_tir_transform_vectorize.py
+++ b/tests/python/unittest/test_tir_transform_vectorize.py
@@ -81,6 +81,20 @@ def test_vectorize_with_if():
     assert isinstance(stmt.else_case, tvm.tir.For)
 
 
+def test_vectorize_let():
+    v = tvm.tir.Var("v", "float32")
+    ib = tvm.tir.ir_builder.create()
+    A = ib.pointer("float32", name="A")
+    with ib.for_range(0, 4, for_type="vectorize") as i:
+        ib.emit(lambda body: tvm.tir.LetStmt(v, A[i] + 1, body))
+        A[i] = v + 2
+
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A], ib.get()))
+    stmt = tvm.tir.transform.VectorizeLoop()(mod)["main"].body
+    assert isinstance(stmt, tvm.tir.LetStmt)
+    assert stmt.value.dtype == "float32x4"
+
+
 def test_vectorize_with_le_cond():
     n = te.var('n')
     ib = tvm.tir.ir_builder.create()
@@ -153,3 +167,4 @@ if __name__ == "__main__":
     test_vectorize_if_then_else()
     test_vectorize_with_le_cond()
     test_vectorize_with_ge_cond()
+    test_vectorize_let()

Reply via email to