electriclilies commented on a change in pull request #9542:
URL: https://github.com/apache/tvm/pull/9542#discussion_r754716450
##########
File path: include/tvm/relay/transform.h
##########
@@ -540,7 +550,7 @@ TVM_DLL Function ToCPS(const Function& f, const IRModule&
mod);
/*!
* \brief Remove the continuation argument of a CPS function.
*
- * Note that this only transform the type back into un-CPS form
+ * Note that this only transform the type back into un-CPS formA
Review comment:
You added an "A" here by accident :)
##########
File path: src/relay/transforms/dead_code.cc
##########
@@ -18,158 +18,565 @@
*/
/*!
+ * \file src/relay/transforms/dead_code.cc
+ * \brief Elides or inlines let-bindings.
*
- * \file dead_code.cc
- *
- * \brief Remove code that does not effect the program result.
- *
- * The algorithm is implemented by two visitor:
- * CalcDep turn an expr into a dependency graph of expr,
- * GenLet turn the dependency graph into a let list, taking only the used
value.
+ * TODO(mbs): Track dead writes into references.
*/
+
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/pattern_functor.h>
#include <tvm/relay/transform.h>
-#include "let_list.h"
+#include "../op/call/call.h"
namespace tvm {
namespace relay {
+namespace {
-template <typename X>
-using VarMap = std::unordered_map<Var, X, ObjectPtrHash, ObjectPtrEqual>;
-using VarSet = std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual>;
+/*! \brief Maximum depth of calls to analyize. */
+constexpr int kMaxCallDepth = 25;
+
+/*!
+ * \brief Captures (an approximation of) the purity for a Relay
sub-expression. A pure
+ * sub-expression is guaranteed never to access or mutate state. Thus the
sub-expression
+ * can safely be elided (if its result is never used), or inlined (which may
change the
+ * number of times and program order for the evaluation.)
+ */
+struct Purity {
+ /*!
+ * \brief True if evaling the sub-expression itself is pure.
+ */
+ bool pure_eval;
+ /*!
+ * \brief If the sub-expression is first-order then always true. Otherwise
true only if evaling
+ * a call to the the sub-expression is pure. See [RULE A] below.
+ */
+ bool pure_call;
+};
+
+/*!
+ * \brief Visits all the global functions in a module and records the purity
of every let-bound
+ * value.
+ *
+ * (See also inline.cc for function inlining.)
+ *
+ * Generally we track whether evaluation of a sub-expression is definitely
pure. However for
+ * sub-expressions f of higher-order type we also track the 'call purity' of
evaling a call to f:
+ * - [RULE A] If f's result is itself higher-order then f is call-pure only
if the result of f is
+ * also call-pure.
+ * - [RULE B] Higher-order function arguments are assumed call impure.
+ * - [RULE C] We assume functions extracted from tuples are call impure.
+ * - [RULE D] We assume functions extracted from references are call impure.
+ * - [RULE E] We assume functions extracted from ADTs are call impure.
+ * - [RULE F] We assume all external Functions and PrimFuncs are call impure.
+ */
+class PurityVisitor : ExprFunctor<Purity(const Expr&)> {
+ public:
+ explicit PurityVisitor(IRModule mod) : mod_(std::move(mod)),
current_call_depth_(0) {}
+
+ /*! \brief Visit all the functions in the module. */
+ void VisitModule() {
+ VLOG_CONTEXT << "PurityVisitor";
+ // It is safe to visit the global functions in any order. Recursive global
functions are
+ // allowed.
+ for (const auto& kv : mod_->functions) {
+ if (const auto* function_node = kv.second.as<FunctionNode>()) {
+ if (function_node->HasNonzeroAttr(attr::kPrimitive) ||
+ function_node->GetAttr<String>(attr::kExternalSymbol)) {
+ // Ignore primitive and external functions.
+ continue;
+ }
+ // Everything of interest will be recorded in the purity maps so we
ignore the result.
+ (void)VisitGlobalFunction(kv.first, GetRef<Function>(function_node));
+ }
+ }
+ }
+
+ /*!
+ * \brief Returns a map from every let-bound variable to whether its
let-bound value is
+ * definitely pure.
+ */
+ std::unordered_map<const VarNode*, bool> GetPurityMap() const {
+ std::unordered_map<const VarNode*, bool> result;
+ for (const auto& kv : var_to_purity_) {
+ result.emplace(kv.first, kv.second.pure_eval);
+ }
+ return result;
+ }
-class CalcDep;
-class FindDef : private ExprVisitor {
private:
- VarMap<Expr> expr_map_;
+ Purity VisitExpr(const Expr& expr) final {
+ auto it = memo_.find(expr.get());
+ if (it != this->memo_.end()) {
+ return it->second;
+ } else {
+ Purity result = ExprFunctor::VisitExpr(expr);
+ memo_[expr.get()] = result;
+ return result;
+ }
+ }
- void VisitExpr_(const LetNode* l) final {
- auto pre_visit = [this](const LetNode* op) {
- ICHECK_EQ(expr_map_.count(op->var), 0);
- expr_map_[op->var] = op->value;
- this->VisitExpr(op->value);
- };
- auto post_visit = [this](const LetNode* op) {
- this->VisitExpr(op->body);
- this->visit_counter_[op] += 1;
- };
- ExpandANormalForm(l, pre_visit, post_visit);
+ Purity VisitExpr_(const ConstantNode*) final { return {/*pure_eval=*/true,
/*pure_call=*/true}; }
+
+ Purity VisitExpr_(const ConstructorNode*) final {
+ return {/*pure_eval=*/true, /*pure_call=*/true};
+ }
+
+ Purity VisitExpr_(const OpNode* op_node) final {
+ // Primitive operators are pure unless marked as 'stateful'.
+ static OpAttrMap<bool> attr_map =
Op::GetAttrMap<TOpIsStateful>("TOpIsStateful");
+ bool is_statefull = attr_map.count(GetRef<Op>(op_node)) &&
attr_map[GetRef<Op>(op_node)];
+ return {/*pure_eval=*/true, /*pure_call=*/!is_statefull};
+ }
+
+ Purity VisitExpr_(const GlobalVarNode* global_var_node) final {
+ auto global_var = GetRef<GlobalVar>(global_var_node);
+ auto func = mod_->Lookup(global_var);
+ if (const auto* function_node = func.as<FunctionNode>()) {
+ if (!function_node->GetAttr<String>(attr::kExternalSymbol)) {
+ return VisitGlobalFunction(global_var,
GetRef<Function>(function_node));
+ }
+ }
+ // Assume externals and PrimFuncs are call-impure [RULE F].
+ // (If they are pure then we should have dealt with them before lowering.)
+ return {/*pure_eval==*/true, /*pure_call=*/false};
+ }
+
+ Purity VisitExpr_(const VarNode* var_node) final {
+ // The var is bound to a value, but if that value is a function we need to
propagate the
+ // function body's purity.
+ ICHECK(var_to_purity_.count(var_node)) <<
PrettyPrint(GetRef<Var>(var_node));
+ return {/*pure_eval=*/true,
/*pure_call=*/var_to_purity_[var_node].pure_call};
+ }
+
+ Purity VisitExpr_(const FunctionNode* function_node) final {
+ for (const auto& param : function_node->params) {
+ // Any higher-order parameters are assumed to be call-impure [RULE B]
+ var_to_purity_[param.get()] = {/*pure_eval=*/true,
/*pure_call=*/IsFirstOrder(param)};
+ }
+ Purity body_purity = VisitExpr(function_node->body);
+ // The function itself is a value and thus pure. If the function returns
+ // a function we'll fold its purity in here [RULE A]
+ return {/*pure_eval=*/true, /*pure_call=*/body_purity.pure_eval &&
body_purity.pure_call};
+ }
+
+ Purity VisitExpr_(const LetNode* let_node) final {
+ Expr expr = GetRef<Expr>(let_node);
+ bool all_values_pure_eval = true;
+ while (const auto* inner_let_node = expr.as<LetNode>()) {
+ // In case the value is a recursive function assume the let-bound
variable is call-pure.
+ var_to_purity_[inner_let_node->var.get()] = {/*pure_eval=*/true,
/*pure_call=*/true};
+ Purity value_purity = VisitExpr(inner_let_node->value);
+ // Now revise the variable to it's true purity.
+ var_to_purity_[inner_let_node->var.get()] = value_purity;
+ VLOG(2) << (value_purity.pure_eval ? "pure" : "impure") << "
expression:" << std::endl
+ << PrettyPrint(inner_let_node->value) << std::endl
+ << "let-bound to variable:" << std::endl
+ << PrettyPrint(inner_let_node->var);
+ all_values_pure_eval = all_values_pure_eval && value_purity.pure_eval;
+ expr = inner_let_node->body;
+ }
+ Purity body_purity = VisitExpr(expr);
+ return {/*pure_eval=*/all_values_pure_eval && body_purity.pure_eval,
+ /*pure_call=*/body_purity.pure_call};
+ }
+
+ Purity VisitExpr_(const CallNode* call_node) final {
+ if (current_call_depth_ >= kMaxCallDepth) {
+ // Assume impure.
+ VLOG(2) << "assuming call is impure since too deeply nested";
+ return {/*pure_eval=*/false, /*pure_call*/
IsFirstOrder(GetRef<Call>(call_node))};
+ }
+
+ ++current_call_depth_;
+
+ // We can work with the call in both pre- and post-lowered form.
+ Expr callee;
+ Array<Expr> args;
+ if (call_node->op == CallLoweredOp()) {
+ CallLoweredProps props = GetCallLoweredProps(call_node);
+ callee = props.lowered_func;
+ args = props.arguments;
+ } else {
+ callee = call_node->op;
+ args = call_node->args;
+ }
+
+ // Find purity for the callee and the args.
+ Purity callee_purity = VisitExpr(callee);
+ bool all_args_pure_eval = true;
+ for (const auto& arg : args) {
+ Purity arg_purity = VisitExpr(arg);
+ all_args_pure_eval = all_args_pure_eval && arg_purity.pure_eval;
+ }
+
+ VLOG(2) << (callee_purity.pure_call ? "pure" : "impure") << " call to:" <<
std::endl
+ << PrettyPrint(callee);
+
+ ICHECK_GT(current_call_depth_, 0);
+ --current_call_depth_;
+
+ // If the callee's result is itself a function then by [RULE A] its purity
+ // is given by callee_purity.pure_call.
+ return {/*pure_eval=*/all_args_pure_eval && callee_purity.pure_eval &&
callee_purity.pure_call,
+ /*pure_call=*/IsFirstOrder(GetRef<Call>(call_node)) ||
callee_purity.pure_call};
+ }
+
+ Purity VisitExpr_(const IfNode* if_node) final {
+ Purity cond_purity = VisitExpr(if_node->cond);
+ ICHECK(cond_purity.pure_call); // conditional is first-order
+ Purity true_purity = VisitExpr(if_node->true_branch);
+ Purity false_purity = VisitExpr(if_node->false_branch);
+ return {/*pure_eval=*/cond_purity.pure_eval && true_purity.pure_eval &&
false_purity.pure_eval,
+ /*pure_call=*/true_purity.pure_call && false_purity.pure_call};
+ }
+
+ Purity VisitExpr_(const TupleNode* tuple_node) final {
+ bool all_fields_pure = true;
+ for (const auto& field : tuple_node->fields) {
+ // The call purity of each tuple field is lost [RULE C].
+ Purity field_purity = VisitExpr(field);
+ if (!field_purity.pure_eval) {
+ all_fields_pure = false;
+ }
+ }
+ return {/*pure_eval=*/all_fields_pure, /*pure_call=*/true};
+ }
+
+ Purity VisitExpr_(const TupleGetItemNode* tuple_get_item_node) final {
+ Purity tuple_purity = VisitExpr(tuple_get_item_node->tuple);
+ ICHECK(tuple_purity.pure_call); // tuple is first-order
+ // We don't track call purity through tuple fields, so if the result is a
function type we
+ // must assume it is call impure [RULE C].
+ return {/*pure_eval=*/tuple_purity.pure_eval,
+
/*pure_call=*/IsFirstOrder(GetRef<TupleGetItem>(tuple_get_item_node))};
+ }
+
+ Purity VisitExpr_(const RefCreateNode*) final {
+ // The creation of the ref itself is unobservable other than via the
reads/writes into it.
+ return {/*pure_eval=*/true, /*pure_call=*/true};
+ }
+
+ Purity VisitExpr_(const RefWriteNode* ref_write_node) final {
+ Purity ref_purity = VisitExpr(ref_write_node->ref);
+ ICHECK(ref_purity.pure_call); // reference is first-order
+ // The call purity of the written value is lost [RULE D].
Review comment:
if the call purity of the written value is lost, why do we need to visit
it?
##########
File path: src/relay/transforms/dead_code.cc
##########
@@ -18,158 +18,565 @@
*/
/*!
+ * \file src/relay/transforms/dead_code.cc
+ * \brief Elides or inlines let-bindings.
*
- * \file dead_code.cc
- *
- * \brief Remove code that does not effect the program result.
- *
- * The algorithm is implemented by two visitor:
- * CalcDep turn an expr into a dependency graph of expr,
- * GenLet turn the dependency graph into a let list, taking only the used
value.
+ * TODO(mbs): Track dead writes into references.
*/
+
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/pattern_functor.h>
#include <tvm/relay/transform.h>
-#include "let_list.h"
+#include "../op/call/call.h"
namespace tvm {
namespace relay {
+namespace {
-template <typename X>
-using VarMap = std::unordered_map<Var, X, ObjectPtrHash, ObjectPtrEqual>;
-using VarSet = std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual>;
+/*! \brief Maximum depth of calls to analyize. */
+constexpr int kMaxCallDepth = 25;
+
+/*!
+ * \brief Captures (an approximation of) the purity for a Relay
sub-expression. A pure
+ * sub-expression is guaranteed never to access or mutate state. Thus the
sub-expression
+ * can safely be elided (if its result is never used), or inlined (which may
change the
+ * number of times and program order for the evaluation.)
+ */
+struct Purity {
+ /*!
+ * \brief True if evaling the sub-expression itself is pure.
+ */
+ bool pure_eval;
+ /*!
+ * \brief If the sub-expression is first-order then always true. Otherwise
true only if evaling
+ * a call to the the sub-expression is pure. See [RULE A] below.
+ */
+ bool pure_call;
+};
+
+/*!
+ * \brief Visits all the global functions in a module and records the purity
of every let-bound
+ * value.
+ *
+ * (See also inline.cc for function inlining.)
+ *
+ * Generally we track whether evaluation of a sub-expression is definitely
pure. However for
+ * sub-expressions f of higher-order type we also track the 'call purity' of
evaling a call to f:
+ * - [RULE A] If f's result is itself higher-order then f is call-pure only
if the result of f is
+ * also call-pure.
+ * - [RULE B] Higher-order function arguments are assumed call impure.
+ * - [RULE C] We assume functions extracted from tuples are call impure.
+ * - [RULE D] We assume functions extracted from references are call impure.
+ * - [RULE E] We assume functions extracted from ADTs are call impure.
+ * - [RULE F] We assume all external Functions and PrimFuncs are call impure.
+ */
Review comment:
This description is very nice
##########
File path: src/relay/transforms/dead_code.cc
##########
@@ -18,158 +18,565 @@
*/
/*!
+ * \file src/relay/transforms/dead_code.cc
+ * \brief Elides or inlines let-bindings.
*
- * \file dead_code.cc
- *
- * \brief Remove code that does not effect the program result.
- *
- * The algorithm is implemented by two visitor:
- * CalcDep turn an expr into a dependency graph of expr,
- * GenLet turn the dependency graph into a let list, taking only the used
value.
+ * TODO(mbs): Track dead writes into references.
*/
+
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/pattern_functor.h>
#include <tvm/relay/transform.h>
-#include "let_list.h"
+#include "../op/call/call.h"
namespace tvm {
namespace relay {
+namespace {
-template <typename X>
-using VarMap = std::unordered_map<Var, X, ObjectPtrHash, ObjectPtrEqual>;
-using VarSet = std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual>;
+/*! \brief Maximum depth of calls to analyize. */
+constexpr int kMaxCallDepth = 25;
+
+/*!
+ * \brief Captures (an approximation of) the purity for a Relay
sub-expression. A pure
+ * sub-expression is guaranteed never to access or mutate state. Thus the
sub-expression
+ * can safely be elided (if its result is never used), or inlined (which may
change the
+ * number of times and program order for the evaluation.)
+ */
+struct Purity {
+ /*!
+ * \brief True if evaling the sub-expression itself is pure.
+ */
+ bool pure_eval;
+ /*!
+ * \brief If the sub-expression is first-order then always true. Otherwise
true only if evaling
+ * a call to the the sub-expression is pure. See [RULE A] below.
+ */
+ bool pure_call;
+};
+
+/*!
+ * \brief Visits all the global functions in a module and records the purity
of every let-bound
+ * value.
+ *
+ * (See also inline.cc for function inlining.)
+ *
+ * Generally we track whether evaluation of a sub-expression is definitely
pure. However for
+ * sub-expressions f of higher-order type we also track the 'call purity' of
evaling a call to f:
+ * - [RULE A] If f's result is itself higher-order then f is call-pure only
if the result of f is
+ * also call-pure.
+ * - [RULE B] Higher-order function arguments are assumed call impure.
+ * - [RULE C] We assume functions extracted from tuples are call impure.
+ * - [RULE D] We assume functions extracted from references are call impure.
+ * - [RULE E] We assume functions extracted from ADTs are call impure.
+ * - [RULE F] We assume all external Functions and PrimFuncs are call impure.
+ */
+class PurityVisitor : ExprFunctor<Purity(const Expr&)> {
+ public:
+ explicit PurityVisitor(IRModule mod) : mod_(std::move(mod)),
current_call_depth_(0) {}
+
+ /*! \brief Visit all the functions in the module. */
+ void VisitModule() {
+ VLOG_CONTEXT << "PurityVisitor";
+ // It is safe to visit the global functions in any order. Recursive global
functions are
+ // allowed.
+ for (const auto& kv : mod_->functions) {
+ if (const auto* function_node = kv.second.as<FunctionNode>()) {
+ if (function_node->HasNonzeroAttr(attr::kPrimitive) ||
+ function_node->GetAttr<String>(attr::kExternalSymbol)) {
+ // Ignore primitive and external functions.
+ continue;
+ }
+ // Everything of interest will be recorded in the purity maps so we
ignore the result.
+ (void)VisitGlobalFunction(kv.first, GetRef<Function>(function_node));
+ }
+ }
+ }
+
+ /*!
+ * \brief Returns a map from every let-bound variable to whether its
let-bound value is
+ * definitely pure.
+ */
+ std::unordered_map<const VarNode*, bool> GetPurityMap() const {
+ std::unordered_map<const VarNode*, bool> result;
+ for (const auto& kv : var_to_purity_) {
+ result.emplace(kv.first, kv.second.pure_eval);
+ }
+ return result;
+ }
-class CalcDep;
-class FindDef : private ExprVisitor {
private:
- VarMap<Expr> expr_map_;
+ Purity VisitExpr(const Expr& expr) final {
+ auto it = memo_.find(expr.get());
+ if (it != this->memo_.end()) {
+ return it->second;
+ } else {
+ Purity result = ExprFunctor::VisitExpr(expr);
+ memo_[expr.get()] = result;
+ return result;
+ }
+ }
- void VisitExpr_(const LetNode* l) final {
- auto pre_visit = [this](const LetNode* op) {
- ICHECK_EQ(expr_map_.count(op->var), 0);
- expr_map_[op->var] = op->value;
- this->VisitExpr(op->value);
- };
- auto post_visit = [this](const LetNode* op) {
- this->VisitExpr(op->body);
- this->visit_counter_[op] += 1;
- };
- ExpandANormalForm(l, pre_visit, post_visit);
+ Purity VisitExpr_(const ConstantNode*) final { return {/*pure_eval=*/true,
/*pure_call=*/true}; }
+
+ Purity VisitExpr_(const ConstructorNode*) final {
+ return {/*pure_eval=*/true, /*pure_call=*/true};
+ }
+
+ Purity VisitExpr_(const OpNode* op_node) final {
+ // Primitive operators are pure unless marked as 'stateful'.
+ static OpAttrMap<bool> attr_map =
Op::GetAttrMap<TOpIsStateful>("TOpIsStateful");
+ bool is_statefull = attr_map.count(GetRef<Op>(op_node)) &&
attr_map[GetRef<Op>(op_node)];
Review comment:
sp: statefull -> stateful
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]