This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new 327891c  [Relay][Pass] Add submodule extraction pass (#4960)
327891c is described below

commit 327891cbb863c9ef48ec8ba3cab950dee7c845c3
Author: anwang2009 <[email protected]>
AuthorDate: Fri Mar 13 13:58:49 2020 -0700

    [Relay][Pass] Add submodule extraction pass (#4960)
    
    * rebased
    
    * fix lint
---
 python/tvm/relay/analysis.py                       |  22 ++++
 src/relay/analysis/extract_fused_functions.cc      |  82 +++++++++++++++
 .../relay/test_analysis_extract_fused_functions.py | 115 +++++++++++++++++++++
 3 files changed, 219 insertions(+)

diff --git a/python/tvm/relay/analysis.py b/python/tvm/relay/analysis.py
index fc4a037..198e0a3 100644
--- a/python/tvm/relay/analysis.py
+++ b/python/tvm/relay/analysis.py
@@ -407,3 +407,25 @@ def structural_hash(value):
         msg = ("found value of type {0} expected" +
                "relay.Expr or relay.Type").format(type(value))
         raise TypeError(msg)
+
+
+def extract_fused_functions(mod):
+    """Pass to extract IRModule of only fused primitive functions.
+
+    The ExtractFusedFunctions pass invokes SimplifyInference, FuseOps(3),
+    and ExtractFusedFunctions in that order
+
+    Parameters
+    ----------
+    mod : tvm.relay.IRModule
+
+    Returns
+    -------
+    ret : Dict[int, tvm.relay.expr.Function]
+        A module containing only fused primitive functions
+    """
+    ret_mod = _analysis.ExtractFusedFunctions()(mod)
+    ret = {}
+    for hash_, func in ret_mod.functions.items():
+        ret[hash_] = func
+    return ret
diff --git a/src/relay/analysis/extract_fused_functions.cc 
b/src/relay/analysis/extract_fused_functions.cc
new file mode 100644
index 0000000..3667d8a
--- /dev/null
+++ b/src/relay/analysis/extract_fused_functions.cc
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file extract_fused_functions.cc
+ * \brief Apply fusion and extract fused primitive functions from an IRModule
+ */
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+
+namespace tvm {
+namespace relay {
+
+class FusedFunctionExtractorWrapper : private ExprVisitor {
+ public:
+  explicit FusedFunctionExtractorWrapper(const IRModule& mod) : mod_(mod) {}
+
+  IRModule Extract() {
+    VisitExpr(this->mod_->Lookup("main"));
+
+    auto functions = Map<GlobalVar, BaseFunc>();
+    for (auto pair : this->functions) {
+      functions.Set(GlobalVar(pair.first), pair.second);
+    }
+
+    this->mod_->functions = functions;
+    return this->mod_;
+  }
+
+ private:
+  const IRModule mod_;
+  // This is not simply Map<GlobalVar, Function> because GlobalVar doesn't
+  // have the desired equals property
+  Map<std::string, Function> functions;
+
+  void VisitExpr_(const FunctionNode* n) final {
+    if (n->HasNonzeroAttr(attr::kPrimitive)) {
+      // Add function to functions, keyed by function hash string
+      Function func = Function(n->params, n->body, n->ret_type, 
n->type_params, n->attrs);
+      size_t hash_ = StructuralHash()(func);
+      this->functions.Set(std::to_string(hash_), func);
+    }
+
+    ExprVisitor::VisitExpr_(n);
+  }
+};
+
+namespace transform {
+
+Pass ExtractFusedFunctions() {
+  runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
+      [=](IRModule m, PassContext pc) { return 
FusedFunctionExtractorWrapper(m).Extract(); };
+  auto fused_function_extractor_pass = CreateModulePass(pass_func, 1, 
"ExtractFusedFunctions", {});
+
+  return Sequential({SimplifyInference(), FuseOps(3), 
fused_function_extractor_pass},
+                    "ExtractFusedFunctions");
+}
+
+TVM_REGISTER_GLOBAL("relay._analysis.ExtractFusedFunctions").set_body_typed(ExtractFusedFunctions);
+
+}  // namespace transform
+
+}  // namespace relay
+}  // namespace tvm
diff --git a/tests/python/relay/test_analysis_extract_fused_functions.py 
b/tests/python/relay/test_analysis_extract_fused_functions.py
new file mode 100644
index 0000000..1a70ef1
--- /dev/null
+++ b/tests/python/relay/test_analysis_extract_fused_functions.py
@@ -0,0 +1,115 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Test function extraction"""
+import tvm
+from tvm import relay
+from tvm.relay.testing.resnet import get_workload
+
+
+def get_conv_net():
+    """This gets the net for a case described in fuse_ops.cc:
+
+            conv2d
+            /  |  \
+           /   |   \
+         op    op   op
+          \    |    /
+           \   |   /
+          elemwise add
+               |
+    """
+    dshape = (1, 1, 5, 1)
+    x = relay.var("x", shape=dshape)
+    y = relay.nn.conv2d(x, relay.var("w1"),
+                        kernel_size=(3, 3),
+                        padding=(1, 1),
+                        channels=1)
+
+    x1 = relay.nn.conv2d(y, relay.var("w2"),
+                         kernel_size=(3, 3),
+                         padding=(1, 1),
+                         channels=1)
+    x2 = relay.nn.conv2d(y, relay.var("w3"),
+                         kernel_size=(3, 3),
+                         padding=(1, 1),
+                         channels=1)
+    x3 = relay.nn.conv2d(y, relay.var("w4"),
+                         kernel_size=(3, 3),
+                         padding=(1, 1),
+                         channels=1)
+
+    z = relay.add(x1, x2)
+    z = relay.add(x3, z)
+
+    return tvm.IRModule.from_expr(z)
+
+
+def get_conv2d():
+    x = relay.var("x", shape=(1, 56, 56, 64))
+    weight1 = relay.var('weight1', shape=(3, 3, 64, 32))
+    y = relay.nn.conv2d(x, weight1,
+                        channels=32,
+                        kernel_size=(3, 3),
+                        padding=(1, 1),
+                        data_layout='NHWC',
+                        kernel_layout='HWIO')
+    return tvm.IRModule.from_expr(y)
+
+
+def test_extract_identity():
+    mod = get_conv2d()
+    items = relay.analysis.extract_fused_functions(mod)
+    assert len(items) == 1
+
+    mod["main"] = mod["main"].with_attr(
+        "Primitive", tvm.tir.IntImm("int32", 1))
+    relay.analysis.assert_graph_equal(list(items.values())[0], mod["main"])
+
+
+def test_extract_conv_net():
+    mod = get_conv_net()
+    items = relay.analysis.extract_fused_functions(mod)
+    functions = list(items.values())
+    assert len(functions) == 2
+    x = functions[0]
+    y = functions[1]
+
+    def is_conv(func):
+        conv2d = relay.op.op.get("nn.conv2d")
+        call_node = func.body
+        return call_node.op == conv2d
+
+    def is_conv_add(func):
+        add = relay.op.op.get("add")
+        call_node = func.body
+        maybe_conv_module = tvm.IRModule.from_expr(call_node.args[0])
+        return call_node.op == add and is_conv(maybe_conv_module["main"])
+
+    # Function traversal order isn't obvious, so checking both orders is more 
consistent
+    assert (is_conv(x) and is_conv_add(y)) or (is_conv_add(x) and is_conv(y))
+
+
+def test_extract_resnet():
+    mod, _params = get_workload()
+    items = relay.analysis.extract_fused_functions(mod)
+    assert len(items) == 34
+
+
+if __name__ == '__main__':
+    test_extract_identity()
+    test_extract_conv_net()
+    test_extract_resnet()

Reply via email to