This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 9cc1df6 [AMP][Pass][Typing] Add faster type inference (#9735)
9cc1df6 is described below
commit 9cc1df60701d6d46577d028c211bb568225fc4f1
Author: AndrewZhaoLuo <[email protected]>
AuthorDate: Mon Jan 3 19:24:45 2022 -0800
[AMP][Pass][Typing] Add faster type inference (#9735)
* reuse checked types
* analogous subgraph
* brr go fast
* clean up src logs
* clean up PR more
* more clean up
* more documenetation
* clean up
* formatting
* rename fast --> local
* more ocmments
* jostle ci
* type inference
* change comment for SameTypedSubgraphExtractor
* get_analogous_expression -> GetAnalogousExpression
* comment in GetAnaalogousExpression
* add comment
* replace infer tests
* jostle
---
include/tvm/relay/transform.h | 17 ++++-
python/tvm/relay/transform/transform.py | 19 ++++++
src/relay/transforms/to_mixed_precision.cc | 28 ++++++--
src/relay/transforms/type_infer.cc | 106 +++++++++++++++++++++++++++++
tests/python/relay/test_type_infer.py | 27 ++++----
5 files changed, 176 insertions(+), 21 deletions(-)
diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h
index dfc49cb..3100f14 100644
--- a/include/tvm/relay/transform.h
+++ b/include/tvm/relay/transform.h
@@ -250,7 +250,7 @@ TVM_DLL Pass DynamicToStatic();
/*!
* \brief Infer the type of an expression.
*
- * The result of type checking is a new expression with unambigous
+ * The result of type checking is a new expression with unambiguous
* type information filled in, as well as it's checked type field
* populated with the result type.
*
@@ -259,6 +259,21 @@ TVM_DLL Pass DynamicToStatic();
TVM_DLL Pass InferType();
/*!
+ * \brief Infer the type of an expression, reusing existing type information.
+ *
+ * The result of type checking is a new expression with unambiguous
+ * type information filled in for the given node only. The local
+ * version can use existing type information populated throughout
+ * the expression and assumes this information is correct. The local
+ * version also avoids examining large amounts of the graph assuming
+ * type information is filled in properly which makes it much faster if we
+ * iteratively call type inference.
+ *
+ * \return The type of the expression.
+ */
+TVM_DLL Type InferTypeLocal(const Expr& expr);
+
+/*!
* \brief Search and eliminate common subexpression. For example, if there are
* two expressions evaluated to an identical value, a single variable is
created
* and these two expressions are replaced by this variable.
diff --git a/python/tvm/relay/transform/transform.py
b/python/tvm/relay/transform/transform.py
index bbe4bc2..0123ca0 100644
--- a/python/tvm/relay/transform/transform.py
+++ b/python/tvm/relay/transform/transform.py
@@ -99,6 +99,25 @@ def InferType():
return _ffi_api.InferType()
+def InferTypeLocal(expr):
+ """Infer the type of a single expr, reusing type information to do so.
+
+ This populates the checked_type field in expr. We assume existing type
information
+ in the graph is correct!
+
+ Parameters
+ ----------
+ expr: relay.Expr
+ The expression we want to know the type of
+
+ Returns
+ -------
+ type: relay.Type
+ The type of the expression
+ """
+ return _ffi_api.InferTypeLocal(expr)
+
+
def FoldScaleAxis():
"""Fold the scaling of axis into weights of conv2d/dense. This pass will
invoke both forward and backward scale folding.
diff --git a/src/relay/transforms/to_mixed_precision.cc
b/src/relay/transforms/to_mixed_precision.cc
index ae10c93..d8c7aa2 100644
--- a/src/relay/transforms/to_mixed_precision.cc
+++ b/src/relay/transforms/to_mixed_precision.cc
@@ -176,13 +176,17 @@ class MixedPrecisionPass : public MixedModeMutator {
}
Type GetType(const Expr& expr) const {
- auto mod = IRModule::FromExpr(expr);
- mod = transform::InferType()(mod);
- if (expr.as<FunctionNode>()) {
- return mod->Lookup("main")->checked_type();
- } else {
- return mod->Lookup("main").as<FunctionNode>()->body->checked_type();
+ // The expression has not been changed AND it's existing type
+ // is known to still be valid. (See special handling for tuples etc
+ // below for where we null out checked_type_ when we can not
+ // sure it is still valid.
+ Type checked_type = expr->checked_type_;
+ if (checked_type.defined()) {
+ return checked_type;
}
+
+ // This also populates the checked_type_ field for expr
+ return transform::InferTypeLocal(expr);
}
bool IsMixedPrecisionType(const Type& t, bool ignore_non_float = false)
const {
@@ -381,6 +385,18 @@ class MixedPrecisionPass : public MixedModeMutator {
return Call(cur_op, new_args, pre_call_node->attrs, new_arg_types,
pre_call_node->span);
}
+ Expr Rewrite_(const TupleGetItemNode* pre, const Expr& post) {
+ // The old checked type in the expression may not be valid so clear it
+ post->checked_type_ = Type(nullptr);
+ return post;
+ }
+
+ Expr Rewrite_(const TupleNode* pre, const Expr& post) {
+ // The old checked type in the expression may not be valid so clear it
+ post->checked_type_ = Type(nullptr);
+ return post;
+ }
+
Expr VisitExpr_(const FunctionNode* func) final {
// Erase the ret_type annotation and let the normal pass recalculate
const_cast<FunctionNode*>(func)->ret_type = Type(nullptr);
diff --git a/src/relay/transforms/type_infer.cc
b/src/relay/transforms/type_infer.cc
index 22bc8f3..456e210 100644
--- a/src/relay/transforms/type_infer.cc
+++ b/src/relay/transforms/type_infer.cc
@@ -824,8 +824,114 @@ void AddGlobalTypes(IRModule mod) {
}
}
+/*!
+ * \brief Returns a possibly much smaller subgraph whose inner nodes have the
same type.
+ *
+ * Returns the largest sub-graph who's inner nodes need types and leaves are
vars standing in
+ * for already typed sub-expressions. This creates a graph whose inner nodes
have the same
+ * type as the original graph and when running type inference, we can avoid
copying and
+ * recursing through most of the expression graph when running type inference.
Note, this assumes
+ * that current populated type information is correct!
+ *
+ * ExprMutator is sufficient over MixedModemutator since we will not recurse
much.
+ */
+class SameTypedSubgraphExtractor : public ExprMutator {
+ Expr VisitExpr_(const VarNode* op) { return Var(op->vid,
op->type_annotation, op->span); }
+ Expr VisitExpr_(const ConstantNode* op) { return Constant(op->data,
op->span); }
+ Expr VisitExpr_(const GlobalVarNode* op) { return GlobalVar(op->name_hint); }
+ Expr VisitExpr_(const OpNode* op) { return Op(GetRef<Op>(op)); }
+ Expr VisitExpr_(const TupleNode* op) {
+ return Tuple(GetAnalogousExpression(op->fields), op->span);
+ }
+ Expr VisitExpr_(const FunctionNode* op) {
+ // Unfortunately our strategy of inserting variables as dummies would
change the signature of
+ // existing function nodes so we have to copy all used functions always :/
+ return Function(op->params, op->body, op->ret_type, op->type_params,
op->attrs, op->span);
+ }
+ Expr VisitExpr_(const CallNode* op) {
+ return Call(op->op, GetAnalogousExpression(op->args), op->attrs,
op->type_args, op->span);
+ }
+ Expr VisitExpr_(const LetNode* op) {
+ return Let(op->var, GetAnalogousExpression(op->value),
GetAnalogousExpression(op->body),
+ op->span);
+ }
+ Expr VisitExpr_(const IfNode* op) {
+ return If(GetAnalogousExpression(op->cond),
GetAnalogousExpression(op->true_branch),
+ GetAnalogousExpression(op->false_branch), op->span);
+ }
+ Expr VisitExpr_(const TupleGetItemNode* op) {
+ return TupleGetItem(GetAnalogousExpression(op->tuple), op->index,
op->span);
+ }
+ Expr VisitExpr_(const RefCreateNode* op) {
+ return RefCreate(GetAnalogousExpression(op->value), op->span);
+ }
+ Expr VisitExpr_(const RefReadNode* op) {
+ return RefRead(GetAnalogousExpression(op->ref), op->span);
+ }
+ Expr VisitExpr_(const RefWriteNode* op) {
+ return RefWrite(GetAnalogousExpression(op->ref),
GetAnalogousExpression(op->value), op->span);
+ }
+ Expr VisitExpr_(const ConstructorNode* op) {
+ return Constructor(op->name_hint, op->inputs, op->belong_to);
+ }
+ Expr VisitExpr_(const MatchNode* op) {
+ return Match(GetAnalogousExpression(op->data), op->clauses, op->complete,
op->span);
+ }
+
+ private:
+ Expr GetAnalogousExpression(const Expr& expr) {
+ // Replace the expression with a potentially simpler expression of the
same type
+ if (expr->checked_type_.defined()) {
+ // Since the expression already has a checked_type which we assume is
correct we don't need
+ // full type inference to enter it. So stub it out with a dummy var of
the same type.
+ return Var("dummy_var", expr->checked_type(), expr->span);
+ }
+
+ return VisitExpr(expr);
+ }
+ Array<Expr> GetAnalogousExpression(const Array<Expr>& fields) {
+ Array<Expr> new_fields;
+ for (Expr expr : fields) {
+ new_fields.push_back(GetAnalogousExpression(expr));
+ }
+ return new_fields;
+ }
+};
+
namespace transform {
+Type InferTypeLocal(const Expr& expr) {
+ /*
+ This type inference differs from InferType in that it uses existing type
information
+ to avoid recursing over much of the graph, and it only examines the type of
the input
+ node. This makes it faster if you need to run type inference iteratively
throughout
+ a pass for example.
+
+ However, it assumes any existing populated type inference is correct! If
some populated
+ type inference is incorrect, an incorrect type may be returned or a type
error will be
+ raised. If you know not all populated type fields are correct with the
current graph,
+ you should use InferType() instead.
+ */
+ SameTypedSubgraphExtractor subgraph_extractor;
+ Expr sub_graph = subgraph_extractor(expr);
+ auto mod = IRModule::FromExpr(sub_graph);
+ mod = transform::InferType()(mod);
+
+ Type result_type;
+ if (expr.as<FunctionNode>()) {
+ result_type = mod->Lookup("main")->checked_type();
+ } else {
+ result_type = mod->Lookup("main").as<FunctionNode>()->body->checked_type();
+ }
+
+ expr->checked_type_ = result_type;
+ return result_type;
+}
+
+TVM_REGISTER_GLOBAL("relay._transform.InferTypeLocal").set_body_typed([](const
Expr& expr) {
+ return InferTypeLocal(expr);
+});
+
Pass InferType() {
auto pass_info = PassInfo(0, "InferType", {});
return tvm::transform::CreateModulePass(
diff --git a/tests/python/relay/test_type_infer.py
b/tests/python/relay/test_type_infer.py
index a0d3784..af64ce7 100644
--- a/tests/python/relay/test_type_infer.py
+++ b/tests/python/relay/test_type_infer.py
@@ -19,9 +19,8 @@
"""
import pytest
import tvm
-
-from tvm import IRModule, te, relay, parser
-from tvm.relay import op, transform, analysis
+from tvm import IRModule, parser, relay, te
+from tvm.relay import analysis, op, transform
from tvm.relay.op import op as _op
@@ -33,12 +32,9 @@ def infer_mod(mod, annotate_spans=True):
return mod
-def infer_expr(expr, annotate_spans=True):
- mod = IRModule.from_expr(expr)
- mod = infer_mod(mod, annotate_spans)
- mod = transform.InferType()(mod)
- entry = mod["main"]
- return entry if isinstance(expr, relay.Function) else entry.body
+def infer_expr(expr):
+ transform.InferTypeLocal(expr)
+ return expr
def assert_has_type(expr, typ, mod=None):
@@ -68,7 +64,7 @@ def test_monomorphic_let():
# TODO(@jroesch): this seems whack.
sb = relay.ScopeBuilder()
x = relay.var("x", dtype="float64", shape=())
- x = sb.let("x", relay.const(1.0, "float64"))
+ x = sb.let(x, relay.const(1.0, "float64"))
sb.ret(x)
xchecked = infer_expr(sb.get())
assert xchecked.checked_type == relay.scalar_type("float64")
@@ -165,11 +161,11 @@ def test_recursion():
def test_incomplete_call():
tt = relay.scalar_type("int32")
x = relay.var("x", tt)
+ f_type = relay.FuncType([tt], tt)
f = relay.var("f")
func = relay.Function([x, f], relay.Call(f, [x]), tt)
ft = infer_expr(func)
- f_type = relay.FuncType([tt], tt)
assert ft.checked_type == relay.FuncType([tt, f_type], tt)
@@ -245,7 +241,7 @@ def test_ref():
def test_free_expr():
x = relay.var("x", "float32")
y = relay.add(x, x)
- yy = infer_expr(y, annotate_spans=False)
+ yy = infer_expr(y)
assert tvm.ir.structural_equal(yy.args[0], x, map_free_vars=True)
assert yy.checked_type == relay.scalar_type("float32")
assert x.vid.same_as(yy.args[0].vid)
@@ -255,8 +251,11 @@ def test_type_args():
x = relay.var("x", shape=(10, 10))
y = relay.var("y", shape=(1, 10))
z = relay.add(x, y)
- ty_z = infer_expr(z)
- ty_args = ty_z.type_args
+
+ # InferTypeLocal does not support populating the type_args field
+ mod = infer_mod(IRModule.from_expr(z))
+ mod = infer_mod(mod, annotate_spans=False)
+ ty_args = mod["main"].body.type_args
assert len(ty_args) == 2
assert ty_args[0].dtype == "float32"
assert ty_args[1].dtype == "float32"