This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 2c61afa Switch from CompileEngine to TECompiler in Interpreter (#8486)
2c61afa is described below
commit 2c61afa4d1f67c6b41467e94868f6813feb89d6c
Author: mbs-octoml <[email protected]>
AuthorDate: Fri Jul 16 23:57:18 2021 -0700
Switch from CompileEngine to TECompiler in Interpreter (#8486)
This continues on:
https://discuss.tvm.apache.org/t/rfc-relay-tecompiler-rewrite-existing-compile-engine-to-match-updated-compiler-flow/9233
and #751, this time just replacing CompileEngine with TECompiler in the
Interpreter,
using the JIT helper added to help the transition.
Some whitespace improvements while there.
---
include/tvm/relay/function.h | 2 +-
include/tvm/runtime/device_api.h | 17 +++++++-------
python/tvm/relay/backend/graph_executor_codegen.py | 6 ++---
src/parser/source_map.cc | 1 -
src/relay/backend/interpreter.cc | 27 +++++++++++-----------
src/relay/backend/te_compiler.cc | 2 +-
src/relay/backend/te_compiler.h | 11 +++++----
7 files changed, 32 insertions(+), 34 deletions(-)
diff --git a/include/tvm/relay/function.h b/include/tvm/relay/function.h
index 95eaad0..fccd1f9 100644
--- a/include/tvm/relay/function.h
+++ b/include/tvm/relay/function.h
@@ -126,7 +126,7 @@ namespace attr {
/*! \brief Mark the function as a primitive function. */
constexpr const char* kPrimitive = "Primitive";
/*!
- * \brief Indicate the compiler that should be used for builing this function.
+ * \brief Indicate the compiler that should be used for building this function.
* When this is unset or set to "default", the default compilation pipeline
will be used.
*/
constexpr const char* kCompiler = "Compiler";
diff --git a/include/tvm/runtime/device_api.h b/include/tvm/runtime/device_api.h
index 58b9ff1..7118857 100644
--- a/include/tvm/runtime/device_api.h
+++ b/include/tvm/runtime/device_api.h
@@ -291,7 +291,14 @@ inline Device RemoveRPCSessionMask(Device dev) {
return dev;
}
-inline std::ostream& operator<<(std::ostream& os, DLDevice dev);
+inline std::ostream& operator<<(std::ostream& os, DLDevice dev) { // NOLINT(*)
+ if (tvm::runtime::IsRPCSessionDevice(dev)) {
+ os << "remote[" << tvm::runtime::GetRPCSessionIndex(dev) << "]-";
+ dev = tvm::runtime::RemoveRPCSessionMask(dev);
+ }
+ os << tvm::runtime::DeviceName(static_cast<int>(dev.device_type)) << "(" <<
dev.device_id << ")";
+ return os;
+}
/*!
* \brief Add a RPC session mask to a Device.
@@ -308,14 +315,6 @@ inline Device AddRPCSessionMask(Device dev, int
session_table_index) {
return dev;
}
-inline std::ostream& operator<<(std::ostream& os, DLDevice dev) { // NOLINT(*)
- if (IsRPCSessionDevice(dev)) {
- os << "remote[" << GetRPCSessionIndex(dev) << "]-";
- dev = RemoveRPCSessionMask(dev);
- }
- os << runtime::DeviceName(static_cast<int>(dev.device_type)) << "(" <<
dev.device_id << ")";
- return os;
-}
} // namespace runtime
} // namespace tvm
diff --git a/python/tvm/relay/backend/graph_executor_codegen.py
b/python/tvm/relay/backend/graph_executor_codegen.py
index 11274b9..58717a0 100644
--- a/python/tvm/relay/backend/graph_executor_codegen.py
+++ b/python/tvm/relay/backend/graph_executor_codegen.py
@@ -20,14 +20,14 @@ A compiler from a Relay expression to TVM's graph executor.
The compiler is built from a few pieces.
First we define a compiler from a single Relay expression to the
-graph langauge. We require the expression to be a function.
+graph language. We require the expression to be a function.
The function's parameters correspond to the placeholder/inputs
and model parameters found in the computation graph representation.
The body of the function represents the computation graph.
The compiler's output is a program in the graph language, which is composed of
-graph langauge is composed of Node, NodeRef, InputNode, OpNode.
-This "little language" represents programs in TVM's graph format.
+Node, NodeRef, InputNode, OpNode. This "little language" represents programs in
+TVM's graph format.
To connect to the graph executor, we use a printer that converts our graph
format
into TVM's JSON format. The resulting string can be loaded by
diff --git a/src/parser/source_map.cc b/src/parser/source_map.cc
index 7340f69..4e79d0e 100644
--- a/src/parser/source_map.cc
+++ b/src/parser/source_map.cc
@@ -40,7 +40,6 @@ Source::Source(SourceName src_name, std::string source) {
// NB(@jroesch):
std::string source_str = n->source;
for (auto c : source_str) {
- DLOG(INFO) << "char=" << c;
if (c == '\n') {
// Record the length of the line.
n->line_map.back().second = length;
diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc
index 53985c7..6ebb17e 100644
--- a/src/relay/backend/interpreter.cc
+++ b/src/relay/backend/interpreter.cc
@@ -34,6 +34,7 @@
#include "../transforms/pass_utils.h"
#include "compile_engine.h"
+#include "te_compiler.h"
namespace tvm {
namespace relay {
@@ -214,9 +215,7 @@ class Interpreter : public ExprFunctor<ObjectRef(const
Expr& n)>,
PatternFunctor<bool(const Pattern& p, const ObjectRef& v)>
{
public:
Interpreter(IRModule mod, Device device, Target target)
- : mod_(mod), device_(device), target_(target),
debug_op_(Op::Get("debug")) {
- engine_ = CompileEngine::Global();
- }
+ : mod_(mod), device_(device), target_(target),
debug_op_(Op::Get("debug")) {}
template <typename T>
T WithFrame(const Frame& fr, const std::function<T()>& f) {
@@ -286,7 +285,7 @@ class Interpreter : public ExprFunctor<ObjectRef(const
Expr& n)>,
Array<Shape> ComputeDynamicShape(const Function& func, const
Array<ObjectRef>& args) {
CCacheKey key(func, Target("llvm"));
- auto cfunc = engine_->LowerShapeFunc(key);
+ auto cfunc = compiler_->LowerShapeFunc(key);
size_t arity = cfunc->inputs.size() + cfunc->outputs.size();
std::vector<TVMValue> values(arity);
@@ -485,7 +484,7 @@ class Interpreter : public ExprFunctor<ObjectRef(const
Expr& n)>,
out_shapes = ComputeDynamicShape(func, args);
}
- PackedFunc packed_func = engine_->JIT(CCacheKey(func, target_));
+ PackedFunc packed_func = compiler_->JIT(CCacheKey(func, target_));
TVMRetValue rv;
if (const TupleTypeNode* rtype =
func->body->checked_type().as<TupleTypeNode>()) {
ICHECK(!is_dyn || out_shapes.size() == rtype->fields.size());
@@ -555,11 +554,11 @@ class Interpreter : public ExprFunctor<ObjectRef(const
Expr& n)>,
// We should not find operators after running fusion,
// and operator lowering.
//
- // We have some functions cotaining chunks of operators
+ // We have some functions containing chunks of operators
// which will be loaded into operator map.
if (const auto* op_node = call->op.as<OpNode>()) {
LOG(FATAL) << "found " << op_node->name
- << "; operators should be removed by future passes; try "
+ << "; operators should have been removed by previous passes;
try "
"fusing and lowering";
}
if (auto con = call->op.as<ConstructorNode>()) {
@@ -569,9 +568,9 @@ class Interpreter : public ExprFunctor<ObjectRef(const
Expr& n)>,
ObjectRef fn_val = Eval(call->op);
if (const InterpreterClosureObj* closure_node =
fn_val.as<InterpreterClosureObj>()) {
auto closure = GetRef<InterpreterClosure>(closure_node);
- return this->Invoke(closure, args);
+ return Invoke(closure, args);
} else if (const RecClosureObj* closure_node = fn_val.as<RecClosureObj>())
{
- return this->Invoke(closure_node->clos, args, closure_node->bind);
+ return Invoke(closure_node->clos, args, closure_node->bind);
} else {
LOG(FATAL) << "internal error: type error, expected function value in
the call "
<< "position";
@@ -710,17 +709,17 @@ class Interpreter : public ExprFunctor<ObjectRef(const
Expr& n)>,
Target target_;
// Object stack.
Stack stack_;
- // Backend compile engine.
- CompileEngine engine_;
+ // TE-to-TIR lowerer (compiler).
+ TECompiler compiler_;
// Cache ops that need to be frequently used later to reduce lookup overhead.
const Op& debug_op_;
};
TypedPackedFunc<ObjectRef(Expr)> CreateInterpreter(IRModule mod, Device
device, Target target) {
if (mod.defined()) {
- // eta expand to support constructors in argument position
- transform::Sequential seq({transform::EtaExpand(
- /* expand_constructor */ true, /*
expand_global_var */ false),
+ transform::Sequential seq({// eta expand to support constructors in
argument position
+ transform::EtaExpand(
+ /*expand_constructor=*/true,
/*expand_global_var=*/false),
transform::InferType()});
transform::PassContext pass_ctx = transform::PassContext::Current();
diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc
index 93b9c6f..9a0d2a2 100644
--- a/src/relay/backend/te_compiler.cc
+++ b/src/relay/backend/te_compiler.cc
@@ -444,7 +444,7 @@ class LowerTensorExpr : public ExprMutator {
tir_call_attrs->metadata.Set("relay_attrs", func->attrs);
Expr ret_call = Call(lowered_func->prim_fn_var, args,
Attrs(tir_call_attrs));
- return ret_call;
+ return std::move(ret_call);
}
IRModule module_;
diff --git a/src/relay/backend/te_compiler.h b/src/relay/backend/te_compiler.h
index a32eefb..8376b99 100644
--- a/src/relay/backend/te_compiler.h
+++ b/src/relay/backend/te_compiler.h
@@ -76,7 +76,7 @@ using ProcessFn = std::function<void(Function)>;
/*!
* \brief A compiler which lowers primitive Relay functions to tensor
expressions
- * and schdules them into TIR functions.
+ * and schedules them into TIR functions.
*/
class TECompilerNode : public Object {
public:
@@ -178,10 +178,11 @@ Target GetTargetFromInteger(DLDeviceType dev_type,
TargetMap targets);
* This is the "back half" of the Relay compiler which lowers "primitive
functions"
* to TE expressions, schedules them, and then to TIR.
*
- * /param module The IRModule.
- * /param targets The mapping for devices to targets.
- * /param device_map An analysis result mapping each sub-expression to a
device.
- * /return The lowered module, see above.
+ * \param compiler The TE-to-TIR compliler (which caches lowered functions)
+ * \param module The IRModule.
+ * \param targets The mapping for devices to targets.
+ * \param device_map An analysis result mapping each sub-expression to a
device.
+ * \return The lowered module, see above.
*/
// TODO(@electriclilies): Not sure if this default initialization is correct...
LoweredModule LowerTE(