electriclilies commented on a change in pull request #9312: URL: https://github.com/apache/tvm/pull/9312#discussion_r744014917
########## File path: src/relay/op/call/call.cc ########## @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/op/call/call.cc + * \brief Operators for calling lowered functions. + */ + +#include "./call.h" + +#include <tvm/relay/attrs/call.h> +#include <tvm/relay/expr.h> +#include <tvm/relay/op.h> +#include <tvm/relay/op_attr_types.h> + +#include "../../transforms/infer_layout_utils.h" + +namespace tvm { +namespace relay { + +TVM_REGISTER_NODE_TYPE(CallLoweredAttrs); + +// call_lowered +bool CallLoweredRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + // Types = [func, args, ret_type] + ICHECK_EQ(types.size(), 3u); + auto func_type = types[0].as<FuncTypeNode>(); + ICHECK(func_type != nullptr) << "input must be operator with known type"; + auto input_type = types[1].as<TupleTypeNode>(); + ICHECK(input_type != nullptr) + << "internal invariant violated: call_lowered inputs must be a tuple"; + + // Constraint to ensure function arguments are the same type as the inputs to the function (modulo + // the Tuple wrapper) + reporter->Assign(types[1], TupleType(func_type->arg_types, {})); + // Constraint to ensure the output of call_lowered is the same as the function's return type + reporter->Assign(types[2], func_type->ret_type); + return true; +} + +const Op& CallLoweredOp() { return Op::Get("call_lowered"); } + +Expr CallLowered(Expr func, Expr inputs, Attrs attrs, Array<Type> type_args, Span span) { + // Right now, call_lowered only supports func being a global var pointing to the lowered + // function. + ICHECK(func.as<GlobalVarNode>()) + << "Function to call should be GlobalVarNode, but got " << func->GetTypeKey(); + ICHECK(inputs.as<TupleNode>()) << "Inputs to call_lowered should be TupleNode, but got " + << inputs->GetTypeKey(); + return Call(CallLoweredOp(), {func, inputs}, attrs); +} + +TVM_REGISTER_GLOBAL("relay.op.call_lowered").set_body_typed(CallLowered); + +RELAY_REGISTER_OP("call_lowered") + .describe(R"code(Invoke an operation compiled by TVM.)code" TVM_ADD_FILELINE) + .set_num_inputs(2) + .add_argument("op", "Function", "The operation to call") + .add_argument("ins", "Tuple", "The input tensors.") + .add_type_rel("CallLoweredRel", CallLoweredRel) + .set_support_level(10) + .set_attr<TOpPattern>("TOpPattern", kOpaque) + .set_attr<TOpIsStateful>("TOpIsStateful", false) + .set_attr<TNonComputational>("TNonComputational", true) + .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout); + +std::pair<GlobalVar, Array<Expr>> ExtractFunctionAndArgs(const CallNode* call_node) { + ICHECK(call_node->op == CallLoweredOp()) + << "ExtractFunctionAndArgs expects the op to be call_lowered. "; + ICHECK(call_node->args.size() == 2) << "Expected call_lowered to have 2 arguments. "; + const GlobalVarNode* function = call_node->args[0].as<GlobalVarNode>(); Review comment: I think i'll just leave it for now instead of introducing a new helper -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
