mbs-octoml commented on a change in pull request #9542:
URL: https://github.com/apache/tvm/pull/9542#discussion_r754747093



##########
File path: src/relay/transforms/dead_code.cc
##########
@@ -18,158 +18,565 @@
  */
 
 /*!
+ * \file src/relay/transforms/dead_code.cc
+ * \brief Elides or inlines let-bindings.
  *
- * \file dead_code.cc
- *
- * \brief Remove code that does not effect the program result.
- *
- * The algorithm is implemented by two visitor:
- * CalcDep turn an expr into a dependency graph of expr,
- * GenLet turn the dependency graph into a let list, taking only the used 
value.
+ * TODO(mbs): Track dead writes into references.
  */
+
 #include <tvm/relay/analysis.h>
 #include <tvm/relay/expr_functor.h>
+#include <tvm/relay/pattern_functor.h>
 #include <tvm/relay/transform.h>
 
-#include "let_list.h"
+#include "../op/call/call.h"
 
 namespace tvm {
 namespace relay {
+namespace {
 
-template <typename X>
-using VarMap = std::unordered_map<Var, X, ObjectPtrHash, ObjectPtrEqual>;
-using VarSet = std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual>;
+/*! \brief Maximum depth of calls to analyize. */
+constexpr int kMaxCallDepth = 25;
+
+/*!
+ * \brief Captures (an approximation of) the purity for a Relay 
sub-expression. A pure
+ * sub-expression is guaranteed never to access or mutate state. Thus the 
sub-expression
+ * can safely be elided (if its result is never used), or inlined (which may 
change the
+ * number of times and program order for the evaluation.)
+ */
+struct Purity {
+  /*!
+   * \brief True if evaling the sub-expression itself is pure.
+   */
+  bool pure_eval;
+  /*!
+   * \brief If the sub-expression is first-order then always true. Otherwise 
true only if evaling
+   * a call to the the sub-expression is pure. See [RULE A] below.
+   */
+  bool pure_call;
+};
+
+/*!
+ * \brief Visits all the global functions in a module and records the purity 
of every let-bound
+ * value.
+ *
+ * (See also inline.cc for function inlining.)
+ *
+ * Generally we track whether evaluation of a sub-expression is definitely 
pure. However for
+ * sub-expressions f of higher-order type we also track the 'call purity' of 
evaling a call to f:
+ *  - [RULE A] If f's result is itself higher-order then f is call-pure only 
if the result of f is
+ *    also call-pure.
+ *  - [RULE B] Higher-order function arguments are assumed call impure.
+ *  - [RULE C] We assume functions extracted from tuples are call impure.
+ *  - [RULE D] We assume functions extracted from references are call impure.
+ *  - [RULE E] We assume functions extracted from ADTs are call impure.
+ *  - [RULE F] We assume all external Functions and PrimFuncs are call impure.
+ */
+class PurityVisitor : ExprFunctor<Purity(const Expr&)> {
+ public:
+  explicit PurityVisitor(IRModule mod) : mod_(std::move(mod)), 
current_call_depth_(0) {}
+
+  /*! \brief Visit all the functions in the module. */
+  void VisitModule() {
+    VLOG_CONTEXT << "PurityVisitor";
+    // It is safe to visit the global functions in any order. Recursive global 
functions are
+    // allowed.
+    for (const auto& kv : mod_->functions) {
+      if (const auto* function_node = kv.second.as<FunctionNode>()) {
+        if (function_node->HasNonzeroAttr(attr::kPrimitive) ||
+            function_node->GetAttr<String>(attr::kExternalSymbol)) {
+          // Ignore primitive and external functions.
+          continue;
+        }
+        // Everything of interest will be recorded in the purity maps so we 
ignore the result.
+        (void)VisitGlobalFunction(kv.first, GetRef<Function>(function_node));
+      }
+    }
+  }
+
+  /*!
+   * \brief Returns a map from every let-bound variable to whether its 
let-bound value is
+   * definitely pure.
+   */
+  std::unordered_map<const VarNode*, bool> GetPurityMap() const {
+    std::unordered_map<const VarNode*, bool> result;
+    for (const auto& kv : var_to_purity_) {
+      result.emplace(kv.first, kv.second.pure_eval);
+    }
+    return result;
+  }
 
-class CalcDep;
-class FindDef : private ExprVisitor {
  private:
-  VarMap<Expr> expr_map_;
+  Purity VisitExpr(const Expr& expr) final {
+    auto it = memo_.find(expr.get());
+    if (it != this->memo_.end()) {
+      return it->second;
+    } else {
+      Purity result = ExprFunctor::VisitExpr(expr);
+      memo_[expr.get()] = result;
+      return result;
+    }
+  }
 
-  void VisitExpr_(const LetNode* l) final {
-    auto pre_visit = [this](const LetNode* op) {
-      ICHECK_EQ(expr_map_.count(op->var), 0);
-      expr_map_[op->var] = op->value;
-      this->VisitExpr(op->value);
-    };
-    auto post_visit = [this](const LetNode* op) {
-      this->VisitExpr(op->body);
-      this->visit_counter_[op] += 1;
-    };
-    ExpandANormalForm(l, pre_visit, post_visit);
+  Purity VisitExpr_(const ConstantNode*) final { return {/*pure_eval=*/true, 
/*pure_call=*/true}; }
+
+  Purity VisitExpr_(const ConstructorNode*) final {
+    return {/*pure_eval=*/true, /*pure_call=*/true};
+  }
+
+  Purity VisitExpr_(const OpNode* op_node) final {
+    // Primitive operators are pure unless marked as 'stateful'.
+    static OpAttrMap<bool> attr_map = 
Op::GetAttrMap<TOpIsStateful>("TOpIsStateful");
+    bool is_statefull = attr_map.count(GetRef<Op>(op_node)) && 
attr_map[GetRef<Op>(op_node)];

Review comment:
       done




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to