This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 9cc1df6  [AMP][Pass][Typing] Add faster type inference (#9735)
9cc1df6 is described below

commit 9cc1df60701d6d46577d028c211bb568225fc4f1
Author: AndrewZhaoLuo <[email protected]>
AuthorDate: Mon Jan 3 19:24:45 2022 -0800

    [AMP][Pass][Typing] Add faster type inference (#9735)
    
    * reuse checked types
    
    * analogous subgraph
    
    * brr go fast
    
    * clean up src logs
    
    * clean up PR more
    
    * more clean up
    
    * more documenetation
    
    * clean up
    
    * formatting
    
    * rename fast --> local
    
    * more ocmments
    
    * jostle ci
    
    * type inference
    
    * change comment for SameTypedSubgraphExtractor
    
    * get_analogous_expression -> GetAnalogousExpression
    
    * comment in GetAnaalogousExpression
    
    * add comment
    
    * replace infer tests
    
    * jostle
---
 include/tvm/relay/transform.h              |  17 ++++-
 python/tvm/relay/transform/transform.py    |  19 ++++++
 src/relay/transforms/to_mixed_precision.cc |  28 ++++++--
 src/relay/transforms/type_infer.cc         | 106 +++++++++++++++++++++++++++++
 tests/python/relay/test_type_infer.py      |  27 ++++----
 5 files changed, 176 insertions(+), 21 deletions(-)

diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h
index dfc49cb..3100f14 100644
--- a/include/tvm/relay/transform.h
+++ b/include/tvm/relay/transform.h
@@ -250,7 +250,7 @@ TVM_DLL Pass DynamicToStatic();
 /*!
  * \brief Infer the type of an expression.
  *
- * The result of type checking is a new expression with unambigous
+ * The result of type checking is a new expression with unambiguous
  * type information filled in, as well as it's checked type field
  * populated with the result type.
  *
@@ -259,6 +259,21 @@ TVM_DLL Pass DynamicToStatic();
 TVM_DLL Pass InferType();
 
 /*!
+ * \brief Infer the type of an expression, reusing existing type information.
+ *
+ * The result of type checking is a new expression with unambiguous
+ * type information filled in for the given node only. The local
+ * version can use existing type information populated throughout
+ * the expression and assumes this information is correct. The local
+ * version also avoids examining large amounts of the graph assuming
+ * type information is filled in properly which makes it much faster if we
+ * iteratively call type inference.
+ *
+ * \return The type of the expression.
+ */
+TVM_DLL Type InferTypeLocal(const Expr& expr);
+
+/*!
  * \brief Search and eliminate common subexpression. For example, if there are
  * two expressions evaluated to an identical value, a single variable is 
created
  * and these two expressions are replaced by this variable.
diff --git a/python/tvm/relay/transform/transform.py 
b/python/tvm/relay/transform/transform.py
index bbe4bc2..0123ca0 100644
--- a/python/tvm/relay/transform/transform.py
+++ b/python/tvm/relay/transform/transform.py
@@ -99,6 +99,25 @@ def InferType():
     return _ffi_api.InferType()
 
 
+def InferTypeLocal(expr):
+    """Infer the type of a single expr, reusing type information to do so.
+
+    This populates the checked_type field in expr. We assume existing type 
information
+    in the graph is correct!
+
+    Parameters
+    ----------
+    expr: relay.Expr
+        The expression we want to know the type of
+
+    Returns
+    -------
+    type: relay.Type
+        The type of the expression
+    """
+    return _ffi_api.InferTypeLocal(expr)
+
+
 def FoldScaleAxis():
     """Fold the scaling of axis into weights of conv2d/dense. This pass will
     invoke both forward and backward scale folding.
diff --git a/src/relay/transforms/to_mixed_precision.cc 
b/src/relay/transforms/to_mixed_precision.cc
index ae10c93..d8c7aa2 100644
--- a/src/relay/transforms/to_mixed_precision.cc
+++ b/src/relay/transforms/to_mixed_precision.cc
@@ -176,13 +176,17 @@ class MixedPrecisionPass : public MixedModeMutator {
   }
 
   Type GetType(const Expr& expr) const {
-    auto mod = IRModule::FromExpr(expr);
-    mod = transform::InferType()(mod);
-    if (expr.as<FunctionNode>()) {
-      return mod->Lookup("main")->checked_type();
-    } else {
-      return mod->Lookup("main").as<FunctionNode>()->body->checked_type();
+    // The expression has not been changed AND it's existing type
+    // is known to still be valid. (See special handling for tuples etc
+    // below for where we null out checked_type_ when we can not
+    // sure it is still valid.
+    Type checked_type = expr->checked_type_;
+    if (checked_type.defined()) {
+      return checked_type;
     }
+
+    // This also populates the checked_type_ field for expr
+    return transform::InferTypeLocal(expr);
   }
 
   bool IsMixedPrecisionType(const Type& t, bool ignore_non_float = false) 
const {
@@ -381,6 +385,18 @@ class MixedPrecisionPass : public MixedModeMutator {
     return Call(cur_op, new_args, pre_call_node->attrs, new_arg_types, 
pre_call_node->span);
   }
 
+  Expr Rewrite_(const TupleGetItemNode* pre, const Expr& post) {
+    // The old checked type in the expression may not be valid so clear it
+    post->checked_type_ = Type(nullptr);
+    return post;
+  }
+
+  Expr Rewrite_(const TupleNode* pre, const Expr& post) {
+    // The old checked type in the expression may not be valid so clear it
+    post->checked_type_ = Type(nullptr);
+    return post;
+  }
+
   Expr VisitExpr_(const FunctionNode* func) final {
     // Erase the ret_type annotation and let the normal pass recalculate
     const_cast<FunctionNode*>(func)->ret_type = Type(nullptr);
diff --git a/src/relay/transforms/type_infer.cc 
b/src/relay/transforms/type_infer.cc
index 22bc8f3..456e210 100644
--- a/src/relay/transforms/type_infer.cc
+++ b/src/relay/transforms/type_infer.cc
@@ -824,8 +824,114 @@ void AddGlobalTypes(IRModule mod) {
   }
 }
 
+/*!
+ * \brief Returns a possibly much smaller subgraph whose inner nodes have the 
same type.
+ *
+ * Returns the largest sub-graph who's inner nodes need types and leaves are 
vars standing in
+ * for already typed sub-expressions. This creates a graph whose inner nodes 
have the same
+ * type as the original graph and when running type inference, we can avoid 
copying and
+ * recursing through most of the expression graph when running type inference. 
Note, this assumes
+ * that current populated type information is correct!
+ *
+ * ExprMutator is sufficient over MixedModemutator since we will not recurse 
much.
+ */
+class SameTypedSubgraphExtractor : public ExprMutator {
+  Expr VisitExpr_(const VarNode* op) { return Var(op->vid, 
op->type_annotation, op->span); }
+  Expr VisitExpr_(const ConstantNode* op) { return Constant(op->data, 
op->span); }
+  Expr VisitExpr_(const GlobalVarNode* op) { return GlobalVar(op->name_hint); }
+  Expr VisitExpr_(const OpNode* op) { return Op(GetRef<Op>(op)); }
+  Expr VisitExpr_(const TupleNode* op) {
+    return Tuple(GetAnalogousExpression(op->fields), op->span);
+  }
+  Expr VisitExpr_(const FunctionNode* op) {
+    // Unfortunately our strategy of inserting variables as dummies would 
change the signature of
+    // existing function nodes so we have to copy all used functions always :/
+    return Function(op->params, op->body, op->ret_type, op->type_params, 
op->attrs, op->span);
+  }
+  Expr VisitExpr_(const CallNode* op) {
+    return Call(op->op, GetAnalogousExpression(op->args), op->attrs, 
op->type_args, op->span);
+  }
+  Expr VisitExpr_(const LetNode* op) {
+    return Let(op->var, GetAnalogousExpression(op->value), 
GetAnalogousExpression(op->body),
+               op->span);
+  }
+  Expr VisitExpr_(const IfNode* op) {
+    return If(GetAnalogousExpression(op->cond), 
GetAnalogousExpression(op->true_branch),
+              GetAnalogousExpression(op->false_branch), op->span);
+  }
+  Expr VisitExpr_(const TupleGetItemNode* op) {
+    return TupleGetItem(GetAnalogousExpression(op->tuple), op->index, 
op->span);
+  }
+  Expr VisitExpr_(const RefCreateNode* op) {
+    return RefCreate(GetAnalogousExpression(op->value), op->span);
+  }
+  Expr VisitExpr_(const RefReadNode* op) {
+    return RefRead(GetAnalogousExpression(op->ref), op->span);
+  }
+  Expr VisitExpr_(const RefWriteNode* op) {
+    return RefWrite(GetAnalogousExpression(op->ref), 
GetAnalogousExpression(op->value), op->span);
+  }
+  Expr VisitExpr_(const ConstructorNode* op) {
+    return Constructor(op->name_hint, op->inputs, op->belong_to);
+  }
+  Expr VisitExpr_(const MatchNode* op) {
+    return Match(GetAnalogousExpression(op->data), op->clauses, op->complete, 
op->span);
+  }
+
+ private:
+  Expr GetAnalogousExpression(const Expr& expr) {
+    // Replace the expression with a potentially simpler expression of the 
same type
+    if (expr->checked_type_.defined()) {
+      // Since the expression already has a checked_type which we assume is 
correct we don't need
+      // full type inference to enter it. So stub it out with a dummy var of 
the same type.
+      return Var("dummy_var", expr->checked_type(), expr->span);
+    }
+
+    return VisitExpr(expr);
+  }
+  Array<Expr> GetAnalogousExpression(const Array<Expr>& fields) {
+    Array<Expr> new_fields;
+    for (Expr expr : fields) {
+      new_fields.push_back(GetAnalogousExpression(expr));
+    }
+    return new_fields;
+  }
+};
+
 namespace transform {
 
+Type InferTypeLocal(const Expr& expr) {
+  /*
+  This type inference differs from InferType in that it uses existing type 
information
+  to avoid recursing over much of the graph, and it only examines the type of 
the input
+  node. This makes it faster if you need to run type inference iteratively 
throughout
+  a pass for example.
+
+  However, it assumes any existing populated type inference is correct! If 
some populated
+  type inference is incorrect, an incorrect type may be returned or a type 
error will be
+  raised. If you know not all populated type fields are correct with the 
current graph,
+  you should use InferType() instead.
+  */
+  SameTypedSubgraphExtractor subgraph_extractor;
+  Expr sub_graph = subgraph_extractor(expr);
+  auto mod = IRModule::FromExpr(sub_graph);
+  mod = transform::InferType()(mod);
+
+  Type result_type;
+  if (expr.as<FunctionNode>()) {
+    result_type = mod->Lookup("main")->checked_type();
+  } else {
+    result_type = mod->Lookup("main").as<FunctionNode>()->body->checked_type();
+  }
+
+  expr->checked_type_ = result_type;
+  return result_type;
+}
+
+TVM_REGISTER_GLOBAL("relay._transform.InferTypeLocal").set_body_typed([](const 
Expr& expr) {
+  return InferTypeLocal(expr);
+});
+
 Pass InferType() {
   auto pass_info = PassInfo(0, "InferType", {});
   return tvm::transform::CreateModulePass(
diff --git a/tests/python/relay/test_type_infer.py 
b/tests/python/relay/test_type_infer.py
index a0d3784..af64ce7 100644
--- a/tests/python/relay/test_type_infer.py
+++ b/tests/python/relay/test_type_infer.py
@@ -19,9 +19,8 @@
 """
 import pytest
 import tvm
-
-from tvm import IRModule, te, relay, parser
-from tvm.relay import op, transform, analysis
+from tvm import IRModule, parser, relay, te
+from tvm.relay import analysis, op, transform
 from tvm.relay.op import op as _op
 
 
@@ -33,12 +32,9 @@ def infer_mod(mod, annotate_spans=True):
     return mod
 
 
-def infer_expr(expr, annotate_spans=True):
-    mod = IRModule.from_expr(expr)
-    mod = infer_mod(mod, annotate_spans)
-    mod = transform.InferType()(mod)
-    entry = mod["main"]
-    return entry if isinstance(expr, relay.Function) else entry.body
+def infer_expr(expr):
+    transform.InferTypeLocal(expr)
+    return expr
 
 
 def assert_has_type(expr, typ, mod=None):
@@ -68,7 +64,7 @@ def test_monomorphic_let():
     # TODO(@jroesch): this seems whack.
     sb = relay.ScopeBuilder()
     x = relay.var("x", dtype="float64", shape=())
-    x = sb.let("x", relay.const(1.0, "float64"))
+    x = sb.let(x, relay.const(1.0, "float64"))
     sb.ret(x)
     xchecked = infer_expr(sb.get())
     assert xchecked.checked_type == relay.scalar_type("float64")
@@ -165,11 +161,11 @@ def test_recursion():
 def test_incomplete_call():
     tt = relay.scalar_type("int32")
     x = relay.var("x", tt)
+    f_type = relay.FuncType([tt], tt)
     f = relay.var("f")
     func = relay.Function([x, f], relay.Call(f, [x]), tt)
 
     ft = infer_expr(func)
-    f_type = relay.FuncType([tt], tt)
     assert ft.checked_type == relay.FuncType([tt, f_type], tt)
 
 
@@ -245,7 +241,7 @@ def test_ref():
 def test_free_expr():
     x = relay.var("x", "float32")
     y = relay.add(x, x)
-    yy = infer_expr(y, annotate_spans=False)
+    yy = infer_expr(y)
     assert tvm.ir.structural_equal(yy.args[0], x, map_free_vars=True)
     assert yy.checked_type == relay.scalar_type("float32")
     assert x.vid.same_as(yy.args[0].vid)
@@ -255,8 +251,11 @@ def test_type_args():
     x = relay.var("x", shape=(10, 10))
     y = relay.var("y", shape=(1, 10))
     z = relay.add(x, y)
-    ty_z = infer_expr(z)
-    ty_args = ty_z.type_args
+
+    # InferTypeLocal does not support populating the type_args field
+    mod = infer_mod(IRModule.from_expr(z))
+    mod = infer_mod(mod, annotate_spans=False)
+    ty_args = mod["main"].body.type_args
     assert len(ty_args) == 2
     assert ty_args[0].dtype == "float32"
     assert ty_args[1].dtype == "float32"

Reply via email to