tkonolige commented on a change in pull request #7518:
URL: https://github.com/apache/tvm/pull/7518#discussion_r602419800
##########
File path: src/relay/backend/graph_runtime_codegen.cc
##########
@@ -181,23 +186,56 @@ class GraphOpNode : public GraphNode {
const std::string op_type_name_{"tvm_op"};
};
-/*! \brief Code generator for graph runtime */
+/*! \brief Code generator for the graph runtime, produces a module containing
the graph JSON,
+ * module, and parameters.
+ */
class GraphRuntimeCodegen : public
backend::MemoizedExprTranslator<std::vector<GraphNodeRef>> {
public:
GraphRuntimeCodegen(runtime::Module* mod, const TargetsMap& targets) :
mod_(mod) {
- compile_engine_ = CompileEngine::Global();
targets_ = targets;
}
LoweredOutput Codegen(relay::Function func) {
- auto pf = GetPackedFunc("relay.backend.GraphPlanMemory");
- storage_device_map_ = (*pf)(func);
+ // Jared: why do we do this? just call C++ API.
+ // auto pf = GetPackedFunc("relay.backend.GraphPlanMemory");
+ // storage_device_map_ = (*pf)(func);
+ storage_device_map_ = GraphPlanMemory(func);
+
+ // This first phase moves from implicit use of compile engine,
+ // to instead the lower the incoming IRModule, and then performing
Review comment:
```suggestion
// to instead lowering the incoming IRModule, and then performing
```
##########
File path: src/relay/backend/graph_runtime_codegen.cc
##########
@@ -181,23 +186,56 @@ class GraphOpNode : public GraphNode {
const std::string op_type_name_{"tvm_op"};
};
-/*! \brief Code generator for graph runtime */
+/*! \brief Code generator for the graph runtime, produces a module containing
the graph JSON,
+ * module, and parameters.
+ */
class GraphRuntimeCodegen : public
backend::MemoizedExprTranslator<std::vector<GraphNodeRef>> {
public:
GraphRuntimeCodegen(runtime::Module* mod, const TargetsMap& targets) :
mod_(mod) {
- compile_engine_ = CompileEngine::Global();
targets_ = targets;
}
LoweredOutput Codegen(relay::Function func) {
- auto pf = GetPackedFunc("relay.backend.GraphPlanMemory");
- storage_device_map_ = (*pf)(func);
+ // Jared: why do we do this? just call C++ API.
+ // auto pf = GetPackedFunc("relay.backend.GraphPlanMemory");
+ // storage_device_map_ = (*pf)(func);
+ storage_device_map_ = GraphPlanMemory(func);
+
+ // This first phase moves from implicit use of compile engine,
+ // to instead the lower the incoming IRModule, and then performing
+ // the pre-exiting graph runtime code generation phase.
Review comment:
```suggestion
// the pre-existing graph runtime code generation phase.
```
##########
File path: src/relay/backend/te_compiler.cc
##########
@@ -0,0 +1,372 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "te_compiler.h"
+
+#include <tvm/driver/driver_api.h>
+#include <tvm/ir/type_functor.h>
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/attrs/device_copy.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/op.h>
+#include <tvm/relay/op_attr_types.h>
+#include <tvm/runtime/container.h>
+#include <tvm/runtime/device_api.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/te/operation.h>
+#include <tvm/te/schedule.h>
+#include <tvm/te/schedule_pass.h>
+#include <tvm/topi/tags.h>
+
+#include <functional>
+#include <limits>
+#include <mutex>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "../transforms/pass_utils.h"
+#include "te_compiler_cache.h"
+#include "utils.h"
+
+namespace tvm {
+namespace relay {
+namespace tec {
+
+using namespace tvm::relay::transform;
+
+TVM_REGISTER_OBJECT_TYPE(TECompilerNode);
+
+class TECompilerImpl : public TECompilerNode {
+ public:
+ // Lower the function.
+ CachedFunc Lower(const CCacheKey& key) { return
LowerInternal(key)->cached_func; }
+
+ // For now, build one module per function.
+ PackedFunc JIT(const CCacheKey& key) final {
+ CCacheValue value = LowerInternal(key);
+ if (value->packed_func != nullptr) {
+ return value->packed_func;
+ }
+ auto m = build(value->cached_func->funcs, key->target, Target(nullptr));
+ value->packed_func =
m.GetFunction(value->cached_func->prim_fn_var->name_hint);
+ return value->packed_func;
+ }
+
+ CachedFunc LowerShapeFunc(const CCacheKey& key) final {
+ return LowerShapeFuncInternal(key)->cached_func;
+ }
+
+ Map<String, IRModule> GetLoweredFunctions() {
+ Map<String, IRModule> lowered_functions;
+ for (const auto& it : cache_) {
+ auto source_func = it.first;
+ auto lowered_func = it.second;
+ auto target = source_func->target;
+
+ if (!lowered_functions.count(target->str())) {
+ lowered_functions.Set(target->str(), IRModule(Map<GlobalVar,
BaseFunc>({})));
+ }
+
+
lowered_functions[target->str()]->Update(lowered_func->cached_func->funcs);
+ }
+ return lowered_functions;
+ }
+
+ Array<tvm::runtime::Module> LowerExternalFunctions() {
+ Array<tvm::runtime::Module> ret;
+ std::unordered_map<std::string, std::string> cached_symbol;
+ std::vector<CCacheKey> cached_ext_funcs;
+ for (const auto& it : cache_) {
+ auto src_func = it.first->source_func;
+ ICHECK(src_func.defined());
+ if (src_func->GetAttr<String>(attr::kCompiler).defined()) {
+ auto code_gen = src_func->GetAttr<String>(attr::kCompiler);
+ ICHECK(code_gen.defined()) << "No external codegen is set";
+ std::string code_gen_name = code_gen.value();
+ cached_ext_funcs.push_back(it.first);
+
+ auto symbol_name = src_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
+ ICHECK(symbol_name.defined()) << "No external symbol is set for:\n"
+ << AsText(src_func, false);
+
+ std::string sn = symbol_name.value();
+ if (cached_symbol.count(sn)) {
+ cached_symbol[sn] = code_gen_name;
+ } else {
+ ICHECK_NE(sn, code_gen_name)
+ << "Found duplicated symbol: " << sn << " for: " <<
code_gen_name;
+ }
+
+ std::string ext_name = "relay.ext." + code_gen_name;
+ auto pf = tvm::runtime::Registry::Get(ext_name);
+ ICHECK(pf) << "Failed to find the codegen tool for " << ext_name <<
"\n";
Review comment:
```suggestion
ICHECK(pf) << "Failed to find the codegen tool for " << ext_name;
```
##########
File path: src/relay/backend/te_compiler.cc
##########
@@ -0,0 +1,372 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "te_compiler.h"
+
+#include <tvm/driver/driver_api.h>
+#include <tvm/ir/type_functor.h>
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/attrs/device_copy.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/op.h>
+#include <tvm/relay/op_attr_types.h>
+#include <tvm/runtime/container.h>
+#include <tvm/runtime/device_api.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/te/operation.h>
+#include <tvm/te/schedule.h>
+#include <tvm/te/schedule_pass.h>
+#include <tvm/topi/tags.h>
+
+#include <functional>
+#include <limits>
+#include <mutex>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "../transforms/pass_utils.h"
+#include "te_compiler_cache.h"
+#include "utils.h"
+
+namespace tvm {
+namespace relay {
+namespace tec {
+
+using namespace tvm::relay::transform;
+
+TVM_REGISTER_OBJECT_TYPE(TECompilerNode);
+
+class TECompilerImpl : public TECompilerNode {
+ public:
+ // Lower the function.
+ CachedFunc Lower(const CCacheKey& key) { return
LowerInternal(key)->cached_func; }
+
+ // For now, build one module per function.
+ PackedFunc JIT(const CCacheKey& key) final {
+ CCacheValue value = LowerInternal(key);
+ if (value->packed_func != nullptr) {
+ return value->packed_func;
+ }
+ auto m = build(value->cached_func->funcs, key->target, Target(nullptr));
+ value->packed_func =
m.GetFunction(value->cached_func->prim_fn_var->name_hint);
+ return value->packed_func;
+ }
+
+ CachedFunc LowerShapeFunc(const CCacheKey& key) final {
+ return LowerShapeFuncInternal(key)->cached_func;
+ }
+
+ Map<String, IRModule> GetLoweredFunctions() {
+ Map<String, IRModule> lowered_functions;
+ for (const auto& it : cache_) {
+ auto source_func = it.first;
+ auto lowered_func = it.second;
+ auto target = source_func->target;
+
+ if (!lowered_functions.count(target->str())) {
+ lowered_functions.Set(target->str(), IRModule(Map<GlobalVar,
BaseFunc>({})));
+ }
+
+
lowered_functions[target->str()]->Update(lowered_func->cached_func->funcs);
+ }
+ return lowered_functions;
+ }
+
+ Array<tvm::runtime::Module> LowerExternalFunctions() {
+ Array<tvm::runtime::Module> ret;
+ std::unordered_map<std::string, std::string> cached_symbol;
+ std::vector<CCacheKey> cached_ext_funcs;
+ for (const auto& it : cache_) {
+ auto src_func = it.first->source_func;
+ ICHECK(src_func.defined());
+ if (src_func->GetAttr<String>(attr::kCompiler).defined()) {
+ auto code_gen = src_func->GetAttr<String>(attr::kCompiler);
+ ICHECK(code_gen.defined()) << "No external codegen is set";
+ std::string code_gen_name = code_gen.value();
+ cached_ext_funcs.push_back(it.first);
+
+ auto symbol_name = src_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
+ ICHECK(symbol_name.defined()) << "No external symbol is set for:\n"
+ << AsText(src_func, false);
Review comment:
```suggestion
ICHECK(symbol_name.defined()) << "No external symbol is set for:\n"
<< AsText(src_func, false) << "\n" <<
"Functions with external codegen must have the "
<< attract::kGlobalSymbol << " attr
set.";
```
##########
File path: src/relay/backend/te_compiler.cc
##########
@@ -0,0 +1,372 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "te_compiler.h"
+
+#include <tvm/driver/driver_api.h>
+#include <tvm/ir/type_functor.h>
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/attrs/device_copy.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/op.h>
+#include <tvm/relay/op_attr_types.h>
+#include <tvm/runtime/container.h>
+#include <tvm/runtime/device_api.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/te/operation.h>
+#include <tvm/te/schedule.h>
+#include <tvm/te/schedule_pass.h>
+#include <tvm/topi/tags.h>
+
+#include <functional>
+#include <limits>
+#include <mutex>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "../transforms/pass_utils.h"
+#include "te_compiler_cache.h"
+#include "utils.h"
+
+namespace tvm {
+namespace relay {
+namespace tec {
+
+using namespace tvm::relay::transform;
+
+TVM_REGISTER_OBJECT_TYPE(TECompilerNode);
+
+class TECompilerImpl : public TECompilerNode {
+ public:
+ // Lower the function.
+ CachedFunc Lower(const CCacheKey& key) { return
LowerInternal(key)->cached_func; }
+
+ // For now, build one module per function.
+ PackedFunc JIT(const CCacheKey& key) final {
+ CCacheValue value = LowerInternal(key);
+ if (value->packed_func != nullptr) {
+ return value->packed_func;
+ }
+ auto m = build(value->cached_func->funcs, key->target, Target(nullptr));
+ value->packed_func =
m.GetFunction(value->cached_func->prim_fn_var->name_hint);
+ return value->packed_func;
+ }
+
+ CachedFunc LowerShapeFunc(const CCacheKey& key) final {
+ return LowerShapeFuncInternal(key)->cached_func;
+ }
+
+ Map<String, IRModule> GetLoweredFunctions() {
+ Map<String, IRModule> lowered_functions;
+ for (const auto& it : cache_) {
+ auto source_func = it.first;
+ auto lowered_func = it.second;
+ auto target = source_func->target;
+
+ if (!lowered_functions.count(target->str())) {
+ lowered_functions.Set(target->str(), IRModule(Map<GlobalVar,
BaseFunc>({})));
+ }
+
+
lowered_functions[target->str()]->Update(lowered_func->cached_func->funcs);
+ }
+ return lowered_functions;
+ }
+
+ Array<tvm::runtime::Module> LowerExternalFunctions() {
+ Array<tvm::runtime::Module> ret;
+ std::unordered_map<std::string, std::string> cached_symbol;
+ std::vector<CCacheKey> cached_ext_funcs;
+ for (const auto& it : cache_) {
+ auto src_func = it.first->source_func;
+ ICHECK(src_func.defined());
+ if (src_func->GetAttr<String>(attr::kCompiler).defined()) {
+ auto code_gen = src_func->GetAttr<String>(attr::kCompiler);
+ ICHECK(code_gen.defined()) << "No external codegen is set";
+ std::string code_gen_name = code_gen.value();
+ cached_ext_funcs.push_back(it.first);
+
+ auto symbol_name = src_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
+ ICHECK(symbol_name.defined()) << "No external symbol is set for:\n"
+ << AsText(src_func, false);
+
+ std::string sn = symbol_name.value();
+ if (cached_symbol.count(sn)) {
+ cached_symbol[sn] = code_gen_name;
+ } else {
+ ICHECK_NE(sn, code_gen_name)
+ << "Found duplicated symbol: " << sn << " for: " <<
code_gen_name;
+ }
+
+ std::string ext_name = "relay.ext." + code_gen_name;
+ auto pf = tvm::runtime::Registry::Get(ext_name);
+ ICHECK(pf) << "Failed to find the codegen tool for " << ext_name <<
"\n";
+ // No need to keep compiler attribute at this point, functions have
been
+ // extracted for specific codegen.
+ src_func = WithAttr(std::move(src_func), attr::kCompiler,
NullValue<ObjectRef>());
+ runtime::Module ext_mod = (*pf)(src_func);
+
+ ICHECK(ext_mod.defined()) << "No external runtime is generated.";
+ ret.push_back(ext_mod);
+ }
+ }
+
+ // No need to cache external functions as we collected them all to create
+ // external runtime modules.
+ for (const auto& it : cached_ext_funcs) {
+ cache_.erase(it);
+ }
+ return ret;
+ }
+
+ void Clear() final { cache_.clear(); }
+
+ // List all items in the cache.
+ Array<ObjectRef> ListItems() {
+ std::lock_guard<std::mutex> lock(mutex_);
+ Array<ObjectRef> items;
+ for (auto& kv : cache_) {
+ items.push_back(kv.first);
+ items.push_back(kv.second);
+ }
+ return items;
+ }
+
+ /*!
+ * \brief Get the cache key of the function that is being lowered currently
+ * \return the cache key
+ */
+ CCacheKey GetCurrentCCacheKey() { return cur_ccache_key_; }
+
+ private:
+ // implement lowered func
+ CCacheValue LowerInternal(const CCacheKey& key) {
+ std::lock_guard<std::mutex> lock(mutex_);
+ CCacheValue value;
+ auto it = cache_.find(key);
+ if (it != cache_.end()) {
+ it->second->use_count += 1;
+ if (it->second->cached_func.defined()) return it->second;
+ value = it->second;
+ } else {
+ value = CCacheValue(make_object<CCacheValueNode>());
+ value->use_count = 0;
+ if (!backend::IsCompileEngineCacheDisabled()) {
+ cache_[key] = value;
+ }
+ }
+ cur_ccache_key_ = key;
+
+ // No need to lower external functions for now. We will invoke the external
+ // codegen tool once and lower all functions together.
+ if (key->source_func->GetAttr<String>(attr::kCompiler).defined()) {
+ auto ir_module = IRModule();
+ const auto name_node =
key->source_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
+ ICHECK(name_node.defined()) << "External function has not been attached
a name yet.";
+ auto func_name = std::string(name_node.value());
+ auto target = Target("ext_dev");
+ auto global_var = GlobalVar(func_name);
+ global_var->checked_type_ = key->source_func->checked_type();
+ ir_module->Add(global_var, key->source_func);
+ value->cached_func = CachedFunc(target, global_var, {}, {},
te::Schedule(), {}, ir_module);
+ return value;
+ }
+ // Enforce use the target.
+ With<Target> target_scope(key->target);
+
+ ICHECK(!value->cached_func.defined());
+ auto cfunc = PrimFuncFor(key->source_func, key->target,
+ [&](std::string name) { return
GetUniqueName(name, name_map_); });
+
+ // Skip lowering for device copy node.
+ const Expr body = (key->source_func)->body;
+ if (const CallNode* call_node = body.as<CallNode>()) {
+ if (call_node->attrs.as<DeviceCopyAttrs>()) {
+ value->cached_func = cfunc;
+ return value;
+ }
+ }
+
+ std::cout << "Input Size: " << cfunc->inputs.size() << std::endl;
+ std::cout << "Output Size: " << cfunc->outputs.size() << std::endl;
+ // NOTE: array will copy on write.
+ Array<te::Tensor> all_args = Array<te::Tensor>(cfunc->inputs);
+ for (te::Tensor arg : cfunc->outputs) {
+ all_args.push_back(arg);
+ }
+
+ std::cout << "Allargs Size: " << all_args.size() << std::endl;
+
+ using tvm::transform::PassContext;
+ With<PassContext> fresh_pass_ctx_scope(PassContext::Create());
+
+ std::unordered_map<te::Tensor, tir::Buffer> binds;
+ auto func_name = cfunc->prim_fn_var->name_hint;
+ cfunc->funcs->Update(tvm::lower(cfunc->schedule, all_args, func_name,
binds));
+ value->cached_func = cfunc;
+ return value;
+ }
+
+ // implement lowered shape func
+ CCacheValue LowerShapeFuncInternal(const CCacheKey& key) {
+ std::lock_guard<std::mutex> lock(mutex_);
+ CCacheValue value;
+ auto it = shape_func_cache_.find(key);
+ if (it != shape_func_cache_.end()) {
+ it->second->use_count += 1;
+ if (it->second->cached_func.defined()) return it->second;
+ value = it->second;
+ } else {
+ value = CCacheValue(make_object<CCacheValueNode>());
+ value->use_count = 0;
+ shape_func_cache_[key] = value;
+ }
+ // Enforce use the target.
+ With<Target> target_scope(key->target);
+
+ ICHECK(!value->cached_func.defined());
+ auto cached_func = ShapeFuncFor(key->source_func, key->target,
[&](std::string name) {
+ return GetUniqueName(name, name_map_);
+ });
+
+ value->cached_func = cached_func;
+ return value;
+ }
+
+ /*! \brief compiler cache lock*/
+ std::mutex mutex_;
+ /*! \brief internal name map to get an unique name */
+ std::unordered_map<std::string, int> name_map_;
+ /*! \brief internal compiler cache */
+ std::unordered_map<CCacheKey, CCacheValue> cache_;
+ /*! \brief internal compiler cache for shape funcs */
+ std::unordered_map<CCacheKey, CCacheValue> shape_func_cache_;
+ /*! \brief the cache key of the function that is being lowered currently*/
+ CCacheKey cur_ccache_key_;
+};
+
+TECompiler::TECompiler() {
+ auto object = make_object<TECompilerImpl>();
+ data_ = object;
+}
+
+class LowerTensorExpr : public ExprMutator {
+ public:
+ LowerTensorExpr(const IRModule& module, const TargetsMap& targets,
+ const DeviceContextMap& device_ctx_map, TECompiler compiler)
+ : module_(module),
+ targets_(targets),
+ device_context_map_(device_ctx_map),
+ compiler_(compiler) {}
+
+ Expr VisitExpr_(const CallNode* call) override {
+ Call expr = GetRef<Call>(call);
+ Function func;
+
+ if (call->op.as<FunctionNode>()) {
+ func = GetRef<Function>(call->op.as<FunctionNode>());
+ } else {
+ return ExprMutator::VisitExpr_(call);
+ }
+
+ if (!func->HasNonzeroAttr(attr::kPrimitive)) {
+ // LOG(FATAL) << "TVM only support calls to primitive functions "
+ // << "(i.e functions composed of fusable operator
invocations)";
+ return ExprMutator::VisitExpr_(call);
+ }
+
+ // Process inputs.
+ Array<Expr> args;
+ for (size_t i = 0; i < expr->args.size(); i++) {
+ args.push_back(VisitExpr(expr->args[i]));
+ }
+
+ Target target;
+
+ if (func->GetAttr<String>(attr::kCompiler).defined()) {
+ target = Target("ext_dev");
+ CCacheKey key = CCacheKey(func, target);
+ CachedFunc ext_func = compiler_->Lower(key);
+ ICHECK(ext_func.defined()) << "External function is not defined.";
Review comment:
Can you add the function name?
##########
File path: src/relay/backend/te_compiler_cache.cc
##########
@@ -0,0 +1,681 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "./te_compiler_cache.h"
+
+#include <tvm/driver/driver_api.h>
+#include <tvm/ir/type_functor.h>
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/attrs/device_copy.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/op.h>
+#include <tvm/relay/op_attr_types.h>
+#include <tvm/runtime/container.h>
+#include <tvm/runtime/device_api.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/te/operation.h>
+#include <tvm/te/schedule.h>
+#include <tvm/te/schedule_pass.h>
+#include <tvm/topi/tags.h>
+
+#include <functional>
+#include <limits>
+#include <mutex>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "../transforms/pass_utils.h"
+#include "utils.h"
+
+namespace tvm {
+namespace relay {
+namespace tec {
+
+TVM_REGISTER_NODE_TYPE(LoweredOutputNode);
+TVM_REGISTER_NODE_TYPE(CachedFuncNode);
+TVM_REGISTER_NODE_TYPE(CCacheKeyNode);
+TVM_REGISTER_NODE_TYPE(CCacheValueNode);
+
+LoweredOutput::LoweredOutput(tvm::Array<te::Tensor> outputs, OpImplementation
impl) {
+ auto n = make_object<LoweredOutputNode>();
+ n->outputs = std::move(outputs);
+ n->implementation = std::move(impl);
+ data_ = std::move(n);
+}
+
+CCacheKey::CCacheKey(Function source_func, Target target) {
+ auto n = make_object<CCacheKeyNode>();
+ n->source_func = std::move(source_func);
+ n->target = std::move(target);
+ data_ = std::move(n);
+}
+
+CachedFunc::CachedFunc(tvm::Target target, GlobalVar prim_fn_var,
tvm::Array<te::Tensor> inputs,
+ tvm::Array<te::Tensor> outputs, te::Schedule schedule,
+ tvm::Array<Integer> shape_func_param_states, IRModule
funcs) {
+ auto n = make_object<CachedFuncNode>();
+ n->target = target;
+ n->prim_fn_var = prim_fn_var;
+ n->inputs = inputs;
+ n->outputs = outputs;
+ n->schedule = schedule;
+ n->shape_func_param_states = shape_func_param_states;
+ n->funcs = funcs;
+ data_ = std::move(n);
+}
+
+Array<IndexExpr> GetShape(const Array<IndexExpr>& shape) {
+ // for now, we always use int32 shape when possible
+ // even if the result of shape inference becomes int64.
+ Array<IndexExpr> res;
+ for (IndexExpr val : shape) {
+ const int64_t* pval = tir::as_const_int(val);
+ if (pval != nullptr) {
+#ifndef TVM_INDEX_DEFAULT_I64
+ ICHECK_LE(pval[0], std::numeric_limits<int32_t>::max());
+ ICHECK_GE(pval[0], std::numeric_limits<int32_t>::min());
+ res.push_back(IntImm(DataType::Int(32), *pval));
+#else
+ res.push_back(val);
+#endif // TVM_INDEX_DEFAULT_I64
+ } else if (val->IsInstance<tir::AnyNode>()) {
+ res.push_back(val.as<tir::AnyNode>()->ToVar());
+ } else {
+ res.push_back(val);
+ }
+ }
+ return res;
+}
+
+// The getter to get schedule from compile engine.
+// Get schedule from functor.
+class ScheduleGetter : public
backend::MemoizedExprTranslator<Array<te::Tensor>> {
+ public:
+ explicit ScheduleGetter(Target target)
+ : target_(target), device_copy_op_(Op::Get("device_copy")) {
+ // Whether to use auto_scheduler schedule.
+ use_auto_scheduler_ = backend::IsAutoSchedulerEnabled();
+ }
+
+ CachedFunc Create(const Function& prim_func,
std::function<std::string(std::string)> renamer) {
+ Array<tvm::te::Tensor> fn_inputs;
+ for (Var param : prim_func->params) {
+ Array<tvm::te::Tensor> inputs;
+ if (const auto* ttype = param->checked_type().as<TensorTypeNode>()) {
+ tvm::te::Tensor tensor = tvm::te::placeholder(GetShape(ttype->shape),
ttype->dtype);
+ fn_inputs.push_back(tensor);
+ inputs.push_back(tensor);
+ } else {
+ // flatten tuple of tensor type.
+ const auto* tuple_type = param->type_as<TupleTypeNode>();
+ for (Type field : tuple_type->fields) {
+ const auto* ttype = field.as<TensorTypeNode>();
+ // TODO(@icemelon): Allow recursive tuple
+ ICHECK(ttype != nullptr);
+ tvm::te::Tensor tensor =
tvm::te::placeholder(GetShape(ttype->shape), ttype->dtype);
+ fn_inputs.push_back(tensor);
+ inputs.push_back(tensor);
+ }
+ }
+ memo_[param] = inputs;
+ }
+ readable_name_stream_ << "fused";
+ auto outputs = this->VisitExpr(prim_func->body);
+ auto candidate_name = readable_name_stream_.str();
+ constexpr static size_t kMaxFuncNameLength = 80;
+ if (candidate_name.size() > kMaxFuncNameLength) {
+ std::stringstream truncated_name;
+ truncated_name << candidate_name.substr(0, kMaxFuncNameLength);
+ truncated_name << "_" << std::hash<std::string>{}(candidate_name) << "_";
+ candidate_name = truncated_name.str();
+ }
+
+ ICHECK(anchor_op_.defined());
+ // Fusion over tupled results may leave identity relationships
+ // between inputs and outputs, and those should not be scheduled.
+ // Hence schedule only non PlaceholderOp outputs.
+ tvm::Array<te::Tensor> tensor_outs;
+ for (const auto& tensor : outputs) {
+ if (!tensor->op.as<te::PlaceholderOpNode>()) {
+ tensor_outs.push_back(tensor);
+ }
+ }
+
+ te::Schedule schedule;
+ // No need to register schedule for device copy op.
+ if (anchor_attrs_.as<DeviceCopyAttrs>() == nullptr) {
+ if (use_auto_scheduler_) {
+ const auto* fauto_schedule =
+
runtime::Registry::Get("auto_scheduler.relay_integration.auto_schedule_topi_compute");
+ ICHECK(fauto_schedule != nullptr)
+ << "auto_scheduler.relay_integration.auto_schedule_topi_compute is
not registered";
+ ObjectRef obj = (*fauto_schedule)(tensor_outs);
+ if (obj.defined()) {
+ schedule = Downcast<te::Schedule>(obj);
+ }
+ }
+
+ // Use TOPI schdule if user specificed, or the function has no
auto_scheduler schedule.
+ if (!schedule.defined()) {
+ ICHECK(anchor_implementation_.defined());
+ schedule = anchor_implementation_.Schedule(anchor_attrs_, tensor_outs,
target_);
+ }
+ for (const auto& scalar : scalars_) {
+ if (schedule->Contain(scalar)) {
+ schedule[scalar].compute_inline();
+ }
+ }
+ }
+
+ auto prim_fn_var = GlobalVar(candidate_name);
+ prim_fn_var->checked_type_ = prim_func->checked_type();
+ return CachedFunc(target_, prim_fn_var, fn_inputs, outputs, schedule, {});
+ }
+
+ Array<te::Tensor> VisitExpr_(const VarNode* op) final {
+ LOG(FATAL) << "Free variable " << op->name_hint();
Review comment:
Could you make this a little more detailed? Maybe even "Unexpected free
variable " << op->name_hint << ", expected " ....
##########
File path: src/relay/backend/te_compiler.cc
##########
@@ -0,0 +1,372 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "te_compiler.h"
+
+#include <tvm/driver/driver_api.h>
+#include <tvm/ir/type_functor.h>
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/attrs/device_copy.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/op.h>
+#include <tvm/relay/op_attr_types.h>
+#include <tvm/runtime/container.h>
+#include <tvm/runtime/device_api.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/te/operation.h>
+#include <tvm/te/schedule.h>
+#include <tvm/te/schedule_pass.h>
+#include <tvm/topi/tags.h>
+
+#include <functional>
+#include <limits>
+#include <mutex>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "../transforms/pass_utils.h"
+#include "te_compiler_cache.h"
+#include "utils.h"
+
+namespace tvm {
+namespace relay {
+namespace tec {
+
+using namespace tvm::relay::transform;
+
+TVM_REGISTER_OBJECT_TYPE(TECompilerNode);
+
+class TECompilerImpl : public TECompilerNode {
+ public:
+ // Lower the function.
+ CachedFunc Lower(const CCacheKey& key) { return
LowerInternal(key)->cached_func; }
+
+ // For now, build one module per function.
+ PackedFunc JIT(const CCacheKey& key) final {
+ CCacheValue value = LowerInternal(key);
+ if (value->packed_func != nullptr) {
+ return value->packed_func;
+ }
+ auto m = build(value->cached_func->funcs, key->target, Target(nullptr));
+ value->packed_func =
m.GetFunction(value->cached_func->prim_fn_var->name_hint);
+ return value->packed_func;
+ }
+
+ CachedFunc LowerShapeFunc(const CCacheKey& key) final {
+ return LowerShapeFuncInternal(key)->cached_func;
+ }
+
+ Map<String, IRModule> GetLoweredFunctions() {
+ Map<String, IRModule> lowered_functions;
+ for (const auto& it : cache_) {
+ auto source_func = it.first;
+ auto lowered_func = it.second;
+ auto target = source_func->target;
+
+ if (!lowered_functions.count(target->str())) {
+ lowered_functions.Set(target->str(), IRModule(Map<GlobalVar,
BaseFunc>({})));
+ }
+
+
lowered_functions[target->str()]->Update(lowered_func->cached_func->funcs);
+ }
+ return lowered_functions;
+ }
+
+ Array<tvm::runtime::Module> LowerExternalFunctions() {
+ Array<tvm::runtime::Module> ret;
+ std::unordered_map<std::string, std::string> cached_symbol;
+ std::vector<CCacheKey> cached_ext_funcs;
+ for (const auto& it : cache_) {
+ auto src_func = it.first->source_func;
+ ICHECK(src_func.defined());
+ if (src_func->GetAttr<String>(attr::kCompiler).defined()) {
+ auto code_gen = src_func->GetAttr<String>(attr::kCompiler);
+ ICHECK(code_gen.defined()) << "No external codegen is set";
Review comment:
Is this check necessary? Didn't line 100 already check if `code_gen`
(`src_func->GetAttr<String>(attr::kCompiler)`) was defined?
##########
File path: src/relay/backend/graph_runtime_codegen.cc
##########
@@ -349,65 +380,47 @@ class GraphRuntimeCodegen : public
backend::MemoizedExprTranslator<std::vector<G
return AddNode(node, GetRef<Expr>(op));
}
- std::vector<GraphNodeRef> VisitExpr_(const CallNode* op) override {
- Expr expr = GetRef<Expr>(op);
- Function func;
- if (op->op.as<OpNode>()) {
- LOG(FATAL) << "Operators should be transformed away; try applying"
- << "the fuse_ops transformation to the expression.";
- } else if (op->op.as<GlobalVarNode>()) {
- LOG(FATAL) << "Not implemented";
- } else if (op->op.as<FunctionNode>()) {
- func = GetRef<Function>(op->op.as<FunctionNode>());
- } else {
- LOG(FATAL) << "TVM runtime does not support calls to " <<
op->op->GetTypeKey();
- }
- if (!func->HasNonzeroAttr(attr::kPrimitive)) {
- LOG(FATAL) << "TVM only support calls to primitive functions "
- << "(i.e functions composed of fusable operator invocations)";
- }
+ std::vector<GraphNodeRef> VisitExpr_(const CallNode* call_node) override {
+ relay::Call call = GetRef<Call>(call_node);
- auto pf0 = GetPackedFunc("relay.backend._make_CCacheKey");
- auto pf1 = GetPackedFunc("relay.backend._CompileEngineLower");
- Target target;
- // Handle external function
- if (func->GetAttr<String>(attr::kCompiler).defined()) {
- target = Target("ext_dev");
- CCacheKey key = (*pf0)(func, target);
- CachedFunc ext_func = (*pf1)(compile_engine_, key);
- ICHECK(ext_func.defined()) << "External function is not defined.";
- UpdateConstants(func, ¶ms_);
- return GraphAddCallNode(op, ext_func->func_name, ext_func->func_name);
- }
+ if (auto global_node = call->op.as<GlobalVarNode>()) {
+ auto prim_fn_name = global_node->name_hint;
- ICHECK_GE(storage_device_map_.count(expr), 0);
- auto& device_type = storage_device_map_[expr][1];
- auto call_dev_type = device_type[0]->value;
- // Normal Relay Function
- if (targets_.size() == 1) {
- // homogeneous execution.
- const auto& it = targets_.begin();
- target = (*it).second;
- } else {
- // heterogeneous execution.
- std::string call_dev_name;
- if (call_dev_type == 0) {
- call_dev_name = "llvm";
+ Target target;
+
+ // // Handle external function
+ // if (func->GetAttr<String>(attr::kCompiler).defined()) {
+ // UpdateConstants(func, ¶ms_);
+ // return GraphAddCallNode(call_node, prim_fn_name, prim_fn_name);
+ // }
+
+ ICHECK_GE(storage_device_map_.count(call), 0);
+ auto& device_type = storage_device_map_[call][1];
+ auto call_dev_type = device_type[0]->value;
+ // Normal Relay Function
+ if (targets_.size() == 1) {
+ // homogeneous execution.
+ const auto& it = targets_.begin();
+ target = (*it).second;
} else {
- call_dev_name = runtime::DeviceName(call_dev_type);
- }
- if (targets_.count(call_dev_type) == 0) {
- LOG(FATAL) << "No target is provided for device " << call_dev_name;
+ // heterogeneous execution.
+ std::string call_dev_name;
+ if (call_dev_type == 0) {
+ call_dev_name = "llvm";
+ } else {
+ call_dev_name = runtime::DeviceName(call_dev_type);
+ }
+ if (targets_.count(call_dev_type) == 0) {
+ LOG(FATAL) << "No target is provided for device " << call_dev_name;
+ }
+ target = targets_[call_dev_type];
}
- target = targets_[call_dev_type];
- }
- CCacheKey key = (*pf0)(func, target);
- CachedFunc lowered_func = (*pf1)(compile_engine_, key);
- if (!lowered_funcs_.count(target->str())) {
- lowered_funcs_[target->str()] = IRModule(Map<GlobalVar, BaseFunc>({}));
+
+ return GraphAddCallNode(call_node, _GetUniqueName(prim_fn_name),
prim_fn_name);
+ } else {
+ LOG(FATAL) << "BadCase: " << PrettyPrint(call) << std::endl;
Review comment:
```suggestion
LOG(FATAL) << "Graph runtime codegen can only handle calls to global
functions, but it got a " call->GetTypeKey() << " (should be GlobalVarNode).
This is what was provided: " << PrettyPrint(call);
```
##########
File path: src/driver/driver_api.cc
##########
@@ -244,14 +244,17 @@ std::pair<IRModule, IRModule> SplitDevHostFuncs(IRModule
mod_mixed, const Target
}
if (target->kind->device_type == kDLCPU && target_host == target) {
- ICHECK(mdevice->functions.empty()) << "No device code should be generated
when target "
- << "and host_target are both llvm
target."
- << "\n";
+ // We need to relax this check for just TIR functions.
Review comment:
Maybe add a todo to fix this up?
##########
File path: src/relay/backend/graph_runtime_codegen.cc
##########
@@ -349,65 +380,47 @@ class GraphRuntimeCodegen : public
backend::MemoizedExprTranslator<std::vector<G
return AddNode(node, GetRef<Expr>(op));
}
- std::vector<GraphNodeRef> VisitExpr_(const CallNode* op) override {
- Expr expr = GetRef<Expr>(op);
- Function func;
- if (op->op.as<OpNode>()) {
- LOG(FATAL) << "Operators should be transformed away; try applying"
- << "the fuse_ops transformation to the expression.";
- } else if (op->op.as<GlobalVarNode>()) {
- LOG(FATAL) << "Not implemented";
- } else if (op->op.as<FunctionNode>()) {
- func = GetRef<Function>(op->op.as<FunctionNode>());
- } else {
- LOG(FATAL) << "TVM runtime does not support calls to " <<
op->op->GetTypeKey();
- }
- if (!func->HasNonzeroAttr(attr::kPrimitive)) {
- LOG(FATAL) << "TVM only support calls to primitive functions "
- << "(i.e functions composed of fusable operator invocations)";
- }
+ std::vector<GraphNodeRef> VisitExpr_(const CallNode* call_node) override {
+ relay::Call call = GetRef<Call>(call_node);
- auto pf0 = GetPackedFunc("relay.backend._make_CCacheKey");
- auto pf1 = GetPackedFunc("relay.backend._CompileEngineLower");
- Target target;
- // Handle external function
- if (func->GetAttr<String>(attr::kCompiler).defined()) {
- target = Target("ext_dev");
- CCacheKey key = (*pf0)(func, target);
- CachedFunc ext_func = (*pf1)(compile_engine_, key);
- ICHECK(ext_func.defined()) << "External function is not defined.";
- UpdateConstants(func, ¶ms_);
- return GraphAddCallNode(op, ext_func->func_name, ext_func->func_name);
- }
+ if (auto global_node = call->op.as<GlobalVarNode>()) {
+ auto prim_fn_name = global_node->name_hint;
- ICHECK_GE(storage_device_map_.count(expr), 0);
- auto& device_type = storage_device_map_[expr][1];
- auto call_dev_type = device_type[0]->value;
- // Normal Relay Function
- if (targets_.size() == 1) {
- // homogeneous execution.
- const auto& it = targets_.begin();
- target = (*it).second;
- } else {
- // heterogeneous execution.
- std::string call_dev_name;
- if (call_dev_type == 0) {
- call_dev_name = "llvm";
+ Target target;
+
+ // // Handle external function
+ // if (func->GetAttr<String>(attr::kCompiler).defined()) {
+ // UpdateConstants(func, ¶ms_);
+ // return GraphAddCallNode(call_node, prim_fn_name, prim_fn_name);
+ // }
+
+ ICHECK_GE(storage_device_map_.count(call), 0);
Review comment:
```suggestion
ICHECK_GE(storage_device_map_.count(call), 0) << "Could not find a
storage device for " << prim_fn_name << ". This could be cause my a error in
GraphPlanMemory.";
```
##########
File path: src/relay/backend/te_compiler.cc
##########
@@ -0,0 +1,372 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "te_compiler.h"
+
+#include <tvm/driver/driver_api.h>
+#include <tvm/ir/type_functor.h>
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/attrs/device_copy.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/op.h>
+#include <tvm/relay/op_attr_types.h>
+#include <tvm/runtime/container.h>
+#include <tvm/runtime/device_api.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/te/operation.h>
+#include <tvm/te/schedule.h>
+#include <tvm/te/schedule_pass.h>
+#include <tvm/topi/tags.h>
+
+#include <functional>
+#include <limits>
+#include <mutex>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "../transforms/pass_utils.h"
+#include "te_compiler_cache.h"
+#include "utils.h"
+
+namespace tvm {
+namespace relay {
+namespace tec {
+
+using namespace tvm::relay::transform;
+
+TVM_REGISTER_OBJECT_TYPE(TECompilerNode);
+
+class TECompilerImpl : public TECompilerNode {
+ public:
+ // Lower the function.
+ CachedFunc Lower(const CCacheKey& key) { return
LowerInternal(key)->cached_func; }
+
+ // For now, build one module per function.
+ PackedFunc JIT(const CCacheKey& key) final {
+ CCacheValue value = LowerInternal(key);
+ if (value->packed_func != nullptr) {
+ return value->packed_func;
+ }
+ auto m = build(value->cached_func->funcs, key->target, Target(nullptr));
+ value->packed_func =
m.GetFunction(value->cached_func->prim_fn_var->name_hint);
+ return value->packed_func;
+ }
+
+ CachedFunc LowerShapeFunc(const CCacheKey& key) final {
+ return LowerShapeFuncInternal(key)->cached_func;
+ }
+
+ Map<String, IRModule> GetLoweredFunctions() {
+ Map<String, IRModule> lowered_functions;
+ for (const auto& it : cache_) {
+ auto source_func = it.first;
+ auto lowered_func = it.second;
+ auto target = source_func->target;
+
+ if (!lowered_functions.count(target->str())) {
+ lowered_functions.Set(target->str(), IRModule(Map<GlobalVar,
BaseFunc>({})));
+ }
+
+
lowered_functions[target->str()]->Update(lowered_func->cached_func->funcs);
+ }
+ return lowered_functions;
+ }
+
+ Array<tvm::runtime::Module> LowerExternalFunctions() {
+ Array<tvm::runtime::Module> ret;
+ std::unordered_map<std::string, std::string> cached_symbol;
+ std::vector<CCacheKey> cached_ext_funcs;
+ for (const auto& it : cache_) {
+ auto src_func = it.first->source_func;
+ ICHECK(src_func.defined());
+ if (src_func->GetAttr<String>(attr::kCompiler).defined()) {
+ auto code_gen = src_func->GetAttr<String>(attr::kCompiler);
+ ICHECK(code_gen.defined()) << "No external codegen is set";
+ std::string code_gen_name = code_gen.value();
+ cached_ext_funcs.push_back(it.first);
+
+ auto symbol_name = src_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
+ ICHECK(symbol_name.defined()) << "No external symbol is set for:\n"
+ << AsText(src_func, false);
+
+ std::string sn = symbol_name.value();
+ if (cached_symbol.count(sn)) {
+ cached_symbol[sn] = code_gen_name;
+ } else {
+ ICHECK_NE(sn, code_gen_name)
+ << "Found duplicated symbol: " << sn << " for: " <<
code_gen_name;
+ }
+
+ std::string ext_name = "relay.ext." + code_gen_name;
+ auto pf = tvm::runtime::Registry::Get(ext_name);
+ ICHECK(pf) << "Failed to find the codegen tool for " << ext_name <<
"\n";
+ // No need to keep compiler attribute at this point, functions have
been
+ // extracted for specific codegen.
+ src_func = WithAttr(std::move(src_func), attr::kCompiler,
NullValue<ObjectRef>());
+ runtime::Module ext_mod = (*pf)(src_func);
+
+ ICHECK(ext_mod.defined()) << "No external runtime is generated.";
+ ret.push_back(ext_mod);
+ }
+ }
+
+ // No need to cache external functions as we collected them all to create
+ // external runtime modules.
+ for (const auto& it : cached_ext_funcs) {
+ cache_.erase(it);
+ }
+ return ret;
+ }
+
+ void Clear() final { cache_.clear(); }
+
+ // List all items in the cache.
+ Array<ObjectRef> ListItems() {
+ std::lock_guard<std::mutex> lock(mutex_);
+ Array<ObjectRef> items;
+ for (auto& kv : cache_) {
+ items.push_back(kv.first);
+ items.push_back(kv.second);
+ }
+ return items;
+ }
+
+ /*!
+ * \brief Get the cache key of the function that is being lowered currently
+ * \return the cache key
+ */
+ CCacheKey GetCurrentCCacheKey() { return cur_ccache_key_; }
+
+ private:
+ // implement lowered func
+ CCacheValue LowerInternal(const CCacheKey& key) {
+ std::lock_guard<std::mutex> lock(mutex_);
+ CCacheValue value;
+ auto it = cache_.find(key);
+ if (it != cache_.end()) {
+ it->second->use_count += 1;
+ if (it->second->cached_func.defined()) return it->second;
+ value = it->second;
+ } else {
+ value = CCacheValue(make_object<CCacheValueNode>());
+ value->use_count = 0;
+ if (!backend::IsCompileEngineCacheDisabled()) {
+ cache_[key] = value;
+ }
+ }
+ cur_ccache_key_ = key;
+
+ // No need to lower external functions for now. We will invoke the external
+ // codegen tool once and lower all functions together.
+ if (key->source_func->GetAttr<String>(attr::kCompiler).defined()) {
+ auto ir_module = IRModule();
+ const auto name_node =
key->source_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
+ ICHECK(name_node.defined()) << "External function has not been attached
a name yet.";
+ auto func_name = std::string(name_node.value());
+ auto target = Target("ext_dev");
+ auto global_var = GlobalVar(func_name);
+ global_var->checked_type_ = key->source_func->checked_type();
+ ir_module->Add(global_var, key->source_func);
+ value->cached_func = CachedFunc(target, global_var, {}, {},
te::Schedule(), {}, ir_module);
+ return value;
+ }
+ // Enforce use the target.
+ With<Target> target_scope(key->target);
+
+ ICHECK(!value->cached_func.defined());
+ auto cfunc = PrimFuncFor(key->source_func, key->target,
+ [&](std::string name) { return
GetUniqueName(name, name_map_); });
+
+ // Skip lowering for device copy node.
+ const Expr body = (key->source_func)->body;
+ if (const CallNode* call_node = body.as<CallNode>()) {
+ if (call_node->attrs.as<DeviceCopyAttrs>()) {
+ value->cached_func = cfunc;
+ return value;
+ }
+ }
+
+ std::cout << "Input Size: " << cfunc->inputs.size() << std::endl;
+ std::cout << "Output Size: " << cfunc->outputs.size() << std::endl;
Review comment:
remove?
##########
File path: src/relay/backend/graph_runtime_codegen.cc
##########
@@ -349,65 +380,47 @@ class GraphRuntimeCodegen : public
backend::MemoizedExprTranslator<std::vector<G
return AddNode(node, GetRef<Expr>(op));
}
- std::vector<GraphNodeRef> VisitExpr_(const CallNode* op) override {
- Expr expr = GetRef<Expr>(op);
- Function func;
- if (op->op.as<OpNode>()) {
- LOG(FATAL) << "Operators should be transformed away; try applying"
- << "the fuse_ops transformation to the expression.";
- } else if (op->op.as<GlobalVarNode>()) {
- LOG(FATAL) << "Not implemented";
- } else if (op->op.as<FunctionNode>()) {
- func = GetRef<Function>(op->op.as<FunctionNode>());
- } else {
- LOG(FATAL) << "TVM runtime does not support calls to " <<
op->op->GetTypeKey();
- }
- if (!func->HasNonzeroAttr(attr::kPrimitive)) {
- LOG(FATAL) << "TVM only support calls to primitive functions "
- << "(i.e functions composed of fusable operator invocations)";
- }
+ std::vector<GraphNodeRef> VisitExpr_(const CallNode* call_node) override {
+ relay::Call call = GetRef<Call>(call_node);
- auto pf0 = GetPackedFunc("relay.backend._make_CCacheKey");
- auto pf1 = GetPackedFunc("relay.backend._CompileEngineLower");
- Target target;
- // Handle external function
- if (func->GetAttr<String>(attr::kCompiler).defined()) {
- target = Target("ext_dev");
- CCacheKey key = (*pf0)(func, target);
- CachedFunc ext_func = (*pf1)(compile_engine_, key);
- ICHECK(ext_func.defined()) << "External function is not defined.";
- UpdateConstants(func, ¶ms_);
- return GraphAddCallNode(op, ext_func->func_name, ext_func->func_name);
- }
+ if (auto global_node = call->op.as<GlobalVarNode>()) {
+ auto prim_fn_name = global_node->name_hint;
- ICHECK_GE(storage_device_map_.count(expr), 0);
- auto& device_type = storage_device_map_[expr][1];
- auto call_dev_type = device_type[0]->value;
- // Normal Relay Function
- if (targets_.size() == 1) {
- // homogeneous execution.
- const auto& it = targets_.begin();
- target = (*it).second;
- } else {
- // heterogeneous execution.
- std::string call_dev_name;
- if (call_dev_type == 0) {
- call_dev_name = "llvm";
+ Target target;
+
+ // // Handle external function
+ // if (func->GetAttr<String>(attr::kCompiler).defined()) {
+ // UpdateConstants(func, ¶ms_);
+ // return GraphAddCallNode(call_node, prim_fn_name, prim_fn_name);
+ // }
Review comment:
Delete?
##########
File path: src/relay/backend/te_compiler.cc
##########
@@ -0,0 +1,372 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "te_compiler.h"
+
+#include <tvm/driver/driver_api.h>
+#include <tvm/ir/type_functor.h>
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/attrs/device_copy.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/op.h>
+#include <tvm/relay/op_attr_types.h>
+#include <tvm/runtime/container.h>
+#include <tvm/runtime/device_api.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/te/operation.h>
+#include <tvm/te/schedule.h>
+#include <tvm/te/schedule_pass.h>
+#include <tvm/topi/tags.h>
+
+#include <functional>
+#include <limits>
+#include <mutex>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "../transforms/pass_utils.h"
+#include "te_compiler_cache.h"
+#include "utils.h"
+
+namespace tvm {
+namespace relay {
+namespace tec {
+
+using namespace tvm::relay::transform;
+
+TVM_REGISTER_OBJECT_TYPE(TECompilerNode);
+
+class TECompilerImpl : public TECompilerNode {
+ public:
+ // Lower the function.
+ CachedFunc Lower(const CCacheKey& key) { return
LowerInternal(key)->cached_func; }
+
+ // For now, build one module per function.
+ PackedFunc JIT(const CCacheKey& key) final {
+ CCacheValue value = LowerInternal(key);
+ if (value->packed_func != nullptr) {
+ return value->packed_func;
+ }
+ auto m = build(value->cached_func->funcs, key->target, Target(nullptr));
+ value->packed_func =
m.GetFunction(value->cached_func->prim_fn_var->name_hint);
+ return value->packed_func;
+ }
+
+ CachedFunc LowerShapeFunc(const CCacheKey& key) final {
+ return LowerShapeFuncInternal(key)->cached_func;
+ }
+
+ Map<String, IRModule> GetLoweredFunctions() {
+ Map<String, IRModule> lowered_functions;
+ for (const auto& it : cache_) {
+ auto source_func = it.first;
+ auto lowered_func = it.second;
+ auto target = source_func->target;
+
+ if (!lowered_functions.count(target->str())) {
+ lowered_functions.Set(target->str(), IRModule(Map<GlobalVar,
BaseFunc>({})));
+ }
+
+
lowered_functions[target->str()]->Update(lowered_func->cached_func->funcs);
+ }
+ return lowered_functions;
+ }
+
+ Array<tvm::runtime::Module> LowerExternalFunctions() {
+ Array<tvm::runtime::Module> ret;
+ std::unordered_map<std::string, std::string> cached_symbol;
+ std::vector<CCacheKey> cached_ext_funcs;
+ for (const auto& it : cache_) {
+ auto src_func = it.first->source_func;
+ ICHECK(src_func.defined());
+ if (src_func->GetAttr<String>(attr::kCompiler).defined()) {
+ auto code_gen = src_func->GetAttr<String>(attr::kCompiler);
+ ICHECK(code_gen.defined()) << "No external codegen is set";
+ std::string code_gen_name = code_gen.value();
+ cached_ext_funcs.push_back(it.first);
+
+ auto symbol_name = src_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
+ ICHECK(symbol_name.defined()) << "No external symbol is set for:\n"
+ << AsText(src_func, false);
+
+ std::string sn = symbol_name.value();
+ if (cached_symbol.count(sn)) {
+ cached_symbol[sn] = code_gen_name;
+ } else {
+ ICHECK_NE(sn, code_gen_name)
+ << "Found duplicated symbol: " << sn << " for: " <<
code_gen_name;
+ }
+
+ std::string ext_name = "relay.ext." + code_gen_name;
+ auto pf = tvm::runtime::Registry::Get(ext_name);
+ ICHECK(pf) << "Failed to find the codegen tool for " << ext_name <<
"\n";
+ // No need to keep compiler attribute at this point, functions have
been
+ // extracted for specific codegen.
+ src_func = WithAttr(std::move(src_func), attr::kCompiler,
NullValue<ObjectRef>());
+ runtime::Module ext_mod = (*pf)(src_func);
+
+ ICHECK(ext_mod.defined()) << "No external runtime is generated.";
Review comment:
```suggestion
ICHECK(ext_mod.defined()) << "No external runtime was generated by "
<< ext_name << ".";
```
##########
File path: src/relay/backend/te_compiler_cache.cc
##########
@@ -0,0 +1,681 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "./te_compiler_cache.h"
+
+#include <tvm/driver/driver_api.h>
+#include <tvm/ir/type_functor.h>
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/attrs/device_copy.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/op.h>
+#include <tvm/relay/op_attr_types.h>
+#include <tvm/runtime/container.h>
+#include <tvm/runtime/device_api.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/te/operation.h>
+#include <tvm/te/schedule.h>
+#include <tvm/te/schedule_pass.h>
+#include <tvm/topi/tags.h>
+
+#include <functional>
+#include <limits>
+#include <mutex>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "../transforms/pass_utils.h"
+#include "utils.h"
+
+namespace tvm {
+namespace relay {
+namespace tec {
+
+TVM_REGISTER_NODE_TYPE(LoweredOutputNode);
+TVM_REGISTER_NODE_TYPE(CachedFuncNode);
+TVM_REGISTER_NODE_TYPE(CCacheKeyNode);
+TVM_REGISTER_NODE_TYPE(CCacheValueNode);
+
+LoweredOutput::LoweredOutput(tvm::Array<te::Tensor> outputs, OpImplementation
impl) {
+ auto n = make_object<LoweredOutputNode>();
+ n->outputs = std::move(outputs);
+ n->implementation = std::move(impl);
+ data_ = std::move(n);
+}
+
+CCacheKey::CCacheKey(Function source_func, Target target) {
+ auto n = make_object<CCacheKeyNode>();
+ n->source_func = std::move(source_func);
+ n->target = std::move(target);
+ data_ = std::move(n);
+}
+
+CachedFunc::CachedFunc(tvm::Target target, GlobalVar prim_fn_var,
tvm::Array<te::Tensor> inputs,
+ tvm::Array<te::Tensor> outputs, te::Schedule schedule,
+ tvm::Array<Integer> shape_func_param_states, IRModule
funcs) {
+ auto n = make_object<CachedFuncNode>();
+ n->target = target;
+ n->prim_fn_var = prim_fn_var;
+ n->inputs = inputs;
+ n->outputs = outputs;
+ n->schedule = schedule;
+ n->shape_func_param_states = shape_func_param_states;
+ n->funcs = funcs;
+ data_ = std::move(n);
+}
+
+Array<IndexExpr> GetShape(const Array<IndexExpr>& shape) {
+ // for now, we always use int32 shape when possible
+ // even if the result of shape inference becomes int64.
+ Array<IndexExpr> res;
+ for (IndexExpr val : shape) {
+ const int64_t* pval = tir::as_const_int(val);
+ if (pval != nullptr) {
+#ifndef TVM_INDEX_DEFAULT_I64
+ ICHECK_LE(pval[0], std::numeric_limits<int32_t>::max());
Review comment:
Can we have an error message on these?
##########
File path: src/relay/backend/te_compiler.cc
##########
@@ -0,0 +1,372 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "te_compiler.h"
+
+#include <tvm/driver/driver_api.h>
+#include <tvm/ir/type_functor.h>
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/attrs/device_copy.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/op.h>
+#include <tvm/relay/op_attr_types.h>
+#include <tvm/runtime/container.h>
+#include <tvm/runtime/device_api.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/te/operation.h>
+#include <tvm/te/schedule.h>
+#include <tvm/te/schedule_pass.h>
+#include <tvm/topi/tags.h>
+
+#include <functional>
+#include <limits>
+#include <mutex>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "../transforms/pass_utils.h"
+#include "te_compiler_cache.h"
+#include "utils.h"
+
+namespace tvm {
+namespace relay {
+namespace tec {
+
+using namespace tvm::relay::transform;
+
+TVM_REGISTER_OBJECT_TYPE(TECompilerNode);
+
+class TECompilerImpl : public TECompilerNode {
+ public:
+ // Lower the function.
+ CachedFunc Lower(const CCacheKey& key) { return
LowerInternal(key)->cached_func; }
+
+ // For now, build one module per function.
+ PackedFunc JIT(const CCacheKey& key) final {
+ CCacheValue value = LowerInternal(key);
+ if (value->packed_func != nullptr) {
+ return value->packed_func;
+ }
+ auto m = build(value->cached_func->funcs, key->target, Target(nullptr));
+ value->packed_func =
m.GetFunction(value->cached_func->prim_fn_var->name_hint);
+ return value->packed_func;
+ }
+
+ CachedFunc LowerShapeFunc(const CCacheKey& key) final {
+ return LowerShapeFuncInternal(key)->cached_func;
+ }
+
+ Map<String, IRModule> GetLoweredFunctions() {
+ Map<String, IRModule> lowered_functions;
+ for (const auto& it : cache_) {
+ auto source_func = it.first;
+ auto lowered_func = it.second;
+ auto target = source_func->target;
+
+ if (!lowered_functions.count(target->str())) {
+ lowered_functions.Set(target->str(), IRModule(Map<GlobalVar,
BaseFunc>({})));
+ }
+
+
lowered_functions[target->str()]->Update(lowered_func->cached_func->funcs);
+ }
+ return lowered_functions;
+ }
+
+ Array<tvm::runtime::Module> LowerExternalFunctions() {
+ Array<tvm::runtime::Module> ret;
+ std::unordered_map<std::string, std::string> cached_symbol;
+ std::vector<CCacheKey> cached_ext_funcs;
+ for (const auto& it : cache_) {
+ auto src_func = it.first->source_func;
+ ICHECK(src_func.defined());
+ if (src_func->GetAttr<String>(attr::kCompiler).defined()) {
+ auto code_gen = src_func->GetAttr<String>(attr::kCompiler);
+ ICHECK(code_gen.defined()) << "No external codegen is set";
+ std::string code_gen_name = code_gen.value();
+ cached_ext_funcs.push_back(it.first);
+
+ auto symbol_name = src_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
+ ICHECK(symbol_name.defined()) << "No external symbol is set for:\n"
+ << AsText(src_func, false);
+
+ std::string sn = symbol_name.value();
+ if (cached_symbol.count(sn)) {
+ cached_symbol[sn] = code_gen_name;
+ } else {
+ ICHECK_NE(sn, code_gen_name)
+ << "Found duplicated symbol: " << sn << " for: " <<
code_gen_name;
+ }
+
+ std::string ext_name = "relay.ext." + code_gen_name;
+ auto pf = tvm::runtime::Registry::Get(ext_name);
+ ICHECK(pf) << "Failed to find the codegen tool for " << ext_name <<
"\n";
+ // No need to keep compiler attribute at this point, functions have
been
+ // extracted for specific codegen.
+ src_func = WithAttr(std::move(src_func), attr::kCompiler,
NullValue<ObjectRef>());
+ runtime::Module ext_mod = (*pf)(src_func);
+
+ ICHECK(ext_mod.defined()) << "No external runtime is generated.";
+ ret.push_back(ext_mod);
+ }
+ }
+
+ // No need to cache external functions as we collected them all to create
+ // external runtime modules.
+ for (const auto& it : cached_ext_funcs) {
+ cache_.erase(it);
+ }
+ return ret;
+ }
+
+ void Clear() final { cache_.clear(); }
+
+ // List all items in the cache.
+ Array<ObjectRef> ListItems() {
+ std::lock_guard<std::mutex> lock(mutex_);
+ Array<ObjectRef> items;
+ for (auto& kv : cache_) {
+ items.push_back(kv.first);
+ items.push_back(kv.second);
+ }
+ return items;
+ }
+
+ /*!
+ * \brief Get the cache key of the function that is being lowered currently
+ * \return the cache key
+ */
+ CCacheKey GetCurrentCCacheKey() { return cur_ccache_key_; }
+
+ private:
+ // implement lowered func
+ CCacheValue LowerInternal(const CCacheKey& key) {
+ std::lock_guard<std::mutex> lock(mutex_);
+ CCacheValue value;
+ auto it = cache_.find(key);
+ if (it != cache_.end()) {
+ it->second->use_count += 1;
+ if (it->second->cached_func.defined()) return it->second;
+ value = it->second;
+ } else {
+ value = CCacheValue(make_object<CCacheValueNode>());
+ value->use_count = 0;
+ if (!backend::IsCompileEngineCacheDisabled()) {
+ cache_[key] = value;
+ }
+ }
+ cur_ccache_key_ = key;
+
+ // No need to lower external functions for now. We will invoke the external
+ // codegen tool once and lower all functions together.
+ if (key->source_func->GetAttr<String>(attr::kCompiler).defined()) {
+ auto ir_module = IRModule();
+ const auto name_node =
key->source_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
+ ICHECK(name_node.defined()) << "External function has not been attached
a name yet.";
+ auto func_name = std::string(name_node.value());
+ auto target = Target("ext_dev");
+ auto global_var = GlobalVar(func_name);
+ global_var->checked_type_ = key->source_func->checked_type();
+ ir_module->Add(global_var, key->source_func);
+ value->cached_func = CachedFunc(target, global_var, {}, {},
te::Schedule(), {}, ir_module);
+ return value;
+ }
+ // Enforce use the target.
+ With<Target> target_scope(key->target);
+
+ ICHECK(!value->cached_func.defined());
+ auto cfunc = PrimFuncFor(key->source_func, key->target,
+ [&](std::string name) { return
GetUniqueName(name, name_map_); });
+
+ // Skip lowering for device copy node.
+ const Expr body = (key->source_func)->body;
+ if (const CallNode* call_node = body.as<CallNode>()) {
+ if (call_node->attrs.as<DeviceCopyAttrs>()) {
+ value->cached_func = cfunc;
+ return value;
+ }
+ }
+
+ std::cout << "Input Size: " << cfunc->inputs.size() << std::endl;
+ std::cout << "Output Size: " << cfunc->outputs.size() << std::endl;
+ // NOTE: array will copy on write.
+ Array<te::Tensor> all_args = Array<te::Tensor>(cfunc->inputs);
+ for (te::Tensor arg : cfunc->outputs) {
+ all_args.push_back(arg);
+ }
+
+ std::cout << "Allargs Size: " << all_args.size() << std::endl;
+
+ using tvm::transform::PassContext;
+ With<PassContext> fresh_pass_ctx_scope(PassContext::Create());
+
+ std::unordered_map<te::Tensor, tir::Buffer> binds;
+ auto func_name = cfunc->prim_fn_var->name_hint;
+ cfunc->funcs->Update(tvm::lower(cfunc->schedule, all_args, func_name,
binds));
+ value->cached_func = cfunc;
+ return value;
+ }
+
+ // implement lowered shape func
+ CCacheValue LowerShapeFuncInternal(const CCacheKey& key) {
+ std::lock_guard<std::mutex> lock(mutex_);
+ CCacheValue value;
+ auto it = shape_func_cache_.find(key);
+ if (it != shape_func_cache_.end()) {
+ it->second->use_count += 1;
+ if (it->second->cached_func.defined()) return it->second;
+ value = it->second;
+ } else {
+ value = CCacheValue(make_object<CCacheValueNode>());
+ value->use_count = 0;
+ shape_func_cache_[key] = value;
+ }
+ // Enforce use the target.
+ With<Target> target_scope(key->target);
+
+ ICHECK(!value->cached_func.defined());
+ auto cached_func = ShapeFuncFor(key->source_func, key->target,
[&](std::string name) {
+ return GetUniqueName(name, name_map_);
+ });
+
+ value->cached_func = cached_func;
+ return value;
+ }
+
+ /*! \brief compiler cache lock*/
+ std::mutex mutex_;
+ /*! \brief internal name map to get an unique name */
+ std::unordered_map<std::string, int> name_map_;
+ /*! \brief internal compiler cache */
+ std::unordered_map<CCacheKey, CCacheValue> cache_;
+ /*! \brief internal compiler cache for shape funcs */
+ std::unordered_map<CCacheKey, CCacheValue> shape_func_cache_;
+ /*! \brief the cache key of the function that is being lowered currently*/
+ CCacheKey cur_ccache_key_;
+};
+
+TECompiler::TECompiler() {
+ auto object = make_object<TECompilerImpl>();
+ data_ = object;
+}
+
+class LowerTensorExpr : public ExprMutator {
+ public:
+ LowerTensorExpr(const IRModule& module, const TargetsMap& targets,
+ const DeviceContextMap& device_ctx_map, TECompiler compiler)
+ : module_(module),
+ targets_(targets),
+ device_context_map_(device_ctx_map),
+ compiler_(compiler) {}
+
+ Expr VisitExpr_(const CallNode* call) override {
+ Call expr = GetRef<Call>(call);
+ Function func;
+
+ if (call->op.as<FunctionNode>()) {
+ func = GetRef<Function>(call->op.as<FunctionNode>());
+ } else {
+ return ExprMutator::VisitExpr_(call);
+ }
+
+ if (!func->HasNonzeroAttr(attr::kPrimitive)) {
+ // LOG(FATAL) << "TVM only support calls to primitive functions "
+ // << "(i.e functions composed of fusable operator
invocations)";
Review comment:
Delete or uncomment?
##########
File path: src/relay/backend/te_compiler_cache.cc
##########
@@ -0,0 +1,681 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "./te_compiler_cache.h"
+
+#include <tvm/driver/driver_api.h>
+#include <tvm/ir/type_functor.h>
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/attrs/device_copy.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/op.h>
+#include <tvm/relay/op_attr_types.h>
+#include <tvm/runtime/container.h>
+#include <tvm/runtime/device_api.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/te/operation.h>
+#include <tvm/te/schedule.h>
+#include <tvm/te/schedule_pass.h>
+#include <tvm/topi/tags.h>
+
+#include <functional>
+#include <limits>
+#include <mutex>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "../transforms/pass_utils.h"
+#include "utils.h"
+
+namespace tvm {
+namespace relay {
+namespace tec {
+
+TVM_REGISTER_NODE_TYPE(LoweredOutputNode);
+TVM_REGISTER_NODE_TYPE(CachedFuncNode);
+TVM_REGISTER_NODE_TYPE(CCacheKeyNode);
+TVM_REGISTER_NODE_TYPE(CCacheValueNode);
+
+LoweredOutput::LoweredOutput(tvm::Array<te::Tensor> outputs, OpImplementation
impl) {
+ auto n = make_object<LoweredOutputNode>();
+ n->outputs = std::move(outputs);
+ n->implementation = std::move(impl);
+ data_ = std::move(n);
+}
+
+CCacheKey::CCacheKey(Function source_func, Target target) {
+ auto n = make_object<CCacheKeyNode>();
+ n->source_func = std::move(source_func);
+ n->target = std::move(target);
+ data_ = std::move(n);
+}
+
+CachedFunc::CachedFunc(tvm::Target target, GlobalVar prim_fn_var,
tvm::Array<te::Tensor> inputs,
+ tvm::Array<te::Tensor> outputs, te::Schedule schedule,
+ tvm::Array<Integer> shape_func_param_states, IRModule
funcs) {
+ auto n = make_object<CachedFuncNode>();
+ n->target = target;
+ n->prim_fn_var = prim_fn_var;
+ n->inputs = inputs;
+ n->outputs = outputs;
+ n->schedule = schedule;
+ n->shape_func_param_states = shape_func_param_states;
+ n->funcs = funcs;
+ data_ = std::move(n);
+}
+
+Array<IndexExpr> GetShape(const Array<IndexExpr>& shape) {
+ // for now, we always use int32 shape when possible
+ // even if the result of shape inference becomes int64.
+ Array<IndexExpr> res;
+ for (IndexExpr val : shape) {
+ const int64_t* pval = tir::as_const_int(val);
+ if (pval != nullptr) {
+#ifndef TVM_INDEX_DEFAULT_I64
+ ICHECK_LE(pval[0], std::numeric_limits<int32_t>::max());
+ ICHECK_GE(pval[0], std::numeric_limits<int32_t>::min());
+ res.push_back(IntImm(DataType::Int(32), *pval));
+#else
+ res.push_back(val);
+#endif // TVM_INDEX_DEFAULT_I64
+ } else if (val->IsInstance<tir::AnyNode>()) {
+ res.push_back(val.as<tir::AnyNode>()->ToVar());
+ } else {
+ res.push_back(val);
+ }
+ }
+ return res;
+}
+
+// The getter to get schedule from compile engine.
+// Get schedule from functor.
+class ScheduleGetter : public
backend::MemoizedExprTranslator<Array<te::Tensor>> {
+ public:
+ explicit ScheduleGetter(Target target)
+ : target_(target), device_copy_op_(Op::Get("device_copy")) {
+ // Whether to use auto_scheduler schedule.
+ use_auto_scheduler_ = backend::IsAutoSchedulerEnabled();
+ }
+
+ CachedFunc Create(const Function& prim_func,
std::function<std::string(std::string)> renamer) {
+ Array<tvm::te::Tensor> fn_inputs;
+ for (Var param : prim_func->params) {
+ Array<tvm::te::Tensor> inputs;
+ if (const auto* ttype = param->checked_type().as<TensorTypeNode>()) {
+ tvm::te::Tensor tensor = tvm::te::placeholder(GetShape(ttype->shape),
ttype->dtype);
+ fn_inputs.push_back(tensor);
+ inputs.push_back(tensor);
+ } else {
+ // flatten tuple of tensor type.
+ const auto* tuple_type = param->type_as<TupleTypeNode>();
+ for (Type field : tuple_type->fields) {
+ const auto* ttype = field.as<TensorTypeNode>();
+ // TODO(@icemelon): Allow recursive tuple
+ ICHECK(ttype != nullptr);
+ tvm::te::Tensor tensor =
tvm::te::placeholder(GetShape(ttype->shape), ttype->dtype);
+ fn_inputs.push_back(tensor);
+ inputs.push_back(tensor);
+ }
+ }
+ memo_[param] = inputs;
+ }
+ readable_name_stream_ << "fused";
+ auto outputs = this->VisitExpr(prim_func->body);
+ auto candidate_name = readable_name_stream_.str();
+ constexpr static size_t kMaxFuncNameLength = 80;
+ if (candidate_name.size() > kMaxFuncNameLength) {
+ std::stringstream truncated_name;
+ truncated_name << candidate_name.substr(0, kMaxFuncNameLength);
+ truncated_name << "_" << std::hash<std::string>{}(candidate_name) << "_";
+ candidate_name = truncated_name.str();
+ }
+
+ ICHECK(anchor_op_.defined());
+ // Fusion over tupled results may leave identity relationships
+ // between inputs and outputs, and those should not be scheduled.
+ // Hence schedule only non PlaceholderOp outputs.
+ tvm::Array<te::Tensor> tensor_outs;
+ for (const auto& tensor : outputs) {
+ if (!tensor->op.as<te::PlaceholderOpNode>()) {
+ tensor_outs.push_back(tensor);
+ }
+ }
+
+ te::Schedule schedule;
+ // No need to register schedule for device copy op.
+ if (anchor_attrs_.as<DeviceCopyAttrs>() == nullptr) {
+ if (use_auto_scheduler_) {
+ const auto* fauto_schedule =
+
runtime::Registry::Get("auto_scheduler.relay_integration.auto_schedule_topi_compute");
+ ICHECK(fauto_schedule != nullptr)
+ << "auto_scheduler.relay_integration.auto_schedule_topi_compute is
not registered";
+ ObjectRef obj = (*fauto_schedule)(tensor_outs);
+ if (obj.defined()) {
+ schedule = Downcast<te::Schedule>(obj);
+ }
+ }
+
+ // Use TOPI schdule if user specificed, or the function has no
auto_scheduler schedule.
+ if (!schedule.defined()) {
+ ICHECK(anchor_implementation_.defined());
+ schedule = anchor_implementation_.Schedule(anchor_attrs_, tensor_outs,
target_);
+ }
+ for (const auto& scalar : scalars_) {
+ if (schedule->Contain(scalar)) {
+ schedule[scalar].compute_inline();
+ }
+ }
+ }
+
+ auto prim_fn_var = GlobalVar(candidate_name);
+ prim_fn_var->checked_type_ = prim_func->checked_type();
+ return CachedFunc(target_, prim_fn_var, fn_inputs, outputs, schedule, {});
+ }
+
+ Array<te::Tensor> VisitExpr_(const VarNode* op) final {
+ LOG(FATAL) << "Free variable " << op->name_hint();
+ return {};
+ }
+
+ Array<te::Tensor> VisitExpr_(const ConstantNode* op) final {
+ using tir::make_const;
+ ICHECK(op->is_scalar());
+ void* data = op->data->data;
+ DataType dtype = DataType(op->data->dtype);
+ auto value = te::compute(
+ {},
+ [&](const Array<tvm::tir::Var>&) {
+ if (dtype == DataType::Int(32)) {
+ return make_const(dtype, static_cast<const int32_t*>(data)[0]);
+ } else if (dtype == DataType::Int(64)) {
+ return make_const(dtype, static_cast<const int64_t*>(data)[0]);
+ } else if (dtype == DataType::Float(32)) {
+ return make_const(dtype, static_cast<const float*>(data)[0]);
+ } else if (dtype == DataType::Float(64)) {
+ return make_const(dtype, static_cast<const double*>(data)[0]);
+ } else if (dtype == DataType::Bool()) {
+ return make_const(dtype, static_cast<const uint8_t*>(data)[0]);
+ } else {
+ LOG(FATAL) << "not handled";
+ return tvm::PrimExpr();
+ }
+ },
+ "compile_engine_const", topi::kBroadcast);
+ scalars_.push_back(value->op);
+ return {value};
+ }
+
+ Array<te::Tensor> VisitExpr_(const CallNode* call_node) final {
+ static auto fpattern = Op::GetAttrMap<TOpPattern>("TOpPattern");
+ static auto flower_call =
tvm::runtime::Registry::Get("relay.backend.lower_call");
+ ICHECK(flower_call) << "relay.backend.lower_call is not registered.";
+
+ Array<te::Tensor> inputs;
+ int count_tuple = 0;
+ for (Expr arg : call_node->args) {
+ if (arg->checked_type().as<TupleTypeNode>()) {
+ ++count_tuple;
+ }
+ for (te::Tensor tensor : VisitExpr(arg)) {
+ inputs.push_back(tensor);
+ }
+ }
+ if (count_tuple) {
+ ICHECK_EQ(call_node->args.size(), 1U) << "Only allow function with a
single tuple input";
Review comment:
```suggestion
ICHECK_EQ(call_node->args.size(), 1U) << "Only functions with a single
tuple input are allowed, but " << call_node->args.size() << " were provided.";
```
##########
File path: src/relay/backend/te_compiler_cache.cc
##########
@@ -0,0 +1,681 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "./te_compiler_cache.h"
+
+#include <tvm/driver/driver_api.h>
+#include <tvm/ir/type_functor.h>
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/attrs/device_copy.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/op.h>
+#include <tvm/relay/op_attr_types.h>
+#include <tvm/runtime/container.h>
+#include <tvm/runtime/device_api.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/te/operation.h>
+#include <tvm/te/schedule.h>
+#include <tvm/te/schedule_pass.h>
+#include <tvm/topi/tags.h>
+
+#include <functional>
+#include <limits>
+#include <mutex>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "../transforms/pass_utils.h"
+#include "utils.h"
+
+namespace tvm {
+namespace relay {
+namespace tec {
+
+TVM_REGISTER_NODE_TYPE(LoweredOutputNode);
+TVM_REGISTER_NODE_TYPE(CachedFuncNode);
+TVM_REGISTER_NODE_TYPE(CCacheKeyNode);
+TVM_REGISTER_NODE_TYPE(CCacheValueNode);
+
+LoweredOutput::LoweredOutput(tvm::Array<te::Tensor> outputs, OpImplementation
impl) {
+ auto n = make_object<LoweredOutputNode>();
+ n->outputs = std::move(outputs);
+ n->implementation = std::move(impl);
+ data_ = std::move(n);
+}
+
+CCacheKey::CCacheKey(Function source_func, Target target) {
+ auto n = make_object<CCacheKeyNode>();
+ n->source_func = std::move(source_func);
+ n->target = std::move(target);
+ data_ = std::move(n);
+}
+
+CachedFunc::CachedFunc(tvm::Target target, GlobalVar prim_fn_var,
tvm::Array<te::Tensor> inputs,
+ tvm::Array<te::Tensor> outputs, te::Schedule schedule,
+ tvm::Array<Integer> shape_func_param_states, IRModule
funcs) {
+ auto n = make_object<CachedFuncNode>();
+ n->target = target;
+ n->prim_fn_var = prim_fn_var;
+ n->inputs = inputs;
+ n->outputs = outputs;
+ n->schedule = schedule;
+ n->shape_func_param_states = shape_func_param_states;
+ n->funcs = funcs;
+ data_ = std::move(n);
+}
+
+Array<IndexExpr> GetShape(const Array<IndexExpr>& shape) {
+ // for now, we always use int32 shape when possible
+ // even if the result of shape inference becomes int64.
+ Array<IndexExpr> res;
+ for (IndexExpr val : shape) {
+ const int64_t* pval = tir::as_const_int(val);
+ if (pval != nullptr) {
+#ifndef TVM_INDEX_DEFAULT_I64
+ ICHECK_LE(pval[0], std::numeric_limits<int32_t>::max());
+ ICHECK_GE(pval[0], std::numeric_limits<int32_t>::min());
+ res.push_back(IntImm(DataType::Int(32), *pval));
+#else
+ res.push_back(val);
+#endif // TVM_INDEX_DEFAULT_I64
+ } else if (val->IsInstance<tir::AnyNode>()) {
+ res.push_back(val.as<tir::AnyNode>()->ToVar());
+ } else {
+ res.push_back(val);
+ }
+ }
+ return res;
+}
+
+// The getter to get schedule from compile engine.
+// Get schedule from functor.
+class ScheduleGetter : public
backend::MemoizedExprTranslator<Array<te::Tensor>> {
+ public:
+ explicit ScheduleGetter(Target target)
+ : target_(target), device_copy_op_(Op::Get("device_copy")) {
+ // Whether to use auto_scheduler schedule.
+ use_auto_scheduler_ = backend::IsAutoSchedulerEnabled();
+ }
+
+ CachedFunc Create(const Function& prim_func,
std::function<std::string(std::string)> renamer) {
+ Array<tvm::te::Tensor> fn_inputs;
+ for (Var param : prim_func->params) {
+ Array<tvm::te::Tensor> inputs;
+ if (const auto* ttype = param->checked_type().as<TensorTypeNode>()) {
+ tvm::te::Tensor tensor = tvm::te::placeholder(GetShape(ttype->shape),
ttype->dtype);
+ fn_inputs.push_back(tensor);
+ inputs.push_back(tensor);
+ } else {
+ // flatten tuple of tensor type.
+ const auto* tuple_type = param->type_as<TupleTypeNode>();
+ for (Type field : tuple_type->fields) {
+ const auto* ttype = field.as<TensorTypeNode>();
+ // TODO(@icemelon): Allow recursive tuple
+ ICHECK(ttype != nullptr);
+ tvm::te::Tensor tensor =
tvm::te::placeholder(GetShape(ttype->shape), ttype->dtype);
+ fn_inputs.push_back(tensor);
+ inputs.push_back(tensor);
+ }
+ }
+ memo_[param] = inputs;
+ }
+ readable_name_stream_ << "fused";
+ auto outputs = this->VisitExpr(prim_func->body);
+ auto candidate_name = readable_name_stream_.str();
+ constexpr static size_t kMaxFuncNameLength = 80;
+ if (candidate_name.size() > kMaxFuncNameLength) {
+ std::stringstream truncated_name;
+ truncated_name << candidate_name.substr(0, kMaxFuncNameLength);
+ truncated_name << "_" << std::hash<std::string>{}(candidate_name) << "_";
+ candidate_name = truncated_name.str();
+ }
+
+ ICHECK(anchor_op_.defined());
+ // Fusion over tupled results may leave identity relationships
+ // between inputs and outputs, and those should not be scheduled.
+ // Hence schedule only non PlaceholderOp outputs.
+ tvm::Array<te::Tensor> tensor_outs;
+ for (const auto& tensor : outputs) {
+ if (!tensor->op.as<te::PlaceholderOpNode>()) {
+ tensor_outs.push_back(tensor);
+ }
+ }
+
+ te::Schedule schedule;
+ // No need to register schedule for device copy op.
+ if (anchor_attrs_.as<DeviceCopyAttrs>() == nullptr) {
+ if (use_auto_scheduler_) {
+ const auto* fauto_schedule =
+
runtime::Registry::Get("auto_scheduler.relay_integration.auto_schedule_topi_compute");
+ ICHECK(fauto_schedule != nullptr)
+ << "auto_scheduler.relay_integration.auto_schedule_topi_compute is
not registered";
+ ObjectRef obj = (*fauto_schedule)(tensor_outs);
+ if (obj.defined()) {
+ schedule = Downcast<te::Schedule>(obj);
+ }
+ }
+
+ // Use TOPI schdule if user specificed, or the function has no
auto_scheduler schedule.
+ if (!schedule.defined()) {
+ ICHECK(anchor_implementation_.defined());
+ schedule = anchor_implementation_.Schedule(anchor_attrs_, tensor_outs,
target_);
+ }
+ for (const auto& scalar : scalars_) {
+ if (schedule->Contain(scalar)) {
+ schedule[scalar].compute_inline();
+ }
+ }
+ }
+
+ auto prim_fn_var = GlobalVar(candidate_name);
+ prim_fn_var->checked_type_ = prim_func->checked_type();
+ return CachedFunc(target_, prim_fn_var, fn_inputs, outputs, schedule, {});
+ }
+
+ Array<te::Tensor> VisitExpr_(const VarNode* op) final {
+ LOG(FATAL) << "Free variable " << op->name_hint();
+ return {};
+ }
+
+ Array<te::Tensor> VisitExpr_(const ConstantNode* op) final {
+ using tir::make_const;
+ ICHECK(op->is_scalar());
+ void* data = op->data->data;
+ DataType dtype = DataType(op->data->dtype);
+ auto value = te::compute(
+ {},
+ [&](const Array<tvm::tir::Var>&) {
+ if (dtype == DataType::Int(32)) {
+ return make_const(dtype, static_cast<const int32_t*>(data)[0]);
+ } else if (dtype == DataType::Int(64)) {
+ return make_const(dtype, static_cast<const int64_t*>(data)[0]);
+ } else if (dtype == DataType::Float(32)) {
+ return make_const(dtype, static_cast<const float*>(data)[0]);
+ } else if (dtype == DataType::Float(64)) {
+ return make_const(dtype, static_cast<const double*>(data)[0]);
+ } else if (dtype == DataType::Bool()) {
+ return make_const(dtype, static_cast<const uint8_t*>(data)[0]);
+ } else {
+ LOG(FATAL) << "not handled";
+ return tvm::PrimExpr();
+ }
+ },
+ "compile_engine_const", topi::kBroadcast);
+ scalars_.push_back(value->op);
+ return {value};
+ }
+
+ Array<te::Tensor> VisitExpr_(const CallNode* call_node) final {
+ static auto fpattern = Op::GetAttrMap<TOpPattern>("TOpPattern");
+ static auto flower_call =
tvm::runtime::Registry::Get("relay.backend.lower_call");
+ ICHECK(flower_call) << "relay.backend.lower_call is not registered.";
+
+ Array<te::Tensor> inputs;
+ int count_tuple = 0;
+ for (Expr arg : call_node->args) {
+ if (arg->checked_type().as<TupleTypeNode>()) {
+ ++count_tuple;
+ }
+ for (te::Tensor tensor : VisitExpr(arg)) {
+ inputs.push_back(tensor);
+ }
+ }
+ if (count_tuple) {
+ ICHECK_EQ(call_node->args.size(), 1U) << "Only allow function with a
single tuple input";
+ }
+
+ ICHECK(call_node->op.as<OpNode>()) << "Primitive function only allows call
into primitive ops";
+ Op op = Downcast<Op>(call_node->op);
+
+ Array<te::Tensor> outputs;
+ OpImplementation impl;
+ // Skip fcompute for device copy operators as it is not registered.
+ if (op == device_copy_op_) {
+ const auto* copy_input = inputs[0].operator->();
+ outputs.push_back(te::Tensor(copy_input->shape, copy_input->dtype,
te::Operation(), 0));
+ } else {
+ LoweredOutput lowered_out = (*flower_call)(GetRef<Call>(call_node),
inputs, target_);
+ outputs = lowered_out->outputs;
+ impl = lowered_out->implementation;
+ }
+
+ int op_pattern = fpattern[op];
+ if (!use_auto_scheduler_ && op_pattern >= kCommReduce) {
+ ICHECK(!anchor_op_.defined() || anchor_op_pattern_ < kCommReduce)
+ << "Cannot apply TOPI schedule to a primitive function with two
complicated ops"
+ << " anchor=" << anchor_op_ << " current=" << op;
+ }
+ if (op_pattern >= anchor_op_pattern_) {
+ anchor_op_ = op;
+ anchor_attrs_ = call_node->attrs;
+ anchor_op_pattern_ = op_pattern;
+ anchor_implementation_ = impl;
+ }
+ if (outputs.size() != 1) {
+ const auto* tuple_type = call_node->checked_type().as<TupleTypeNode>();
+ ICHECK(tuple_type) << "Expect output to be a tuple type";
Review comment:
Add actual type here
##########
File path: src/relay/backend/te_compiler.cc
##########
@@ -0,0 +1,372 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "te_compiler.h"
+
+#include <tvm/driver/driver_api.h>
+#include <tvm/ir/type_functor.h>
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/attrs/device_copy.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/op.h>
+#include <tvm/relay/op_attr_types.h>
+#include <tvm/runtime/container.h>
+#include <tvm/runtime/device_api.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/te/operation.h>
+#include <tvm/te/schedule.h>
+#include <tvm/te/schedule_pass.h>
+#include <tvm/topi/tags.h>
+
+#include <functional>
+#include <limits>
+#include <mutex>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "../transforms/pass_utils.h"
+#include "te_compiler_cache.h"
+#include "utils.h"
+
+namespace tvm {
+namespace relay {
+namespace tec {
+
+using namespace tvm::relay::transform;
+
+TVM_REGISTER_OBJECT_TYPE(TECompilerNode);
+
+class TECompilerImpl : public TECompilerNode {
+ public:
+ // Lower the function.
+ CachedFunc Lower(const CCacheKey& key) { return
LowerInternal(key)->cached_func; }
+
+ // For now, build one module per function.
+ PackedFunc JIT(const CCacheKey& key) final {
+ CCacheValue value = LowerInternal(key);
+ if (value->packed_func != nullptr) {
+ return value->packed_func;
+ }
+ auto m = build(value->cached_func->funcs, key->target, Target(nullptr));
+ value->packed_func =
m.GetFunction(value->cached_func->prim_fn_var->name_hint);
+ return value->packed_func;
+ }
+
+ CachedFunc LowerShapeFunc(const CCacheKey& key) final {
+ return LowerShapeFuncInternal(key)->cached_func;
+ }
+
+ Map<String, IRModule> GetLoweredFunctions() {
+ Map<String, IRModule> lowered_functions;
+ for (const auto& it : cache_) {
+ auto source_func = it.first;
+ auto lowered_func = it.second;
+ auto target = source_func->target;
+
+ if (!lowered_functions.count(target->str())) {
+ lowered_functions.Set(target->str(), IRModule(Map<GlobalVar,
BaseFunc>({})));
+ }
+
+
lowered_functions[target->str()]->Update(lowered_func->cached_func->funcs);
+ }
+ return lowered_functions;
+ }
+
+ Array<tvm::runtime::Module> LowerExternalFunctions() {
+ Array<tvm::runtime::Module> ret;
+ std::unordered_map<std::string, std::string> cached_symbol;
+ std::vector<CCacheKey> cached_ext_funcs;
+ for (const auto& it : cache_) {
+ auto src_func = it.first->source_func;
+ ICHECK(src_func.defined());
+ if (src_func->GetAttr<String>(attr::kCompiler).defined()) {
+ auto code_gen = src_func->GetAttr<String>(attr::kCompiler);
+ ICHECK(code_gen.defined()) << "No external codegen is set";
+ std::string code_gen_name = code_gen.value();
+ cached_ext_funcs.push_back(it.first);
+
+ auto symbol_name = src_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
+ ICHECK(symbol_name.defined()) << "No external symbol is set for:\n"
+ << AsText(src_func, false);
+
+ std::string sn = symbol_name.value();
+ if (cached_symbol.count(sn)) {
+ cached_symbol[sn] = code_gen_name;
+ } else {
+ ICHECK_NE(sn, code_gen_name)
+ << "Found duplicated symbol: " << sn << " for: " <<
code_gen_name;
+ }
+
+ std::string ext_name = "relay.ext." + code_gen_name;
+ auto pf = tvm::runtime::Registry::Get(ext_name);
+ ICHECK(pf) << "Failed to find the codegen tool for " << ext_name <<
"\n";
+ // No need to keep compiler attribute at this point, functions have
been
+ // extracted for specific codegen.
+ src_func = WithAttr(std::move(src_func), attr::kCompiler,
NullValue<ObjectRef>());
+ runtime::Module ext_mod = (*pf)(src_func);
+
+ ICHECK(ext_mod.defined()) << "No external runtime is generated.";
+ ret.push_back(ext_mod);
+ }
+ }
+
+ // No need to cache external functions as we collected them all to create
+ // external runtime modules.
+ for (const auto& it : cached_ext_funcs) {
+ cache_.erase(it);
+ }
+ return ret;
+ }
+
+ void Clear() final { cache_.clear(); }
+
+ // List all items in the cache.
+ Array<ObjectRef> ListItems() {
+ std::lock_guard<std::mutex> lock(mutex_);
+ Array<ObjectRef> items;
+ for (auto& kv : cache_) {
+ items.push_back(kv.first);
+ items.push_back(kv.second);
+ }
+ return items;
+ }
+
+ /*!
+ * \brief Get the cache key of the function that is being lowered currently
+ * \return the cache key
+ */
+ CCacheKey GetCurrentCCacheKey() { return cur_ccache_key_; }
+
+ private:
+ // implement lowered func
+ CCacheValue LowerInternal(const CCacheKey& key) {
+ std::lock_guard<std::mutex> lock(mutex_);
+ CCacheValue value;
+ auto it = cache_.find(key);
+ if (it != cache_.end()) {
+ it->second->use_count += 1;
+ if (it->second->cached_func.defined()) return it->second;
+ value = it->second;
+ } else {
+ value = CCacheValue(make_object<CCacheValueNode>());
+ value->use_count = 0;
+ if (!backend::IsCompileEngineCacheDisabled()) {
+ cache_[key] = value;
+ }
+ }
+ cur_ccache_key_ = key;
+
+ // No need to lower external functions for now. We will invoke the external
+ // codegen tool once and lower all functions together.
+ if (key->source_func->GetAttr<String>(attr::kCompiler).defined()) {
+ auto ir_module = IRModule();
+ const auto name_node =
key->source_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
+ ICHECK(name_node.defined()) << "External function has not been attached
a name yet.";
Review comment:
Could you add the function name to this?
##########
File path: src/relay/backend/te_compiler.cc
##########
@@ -0,0 +1,372 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "te_compiler.h"
+
+#include <tvm/driver/driver_api.h>
+#include <tvm/ir/type_functor.h>
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/attrs/device_copy.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/op.h>
+#include <tvm/relay/op_attr_types.h>
+#include <tvm/runtime/container.h>
+#include <tvm/runtime/device_api.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/te/operation.h>
+#include <tvm/te/schedule.h>
+#include <tvm/te/schedule_pass.h>
+#include <tvm/topi/tags.h>
+
+#include <functional>
+#include <limits>
+#include <mutex>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "../transforms/pass_utils.h"
+#include "te_compiler_cache.h"
+#include "utils.h"
+
+namespace tvm {
+namespace relay {
+namespace tec {
+
+using namespace tvm::relay::transform;
+
+TVM_REGISTER_OBJECT_TYPE(TECompilerNode);
+
+class TECompilerImpl : public TECompilerNode {
+ public:
+ // Lower the function.
+ CachedFunc Lower(const CCacheKey& key) { return
LowerInternal(key)->cached_func; }
+
+ // For now, build one module per function.
+ PackedFunc JIT(const CCacheKey& key) final {
+ CCacheValue value = LowerInternal(key);
+ if (value->packed_func != nullptr) {
+ return value->packed_func;
+ }
+ auto m = build(value->cached_func->funcs, key->target, Target(nullptr));
+ value->packed_func =
m.GetFunction(value->cached_func->prim_fn_var->name_hint);
+ return value->packed_func;
+ }
+
+ CachedFunc LowerShapeFunc(const CCacheKey& key) final {
+ return LowerShapeFuncInternal(key)->cached_func;
+ }
+
+ Map<String, IRModule> GetLoweredFunctions() {
+ Map<String, IRModule> lowered_functions;
+ for (const auto& it : cache_) {
+ auto source_func = it.first;
+ auto lowered_func = it.second;
+ auto target = source_func->target;
+
+ if (!lowered_functions.count(target->str())) {
+ lowered_functions.Set(target->str(), IRModule(Map<GlobalVar,
BaseFunc>({})));
+ }
+
+
lowered_functions[target->str()]->Update(lowered_func->cached_func->funcs);
+ }
+ return lowered_functions;
+ }
+
+ Array<tvm::runtime::Module> LowerExternalFunctions() {
+ Array<tvm::runtime::Module> ret;
+ std::unordered_map<std::string, std::string> cached_symbol;
+ std::vector<CCacheKey> cached_ext_funcs;
+ for (const auto& it : cache_) {
+ auto src_func = it.first->source_func;
+ ICHECK(src_func.defined());
+ if (src_func->GetAttr<String>(attr::kCompiler).defined()) {
+ auto code_gen = src_func->GetAttr<String>(attr::kCompiler);
+ ICHECK(code_gen.defined()) << "No external codegen is set";
+ std::string code_gen_name = code_gen.value();
+ cached_ext_funcs.push_back(it.first);
+
+ auto symbol_name = src_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
+ ICHECK(symbol_name.defined()) << "No external symbol is set for:\n"
+ << AsText(src_func, false);
+
+ std::string sn = symbol_name.value();
+ if (cached_symbol.count(sn)) {
+ cached_symbol[sn] = code_gen_name;
+ } else {
+ ICHECK_NE(sn, code_gen_name)
+ << "Found duplicated symbol: " << sn << " for: " <<
code_gen_name;
+ }
+
+ std::string ext_name = "relay.ext." + code_gen_name;
+ auto pf = tvm::runtime::Registry::Get(ext_name);
+ ICHECK(pf) << "Failed to find the codegen tool for " << ext_name <<
"\n";
+ // No need to keep compiler attribute at this point, functions have
been
+ // extracted for specific codegen.
+ src_func = WithAttr(std::move(src_func), attr::kCompiler,
NullValue<ObjectRef>());
+ runtime::Module ext_mod = (*pf)(src_func);
+
+ ICHECK(ext_mod.defined()) << "No external runtime is generated.";
+ ret.push_back(ext_mod);
+ }
+ }
+
+ // No need to cache external functions as we collected them all to create
+ // external runtime modules.
+ for (const auto& it : cached_ext_funcs) {
+ cache_.erase(it);
+ }
+ return ret;
+ }
+
+ void Clear() final { cache_.clear(); }
+
+ // List all items in the cache.
+ Array<ObjectRef> ListItems() {
+ std::lock_guard<std::mutex> lock(mutex_);
+ Array<ObjectRef> items;
+ for (auto& kv : cache_) {
+ items.push_back(kv.first);
+ items.push_back(kv.second);
+ }
+ return items;
+ }
+
+ /*!
+ * \brief Get the cache key of the function that is being lowered currently
+ * \return the cache key
+ */
+ CCacheKey GetCurrentCCacheKey() { return cur_ccache_key_; }
+
+ private:
+ // implement lowered func
+ CCacheValue LowerInternal(const CCacheKey& key) {
+ std::lock_guard<std::mutex> lock(mutex_);
+ CCacheValue value;
+ auto it = cache_.find(key);
+ if (it != cache_.end()) {
+ it->second->use_count += 1;
+ if (it->second->cached_func.defined()) return it->second;
+ value = it->second;
+ } else {
+ value = CCacheValue(make_object<CCacheValueNode>());
+ value->use_count = 0;
+ if (!backend::IsCompileEngineCacheDisabled()) {
+ cache_[key] = value;
+ }
+ }
+ cur_ccache_key_ = key;
+
+ // No need to lower external functions for now. We will invoke the external
+ // codegen tool once and lower all functions together.
+ if (key->source_func->GetAttr<String>(attr::kCompiler).defined()) {
+ auto ir_module = IRModule();
+ const auto name_node =
key->source_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
+ ICHECK(name_node.defined()) << "External function has not been attached
a name yet.";
+ auto func_name = std::string(name_node.value());
+ auto target = Target("ext_dev");
+ auto global_var = GlobalVar(func_name);
+ global_var->checked_type_ = key->source_func->checked_type();
+ ir_module->Add(global_var, key->source_func);
+ value->cached_func = CachedFunc(target, global_var, {}, {},
te::Schedule(), {}, ir_module);
+ return value;
+ }
+ // Enforce use the target.
+ With<Target> target_scope(key->target);
+
+ ICHECK(!value->cached_func.defined());
+ auto cfunc = PrimFuncFor(key->source_func, key->target,
+ [&](std::string name) { return
GetUniqueName(name, name_map_); });
+
+ // Skip lowering for device copy node.
+ const Expr body = (key->source_func)->body;
+ if (const CallNode* call_node = body.as<CallNode>()) {
+ if (call_node->attrs.as<DeviceCopyAttrs>()) {
+ value->cached_func = cfunc;
+ return value;
+ }
+ }
+
+ std::cout << "Input Size: " << cfunc->inputs.size() << std::endl;
+ std::cout << "Output Size: " << cfunc->outputs.size() << std::endl;
+ // NOTE: array will copy on write.
+ Array<te::Tensor> all_args = Array<te::Tensor>(cfunc->inputs);
+ for (te::Tensor arg : cfunc->outputs) {
+ all_args.push_back(arg);
+ }
+
+ std::cout << "Allargs Size: " << all_args.size() << std::endl;
Review comment:
remove?
##########
File path: src/runtime/graph/graph_runtime.cc
##########
@@ -428,6 +429,7 @@ std::pair<std::function<void()>,
std::shared_ptr<GraphRuntime::OpArgs> > GraphRu
ICHECK(pf != nullptr) << "no such function in module: " << param.func_name;
auto fexec = [arg_ptr, pf]() {
+ std::cout << "Number of args: " <<
static_cast<int>(arg_ptr->arg_values.size());
Review comment:
remove?
##########
File path: src/relay/backend/te_compiler_cache.cc
##########
@@ -0,0 +1,681 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "./te_compiler_cache.h"
+
+#include <tvm/driver/driver_api.h>
+#include <tvm/ir/type_functor.h>
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/attrs/device_copy.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/op.h>
+#include <tvm/relay/op_attr_types.h>
+#include <tvm/runtime/container.h>
+#include <tvm/runtime/device_api.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/te/operation.h>
+#include <tvm/te/schedule.h>
+#include <tvm/te/schedule_pass.h>
+#include <tvm/topi/tags.h>
+
+#include <functional>
+#include <limits>
+#include <mutex>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "../transforms/pass_utils.h"
+#include "utils.h"
+
+namespace tvm {
+namespace relay {
+namespace tec {
+
+TVM_REGISTER_NODE_TYPE(LoweredOutputNode);
+TVM_REGISTER_NODE_TYPE(CachedFuncNode);
+TVM_REGISTER_NODE_TYPE(CCacheKeyNode);
+TVM_REGISTER_NODE_TYPE(CCacheValueNode);
+
+LoweredOutput::LoweredOutput(tvm::Array<te::Tensor> outputs, OpImplementation
impl) {
+ auto n = make_object<LoweredOutputNode>();
+ n->outputs = std::move(outputs);
+ n->implementation = std::move(impl);
+ data_ = std::move(n);
+}
+
+CCacheKey::CCacheKey(Function source_func, Target target) {
+ auto n = make_object<CCacheKeyNode>();
+ n->source_func = std::move(source_func);
+ n->target = std::move(target);
+ data_ = std::move(n);
+}
+
+CachedFunc::CachedFunc(tvm::Target target, GlobalVar prim_fn_var,
tvm::Array<te::Tensor> inputs,
+ tvm::Array<te::Tensor> outputs, te::Schedule schedule,
+ tvm::Array<Integer> shape_func_param_states, IRModule
funcs) {
+ auto n = make_object<CachedFuncNode>();
+ n->target = target;
+ n->prim_fn_var = prim_fn_var;
+ n->inputs = inputs;
+ n->outputs = outputs;
+ n->schedule = schedule;
+ n->shape_func_param_states = shape_func_param_states;
+ n->funcs = funcs;
+ data_ = std::move(n);
+}
+
+Array<IndexExpr> GetShape(const Array<IndexExpr>& shape) {
+ // for now, we always use int32 shape when possible
+ // even if the result of shape inference becomes int64.
+ Array<IndexExpr> res;
+ for (IndexExpr val : shape) {
+ const int64_t* pval = tir::as_const_int(val);
+ if (pval != nullptr) {
+#ifndef TVM_INDEX_DEFAULT_I64
+ ICHECK_LE(pval[0], std::numeric_limits<int32_t>::max());
+ ICHECK_GE(pval[0], std::numeric_limits<int32_t>::min());
+ res.push_back(IntImm(DataType::Int(32), *pval));
+#else
+ res.push_back(val);
+#endif // TVM_INDEX_DEFAULT_I64
+ } else if (val->IsInstance<tir::AnyNode>()) {
+ res.push_back(val.as<tir::AnyNode>()->ToVar());
+ } else {
+ res.push_back(val);
+ }
+ }
+ return res;
+}
+
+// The getter to get schedule from compile engine.
+// Get schedule from functor.
+class ScheduleGetter : public
backend::MemoizedExprTranslator<Array<te::Tensor>> {
+ public:
+ explicit ScheduleGetter(Target target)
+ : target_(target), device_copy_op_(Op::Get("device_copy")) {
+ // Whether to use auto_scheduler schedule.
+ use_auto_scheduler_ = backend::IsAutoSchedulerEnabled();
+ }
+
+ CachedFunc Create(const Function& prim_func,
std::function<std::string(std::string)> renamer) {
+ Array<tvm::te::Tensor> fn_inputs;
+ for (Var param : prim_func->params) {
+ Array<tvm::te::Tensor> inputs;
+ if (const auto* ttype = param->checked_type().as<TensorTypeNode>()) {
+ tvm::te::Tensor tensor = tvm::te::placeholder(GetShape(ttype->shape),
ttype->dtype);
+ fn_inputs.push_back(tensor);
+ inputs.push_back(tensor);
+ } else {
+ // flatten tuple of tensor type.
+ const auto* tuple_type = param->type_as<TupleTypeNode>();
+ for (Type field : tuple_type->fields) {
+ const auto* ttype = field.as<TensorTypeNode>();
+ // TODO(@icemelon): Allow recursive tuple
+ ICHECK(ttype != nullptr);
+ tvm::te::Tensor tensor =
tvm::te::placeholder(GetShape(ttype->shape), ttype->dtype);
+ fn_inputs.push_back(tensor);
+ inputs.push_back(tensor);
+ }
+ }
+ memo_[param] = inputs;
+ }
+ readable_name_stream_ << "fused";
+ auto outputs = this->VisitExpr(prim_func->body);
+ auto candidate_name = readable_name_stream_.str();
+ constexpr static size_t kMaxFuncNameLength = 80;
+ if (candidate_name.size() > kMaxFuncNameLength) {
+ std::stringstream truncated_name;
+ truncated_name << candidate_name.substr(0, kMaxFuncNameLength);
+ truncated_name << "_" << std::hash<std::string>{}(candidate_name) << "_";
+ candidate_name = truncated_name.str();
+ }
+
+ ICHECK(anchor_op_.defined());
+ // Fusion over tupled results may leave identity relationships
+ // between inputs and outputs, and those should not be scheduled.
+ // Hence schedule only non PlaceholderOp outputs.
+ tvm::Array<te::Tensor> tensor_outs;
+ for (const auto& tensor : outputs) {
+ if (!tensor->op.as<te::PlaceholderOpNode>()) {
+ tensor_outs.push_back(tensor);
+ }
+ }
+
+ te::Schedule schedule;
+ // No need to register schedule for device copy op.
+ if (anchor_attrs_.as<DeviceCopyAttrs>() == nullptr) {
+ if (use_auto_scheduler_) {
+ const auto* fauto_schedule =
+
runtime::Registry::Get("auto_scheduler.relay_integration.auto_schedule_topi_compute");
+ ICHECK(fauto_schedule != nullptr)
+ << "auto_scheduler.relay_integration.auto_schedule_topi_compute is
not registered";
+ ObjectRef obj = (*fauto_schedule)(tensor_outs);
+ if (obj.defined()) {
+ schedule = Downcast<te::Schedule>(obj);
+ }
+ }
+
+ // Use TOPI schdule if user specificed, or the function has no
auto_scheduler schedule.
+ if (!schedule.defined()) {
+ ICHECK(anchor_implementation_.defined());
+ schedule = anchor_implementation_.Schedule(anchor_attrs_, tensor_outs,
target_);
+ }
+ for (const auto& scalar : scalars_) {
+ if (schedule->Contain(scalar)) {
+ schedule[scalar].compute_inline();
+ }
+ }
+ }
+
+ auto prim_fn_var = GlobalVar(candidate_name);
+ prim_fn_var->checked_type_ = prim_func->checked_type();
+ return CachedFunc(target_, prim_fn_var, fn_inputs, outputs, schedule, {});
+ }
+
+ Array<te::Tensor> VisitExpr_(const VarNode* op) final {
+ LOG(FATAL) << "Free variable " << op->name_hint();
+ return {};
+ }
+
+ Array<te::Tensor> VisitExpr_(const ConstantNode* op) final {
+ using tir::make_const;
+ ICHECK(op->is_scalar());
+ void* data = op->data->data;
+ DataType dtype = DataType(op->data->dtype);
+ auto value = te::compute(
+ {},
+ [&](const Array<tvm::tir::Var>&) {
+ if (dtype == DataType::Int(32)) {
+ return make_const(dtype, static_cast<const int32_t*>(data)[0]);
+ } else if (dtype == DataType::Int(64)) {
+ return make_const(dtype, static_cast<const int64_t*>(data)[0]);
+ } else if (dtype == DataType::Float(32)) {
+ return make_const(dtype, static_cast<const float*>(data)[0]);
+ } else if (dtype == DataType::Float(64)) {
+ return make_const(dtype, static_cast<const double*>(data)[0]);
+ } else if (dtype == DataType::Bool()) {
+ return make_const(dtype, static_cast<const uint8_t*>(data)[0]);
+ } else {
+ LOG(FATAL) << "not handled";
+ return tvm::PrimExpr();
+ }
+ },
+ "compile_engine_const", topi::kBroadcast);
+ scalars_.push_back(value->op);
+ return {value};
+ }
+
+ Array<te::Tensor> VisitExpr_(const CallNode* call_node) final {
+ static auto fpattern = Op::GetAttrMap<TOpPattern>("TOpPattern");
+ static auto flower_call =
tvm::runtime::Registry::Get("relay.backend.lower_call");
+ ICHECK(flower_call) << "relay.backend.lower_call is not registered.";
+
+ Array<te::Tensor> inputs;
+ int count_tuple = 0;
+ for (Expr arg : call_node->args) {
+ if (arg->checked_type().as<TupleTypeNode>()) {
+ ++count_tuple;
+ }
+ for (te::Tensor tensor : VisitExpr(arg)) {
+ inputs.push_back(tensor);
+ }
+ }
+ if (count_tuple) {
+ ICHECK_EQ(call_node->args.size(), 1U) << "Only allow function with a
single tuple input";
+ }
+
+ ICHECK(call_node->op.as<OpNode>()) << "Primitive function only allows call
into primitive ops";
+ Op op = Downcast<Op>(call_node->op);
+
+ Array<te::Tensor> outputs;
+ OpImplementation impl;
+ // Skip fcompute for device copy operators as it is not registered.
+ if (op == device_copy_op_) {
+ const auto* copy_input = inputs[0].operator->();
+ outputs.push_back(te::Tensor(copy_input->shape, copy_input->dtype,
te::Operation(), 0));
+ } else {
+ LoweredOutput lowered_out = (*flower_call)(GetRef<Call>(call_node),
inputs, target_);
+ outputs = lowered_out->outputs;
+ impl = lowered_out->implementation;
+ }
+
+ int op_pattern = fpattern[op];
+ if (!use_auto_scheduler_ && op_pattern >= kCommReduce) {
+ ICHECK(!anchor_op_.defined() || anchor_op_pattern_ < kCommReduce)
+ << "Cannot apply TOPI schedule to a primitive function with two
complicated ops"
+ << " anchor=" << anchor_op_ << " current=" << op;
+ }
+ if (op_pattern >= anchor_op_pattern_) {
+ anchor_op_ = op;
+ anchor_attrs_ = call_node->attrs;
+ anchor_op_pattern_ = op_pattern;
+ anchor_implementation_ = impl;
+ }
+ if (outputs.size() != 1) {
+ const auto* tuple_type = call_node->checked_type().as<TupleTypeNode>();
+ ICHECK(tuple_type) << "Expect output to be a tuple type";
+ ICHECK_EQ(tuple_type->fields.size(), outputs.size());
+ }
+ // Set the name to `__copy`. It will be detected in graph runtime to
perform
+ // data copy across devices.
+ if (op == device_copy_op_) {
+ readable_name_stream_.str(std::string());
+ readable_name_stream_ << "__copy";
+ } else {
+ readable_name_stream_ << '_' << op->name;
+ }
+ return outputs;
+ }
+
+ Array<te::Tensor> VisitExpr_(const FunctionNode* op) final {
+ LOG(FATAL) << "Do not support sub function";
Review comment:
Can you add what is supported?
##########
File path: tests/python/relay/test_backend_graph_runtime.py
##########
@@ -231,10 +232,12 @@ def test_graph_executor_nested_tuples():
if __name__ == "__main__":
- test_plan_memory()
- test_with_params()
- test_add_op_scalar()
test_add_op_tensor()
- test_add_op_broadcast()
- test_gru_like()
- test_compile_nested_tuples()
+ # test_plan_memory()
+ # test_with_params()
+ # test_add_op_scalar()
+ # test_add_op_tensor()
+ # test_add_op_broadcast()
+ # test_gru_like()
+ # test_compile_nested_tuples()
+ # test_add_op_tensor()
Review comment:
uncomment?
##########
File path: src/relay/backend/te_compiler_cache.cc
##########
@@ -0,0 +1,681 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "./te_compiler_cache.h"
+
+#include <tvm/driver/driver_api.h>
+#include <tvm/ir/type_functor.h>
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/attrs/device_copy.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/op.h>
+#include <tvm/relay/op_attr_types.h>
+#include <tvm/runtime/container.h>
+#include <tvm/runtime/device_api.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/te/operation.h>
+#include <tvm/te/schedule.h>
+#include <tvm/te/schedule_pass.h>
+#include <tvm/topi/tags.h>
+
+#include <functional>
+#include <limits>
+#include <mutex>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "../transforms/pass_utils.h"
+#include "utils.h"
+
+namespace tvm {
+namespace relay {
+namespace tec {
+
+TVM_REGISTER_NODE_TYPE(LoweredOutputNode);
+TVM_REGISTER_NODE_TYPE(CachedFuncNode);
+TVM_REGISTER_NODE_TYPE(CCacheKeyNode);
+TVM_REGISTER_NODE_TYPE(CCacheValueNode);
+
+LoweredOutput::LoweredOutput(tvm::Array<te::Tensor> outputs, OpImplementation
impl) {
+ auto n = make_object<LoweredOutputNode>();
+ n->outputs = std::move(outputs);
+ n->implementation = std::move(impl);
+ data_ = std::move(n);
+}
+
+CCacheKey::CCacheKey(Function source_func, Target target) {
+ auto n = make_object<CCacheKeyNode>();
+ n->source_func = std::move(source_func);
+ n->target = std::move(target);
+ data_ = std::move(n);
+}
+
+CachedFunc::CachedFunc(tvm::Target target, GlobalVar prim_fn_var,
tvm::Array<te::Tensor> inputs,
+ tvm::Array<te::Tensor> outputs, te::Schedule schedule,
+ tvm::Array<Integer> shape_func_param_states, IRModule
funcs) {
+ auto n = make_object<CachedFuncNode>();
+ n->target = target;
+ n->prim_fn_var = prim_fn_var;
+ n->inputs = inputs;
+ n->outputs = outputs;
+ n->schedule = schedule;
+ n->shape_func_param_states = shape_func_param_states;
+ n->funcs = funcs;
+ data_ = std::move(n);
+}
+
+Array<IndexExpr> GetShape(const Array<IndexExpr>& shape) {
+ // for now, we always use int32 shape when possible
+ // even if the result of shape inference becomes int64.
+ Array<IndexExpr> res;
+ for (IndexExpr val : shape) {
+ const int64_t* pval = tir::as_const_int(val);
+ if (pval != nullptr) {
+#ifndef TVM_INDEX_DEFAULT_I64
+ ICHECK_LE(pval[0], std::numeric_limits<int32_t>::max());
+ ICHECK_GE(pval[0], std::numeric_limits<int32_t>::min());
+ res.push_back(IntImm(DataType::Int(32), *pval));
+#else
+ res.push_back(val);
+#endif // TVM_INDEX_DEFAULT_I64
+ } else if (val->IsInstance<tir::AnyNode>()) {
+ res.push_back(val.as<tir::AnyNode>()->ToVar());
+ } else {
+ res.push_back(val);
+ }
+ }
+ return res;
+}
+
+// The getter to get schedule from compile engine.
+// Get schedule from functor.
+class ScheduleGetter : public
backend::MemoizedExprTranslator<Array<te::Tensor>> {
+ public:
+ explicit ScheduleGetter(Target target)
+ : target_(target), device_copy_op_(Op::Get("device_copy")) {
+ // Whether to use auto_scheduler schedule.
+ use_auto_scheduler_ = backend::IsAutoSchedulerEnabled();
+ }
+
+ CachedFunc Create(const Function& prim_func,
std::function<std::string(std::string)> renamer) {
+ Array<tvm::te::Tensor> fn_inputs;
+ for (Var param : prim_func->params) {
+ Array<tvm::te::Tensor> inputs;
+ if (const auto* ttype = param->checked_type().as<TensorTypeNode>()) {
+ tvm::te::Tensor tensor = tvm::te::placeholder(GetShape(ttype->shape),
ttype->dtype);
+ fn_inputs.push_back(tensor);
+ inputs.push_back(tensor);
+ } else {
+ // flatten tuple of tensor type.
+ const auto* tuple_type = param->type_as<TupleTypeNode>();
+ for (Type field : tuple_type->fields) {
+ const auto* ttype = field.as<TensorTypeNode>();
+ // TODO(@icemelon): Allow recursive tuple
+ ICHECK(ttype != nullptr);
+ tvm::te::Tensor tensor =
tvm::te::placeholder(GetShape(ttype->shape), ttype->dtype);
+ fn_inputs.push_back(tensor);
+ inputs.push_back(tensor);
+ }
+ }
+ memo_[param] = inputs;
+ }
+ readable_name_stream_ << "fused";
+ auto outputs = this->VisitExpr(prim_func->body);
+ auto candidate_name = readable_name_stream_.str();
+ constexpr static size_t kMaxFuncNameLength = 80;
+ if (candidate_name.size() > kMaxFuncNameLength) {
+ std::stringstream truncated_name;
+ truncated_name << candidate_name.substr(0, kMaxFuncNameLength);
+ truncated_name << "_" << std::hash<std::string>{}(candidate_name) << "_";
+ candidate_name = truncated_name.str();
+ }
+
+ ICHECK(anchor_op_.defined());
+ // Fusion over tupled results may leave identity relationships
+ // between inputs and outputs, and those should not be scheduled.
+ // Hence schedule only non PlaceholderOp outputs.
+ tvm::Array<te::Tensor> tensor_outs;
+ for (const auto& tensor : outputs) {
+ if (!tensor->op.as<te::PlaceholderOpNode>()) {
+ tensor_outs.push_back(tensor);
+ }
+ }
+
+ te::Schedule schedule;
+ // No need to register schedule for device copy op.
+ if (anchor_attrs_.as<DeviceCopyAttrs>() == nullptr) {
+ if (use_auto_scheduler_) {
+ const auto* fauto_schedule =
+
runtime::Registry::Get("auto_scheduler.relay_integration.auto_schedule_topi_compute");
+ ICHECK(fauto_schedule != nullptr)
+ << "auto_scheduler.relay_integration.auto_schedule_topi_compute is
not registered";
+ ObjectRef obj = (*fauto_schedule)(tensor_outs);
+ if (obj.defined()) {
+ schedule = Downcast<te::Schedule>(obj);
+ }
+ }
+
+ // Use TOPI schdule if user specificed, or the function has no
auto_scheduler schedule.
+ if (!schedule.defined()) {
+ ICHECK(anchor_implementation_.defined());
+ schedule = anchor_implementation_.Schedule(anchor_attrs_, tensor_outs,
target_);
+ }
+ for (const auto& scalar : scalars_) {
+ if (schedule->Contain(scalar)) {
+ schedule[scalar].compute_inline();
+ }
+ }
+ }
+
+ auto prim_fn_var = GlobalVar(candidate_name);
+ prim_fn_var->checked_type_ = prim_func->checked_type();
+ return CachedFunc(target_, prim_fn_var, fn_inputs, outputs, schedule, {});
+ }
+
+ Array<te::Tensor> VisitExpr_(const VarNode* op) final {
+ LOG(FATAL) << "Free variable " << op->name_hint();
+ return {};
+ }
+
+ Array<te::Tensor> VisitExpr_(const ConstantNode* op) final {
+ using tir::make_const;
+ ICHECK(op->is_scalar());
+ void* data = op->data->data;
+ DataType dtype = DataType(op->data->dtype);
+ auto value = te::compute(
+ {},
+ [&](const Array<tvm::tir::Var>&) {
+ if (dtype == DataType::Int(32)) {
+ return make_const(dtype, static_cast<const int32_t*>(data)[0]);
+ } else if (dtype == DataType::Int(64)) {
+ return make_const(dtype, static_cast<const int64_t*>(data)[0]);
+ } else if (dtype == DataType::Float(32)) {
+ return make_const(dtype, static_cast<const float*>(data)[0]);
+ } else if (dtype == DataType::Float(64)) {
+ return make_const(dtype, static_cast<const double*>(data)[0]);
+ } else if (dtype == DataType::Bool()) {
+ return make_const(dtype, static_cast<const uint8_t*>(data)[0]);
+ } else {
+ LOG(FATAL) << "not handled";
+ return tvm::PrimExpr();
+ }
+ },
+ "compile_engine_const", topi::kBroadcast);
+ scalars_.push_back(value->op);
+ return {value};
+ }
+
+ Array<te::Tensor> VisitExpr_(const CallNode* call_node) final {
+ static auto fpattern = Op::GetAttrMap<TOpPattern>("TOpPattern");
+ static auto flower_call =
tvm::runtime::Registry::Get("relay.backend.lower_call");
+ ICHECK(flower_call) << "relay.backend.lower_call is not registered.";
+
+ Array<te::Tensor> inputs;
+ int count_tuple = 0;
+ for (Expr arg : call_node->args) {
+ if (arg->checked_type().as<TupleTypeNode>()) {
+ ++count_tuple;
+ }
+ for (te::Tensor tensor : VisitExpr(arg)) {
+ inputs.push_back(tensor);
+ }
+ }
+ if (count_tuple) {
+ ICHECK_EQ(call_node->args.size(), 1U) << "Only allow function with a
single tuple input";
+ }
+
+ ICHECK(call_node->op.as<OpNode>()) << "Primitive function only allows call
into primitive ops";
+ Op op = Downcast<Op>(call_node->op);
+
+ Array<te::Tensor> outputs;
+ OpImplementation impl;
+ // Skip fcompute for device copy operators as it is not registered.
+ if (op == device_copy_op_) {
+ const auto* copy_input = inputs[0].operator->();
+ outputs.push_back(te::Tensor(copy_input->shape, copy_input->dtype,
te::Operation(), 0));
+ } else {
+ LoweredOutput lowered_out = (*flower_call)(GetRef<Call>(call_node),
inputs, target_);
+ outputs = lowered_out->outputs;
+ impl = lowered_out->implementation;
+ }
+
+ int op_pattern = fpattern[op];
+ if (!use_auto_scheduler_ && op_pattern >= kCommReduce) {
+ ICHECK(!anchor_op_.defined() || anchor_op_pattern_ < kCommReduce)
+ << "Cannot apply TOPI schedule to a primitive function with two
complicated ops"
+ << " anchor=" << anchor_op_ << " current=" << op;
+ }
+ if (op_pattern >= anchor_op_pattern_) {
+ anchor_op_ = op;
+ anchor_attrs_ = call_node->attrs;
+ anchor_op_pattern_ = op_pattern;
+ anchor_implementation_ = impl;
+ }
+ if (outputs.size() != 1) {
+ const auto* tuple_type = call_node->checked_type().as<TupleTypeNode>();
+ ICHECK(tuple_type) << "Expect output to be a tuple type";
+ ICHECK_EQ(tuple_type->fields.size(), outputs.size());
+ }
+ // Set the name to `__copy`. It will be detected in graph runtime to
perform
+ // data copy across devices.
+ if (op == device_copy_op_) {
+ readable_name_stream_.str(std::string());
+ readable_name_stream_ << "__copy";
+ } else {
+ readable_name_stream_ << '_' << op->name;
+ }
+ return outputs;
+ }
+
+ Array<te::Tensor> VisitExpr_(const FunctionNode* op) final {
+ LOG(FATAL) << "Do not support sub function";
+ return Array<te::Tensor>();
+ }
+
+ Array<te::Tensor> VisitExpr_(const LetNode* op) final {
+ Array<te::Tensor> val = VisitExpr(op->value);
+ ICHECK(!memo_.count(op->var));
+ memo_[op->var] = val;
+ return VisitExpr(op->body);
+ }
+
+ Array<te::Tensor> VisitExpr_(const TupleNode* op) final {
+ Array<te::Tensor> fields;
+ for (Expr field : op->fields) {
+ ICHECK(field->checked_type().as<TensorTypeNode>()) << "Only allow Tuple
of Tensor";
+ Array<te::Tensor> res = VisitExpr(field);
+ ICHECK_EQ(res.size(), 1);
+ fields.push_back(res[0]);
+ }
+ return fields;
+ }
+
+ Array<te::Tensor> VisitExpr_(const TupleGetItemNode* op) final {
+ const auto* tuple_type = op->tuple->type_as<TupleTypeNode>();
+ Array<te::Tensor> tuple = VisitExpr(op->tuple);
+ ICHECK_EQ(tuple_type->fields.size(), tuple.size());
+ ICHECK_GE(op->index, 0);
+ ICHECK_LT(static_cast<size_t>(op->index), tuple.size());
+ return {tuple[op->index]};
+ }
+
+ private:
+ tvm::Target target_;
+ Op anchor_op_;
+ Attrs anchor_attrs_;
+ int anchor_op_pattern_{0};
+ OpImplementation anchor_implementation_;
+ std::ostringstream readable_name_stream_;
+ Array<te::Operation> scalars_;
+ bool use_auto_scheduler_;
+ // Cache device copy op for equivalence checking to reduce registry lookup
+ // overhead for each invocation of call node when retrieving schedules.
+ const Op& device_copy_op_;
+};
+
+/*!
+ * \brief Create schedule for target.
+ * \param source_func The primitive function to be lowered.
+ * \param target The target we want to create schedule for.
+ * \return Pair of schedule and cache.
+ * The funcs field in cache is not yet populated.
+ */
+CachedFunc PrimFuncFor(const Function& source_func, const Target& target,
+ std::function<std::string(std::string)> renamer) {
+ return ScheduleGetter(target).Create(source_func, renamer);
+}
+
+// Creates shape function from functor.
+class MakeShapeFunc : public
backend::MemoizedExprTranslator<Array<te::Tensor>> {
+ public:
+ MakeShapeFunc() {}
+
+ CachedFunc Create(const Function& prim_func, const Target& target,
+ std::function<std::string(std::string)> renamer) {
+ Array<te::Tensor> inputs;
+ TShapeDataDependent shape_func_param_states;
+
+ for (auto param : prim_func->params) {
+ param_states_[param] = kNoNeed;
+ Array<tvm::te::Tensor> data_inputs;
+ Array<tvm::te::Tensor> shape_inputs;
+
+ auto add_placeholder = [&data_inputs, &shape_inputs](const
TensorTypeNode* ttype) {
+ // Add data placeholder
+ Shape shape = GetShape(ttype->shape);
+ tvm::te::Tensor data_tensor = tvm::te::placeholder(shape,
ttype->dtype);
+ data_inputs.push_back(data_tensor);
+ // Add shape placeholder
+ int64_t ndim = shape.size();
+ Shape sshape;
+ if (ndim > 0) {
+ sshape.push_back(tvm::Integer(ndim));
+ }
+ tvm::te::Tensor shape_tensor = tvm::te::placeholder(sshape,
DataType::Int(64));
+ shape_inputs.push_back(shape_tensor);
+ };
+
+ if (const auto* ttype = param->checked_type().as<TensorTypeNode>()) {
+ add_placeholder(ttype);
+ } else {
+ // flatten tuple of tensor type.
+ const auto* tuple_type = param->type_as<TupleTypeNode>();
+ // TODO(@icemelon): Support recursive tuple
+ ICHECK(tuple_type);
+ for (Type field : tuple_type->fields) {
+ const auto* ttype = field.as<TensorTypeNode>();
+ ICHECK(ttype);
+ add_placeholder(ttype);
+ }
+ }
+ param_data_[param] = data_inputs;
+ param_shapes_[param] = shape_inputs;
+ }
+
+ // Setup the name;
+ readable_name_stream_ << "shape_func";
+
+ // Create the `te::Tensor`s which represent the output.
+ auto outputs = VisitExpr(prim_func->body);
+
+ // Generate a name.
+ auto candidate_name = readable_name_stream_.str();
+ constexpr static size_t kMaxFuncNameLength = 80;
+ if (candidate_name.size() > kMaxFuncNameLength) {
+ std::stringstream truncated_name;
+ truncated_name << candidate_name.substr(0, kMaxFuncNameLength);
+ truncated_name << "_" << std::hash<std::string>{}(candidate_name) << "_";
+ candidate_name = truncated_name.str();
+ }
+
+ auto func_name = renamer(candidate_name);
+
+ // Set all the inputs correctly.
+ for (auto param : prim_func->params) {
+ int state = param_states_[param];
+ shape_func_param_states.push_back(IntImm(DataType::Int(32), state));
+ if (state & kNeedInputData) {
+ for (auto t : param_data_[param]) {
+ inputs.push_back(t);
+ }
+ }
+ if (state & kNeedInputShape) {
+ for (auto t : param_shapes_[param]) {
+ inputs.push_back(t);
+ }
+ }
+ }
+
+ auto prim_fn_gvar = GlobalVar(func_name);
+ prim_fn_gvar->checked_type_ = prim_func->checked_type();
+
+ // generate schedule for shape func
+ Array<te::Operation> out_ops;
+ for (auto t : outputs) {
+ out_ops.push_back(t->op);
+ }
+ auto schedule = te::create_schedule(out_ops);
+ tvm::te::AutoInlineInjective(schedule);
+ for (const auto& scalar : scalars_) {
+ auto scalar_op = scalar->op;
+ if (schedule->Contain(scalar_op)) {
+ schedule[scalar_op].compute_inline();
+ }
+ }
+
+ Array<te::Tensor> all_args = Array<te::Tensor>(inputs);
+ for (te::Tensor arg : outputs) {
+ all_args.push_back(arg);
+ }
+
+ using tvm::transform::PassContext;
+ With<PassContext> fresh_pass_ctx_scope(PassContext::Create());
+
+ std::unordered_map<te::Tensor, tir::Buffer> binds;
+ auto ir_module = tvm::lower(schedule, all_args, func_name, binds);
+
+ return CachedFunc(target, prim_fn_gvar, inputs, outputs, schedule,
shape_func_param_states,
+ ir_module);
+ }
+
+ Array<te::Tensor> VisitExpr(const Expr& expr) final {
+ if (expr.as<VarNode>()) {
+ // Do not memoize vars because shape functions could use either the data
+ // or the shape of a var each time.
+ return ExprFunctor::VisitExpr(expr);
+ }
+ // For other case, do memoized visit
+ return backend::MemoizedExprTranslator<Array<te::Tensor>>::VisitExpr(expr);
+ }
+
+ Array<te::Tensor> VisitExpr_(const VarNode* var_node) final {
+ auto var = GetRef<Var>(var_node);
+ auto it = param_states_.find(var);
+ if (it == param_states_.end()) {
+ LOG(FATAL) << "Free variable " << var->name_hint();
Review comment:
```suggestion
LOG(FATAL) << "Unexpected free variable " << var->name_hint();
```
##########
File path: src/relay/backend/te_compiler_cache.cc
##########
@@ -0,0 +1,681 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "./te_compiler_cache.h"
+
+#include <tvm/driver/driver_api.h>
+#include <tvm/ir/type_functor.h>
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/attrs/device_copy.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/op.h>
+#include <tvm/relay/op_attr_types.h>
+#include <tvm/runtime/container.h>
+#include <tvm/runtime/device_api.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/te/operation.h>
+#include <tvm/te/schedule.h>
+#include <tvm/te/schedule_pass.h>
+#include <tvm/topi/tags.h>
+
+#include <functional>
+#include <limits>
+#include <mutex>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "../transforms/pass_utils.h"
+#include "utils.h"
+
+namespace tvm {
+namespace relay {
+namespace tec {
+
+TVM_REGISTER_NODE_TYPE(LoweredOutputNode);
+TVM_REGISTER_NODE_TYPE(CachedFuncNode);
+TVM_REGISTER_NODE_TYPE(CCacheKeyNode);
+TVM_REGISTER_NODE_TYPE(CCacheValueNode);
+
+LoweredOutput::LoweredOutput(tvm::Array<te::Tensor> outputs, OpImplementation
impl) {
+ auto n = make_object<LoweredOutputNode>();
+ n->outputs = std::move(outputs);
+ n->implementation = std::move(impl);
+ data_ = std::move(n);
+}
+
+CCacheKey::CCacheKey(Function source_func, Target target) {
+ auto n = make_object<CCacheKeyNode>();
+ n->source_func = std::move(source_func);
+ n->target = std::move(target);
+ data_ = std::move(n);
+}
+
+CachedFunc::CachedFunc(tvm::Target target, GlobalVar prim_fn_var,
tvm::Array<te::Tensor> inputs,
+ tvm::Array<te::Tensor> outputs, te::Schedule schedule,
+ tvm::Array<Integer> shape_func_param_states, IRModule
funcs) {
+ auto n = make_object<CachedFuncNode>();
+ n->target = target;
+ n->prim_fn_var = prim_fn_var;
+ n->inputs = inputs;
+ n->outputs = outputs;
+ n->schedule = schedule;
+ n->shape_func_param_states = shape_func_param_states;
+ n->funcs = funcs;
+ data_ = std::move(n);
+}
+
+Array<IndexExpr> GetShape(const Array<IndexExpr>& shape) {
+ // for now, we always use int32 shape when possible
+ // even if the result of shape inference becomes int64.
+ Array<IndexExpr> res;
+ for (IndexExpr val : shape) {
+ const int64_t* pval = tir::as_const_int(val);
+ if (pval != nullptr) {
+#ifndef TVM_INDEX_DEFAULT_I64
+ ICHECK_LE(pval[0], std::numeric_limits<int32_t>::max());
+ ICHECK_GE(pval[0], std::numeric_limits<int32_t>::min());
+ res.push_back(IntImm(DataType::Int(32), *pval));
+#else
+ res.push_back(val);
+#endif // TVM_INDEX_DEFAULT_I64
+ } else if (val->IsInstance<tir::AnyNode>()) {
+ res.push_back(val.as<tir::AnyNode>()->ToVar());
+ } else {
+ res.push_back(val);
+ }
+ }
+ return res;
+}
+
+// The getter to get schedule from compile engine.
+// Get schedule from functor.
+class ScheduleGetter : public
backend::MemoizedExprTranslator<Array<te::Tensor>> {
+ public:
+ explicit ScheduleGetter(Target target)
+ : target_(target), device_copy_op_(Op::Get("device_copy")) {
+ // Whether to use auto_scheduler schedule.
+ use_auto_scheduler_ = backend::IsAutoSchedulerEnabled();
+ }
+
+ CachedFunc Create(const Function& prim_func,
std::function<std::string(std::string)> renamer) {
+ Array<tvm::te::Tensor> fn_inputs;
+ for (Var param : prim_func->params) {
+ Array<tvm::te::Tensor> inputs;
+ if (const auto* ttype = param->checked_type().as<TensorTypeNode>()) {
+ tvm::te::Tensor tensor = tvm::te::placeholder(GetShape(ttype->shape),
ttype->dtype);
+ fn_inputs.push_back(tensor);
+ inputs.push_back(tensor);
+ } else {
+ // flatten tuple of tensor type.
+ const auto* tuple_type = param->type_as<TupleTypeNode>();
+ for (Type field : tuple_type->fields) {
+ const auto* ttype = field.as<TensorTypeNode>();
+ // TODO(@icemelon): Allow recursive tuple
+ ICHECK(ttype != nullptr);
+ tvm::te::Tensor tensor =
tvm::te::placeholder(GetShape(ttype->shape), ttype->dtype);
+ fn_inputs.push_back(tensor);
+ inputs.push_back(tensor);
+ }
+ }
+ memo_[param] = inputs;
+ }
+ readable_name_stream_ << "fused";
+ auto outputs = this->VisitExpr(prim_func->body);
+ auto candidate_name = readable_name_stream_.str();
+ constexpr static size_t kMaxFuncNameLength = 80;
+ if (candidate_name.size() > kMaxFuncNameLength) {
+ std::stringstream truncated_name;
+ truncated_name << candidate_name.substr(0, kMaxFuncNameLength);
+ truncated_name << "_" << std::hash<std::string>{}(candidate_name) << "_";
+ candidate_name = truncated_name.str();
+ }
+
+ ICHECK(anchor_op_.defined());
+ // Fusion over tupled results may leave identity relationships
+ // between inputs and outputs, and those should not be scheduled.
+ // Hence schedule only non PlaceholderOp outputs.
+ tvm::Array<te::Tensor> tensor_outs;
+ for (const auto& tensor : outputs) {
+ if (!tensor->op.as<te::PlaceholderOpNode>()) {
+ tensor_outs.push_back(tensor);
+ }
+ }
+
+ te::Schedule schedule;
+ // No need to register schedule for device copy op.
+ if (anchor_attrs_.as<DeviceCopyAttrs>() == nullptr) {
+ if (use_auto_scheduler_) {
+ const auto* fauto_schedule =
+
runtime::Registry::Get("auto_scheduler.relay_integration.auto_schedule_topi_compute");
+ ICHECK(fauto_schedule != nullptr)
+ << "auto_scheduler.relay_integration.auto_schedule_topi_compute is
not registered";
+ ObjectRef obj = (*fauto_schedule)(tensor_outs);
+ if (obj.defined()) {
+ schedule = Downcast<te::Schedule>(obj);
+ }
+ }
+
+ // Use TOPI schdule if user specificed, or the function has no
auto_scheduler schedule.
+ if (!schedule.defined()) {
+ ICHECK(anchor_implementation_.defined());
+ schedule = anchor_implementation_.Schedule(anchor_attrs_, tensor_outs,
target_);
+ }
+ for (const auto& scalar : scalars_) {
+ if (schedule->Contain(scalar)) {
+ schedule[scalar].compute_inline();
+ }
+ }
+ }
+
+ auto prim_fn_var = GlobalVar(candidate_name);
+ prim_fn_var->checked_type_ = prim_func->checked_type();
+ return CachedFunc(target_, prim_fn_var, fn_inputs, outputs, schedule, {});
+ }
+
+ Array<te::Tensor> VisitExpr_(const VarNode* op) final {
+ LOG(FATAL) << "Free variable " << op->name_hint();
+ return {};
+ }
+
+ Array<te::Tensor> VisitExpr_(const ConstantNode* op) final {
+ using tir::make_const;
+ ICHECK(op->is_scalar());
+ void* data = op->data->data;
+ DataType dtype = DataType(op->data->dtype);
+ auto value = te::compute(
+ {},
+ [&](const Array<tvm::tir::Var>&) {
+ if (dtype == DataType::Int(32)) {
+ return make_const(dtype, static_cast<const int32_t*>(data)[0]);
+ } else if (dtype == DataType::Int(64)) {
+ return make_const(dtype, static_cast<const int64_t*>(data)[0]);
+ } else if (dtype == DataType::Float(32)) {
+ return make_const(dtype, static_cast<const float*>(data)[0]);
+ } else if (dtype == DataType::Float(64)) {
+ return make_const(dtype, static_cast<const double*>(data)[0]);
+ } else if (dtype == DataType::Bool()) {
+ return make_const(dtype, static_cast<const uint8_t*>(data)[0]);
+ } else {
+ LOG(FATAL) << "not handled";
+ return tvm::PrimExpr();
+ }
+ },
+ "compile_engine_const", topi::kBroadcast);
+ scalars_.push_back(value->op);
+ return {value};
+ }
+
+ Array<te::Tensor> VisitExpr_(const CallNode* call_node) final {
+ static auto fpattern = Op::GetAttrMap<TOpPattern>("TOpPattern");
+ static auto flower_call =
tvm::runtime::Registry::Get("relay.backend.lower_call");
+ ICHECK(flower_call) << "relay.backend.lower_call is not registered.";
+
+ Array<te::Tensor> inputs;
+ int count_tuple = 0;
+ for (Expr arg : call_node->args) {
+ if (arg->checked_type().as<TupleTypeNode>()) {
+ ++count_tuple;
+ }
+ for (te::Tensor tensor : VisitExpr(arg)) {
+ inputs.push_back(tensor);
+ }
+ }
+ if (count_tuple) {
+ ICHECK_EQ(call_node->args.size(), 1U) << "Only allow function with a
single tuple input";
+ }
+
+ ICHECK(call_node->op.as<OpNode>()) << "Primitive function only allows call
into primitive ops";
+ Op op = Downcast<Op>(call_node->op);
+
+ Array<te::Tensor> outputs;
+ OpImplementation impl;
+ // Skip fcompute for device copy operators as it is not registered.
+ if (op == device_copy_op_) {
+ const auto* copy_input = inputs[0].operator->();
+ outputs.push_back(te::Tensor(copy_input->shape, copy_input->dtype,
te::Operation(), 0));
+ } else {
+ LoweredOutput lowered_out = (*flower_call)(GetRef<Call>(call_node),
inputs, target_);
+ outputs = lowered_out->outputs;
+ impl = lowered_out->implementation;
+ }
+
+ int op_pattern = fpattern[op];
+ if (!use_auto_scheduler_ && op_pattern >= kCommReduce) {
+ ICHECK(!anchor_op_.defined() || anchor_op_pattern_ < kCommReduce)
+ << "Cannot apply TOPI schedule to a primitive function with two
complicated ops"
+ << " anchor=" << anchor_op_ << " current=" << op;
+ }
+ if (op_pattern >= anchor_op_pattern_) {
+ anchor_op_ = op;
+ anchor_attrs_ = call_node->attrs;
+ anchor_op_pattern_ = op_pattern;
+ anchor_implementation_ = impl;
+ }
+ if (outputs.size() != 1) {
+ const auto* tuple_type = call_node->checked_type().as<TupleTypeNode>();
+ ICHECK(tuple_type) << "Expect output to be a tuple type";
+ ICHECK_EQ(tuple_type->fields.size(), outputs.size());
+ }
+ // Set the name to `__copy`. It will be detected in graph runtime to
perform
+ // data copy across devices.
+ if (op == device_copy_op_) {
+ readable_name_stream_.str(std::string());
+ readable_name_stream_ << "__copy";
+ } else {
+ readable_name_stream_ << '_' << op->name;
+ }
+ return outputs;
+ }
+
+ Array<te::Tensor> VisitExpr_(const FunctionNode* op) final {
+ LOG(FATAL) << "Do not support sub function";
+ return Array<te::Tensor>();
+ }
+
+ Array<te::Tensor> VisitExpr_(const LetNode* op) final {
+ Array<te::Tensor> val = VisitExpr(op->value);
+ ICHECK(!memo_.count(op->var));
+ memo_[op->var] = val;
+ return VisitExpr(op->body);
+ }
+
+ Array<te::Tensor> VisitExpr_(const TupleNode* op) final {
+ Array<te::Tensor> fields;
+ for (Expr field : op->fields) {
+ ICHECK(field->checked_type().as<TensorTypeNode>()) << "Only allow Tuple
of Tensor";
Review comment:
```suggestion
ICHECK(field->checked_type().as<TensorTypeNode>()) << "Expected a
Tuple of Tensor, but got " << field->checked_type()->GetTypeName();
```
##########
File path: src/relay/ir/function.cc
##########
@@ -62,9 +62,12 @@ TVM_REGISTER_GLOBAL("relay.ir.Function")
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<FunctionNode>([](const ObjectRef& ref, ReprPrinter* p) {
- auto* node = static_cast<const FunctionNode*>(ref.get());
- p->stream << "FunctionNode(" << node->params << ", " << node->ret_type
<< ", " << node->body
- << ", " << node->type_params << ", " << node->attrs << ")";
+ // auto* node = static_cast<const FunctionNode*>(ref.get());
+ // p->stream << "FunctionNode(" << node->params << ", " <<
node->ret_type << ", " <<
+ // node->body
+ // << ", " << node->type_params << ", " << node->attrs << ")";
Review comment:
remove?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]