kparzysz-quic commented on a change in pull request #10283: URL: https://github.com/apache/tvm/pull/10283#discussion_r816210034
########## File path: src/runtime/aot_executor/aot_executor.cc ########## @@ -0,0 +1,190 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \brief Defines an implementation of Module-based Model Runtime Interface that works with + * Ahead-of-Time compilation. + * \file aot_executor.cc + */ + +#include "aot_executor.h" + +#include <tvm/runtime/c_runtime_api.h> + +#include <memory> + +#include "../meta_data.h" + +namespace tvm { +namespace runtime { + +AotExecutor::AotExecutor(tvm::runtime::Module module, const std::vector<Device>& devs) + : module_{module}, devices_{devs} { + auto fmetadata = module->GetFunction("get_metadata"); + CHECK(fmetadata != nullptr) << "Expected a module with PackedFunc get_metadata"; + auto ret_value = fmetadata(); + metadata_ = ret_value.AsObjectRef<tvm::runtime::metadata::Metadata>(); + + for (auto input : metadata_->inputs()) { + // TODO(areusch): Encode device information in Metadata. + args_.emplace_back(NDArray::Empty(ShapeTuple(input->shape().begin(), input->shape().end()), + input->dtype(), devices_[0])); + } + + for (auto output : metadata_->outputs()) { + args_.emplace_back(NDArray::Empty(ShapeTuple(output->shape().begin(), output->shape().end()), + output->dtype(), devices_[0])); + } + + for (auto pool : metadata_->pools()) { + args_.emplace_back(NDArray::Empty(ShapeTuple(pool->shape().begin(), pool->shape().end()), + pool->dtype(), devices_[0])); + } +} + +PackedFunc AotExecutor::GetFunction(const std::string& name, + const ObjectPtr<Object>& sptr_to_self) { + // Return member functions during query. + if (name == "set_input") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + if (String::CanConvertFrom(args[0])) { + int in_idx = this->GetInputIndex(args[0].operator String()); + if (in_idx >= 0) this->SetInput(in_idx, args[1]); + } else { + this->SetInput(args[0], args[1]); + } + }); + } else if (name == "set_input_zero_copy") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + if (String::CanConvertFrom(args[0])) { + int in_idx = this->GetInputIndex(args[0].operator String()); + if (in_idx >= 0) this->SetInputZeroCopy(in_idx, args[1]); + } else { + this->SetInputZeroCopy(args[0], args[1]); + } + }); + } else if (name == "set_output_zero_copy") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + if (String::CanConvertFrom(args[0])) { + int out_idx = this->GetOutputIndex(args[0].operator String()); + if (out_idx >= 0) this->SetOutputZeroCopy(out_idx, args[1]); + } else { + this->SetOutputZeroCopy(args[0], args[1]); + } + }); + } else if (name == "get_output") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + if (args.num_args == 2) { + this->CopyOutputTo(args[0], args[1]); + } else { + *rv = this->GetOutput(args[0]); + } + }); + } else if (name == "get_input") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + int in_idx = 0; + if (String::CanConvertFrom(args[0])) { + in_idx = this->GetInputIndex(args[0].operator String()); + } else { + in_idx = args[0]; + } + if (in_idx >= 0) { + *rv = this->GetInput(in_idx); + } + }); + } else if (name == "get_num_outputs") { + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->NumOutputs(); }); + } else if (name == "get_num_inputs") { + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->NumInputs(); }); + } else if (name == "run") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { this->Run(); }); + } else if (name == "get_input_index") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + CHECK(String::CanConvertFrom(args[0])) << "Input key is not a string"; + *rv = this->GetInputIndex(args[0].operator String()); + }); + } else { + return PackedFunc(); + } +} + +void AotExecutor::Run() { + auto pf = module_.GetFunction( + get_name_mangled(metadata_->mod_name(), ::tvm::runtime::symbol::tvm_module_main), + true /* query_imports */); + ICHECK(pf != nullptr) << "Module entrypoint is not defined"; + + const int num_args = args_.size(); + ::std::unique_ptr<TVMValue> call_values{new TVMValue[num_args]}; + ::std::unique_ptr<int> call_type_codes{new int[num_args]}; Review comment: These two lines should be ``` auto call_values = std::make_unique<TVMValue[]>(num_args); auto call_type_codes = std::make_unique<int[]>(num_args); ``` ########## File path: src/relay/backend/aot_executor_codegen.cc ########## @@ -177,6 +179,12 @@ class AOTOnDemandAllocator : public transform::DeviceAwareExprVisitor { for (auto sid : sinfo->storage_ids) { return_ids_.push_back(sid); } + return_ttypes_.clear(); + auto ttypes = FlattenTupleType(e->checked_type()); + return_ttypes_.reserve(ttypes.size()); + for (auto ttype : ttypes) { + return_ttypes_.push_back(ttype); + } Review comment: All of this is equivalent to `return_ttypes_ = FlattenTupleType(e->checked_type());`. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
