This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new 1a0f44d [Refactor][std::string --> String] IR is updated with String
(#5547)
1a0f44d is described below
commit 1a0f44d4c8dead73a47fbdf44e50c0b8edde5f00
Author: ANSHUMAN TRIPATHY <[email protected]>
AuthorDate: Tue May 12 00:38:04 2020 +0530
[Refactor][std::string --> String] IR is updated with String (#5547)
* [std::string --> String] GlobalTypeVar is updated with String
* [std::string --> String] GlobalVar is updated with String
* [std::string --> String][IR] ADT is updated with String
* [std::string --> String][IR] OP is updated with String
* [std::string --> String][IR] Attrs is updated with String input
* [std::string --> String][IR] GlobalVar is updated with String
* [std::string --> String][Test] Pyconverter is updated with String change
---
include/tvm/ir/adt.h | 2 +-
include/tvm/ir/env_func.h | 2 +-
include/tvm/ir/expr.h | 4 ++--
include/tvm/ir/op.h | 9 +++++----
include/tvm/ir/transform.h | 8 ++++----
include/tvm/ir/type.h | 4 ++--
include/tvm/runtime/container.h | 9 +++++++++
python/tvm/ir/json_compact.py | 4 ++--
python/tvm/relay/testing/py_converter.py | 4 ++--
src/ir/adt.cc | 4 ++--
src/ir/env_func.cc | 2 +-
src/ir/expr.cc | 4 ++--
src/ir/function.cc | 2 +-
src/ir/op.cc | 10 +++++-----
src/ir/transform.cc | 14 +++++++-------
src/ir/type.cc | 4 ++--
src/printer/relay_text_printer.cc | 4 +++-
17 files changed, 51 insertions(+), 39 deletions(-)
diff --git a/include/tvm/ir/adt.h b/include/tvm/ir/adt.h
index 9d45dc1..9b45c66 100644
--- a/include/tvm/ir/adt.h
+++ b/include/tvm/ir/adt.h
@@ -91,7 +91,7 @@ class Constructor : public RelayExpr {
* \param inputs The input types.
* \param belong_to The data type var the constructor will construct.
*/
- TVM_DLL Constructor(std::string name_hint, Array<Type> inputs, GlobalTypeVar
belong_to);
+ TVM_DLL Constructor(String name_hint, Array<Type> inputs, GlobalTypeVar
belong_to);
TVM_DEFINE_OBJECT_REF_METHODS(Constructor, RelayExpr, ConstructorNode);
};
diff --git a/include/tvm/ir/env_func.h b/include/tvm/ir/env_func.h
index 320d6e3..2f80367 100644
--- a/include/tvm/ir/env_func.h
+++ b/include/tvm/ir/env_func.h
@@ -92,7 +92,7 @@ class EnvFunc : public ObjectRef {
* \return The created global function.
* \note The function can be unique
*/
- TVM_DLL static EnvFunc Get(const std::string& name);
+ TVM_DLL static EnvFunc Get(const String& name);
/*! \brief specify container node */
using ContainerType = EnvFuncNode;
};
diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h
index 717ffb1..6797f16 100644
--- a/include/tvm/ir/expr.h
+++ b/include/tvm/ir/expr.h
@@ -188,7 +188,7 @@ class GlobalVar;
class GlobalVarNode : public RelayExprNode {
public:
/*! \brief The name of the variable, this only acts as a hint. */
- std::string name_hint;
+ String name_hint;
void VisitAttrs(AttrVisitor* v) {
v->Visit("name_hint", &name_hint);
@@ -216,7 +216,7 @@ class GlobalVarNode : public RelayExprNode {
*/
class GlobalVar : public RelayExpr {
public:
- TVM_DLL explicit GlobalVar(std::string name_hint);
+ TVM_DLL explicit GlobalVar(String name_hint);
TVM_DEFINE_OBJECT_REF_METHODS(GlobalVar, RelayExpr, GlobalVarNode);
};
diff --git a/include/tvm/ir/op.h b/include/tvm/ir/op.h
index 7fafb5a..aeda4fa 100644
--- a/include/tvm/ir/op.h
+++ b/include/tvm/ir/op.h
@@ -185,7 +185,7 @@ class Op : public RelayExpr {
* \param op_name Name of the operator.
* \return Pointer to a Op, valid throughout program lifetime.
*/
- TVM_DLL static const Op& Get(const std::string& op_name);
+ TVM_DLL static const Op& Get(const String& op_name);
/*! \brief specify container node */
using ContainerType = OpNode;
@@ -196,13 +196,13 @@ class Op : public RelayExpr {
* \param key The attribute key
* \return reference to GenericOpMap
*/
- TVM_DLL static const GenericOpMap& GetGenericAttr(const std::string& key);
+ TVM_DLL static const GenericOpMap& GetGenericAttr(const String& key);
/*!
* \brief Checks if the key is present in the registry
* \param key The attribute key
* \return bool True if the key is present
*/
- TVM_DLL static bool HasGenericAttr(const std::string& key);
+ TVM_DLL static bool HasGenericAttr(const String& key);
};
/*!
@@ -303,7 +303,8 @@ class OpRegistry {
// return internal pointer to op.
inline OpNode* get();
// update the attribute OpMap
- TVM_DLL void UpdateAttr(const std::string& key, runtime::TVMRetValue value,
int plevel);
+
+ TVM_DLL void UpdateAttr(const String& key, runtime::TVMRetValue value, int
plevel);
};
/*!
diff --git a/include/tvm/ir/transform.h b/include/tvm/ir/transform.h
index 558d2da..a825b95 100644
--- a/include/tvm/ir/transform.h
+++ b/include/tvm/ir/transform.h
@@ -224,7 +224,7 @@ class PassInfo : public ObjectRef {
* \param name Name of the pass.
* \param required The passes that are required to perform the current pass.
*/
- TVM_DLL PassInfo(int opt_level, std::string name, Array<runtime::String>
required);
+ TVM_DLL PassInfo(int opt_level, String name, Array<runtime::String>
required);
TVM_DEFINE_OBJECT_REF_METHODS(PassInfo, ObjectRef, PassInfoNode);
};
@@ -327,7 +327,7 @@ class Sequential : public Pass {
* This allows users to only provide a list of passes and execute them
* under a given context.
*/
- TVM_DLL Sequential(Array<Pass> passes, std::string name = "sequential");
+ TVM_DLL Sequential(Array<Pass> passes, String name = "sequential");
Sequential() = default;
explicit Sequential(ObjectPtr<Object> n) : Pass(n) {}
@@ -348,7 +348,7 @@ class Sequential : public Pass {
*/
TVM_DLL Pass
CreateModulePass(const runtime::TypedPackedFunc<IRModule(IRModule,
PassContext)>& pass_func,
- int opt_level, const std::string& name, const
Array<runtime::String>& required);
+ int opt_level, const String& name, const
Array<runtime::String>& required);
/*!
* \brief A special trace pass that prints the header and IR to LOG(INFO).
@@ -356,7 +356,7 @@ CreateModulePass(const
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)>
* \param show_meta_data Whether should we show meta data.
* \return The pass.
*/
-TVM_DLL Pass PrintIR(std::string header = "", bool show_meta_data = false);
+TVM_DLL Pass PrintIR(String header = "", bool show_meta_data = false);
} // namespace transform
} // namespace tvm
diff --git a/include/tvm/ir/type.h b/include/tvm/ir/type.h
index ed64841..65b454f 100644
--- a/include/tvm/ir/type.h
+++ b/include/tvm/ir/type.h
@@ -267,7 +267,7 @@ class GlobalTypeVarNode : public TypeNode {
* this only acts as a hint to the user,
* and is not used for equality.
*/
- std::string name_hint;
+ String name_hint;
/*! \brief The kind of type parameter */
TypeKind kind;
@@ -301,7 +301,7 @@ class GlobalTypeVar : public Type {
* \param name_hint The name of the type var.
* \param kind The kind of the type var.
*/
- TVM_DLL GlobalTypeVar(std::string name_hint, TypeKind kind);
+ TVM_DLL GlobalTypeVar(String name_hint, TypeKind kind);
TVM_DEFINE_OBJECT_REF_METHODS(GlobalTypeVar, Type, GlobalTypeVarNode);
};
diff --git a/include/tvm/runtime/container.h b/include/tvm/runtime/container.h
index 49c005e..e2f2453 100644
--- a/include/tvm/runtime/container.h
+++ b/include/tvm/runtime/container.h
@@ -564,6 +564,15 @@ inline String String::operator=(std::string other) {
return Downcast<String>(*this);
}
+inline String operator+(const std::string lhs, const String& rhs) {
+ return lhs + rhs.operator std::string();
+}
+
+inline std::ostream& operator<<(std::ostream& out, const String& input) {
+ out.write(input.data(), input.size());
+ return out;
+}
+
inline int String::memncmp(const char* lhs, const char* rhs, size_t lhs_count,
size_t rhs_count) {
if (lhs == rhs && lhs_count == rhs_count) return 0;
diff --git a/python/tvm/ir/json_compact.py b/python/tvm/ir/json_compact.py
index fcea9d8..a3ff499 100644
--- a/python/tvm/ir/json_compact.py
+++ b/python/tvm/ir/json_compact.py
@@ -111,7 +111,7 @@ def create_updater_06_to_07():
"EnvFunc": _update_global_key,
"relay.Op": _update_global_key,
"relay.TypeVar": [_ftype_var, _update_from_std_str("name_hint")],
- "relay.GlobalTypeVar": _ftype_var,
+ "relay.GlobalTypeVar": [_ftype_var, _update_from_std_str("name_hint")],
"relay.Type": _rename("Type"),
"relay.TupleType": _rename("TupleType"),
"relay.TypeConstraint": _rename("TypeConstraint"),
@@ -122,7 +122,7 @@ def create_updater_06_to_07():
"relay.Module": _rename("IRModule"),
"relay.SourceName": _rename("SourceName"),
"relay.Span": _rename("Span"),
- "relay.GlobalVar": _rename("GlobalVar"),
+ "relay.GlobalVar": [_rename("GlobalVar"),
_update_from_std_str("name_hint")],
"relay.Pass": _rename("transform.Pass"),
"relay.PassInfo": _rename("transform.PassInfo"),
"relay.PassContext": _rename("transform.PassContext"),
diff --git a/python/tvm/relay/testing/py_converter.py
b/python/tvm/relay/testing/py_converter.py
index 61a04ec..89c3393 100644
--- a/python/tvm/relay/testing/py_converter.py
+++ b/python/tvm/relay/testing/py_converter.py
@@ -190,7 +190,7 @@ class PythonConverter(ExprFunctor):
if name_var is None:
func_name = self.generate_function_name('_anon_func')
if isinstance(name_var, GlobalVar):
- func_name = name_var.name_hint
+ func_name = str(name_var.name_hint)
if isinstance(name_var, Var):
func_name = self.get_var_name(name_var)
@@ -411,7 +411,7 @@ class PythonConverter(ExprFunctor):
def visit_global_var(self, gvar: Expr):
# we don't need to add numbers to global var names because
# the *names* are checked for uniqueness in the mod
- return (Name(gvar.name_hint, Load()), [])
+ return (Name(str(gvar.name_hint), Load()), [])
def visit_let(self, letexp: Expr):
diff --git a/src/ir/adt.cc b/src/ir/adt.cc
index 957905d..f0ce859 100644
--- a/src/ir/adt.cc
+++ b/src/ir/adt.cc
@@ -26,7 +26,7 @@
namespace tvm {
-Constructor::Constructor(std::string name_hint, tvm::Array<Type> inputs,
GlobalTypeVar belong_to) {
+Constructor::Constructor(String name_hint, tvm::Array<Type> inputs,
GlobalTypeVar belong_to) {
ObjectPtr<ConstructorNode> n = make_object<ConstructorNode>();
n->name_hint = std::move(name_hint);
n->inputs = std::move(inputs);
@@ -37,7 +37,7 @@ Constructor::Constructor(std::string name_hint,
tvm::Array<Type> inputs, GlobalT
TVM_REGISTER_NODE_TYPE(ConstructorNode);
TVM_REGISTER_GLOBAL("ir.Constructor")
- .set_body_typed([](std::string name_hint, tvm::Array<Type> inputs,
GlobalTypeVar belong_to) {
+ .set_body_typed([](String name_hint, tvm::Array<Type> inputs,
GlobalTypeVar belong_to) {
return Constructor(name_hint, inputs, belong_to);
});
diff --git a/src/ir/env_func.cc b/src/ir/env_func.cc
index 7deff90..7b0d6e6 100644
--- a/src/ir/env_func.cc
+++ b/src/ir/env_func.cc
@@ -45,7 +45,7 @@ ObjectPtr<Object> CreateEnvNode(const std::string& name) {
return n;
}
-EnvFunc EnvFunc::Get(const std::string& name) { return
EnvFunc(CreateEnvNode(name)); }
+EnvFunc EnvFunc::Get(const String& name) { return
EnvFunc(CreateEnvNode(name)); }
TVM_REGISTER_GLOBAL("ir.EnvFuncGet").set_body_typed(EnvFunc::Get);
diff --git a/src/ir/expr.cc b/src/ir/expr.cc
index 000305b..8b2656b 100644
--- a/src/ir/expr.cc
+++ b/src/ir/expr.cc
@@ -137,7 +137,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->stream << "range(min=" << op->min << ", ext=" << op->extent << ')';
});
-GlobalVar::GlobalVar(std::string name_hint) {
+GlobalVar::GlobalVar(String name_hint) {
ObjectPtr<GlobalVarNode> n = make_object<GlobalVarNode>();
n->name_hint = std::move(name_hint);
data_ = std::move(n);
@@ -145,7 +145,7 @@ GlobalVar::GlobalVar(std::string name_hint) {
TVM_REGISTER_NODE_TYPE(GlobalVarNode);
-TVM_REGISTER_GLOBAL("ir.GlobalVar").set_body_typed([](std::string name) {
+TVM_REGISTER_GLOBAL("ir.GlobalVar").set_body_typed([](String name) {
return GlobalVar(name);
});
diff --git a/src/ir/function.cc b/src/ir/function.cc
index 57d62b4..c0cda70 100644
--- a/src/ir/function.cc
+++ b/src/ir/function.cc
@@ -38,7 +38,7 @@
TVM_REGISTER_GLOBAL("ir.BaseFunc_Attrs").set_body_typed([](BaseFunc func) { retu
TVM_REGISTER_GLOBAL("ir.BaseFuncCopy").set_body_typed([](BaseFunc func) {
return func; });
TVM_REGISTER_GLOBAL("ir.BaseFuncWithAttr")
- .set_body_typed([](BaseFunc func, std::string key, ObjectRef value) ->
BaseFunc {
+ .set_body_typed([](BaseFunc func, String key, ObjectRef value) -> BaseFunc
{
if (func->IsInstance<tir::PrimFuncNode>()) {
return WithAttr(Downcast<tir::PrimFunc>(std::move(func)), key, value);
} else if (func->IsInstance<relay::FunctionNode>()) {
diff --git a/src/ir/op.cc b/src/ir/op.cc
index 8f58768..3a6bcbc 100644
--- a/src/ir/op.cc
+++ b/src/ir/op.cc
@@ -61,7 +61,7 @@ struct OpManager {
};
// find operator by name
-const Op& Op::Get(const std::string& name) {
+const Op& Op::Get(const String& name) {
const OpRegistry* reg = dmlc::Registry<OpRegistry>::Find(name);
CHECK(reg != nullptr) << "Operator " << name << " is not registered";
return reg->op();
@@ -75,7 +75,7 @@ OpRegistry::OpRegistry() {
}
// Get attribute map by key
-const GenericOpMap& Op::GetGenericAttr(const std::string& key) {
+const GenericOpMap& Op::GetGenericAttr(const String& key) {
OpManager* mgr = OpManager::Global();
std::lock_guard<std::mutex> lock(mgr->mutex);
auto it = mgr->attr.find(key);
@@ -86,7 +86,7 @@ const GenericOpMap& Op::GetGenericAttr(const std::string&
key) {
}
// Check if a key is present in the registry.
-bool Op::HasGenericAttr(const std::string& key) {
+bool Op::HasGenericAttr(const String& key) {
OpManager* mgr = OpManager::Global();
std::lock_guard<std::mutex> lock(mgr->mutex);
auto it = mgr->attr.find(key);
@@ -110,7 +110,7 @@ void OpRegistry::reset_attr(const std::string& key) {
}
}
-void OpRegistry::UpdateAttr(const std::string& key, TVMRetValue value, int
plevel) {
+void OpRegistry::UpdateAttr(const String& key, TVMRetValue value, int plevel) {
OpManager* mgr = OpManager::Global();
std::lock_guard<std::mutex> lock(mgr->mutex);
std::unique_ptr<GenericOpMap>& op_map = mgr->attr[key];
@@ -141,7 +141,7 @@
TVM_REGISTER_GLOBAL("relay.op._ListOpNames").set_body_typed([]() {
return ret;
});
-TVM_REGISTER_GLOBAL("relay.op._GetOp").set_body_typed([](std::string name) ->
Op {
+TVM_REGISTER_GLOBAL("relay.op._GetOp").set_body_typed([](String name) -> Op {
return Op::Get(name);
});
diff --git a/src/ir/transform.cc b/src/ir/transform.cc
index d7d9b06..59e0c1c 100644
--- a/src/ir/transform.cc
+++ b/src/ir/transform.cc
@@ -201,7 +201,7 @@ class SequentialNode : public PassNode {
TVM_DECLARE_FINAL_OBJECT_INFO(SequentialNode, PassNode);
};
-PassInfo::PassInfo(int opt_level, std::string name,
tvm::Array<runtime::String> required) {
+PassInfo::PassInfo(int opt_level, String name, tvm::Array<runtime::String>
required) {
auto pass_info = make_object<PassInfoNode>();
pass_info->opt_level = opt_level;
pass_info->name = std::move(name);
@@ -238,7 +238,7 @@ Sequential::Sequential(tvm::Array<Pass> passes, PassInfo
pass_info) {
data_ = std::move(n);
}
-Sequential::Sequential(tvm::Array<Pass> passes, std::string name) {
+Sequential::Sequential(tvm::Array<Pass> passes, String name) {
auto n = make_object<SequentialNode>();
n->passes = std::move(passes);
PassInfo pass_info = PassInfo(2, std::move(name), {});
@@ -282,10 +282,10 @@ bool SequentialNode::PassEnabled(const PassInfo& info)
const {
return ctx->opt_level >= info->opt_level;
}
-Pass GetPass(const std::string& pass_name) {
+Pass GetPass(const String& pass_name) {
using tvm::runtime::Registry;
const runtime::PackedFunc* f = nullptr;
- if (pass_name.find("transform.") != std::string::npos) {
+ if (pass_name.operator std::string().find("transform.") !=
std::string::npos) {
f = Registry::Get(pass_name);
} else if ((f = Registry::Get("transform." + pass_name))) {
// pass
@@ -313,7 +313,7 @@ IRModule SequentialNode::operator()(IRModule mod, const
PassContext& pass_ctx) c
}
Pass CreateModulePass(const runtime::TypedPackedFunc<IRModule(IRModule,
PassContext)>& pass_func,
- int opt_level, const std::string& name,
+ int opt_level, const String& name,
const tvm::Array<runtime::String>& required) {
PassInfo pass_info = PassInfo(opt_level, name, required);
return ModulePass(pass_func, pass_info);
@@ -322,7 +322,7 @@ Pass CreateModulePass(const
runtime::TypedPackedFunc<IRModule(IRModule, PassCont
TVM_REGISTER_NODE_TYPE(PassInfoNode);
TVM_REGISTER_GLOBAL("transform.PassInfo")
- .set_body_typed([](int opt_level, std::string name,
tvm::Array<runtime::String> required) {
+ .set_body_typed([](int opt_level, String name, tvm::Array<runtime::String>
required) {
return PassInfo(opt_level, name, required);
});
@@ -439,7 +439,7 @@
TVM_REGISTER_GLOBAL("transform.EnterPassContext").set_body_typed(PassContext::In
TVM_REGISTER_GLOBAL("transform.ExitPassContext").set_body_typed(PassContext::Internal::ExitScope);
-Pass PrintIR(std::string header, bool show_meta_data) {
+Pass PrintIR(String header, bool show_meta_data) {
auto pass_func = [header, show_meta_data](IRModule mod, const PassContext&
ctx) {
LOG(INFO) << "PrintIR(" << header << "):\n" << AsText(mod, show_meta_data);
return mod;
diff --git a/src/ir/type.cc b/src/ir/type.cc
index 212a6e5..38a6ec3 100644
--- a/src/ir/type.cc
+++ b/src/ir/type.cc
@@ -81,7 +81,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->stream << "TypeVar(" << node->name_hint << ", " << node->kind << ")";
});
-GlobalTypeVar::GlobalTypeVar(std::string name, TypeKind kind) {
+GlobalTypeVar::GlobalTypeVar(String name, TypeKind kind) {
ObjectPtr<GlobalTypeVarNode> n = make_object<GlobalTypeVarNode>();
n->name_hint = std::move(name);
n->kind = std::move(kind);
@@ -90,7 +90,7 @@ GlobalTypeVar::GlobalTypeVar(std::string name, TypeKind kind)
{
TVM_REGISTER_NODE_TYPE(GlobalTypeVarNode);
-TVM_REGISTER_GLOBAL("ir.GlobalTypeVar").set_body_typed([](std::string name,
int kind) {
+TVM_REGISTER_GLOBAL("ir.GlobalTypeVar").set_body_typed([](String name, int
kind) {
return GlobalTypeVar(name, static_cast<TypeKind>(kind));
});
diff --git a/src/printer/relay_text_printer.cc
b/src/printer/relay_text_printer.cc
index 3c545ef..5166a48 100644
--- a/src/printer/relay_text_printer.cc
+++ b/src/printer/relay_text_printer.cc
@@ -446,7 +446,9 @@ Doc RelayTextPrinter::VisitExpr_(const FunctionNode* op) {
return PrintFunc(Doc::Text("fn "), GetRef<Function>(op));
}
-Doc RelayTextPrinter::VisitExpr_(const GlobalVarNode* op) { return
Doc::Text('@' + op->name_hint); }
+Doc RelayTextPrinter::VisitExpr_(const GlobalVarNode* op) {
+ return Doc::Text('@' + op->name_hint.operator std::string());
+}
Doc RelayTextPrinter::VisitExpr_(const OpNode* op) { return
Doc::Text(op->name); }