giuseros commented on a change in pull request #7785: URL: https://github.com/apache/tvm/pull/7785#discussion_r608186592
########## File path: src/relay/backend/aot_codegen.cc ########## @@ -0,0 +1,675 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file relay/backend/graph_codegen.cc + * \brief Graph runtime codegen + */ + +#include <dmlc/any.h> +#include <tvm/ir/module.h> +#include <tvm/relay/expr_functor.h> +#include <tvm/runtime/device_api.h> +#include <tvm/tir/builtin.h> +#include <tvm/tir/expr.h> +#include <tvm/tir/stmt.h> + +#include <algorithm> +#include <list> +#include <string> +#include <vector> + +#include "../../runtime/meta_data.h" +#include "compile_engine.h" +#include "utils.h" + +namespace tvm { +namespace relay { +namespace backend { + +using IntegerArray = Array<Integer>; +using ShapeVector = std::vector<std::vector<int64_t>>; +using GraphAttrs = std::unordered_map<std::string, dmlc::any>; +using TargetsMap = std::unordered_map<int, Target>; + +/*! \brief Lowered outputs */ +struct AOTLoweredOutput { + std::string graph_tir; + Map<String, IRModule> lowered_funcs; + Array<tvm::runtime::Module> external_mods; + std::unordered_map<std::string, std::pair<int, const tvm::runtime::NDArray>> params; + runtime::AOTMetadata aot_metadata; +}; + +class AotReturnSidVisitor : public ExprVisitor { + public: + explicit AotReturnSidVisitor(Map<Expr, Array<IntegerArray>> storage_device_map) + : storage_device_map_{storage_device_map}, return_sid_{-1} {} + + IntegerArray FindReturnSid(Function func) { + VisitExpr(func->body); + return return_sid_; + } + + protected: + void AssignReturnSid(Expr e) { + auto iter = storage_device_map_.find(e); + if (iter != storage_device_map_.end()) { + return_sid_ = (*iter).second[0]; + } + } + + void VisitExpr_(const ConstantNode* cn) override { + ExprVisitor::VisitExpr_(cn); + AssignReturnSid(GetRef<Expr>(cn)); + } + + void VisitExpr_(const VarNode* vn) override { + ExprVisitor::VisitExpr_(vn); + AssignReturnSid(GetRef<Expr>(vn)); + } + + void VisitExpr_(const CallNode* cn) override { + ExprVisitor::VisitExpr_(cn); + AssignReturnSid(GetRef<Expr>(cn)); + } + + void VisitExpr_(const LetNode* op) override { VisitExpr(op->body); } + + void VisitExpr_(const TupleNode* tn) override { + ExprVisitor::VisitExpr_(tn); + AssignReturnSid(GetRef<Expr>(tn)); + } + + private: + Map<Expr, Array<IntegerArray>> storage_device_map_; + IntegerArray return_sid_; +}; + +using TIRNetwork = tvm::Array<tir::Stmt>; + +/*! \brief Code generator for graph runtime */ +class AOTCodegen : public ExprVisitor { + protected: + /*! + * \brief Utility function to allocate a DLTensor or TVMValue + * \param type the type of allocation + * \param num the number of variable to allocate on the stack + * \return PrimExpr representing the allocated object + */ + PrimExpr StackAlloca(std::string type, size_t num) { + Array<PrimExpr> args = {tir::StringImm(type), ConstInt32(num)}; + return tir::Call(DataType::Handle(), tir::builtin::tvm_stack_alloca(), args); + } + + /*! + * \brief Utility function to allocate memory for storage identifiers + * \param memory_size_byte size in bytes of the allocation + * \return PrimExpr representing the allocated memory + */ + PrimExpr AllocateBackendMemory(int memory_size_byte) { + // TODO(giuseros): use tir::Allocate instead of TVMBackendAllocWorkspace + // to enable unified memory planning + static const Op& op = Op::Get("tir.TVMBackendAllocWorkspace"); + return tvm::tir::Call(DataType::Handle(), op, {1, 0, memory_size_byte, 2, 8}); + } + + /*! + * \brief Utility function to convert a concrete integer to a PrimExpr. + * \param num the number to convert + * \return PrimExpr representing num + */ + inline PrimExpr ConstInt32(size_t num) { + ICHECK_LE(num, std::numeric_limits<int>::max()); + return tir::make_const(DataType::Int(32), static_cast<int>(num)); + } + + /*! + * \brief Return a vector of variables that represents the sids for the given Relay Expr + */ + std::vector<tir::Var> pack_sid(Expr expr) { + Array<IntegerArray> sids = storage_device_map_[expr]; + std::vector<tir::Var> sid_vars; + + // Note that an expression can have multiple sids associated with it + // e.g., returning multiple values from a function + for (const auto& sid : sids[0]) { + // Determine if an sid is an output buffer + int sid_int = static_cast<int>((sid.as<IntImmNode>())->value); + auto output_iter = std::find(return_sid_.begin(), return_sid_.end(), sid_int); + if (output_iter != return_sid_.end()) { + int output_index = std::distance(return_sid_.begin(), output_iter); + sid_vars.push_back(main_signature_[input_vars_.size() + output_index]); + continue; + } + // Pack the sid inside the TVMValue + auto sid_array = te::Var(make_string("sid_", sid, "_value"), DataType::Handle()); + auto sid_value = sids_table_[sid]; + tvm::PrimExpr set_tensor = + tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_set(), + {sid_array, 0, tir::builtin::kArrData, sid_value}); + stmts_.push_back(tir::LetStmt(sid_array, StackAlloca("array", 1), tir::Evaluate(set_tensor))); + sid_vars.push_back(sid_array); + } + return sid_vars; + } + + /*! + * \brief Utility function to return a parameter associated with an expression + * \param expr Relay Expression assicated with the parameter + * \return Variable that represents the DLTensor associated with the parameters + */ + tir::Var pack_param(Expr expr) { + // TODO(giuseros): Using call_extern to call into lookup_linked_param. This is because the + // builtin::ret is not supported yet in the c target. Once return is supported we can use + // tvm_call_packed_lowered(). + int param_sid = param_storage_ids_[reverse_params_lookup_[expr]]; + auto lookup_linked_param_fn = tir::StringImm(::tvm::runtime::symbol::tvm_lookup_linked_param); + auto param_array = te::Var(make_string("param_", param_sid, "_array"), DataType::Handle()); + + // Compose the lookup_call using a local stack + Array<tir::Stmt> lookup_call; + auto param_var = te::Var(make_string("param_", param_sid, "_value"), DataType::Handle()); + auto ret_var = te::Var("ret_value", DataType::Handle()); + auto ret_code = te::Var("ret_value", DataType::Handle()); + + lookup_call.push_back(tir::Evaluate( + tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_set(), + {param_var, 0, tir::builtin::kTVMValueContent, ConstInt32(param_sid)}))); + lookup_call.push_back(tir::Evaluate( + tvm::tir::Call(DataType::Handle(), tir::builtin::call_extern(), + {lookup_linked_param_fn, param_var, 0, 0, ret_var, ret_code, 0}))); + auto ret_var_handle = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_get(), + {ret_var, 0, tir::builtin::kTVMValueContent}); + + // Set the param to the value returned by lookup_call + tvm::PrimExpr set_param_array = + tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_set(), + {param_array, 0, tir::builtin::kArrData, ret_var_handle}); + lookup_call.push_back(tir::Evaluate(set_param_array)); + + tir::Stmt lookup_body = tir::SeqStmt(lookup_call); + + // Allocate the DLTensors on the stack + lookup_body = tir::LetStmt(param_var, StackAlloca("arg_value", 1), lookup_body); + lookup_body = tir::LetStmt(ret_var, StackAlloca("arg_value", 1), lookup_body); + lookup_body = tir::LetStmt(ret_code, StackAlloca("arg_value", 1), lookup_body); + lookup_body = tir::LetStmt(param_array, StackAlloca("arg_value", 1), lookup_body); + stmts_.push_back(lookup_body); + return param_array; + } + + /*! + * brief Given an expression return the variable(s) associated with that expression + */ + std::vector<te::Var> find_expr(Expr arg) { + auto input_iter = std::find(input_vars_.begin(), input_vars_.end(), arg); + if (input_iter != input_vars_.end()) { + // Input variable + int main_index = std::distance(input_vars_.begin(), input_iter); + return {main_signature_[main_index]}; + } else if (reverse_params_lookup_.find(arg) != reverse_params_lookup_.end()) { + // Parameter of the network + return {pack_param(arg)}; + } else { + // Storage identifier (i.e., intermediate memory) + return pack_sid(arg); + } + } + + /*! + * brief Call a function with a given name + */ + void func_call(Call call, std::string func_name) { + tvm::Array<PrimExpr> args{tvm::tir::StringImm(func_name)}; + std::vector<tir::Stmt> func_call_stmts; + + // Pack the inputs + for (Expr arg : call->args) { + auto var_arg = find_expr(arg); + args.push_back(var_arg[0]); + } + + auto ret_expr = Downcast<Expr>(call); + + // Pack the return(s) value. A call node can produce multiple outputs + for (const auto& var : pack_sid(ret_expr)) { + args.push_back(var); + } + + // Use tvm_call_packed to execute the function + func_call_stmts.push_back(tir::Evaluate( + tvm::tir::Call(DataType::Int(32), tvm::tir::builtin::tvm_call_packed(), args))); + tir::Stmt body = tir::SeqStmt(func_call_stmts); + stmts_.push_back(body); + } + + /*! + * brief Copy a variable to the output. This function is mainly used in edge cases + * when we want to return an input or a parameter. + */ + void copy_to_output(te::Var out, te::Var in, size_t size) { + auto retval_get = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_get(), + {in, 0, tir::builtin::kArrData}); + + // Define intermediate DLTensor to load/store the data + auto tmp0 = te::Var("tmp0", DataType::Handle()); + auto tmp1 = te::Var("tmp1", DataType::Handle()); + te::Var loop_idx("i", DataType::Int(32)); + auto retval_i = tir::Load(DataType::UInt(8), tmp0, loop_idx, tir::const_true()); + auto tostore = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_get(), + {out, 0, tir::builtin::kArrData}); + + // Copy the variable from the input to the output + tir::Stmt copy = tir::For( + loop_idx, 0, ConstInt32(size), tir::ForKind::kSerial, + tir::Store(tmp1, tir::Let(tmp0, retval_get, retval_i), loop_idx, tir::const_true())); + stmts_.push_back(tir::LetStmt(tmp1, tostore, copy)); + } + + /*! + * Utility function to string together different arguments + */ + template <typename... Args> + std::string make_string(Args const&... args) { + std::ostringstream ss; + using List = int[]; + (void)List{0, ((void)(ss << args), 0)...}; + + return ss.str(); + } + + void VisitExpr_(const CallNode* op) override { + // Descend the call tree + for (auto arg : op->args) { + VisitExpr(arg); + } + + Expr expr = GetRef<Expr>(op); + Function func; + if (op->op.as<OpNode>()) { + LOG(FATAL) << "Operators should be transformed away; try applying" + << "the fuse_ops transformation to the expression."; + } else if (op->op.as<GlobalVarNode>()) { + LOG(FATAL) << "Not implemented"; + } else if (op->op.as<FunctionNode>()) { + func = GetRef<Function>(op->op.as<FunctionNode>()); + } else { + LOG(FATAL) << "TVM runtime does not support calls to " << op->op->GetTypeKey(); + } + if (!func->HasNonzeroAttr(attr::kPrimitive)) { + LOG(FATAL) << "TVM only support calls to primitive functions " + << "(i.e functions composed of fusable operator invocations)"; + } + + auto pf0 = GetPackedFunc("relay.backend._make_CCacheKey"); + auto pf1 = GetPackedFunc("relay.backend._CompileEngineLower"); + Target target; + // Handle external function + if (func->GetAttr<String>(attr::kCompiler).defined()) { + target = Target("ext_dev"); + CCacheKey key = (*pf0)(func, target); + CachedFunc ext_func = (*pf1)(compile_engine_, key); + ICHECK(ext_func.defined()) << "External function is not defined."; + UpdateConstants(func, ¶ms_); + + // Generate the TIR function call + func_call(GetRef<Call>(op), ext_func->func_name); + } + + ICHECK_GE(storage_device_map_.count(expr), 0); + auto& device_type = storage_device_map_[expr][1]; + auto call_dev_type = device_type[0]->value; + // Normal Relay Function + if (targets_.size() == 1) { + // homogeneous execution. + const auto& it = targets_.begin(); + target = (*it).second; + } else { + // heterogeneous execution. + std::string call_dev_name; + if (call_dev_type == 0) { + call_dev_name = "llvm"; + } else { + call_dev_name = runtime::DeviceName(call_dev_type); + } + if (targets_.count(call_dev_type) == 0) { + LOG(FATAL) << "No target is provided for device " << call_dev_name; + } + target = targets_[call_dev_type]; + } + CCacheKey key = (*pf0)(func, target); + CachedFunc lowered_func = (*pf1)(compile_engine_, key); + if (!lowered_funcs_.count(target->str())) { + lowered_funcs_[target->str()] = IRModule(Map<GlobalVar, BaseFunc>({})); + } + lowered_funcs_[target->str()]->Update(lowered_func->funcs); + + // Generate the TIR function call + func_call(GetRef<Call>(op), lowered_func->func_name); + } + + void VisitExpr_(const VarNode* op) override { + Expr expr = GetRef<Expr>(op); + + // If the Var node is an output node we need to copy the content of the variable to the output + // A Var node can only produce a single output + Array<IntegerArray> sids = storage_device_map_[expr]; + auto output_iter = std::find(return_sid_.begin(), return_sid_.end(), + static_cast<int>((sids[0][0].as<IntImmNode>())->value)); + if (output_iter != return_sid_.end()) { + int output_index = std::distance(return_sid_.begin(), output_iter); + auto var_expr = find_expr(expr); + copy_to_output(main_signature_[input_vars_.size() + output_index], var_expr[0], sids[2][0]); + } + } + + void VisitExpr_(const ConstantNode* op) override { + Expr expr = GetRef<Expr>(op); + size_t index = params_.size(); + std::string name = "p" + std::to_string(index); + + param_storage_ids_[name] = storage_device_map_[expr][0][0]->value; + params_[name] = op->data; + reverse_params_lookup_.Set(expr, name); + + // If the Constant node is an output node we need to copy the content of the parameter to the + // output A Var node can only produce a single output + Array<IntegerArray> sids = storage_device_map_[expr]; + auto output_iter = std::find(return_sid_.begin(), return_sid_.end(), + static_cast<int>((sids[0][0].as<IntImmNode>())->value)); + if (output_iter != return_sid_.end()) { + int output_index = std::distance(return_sid_.begin(), output_iter); + copy_to_output(main_signature_[input_vars_.size() + output_index], pack_param(expr), + sids[2][0]); + } + } + + void VisitExpr_(const TupleNode* op) override { + for (auto field : op->fields) { + VisitExpr(field); + } + } + + void VisitExpr_(const LetNode* op) override { + // TODO(giuseros): support Let nodes in AOT + throw std::invalid_argument("Let not yet implemented in AOT"); + } + void VisitExpr_(const TupleGetItemNode* op) override { VisitExpr(op->tuple); } + void VisitExpr_(const OpNode* op) override { + throw std::runtime_error("can not compile op in non-eta expanded form"); + } + void VisitExpr_(const GlobalVarNode* op) override { throw std::runtime_error(""); } + void VisitExpr_(const IfNode* op) override { throw std::invalid_argument("if not supported"); } + void VisitExpr_(const FunctionNode* op) override { + ICHECK(op->GetAttr<String>(attr::kCompiler).defined()) + << "Only functions supported by custom codegen"; + } + void VisitExpr_(const RefCreateNode* op) override { + throw std::invalid_argument("reference not supported"); + } + void VisitExpr_(const RefReadNode* op) override { + throw std::invalid_argument("reference not supported"); + } + void VisitExpr_(const RefWriteNode* op) override { + throw std::invalid_argument("reference not supported"); + } + void VisitExpr_(const ConstructorNode* op) override { + throw std::invalid_argument("ADT constructor case not yet implemented"); + } + void VisitExpr_(const MatchNode* op) override { + throw std::invalid_argument("match case not yet implemented"); + } + + // Create the main PrimFunc to execute the graph + tir::PrimFunc CreateMainFunc(unsigned int relay_params) { + tir::Stmt body = tir::SeqStmt(stmts_); + + // Allocate the sids + std::unordered_map<int, bool> allocated; + + for (auto kv : storage_device_map_) { + // Only allocate sids that are needed + const bool is_input = + (std::find(input_vars_.begin(), input_vars_.end(), kv.first) != input_vars_.end()); + const bool is_param = (reverse_params_lookup_.find(kv.first) != reverse_params_lookup_.end()); + if (is_input || is_param) { + continue; + } + + for (unsigned int i = 0; i < kv.second[0].size(); i++) { + int size = kv.second[2][i]; + int sid = static_cast<int>((kv.second[0][i].as<IntImmNode>())->value); + + if (std::find(return_sid_.begin(), return_sid_.end(), sid) != return_sid_.end()) { + continue; + } + + if (!allocated[sid]) { + body = tir::LetStmt(sids_table_[sid], AllocateBackendMemory(size), body); Review comment: That would be nice. Can we do that? As in, can we add a tir::let statement to an entire PrimFunc? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: [email protected]
