This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new 327891c [Relay][Pass] Add submodule extraction pass (#4960)
327891c is described below
commit 327891cbb863c9ef48ec8ba3cab950dee7c845c3
Author: anwang2009 <[email protected]>
AuthorDate: Fri Mar 13 13:58:49 2020 -0700
[Relay][Pass] Add submodule extraction pass (#4960)
* rebased
* fix lint
---
python/tvm/relay/analysis.py | 22 ++++
src/relay/analysis/extract_fused_functions.cc | 82 +++++++++++++++
.../relay/test_analysis_extract_fused_functions.py | 115 +++++++++++++++++++++
3 files changed, 219 insertions(+)
diff --git a/python/tvm/relay/analysis.py b/python/tvm/relay/analysis.py
index fc4a037..198e0a3 100644
--- a/python/tvm/relay/analysis.py
+++ b/python/tvm/relay/analysis.py
@@ -407,3 +407,25 @@ def structural_hash(value):
msg = ("found value of type {0} expected" +
"relay.Expr or relay.Type").format(type(value))
raise TypeError(msg)
+
+
+def extract_fused_functions(mod):
+ """Pass to extract IRModule of only fused primitive functions.
+
+ The ExtractFusedFunctions pass invokes SimplifyInference, FuseOps(3),
+ and ExtractFusedFunctions in that order
+
+ Parameters
+ ----------
+ mod : tvm.relay.IRModule
+
+ Returns
+ -------
+ ret : Dict[int, tvm.relay.expr.Function]
+ A module containing only fused primitive functions
+ """
+ ret_mod = _analysis.ExtractFusedFunctions()(mod)
+ ret = {}
+ for hash_, func in ret_mod.functions.items():
+ ret[hash_] = func
+ return ret
diff --git a/src/relay/analysis/extract_fused_functions.cc
b/src/relay/analysis/extract_fused_functions.cc
new file mode 100644
index 0000000..3667d8a
--- /dev/null
+++ b/src/relay/analysis/extract_fused_functions.cc
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file extract_fused_functions.cc
+ * \brief Apply fusion and extract fused primitive functions from an IRModule
+ */
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+
+namespace tvm {
+namespace relay {
+
+class FusedFunctionExtractorWrapper : private ExprVisitor {
+ public:
+ explicit FusedFunctionExtractorWrapper(const IRModule& mod) : mod_(mod) {}
+
+ IRModule Extract() {
+ VisitExpr(this->mod_->Lookup("main"));
+
+ auto functions = Map<GlobalVar, BaseFunc>();
+ for (auto pair : this->functions) {
+ functions.Set(GlobalVar(pair.first), pair.second);
+ }
+
+ this->mod_->functions = functions;
+ return this->mod_;
+ }
+
+ private:
+ const IRModule mod_;
+ // This is not simply Map<GlobalVar, Function> because GlobalVar doesn't
+ // have the desired equals property
+ Map<std::string, Function> functions;
+
+ void VisitExpr_(const FunctionNode* n) final {
+ if (n->HasNonzeroAttr(attr::kPrimitive)) {
+ // Add function to functions, keyed by function hash string
+ Function func = Function(n->params, n->body, n->ret_type,
n->type_params, n->attrs);
+ size_t hash_ = StructuralHash()(func);
+ this->functions.Set(std::to_string(hash_), func);
+ }
+
+ ExprVisitor::VisitExpr_(n);
+ }
+};
+
+namespace transform {
+
+Pass ExtractFusedFunctions() {
+ runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
+ [=](IRModule m, PassContext pc) { return
FusedFunctionExtractorWrapper(m).Extract(); };
+ auto fused_function_extractor_pass = CreateModulePass(pass_func, 1,
"ExtractFusedFunctions", {});
+
+ return Sequential({SimplifyInference(), FuseOps(3),
fused_function_extractor_pass},
+ "ExtractFusedFunctions");
+}
+
+TVM_REGISTER_GLOBAL("relay._analysis.ExtractFusedFunctions").set_body_typed(ExtractFusedFunctions);
+
+} // namespace transform
+
+} // namespace relay
+} // namespace tvm
diff --git a/tests/python/relay/test_analysis_extract_fused_functions.py
b/tests/python/relay/test_analysis_extract_fused_functions.py
new file mode 100644
index 0000000..1a70ef1
--- /dev/null
+++ b/tests/python/relay/test_analysis_extract_fused_functions.py
@@ -0,0 +1,115 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Test function extraction"""
+import tvm
+from tvm import relay
+from tvm.relay.testing.resnet import get_workload
+
+
+def get_conv_net():
+ """This gets the net for a case described in fuse_ops.cc:
+
+ conv2d
+ / | \
+ / | \
+ op op op
+ \ | /
+ \ | /
+ elemwise add
+ |
+ """
+ dshape = (1, 1, 5, 1)
+ x = relay.var("x", shape=dshape)
+ y = relay.nn.conv2d(x, relay.var("w1"),
+ kernel_size=(3, 3),
+ padding=(1, 1),
+ channels=1)
+
+ x1 = relay.nn.conv2d(y, relay.var("w2"),
+ kernel_size=(3, 3),
+ padding=(1, 1),
+ channels=1)
+ x2 = relay.nn.conv2d(y, relay.var("w3"),
+ kernel_size=(3, 3),
+ padding=(1, 1),
+ channels=1)
+ x3 = relay.nn.conv2d(y, relay.var("w4"),
+ kernel_size=(3, 3),
+ padding=(1, 1),
+ channels=1)
+
+ z = relay.add(x1, x2)
+ z = relay.add(x3, z)
+
+ return tvm.IRModule.from_expr(z)
+
+
+def get_conv2d():
+ x = relay.var("x", shape=(1, 56, 56, 64))
+ weight1 = relay.var('weight1', shape=(3, 3, 64, 32))
+ y = relay.nn.conv2d(x, weight1,
+ channels=32,
+ kernel_size=(3, 3),
+ padding=(1, 1),
+ data_layout='NHWC',
+ kernel_layout='HWIO')
+ return tvm.IRModule.from_expr(y)
+
+
+def test_extract_identity():
+ mod = get_conv2d()
+ items = relay.analysis.extract_fused_functions(mod)
+ assert len(items) == 1
+
+ mod["main"] = mod["main"].with_attr(
+ "Primitive", tvm.tir.IntImm("int32", 1))
+ relay.analysis.assert_graph_equal(list(items.values())[0], mod["main"])
+
+
+def test_extract_conv_net():
+ mod = get_conv_net()
+ items = relay.analysis.extract_fused_functions(mod)
+ functions = list(items.values())
+ assert len(functions) == 2
+ x = functions[0]
+ y = functions[1]
+
+ def is_conv(func):
+ conv2d = relay.op.op.get("nn.conv2d")
+ call_node = func.body
+ return call_node.op == conv2d
+
+ def is_conv_add(func):
+ add = relay.op.op.get("add")
+ call_node = func.body
+ maybe_conv_module = tvm.IRModule.from_expr(call_node.args[0])
+ return call_node.op == add and is_conv(maybe_conv_module["main"])
+
+ # Function traversal order isn't obvious, so checking both orders is more
consistent
+ assert (is_conv(x) and is_conv_add(y)) or (is_conv_add(x) and is_conv(y))
+
+
+def test_extract_resnet():
+ mod, _params = get_workload()
+ items = relay.analysis.extract_fused_functions(mod)
+ assert len(items) == 34
+
+
+if __name__ == '__main__':
+ test_extract_identity()
+ test_extract_conv_net()
+ test_extract_resnet()