This is an automated email from the ASF dual-hosted git repository.
mbrookhart pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 0812c07 Change Call with TIRCallAttrs to call_lowered op (#9312)
0812c07 is described below
commit 0812c078bfc8595d079fbb15cdc4a64bf1a4def2
Author: Lily Orth-Smith <[email protected]>
AuthorDate: Wed Nov 10 14:20:01 2021 -0800
Change Call with TIRCallAttrs to call_lowered op (#9312)
* Introduce call_lowered op
Add op vm.call_tir
Change from checking if CallNode has CallTIRAttrs to checking if the Op is
vm.call_tir
Change device_domains to use vm.call_tir op more explicitly
Fixed issue in type checker, now have seg fault :(
Fix typo -- most of VM tests pass now
Interpreter now deals with call_tir properly
Fix typo in te_compiler
Use InvokeTVMOp and CallTIR
Add some checks to graph_plan_memory.cc
Make GetToken skip function types
C++ TESTS PASS WOOHOO
Remove prints
formatting
vm.call_tir -> call_tir and more comment removals
call_tir -> call_lowered
fix lint
clang format
Remove compute from non computational vm ops
missed some semicolons in prev commit
Fix warning
Move call_lowered to relay/op/call/call.cc and rename util func
Add helper fn that returns lowered_call op
fix import order
clang format
Add constraint to call_lowered type rel
clean up empty token vector
comment
Move CallTIRAttrs to include/tvm/relay/attrs/call.h
Rename TIRCallAttrs as CallLoweredAttrs
lint
Add helper for extracting func and args from call_lowered
Change graph_executor_codegen to use helper function
Update interpreter to use helper
Fix device_domains.cc -- could still use cleanup, also I am not sure why
there are still direct calls to primfns in DomainforCallee
Clean up DeviceCopyProps and lint
lint
return CallLoweredAttrs with the extern func
comment
note in comment
Progress & notes. Realized that I am not handling externs correctly
not sure why this ever worked before?
Clean up CreateFuncCall signature, notes
comments
Fix extern function handling
extern_function -> extern_func
fix DeviceAwareVisitExpr_ -- now it handles both lowered and normal calls
yay passes AOT tests!
formatting and comment removal
cleanup
Introduce call_lowered op
* lint
* Fix AOT to deal with externs
* add const auto&
* Fix aot crt test
---
include/tvm/relay/attrs/annotation.h | 11 --
.../op/vm/vm.h => include/tvm/relay/attrs/call.h | 30 ++--
src/relay/backend/aot_executor_codegen.cc | 77 +++++++---
.../contrib/example_target_hooks/relay_to_tir.cc | 12 +-
src/relay/backend/graph_executor_codegen.cc | 89 ++++++-----
src/relay/backend/graph_plan_memory.cc | 52 ++++---
src/relay/backend/interpreter.cc | 152 +++++++++---------
src/relay/backend/te_compiler.cc | 169 +++++++++++----------
src/relay/op/call/call.cc | 116 ++++++++++++++
src/relay/op/call/call.h | 74 +++++++++
src/relay/op/memory/device_copy.cc | 17 +++
src/relay/op/vm/vm.h | 2 +-
src/relay/transforms/device_domains.cc | 33 ++--
src/relay/transforms/memory_alloc.cc | 16 +-
14 files changed, 574 insertions(+), 276 deletions(-)
diff --git a/include/tvm/relay/attrs/annotation.h
b/include/tvm/relay/attrs/annotation.h
index 85ac3f36..f88ca8e 100644
--- a/include/tvm/relay/attrs/annotation.h
+++ b/include/tvm/relay/attrs/annotation.h
@@ -116,17 +116,6 @@ struct CompilerAttrs : public
tvm::AttrsNode<CompilerAttrs> {
}
};
-/*!
- * \brief Metadata for calls to TIR functions, useful for program analysis
crossing Relay and TIR.
- */
-struct TIRCallAttrs : public tvm::AttrsNode<TIRCallAttrs> {
- /*! \brief The metadata attached to the call node. */
- Map<String, ObjectRef> metadata;
-
- TVM_DECLARE_ATTRS(TIRCallAttrs, "relay.attrs.TIRCallAttrs") {
- TVM_ATTR_FIELD(metadata).describe("Metadata attached to the TIR function
call.");
- }
-};
} // namespace relay
} // namespace tvm
diff --git a/src/relay/op/vm/vm.h b/include/tvm/relay/attrs/call.h
similarity index 57%
copy from src/relay/op/vm/vm.h
copy to include/tvm/relay/attrs/call.h
index 802c810..2b02c6a 100644
--- a/src/relay/op/vm/vm.h
+++ b/include/tvm/relay/attrs/call.h
@@ -18,23 +18,31 @@
*/
/*!
- * \file src/relay/op/vm/vm.h
- * \brief Dialect operators for Relay VM.
+ * \file tvm/relay/attrs/call.h
+ * \brief Attribute for call_lowered operator.
*/
-#ifndef TVM_RELAY_OP_VM_VM_H_
-#define TVM_RELAY_OP_VM_VM_H_
+#ifndef TVM_RELAY_ATTRS_CALL_H_
+#define TVM_RELAY_ATTRS_CALL_H_
-#include "tvm/relay/expr.h"
+#include <tvm/ir/attrs.h>
+
+#include <string>
namespace tvm {
namespace relay {
-Expr InvokeTVMOp(Expr func, Expr inputs, Expr outputs);
-Expr ShapeFunc(Expr func, Expr inputs, Expr outputs, Array<tvm::Integer>
is_input);
-Expr ShapeOf(Expr expr);
-Expr ReshapeTensor(Expr data, Expr shape, Array<PrimExpr> newshape);
+/*!
+ * \brief Metadata for calls to TIR functions, useful for program analysis
crossing Relay and TIR.
+ */
+struct CallLoweredAttrs : public tvm::AttrsNode<CallLoweredAttrs> {
+ /*! \brief The metadata attached to the call node. */
+ Map<String, ObjectRef> metadata;
+
+ TVM_DECLARE_ATTRS(CallLoweredAttrs, "relay.attrs.CallLoweredAttrs") {
+ TVM_ATTR_FIELD(metadata).describe("Metadata attached to the lowered
function call.");
+ }
+};
} // namespace relay
} // namespace tvm
-
-#endif // TVM_RELAY_OP_VM_VM_H_
+#endif // TVM_RELAY_ATTRS_CALL_H_
diff --git a/src/relay/backend/aot_executor_codegen.cc
b/src/relay/backend/aot_executor_codegen.cc
index 7e57022..58bcccf 100644
--- a/src/relay/backend/aot_executor_codegen.cc
+++ b/src/relay/backend/aot_executor_codegen.cc
@@ -24,6 +24,7 @@
#include <tvm/ir/module.h>
#include <tvm/relay/attrs/annotation.h>
+#include <tvm/relay/attrs/call.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/object.h>
@@ -40,6 +41,7 @@
#include <vector>
#include "../op/annotation/annotation.h"
+#include "../op/call/call.h"
#include "../transforms/device_aware_visitors.h"
#include "./te_compiler.h"
#include "./utils.h"
@@ -72,14 +74,34 @@ class AOTOnDemandAllocator : public
transform::DeviceAwareExprVisitor {
AssignReturnSid(GetRef<Expr>(op));
}
- void DeviceAwareVisitExpr_(const CallNode* op) final {
- // create token for the call node.
- VisitExpr(op->op);
- CreateStorage(op);
- for (Expr arg : op->args) {
+ void DeviceAwareVisitExpr_(const CallNode* call_node) final {
+ // AOTOnDemandAllocator is run both before and after lowering, so we need
to handle the case
+ // where the op of the call is a generic function
+
+ Expr func;
+ Array<Expr> args;
+
+ if (call_node->op == CallLoweredOp()) {
+ CallLoweredProps call_lowered_props = GetCallLoweredProps(call_node);
+ func = call_lowered_props.lowered_func;
+ args = call_lowered_props.arguments;
+ } else { // Relay functions that have not been lowered and lowered extern
functions
+ func = call_node->op;
+ args = call_node->args;
+ if (call_node->op.as<GlobalVarNode>()) { // Lowered extern function
+ ICHECK(!(call_node->attrs.defined())) << "Extern functions should have
null attributes.";
+ } else { // Relay function which has not been lowered yet
+ ICHECK(call_node->op.as<FunctionNode>())
+ << "Expected the call to be to a lowered primfunc, a lowered
extern function or a "
+ "unlowered Relay function.";
+ }
+ }
+ VisitExpr(func);
+ CreateStorage(call_node);
+ for (const Expr& arg : args) {
GetStorage(arg);
}
- AssignReturnSid(GetRef<Expr>(op));
+ AssignReturnSid(GetRef<Expr>(call_node));
}
void VisitExpr_(const VarNode* op) final {
AssignReturnSid(GetRef<Expr>(op)); }
@@ -287,13 +309,18 @@ class AOTExecutorCodegen : public MixedModeVisitor {
}
/*!
- * brief Call a function with a given name
+ * brief Create a function call
+ * \param call_lowered_props The lowered function and the arguments to call
it with
+ * \param call The call we got func and args from
*/
- void CreateFuncCall(Call call, std::string func_name) {
+ void CreateFuncCall(CallLoweredProps call_lowered_props, Call call) {
+ std::string func_name = call_lowered_props.lowered_func->name_hint;
+
tvm::Array<PrimExpr> args{tvm::tir::StringImm(func_name)};
std::vector<tir::Stmt> create_func_call_stmts;
+
// Pack the inputs
- for (Expr arg : call->args) {
+ for (const Expr& arg : call_lowered_props.arguments) {
if (params_by_expr_.find(arg) != params_by_expr_.end()) {
auto param_handle = tvm::tir::Call(DataType::Handle(),
tvm::tir::builtin::lookup_param(),
{tir::StringImm(params_by_expr_[arg])});
@@ -371,21 +398,25 @@ class AOTExecutorCodegen : public MixedModeVisitor {
return ss.str();
}
- void VisitExpr_(const CallNode* op) override {
+ void VisitExpr_(const CallNode* call_node) override {
// Descend the call tree
- for (auto arg : op->args) {
- VisitExpr(arg);
- }
-
- if (op->op.as<OpNode>()) {
- LOG(FATAL) << "Operators should be transformed away; try applying"
- << "the fuse_ops transformation to the expression.";
- } else if (op->op.as<GlobalVarNode>()) {
- GlobalVar node = GetRef<GlobalVar>(op->op.as<GlobalVarNode>());
- CreateFuncCall(GetRef<Call>(op), node->name_hint);
+ CallLoweredProps call_lowered_props;
+ if (const auto* gvn = call_node->op.as<GlobalVarNode>()) { // Lowered
extern function
+ ICHECK(!(call_node->attrs.defined())) << "Extern functions should have
null attributes.";
+ for (const auto& arg : call_node->args) {
+ VisitExpr(arg);
+ }
+ call_lowered_props = CallLoweredProps{GetRef<GlobalVar>(gvn),
call_node->args, {}};
} else {
- LOG(FATAL) << "TVM runtime does not support calls to " <<
op->op->GetTypeKey();
+ ICHECK(call_node->op == CallLoweredOp()) << "Operators should be
transformed away; Try "
+ "applying the fuse_ops
transformation to the "
+ "expression.";
+ call_lowered_props = GetCallLoweredProps(call_node);
+ for (const auto& arg : call_lowered_props.arguments) {
+ VisitExpr(arg);
+ }
}
+ CreateFuncCall(call_lowered_props, GetRef<Call>(call_node));
}
void VisitExpr_(const VarNode* op) override {
@@ -443,7 +474,9 @@ class AOTExecutorCodegen : public MixedModeVisitor {
}
void VisitExpr_(const TupleGetItemNode* op) override { VisitExpr(op->tuple);
}
void VisitExpr_(const OpNode* op) override {
- LOG(FATAL) << "All OpNodes should have been expanded";
+ if (GetRef<Op>(op) != CallLoweredOp()) {
+ LOG(FATAL) << "All OpNodes except for call_lowered should have been
expanded";
+ }
}
void VisitExpr_(const IfNode* op) override {
LOG(FATAL) << "All GlobalVarNodes should be removed before AOT executor's
Codegen is called";
diff --git a/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc
b/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc
index cae2021..c41399e 100644
--- a/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc
+++ b/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc
@@ -17,14 +17,18 @@
* specific language governing permissions and limitations
* under the License.
*/
+#include <tvm/relay/attrs/call.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h>
+#include <tvm/runtime/memory.h>
#include <tvm/tir/buffer.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/function.h>
#include <tvm/tir/op.h>
+#include "../../../op/call/call.h"
+
namespace tvm {
namespace relay {
namespace contrib {
@@ -109,7 +113,13 @@ class ConvertAddToSubtract : public MixedModeMutator {
GlobalVar new_global_var(func_name.value());
new_global_var->checked_type_ = func->checked_type();
ReplaceAddWithSubtractPrimFunc(new_global_var, GetRef<Function>(func));
- return Call(new_global_var, call->args, call->attrs, call->type_args,
call->span);
+
+ // Since we are replacing the Relay function with a call to a TIR
function, we must use the
+ // call_lowered op.
+ auto call_lowered_attrs = make_object<CallLoweredAttrs>();
+ call_lowered_attrs->metadata.Set("relay_attrs", call->attrs);
+ return CallLowered(std::move(new_global_var), call->args,
+ std::move(Attrs(call_lowered_attrs)),
call->type_args, call->span);
}
}
diff --git a/src/relay/backend/graph_executor_codegen.cc
b/src/relay/backend/graph_executor_codegen.cc
index d32ded3..ac3c835 100644
--- a/src/relay/backend/graph_executor_codegen.cc
+++ b/src/relay/backend/graph_executor_codegen.cc
@@ -26,6 +26,7 @@
#include <dmlc/json.h>
#include <tvm/ir/module.h>
#include <tvm/relay/attrs/annotation.h>
+#include <tvm/relay/attrs/call.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/object.h>
@@ -37,6 +38,7 @@
#include <vector>
#include "../op/annotation/annotation.h"
+#include "../op/call/call.h"
#include "../transforms/device_aware_visitors.h"
#include "./te_compiler.h"
#include "./utils.h"
@@ -403,64 +405,75 @@ class GraphExecutorCodegen : public
backend::MemoizedExprTranslator<std::vector<
return lhs_storage_id == rhs_storage_id;
}
- std::vector<GraphNodeRef> GraphAddCallNode(const CallNode* op, const
std::string& func_name,
- GraphAttrs attrs) {
+ std::vector<GraphNodeRef> GraphAddCallNode(const CallNode* call_node,
GraphAttrs attrs) {
+ Call call = GetRef<Call>(call_node);
std::vector<GraphNodeRef> inputs;
- for (auto arg : op->args) {
- auto res = VisitExpr(arg);
- for (auto nr : res) {
- inputs.push_back(nr);
- }
- }
+ std::string func_name;
- /// An adapted version of the storage optimization for the time being.
- bool reshape_only = false;
- if (op->attrs.defined()) {
- if (auto tir_call_attrs = op->attrs.as<TIRCallAttrs>()) {
- Map<String, ObjectRef> metadata = tir_call_attrs->metadata;
- if (metadata.count(attr::kReshapeOnly) &&
- Downcast<tvm::Integer>(metadata[attr::kReshapeOnly])->value == 1) {
- reshape_only = true;
- }
+ if (call->op == CallLoweredOp()) {
+ // Extract function and arguments from the call_lowered op
+ CallLoweredProps call_lowered_props = GetCallLoweredProps(call_node);
- auto relay_attrs =
Downcast<DictAttrs>(tir_call_attrs->metadata["relay_attrs"]);
+ func_name = call_lowered_props.lowered_func->name_hint;
- for (auto p : relay_attrs->dict) {
- if (p.second.as<StringObj>()) {
- attrs[p.first] = std::string(Downcast<String>(p.second));
+ for (const Expr& arg : call_lowered_props.arguments) {
+ for (auto n : VisitExpr(arg)) {
+ inputs.push_back(n);
+ }
+ }
+ if (call_lowered_props.attrs.metadata.count("relay_attrs")) {
+ if (auto relay_attrs =
+
call_lowered_props.attrs.metadata["relay_attrs"].as<DictAttrsNode>()) {
+ for (auto p : relay_attrs->dict) {
+ if (p.second.as<StringObj>()) {
+ attrs[p.first] = std::string(Downcast<String>(p.second));
+ }
}
}
}
- }
-
- if (reshape_only && ShareSameStorage(GetRef<Expr>(op), op->args[0])) {
- auto node = GraphOpNode::make_node_ptr("reshape_nop", GraphAttrs(),
"__nop", inputs, attrs);
- return AddNode(node, GetRef<Expr>(op));
+ bool reshape_only = false;
+ if (call_lowered_props.attrs.metadata.count(attr::kReshapeOnly) &&
+
Downcast<tvm::Integer>(call_lowered_props.attrs.metadata[attr::kReshapeOnly])->value
==
+ 1) {
+ reshape_only = true;
+ }
+ if (reshape_only &&
+ ShareSameStorage(GetRef<Expr>(call_node),
call_lowered_props.arguments[0])) {
+ auto node = GraphOpNode::make_node_ptr("reshape_nop", GraphAttrs(),
"__nop", inputs, attrs);
+ return AddNode(node, call);
+ }
+ } else if (!call_node->attrs.defined()) { // Call is an extern function
+ std::cout << "call_node: \n" << PrettyPrint(call) << std::endl;
+ const auto* func = call_node->op.as<GlobalVarNode>();
+ ICHECK(func) << "Expected the operator to be a global var, but got "
+ << call_node->op->GetTypeKey(); // getting a relay fn
here, not sure why.
+ func_name = func->name_hint;
+
+ for (const Expr& arg : call_node->args) {
+ for (auto n : VisitExpr(arg)) {
+ inputs.push_back(n);
+ }
+ }
+ } else {
+ LOG(FATAL) << "Non-primitive-call nodes should have been transformed
away.\n"
+ << "The graph executor code generator expects all calls to be
call_lowered, "
+ << "but found: " << std::endl
+ << PrettyPrint(call);
}
// Compute the operator name, because we used the get unique name when
generating the kernel.
auto op_name = _GetUniqueName(func_name);
auto node = GraphOpNode::make_node_ptr(op_name, GraphAttrs(), func_name,
inputs, attrs);
- return AddNode(node, GetRef<Expr>(op));
+ return AddNode(node, call);
}
std::vector<GraphNodeRef> VisitExpr_(const CallNode* call_node) override {
- relay::Call call = GetRef<Call>(call_node);
auto props = GetOnDeviceProps(call_node);
if (props.body.defined()) {
// See through "on_device" calls.
return VisitExpr(props.body);
}
-
- const auto* global_node = call->op.as<GlobalVarNode>();
- ICHECK(global_node)
- << "Non-primitive-call nodes should have been transformed away.\n"
- << "The graph executor code generator expects all calls to have their
callee "
- "normalized to a GlobalVar, but found:"
- << std::endl
- << PrettyPrint(call);
- auto prim_fn_name = global_node->name_hint;
- return GraphAddCallNode(call_node, prim_fn_name, GraphAttrs());
+ return GraphAddCallNode(call_node, GraphAttrs());
}
std::vector<GraphNodeRef> VisitExpr_(const LetNode* op) override {
diff --git a/src/relay/backend/graph_plan_memory.cc
b/src/relay/backend/graph_plan_memory.cc
index 961252a..4031dfd 100644
--- a/src/relay/backend/graph_plan_memory.cc
+++ b/src/relay/backend/graph_plan_memory.cc
@@ -24,6 +24,7 @@
*/
#include <tvm/relay/analysis.h>
#include <tvm/relay/attrs/annotation.h>
+#include <tvm/relay/attrs/call.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h>
@@ -32,6 +33,7 @@
#include "../../support/arena.h"
#include "../op/annotation/annotation.h"
+#include "../op/call/call.h"
#include "../op/memory/memory.h"
#include "../transforms/device_aware_visitors.h"
#include "./utils.h"
@@ -139,6 +141,8 @@ class StorageAllocaBaseVisitor : public
transform::DeviceAwareExprVisitor {
protected:
/*! \brief internal token map */
std::unordered_map<const ExprNode*, std::vector<StorageToken*>> token_map_;
+ /*! \brief empty token map */
+ const std::vector<StorageToken*> no_tokens_;
/*!
* \brief Get the necessary token.
@@ -146,6 +150,11 @@ class StorageAllocaBaseVisitor : public
transform::DeviceAwareExprVisitor {
* \return The corresponding token.
*/
const std::vector<StorageToken*>& GetToken(const Expr& expr) {
+ this->VisitExpr(expr);
+ // Functions don't require data storage, represented by the empty token
+ if (expr->checked_type().as<FuncTypeNode>()) {
+ return no_tokens_;
+ }
// See through on_device calls.
Expr real_expr = IgnoreOnDevice(expr);
this->VisitExpr(real_expr);
@@ -159,8 +168,9 @@ class StorageAllocaBaseVisitor : public
transform::DeviceAwareExprVisitor {
* \brief Allocates (or reuses if \p can_realloc is true) a storage token
for holding
* the result of evaluating \p op.
*/
- void CreateToken(const ExprNode* op, bool can_realloc) {
- return CreateTokenOnDevice(op, GetInScopeDeviceType(GetRef<Expr>(op)),
can_realloc);
+ void CreateToken(const ExprNode* expr_node, bool can_realloc) {
+ return CreateTokenOnDevice(expr_node,
GetInScopeDeviceType(GetRef<Expr>(expr_node)),
+ can_realloc);
}
/*!
@@ -203,12 +213,12 @@ class StorageAllocaInit : protected
StorageAllocaBaseVisitor {
using StorageAllocaBaseVisitor::DeviceAwareVisitExpr_;
- void DeviceAwareVisitExpr_(const CallNode* op) final {
+ void DeviceAwareVisitExpr_(const CallNode* call_node) final {
// create token for the call node.
- CreateToken(op, true);
+ CreateToken(call_node, true);
// for each input, visit argument token.
- for (Expr arg : op->args) {
+ for (Expr arg : call_node->args) {
for (StorageToken* tok : GetToken(arg)) {
tok->ref_counter += 1;
}
@@ -273,7 +283,6 @@ class StorageAllocator : public StorageAllocaBaseVisitor {
<< "expressions are assigned with virtual device types.
Either all "
"or none of the expressions are expected to be annotated.";
}
-
return backend::StaticMemoryPlan(smap);
}
@@ -320,10 +329,13 @@ class StorageAllocator : public StorageAllocaBaseVisitor {
using StorageAllocaBaseVisitor::DeviceAwareVisitExpr_;
// The call map
- void DeviceAwareVisitExpr_(const CallNode* op) final {
+ void DeviceAwareVisitExpr_(const CallNode* call_node) final {
std::vector<StorageToken*> args;
// for each input, visit argument token.
- for (Expr arg : op->args) {
+
+ for (const Expr& arg : call_node->args) {
+ // Note: GetToken skips GlobalVars and handles tuples properly, so we
don't need to treat
+ // call_lowered specially.
for (StorageToken* tok : GetToken(arg)) {
args.push_back(tok);
}
@@ -337,20 +349,17 @@ class StorageAllocator : public StorageAllocaBaseVisitor {
//
// TODO(tvm-team) Update checks of flat memory enablement when we support
// opaque-nd memory planning to skip this path.
- if (IsReshape(op)) {
- // TODO(@electriclilies, jroesch): This check is failing because the
size of args is 3
- // I can't figure out where the extra args are coming from, I assume it
must be related
- // to the relay_attrs field we added to the TIRCallArgs, but I don't
know where / how
- // that's happening...
+
+ if (IsReshape(call_node)) {
ICHECK_EQ(args.size(), 1U);
- ReuseInputToken(op, args[0]);
+ ReuseInputToken(call_node, args[0]);
} else {
// create token for the call node.
- CreateToken(op, true);
+ CreateToken(call_node, true);
}
// check if there is orphaned output that can be released immediately.
- for (StorageToken* tok : token_map_.at(op)) {
+ for (StorageToken* tok : token_map_.at(call_node)) {
CheckForRelease(tok);
}
for (StorageToken* tok : args) {
@@ -376,12 +385,11 @@ class StorageAllocator : public StorageAllocaBaseVisitor {
return fn->HasNonzeroAttr(attr::kReshapeOnly);
}
- if (call->attrs.defined()) {
- if (auto tir_call_attrs = call->attrs.as<TIRCallAttrs>()) {
- Map<String, ObjectRef> metadata = tir_call_attrs->metadata;
- return metadata.count(attr::kReshapeOnly) &&
- (Downcast<tvm::Integer>(metadata[attr::kReshapeOnly])->value ==
1);
- }
+ if (call->op == CallLoweredOp()) {
+ CallLoweredProps call_lowered_props = GetCallLoweredProps(call);
+ Map<String, ObjectRef> metadata = call_lowered_props.attrs.metadata;
+ return metadata.count(attr::kReshapeOnly) &&
+ (Downcast<tvm::Integer>(metadata[attr::kReshapeOnly])->value ==
1);
}
return false;
diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc
index 13b8556..4835d76 100644
--- a/src/relay/backend/interpreter.cc
+++ b/src/relay/backend/interpreter.cc
@@ -25,6 +25,7 @@
#include <tvm/driver/driver_api.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/attrs/annotation.h>
+#include <tvm/relay/attrs/call.h>
#include <tvm/relay/attrs/debug.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/feature.h>
@@ -37,6 +38,7 @@
#include <tvm/target/compilation_config.h>
#include "../op/annotation/annotation.h"
+#include "../op/call/call.h"
#include "../transforms/pass_utils.h"
#include "te_compiler.h"
@@ -682,82 +684,94 @@ class Interpreter : public ExprFunctor<ObjectRef(const
Expr& n)>,
}
ObjectRef VisitExpr_(const CallNode* call_node) final {
- std::vector<ObjectRef> args;
- for (auto arg : call_node->args) {
- args.push_back(Eval(arg));
- }
+ if (call_node->op == CallLoweredOp()) { // Special case: Call a lowered
TIR function.
+ CallLoweredProps call_lowered_props = GetCallLoweredProps(call_node);
- if (call_node->op == OnDeviceOp()) {
- // Special case: The call 'on_device(expr)' denotes that expr should be
executed on
- // a particular device. We can ignore this during interpretation.
- ICHECK_EQ(call_node->args.size(), 1UL);
- return args[0];
- }
+ // Evaluate only function args
+ std::vector<ObjectRef> args;
+ for (auto arg : call_lowered_props.arguments) {
+ args.push_back(Eval(arg));
+ }
- // We should not find calls to operators after running fusion and lowering.
- if (const OpNode* op_node = call_node->op.as<OpNode>()) {
- LOG(FATAL) << "found " << op_node->name
- << "; operators should have been removed by previous passes;
try "
- "fusing and lowering";
- }
+ // TODO(mbs): Make calling convention first-class in Relay.
+ Array<GlobalVar> all_prim_fn_vars;
+ if (call_lowered_props.attrs.metadata.count("all_prim_fn_vars")) {
+ all_prim_fn_vars =
+
Downcast<Array<GlobalVar>>(call_lowered_props.attrs.metadata.at("all_prim_fn_vars"));
+ }
+ GlobalVar prim_shape_fn_var;
+ if (call_lowered_props.attrs.metadata.count("prim_shape_fn_var")) {
+ prim_shape_fn_var =
+
Downcast<GlobalVar>(call_lowered_props.attrs.metadata.at("prim_shape_fn_var"));
+ }
+ Array<GlobalVar> all_prim_shape_fn_vars;
+ if (call_lowered_props.attrs.metadata.count("all_prim_shape_fn_vars")) {
+ all_prim_shape_fn_vars = Downcast<Array<GlobalVar>>(
+ call_lowered_props.attrs.metadata.at("all_prim_shape_fn_vars"));
+ }
+ Array<Integer> prim_shape_fn_states;
+ if (call_lowered_props.attrs.metadata.count("prim_shape_fn_states")) {
+ prim_shape_fn_states =
+
Downcast<Array<Integer>>(call_lowered_props.attrs.metadata.at("prim_shape_fn_states"));
+ }
- if (const ConstructorNode* con = call_node->op.as<ConstructorNode>()) {
- // Special case: ADT constructor
- return ConstructorValue(con->tag, args, GetRef<Constructor>(con));
- }
+ size_t num_shape_inputs = 0;
+ if (call_lowered_props.attrs.metadata.count("prim_shape_fn_num_inputs"))
{
+ num_shape_inputs = static_cast<size_t>(
+
Downcast<Integer>(call_lowered_props.attrs.metadata.at("prim_shape_fn_num_inputs"))
+ ->value);
+ }
+ size_t num_shape_outputs = 0;
+ if
(call_lowered_props.attrs.metadata.count("prim_shape_fn_num_outputs")) {
+ num_shape_outputs = static_cast<size_t>(
+
Downcast<Integer>(call_lowered_props.attrs.metadata.at("prim_shape_fn_num_outputs"))
+ ->value);
+ }
+ ICHECK(config_->optional_homogeneous_target.defined());
+ return InvokePrimitiveOp(call_lowered_props.lowered_func,
all_prim_fn_vars,
+ config_->optional_homogeneous_target,
prim_shape_fn_var,
+ all_prim_shape_fn_vars, prim_shape_fn_states,
num_shape_inputs,
+ num_shape_outputs,
config_->host_se_scope->target, args);
+ } else { // All other calls
+ // Evaluate all arguments
+ std::vector<ObjectRef> args;
+ for (auto arg : call_node->args) {
+ args.push_back(Eval(arg));
+ }
- if (const GlobalVarNode* gvn = call_node->op.as<GlobalVarNode>()) {
- if (const TIRCallAttrs* attrs = call_node->attrs.as<TIRCallAttrs>()) {
- // Special case: Call a lowered TIR function.
- // TODO(mbs): Make calling convention first-class in Relay.
- Array<GlobalVar> all_prim_fn_vars;
- if (attrs->metadata.count("all_prim_fn_vars")) {
- all_prim_fn_vars =
Downcast<Array<GlobalVar>>(attrs->metadata.at("all_prim_fn_vars"));
- }
- GlobalVar prim_shape_fn_var;
- if (attrs->metadata.count("prim_shape_fn_var")) {
- prim_shape_fn_var =
Downcast<GlobalVar>(attrs->metadata.at("prim_shape_fn_var"));
- }
- Array<GlobalVar> all_prim_shape_fn_vars;
- if (attrs->metadata.count("all_prim_shape_fn_vars")) {
- all_prim_shape_fn_vars =
-
Downcast<Array<GlobalVar>>(attrs->metadata.at("all_prim_shape_fn_vars"));
- }
- Array<Integer> prim_shape_fn_states;
- if (attrs->metadata.count("prim_shape_fn_states")) {
- prim_shape_fn_states =
-
Downcast<Array<Integer>>(attrs->metadata.at("prim_shape_fn_states"));
- }
- size_t num_shape_inputs = 0;
- if (attrs->metadata.count("prim_shape_fn_num_inputs")) {
- num_shape_inputs = static_cast<size_t>(
-
Downcast<Integer>(attrs->metadata.at("prim_shape_fn_num_inputs"))->value);
- }
- size_t num_shape_outputs = 0;
- if (attrs->metadata.count("prim_shape_fn_num_outputs")) {
- num_shape_outputs = static_cast<size_t>(
-
Downcast<Integer>(attrs->metadata.at("prim_shape_fn_num_outputs"))->value);
- }
+ if (call_node->op == OnDeviceOp()) {
+ // Special case: The call 'on_device(expr)' denotes that expr should
be executed on
+ // a particular device. We can ignore this during interpretation.
+ ICHECK_EQ(call_node->args.size(), 1UL);
+ return args[0];
+ }
+ if (const ConstructorNode* con = call_node->op.as<ConstructorNode>()) {
+ // Special case: ADT constructor
- ICHECK(config_->optional_homogeneous_target.defined());
- return InvokePrimitiveOp(GetRef<GlobalVar>(gvn), all_prim_fn_vars,
- config_->optional_homogeneous_target,
prim_shape_fn_var,
- all_prim_shape_fn_vars, prim_shape_fn_states,
num_shape_inputs,
- num_shape_outputs,
config_->host_se_scope->target, args);
+ return ConstructorValue(con->tag, args, GetRef<Constructor>(con));
}
- }
- // Now we just evaluate and expect to find a closure.
- ObjectRef fn_val = Eval(call_node->op);
- if (const InterpreterClosureObj* closure_node =
fn_val.as<InterpreterClosureObj>()) {
- auto closure = GetRef<InterpreterClosure>(closure_node);
- return Invoke(closure, args);
- } else if (const RecClosureObj* closure_node = fn_val.as<RecClosureObj>())
{
- return Invoke(closure_node->clos, args, closure_node->bind);
- } else {
- LOG(FATAL) << "internal error: type error, expected function value in
the call "
- << "position";
- return ObjectRef();
+ if (const OpNode* op_node = call_node->op.as<OpNode>()) {
+ // Except for call_lowered and on_device, we should not find calls to
operators after
+ // running fusion and lowering.
+ LOG(FATAL) << "found " << op_node->name
+ << "; operators should have been removed by previous
passes; try "
+ "fusing and lowering";
+ }
+
+ // Now we just evaluate and expect to find a closure.
+ // TODO(@electriclilies): How should call_lowered behave with closures?
+ ObjectRef fn_val = Eval(call_node->op);
+ if (const InterpreterClosureObj* closure_node =
fn_val.as<InterpreterClosureObj>()) {
+ auto closure = GetRef<InterpreterClosure>(closure_node);
+ return Invoke(closure, args);
+ } else if (const RecClosureObj* closure_node =
fn_val.as<RecClosureObj>()) {
+ return Invoke(closure_node->clos, args, closure_node->bind);
+ } else {
+ LOG(FATAL) << "internal error: type error, expected function value in
the call "
+ << "position";
+ return ObjectRef();
+ }
}
}
diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc
index 163bb9f..915fc22 100644
--- a/src/relay/backend/te_compiler.cc
+++ b/src/relay/backend/te_compiler.cc
@@ -24,6 +24,7 @@
#include <tvm/ir/function.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/attrs/annotation.h>
+#include <tvm/relay/attrs/call.h>
#include <tvm/relay/attrs/device_copy.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
@@ -43,6 +44,7 @@
#include <vector>
#include "../op/annotation/annotation.h"
+#include "../op/call/call.h"
#include "../transforms/device_aware_visitors.h"
#include "./te_compiler_cache.h"
#include "./utils.h"
@@ -460,7 +462,8 @@ class LowerTensorExprMutator : public
DeviceAwareExprMutator {
* to the TIR implementation, and attributes to attach to the call to
identify it as
* a TIR call.
*/
- std::pair<GlobalVar, Attrs> LowerFunction(Function func, Target target) {
+ Expr MakeLoweredCall(Function func, Array<Expr> visited_args, Array<Type>
type_args, Span span,
+ Target target) {
if (func->GetAttr<String>(attr::kCompiler).defined()) {
// BYOC flow.
CCacheKey key = CCacheKey(func, target);
@@ -468,6 +471,7 @@ class LowerTensorExprMutator : public
DeviceAwareExprMutator {
ICHECK(ext_func.defined()) << "Lowering returned undefined function for "
<< ext_func->prim_fn_var->name_hint;
+ // TODO(@areusch, @jroesch): this metadata is for AOT, this should be
our interface for AOT
Map<GlobalVar, tir::PrimFunc> prim_fns;
relay::Function func_with_metadata = func;
func_with_metadata = WithAttr(func_with_metadata, "prim_fn_var",
ext_func->prim_fn_var);
@@ -478,87 +482,91 @@ class LowerTensorExprMutator : public
DeviceAwareExprMutator {
// act when we process a function.
this->process_fn_(func_with_metadata);
- // TODO(mbs): Need TIRCallAttrs or equiv so targets know this is an
extern.
// TODO(mbs): Dynamic shapes?
- return {ext_func->prim_fn_var, Attrs()};
- }
+ // TODO(@mbs, electriclilies): Make extern functions explicit
+ return Call(ext_func->prim_fn_var, visited_args, Attrs(), type_args,
span);
- // Non-External Relay Function
- VLOG(1) << "lowering to target '" << target->str() << "' for primitive:\n"
<< PrettyPrint(func);
- CCacheKey key = CCacheKey(func, target);
- CachedFunc lowered_func = compiler_->Lower(key, module_name_);
- VLOG(1) << "lowered primitive bound to '" <<
PrettyPrint(lowered_func->prim_fn_var) << "'";
-
- // Collect all the lowered functions produced for this primitive function.
- Map<GlobalVar, tir::PrimFunc> prim_fns;
- Array<GlobalVar> all_prim_fn_vars;
- for (auto prim_fn : lowered_func->funcs->functions) {
- CHECK(prim_fn.second.as<tir::PrimFuncNode>()) << "must be a prim fn";
- prim_fns.Set(prim_fn.first, Downcast<tir::PrimFunc>(prim_fn.second));
- all_prim_fn_vars.push_back(prim_fn.first);
- VLOG(1) << "lowered primitive includes bindings for '" <<
PrettyPrint(prim_fn.first) << "'";
- }
+ } else {
+ // Non-External Relay Function
+ VLOG(1) << "lowering to target '" << target->str() << "' for
primitive:\n"
+ << PrettyPrint(func);
+ CCacheKey key = CCacheKey(func, target);
+ CachedFunc lowered_func = compiler_->Lower(key, module_name_);
+ VLOG(1) << "lowered primitive bound to '" <<
PrettyPrint(lowered_func->prim_fn_var) << "'";
- // TODO(@areusch, @jroesch): this metadata is for AOT, this should be our
interface for AOT
- relay::Function func_with_metadata = func;
- func_with_metadata = WithAttr(func_with_metadata, "prim_fn_var",
lowered_func->prim_fn_var);
- func_with_metadata = WithAttr(func_with_metadata, "prim_funcs", prim_fns);
- func_with_metadata = WithAttr(func_with_metadata, tvm::attr::kTarget,
lowered_func->target);
+ // Collect all the lowered functions produced for this primitive
function.
+ Map<GlobalVar, tir::PrimFunc> prim_fns;
+ Array<GlobalVar> all_prim_fn_vars;
+ for (auto prim_fn : lowered_func->funcs->functions) {
+ CHECK(prim_fn.second.as<tir::PrimFuncNode>()) << "must be a prim fn";
+ prim_fns.Set(prim_fn.first, Downcast<tir::PrimFunc>(prim_fn.second));
+ all_prim_fn_vars.push_back(prim_fn.first);
+ VLOG(1) << "lowered primitive includes bindings for '" <<
PrettyPrint(prim_fn.first) << "'";
+ }
- // Provide a callback hook which allows one-level up code generators to
- // act when we process a function.
- this->process_fn_(func_with_metadata);
+ // TODO(@areusch, @jroesch): this metadata is for AOT, this should be
our interface for AOT
+ relay::Function func_with_metadata = func;
+ func_with_metadata = WithAttr(func_with_metadata, "prim_fn_var",
lowered_func->prim_fn_var);
+ func_with_metadata = WithAttr(func_with_metadata, "prim_funcs",
prim_fns);
+ func_with_metadata = WithAttr(func_with_metadata, tvm::attr::kTarget,
lowered_func->target);
- auto tir_call_attrs = make_object<TIRCallAttrs>();
- if (func->HasNonzeroAttr(attr::kReshapeOnly)) {
- tir_call_attrs->metadata.Set(attr::kReshapeOnly, tvm::Integer(1));
- }
+ // Provide a callback hook which allows one-level up code generators to
+ // act when we process a function.
+ this->process_fn_(func_with_metadata);
- auto device_copy = IsDeviceCopy(func);
- if (std::get<0>(device_copy)) {
- // Record that device copy source and destination devices so the device
planner can
- // still follow along.
- auto source_device = std::get<1>(device_copy);
- auto dst_device = std::get<2>(device_copy);
- tir_call_attrs->metadata.Set("source_device",
tvm::Integer(source_device));
- tir_call_attrs->metadata.Set("dst_device", tvm::Integer(dst_device));
- }
+ auto call_lowered_attrs = make_object<CallLoweredAttrs>();
+ if (func->HasNonzeroAttr(attr::kReshapeOnly)) {
+ call_lowered_attrs->metadata.Set(attr::kReshapeOnly, tvm::Integer(1));
+ }
- tir_call_attrs->metadata.Set("relay_attrs", func->attrs);
- tir_call_attrs->metadata.Set("all_prim_fn_vars", all_prim_fn_vars);
-
- if (IsDynamic(func->ret_type)) {
- // Also lower the dynamic shape function.
- // Shape function keys use the underlying primitive function as their
'function',
- // but the generic 'cpu' target as the target since all shape functions
run
- // on the host cpu irrespective of where the primitive runs.
- // TODO(mbs): Cleanup target handling.
- Target shape_target("llvm");
- VLOG(1) << "lowering to target '" << shape_target->str()
- << "' for dynamic shape function for primitive";
- CCacheKey shape_key(func, shape_target);
- CachedFunc lowered_shape_func = compiler_->LowerShapeFunc(shape_key);
- // Capture the shape function's global var and parameters 'states' in
call
- // annotations so calling convention can be recovered.
- // TODO(mbs): Capture all this as part of a 'call into TIR' construct
once available.
- // The way the shape function calling convention is derived and passed
to call sites
- // via the 'parameter states' could be improved.
- tir_call_attrs->metadata.Set("prim_shape_fn_var",
lowered_shape_func->prim_fn_var);
- tir_call_attrs->metadata.Set("prim_shape_fn_states",
-
lowered_shape_func->shape_func_param_states);
- tir_call_attrs->metadata.Set("prim_shape_fn_num_inputs",
-
Integer(static_cast<int>(lowered_shape_func->inputs.size())));
- tir_call_attrs->metadata.Set("prim_shape_fn_num_outputs",
-
Integer(static_cast<int>(lowered_shape_func->outputs.size())));
- Array<GlobalVar> all_prim_shape_fn_vars;
- for (auto prim_shape_fn : lowered_shape_func->funcs->functions) {
- CHECK(prim_shape_fn.second.as<tir::PrimFuncNode>()) << "must be a prim
fn";
- all_prim_shape_fn_vars.push_back(prim_shape_fn.first);
+ auto device_copy = IsDeviceCopy(func);
+ if (std::get<0>(device_copy)) {
+ // Record that device copy source and destination devices so the
device planner can
+ // still follow along.
+ auto source_device = std::get<1>(device_copy);
+ auto dst_device = std::get<2>(device_copy);
+ call_lowered_attrs->metadata.Set("source_device",
tvm::Integer(source_device));
+ call_lowered_attrs->metadata.Set("dst_device",
tvm::Integer(dst_device));
}
- tir_call_attrs->metadata.Set("all_prim_shape_fn_vars",
all_prim_shape_fn_vars);
- }
- return {lowered_func->prim_fn_var, Attrs(tir_call_attrs)};
+ call_lowered_attrs->metadata.Set("relay_attrs", func->attrs);
+ call_lowered_attrs->metadata.Set("all_prim_fn_vars", all_prim_fn_vars);
+
+ if (IsDynamic(func->ret_type)) {
+ // Also lower the dynamic shape function.
+ // Shape function keys use the underlying primitive function as their
'function',
+ // but the generic 'cpu' target as the target since all shape
functions run
+ // on the host cpu irrespective of where the primitive runs.
+ // TODO(mbs): Cleanup target handling.
+ Target shape_target("llvm");
+ VLOG(1) << "lowering to target '" << shape_target->str()
+ << "' for dynamic shape function for primitive";
+ CCacheKey shape_key(func, shape_target);
+ CachedFunc lowered_shape_func = compiler_->LowerShapeFunc(shape_key);
+ // Capture the shape function's global var and parameters 'states' in
call
+ // annotations so calling convention can be recovered.
+ // TODO(mbs): Capture all this as part of a 'call into TIR' construct
once available.
+ // The way the shape function calling convention is derived and passed
to call sites
+ // via the 'parameter states' could be improved.
+ call_lowered_attrs->metadata.Set("prim_shape_fn_var",
lowered_shape_func->prim_fn_var);
+ call_lowered_attrs->metadata.Set("prim_shape_fn_states",
+
lowered_shape_func->shape_func_param_states);
+ call_lowered_attrs->metadata.Set(
+ "prim_shape_fn_num_inputs",
+ Integer(static_cast<int>(lowered_shape_func->inputs.size())));
+ call_lowered_attrs->metadata.Set(
+ "prim_shape_fn_num_outputs",
+ Integer(static_cast<int>(lowered_shape_func->outputs.size())));
+ Array<GlobalVar> all_prim_shape_fn_vars;
+ for (auto prim_shape_fn : lowered_shape_func->funcs->functions) {
+ CHECK(prim_shape_fn.second.as<tir::PrimFuncNode>()) << "must be a
prim fn";
+ all_prim_shape_fn_vars.push_back(prim_shape_fn.first);
+ }
+ call_lowered_attrs->metadata.Set("all_prim_shape_fn_vars",
all_prim_shape_fn_vars);
+ }
+ return CallLowered(lowered_func->prim_fn_var, visited_args,
Attrs(call_lowered_attrs),
+ type_args, span);
+ }
}
std::pair<Var, Expr> PreVisitLetBinding_(const Var& var, const Expr& value)
final {
@@ -593,6 +601,9 @@ class LowerTensorExprMutator : public
DeviceAwareExprMutator {
}
Expr DeviceAwareVisitExpr_(const CallNode* call_node) override {
+ // Passes before lowering might insert a call_lowered to call a function
that has already
+ // been lowered. Therefore we might see call_lowered ops here, but we
don't need to do anything
+ // because ResolveToPrimitive returns null for all calls where the
call_node->op is an OpNode
Call call = GetRef<Call>(call_node);
// Look for (indirect) calls to primitives.
@@ -628,15 +639,13 @@ class LowerTensorExprMutator : public
DeviceAwareExprMutator {
// TODO(mbs): Replace device_type with target so this lookup is
unnecessary.
target = GetTargetFromInteger(device_type, targets_);
}
-
+ Array<Expr> visited_args;
+ for (const auto& arg : call_node->args) {
+ visited_args.push_back(VisitExpr(arg));
+ }
// Lower the primitive function for that target.
Function func = Downcast<Function>(prim_func);
- std::pair<GlobalVar, Attrs> pair = LowerFunction(func, target);
-
- // Replace with direct call to lowered primitive, and attach annotations
to record calling
- // convention.
- // =====> in new call_lowered form
- return Call(pair.first, args, pair.second);
+ return MakeLoweredCall(func, visited_args, call_node->type_args,
call_node->span, target);
}
IRModule module_;
diff --git a/src/relay/op/call/call.cc b/src/relay/op/call/call.cc
new file mode 100644
index 0000000..9485b72
--- /dev/null
+++ b/src/relay/op/call/call.cc
@@ -0,0 +1,116 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/op/call/call.cc
+ * \brief Operators for calling lowered functions.
+ */
+
+#include "./call.h"
+
+#include <tvm/relay/attrs/call.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/op.h>
+#include <tvm/relay/op_attr_types.h>
+
+#include "../../transforms/infer_layout_utils.h"
+
+namespace tvm {
+namespace relay {
+
+TVM_REGISTER_NODE_TYPE(CallLoweredAttrs);
+
+// call_lowered
+bool CallLoweredRel(const Array<Type>& types, int num_inputs, const Attrs&
attrs,
+ const TypeReporter& reporter) {
+ // Types = [func, call_args, ret_type]
+ if (types.size() != 3u) {
+ return false;
+ }
+ const auto* func_type = types[0].as<FuncTypeNode>();
+ if (!func_type) {
+ return false;
+ }
+
+ const auto* tuple_type_node = types[1].as<TupleTypeNode>();
+ if (!tuple_type_node) {
+ return false;
+ }
+
+ // Constraint to ensure function arguments are the same type as the inputs
to the function (modulo
+ // the Tuple wrapper)
+ reporter->Assign(GetRef<TupleType>(tuple_type_node),
TupleType(func_type->arg_types, {}));
+ // Constraint to ensure the output of call_lowered is the same as the
function's return type
+ reporter->Assign(types[2], func_type->ret_type);
+ return true;
+}
+
+const Op& CallLoweredOp() { return Op::Get("call_lowered"); }
+
+Expr CallLowered(Expr func, Array<Expr> inputs, Attrs attrs, Array<Type>
type_args, Span span) {
+ // Right now, call_lowered only supports func being a global var pointing to
the lowered
+ // function.
+ ICHECK(func.as<GlobalVarNode>())
+ << "Function to call should be GlobalVarNode, but got " <<
func->GetTypeKey();
+ ICHECK(attrs.as<CallLoweredAttrs>())
+ << "Expected attributes to be CallLoweredAttrs, but got " <<
attrs->GetTypeKey();
+ return Call(CallLoweredOp(), {std::move(func), Tuple(std::move(inputs))},
std::move(attrs),
+ std::move(type_args), std::move(span));
+}
+
+TVM_REGISTER_GLOBAL("relay.op.call_lowered")
+ .set_body_typed([](Expr func, Array<Expr> inputs, Attrs attrs, Array<Type>
type_args,
+ Span span) {
+ const TupleNode* tuple_node = inputs.as<TupleNode>();
+ return CallLowered(func, tuple_node->fields, attrs, type_args, span);
+ });
+
+RELAY_REGISTER_OP("call_lowered")
+ .describe(R"code(Invoke an operation compiled by TVM.)code"
TVM_ADD_FILELINE)
+ .set_num_inputs(2)
+ .set_attrs_type<CallLoweredAttrs>()
+ .add_argument("func", "Function", "The lowered function to call.")
+ .add_argument("call_args", "Tuple", "The input tensors.")
+ .add_type_rel("CallLoweredRel", CallLoweredRel)
+ .set_support_level(10)
+ .set_attr<TOpPattern>("TOpPattern", kOpaque)
+ .set_attr<TOpIsStateful>("TOpIsStateful", false)
+ .set_attr<TNonComputational>("TNonComputational", true)
+ .set_attr<FInferCorrectLayout>("FInferCorrectLayout",
ElemwiseArbitraryLayout);
+
+CallLoweredProps GetCallLoweredProps(const CallNode* call_node) {
+ ICHECK(call_node->op == CallLoweredOp())
+ << "GetCallLoweredProps expects the op to be call_lowered. ";
+ ICHECK(call_node->args.size() == 2) << "Expected call_lowered to have 2
arguments. ";
+ const auto* function = call_node->args[0].as<GlobalVarNode>();
+ ICHECK(function) << "Expected first arg to call_lowered to be a GlobalVar. ";
+
+ const auto* tuple_args = call_node->args[1].as<TupleNode>();
+ ICHECK(tuple_args) << "Expected second arg to call_lowered to be a Tuple. ";
+
+ ICHECK(call_node->attrs.defined()) << "Attributes for call_lowered should be
defined!";
+ const auto* attrs = call_node->attrs.as<CallLoweredAttrs>();
+ ICHECK(attrs) << "Expected call_lowered op to have CallLoweredAttrs, but
found "
+ << call_node->attrs->GetTypeKey();
+ return CallLoweredProps{std::move(GetRef<GlobalVar>(function)),
std::move(tuple_args->fields),
+ std::move(*attrs)};
+}
+
+} // namespace relay
+} // namespace tvm
diff --git a/src/relay/op/call/call.h b/src/relay/op/call/call.h
new file mode 100644
index 0000000..381be67
--- /dev/null
+++ b/src/relay/op/call/call.h
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/op/call/call.h
+ * \brief Operators for calling lowered functions.
+ */
+#ifndef TVM_RELAY_OP_CALL_CALL_H_
+#define TVM_RELAY_OP_CALL_CALL_H_
+
+#include <tvm/relay/attrs/call.h>
+#include <tvm/relay/expr.h>
+
+#include <utility>
+
+namespace tvm {
+namespace relay {
+
+/*!
+ * \brief Helper to construct a Relay call with the call_lowered op.
+ * \param func Lowered function to call with call_lowered.
+ * \param inputs Arguments to be passed to the function.
+ * \param attrs Function attributes, should be TIRCallAttrs.
+ * \param type_args Type arguments for the call.
+ * \param span TVM span for propogating debugging info.
+ * \return
+ */
+Expr CallLowered(Expr func, Array<Expr> inputs, Attrs attrs, Array<Type>
type_args, Span span);
+
+/*!
+ * \brief Returns the Relay call_lowered op. Use this helper to avoid
extraneous calls to
+ * Registry::Get.
+ */
+const Op& CallLoweredOp();
+
+/*!
+ * \brief Lowered function and the arguments to call it with.
+ */
+struct CallLoweredProps {
+ /*! \brief Global variable pointing to the lowered function. */
+ GlobalVar lowered_func;
+ /*! \brief Array of the arguments to call lowered_func with. */
+ Array<Expr> arguments;
+ /*! \brief Arguments from the call_lowered op. */
+ CallLoweredAttrs attrs;
+};
+
+/*!
+ * \brief Helper to extract the lowered function and its arguments from
Call("call_lowered", ...).
+ * Will fail if called on a Call whose op is not "call_lowered" \param
call_node CallNode that we
+ * want to get the function and its arguments from.
+ */
+CallLoweredProps GetCallLoweredProps(const CallNode* call_node);
+
+} // namespace relay
+} // namespace tvm
+
+#endif // TVM_RELAY_OP_CALL_CALL_H_
diff --git a/src/relay/op/memory/device_copy.cc
b/src/relay/op/memory/device_copy.cc
index dce89aa..9106b95 100644
--- a/src/relay/op/memory/device_copy.cc
+++ b/src/relay/op/memory/device_copy.cc
@@ -24,6 +24,7 @@
#include "./device_copy.h"
+#include <tvm/relay/attrs/call.h>
#include <tvm/relay/attrs/device_copy.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/op.h>
@@ -31,6 +32,8 @@
#include <tvm/topi/elemwise.h>
#include "../../transforms/infer_layout_utils.h"
+#include "../annotation/annotation.h"
+#include "../call/call.h"
#include "../type_relations.h"
namespace tvm {
@@ -86,6 +89,7 @@ on different devices.
return {topi::identity(inputs[0])};
});
+// Get device copy props for original device copy op
DeviceCopyProps GetDeviceCopyProps(const CallNode* call_node) {
if (call_node->op == DeviceCopyOp()) {
ICHECK_EQ(call_node->args.size(), 1) << "device_copy expects one argument";
@@ -103,6 +107,19 @@ DeviceCopyProps GetDeviceCopyProps(const CallNode*
call_node) {
} else {
return {call_node->args[0], src_dev_type, dst_dev_type};
}
+ } else if (call_node->op == CallLoweredOp()) {
+ /* Get device props for a TIR function */
+ CallLoweredProps call_lowered_props = GetCallLoweredProps(call_node);
+
+ if (call_lowered_props.attrs.metadata.count("source_device") == 1 &&
+ call_lowered_props.attrs.metadata.count("dst_device") == 1) {
+ ICHECK_EQ(call_lowered_props.arguments.size(), 1) << "device_copy is of
arity 1";
+ return {call_lowered_props.lowered_func,
+ static_cast<DLDeviceType>(
+
Downcast<Integer>(call_lowered_props.attrs.metadata["source_device"])->value),
+ static_cast<DLDeviceType>(
+
Downcast<Integer>(call_lowered_props.attrs.metadata["dst_device"])->value)};
+ }
}
return {};
}
diff --git a/src/relay/op/vm/vm.h b/src/relay/op/vm/vm.h
index 802c810..68d25b0 100644
--- a/src/relay/op/vm/vm.h
+++ b/src/relay/op/vm/vm.h
@@ -24,7 +24,7 @@
#ifndef TVM_RELAY_OP_VM_VM_H_
#define TVM_RELAY_OP_VM_VM_H_
-#include "tvm/relay/expr.h"
+#include <tvm/relay/expr.h>
namespace tvm {
namespace relay {
diff --git a/src/relay/transforms/device_domains.cc
b/src/relay/transforms/device_domains.cc
index 1578485..b9fa049 100644
--- a/src/relay/transforms/device_domains.cc
+++ b/src/relay/transforms/device_domains.cc
@@ -24,9 +24,11 @@
#include "./device_domains.h"
+#include <tvm/relay/attrs/call.h>
#include <tvm/relay/attrs/memory.h>
#include "../op/annotation/annotation.h"
+#include "../op/call/call.h"
#include "../op/memory/device_copy.h"
namespace tvm {
@@ -47,20 +49,19 @@ constexpr size_t mix(size_t h1, size_t h2) {
* See te_compiler.cc for where this rewriting occurs.
*/
DeviceCopyProps GetPrimitiveDeviceCopyProps(const CallNode* call_node) {
- auto tir_call_attrs = call_node->attrs.as<TIRCallAttrs>();
- if (tir_call_attrs == nullptr) {
- return {};
- }
- if (tir_call_attrs->metadata.count("source_device") != 1 ||
- tir_call_attrs->metadata.count("dst_device") != 1) {
- return {};
+ if (call_node->op == CallLoweredOp()) {
+ CallLoweredProps call_lowered_props = GetCallLoweredProps(call_node);
+ if (call_lowered_props.attrs.metadata.count("source_device") == 1 &&
+ call_lowered_props.attrs.metadata.count("dst_device") == 1) {
+ ICHECK_EQ(call_lowered_props.arguments.size(), 1) << "device_copy is of
arity 1";
+ return {call_lowered_props.arguments[0],
+ static_cast<DLDeviceType>(
+
Downcast<Integer>(call_lowered_props.attrs.metadata["source_device"])->value),
+ static_cast<DLDeviceType>(
+
Downcast<Integer>(call_lowered_props.attrs.metadata["dst_device"])->value)};
+ }
}
- ICHECK_EQ(call_node->args.size(), 1) << "device_copy is of arity 1";
- return {
- call_node->args[0],
- static_cast<DLDeviceType>(
- Downcast<Integer>(tir_call_attrs->metadata["source_device"])->value),
-
static_cast<DLDeviceType>(Downcast<Integer>(tir_call_attrs->metadata["dst_device"])->value)};
+ return {};
}
} // namespace
@@ -319,8 +320,12 @@ DeviceDomainPtr DeviceDomains::DomainForCallee(const Call&
call) {
args_and_result.emplace_back(param_domain);
}
args_and_result.emplace_back(result_domain);
+ } else if (call->op == CallLoweredOp()) {
+ CallLoweredProps call_lowered_props = GetCallLoweredProps(call.get());
+ return DomainFor(call_lowered_props.lowered_func);
} else {
- // Defer to normal case where op can be an arbitrary expression.
+ // We still need to handle the case where the function / op is not lowered
+ // because the device planner runs before and after lowering.
return DomainFor(call->op);
}
auto domain = MakeDomain(std::move(args_and_result));
diff --git a/src/relay/transforms/memory_alloc.cc
b/src/relay/transforms/memory_alloc.cc
index 81d704e..a328eaa 100644
--- a/src/relay/transforms/memory_alloc.cc
+++ b/src/relay/transforms/memory_alloc.cc
@@ -26,6 +26,7 @@
#include <tvm/node/structural_hash.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/attrs/annotation.h>
+#include <tvm/relay/attrs/call.h>
#include <tvm/relay/attrs/device_copy.h>
#include <tvm/relay/attrs/memory.h>
#include <tvm/relay/expr.h>
@@ -44,6 +45,7 @@
#include "../backend/te_compiler.h"
#include "../backend/te_compiler_cache.h"
#include "../op/annotation/annotation.h"
+#include "../op/call/call.h"
#include "../op/memory/device_copy.h"
#include "../op/memory/memory.h"
#include "../op/vm/vm.h"
@@ -74,12 +76,11 @@ bool IsReshapeOnly(const Expr& expr) {
return func->HasNonzeroAttr(attr::kReshapeOnly);
}
if (const CallNode* call = expr.as<CallNode>()) {
- if (call->attrs.defined()) {
- if (auto tir_call_attrs = call->attrs.as<TIRCallAttrs>()) {
- Map<String, ObjectRef> metadata = tir_call_attrs->metadata;
- return metadata.count(attr::kReshapeOnly) &&
- (Downcast<tvm::Integer>(metadata[attr::kReshapeOnly])->value ==
1);
- }
+ if (call->op == CallLoweredOp()) {
+ CallLoweredProps call_lowered_props = GetCallLoweredProps(call);
+ Map<String, ObjectRef> metadata = call_lowered_props.attrs.metadata;
+ return metadata.count(attr::kReshapeOnly) &&
+ (Downcast<tvm::Integer>(metadata[attr::kReshapeOnly])->value ==
1);
}
}
return false;
@@ -377,7 +378,8 @@ class DialectRewriter : public
transform::DeviceAwareExprMutator {
}
Tuple tuple_outs(outs);
- auto invoke = OnDevice(InvokeTVMOp(func, ins, tuple_outs),
dev.device_type, /*is_fixed=*/true);
+ auto call = InvokeTVMOp(func, ins, tuple_outs);
+ auto invoke = OnDevice(call, dev.device_type, /*is_fixed=*/true);
scope->Push(invoke);
return ToTupleType(ret_type,
std::vector<Expr>(tuple_outs->fields.begin(),
tuple_outs->fields.end()));