This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new 38a44c8 [Target] Creating Target from JSON-like Configuration (#6218)
38a44c8 is described below
commit 38a44c82b1af6652b83fbbf9f055e9aa7c0b5ba0
Author: Junru Shao <[email protected]>
AuthorDate: Fri Aug 14 18:33:37 2020 -0700
[Target] Creating Target from JSON-like Configuration (#6218)
* [Target] Creating Target from JSON-like Configuration
* Address comments from Cody
* fix unittest
* More testcases as suggested by @comaniac
---
include/tvm/target/target.h | 46 ++++-
include/tvm/target/target_kind.h | 20 +-
src/target/target.cc | 413 ++++++++++++++++++++++++++++++++++-----
src/target/target_kind.cc | 289 ---------------------------
tests/cpp/target_test.cc | 131 ++++++++++---
5 files changed, 516 insertions(+), 383 deletions(-)
diff --git a/include/tvm/target/target.h b/include/tvm/target/target.h
index 4a83579..258b2d8 100644
--- a/include/tvm/target/target.h
+++ b/include/tvm/target/target.h
@@ -31,6 +31,7 @@
#include <tvm/target/target_kind.h>
#include <string>
+#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
@@ -62,6 +63,13 @@ class TargetNode : public Object {
v->Visit("attrs", &attrs);
}
+ /*!
+ * \brief Get an entry from attrs of the target
+ * \tparam TObjectRef Type of the attribute
+ * \param attr_key The name of the attribute key
+ * \param default_value The value returned if the key is not present
+ * \return An optional, NullOpt if not found, otherwise the value found
+ */
template <typename TObjectRef>
Optional<TObjectRef> GetAttr(
const std::string& attr_key,
@@ -75,15 +83,19 @@ class TargetNode : public Object {
return default_value;
}
}
-
+ /*!
+ * \brief Get an entry from attrs of the target
+ * \tparam TObjectRef Type of the attribute
+ * \param attr_key The name of the attribute key
+ * \param default_value The value returned if the key is not present
+ * \return An optional, NullOpt if not found, otherwise the value found
+ */
template <typename TObjectRef>
Optional<TObjectRef> GetAttr(const std::string& attr_key, TObjectRef
default_value) const {
return GetAttr<TObjectRef>(attr_key, Optional<TObjectRef>(default_value));
}
-
/*! \brief Get the keys for this target as a vector of string */
TVM_DLL std::vector<std::string> GetKeys() const;
-
/*! \brief Get the keys for this target as an unordered_set of string */
TVM_DLL std::unordered_set<std::string> GetLibs() const;
@@ -93,6 +105,26 @@ class TargetNode : public Object {
private:
/*! \brief Internal string repr. */
mutable std::string str_repr_;
+ /*!
+ * \brief Parsing TargetNode::attrs from a list of raw strings
+ * \param obj The attribute to be parsed
+ * \param info The runtime type information for parsing
+ * \return The attribute parsed
+ */
+ ObjectRef ParseAttr(const ObjectRef& obj, const
TargetKindNode::ValueTypeInfo& info) const;
+ /*!
+ * \brief Parsing TargetNode::attrs from a list of raw strings
+ * \param options The raw string of fields to be parsed
+ * \return The attributes parsed
+ */
+ Map<String, ObjectRef> ParseAttrsFromRaw(const std::vector<std::string>&
options) const;
+ /*!
+ * \brief Serialize the attributes of a target to raw string
+ * \param attrs The attributes to be converted to string
+ * \return The string converted, NullOpt if attrs is empty
+ */
+ Optional<String> StringifyAttrsToRaw(const Map<String, ObjectRef>& attrs)
const;
+
friend class Target;
};
@@ -103,10 +135,18 @@ class TargetNode : public Object {
class Target : public ObjectRef {
public:
Target() {}
+ /*! \brief Constructor from ObjectPtr */
explicit Target(ObjectPtr<Object> n) : ObjectRef(n) {}
/*!
+ * \brief Create a Target using a JSON-like configuration
+ * \param config The JSON-like configuration
+ * \return The target created
+ */
+ TVM_DLL static Target FromConfig(const Map<String, ObjectRef>& config);
+ /*!
* \brief Create a Target given a string
* \param target_str the string to parse
+ * \return The target created
*/
TVM_DLL static Target Create(const String& target_str);
/*!
diff --git a/include/tvm/target/target_kind.h b/include/tvm/target/target_kind.h
index a661efa..e4e7c2f 100644
--- a/include/tvm/target/target_kind.h
+++ b/include/tvm/target/target_kind.h
@@ -45,9 +45,6 @@ struct ValueTypeInfoMaker;
class Target;
-/*! \brief Perform schema validation */
-TVM_DLL void TargetValidateSchema(const Map<String, ObjectRef>& config);
-
template <typename>
class TargetKindAttrMap;
@@ -67,14 +64,14 @@ class TargetKindNode : public Object {
v->Visit("default_keys", &default_keys);
}
- Map<String, ObjectRef> ParseAttrsFromRaw(const std::vector<std::string>&
options) const;
-
- Optional<String> StringifyAttrsToRaw(const Map<String, ObjectRef>& attrs)
const;
-
static constexpr const char* _type_key = "TargetKind";
TVM_DECLARE_FINAL_OBJECT_INFO(TargetKindNode, Object);
private:
+ /*! \brief Return the index stored in attr registry */
+ uint32_t AttrRegistryIndex() const { return index_; }
+ /*! \brief Return the name stored in attr registry */
+ String AttrRegistryName() const { return name; }
/*! \brief Stores the required type_key and type_index of a specific attr of
a target */
struct ValueTypeInfo {
String type_key;
@@ -82,21 +79,14 @@ class TargetKindNode : public Object {
std::unique_ptr<ValueTypeInfo> key;
std::unique_ptr<ValueTypeInfo> val;
};
-
- uint32_t AttrRegistryIndex() const { return index_; }
- String AttrRegistryName() const { return name; }
- /*! \brief Perform schema validation */
- void ValidateSchema(const Map<String, ObjectRef>& config) const;
- /*! \brief Verify if the obj is consistent with the type info */
- void VerifyTypeInfo(const ObjectRef& obj, const
TargetKindNode::ValueTypeInfo& info) const;
/*! \brief A hash table that stores the type information of each attr of the
target key */
std::unordered_map<String, ValueTypeInfo> key2vtype_;
/*! \brief A hash table that stores the default value of each attr of the
target key */
std::unordered_map<String, ObjectRef> key2default_;
/*! \brief Index used for internal lookup of attribute registry */
uint32_t index_;
- friend void TargetValidateSchema(const Map<String, ObjectRef>&);
friend class Target;
+ friend class TargetNode;
friend class TargetKind;
template <typename, typename>
friend class AttrRegistry;
diff --git a/src/target/target.cc b/src/target/target.cc
index 94b5b03..6a24597 100644
--- a/src/target/target.cc
+++ b/src/target/target.cc
@@ -30,12 +30,201 @@
#include <algorithm>
#include <stack>
+#include "../runtime/object_internal.h"
+
namespace tvm {
using runtime::PackedFunc;
using runtime::TVMArgs;
using runtime::TVMRetValue;
+TVM_REGISTER_NODE_TYPE(TargetNode);
+
+static std::vector<String> DeduplicateKeys(const std::vector<String>& keys) {
+ std::vector<String> new_keys;
+ for (size_t i = 0; i < keys.size(); ++i) {
+ bool found = false;
+ for (size_t j = 0; j < i; ++j) {
+ if (keys[i] == keys[j]) {
+ found = true;
+ break;
+ }
+ }
+ if (!found) {
+ new_keys.push_back(keys[i]);
+ }
+ }
+ return new_keys;
+}
+
+static inline std::string RemovePrefixDashes(const std::string& s) {
+ size_t n_dashes = 0;
+ for (; n_dashes < s.length() && s[n_dashes] == '-'; ++n_dashes) {
+ }
+ CHECK(0 < n_dashes && n_dashes < s.size()) << "ValueError: Not an attribute
key \"" << s << "\"";
+ return s.substr(n_dashes);
+}
+
+static inline int FindUniqueSubstr(const std::string& str, const std::string&
substr) {
+ size_t pos = str.find_first_of(substr);
+ if (pos == std::string::npos) {
+ return -1;
+ }
+ size_t next_pos = pos + substr.size();
+ CHECK(next_pos >= str.size() || str.find_first_of(substr, next_pos) ==
std::string::npos)
+ << "ValueError: At most one \"" << substr << "\" is allowed in "
+ << "the the given string \"" << str << "\"";
+ return pos;
+}
+
+static inline ObjectRef ParseAtomicType(uint32_t type_index, const
std::string& str) {
+ std::istringstream is(str);
+ if (type_index == Integer::ContainerType::_GetOrAllocRuntimeTypeIndex()) {
+ int v;
+ is >> v;
+ return is.fail() ? ObjectRef(nullptr) : Integer(v);
+ } else if (type_index ==
String::ContainerType::_GetOrAllocRuntimeTypeIndex()) {
+ std::string v;
+ is >> v;
+ return is.fail() ? ObjectRef(nullptr) : String(v);
+ }
+ return ObjectRef(nullptr);
+}
+
+Map<String, ObjectRef> TargetNode::ParseAttrsFromRaw(
+ const std::vector<std::string>& options) const {
+ std::unordered_map<String, ObjectRef> attrs;
+ for (size_t iter = 0, end = options.size(); iter < end;) {
+ // remove the prefix dashes
+ std::string s = RemovePrefixDashes(options[iter++]);
+ // parse name-obj pair
+ std::string name;
+ std::string obj;
+ int pos;
+ if ((pos = FindUniqueSubstr(s, "=")) != -1) {
+ // case 1. --key=value
+ name = s.substr(0, pos);
+ obj = s.substr(pos + 1);
+ CHECK(!name.empty()) << "ValueError: Empty attribute key in \"" <<
options[iter - 1] << "\"";
+ CHECK(!obj.empty()) << "ValueError: Empty attribute in \"" <<
options[iter - 1] << "\"";
+ } else if (iter < end && options[iter][0] != '-') {
+ // case 2. --key value
+ name = s;
+ obj = options[iter++];
+ } else {
+ // case 3. --boolean-key
+ name = s;
+ obj = "1";
+ }
+ // check if `name` is invalid
+ auto it = this->kind->key2vtype_.find(name);
+ if (it == this->kind->key2vtype_.end()) {
+ std::ostringstream os;
+ os << "AttributeError: Invalid config option, cannot recognize \'" <<
name
+ << "\'. Candidates are:";
+ for (const auto& kv : this->kind->key2vtype_) {
+ os << "\n " << kv.first;
+ }
+ LOG(FATAL) << os.str();
+ }
+ // check if `name` has been set once
+ CHECK(!attrs.count(name)) << "AttributeError: key \"" << name
+ << "\" appears more than once in the target
string";
+ // then `name` is valid, let's parse them
+ // only several types are supported when parsing raw string
+ const auto& info = it->second;
+ ObjectRef parsed_obj(nullptr);
+ if (info.type_index != ArrayNode::_type_index) {
+ parsed_obj = ParseAtomicType(info.type_index, obj);
+ } else {
+ Array<ObjectRef> array;
+ std::string item;
+ bool failed = false;
+ uint32_t type_index = info.key->type_index;
+ for (std::istringstream is(obj); std::getline(is, item, ',');) {
+ ObjectRef parsed_obj = ParseAtomicType(type_index, item);
+ if (parsed_obj.defined()) {
+ array.push_back(parsed_obj);
+ } else {
+ failed = true;
+ break;
+ }
+ }
+ if (!failed) {
+ parsed_obj = std::move(array);
+ }
+ }
+ if (!parsed_obj.defined()) {
+ LOG(FATAL) << "ValueError: Cannot parse type \"" << info.type_key << "\""
+ << ", where attribute key is \"" << name << "\""
+ << ", and attribute is \"" << obj << "\"";
+ }
+ attrs[name] = std::move(parsed_obj);
+ }
+ // set default attribute values if they do not exist
+ for (const auto& kv : this->kind->key2default_) {
+ if (!attrs.count(kv.first)) {
+ attrs[kv.first] = kv.second;
+ }
+ }
+ return attrs;
+}
+
+static inline Optional<String> StringifyAtomicType(const ObjectRef& obj) {
+ if (const auto* p = obj.as<IntImmNode>()) {
+ return String(std::to_string(p->value));
+ }
+ if (const auto* p = obj.as<StringObj>()) {
+ return GetRef<String>(p);
+ }
+ return NullOpt;
+}
+
+static inline Optional<String> JoinString(const std::vector<String>& array,
char separator) {
+ if (array.empty()) {
+ return NullOpt;
+ }
+ std::ostringstream os;
+ os << array[0];
+ for (size_t i = 1; i < array.size(); ++i) {
+ os << separator << array[i];
+ }
+ return String(os.str());
+}
+
+Optional<String> TargetNode::StringifyAttrsToRaw(const Map<String, ObjectRef>&
attrs) const {
+ std::ostringstream os;
+ std::vector<String> keys;
+ for (const auto& kv : attrs) {
+ keys.push_back(kv.first);
+ }
+ std::sort(keys.begin(), keys.end());
+ std::vector<String> result;
+ for (const auto& key : keys) {
+ const ObjectRef& obj = attrs[key];
+ Optional<String> value = NullOpt;
+ if (const auto* array = obj.as<ArrayNode>()) {
+ std::vector<String> items;
+ for (const ObjectRef& item : *array) {
+ Optional<String> str = StringifyAtomicType(item);
+ if (str.defined()) {
+ items.push_back(str.value());
+ } else {
+ items.clear();
+ break;
+ }
+ }
+ value = JoinString(items, ',');
+ } else {
+ value = StringifyAtomicType(obj);
+ }
+ if (value.defined()) {
+ result.push_back("-" + key + "=" + value.value());
+ }
+ }
+ return JoinString(result, ' ');
+}
+
Target Target::CreateTarget(const std::string& name, const
std::vector<std::string>& options) {
TargetKind kind = TargetKind::Get(name);
ObjectPtr<TargetNode> target = make_object<TargetNode>();
@@ -43,7 +232,7 @@ Target Target::CreateTarget(const std::string& name, const
std::vector<std::stri
// tag is always empty
target->tag = "";
// parse attrs
- target->attrs = kind->ParseAttrsFromRaw(options);
+ target->attrs = target->ParseAttrsFromRaw(options);
String device_name = target->GetAttr<String>("device", "").value();
// set up keys
{
@@ -62,48 +251,11 @@ Target Target::CreateTarget(const std::string& name, const
std::vector<std::stri
keys.push_back(key);
}
// de-duplicate keys
- size_t new_size = 0;
- for (size_t i = 0; i < keys.size(); ++i) {
- if (keys[i] == "") {
- continue;
- }
- keys[new_size++] = keys[i];
- for (size_t j = i + 1; j < keys.size(); ++j) {
- if (keys[j] == keys[i]) {
- keys[j] = String("");
- }
- }
- }
- keys.resize(new_size);
- target->keys = std::move(keys);
+ target->keys = DeduplicateKeys(keys);
}
return Target(target);
}
-TVM_REGISTER_NODE_TYPE(TargetNode);
-
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
- .set_dispatch<TargetNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const TargetNode*>(node.get());
- p->stream << op->str();
- });
-
-TVM_REGISTER_GLOBAL("target.TargetCreate").set_body([](TVMArgs args,
TVMRetValue* ret) {
- std::string name = args[0];
- std::vector<std::string> options;
- for (int i = 1; i < args.num_args; ++i) {
- std::string arg = args[i];
- options.push_back(arg);
- }
-
- *ret = Target::CreateTarget(name, options);
-});
-
-TVM_REGISTER_GLOBAL("target.TargetFromString").set_body([](TVMArgs args,
TVMRetValue* ret) {
- std::string target_str = args[0];
- *ret = Target::Create(target_str);
-});
-
std::vector<std::string> TargetNode::GetKeys() const {
std::vector<std::string> result;
for (auto& expr : keys) {
@@ -140,7 +292,7 @@ const std::string& TargetNode::str() const {
os << s;
}
}
- if (Optional<String> attrs_str = kind->StringifyAttrsToRaw(attrs)) {
+ if (Optional<String> attrs_str = this->StringifyAttrsToRaw(attrs)) {
os << ' ' << attrs_str.value();
}
str_repr_ = os.str();
@@ -162,6 +314,160 @@ Target Target::Create(const String& target_str) {
return CreateTarget(splits[0], {splits.begin() + 1, splits.end()});
}
+ObjectRef TargetNode::ParseAttr(const ObjectRef& obj,
+ const TargetKindNode::ValueTypeInfo& info)
const {
+ if (info.type_index ==
Integer::ContainerType::_GetOrAllocRuntimeTypeIndex()) {
+ const auto* v = obj.as<IntImmNode>();
+ CHECK(v != nullptr) << "Expect type 'int', but get: " << obj->GetTypeKey();
+ return GetRef<Integer>(v);
+ }
+ if (info.type_index == String::ContainerType::_GetOrAllocRuntimeTypeIndex())
{
+ const auto* v = obj.as<StringObj>();
+ CHECK(v != nullptr) << "Expect type 'str', but get: " << obj->GetTypeKey();
+ return GetRef<String>(v);
+ }
+ if (info.type_index == Target::ContainerType::_GetOrAllocRuntimeTypeIndex())
{
+ CHECK(obj->IsInstance<MapNode>())
+ << "Expect type 'dict' to construct Target, but get: " <<
obj->GetTypeKey();
+ return Target::FromConfig(Downcast<Map<String, ObjectRef>>(obj));
+ }
+ if (info.type_index == ArrayNode::_GetOrAllocRuntimeTypeIndex()) {
+ CHECK(obj->IsInstance<ArrayNode>()) << "Expect type 'list', but get: " <<
obj->GetTypeKey();
+ Array<ObjectRef> array = Downcast<Array<ObjectRef>>(obj);
+ std::vector<ObjectRef> result;
+ int i = 0;
+ for (const ObjectRef& e : array) {
+ ++i;
+ try {
+ result.push_back(TargetNode::ParseAttr(e, *info.key));
+ } catch (const dmlc::Error& e) {
+ LOG(FATAL) << "Error occurred when parsing element " << i << " of the
array: " << array
+ << ". Details:\n"
+ << e.what();
+ }
+ }
+ return Array<ObjectRef>(result);
+ }
+ if (info.type_index == MapNode::_GetOrAllocRuntimeTypeIndex()) {
+ CHECK(obj->IsInstance<MapNode>()) << "Expect type 'dict', but get: " <<
obj->GetTypeKey();
+ std::unordered_map<ObjectRef, ObjectRef, ObjectHash, ObjectEqual> result;
+ for (const auto& kv : Downcast<Map<ObjectRef, ObjectRef>>(obj)) {
+ ObjectRef key, val;
+ try {
+ key = TargetNode::ParseAttr(kv.first, *info.key);
+ } catch (const tvm::Error& e) {
+ LOG(FATAL) << "Error occurred when parsing a key of the dict: " <<
kv.first
+ << ". Details:\n"
+ << e.what();
+ }
+ try {
+ val = TargetNode::ParseAttr(kv.second, *info.val);
+ } catch (const tvm::Error& e) {
+ LOG(FATAL) << "Error occurred when parsing a value of the dict: " <<
kv.second
+ << ". Details:\n"
+ << e.what();
+ }
+ result[key] = val;
+ }
+ return Map<ObjectRef, ObjectRef>(result);
+ }
+ LOG(FATAL) << "Unsupported type registered: \"" << info.type_key
+ << "\", and the type given is: " << obj->GetTypeKey();
+ throw;
+}
+
+Target Target::FromConfig(const Map<String, ObjectRef>& config_dict) {
+ const String kKind = "kind";
+ const String kTag = "tag";
+ const String kKeys = "keys";
+ const String kDeviceName = "device";
+ std::unordered_map<std::string, ObjectRef> config(config_dict.begin(),
config_dict.end());
+ ObjectPtr<TargetNode> target = make_object<TargetNode>();
+ // parse 'kind'
+ if (config.count(kKind)) {
+ const auto* kind = config[kKind].as<StringObj>();
+ CHECK(kind != nullptr) << "AttributeError: Expect type of field 'kind' is
string, but get: "
+ << config[kKind]->GetTypeKey();
+ target->kind = TargetKind::Get(GetRef<String>(kind));
+ config.erase(kKind);
+ } else {
+ LOG(FATAL) << "AttributeError: Field 'kind' is not found";
+ }
+ // parse "tag"
+ if (config.count(kTag)) {
+ const auto* tag = config[kTag].as<StringObj>();
+ CHECK(tag != nullptr) << "AttributeError: Expect type of field 'tag' is
string, but get: "
+ << config[kTag]->GetTypeKey();
+ target->tag = GetRef<String>(tag);
+ config.erase(kTag);
+ } else {
+ target->tag = "";
+ }
+ // parse "keys"
+ if (config.count(kKeys)) {
+ std::vector<String> keys;
+ // user provided keys
+ const auto* cfg_keys = config[kKeys].as<ArrayNode>();
+ CHECK(cfg_keys != nullptr)
+ << "AttributeError: Expect type of field 'keys' is an Array, but get: "
+ << config[kTag]->GetTypeKey();
+ for (const ObjectRef& e : *cfg_keys) {
+ const auto* key = e.as<StringObj>();
+ CHECK(key != nullptr) << "AttributeError: Expect 'keys' to be an array
of strings, but it "
+ "contains an element of type: "
+ << e->GetTypeKey();
+ keys.push_back(GetRef<String>(key));
+ }
+ // add device name
+ if (config_dict.count(kDeviceName)) {
+ if (const auto* device = config_dict.at(kDeviceName).as<StringObj>()) {
+ keys.push_back(GetRef<String>(device));
+ }
+ }
+ // add default keys
+ for (const auto& key : target->kind->default_keys) {
+ keys.push_back(key);
+ }
+ // de-duplicate keys
+ target->keys = DeduplicateKeys(keys);
+ config.erase(kKeys);
+ } else {
+ target->keys = {};
+ }
+ // parse attrs
+ std::unordered_map<String, ObjectRef> attrs;
+ const auto& key2vtype = target->kind->key2vtype_;
+ for (const auto& cfg_kv : config) {
+ const String& name = cfg_kv.first;
+ const ObjectRef& obj = cfg_kv.second;
+ if (!key2vtype.count(name)) {
+ std::ostringstream os;
+ os << "AttributeError: Unrecognized config option: \"" << name << "\".
Candidates are:";
+ for (const auto& kv : key2vtype) {
+ os << " " << kv.first;
+ }
+ LOG(FATAL) << os.str();
+ }
+ ObjectRef val;
+ try {
+ val = target->ParseAttr(obj, key2vtype.at(name));
+ } catch (const dmlc::Error& e) {
+ LOG(FATAL) << "AttributeError: Error occurred in parsing the config key
\"" << name
+ << "\". Details:\n"
+ << e.what();
+ }
+ attrs[name] = val;
+ }
+ // set default attribute values if they do not exist
+ for (const auto& kv : target->kind->key2default_) {
+ if (!attrs.count(kv.first)) {
+ attrs[kv.first] = kv.second;
+ }
+ }
+ target->attrs = attrs;
+ return Target(target);
+}
+
/*! \brief Entry to hold the Target context stack. */
struct TVMTargetThreadLocalEntry {
/*! \brief The current target context */
@@ -169,7 +475,7 @@ struct TVMTargetThreadLocalEntry {
};
/*! \brief Thread local store to hold the Target context stack. */
-typedef dmlc::ThreadLocalStore<TVMTargetThreadLocalEntry>
TVMTargetThreadLocalStore;
+using TVMTargetThreadLocalStore =
dmlc::ThreadLocalStore<TVMTargetThreadLocalEntry>;
void Target::EnterWithScope() {
TVMTargetThreadLocalEntry* entry = TVMTargetThreadLocalStore::Get();
@@ -194,20 +500,37 @@ tvm::Target Target::Current(bool allow_not_defined) {
return Target();
}
-TVM_REGISTER_GLOBAL("target.GetCurrentTarget").set_body([](TVMArgs args,
TVMRetValue* ret) {
- bool allow_not_defined = args[0];
- *ret = Target::Current(allow_not_defined);
-});
class Target::Internal {
public:
static void EnterScope(Target target) { target.EnterWithScope(); }
static void ExitScope(Target target) { target.ExitWithScope(); }
};
+TVM_REGISTER_GLOBAL("target.TargetCreate").set_body([](TVMArgs args,
TVMRetValue* ret) {
+ std::string name = args[0];
+ std::vector<std::string> options;
+ for (int i = 1; i < args.num_args; ++i) {
+ std::string arg = args[i];
+ options.push_back(arg);
+ }
+
+ *ret = Target::CreateTarget(name, options);
+});
+
TVM_REGISTER_GLOBAL("target.EnterTargetScope").set_body_typed(Target::Internal::EnterScope);
TVM_REGISTER_GLOBAL("target.ExitTargetScope").set_body_typed(Target::Internal::ExitScope);
+TVM_REGISTER_GLOBAL("target.GetCurrentTarget").set_body_typed(Target::Current);
+
+TVM_REGISTER_GLOBAL("target.TargetFromString").set_body_typed(Target::Create);
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+ .set_dispatch<TargetNode>([](const ObjectRef& node, ReprPrinter* p) {
+ auto* op = static_cast<const TargetNode*>(node.get());
+ p->stream << op->str();
+ });
+
namespace target {
std::vector<std::string> MergeOptions(std::vector<std::string> opts,
const std::vector<std::string>&
new_opts) {
diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc
index e6f7c5c..3e35e5b 100644
--- a/src/target/target_kind.cc
+++ b/src/target/target_kind.cc
@@ -26,7 +26,6 @@
#include <algorithm>
#include "../node/attr_registry.h"
-#include "../runtime/object_internal.h"
namespace tvm {
@@ -60,294 +59,6 @@ const TargetKind& TargetKind::Get(const String&
target_kind_name) {
return reg->kind_;
}
-void TargetKindNode::VerifyTypeInfo(const ObjectRef& obj,
- const TargetKindNode::ValueTypeInfo& info)
const {
- CHECK(obj.defined()) << "Object is None";
- if (!runtime::ObjectInternal::DerivedFrom(obj.get(), info.type_index)) {
- LOG(FATAL) << "AttributeError: expect type \"" << info.type_key << "\" but
get "
- << obj->GetTypeKey();
- throw;
- }
- if (info.type_index == ArrayNode::_type_index) {
- int i = 0;
- for (const auto& e : *obj.as<ArrayNode>()) {
- try {
- VerifyTypeInfo(e, *info.key);
- } catch (const tvm::Error& e) {
- LOG(FATAL) << "The i-th element of array failed type checking, where i
= " << i
- << ", and the error is:\n"
- << e.what();
- throw;
- }
- ++i;
- }
- } else if (info.type_index == MapNode::_type_index) {
- for (const auto& kv : *obj.as<MapNode>()) {
- try {
- VerifyTypeInfo(kv.first, *info.key);
- } catch (const tvm::Error& e) {
- LOG(FATAL) << "The key of map failed type checking, where key = \"" <<
kv.first
- << "\", value = \"" << kv.second << "\", and the error
is:\n"
- << e.what();
- throw;
- }
- try {
- VerifyTypeInfo(kv.second, *info.val);
- } catch (const tvm::Error& e) {
- LOG(FATAL) << "The value of map failed type checking, where key = \""
<< kv.first
- << "\", value = \"" << kv.second << "\", and the error
is:\n"
- << e.what();
- throw;
- }
- }
- }
-}
-
-void TargetKindNode::ValidateSchema(const Map<String, ObjectRef>& config)
const {
- const String kTargetKind = "kind";
- for (const auto& kv : config) {
- const String& name = kv.first;
- const ObjectRef& obj = kv.second;
- if (name == kTargetKind) {
- CHECK(obj->IsInstance<StringObj>())
- << "AttributeError: \"kind\" is not a string, but its type is \"" <<
obj->GetTypeKey()
- << "\"";
- CHECK(Downcast<String>(obj) == this->name)
- << "AttributeError: \"kind\" = \"" << obj << "\" is inconsistent
with TargetKind \""
- << this->name << "\"";
- continue;
- }
- auto it = key2vtype_.find(name);
- if (it == key2vtype_.end()) {
- std::ostringstream os;
- os << "AttributeError: Invalid config option, cannot recognize \"" <<
name
- << "\". Candidates are:";
- for (const auto& kv : key2vtype_) {
- os << "\n " << kv.first;
- }
- LOG(FATAL) << os.str();
- throw;
- }
- const auto& info = it->second;
- try {
- VerifyTypeInfo(obj, info);
- } catch (const tvm::Error& e) {
- LOG(FATAL) << "AttributeError: Schema validation failed for TargetKind
\"" << this->name
- << "\", details:\n"
- << e.what() << "\n"
- << "The config is:\n"
- << config;
- throw;
- }
- }
-}
-
-inline String GetKind(const Map<String, ObjectRef>& target, const char* name) {
- const String kTargetKind = "kind";
- CHECK(target.count(kTargetKind))
- << "AttributeError: \"kind\" does not exist in \"" << name << "\"\n"
- << name << " = " << target;
- const ObjectRef& obj = target[kTargetKind];
- CHECK(obj->IsInstance<StringObj>()) << "AttributeError: \"kind\" is not a
string in \"" << name
- << "\", but its type is \"" <<
obj->GetTypeKey() << "\"\n"
- << name << " = \"" << target << '"';
- return Downcast<String>(obj);
-}
-
-void TargetValidateSchema(const Map<String, ObjectRef>& config) {
- try {
- const String kTargetHost = "target_host";
- Map<String, ObjectRef> target = config;
- Map<String, ObjectRef> target_host;
- String target_kind = GetKind(target, "target");
- String target_host_kind;
- if (config.count(kTargetHost)) {
- target.erase(kTargetHost);
- target_host = Downcast<Map<String, ObjectRef>>(config[kTargetHost]);
- target_host_kind = GetKind(target_host, "target_host");
- }
- TargetKind::Get(target_kind)->ValidateSchema(target);
- if (!target_host.empty()) {
- TargetKind::Get(target_host_kind)->ValidateSchema(target_host);
- }
- } catch (const tvm::Error& e) {
- LOG(FATAL) << "AttributeError: schedule validation fails:\n"
- << e.what() << "\nThe configuration is:\n"
- << config;
- }
-}
-
-static inline size_t CountNumPrefixDashes(const std::string& s) {
- size_t i = 0;
- for (; i < s.length() && s[i] == '-'; ++i) {
- }
- return i;
-}
-
-static inline int FindUniqueSubstr(const std::string& str, const std::string&
substr) {
- size_t pos = str.find_first_of(substr);
- if (pos == std::string::npos) {
- return -1;
- }
- size_t next_pos = pos + substr.size();
- CHECK(next_pos >= str.size() || str.find_first_of(substr, next_pos) ==
std::string::npos)
- << "ValueError: At most one \"" << substr << "\" is allowed in "
- << "the the given string \"" << str << "\"";
- return pos;
-}
-
-static inline ObjectRef ParseScalar(uint32_t type_index, const std::string&
str) {
- std::istringstream is(str);
- if (type_index == Integer::ContainerType::_GetOrAllocRuntimeTypeIndex()) {
- int v;
- is >> v;
- return is.fail() ? ObjectRef(nullptr) : Integer(v);
- } else if (type_index ==
String::ContainerType::_GetOrAllocRuntimeTypeIndex()) {
- std::string v;
- is >> v;
- return is.fail() ? ObjectRef(nullptr) : String(v);
- }
- return ObjectRef(nullptr);
-}
-
-static inline Optional<String> StringifyScalar(const ObjectRef& obj) {
- if (const auto* p = obj.as<IntImmNode>()) {
- return String(std::to_string(p->value));
- }
- if (const auto* p = obj.as<StringObj>()) {
- return GetRef<String>(p);
- }
- return NullOpt;
-}
-
-static inline Optional<String> Join(const std::vector<String>& array, char
separator) {
- if (array.empty()) {
- return NullOpt;
- }
- std::ostringstream os;
- os << array[0];
- for (size_t i = 1; i < array.size(); ++i) {
- os << separator << array[i];
- }
- return String(os.str());
-}
-
-Map<String, ObjectRef> TargetKindNode::ParseAttrsFromRaw(
- const std::vector<std::string>& options) const {
- std::unordered_map<String, ObjectRef> attrs;
- for (size_t iter = 0, end = options.size(); iter < end;) {
- std::string s = options[iter++];
- // remove the prefix dashes
- size_t n_dashes = CountNumPrefixDashes(s);
- CHECK(0 < n_dashes && n_dashes < s.size())
- << "ValueError: Not an attribute key \"" << s << "\"";
- s = s.substr(n_dashes);
- // parse name-obj pair
- std::string name;
- std::string obj;
- int pos;
- if ((pos = FindUniqueSubstr(s, "=")) != -1) {
- // case 1. --key=value
- name = s.substr(0, pos);
- obj = s.substr(pos + 1);
- CHECK(!name.empty()) << "ValueError: Empty attribute key in \"" <<
options[iter - 1] << "\"";
- CHECK(!obj.empty()) << "ValueError: Empty attribute in \"" <<
options[iter - 1] << "\"";
- } else if (iter < end && options[iter][0] != '-') {
- // case 2. --key value
- name = s;
- obj = options[iter++];
- } else {
- // case 3. --boolean-key
- name = s;
- obj = "1";
- }
- // check if `name` is invalid
- auto it = key2vtype_.find(name);
- if (it == key2vtype_.end()) {
- std::ostringstream os;
- os << "AttributeError: Invalid config option, cannot recognize \'" <<
name
- << "\'. Candidates are:";
- for (const auto& kv : key2vtype_) {
- os << "\n " << kv.first;
- }
- LOG(FATAL) << os.str();
- }
- // check if `name` has been set once
- CHECK(!attrs.count(name)) << "AttributeError: key \"" << name
- << "\" appears more than once in the target
string";
- // then `name` is valid, let's parse them
- // only several types are supported when parsing raw string
- const auto& info = it->second;
- ObjectRef parsed_obj(nullptr);
- if (info.type_index != ArrayNode::_type_index) {
- parsed_obj = ParseScalar(info.type_index, obj);
- } else {
- Array<ObjectRef> array;
- std::string item;
- bool failed = false;
- uint32_t type_index = info.key->type_index;
- for (std::istringstream is(obj); std::getline(is, item, ',');) {
- ObjectRef parsed_obj = ParseScalar(type_index, item);
- if (parsed_obj.defined()) {
- array.push_back(parsed_obj);
- } else {
- failed = true;
- break;
- }
- }
- if (!failed) {
- parsed_obj = std::move(array);
- }
- }
- if (!parsed_obj.defined()) {
- LOG(FATAL) << "ValueError: Cannot parse type \"" << info.type_key << "\""
- << ", where attribute key is \"" << name << "\""
- << ", and attribute is \"" << obj << "\"";
- }
- attrs[name] = std::move(parsed_obj);
- }
- // set default attribute values if they do not exist
- for (const auto& kv : key2default_) {
- if (!attrs.count(kv.first)) {
- attrs[kv.first] = kv.second;
- }
- }
- return attrs;
-}
-
-Optional<String> TargetKindNode::StringifyAttrsToRaw(const Map<String,
ObjectRef>& attrs) const {
- std::ostringstream os;
- std::vector<String> keys;
- for (const auto& kv : attrs) {
- keys.push_back(kv.first);
- }
- std::sort(keys.begin(), keys.end());
- std::vector<String> result;
- for (const auto& key : keys) {
- const ObjectRef& obj = attrs[key];
- Optional<String> value = NullOpt;
- if (const auto* array = obj.as<ArrayNode>()) {
- std::vector<String> items;
- for (const ObjectRef& item : *array) {
- Optional<String> str = StringifyScalar(item);
- if (str.defined()) {
- items.push_back(str.value());
- } else {
- items.clear();
- break;
- }
- }
- value = Join(items, ',');
- } else {
- value = StringifyScalar(obj);
- }
- if (value.defined()) {
- result.push_back("-" + key + "=" + value.value());
- }
- }
- return Join(result, ' ');
-}
-
// TODO(@junrushao1994): remove some redundant attributes
TVM_REGISTER_TARGET_KIND("llvm")
diff --git a/tests/cpp/target_test.cc b/tests/cpp/target_test.cc
index 8bee707..e8748e6 100644
--- a/tests/cpp/target_test.cc
+++ b/tests/cpp/target_test.cc
@@ -19,7 +19,7 @@
#include <dmlc/logging.h>
#include <gtest/gtest.h>
-#include <tvm/target/target_kind.h>
+#include <tvm/target/target.h>
#include <cmath>
#include <string>
@@ -39,48 +39,117 @@ TEST(TargetKind, GetAttrMap) {
CHECK_EQ(result, "Value1");
}
-TEST(TargetKind, SchemaValidation) {
- tvm::Map<String, ObjectRef> target;
- {
- tvm::Array<String> your_names{"junru", "jian"};
- tvm::Map<String, Integer> her_maps{
- {"a", 1},
- {"b", 2},
- };
- target.Set("my_bool", Bool(true));
- target.Set("your_names", your_names);
- target.Set("her_maps", her_maps);
- target.Set("kind", String("TestTargetKind"));
+TEST(TargetCreation, NestedConfig) {
+ Map<String, ObjectRef> config = {
+ {"my_bool", Bool(true)},
+ {"your_names", Array<String>{"junru", "jian"}},
+ {"kind", String("TestTargetKind")},
+ {
+ "her_maps",
+ Map<String, Integer>{
+ {"a", 1},
+ {"b", 2},
+ },
+ },
+ };
+ Target target = Target::FromConfig(config);
+ CHECK_EQ(target->kind, TargetKind::Get("TestTargetKind"));
+ CHECK_EQ(target->tag, "");
+ CHECK(target->keys.empty());
+ Bool my_bool = target->GetAttr<Bool>("my_bool").value();
+ CHECK_EQ(my_bool.operator bool(), true);
+ Array<String> your_names =
target->GetAttr<Array<String>>("your_names").value();
+ CHECK_EQ(your_names.size(), 2U);
+ CHECK_EQ(your_names[0], "junru");
+ CHECK_EQ(your_names[1], "jian");
+ Map<String, Integer> her_maps = target->GetAttr<Map<String,
Integer>>("her_maps").value();
+ CHECK_EQ(her_maps.size(), 2U);
+ CHECK_EQ(her_maps["a"], 1);
+ CHECK_EQ(her_maps["b"], 2);
+}
+
+TEST(TargetCreationFail, UnrecognizedConfigOption) {
+ Map<String, ObjectRef> config = {
+ {"my_bool", Bool(true)},
+ {"your_names", Array<String>{"junru", "jian"}},
+ {"kind", String("TestTargetKind")},
+ {"bad", ObjectRef(nullptr)},
+ {
+ "her_maps",
+ Map<String, Integer>{
+ {"a", 1},
+ {"b", 2},
+ },
+ },
+ };
+ bool failed = false;
+ try {
+ Target::FromConfig(config);
+ } catch (...) {
+ failed = true;
}
- TargetValidateSchema(target);
- tvm::Map<String, ObjectRef> target_host(target.begin(), target.end());
- target.Set("target_host", target_host);
- TargetValidateSchema(target);
+ ASSERT_EQ(failed, true);
}
-TEST(TargetKind, SchemaValidationFail) {
- tvm::Map<String, ObjectRef> target;
- {
- tvm::Array<String> your_names{"junru", "jian"};
- tvm::Map<String, Integer> her_maps{
- {"a", 1},
- {"b", 2},
- };
- target.Set("my_bool", Bool(true));
- target.Set("your_names", your_names);
- target.Set("her_maps", her_maps);
- target.Set("ok", ObjectRef(nullptr));
- target.Set("kind", String("TestTargetKind"));
+TEST(TargetCreationFail, TypeMismatch) {
+ Map<String, ObjectRef> config = {
+ {"my_bool", String("true")},
+ {"your_names", Array<String>{"junru", "jian"}},
+ {"kind", String("TestTargetKind")},
+ {
+ "her_maps",
+ Map<String, Integer>{
+ {"a", 1},
+ {"b", 2},
+ },
+ },
+ };
+ bool failed = false;
+ try {
+ Target::FromConfig(config);
+ } catch (...) {
+ failed = true;
}
+ ASSERT_EQ(failed, true);
+}
+
+TEST(TargetCreationFail, TargetKindNotFound) {
+ Map<String, ObjectRef> config = {
+ {"my_bool", Bool("true")},
+ {"your_names", Array<String>{"junru", "jian"}},
+ {
+ "her_maps",
+ Map<String, Integer>{
+ {"a", 1},
+ {"b", 2},
+ },
+ },
+ };
bool failed = false;
try {
- TargetValidateSchema(target);
+ Target::FromConfig(config);
} catch (...) {
failed = true;
}
ASSERT_EQ(failed, true);
}
+TEST(TargetCreation, DeduplicateKeys) {
+ Map<String, ObjectRef> config = {
+ {"kind", String("llvm")},
+ {"keys", Array<String>{"cpu", "arm_cpu"}},
+ {"device", String("arm_cpu")},
+ };
+ Target target = Target::FromConfig(config);
+ CHECK_EQ(target->kind, TargetKind::Get("llvm"));
+ CHECK_EQ(target->tag, "");
+ CHECK_EQ(target->keys.size(), 2U);
+ CHECK_EQ(target->keys[0], "cpu");
+ CHECK_EQ(target->keys[1], "arm_cpu");
+ CHECK_EQ(target->attrs.size(), 1U);
+ CHECK_EQ(target->GetAttr<String>("device"), "arm_cpu");
+}
+
int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe";