mbaret commented on a change in pull request #10759:
URL: https://github.com/apache/tvm/pull/10759#discussion_r834323451
##########
File path: python/tvm/relay/op/contrib/tensorrt.py
##########
@@ -446,7 +451,7 @@ def dense_annotate_fn(expr): # pylint:
disable=unused-variable
@_register_external_dynamic_check_func("nn.batch_matmul")
-def batch_matmul_annotate_fn(expr):
+def batch_matmul_annotate_fn(expr): # pylint: disable=unused-variable
Review comment:
Is this change required?
##########
File path: src/relay/transforms/unmerge_composites.cc
##########
@@ -0,0 +1,124 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/transforms/unmerge_composites.cc
+ * \brief Undo the partioned graphs originate from merge composite.
+ */
+#include <tvm/relay/expr.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+
+#include "../analysis/call_graph.h"
+#include "../op/call/call.h"
+
+using namespace tvm::runtime;
+
+namespace tvm {
+
+namespace relay {
+
+class Unmerger : ExprMutator {
+ public:
+ explicit Unmerger(CallGraphEntry* cur_node, CallGraphNode* call_graph)
+ : cur_node_(cur_node), call_graph_(call_graph) {}
+
+ Expr VisitExpr_(const CallNode* call_node) final {
+ Call vanilla_call = GetAnyCall(call_node);
+ const auto* global_var_node = vanilla_call->op.as<GlobalVarNode>();
+ const auto* function_var_node = vanilla_call->op.as<FunctionNode>();
+
+ if (function_var_node) {
+ Function gv = GetRef<Function>(function_var_node);
+ const auto* fn = gv.as<FunctionNode>();
+
+ Array<Expr> new_args;
+ new_args.reserve(vanilla_call->args.size());
+ for (auto arg : vanilla_call->args) {
+ new_args.push_back(VisitExpr(arg));
+ }
+
+ Map<Var, Expr> bind_map;
+ for (size_t i = 0; i < new_args.size(); i++) {
+ bind_map.Set(fn->params[i], new_args[i]);
+ }
+
+ // Attrs need to be empty at this point to avoid propagating Composite
and
+ // PartitionedFromPattern that fiddling TRT code gen for registered ops.
+ auto func = Function(fn->params, fn->body, fn->ret_type,
fn->type_params, {});
+ return Bind(func->body, bind_map);
Review comment:
Not sure I understand this, why can't we just do
`return Bind(fn->body, bind_map);`
##########
File path: src/relay/transforms/unmerge_composites.cc
##########
@@ -0,0 +1,124 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/transforms/unmerge_composites.cc
+ * \brief Undo the partioned graphs originate from merge composite.
+ */
+#include <tvm/relay/expr.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+
+#include "../analysis/call_graph.h"
+#include "../op/call/call.h"
+
+using namespace tvm::runtime;
+
+namespace tvm {
+
+namespace relay {
+
+class Unmerger : ExprMutator {
+ public:
+ explicit Unmerger(CallGraphEntry* cur_node, CallGraphNode* call_graph)
+ : cur_node_(cur_node), call_graph_(call_graph) {}
+
+ Expr VisitExpr_(const CallNode* call_node) final {
+ Call vanilla_call = GetAnyCall(call_node);
+ const auto* global_var_node = vanilla_call->op.as<GlobalVarNode>();
+ const auto* function_var_node = vanilla_call->op.as<FunctionNode>();
+
+ if (function_var_node) {
+ Function gv = GetRef<Function>(function_var_node);
+ const auto* fn = gv.as<FunctionNode>();
Review comment:
Is this needed - it looks like we already start with the FunctionNode?
##########
File path: src/relay/transforms/unmerge_composites.cc
##########
@@ -0,0 +1,124 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/transforms/unmerge_composites.cc
+ * \brief Undo the partioned graphs originate from merge composite.
Review comment:
I think 'Inline composite functions for a given target' describes this a
bit better.
##########
File path: src/relay/transforms/unmerge_composites.cc
##########
@@ -0,0 +1,124 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/transforms/unmerge_composites.cc
Review comment:
Personal preference for the name here would be either InlineComposite or
RemoveComposite, not a huge deal though, if no one else agrees we can keep it
as Unmerge.
##########
File path: src/relay/transforms/unmerge_composites.cc
##########
@@ -0,0 +1,124 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/transforms/unmerge_composites.cc
+ * \brief Undo the partioned graphs originate from merge composite.
+ */
+#include <tvm/relay/expr.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+
+#include "../analysis/call_graph.h"
+#include "../op/call/call.h"
+
+using namespace tvm::runtime;
+
+namespace tvm {
+
+namespace relay {
+
+class Unmerger : ExprMutator {
Review comment:
MixedModeMutator is now preferred where possible.
##########
File path: python/tvm/relay/transform/transform.py
##########
@@ -543,7 +543,10 @@ def MergeComposite(pattern_table):
for tup in pattern_table:
if len(tup) == 2:
pattern_name, pattern = tup
- check = lambda extract: True
Review comment:
Could you explain this change?
##########
File path: src/relay/transforms/unmerge_composites.cc
##########
@@ -0,0 +1,124 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/transforms/unmerge_composites.cc
+ * \brief Undo the partioned graphs originate from merge composite.
+ */
+#include <tvm/relay/expr.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+
+#include "../analysis/call_graph.h"
+#include "../op/call/call.h"
+
+using namespace tvm::runtime;
+
+namespace tvm {
+
+namespace relay {
+
+class Unmerger : ExprMutator {
+ public:
+ explicit Unmerger(CallGraphEntry* cur_node, CallGraphNode* call_graph)
+ : cur_node_(cur_node), call_graph_(call_graph) {}
+
+ Expr VisitExpr_(const CallNode* call_node) final {
+ Call vanilla_call = GetAnyCall(call_node);
+ const auto* global_var_node = vanilla_call->op.as<GlobalVarNode>();
Review comment:
unused
##########
File path: src/relay/transforms/unmerge_composites.cc
##########
@@ -0,0 +1,124 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/transforms/unmerge_composites.cc
+ * \brief Undo the partioned graphs originate from merge composite.
+ */
+#include <tvm/relay/expr.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+
+#include "../analysis/call_graph.h"
+#include "../op/call/call.h"
+
+using namespace tvm::runtime;
+
+namespace tvm {
+
+namespace relay {
+
+class Unmerger : ExprMutator {
+ public:
+ explicit Unmerger(CallGraphEntry* cur_node, CallGraphNode* call_graph)
+ : cur_node_(cur_node), call_graph_(call_graph) {}
+
+ Expr VisitExpr_(const CallNode* call_node) final {
+ Call vanilla_call = GetAnyCall(call_node);
+ const auto* global_var_node = vanilla_call->op.as<GlobalVarNode>();
+ const auto* function_var_node = vanilla_call->op.as<FunctionNode>();
+
+ if (function_var_node) {
+ Function gv = GetRef<Function>(function_var_node);
+ const auto* fn = gv.as<FunctionNode>();
+
+ Array<Expr> new_args;
+ new_args.reserve(vanilla_call->args.size());
+ for (auto arg : vanilla_call->args) {
+ new_args.push_back(VisitExpr(arg));
+ }
+
+ Map<Var, Expr> bind_map;
+ for (size_t i = 0; i < new_args.size(); i++) {
+ bind_map.Set(fn->params[i], new_args[i]);
+ }
+
+ // Attrs need to be empty at this point to avoid propagating Composite
and
+ // PartitionedFromPattern that fiddling TRT code gen for registered ops.
+ auto func = Function(fn->params, fn->body, fn->ret_type,
fn->type_params, {});
+ return Bind(func->body, bind_map);
+ }
+
+ return ExprMutator::VisitExpr_(call_node);
+ }
+
+ Function Unmerge(const Function& func) {
+ return WithFields(func, func->params, VisitExpr(func->body));
+ }
+
+ private:
+ /*!
+ * \brief The current call graph entry that is being handled. Each entry
+ * contains a global function.
+ */
+ CallGraphEntry* cur_node_;
+ /*! \brief The call graph that is used for global function lookup. */
+ const CallGraphNode* call_graph_;
+};
+
+IRModule UnmergeComposites(const IRModule& module, runtime::String target) {
+ CallGraph cg(module);
+ auto topo = cg->TopologicalOrder();
+ std::reverse(topo.begin(), topo.end());
+ std::unordered_set<CallGraphEntry*> original_entry;
+ ICHECK(target.defined());
+ for (auto* it : topo) {
+ auto base_func = module->Lookup(it->GetNameHint());
+
+ if (!base_func->GetAttr<String>(attr::kCompiler).defined() &&
+ base_func->GetAttr<String>(attr::kCompiler) != target) {
+ return module;
Review comment:
I think it'd be better to `continue;` here rather than return, otherwise
it seems if any partitioning for a different target has taken place, this will
bail out.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]