mbs-octoml commented on a change in pull request #9312:
URL: https://github.com/apache/tvm/pull/9312#discussion_r743103504



##########
File path: include/tvm/relay/attrs/call.h
##########
@@ -0,0 +1,48 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/relay/attrs/annotation.h

Review comment:
       nit: update comment

##########
File path: src/relay/op/call/call.cc
##########
@@ -0,0 +1,97 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/op/call/call.cc
+ * \brief Operators for calling lowered functions.
+ */
+
+#include "./call.h"
+
+#include <tvm/relay/attrs/call.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/op.h>
+#include <tvm/relay/op_attr_types.h>
+
+#include "../../transforms/infer_layout_utils.h"
+
+namespace tvm {
+namespace relay {
+
+TVM_REGISTER_NODE_TYPE(CallLoweredAttrs);
+
+// call_lowered
+bool CallLoweredRel(const Array<Type>& types, int num_inputs, const Attrs& 
attrs,
+                    const TypeReporter& reporter) {
+  // Types = [func, args, ret_type]
+  ICHECK_EQ(types.size(), 3u);
+  auto func_type = types[0].as<FuncTypeNode>();

Review comment:
       return false for these checks since we know nothing about how the 
expression was constructed.
   even for the arity check i think?

##########
File path: src/relay/backend/aot_executor_codegen.cc
##########
@@ -289,11 +308,21 @@ class AOTExecutorCodegen : public MixedModeVisitor {
   /*!
    * brief Call a function with a given name
    */
-  void CreateFuncCall(Call call, std::string func_name) {
+  void CreateFuncCall(const CallNode* call_node) {

Review comment:
       nit: pass the already deconstructed args so you don't need to call 
ExtractFunctonAndArgs again. Not for the efficiency (whatever!) but so the 
function is self contained.

##########
File path: src/relay/op/call/call.h
##########
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/op/call/call.h
+ * \brief Operators for calling lowered functions.
+ */
+#ifndef TVM_RELAY_OP_CALL_CALL_H_
+#define TVM_RELAY_OP_CALL_CALL_H_
+
+#include <tvm/relay/expr.h>
+
+#include <utility>
+
+namespace tvm {
+namespace relay {
+
+Expr CallLowered(Expr func, Expr inputs, Attrs attrs, Array<Type> type_args, 
Span span);
+const Op& CallLoweredOp();
+std::pair<GlobalVar, Array<Expr>> ExtractFunctionAndArgs(const CallNode* 
call_node);

Review comment:
       nit: it's a lot of boilerplate i know, but a struct for the return 
result is probably best long term.
   
   super duper nit: for on_device and device_copy i called these Get.... Do I 
get to claim precedence?

##########
File path: src/relay/backend/graph_plan_memory.cc
##########
@@ -139,13 +141,20 @@ class StorageAllocaBaseVisitor : public 
transform::DeviceAwareExprVisitor {
  protected:
   /*! \brief internal token map */
   std::unordered_map<const ExprNode*, std::vector<StorageToken*>> token_map_;
+  /*! \brief empty token map */
+  const std::vector<StorageToken*> no_tokens_;
 
   /*!
    * \brief Get the necessary token.
    * \param expr The expression.
    * \return The corresponding token.
    */
   const std::vector<StorageToken*>& GetToken(const Expr& expr) {
+    this->VisitExpr(expr);
+    // Return empty if called on a Function

Review comment:
       nit: "describe the intent not the behavior" -- ie functions don't 
require data storage, represented by empty token

##########
File path: src/relay/op/call/call.cc
##########
@@ -0,0 +1,97 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/op/call/call.cc
+ * \brief Operators for calling lowered functions.
+ */
+
+#include "./call.h"
+
+#include <tvm/relay/attrs/call.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/op.h>
+#include <tvm/relay/op_attr_types.h>
+
+#include "../../transforms/infer_layout_utils.h"
+
+namespace tvm {
+namespace relay {
+
+TVM_REGISTER_NODE_TYPE(CallLoweredAttrs);
+
+// call_lowered
+bool CallLoweredRel(const Array<Type>& types, int num_inputs, const Attrs& 
attrs,
+                    const TypeReporter& reporter) {
+  // Types = [func, args, ret_type]
+  ICHECK_EQ(types.size(), 3u);
+  auto func_type = types[0].as<FuncTypeNode>();
+  ICHECK(func_type != nullptr) << "input must be operator with known type";
+  auto input_type = types[1].as<TupleTypeNode>();
+  ICHECK(input_type != nullptr)
+      << "internal invariant violated: call_lowered inputs must be a tuple";
+
+  // Constraint to ensure function arguments are the same type as the inputs 
to the function (modulo
+  // the Tuple wrapper)
+  reporter->Assign(types[1], TupleType(func_type->arg_types, {}));
+  // Constraint to ensure the output of call_lowered is the same as the 
function's return type
+  reporter->Assign(types[2], func_type->ret_type);
+  return true;
+}
+
+const Op& CallLoweredOp() { return Op::Get("call_lowered"); }
+
+Expr CallLowered(Expr func, Expr inputs, Attrs attrs, Array<Type> type_args, 
Span span) {

Review comment:
       nit: for symmetry with the Extract, inputs can be Array<Expr> args.
   
   You can std::move into the Call ctor (kinda a mico optimization, but good 
habit to get into to for when the perf really does matter).

##########
File path: src/relay/backend/aot_executor_codegen.cc
##########
@@ -371,21 +400,20 @@ class AOTExecutorCodegen : public MixedModeVisitor {
     return ss.str();
   }
 
-  void VisitExpr_(const CallNode* op) override {
+  void VisitExpr_(const CallNode* call_node) override {
     // Descend the call tree
-    for (auto arg : op->args) {
+    ICHECK(call_node->op == CallLoweredOp()) << "Only expect call_lowered op 
at this point";

Review comment:
       Let's preserve the original failure message since it's pretty 
informative.

##########
File path: src/relay/backend/te_compiler.cc
##########
@@ -473,9 +475,11 @@ class LowerTensorExprMutator : public 
DeviceAwareExprMutator {
       // act when we process a function.
       this->process_fn_(func_with_metadata);
 
-      // TODO(mbs): Need TIRCallAttrs or equiv so targets know this is an 
extern.
       // TODO(mbs): Dynamic shapes?
-      return {ext_func->prim_fn_var, Attrs()};
+      auto call_lowered_attrs = make_object<CallLoweredAttrs>();
+      // Mark the function as a extern function so that AOT knows what to do 
with
+      call_lowered_attrs->metadata.Set("extern_func", Integer(1));

Review comment:
       I see the check in the graph_executor_codegen.cc but not 
aot_executor_codegen.cc. I wonder if we should leave the extern calls as 
regular old boring calls with no attributes and tackle them separately. Perhaps 
this method can return the call instead of pair so that the different call 
representations can be handled.

##########
File path: src/relay/op/memory/device_copy.cc
##########
@@ -103,6 +107,20 @@ DeviceCopyProps GetDeviceCopyProps(const CallNode* 
call_node) {
     } else {
       return {call_node->args[0], src_dev_type, dst_dev_type};
     }
+  } else if (call_node->op == CallLoweredOp()) {

Review comment:
       I'm a bit nervous about this one since there's a lot of usage patterns 
such as
   
   props = GetProps
   if (prop.body.defined()) { .... MakeDeviceCopy(...) ... }
   
   which would silently demote call_lowered copies to regular copies.
   
   I think safer would be to have a GetLoweredDeviceCopyProps so it's always 
clear which form is being matched against.

##########
File path: src/relay/transforms/device_domains.h
##########
@@ -275,6 +275,7 @@ class DeviceDomains {
   const Op& alloc_tensor_op = Op::Get("memory.alloc_tensor");
   const Op& shape_of_op = Op::Get("vm.shape_of");
   const Op& invoke_tvm_op = Op::Get("vm.invoke_tvm_op");
+  const Op& call_lowered = Op::Get("call_lowered");

Review comment:
       nit: your CallLoweredOp() helper is good enough.
   some day all these will use the same idiom.

##########
File path: src/relay/transforms/device_domains.cc
##########
@@ -47,20 +49,22 @@ constexpr size_t mix(size_t h1, size_t h2) {
  * See te_compiler.cc for where this rewriting occurs.
  */
 DeviceCopyProps GetPrimitiveDeviceCopyProps(const CallNode* call_node) {
-  auto tir_call_attrs = call_node->attrs.as<TIRCallAttrs>();
-  if (tir_call_attrs == nullptr) {
-    return {};
-  }
-  if (tir_call_attrs->metadata.count("source_device") != 1 ||
-      tir_call_attrs->metadata.count("dst_device") != 1) {
-    return {};
+  if (call_node->op == CallLoweredOp()) {

Review comment:
       I think this would become the GetLoweredDeviceCopyProps mentioned above, 
so no need for the helper here.

##########
File path: src/relay/backend/graph_plan_memory.cc
##########
@@ -376,12 +385,12 @@ class StorageAllocator : public StorageAllocaBaseVisitor {
       return fn->HasNonzeroAttr(attr::kReshapeOnly);
     }
 
-    if (call->attrs.defined()) {
-      if (auto tir_call_attrs = call->attrs.as<TIRCallAttrs>()) {
-        Map<String, ObjectRef> metadata = tir_call_attrs->metadata;
-        return metadata.count(attr::kReshapeOnly) &&
-               (Downcast<tvm::Integer>(metadata[attr::kReshapeOnly])->value == 
1);
-      }
+    if (call->op == CallLoweredOp()) {
+      auto call_lowered_attrs = call->attrs.as<CallLoweredAttrs>();

Review comment:
       looks like 'is reshape only' was added by inlining the checks instead of 
a predicate. sigh.

##########
File path: src/relay/op/call/call.cc
##########
@@ -0,0 +1,97 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/op/call/call.cc
+ * \brief Operators for calling lowered functions.
+ */
+
+#include "./call.h"
+
+#include <tvm/relay/attrs/call.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/op.h>
+#include <tvm/relay/op_attr_types.h>
+
+#include "../../transforms/infer_layout_utils.h"
+
+namespace tvm {
+namespace relay {
+
+TVM_REGISTER_NODE_TYPE(CallLoweredAttrs);
+
+// call_lowered
+bool CallLoweredRel(const Array<Type>& types, int num_inputs, const Attrs& 
attrs,
+                    const TypeReporter& reporter) {
+  // Types = [func, args, ret_type]
+  ICHECK_EQ(types.size(), 3u);
+  auto func_type = types[0].as<FuncTypeNode>();

Review comment:
       and then can just the Assign for the TupleType do the check for types[1]

##########
File path: src/relay/op/call/call.cc
##########
@@ -0,0 +1,97 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/op/call/call.cc
+ * \brief Operators for calling lowered functions.
+ */
+
+#include "./call.h"
+
+#include <tvm/relay/attrs/call.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/op.h>
+#include <tvm/relay/op_attr_types.h>
+
+#include "../../transforms/infer_layout_utils.h"
+
+namespace tvm {
+namespace relay {
+
+TVM_REGISTER_NODE_TYPE(CallLoweredAttrs);
+
+// call_lowered
+bool CallLoweredRel(const Array<Type>& types, int num_inputs, const Attrs& 
attrs,
+                    const TypeReporter& reporter) {
+  // Types = [func, args, ret_type]
+  ICHECK_EQ(types.size(), 3u);
+  auto func_type = types[0].as<FuncTypeNode>();
+  ICHECK(func_type != nullptr) << "input must be operator with known type";
+  auto input_type = types[1].as<TupleTypeNode>();
+  ICHECK(input_type != nullptr)
+      << "internal invariant violated: call_lowered inputs must be a tuple";
+
+  // Constraint to ensure function arguments are the same type as the inputs 
to the function (modulo
+  // the Tuple wrapper)
+  reporter->Assign(types[1], TupleType(func_type->arg_types, {}));
+  // Constraint to ensure the output of call_lowered is the same as the 
function's return type
+  reporter->Assign(types[2], func_type->ret_type);
+  return true;
+}
+
+const Op& CallLoweredOp() { return Op::Get("call_lowered"); }
+
+Expr CallLowered(Expr func, Expr inputs, Attrs attrs, Array<Type> type_args, 
Span span) {
+  // Right now, call_lowered only supports func being a global var pointing to 
the lowered
+  // function.
+  ICHECK(func.as<GlobalVarNode>())
+      << "Function to call should be GlobalVarNode, but got " << 
func->GetTypeKey();
+  ICHECK(inputs.as<TupleNode>()) << "Inputs to call_lowered should be 
TupleNode, but got "
+                                 << inputs->GetTypeKey();
+  return Call(CallLoweredOp(), {func, inputs}, attrs);
+}
+
+TVM_REGISTER_GLOBAL("relay.op.call_lowered").set_body_typed(CallLowered);
+
+RELAY_REGISTER_OP("call_lowered")
+    .describe(R"code(Invoke an operation compiled by TVM.)code" 
TVM_ADD_FILELINE)
+    .set_num_inputs(2)
+    .add_argument("op", "Function", "The operation to call")
+    .add_argument("ins", "Tuple", "The input tensors.")
+    .add_type_rel("CallLoweredRel", CallLoweredRel)
+    .set_support_level(10)
+    .set_attr<TOpPattern>("TOpPattern", kOpaque)
+    .set_attr<TOpIsStateful>("TOpIsStateful", false)
+    .set_attr<TNonComputational>("TNonComputational", true)
+    .set_attr<FInferCorrectLayout>("FInferCorrectLayout", 
ElemwiseArbitraryLayout);
+
+std::pair<GlobalVar, Array<Expr>> ExtractFunctionAndArgs(const CallNode* 
call_node) {
+  ICHECK(call_node->op == CallLoweredOp())
+      << "ExtractFunctionAndArgs expects the op to be call_lowered. ";
+  ICHECK(call_node->args.size() == 2) << "Expected call_lowered to have 2 
arguments. ";
+  const GlobalVarNode* function = call_node->args[0].as<GlobalVarNode>();

Review comment:
       nit: i also notice a some uses of the form 'just give me the func and 
args, i don't care what the form of the call actually is'. perhaps a helper for 
that case would improve the call sites? But that's your call.

##########
File path: src/relay/op/call/call.h
##########
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/op/call/call.h
+ * \brief Operators for calling lowered functions.
+ */
+#ifndef TVM_RELAY_OP_CALL_CALL_H_
+#define TVM_RELAY_OP_CALL_CALL_H_
+
+#include <tvm/relay/expr.h>
+
+#include <utility>
+
+namespace tvm {
+namespace relay {
+
+Expr CallLowered(Expr func, Expr inputs, Attrs attrs, Array<Type> type_args, 
Span span);

Review comment:
       nit: comments thanks!

##########
File path: src/relay/backend/graph_executor_codegen.cc
##########
@@ -403,64 +405,75 @@ class GraphExecutorCodegen : public 
backend::MemoizedExprTranslator<std::vector<
     return lhs_storage_id == rhs_storage_id;
   }
 
-  std::vector<GraphNodeRef> GraphAddCallNode(const CallNode* op, const 
std::string& func_name,
-                                             GraphAttrs attrs) {
+  std::vector<GraphNodeRef> GraphAddCallNode(const CallNode* call_node, 
GraphAttrs attrs) {
+    Call call = GetRef<Call>(call_node);
+    ICHECK(call->op == CallLoweredOp())
+        << "Non-primitive-call nodes should have been transformed away.\n"
+        << "The graph executor code generator expects all calls to be 
call_lowered, "
+        << "but found: " << std::endl
+        << PrettyPrint(call);
+
+    // Extract function and arguments from the call_lowered op
+    std::pair<GlobalVar, Array<Expr>> func_and_args = 
ExtractFunctionAndArgs(call_node);
+    GlobalVar func = func_and_args.first;
+    Array<Expr> call_args = func_and_args.second;
+
+    // std::string func_name = name;
+    std::string func_name = func->name_hint;
+
     std::vector<GraphNodeRef> inputs;
-    for (auto arg : op->args) {
-      auto res = VisitExpr(arg);
-      for (auto nr : res) {
-        inputs.push_back(nr);
+    // Visit all the arguments to call_lowered
+    for (Expr arg : call_args) {
+      for (auto n : VisitExpr(arg)) {
+        inputs.push_back(n);
       }
     }
 
     /// An adapted version of the storage optimization for the time being.
     bool reshape_only = false;
-    if (op->attrs.defined()) {
-      if (auto tir_call_attrs = op->attrs.as<TIRCallAttrs>()) {
-        Map<String, ObjectRef> metadata = tir_call_attrs->metadata;
-        if (metadata.count(attr::kReshapeOnly) &&
-            Downcast<tvm::Integer>(metadata[attr::kReshapeOnly])->value == 1) {
-          reshape_only = true;
-        }
-
-        auto relay_attrs = 
Downcast<DictAttrs>(tir_call_attrs->metadata["relay_attrs"]);
-
-        for (auto p : relay_attrs->dict) {
-          if (p.second.as<StringObj>()) {
-            attrs[p.first] = std::string(Downcast<String>(p.second));
+    ICHECK(call_node->attrs.defined()) << "Attrs should be defined!";
+    auto call_lowered_attrs = call_node->attrs.as<CallLoweredAttrs>();
+    ICHECK(call_lowered_attrs) << "Expected call_lowered to have 
CallLoweredAttrs";
+
+    Map<String, ObjectRef> metadata = call_lowered_attrs->metadata;
+    if (!metadata.count("extern_func")) {
+      // Extern funcs don't have relay attrs, so only extract them for 
non-externs funcs.
+      ICHECK(call_lowered_attrs->metadata.count("relay_attrs"))

Review comment:
       nit: use meadata
   nit: you've already ICHECKED, so no need for if right?

##########
File path: src/relay/backend/graph_executor_codegen.cc
##########
@@ -403,64 +405,75 @@ class GraphExecutorCodegen : public 
backend::MemoizedExprTranslator<std::vector<
     return lhs_storage_id == rhs_storage_id;
   }
 
-  std::vector<GraphNodeRef> GraphAddCallNode(const CallNode* op, const 
std::string& func_name,
-                                             GraphAttrs attrs) {
+  std::vector<GraphNodeRef> GraphAddCallNode(const CallNode* call_node, 
GraphAttrs attrs) {
+    Call call = GetRef<Call>(call_node);
+    ICHECK(call->op == CallLoweredOp())
+        << "Non-primitive-call nodes should have been transformed away.\n"
+        << "The graph executor code generator expects all calls to be 
call_lowered, "
+        << "but found: " << std::endl
+        << PrettyPrint(call);
+
+    // Extract function and arguments from the call_lowered op
+    std::pair<GlobalVar, Array<Expr>> func_and_args = 
ExtractFunctionAndArgs(call_node);
+    GlobalVar func = func_and_args.first;
+    Array<Expr> call_args = func_and_args.second;
+
+    // std::string func_name = name;
+    std::string func_name = func->name_hint;
+
     std::vector<GraphNodeRef> inputs;
-    for (auto arg : op->args) {
-      auto res = VisitExpr(arg);
-      for (auto nr : res) {
-        inputs.push_back(nr);
+    // Visit all the arguments to call_lowered
+    for (Expr arg : call_args) {

Review comment:
       nit: good practice to const Expr& or const auto& to avoid the copy. not 
a big deal here but will bite you one day i promise!

##########
File path: src/relay/backend/graph_plan_memory.cc
##########
@@ -376,12 +385,12 @@ class StorageAllocator : public StorageAllocaBaseVisitor {
       return fn->HasNonzeroAttr(attr::kReshapeOnly);
     }
 
-    if (call->attrs.defined()) {
-      if (auto tir_call_attrs = call->attrs.as<TIRCallAttrs>()) {
-        Map<String, ObjectRef> metadata = tir_call_attrs->metadata;
-        return metadata.count(attr::kReshapeOnly) &&
-               (Downcast<tvm::Integer>(metadata[attr::kReshapeOnly])->value == 
1);
-      }
+    if (call->op == CallLoweredOp()) {
+      auto call_lowered_attrs = call->attrs.as<CallLoweredAttrs>();

Review comment:
       nit: just calling your Extract helper would be easier assuming it gave 
you the attrs in the returned struct (as per other comment).

##########
File path: src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc
##########
@@ -109,7 +113,13 @@ class ConvertAddToSubtract : public MixedModeMutator {
         GlobalVar new_global_var(func_name.value());
         new_global_var->checked_type_ = func->checked_type();
         ReplaceAddWithSubtractPrimFunc(new_global_var, GetRef<Function>(func));
-        return Call(new_global_var, call->args, call->attrs, call->type_args, 
call->span);
+
+        // Since we are replacing the Relay function with a call to a TIR 
function, we must use the
+        // call_lowered op.
+        auto call_lowered_attrs = make_object<CallLoweredAttrs>();
+        call_lowered_attrs->metadata.Set("relay_attrs", call->attrs);
+        return CallLowered(new_global_var, Tuple(call->args), 
Attrs(call_lowered_attrs),

Review comment:
       nit: again, micro opt but for the practice: std::move the global, and 
attrs. However you can't move the call->args since the call is const.
   
   It just helps make the ownership flow clearer. But don't get hung up on it.

##########
File path: src/relay/op/call/call.cc
##########
@@ -0,0 +1,97 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/op/call/call.cc
+ * \brief Operators for calling lowered functions.
+ */
+
+#include "./call.h"
+
+#include <tvm/relay/attrs/call.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/op.h>
+#include <tvm/relay/op_attr_types.h>
+
+#include "../../transforms/infer_layout_utils.h"
+
+namespace tvm {
+namespace relay {
+
+TVM_REGISTER_NODE_TYPE(CallLoweredAttrs);
+
+// call_lowered
+bool CallLoweredRel(const Array<Type>& types, int num_inputs, const Attrs& 
attrs,
+                    const TypeReporter& reporter) {
+  // Types = [func, args, ret_type]
+  ICHECK_EQ(types.size(), 3u);
+  auto func_type = types[0].as<FuncTypeNode>();
+  ICHECK(func_type != nullptr) << "input must be operator with known type";
+  auto input_type = types[1].as<TupleTypeNode>();
+  ICHECK(input_type != nullptr)
+      << "internal invariant violated: call_lowered inputs must be a tuple";
+
+  // Constraint to ensure function arguments are the same type as the inputs 
to the function (modulo
+  // the Tuple wrapper)
+  reporter->Assign(types[1], TupleType(func_type->arg_types, {}));
+  // Constraint to ensure the output of call_lowered is the same as the 
function's return type
+  reporter->Assign(types[2], func_type->ret_type);
+  return true;
+}
+
+const Op& CallLoweredOp() { return Op::Get("call_lowered"); }
+
+Expr CallLowered(Expr func, Expr inputs, Attrs attrs, Array<Type> type_args, 
Span span) {
+  // Right now, call_lowered only supports func being a global var pointing to 
the lowered
+  // function.
+  ICHECK(func.as<GlobalVarNode>())
+      << "Function to call should be GlobalVarNode, but got " << 
func->GetTypeKey();
+  ICHECK(inputs.as<TupleNode>()) << "Inputs to call_lowered should be 
TupleNode, but got "
+                                 << inputs->GetTypeKey();
+  return Call(CallLoweredOp(), {func, inputs}, attrs);
+}
+
+TVM_REGISTER_GLOBAL("relay.op.call_lowered").set_body_typed(CallLowered);
+
+RELAY_REGISTER_OP("call_lowered")
+    .describe(R"code(Invoke an operation compiled by TVM.)code" 
TVM_ADD_FILELINE)
+    .set_num_inputs(2)
+    .add_argument("op", "Function", "The operation to call")
+    .add_argument("ins", "Tuple", "The input tensors.")
+    .add_type_rel("CallLoweredRel", CallLoweredRel)
+    .set_support_level(10)
+    .set_attr<TOpPattern>("TOpPattern", kOpaque)
+    .set_attr<TOpIsStateful>("TOpIsStateful", false)
+    .set_attr<TNonComputational>("TNonComputational", true)
+    .set_attr<FInferCorrectLayout>("FInferCorrectLayout", 
ElemwiseArbitraryLayout);
+
+std::pair<GlobalVar, Array<Expr>> ExtractFunctionAndArgs(const CallNode* 
call_node) {
+  ICHECK(call_node->op == CallLoweredOp())
+      << "ExtractFunctionAndArgs expects the op to be call_lowered. ";
+  ICHECK(call_node->args.size() == 2) << "Expected call_lowered to have 2 
arguments. ";
+  const GlobalVarNode* function = call_node->args[0].as<GlobalVarNode>();

Review comment:
       nit: i think const auto* function_node = ... is a nice convention. Saves 
all the type repetition with going 'full Haskell mode' with auto.

##########
File path: src/relay/op/vm/vm.cc
##########
@@ -22,8 +22,6 @@
  * \brief Dialect operators for Relay VM.
  */
 
-#include "vm.h"

Review comment:
       nit: might as well keep it -- standard idiom even though not strictly 
needed.

##########
File path: src/relay/op/call/call.cc
##########
@@ -0,0 +1,97 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/op/call/call.cc
+ * \brief Operators for calling lowered functions.
+ */
+
+#include "./call.h"
+
+#include <tvm/relay/attrs/call.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/op.h>
+#include <tvm/relay/op_attr_types.h>
+
+#include "../../transforms/infer_layout_utils.h"
+
+namespace tvm {
+namespace relay {
+
+TVM_REGISTER_NODE_TYPE(CallLoweredAttrs);
+
+// call_lowered
+bool CallLoweredRel(const Array<Type>& types, int num_inputs, const Attrs& 
attrs,
+                    const TypeReporter& reporter) {
+  // Types = [func, args, ret_type]
+  ICHECK_EQ(types.size(), 3u);
+  auto func_type = types[0].as<FuncTypeNode>();
+  ICHECK(func_type != nullptr) << "input must be operator with known type";
+  auto input_type = types[1].as<TupleTypeNode>();
+  ICHECK(input_type != nullptr)
+      << "internal invariant violated: call_lowered inputs must be a tuple";
+
+  // Constraint to ensure function arguments are the same type as the inputs 
to the function (modulo
+  // the Tuple wrapper)
+  reporter->Assign(types[1], TupleType(func_type->arg_types, {}));
+  // Constraint to ensure the output of call_lowered is the same as the 
function's return type
+  reporter->Assign(types[2], func_type->ret_type);
+  return true;
+}
+
+const Op& CallLoweredOp() { return Op::Get("call_lowered"); }
+
+Expr CallLowered(Expr func, Expr inputs, Attrs attrs, Array<Type> type_args, 
Span span) {
+  // Right now, call_lowered only supports func being a global var pointing to 
the lowered
+  // function.
+  ICHECK(func.as<GlobalVarNode>())
+      << "Function to call should be GlobalVarNode, but got " << 
func->GetTypeKey();
+  ICHECK(inputs.as<TupleNode>()) << "Inputs to call_lowered should be 
TupleNode, but got "
+                                 << inputs->GetTypeKey();
+  return Call(CallLoweredOp(), {func, inputs}, attrs);
+}
+
+TVM_REGISTER_GLOBAL("relay.op.call_lowered").set_body_typed(CallLowered);
+
+RELAY_REGISTER_OP("call_lowered")
+    .describe(R"code(Invoke an operation compiled by TVM.)code" 
TVM_ADD_FILELINE)
+    .set_num_inputs(2)
+    .add_argument("op", "Function", "The operation to call")
+    .add_argument("ins", "Tuple", "The input tensors.")
+    .add_type_rel("CallLoweredRel", CallLoweredRel)
+    .set_support_level(10)
+    .set_attr<TOpPattern>("TOpPattern", kOpaque)
+    .set_attr<TOpIsStateful>("TOpIsStateful", false)
+    .set_attr<TNonComputational>("TNonComputational", true)
+    .set_attr<FInferCorrectLayout>("FInferCorrectLayout", 
ElemwiseArbitraryLayout);
+
+std::pair<GlobalVar, Array<Expr>> ExtractFunctionAndArgs(const CallNode* 
call_node) {
+  ICHECK(call_node->op == CallLoweredOp())
+      << "ExtractFunctionAndArgs expects the op to be call_lowered. ";
+  ICHECK(call_node->args.size() == 2) << "Expected call_lowered to have 2 
arguments. ";
+  const GlobalVarNode* function = call_node->args[0].as<GlobalVarNode>();

Review comment:
       nit: for ergonomics might want this to return the struct with null 
global & empty array to signal 'not a lower call', then test for that in all 
the places you currently call.
   
   again, that was the style i settled on for the on_device etc stuff, but 
kinda a personal choice at this point.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to