This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new 38a44c8  [Target] Creating Target from JSON-like Configuration (#6218)
38a44c8 is described below

commit 38a44c82b1af6652b83fbbf9f055e9aa7c0b5ba0
Author: Junru Shao <[email protected]>
AuthorDate: Fri Aug 14 18:33:37 2020 -0700

    [Target] Creating Target from JSON-like Configuration (#6218)
    
    * [Target] Creating Target from JSON-like Configuration
    
    * Address comments from Cody
    
    * fix unittest
    
    * More testcases as suggested by @comaniac
---
 include/tvm/target/target.h      |  46 ++++-
 include/tvm/target/target_kind.h |  20 +-
 src/target/target.cc             | 413 ++++++++++++++++++++++++++++++++++-----
 src/target/target_kind.cc        | 289 ---------------------------
 tests/cpp/target_test.cc         | 131 ++++++++++---
 5 files changed, 516 insertions(+), 383 deletions(-)

diff --git a/include/tvm/target/target.h b/include/tvm/target/target.h
index 4a83579..258b2d8 100644
--- a/include/tvm/target/target.h
+++ b/include/tvm/target/target.h
@@ -31,6 +31,7 @@
 #include <tvm/target/target_kind.h>
 
 #include <string>
+#include <unordered_map>
 #include <unordered_set>
 #include <utility>
 #include <vector>
@@ -62,6 +63,13 @@ class TargetNode : public Object {
     v->Visit("attrs", &attrs);
   }
 
+  /*!
+   * \brief Get an entry from attrs of the target
+   * \tparam TObjectRef Type of the attribute
+   * \param attr_key The name of the attribute key
+   * \param default_value The value returned if the key is not present
+   * \return An optional, NullOpt if not found, otherwise the value found
+   */
   template <typename TObjectRef>
   Optional<TObjectRef> GetAttr(
       const std::string& attr_key,
@@ -75,15 +83,19 @@ class TargetNode : public Object {
       return default_value;
     }
   }
-
+  /*!
+   * \brief Get an entry from attrs of the target
+   * \tparam TObjectRef Type of the attribute
+   * \param attr_key The name of the attribute key
+   * \param default_value The value returned if the key is not present
+   * \return An optional, NullOpt if not found, otherwise the value found
+   */
   template <typename TObjectRef>
   Optional<TObjectRef> GetAttr(const std::string& attr_key, TObjectRef 
default_value) const {
     return GetAttr<TObjectRef>(attr_key, Optional<TObjectRef>(default_value));
   }
-
   /*! \brief Get the keys for this target as a vector of string */
   TVM_DLL std::vector<std::string> GetKeys() const;
-
   /*! \brief Get the keys for this target as an unordered_set of string */
   TVM_DLL std::unordered_set<std::string> GetLibs() const;
 
@@ -93,6 +105,26 @@ class TargetNode : public Object {
  private:
   /*! \brief Internal string repr. */
   mutable std::string str_repr_;
+  /*!
+   * \brief Parsing TargetNode::attrs from a list of raw strings
+   * \param obj The attribute to be parsed
+   * \param info The runtime type information for parsing
+   * \return The attribute parsed
+   */
+  ObjectRef ParseAttr(const ObjectRef& obj, const 
TargetKindNode::ValueTypeInfo& info) const;
+  /*!
+   * \brief Parsing TargetNode::attrs from a list of raw strings
+   * \param options The raw string of fields to be parsed
+   * \return The attributes parsed
+   */
+  Map<String, ObjectRef> ParseAttrsFromRaw(const std::vector<std::string>& 
options) const;
+  /*!
+   * \brief Serialize the attributes of a target to raw string
+   * \param attrs The attributes to be converted to string
+   * \return The string converted, NullOpt if attrs is empty
+   */
+  Optional<String> StringifyAttrsToRaw(const Map<String, ObjectRef>& attrs) 
const;
+
   friend class Target;
 };
 
@@ -103,10 +135,18 @@ class TargetNode : public Object {
 class Target : public ObjectRef {
  public:
   Target() {}
+  /*! \brief Constructor from ObjectPtr */
   explicit Target(ObjectPtr<Object> n) : ObjectRef(n) {}
   /*!
+   * \brief Create a Target using a JSON-like configuration
+   * \param config The JSON-like configuration
+   * \return The target created
+   */
+  TVM_DLL static Target FromConfig(const Map<String, ObjectRef>& config);
+  /*!
    * \brief Create a Target given a string
    * \param target_str the string to parse
+   * \return The target created
    */
   TVM_DLL static Target Create(const String& target_str);
   /*!
diff --git a/include/tvm/target/target_kind.h b/include/tvm/target/target_kind.h
index a661efa..e4e7c2f 100644
--- a/include/tvm/target/target_kind.h
+++ b/include/tvm/target/target_kind.h
@@ -45,9 +45,6 @@ struct ValueTypeInfoMaker;
 
 class Target;
 
-/*! \brief Perform schema validation */
-TVM_DLL void TargetValidateSchema(const Map<String, ObjectRef>& config);
-
 template <typename>
 class TargetKindAttrMap;
 
@@ -67,14 +64,14 @@ class TargetKindNode : public Object {
     v->Visit("default_keys", &default_keys);
   }
 
-  Map<String, ObjectRef> ParseAttrsFromRaw(const std::vector<std::string>& 
options) const;
-
-  Optional<String> StringifyAttrsToRaw(const Map<String, ObjectRef>& attrs) 
const;
-
   static constexpr const char* _type_key = "TargetKind";
   TVM_DECLARE_FINAL_OBJECT_INFO(TargetKindNode, Object);
 
  private:
+  /*! \brief Return the index stored in attr registry */
+  uint32_t AttrRegistryIndex() const { return index_; }
+  /*! \brief Return the name stored in attr registry */
+  String AttrRegistryName() const { return name; }
   /*! \brief Stores the required type_key and type_index of a specific attr of 
a target */
   struct ValueTypeInfo {
     String type_key;
@@ -82,21 +79,14 @@ class TargetKindNode : public Object {
     std::unique_ptr<ValueTypeInfo> key;
     std::unique_ptr<ValueTypeInfo> val;
   };
-
-  uint32_t AttrRegistryIndex() const { return index_; }
-  String AttrRegistryName() const { return name; }
-  /*! \brief Perform schema validation */
-  void ValidateSchema(const Map<String, ObjectRef>& config) const;
-  /*! \brief Verify if the obj is consistent with the type info */
-  void VerifyTypeInfo(const ObjectRef& obj, const 
TargetKindNode::ValueTypeInfo& info) const;
   /*! \brief A hash table that stores the type information of each attr of the 
target key */
   std::unordered_map<String, ValueTypeInfo> key2vtype_;
   /*! \brief A hash table that stores the default value of each attr of the 
target key */
   std::unordered_map<String, ObjectRef> key2default_;
   /*! \brief Index used for internal lookup of attribute registry */
   uint32_t index_;
-  friend void TargetValidateSchema(const Map<String, ObjectRef>&);
   friend class Target;
+  friend class TargetNode;
   friend class TargetKind;
   template <typename, typename>
   friend class AttrRegistry;
diff --git a/src/target/target.cc b/src/target/target.cc
index 94b5b03..6a24597 100644
--- a/src/target/target.cc
+++ b/src/target/target.cc
@@ -30,12 +30,201 @@
 #include <algorithm>
 #include <stack>
 
+#include "../runtime/object_internal.h"
+
 namespace tvm {
 
 using runtime::PackedFunc;
 using runtime::TVMArgs;
 using runtime::TVMRetValue;
 
+TVM_REGISTER_NODE_TYPE(TargetNode);
+
+static std::vector<String> DeduplicateKeys(const std::vector<String>& keys) {
+  std::vector<String> new_keys;
+  for (size_t i = 0; i < keys.size(); ++i) {
+    bool found = false;
+    for (size_t j = 0; j < i; ++j) {
+      if (keys[i] == keys[j]) {
+        found = true;
+        break;
+      }
+    }
+    if (!found) {
+      new_keys.push_back(keys[i]);
+    }
+  }
+  return new_keys;
+}
+
+static inline std::string RemovePrefixDashes(const std::string& s) {
+  size_t n_dashes = 0;
+  for (; n_dashes < s.length() && s[n_dashes] == '-'; ++n_dashes) {
+  }
+  CHECK(0 < n_dashes && n_dashes < s.size()) << "ValueError: Not an attribute 
key \"" << s << "\"";
+  return s.substr(n_dashes);
+}
+
+static inline int FindUniqueSubstr(const std::string& str, const std::string& 
substr) {
+  size_t pos = str.find_first_of(substr);
+  if (pos == std::string::npos) {
+    return -1;
+  }
+  size_t next_pos = pos + substr.size();
+  CHECK(next_pos >= str.size() || str.find_first_of(substr, next_pos) == 
std::string::npos)
+      << "ValueError: At most one \"" << substr << "\" is allowed in "
+      << "the the given string \"" << str << "\"";
+  return pos;
+}
+
+static inline ObjectRef ParseAtomicType(uint32_t type_index, const 
std::string& str) {
+  std::istringstream is(str);
+  if (type_index == Integer::ContainerType::_GetOrAllocRuntimeTypeIndex()) {
+    int v;
+    is >> v;
+    return is.fail() ? ObjectRef(nullptr) : Integer(v);
+  } else if (type_index == 
String::ContainerType::_GetOrAllocRuntimeTypeIndex()) {
+    std::string v;
+    is >> v;
+    return is.fail() ? ObjectRef(nullptr) : String(v);
+  }
+  return ObjectRef(nullptr);
+}
+
+Map<String, ObjectRef> TargetNode::ParseAttrsFromRaw(
+    const std::vector<std::string>& options) const {
+  std::unordered_map<String, ObjectRef> attrs;
+  for (size_t iter = 0, end = options.size(); iter < end;) {
+    // remove the prefix dashes
+    std::string s = RemovePrefixDashes(options[iter++]);
+    // parse name-obj pair
+    std::string name;
+    std::string obj;
+    int pos;
+    if ((pos = FindUniqueSubstr(s, "=")) != -1) {
+      // case 1. --key=value
+      name = s.substr(0, pos);
+      obj = s.substr(pos + 1);
+      CHECK(!name.empty()) << "ValueError: Empty attribute key in \"" << 
options[iter - 1] << "\"";
+      CHECK(!obj.empty()) << "ValueError: Empty attribute in \"" << 
options[iter - 1] << "\"";
+    } else if (iter < end && options[iter][0] != '-') {
+      // case 2. --key value
+      name = s;
+      obj = options[iter++];
+    } else {
+      // case 3. --boolean-key
+      name = s;
+      obj = "1";
+    }
+    // check if `name` is invalid
+    auto it = this->kind->key2vtype_.find(name);
+    if (it == this->kind->key2vtype_.end()) {
+      std::ostringstream os;
+      os << "AttributeError: Invalid config option, cannot recognize \'" << 
name
+         << "\'. Candidates are:";
+      for (const auto& kv : this->kind->key2vtype_) {
+        os << "\n  " << kv.first;
+      }
+      LOG(FATAL) << os.str();
+    }
+    // check if `name` has been set once
+    CHECK(!attrs.count(name)) << "AttributeError: key \"" << name
+                              << "\" appears more than once in the target 
string";
+    // then `name` is valid, let's parse them
+    // only several types are supported when parsing raw string
+    const auto& info = it->second;
+    ObjectRef parsed_obj(nullptr);
+    if (info.type_index != ArrayNode::_type_index) {
+      parsed_obj = ParseAtomicType(info.type_index, obj);
+    } else {
+      Array<ObjectRef> array;
+      std::string item;
+      bool failed = false;
+      uint32_t type_index = info.key->type_index;
+      for (std::istringstream is(obj); std::getline(is, item, ',');) {
+        ObjectRef parsed_obj = ParseAtomicType(type_index, item);
+        if (parsed_obj.defined()) {
+          array.push_back(parsed_obj);
+        } else {
+          failed = true;
+          break;
+        }
+      }
+      if (!failed) {
+        parsed_obj = std::move(array);
+      }
+    }
+    if (!parsed_obj.defined()) {
+      LOG(FATAL) << "ValueError: Cannot parse type \"" << info.type_key << "\""
+                 << ", where attribute key is \"" << name << "\""
+                 << ", and attribute is \"" << obj << "\"";
+    }
+    attrs[name] = std::move(parsed_obj);
+  }
+  // set default attribute values if they do not exist
+  for (const auto& kv : this->kind->key2default_) {
+    if (!attrs.count(kv.first)) {
+      attrs[kv.first] = kv.second;
+    }
+  }
+  return attrs;
+}
+
+static inline Optional<String> StringifyAtomicType(const ObjectRef& obj) {
+  if (const auto* p = obj.as<IntImmNode>()) {
+    return String(std::to_string(p->value));
+  }
+  if (const auto* p = obj.as<StringObj>()) {
+    return GetRef<String>(p);
+  }
+  return NullOpt;
+}
+
+static inline Optional<String> JoinString(const std::vector<String>& array, 
char separator) {
+  if (array.empty()) {
+    return NullOpt;
+  }
+  std::ostringstream os;
+  os << array[0];
+  for (size_t i = 1; i < array.size(); ++i) {
+    os << separator << array[i];
+  }
+  return String(os.str());
+}
+
+Optional<String> TargetNode::StringifyAttrsToRaw(const Map<String, ObjectRef>& 
attrs) const {
+  std::ostringstream os;
+  std::vector<String> keys;
+  for (const auto& kv : attrs) {
+    keys.push_back(kv.first);
+  }
+  std::sort(keys.begin(), keys.end());
+  std::vector<String> result;
+  for (const auto& key : keys) {
+    const ObjectRef& obj = attrs[key];
+    Optional<String> value = NullOpt;
+    if (const auto* array = obj.as<ArrayNode>()) {
+      std::vector<String> items;
+      for (const ObjectRef& item : *array) {
+        Optional<String> str = StringifyAtomicType(item);
+        if (str.defined()) {
+          items.push_back(str.value());
+        } else {
+          items.clear();
+          break;
+        }
+      }
+      value = JoinString(items, ',');
+    } else {
+      value = StringifyAtomicType(obj);
+    }
+    if (value.defined()) {
+      result.push_back("-" + key + "=" + value.value());
+    }
+  }
+  return JoinString(result, ' ');
+}
+
 Target Target::CreateTarget(const std::string& name, const 
std::vector<std::string>& options) {
   TargetKind kind = TargetKind::Get(name);
   ObjectPtr<TargetNode> target = make_object<TargetNode>();
@@ -43,7 +232,7 @@ Target Target::CreateTarget(const std::string& name, const 
std::vector<std::stri
   // tag is always empty
   target->tag = "";
   // parse attrs
-  target->attrs = kind->ParseAttrsFromRaw(options);
+  target->attrs = target->ParseAttrsFromRaw(options);
   String device_name = target->GetAttr<String>("device", "").value();
   // set up keys
   {
@@ -62,48 +251,11 @@ Target Target::CreateTarget(const std::string& name, const 
std::vector<std::stri
       keys.push_back(key);
     }
     // de-duplicate keys
-    size_t new_size = 0;
-    for (size_t i = 0; i < keys.size(); ++i) {
-      if (keys[i] == "") {
-        continue;
-      }
-      keys[new_size++] = keys[i];
-      for (size_t j = i + 1; j < keys.size(); ++j) {
-        if (keys[j] == keys[i]) {
-          keys[j] = String("");
-        }
-      }
-    }
-    keys.resize(new_size);
-    target->keys = std::move(keys);
+    target->keys = DeduplicateKeys(keys);
   }
   return Target(target);
 }
 
-TVM_REGISTER_NODE_TYPE(TargetNode);
-
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-    .set_dispatch<TargetNode>([](const ObjectRef& node, ReprPrinter* p) {
-      auto* op = static_cast<const TargetNode*>(node.get());
-      p->stream << op->str();
-    });
-
-TVM_REGISTER_GLOBAL("target.TargetCreate").set_body([](TVMArgs args, 
TVMRetValue* ret) {
-  std::string name = args[0];
-  std::vector<std::string> options;
-  for (int i = 1; i < args.num_args; ++i) {
-    std::string arg = args[i];
-    options.push_back(arg);
-  }
-
-  *ret = Target::CreateTarget(name, options);
-});
-
-TVM_REGISTER_GLOBAL("target.TargetFromString").set_body([](TVMArgs args, 
TVMRetValue* ret) {
-  std::string target_str = args[0];
-  *ret = Target::Create(target_str);
-});
-
 std::vector<std::string> TargetNode::GetKeys() const {
   std::vector<std::string> result;
   for (auto& expr : keys) {
@@ -140,7 +292,7 @@ const std::string& TargetNode::str() const {
         os << s;
       }
     }
-    if (Optional<String> attrs_str = kind->StringifyAttrsToRaw(attrs)) {
+    if (Optional<String> attrs_str = this->StringifyAttrsToRaw(attrs)) {
       os << ' ' << attrs_str.value();
     }
     str_repr_ = os.str();
@@ -162,6 +314,160 @@ Target Target::Create(const String& target_str) {
   return CreateTarget(splits[0], {splits.begin() + 1, splits.end()});
 }
 
+ObjectRef TargetNode::ParseAttr(const ObjectRef& obj,
+                                const TargetKindNode::ValueTypeInfo& info) 
const {
+  if (info.type_index == 
Integer::ContainerType::_GetOrAllocRuntimeTypeIndex()) {
+    const auto* v = obj.as<IntImmNode>();
+    CHECK(v != nullptr) << "Expect type 'int', but get: " << obj->GetTypeKey();
+    return GetRef<Integer>(v);
+  }
+  if (info.type_index == String::ContainerType::_GetOrAllocRuntimeTypeIndex()) 
{
+    const auto* v = obj.as<StringObj>();
+    CHECK(v != nullptr) << "Expect type 'str', but get: " << obj->GetTypeKey();
+    return GetRef<String>(v);
+  }
+  if (info.type_index == Target::ContainerType::_GetOrAllocRuntimeTypeIndex()) 
{
+    CHECK(obj->IsInstance<MapNode>())
+        << "Expect type 'dict' to construct Target, but get: " << 
obj->GetTypeKey();
+    return Target::FromConfig(Downcast<Map<String, ObjectRef>>(obj));
+  }
+  if (info.type_index == ArrayNode::_GetOrAllocRuntimeTypeIndex()) {
+    CHECK(obj->IsInstance<ArrayNode>()) << "Expect type 'list', but get: " << 
obj->GetTypeKey();
+    Array<ObjectRef> array = Downcast<Array<ObjectRef>>(obj);
+    std::vector<ObjectRef> result;
+    int i = 0;
+    for (const ObjectRef& e : array) {
+      ++i;
+      try {
+        result.push_back(TargetNode::ParseAttr(e, *info.key));
+      } catch (const dmlc::Error& e) {
+        LOG(FATAL) << "Error occurred when parsing element " << i << " of the 
array: " << array
+                   << ". Details:\n"
+                   << e.what();
+      }
+    }
+    return Array<ObjectRef>(result);
+  }
+  if (info.type_index == MapNode::_GetOrAllocRuntimeTypeIndex()) {
+    CHECK(obj->IsInstance<MapNode>()) << "Expect type 'dict', but get: " << 
obj->GetTypeKey();
+    std::unordered_map<ObjectRef, ObjectRef, ObjectHash, ObjectEqual> result;
+    for (const auto& kv : Downcast<Map<ObjectRef, ObjectRef>>(obj)) {
+      ObjectRef key, val;
+      try {
+        key = TargetNode::ParseAttr(kv.first, *info.key);
+      } catch (const tvm::Error& e) {
+        LOG(FATAL) << "Error occurred when parsing a key of the dict: " << 
kv.first
+                   << ". Details:\n"
+                   << e.what();
+      }
+      try {
+        val = TargetNode::ParseAttr(kv.second, *info.val);
+      } catch (const tvm::Error& e) {
+        LOG(FATAL) << "Error occurred when parsing a value of the dict: " << 
kv.second
+                   << ". Details:\n"
+                   << e.what();
+      }
+      result[key] = val;
+    }
+    return Map<ObjectRef, ObjectRef>(result);
+  }
+  LOG(FATAL) << "Unsupported type registered: \"" << info.type_key
+             << "\", and the type given is: " << obj->GetTypeKey();
+  throw;
+}
+
+Target Target::FromConfig(const Map<String, ObjectRef>& config_dict) {
+  const String kKind = "kind";
+  const String kTag = "tag";
+  const String kKeys = "keys";
+  const String kDeviceName = "device";
+  std::unordered_map<std::string, ObjectRef> config(config_dict.begin(), 
config_dict.end());
+  ObjectPtr<TargetNode> target = make_object<TargetNode>();
+  // parse 'kind'
+  if (config.count(kKind)) {
+    const auto* kind = config[kKind].as<StringObj>();
+    CHECK(kind != nullptr) << "AttributeError: Expect type of field 'kind' is 
string, but get: "
+                           << config[kKind]->GetTypeKey();
+    target->kind = TargetKind::Get(GetRef<String>(kind));
+    config.erase(kKind);
+  } else {
+    LOG(FATAL) << "AttributeError: Field 'kind' is not found";
+  }
+  // parse "tag"
+  if (config.count(kTag)) {
+    const auto* tag = config[kTag].as<StringObj>();
+    CHECK(tag != nullptr) << "AttributeError: Expect type of field 'tag' is 
string, but get: "
+                          << config[kTag]->GetTypeKey();
+    target->tag = GetRef<String>(tag);
+    config.erase(kTag);
+  } else {
+    target->tag = "";
+  }
+  // parse "keys"
+  if (config.count(kKeys)) {
+    std::vector<String> keys;
+    // user provided keys
+    const auto* cfg_keys = config[kKeys].as<ArrayNode>();
+    CHECK(cfg_keys != nullptr)
+        << "AttributeError: Expect type of field 'keys' is an Array, but get: "
+        << config[kTag]->GetTypeKey();
+    for (const ObjectRef& e : *cfg_keys) {
+      const auto* key = e.as<StringObj>();
+      CHECK(key != nullptr) << "AttributeError: Expect 'keys' to be an array 
of strings, but it "
+                               "contains an element of type: "
+                            << e->GetTypeKey();
+      keys.push_back(GetRef<String>(key));
+    }
+    // add device name
+    if (config_dict.count(kDeviceName)) {
+      if (const auto* device = config_dict.at(kDeviceName).as<StringObj>()) {
+        keys.push_back(GetRef<String>(device));
+      }
+    }
+    // add default keys
+    for (const auto& key : target->kind->default_keys) {
+      keys.push_back(key);
+    }
+    // de-duplicate keys
+    target->keys = DeduplicateKeys(keys);
+    config.erase(kKeys);
+  } else {
+    target->keys = {};
+  }
+  // parse attrs
+  std::unordered_map<String, ObjectRef> attrs;
+  const auto& key2vtype = target->kind->key2vtype_;
+  for (const auto& cfg_kv : config) {
+    const String& name = cfg_kv.first;
+    const ObjectRef& obj = cfg_kv.second;
+    if (!key2vtype.count(name)) {
+      std::ostringstream os;
+      os << "AttributeError: Unrecognized config option: \"" << name << "\". 
Candidates are:";
+      for (const auto& kv : key2vtype) {
+        os << " " << kv.first;
+      }
+      LOG(FATAL) << os.str();
+    }
+    ObjectRef val;
+    try {
+      val = target->ParseAttr(obj, key2vtype.at(name));
+    } catch (const dmlc::Error& e) {
+      LOG(FATAL) << "AttributeError: Error occurred in parsing the config key 
\"" << name
+                 << "\". Details:\n"
+                 << e.what();
+    }
+    attrs[name] = val;
+  }
+  // set default attribute values if they do not exist
+  for (const auto& kv : target->kind->key2default_) {
+    if (!attrs.count(kv.first)) {
+      attrs[kv.first] = kv.second;
+    }
+  }
+  target->attrs = attrs;
+  return Target(target);
+}
+
 /*! \brief Entry to hold the Target context stack. */
 struct TVMTargetThreadLocalEntry {
   /*! \brief The current target context */
@@ -169,7 +475,7 @@ struct TVMTargetThreadLocalEntry {
 };
 
 /*! \brief Thread local store to hold the Target context stack. */
-typedef dmlc::ThreadLocalStore<TVMTargetThreadLocalEntry> 
TVMTargetThreadLocalStore;
+using TVMTargetThreadLocalStore = 
dmlc::ThreadLocalStore<TVMTargetThreadLocalEntry>;
 
 void Target::EnterWithScope() {
   TVMTargetThreadLocalEntry* entry = TVMTargetThreadLocalStore::Get();
@@ -194,20 +500,37 @@ tvm::Target Target::Current(bool allow_not_defined) {
   return Target();
 }
 
-TVM_REGISTER_GLOBAL("target.GetCurrentTarget").set_body([](TVMArgs args, 
TVMRetValue* ret) {
-  bool allow_not_defined = args[0];
-  *ret = Target::Current(allow_not_defined);
-});
 class Target::Internal {
  public:
   static void EnterScope(Target target) { target.EnterWithScope(); }
   static void ExitScope(Target target) { target.ExitWithScope(); }
 };
 
+TVM_REGISTER_GLOBAL("target.TargetCreate").set_body([](TVMArgs args, 
TVMRetValue* ret) {
+  std::string name = args[0];
+  std::vector<std::string> options;
+  for (int i = 1; i < args.num_args; ++i) {
+    std::string arg = args[i];
+    options.push_back(arg);
+  }
+
+  *ret = Target::CreateTarget(name, options);
+});
+
 
TVM_REGISTER_GLOBAL("target.EnterTargetScope").set_body_typed(Target::Internal::EnterScope);
 
 
TVM_REGISTER_GLOBAL("target.ExitTargetScope").set_body_typed(Target::Internal::ExitScope);
 
+TVM_REGISTER_GLOBAL("target.GetCurrentTarget").set_body_typed(Target::Current);
+
+TVM_REGISTER_GLOBAL("target.TargetFromString").set_body_typed(Target::Create);
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+    .set_dispatch<TargetNode>([](const ObjectRef& node, ReprPrinter* p) {
+      auto* op = static_cast<const TargetNode*>(node.get());
+      p->stream << op->str();
+    });
+
 namespace target {
 std::vector<std::string> MergeOptions(std::vector<std::string> opts,
                                       const std::vector<std::string>& 
new_opts) {
diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc
index e6f7c5c..3e35e5b 100644
--- a/src/target/target_kind.cc
+++ b/src/target/target_kind.cc
@@ -26,7 +26,6 @@
 #include <algorithm>
 
 #include "../node/attr_registry.h"
-#include "../runtime/object_internal.h"
 
 namespace tvm {
 
@@ -60,294 +59,6 @@ const TargetKind& TargetKind::Get(const String& 
target_kind_name) {
   return reg->kind_;
 }
 
-void TargetKindNode::VerifyTypeInfo(const ObjectRef& obj,
-                                    const TargetKindNode::ValueTypeInfo& info) 
const {
-  CHECK(obj.defined()) << "Object is None";
-  if (!runtime::ObjectInternal::DerivedFrom(obj.get(), info.type_index)) {
-    LOG(FATAL) << "AttributeError: expect type \"" << info.type_key << "\" but 
get "
-               << obj->GetTypeKey();
-    throw;
-  }
-  if (info.type_index == ArrayNode::_type_index) {
-    int i = 0;
-    for (const auto& e : *obj.as<ArrayNode>()) {
-      try {
-        VerifyTypeInfo(e, *info.key);
-      } catch (const tvm::Error& e) {
-        LOG(FATAL) << "The i-th element of array failed type checking, where i 
= " << i
-                   << ", and the error is:\n"
-                   << e.what();
-        throw;
-      }
-      ++i;
-    }
-  } else if (info.type_index == MapNode::_type_index) {
-    for (const auto& kv : *obj.as<MapNode>()) {
-      try {
-        VerifyTypeInfo(kv.first, *info.key);
-      } catch (const tvm::Error& e) {
-        LOG(FATAL) << "The key of map failed type checking, where key = \"" << 
kv.first
-                   << "\", value = \"" << kv.second << "\", and the error 
is:\n"
-                   << e.what();
-        throw;
-      }
-      try {
-        VerifyTypeInfo(kv.second, *info.val);
-      } catch (const tvm::Error& e) {
-        LOG(FATAL) << "The value of map failed type checking, where key = \"" 
<< kv.first
-                   << "\", value = \"" << kv.second << "\", and the error 
is:\n"
-                   << e.what();
-        throw;
-      }
-    }
-  }
-}
-
-void TargetKindNode::ValidateSchema(const Map<String, ObjectRef>& config) 
const {
-  const String kTargetKind = "kind";
-  for (const auto& kv : config) {
-    const String& name = kv.first;
-    const ObjectRef& obj = kv.second;
-    if (name == kTargetKind) {
-      CHECK(obj->IsInstance<StringObj>())
-          << "AttributeError: \"kind\" is not a string, but its type is \"" << 
obj->GetTypeKey()
-          << "\"";
-      CHECK(Downcast<String>(obj) == this->name)
-          << "AttributeError: \"kind\" = \"" << obj << "\" is inconsistent 
with TargetKind \""
-          << this->name << "\"";
-      continue;
-    }
-    auto it = key2vtype_.find(name);
-    if (it == key2vtype_.end()) {
-      std::ostringstream os;
-      os << "AttributeError: Invalid config option, cannot recognize \"" << 
name
-         << "\". Candidates are:";
-      for (const auto& kv : key2vtype_) {
-        os << "\n  " << kv.first;
-      }
-      LOG(FATAL) << os.str();
-      throw;
-    }
-    const auto& info = it->second;
-    try {
-      VerifyTypeInfo(obj, info);
-    } catch (const tvm::Error& e) {
-      LOG(FATAL) << "AttributeError: Schema validation failed for TargetKind 
\"" << this->name
-                 << "\", details:\n"
-                 << e.what() << "\n"
-                 << "The config is:\n"
-                 << config;
-      throw;
-    }
-  }
-}
-
-inline String GetKind(const Map<String, ObjectRef>& target, const char* name) {
-  const String kTargetKind = "kind";
-  CHECK(target.count(kTargetKind))
-      << "AttributeError: \"kind\" does not exist in \"" << name << "\"\n"
-      << name << " = " << target;
-  const ObjectRef& obj = target[kTargetKind];
-  CHECK(obj->IsInstance<StringObj>()) << "AttributeError: \"kind\" is not a 
string in \"" << name
-                                      << "\", but its type is \"" << 
obj->GetTypeKey() << "\"\n"
-                                      << name << " = \"" << target << '"';
-  return Downcast<String>(obj);
-}
-
-void TargetValidateSchema(const Map<String, ObjectRef>& config) {
-  try {
-    const String kTargetHost = "target_host";
-    Map<String, ObjectRef> target = config;
-    Map<String, ObjectRef> target_host;
-    String target_kind = GetKind(target, "target");
-    String target_host_kind;
-    if (config.count(kTargetHost)) {
-      target.erase(kTargetHost);
-      target_host = Downcast<Map<String, ObjectRef>>(config[kTargetHost]);
-      target_host_kind = GetKind(target_host, "target_host");
-    }
-    TargetKind::Get(target_kind)->ValidateSchema(target);
-    if (!target_host.empty()) {
-      TargetKind::Get(target_host_kind)->ValidateSchema(target_host);
-    }
-  } catch (const tvm::Error& e) {
-    LOG(FATAL) << "AttributeError: schedule validation fails:\n"
-               << e.what() << "\nThe configuration is:\n"
-               << config;
-  }
-}
-
-static inline size_t CountNumPrefixDashes(const std::string& s) {
-  size_t i = 0;
-  for (; i < s.length() && s[i] == '-'; ++i) {
-  }
-  return i;
-}
-
-static inline int FindUniqueSubstr(const std::string& str, const std::string& 
substr) {
-  size_t pos = str.find_first_of(substr);
-  if (pos == std::string::npos) {
-    return -1;
-  }
-  size_t next_pos = pos + substr.size();
-  CHECK(next_pos >= str.size() || str.find_first_of(substr, next_pos) == 
std::string::npos)
-      << "ValueError: At most one \"" << substr << "\" is allowed in "
-      << "the the given string \"" << str << "\"";
-  return pos;
-}
-
-static inline ObjectRef ParseScalar(uint32_t type_index, const std::string& 
str) {
-  std::istringstream is(str);
-  if (type_index == Integer::ContainerType::_GetOrAllocRuntimeTypeIndex()) {
-    int v;
-    is >> v;
-    return is.fail() ? ObjectRef(nullptr) : Integer(v);
-  } else if (type_index == 
String::ContainerType::_GetOrAllocRuntimeTypeIndex()) {
-    std::string v;
-    is >> v;
-    return is.fail() ? ObjectRef(nullptr) : String(v);
-  }
-  return ObjectRef(nullptr);
-}
-
-static inline Optional<String> StringifyScalar(const ObjectRef& obj) {
-  if (const auto* p = obj.as<IntImmNode>()) {
-    return String(std::to_string(p->value));
-  }
-  if (const auto* p = obj.as<StringObj>()) {
-    return GetRef<String>(p);
-  }
-  return NullOpt;
-}
-
-static inline Optional<String> Join(const std::vector<String>& array, char 
separator) {
-  if (array.empty()) {
-    return NullOpt;
-  }
-  std::ostringstream os;
-  os << array[0];
-  for (size_t i = 1; i < array.size(); ++i) {
-    os << separator << array[i];
-  }
-  return String(os.str());
-}
-
-Map<String, ObjectRef> TargetKindNode::ParseAttrsFromRaw(
-    const std::vector<std::string>& options) const {
-  std::unordered_map<String, ObjectRef> attrs;
-  for (size_t iter = 0, end = options.size(); iter < end;) {
-    std::string s = options[iter++];
-    // remove the prefix dashes
-    size_t n_dashes = CountNumPrefixDashes(s);
-    CHECK(0 < n_dashes && n_dashes < s.size())
-        << "ValueError: Not an attribute key \"" << s << "\"";
-    s = s.substr(n_dashes);
-    // parse name-obj pair
-    std::string name;
-    std::string obj;
-    int pos;
-    if ((pos = FindUniqueSubstr(s, "=")) != -1) {
-      // case 1. --key=value
-      name = s.substr(0, pos);
-      obj = s.substr(pos + 1);
-      CHECK(!name.empty()) << "ValueError: Empty attribute key in \"" << 
options[iter - 1] << "\"";
-      CHECK(!obj.empty()) << "ValueError: Empty attribute in \"" << 
options[iter - 1] << "\"";
-    } else if (iter < end && options[iter][0] != '-') {
-      // case 2. --key value
-      name = s;
-      obj = options[iter++];
-    } else {
-      // case 3. --boolean-key
-      name = s;
-      obj = "1";
-    }
-    // check if `name` is invalid
-    auto it = key2vtype_.find(name);
-    if (it == key2vtype_.end()) {
-      std::ostringstream os;
-      os << "AttributeError: Invalid config option, cannot recognize \'" << 
name
-         << "\'. Candidates are:";
-      for (const auto& kv : key2vtype_) {
-        os << "\n  " << kv.first;
-      }
-      LOG(FATAL) << os.str();
-    }
-    // check if `name` has been set once
-    CHECK(!attrs.count(name)) << "AttributeError: key \"" << name
-                              << "\" appears more than once in the target 
string";
-    // then `name` is valid, let's parse them
-    // only several types are supported when parsing raw string
-    const auto& info = it->second;
-    ObjectRef parsed_obj(nullptr);
-    if (info.type_index != ArrayNode::_type_index) {
-      parsed_obj = ParseScalar(info.type_index, obj);
-    } else {
-      Array<ObjectRef> array;
-      std::string item;
-      bool failed = false;
-      uint32_t type_index = info.key->type_index;
-      for (std::istringstream is(obj); std::getline(is, item, ',');) {
-        ObjectRef parsed_obj = ParseScalar(type_index, item);
-        if (parsed_obj.defined()) {
-          array.push_back(parsed_obj);
-        } else {
-          failed = true;
-          break;
-        }
-      }
-      if (!failed) {
-        parsed_obj = std::move(array);
-      }
-    }
-    if (!parsed_obj.defined()) {
-      LOG(FATAL) << "ValueError: Cannot parse type \"" << info.type_key << "\""
-                 << ", where attribute key is \"" << name << "\""
-                 << ", and attribute is \"" << obj << "\"";
-    }
-    attrs[name] = std::move(parsed_obj);
-  }
-  // set default attribute values if they do not exist
-  for (const auto& kv : key2default_) {
-    if (!attrs.count(kv.first)) {
-      attrs[kv.first] = kv.second;
-    }
-  }
-  return attrs;
-}
-
-Optional<String> TargetKindNode::StringifyAttrsToRaw(const Map<String, 
ObjectRef>& attrs) const {
-  std::ostringstream os;
-  std::vector<String> keys;
-  for (const auto& kv : attrs) {
-    keys.push_back(kv.first);
-  }
-  std::sort(keys.begin(), keys.end());
-  std::vector<String> result;
-  for (const auto& key : keys) {
-    const ObjectRef& obj = attrs[key];
-    Optional<String> value = NullOpt;
-    if (const auto* array = obj.as<ArrayNode>()) {
-      std::vector<String> items;
-      for (const ObjectRef& item : *array) {
-        Optional<String> str = StringifyScalar(item);
-        if (str.defined()) {
-          items.push_back(str.value());
-        } else {
-          items.clear();
-          break;
-        }
-      }
-      value = Join(items, ',');
-    } else {
-      value = StringifyScalar(obj);
-    }
-    if (value.defined()) {
-      result.push_back("-" + key + "=" + value.value());
-    }
-  }
-  return Join(result, ' ');
-}
-
 // TODO(@junrushao1994): remove some redundant attributes
 
 TVM_REGISTER_TARGET_KIND("llvm")
diff --git a/tests/cpp/target_test.cc b/tests/cpp/target_test.cc
index 8bee707..e8748e6 100644
--- a/tests/cpp/target_test.cc
+++ b/tests/cpp/target_test.cc
@@ -19,7 +19,7 @@
 
 #include <dmlc/logging.h>
 #include <gtest/gtest.h>
-#include <tvm/target/target_kind.h>
+#include <tvm/target/target.h>
 
 #include <cmath>
 #include <string>
@@ -39,48 +39,117 @@ TEST(TargetKind, GetAttrMap) {
   CHECK_EQ(result, "Value1");
 }
 
-TEST(TargetKind, SchemaValidation) {
-  tvm::Map<String, ObjectRef> target;
-  {
-    tvm::Array<String> your_names{"junru", "jian"};
-    tvm::Map<String, Integer> her_maps{
-        {"a", 1},
-        {"b", 2},
-    };
-    target.Set("my_bool", Bool(true));
-    target.Set("your_names", your_names);
-    target.Set("her_maps", her_maps);
-    target.Set("kind", String("TestTargetKind"));
+TEST(TargetCreation, NestedConfig) {
+  Map<String, ObjectRef> config = {
+      {"my_bool", Bool(true)},
+      {"your_names", Array<String>{"junru", "jian"}},
+      {"kind", String("TestTargetKind")},
+      {
+          "her_maps",
+          Map<String, Integer>{
+              {"a", 1},
+              {"b", 2},
+          },
+      },
+  };
+  Target target = Target::FromConfig(config);
+  CHECK_EQ(target->kind, TargetKind::Get("TestTargetKind"));
+  CHECK_EQ(target->tag, "");
+  CHECK(target->keys.empty());
+  Bool my_bool = target->GetAttr<Bool>("my_bool").value();
+  CHECK_EQ(my_bool.operator bool(), true);
+  Array<String> your_names = 
target->GetAttr<Array<String>>("your_names").value();
+  CHECK_EQ(your_names.size(), 2U);
+  CHECK_EQ(your_names[0], "junru");
+  CHECK_EQ(your_names[1], "jian");
+  Map<String, Integer> her_maps = target->GetAttr<Map<String, 
Integer>>("her_maps").value();
+  CHECK_EQ(her_maps.size(), 2U);
+  CHECK_EQ(her_maps["a"], 1);
+  CHECK_EQ(her_maps["b"], 2);
+}
+
+TEST(TargetCreationFail, UnrecognizedConfigOption) {
+  Map<String, ObjectRef> config = {
+      {"my_bool", Bool(true)},
+      {"your_names", Array<String>{"junru", "jian"}},
+      {"kind", String("TestTargetKind")},
+      {"bad", ObjectRef(nullptr)},
+      {
+          "her_maps",
+          Map<String, Integer>{
+              {"a", 1},
+              {"b", 2},
+          },
+      },
+  };
+  bool failed = false;
+  try {
+    Target::FromConfig(config);
+  } catch (...) {
+    failed = true;
   }
-  TargetValidateSchema(target);
-  tvm::Map<String, ObjectRef> target_host(target.begin(), target.end());
-  target.Set("target_host", target_host);
-  TargetValidateSchema(target);
+  ASSERT_EQ(failed, true);
 }
 
-TEST(TargetKind, SchemaValidationFail) {
-  tvm::Map<String, ObjectRef> target;
-  {
-    tvm::Array<String> your_names{"junru", "jian"};
-    tvm::Map<String, Integer> her_maps{
-        {"a", 1},
-        {"b", 2},
-    };
-    target.Set("my_bool", Bool(true));
-    target.Set("your_names", your_names);
-    target.Set("her_maps", her_maps);
-    target.Set("ok", ObjectRef(nullptr));
-    target.Set("kind", String("TestTargetKind"));
+TEST(TargetCreationFail, TypeMismatch) {
+  Map<String, ObjectRef> config = {
+      {"my_bool", String("true")},
+      {"your_names", Array<String>{"junru", "jian"}},
+      {"kind", String("TestTargetKind")},
+      {
+          "her_maps",
+          Map<String, Integer>{
+              {"a", 1},
+              {"b", 2},
+          },
+      },
+  };
+  bool failed = false;
+  try {
+    Target::FromConfig(config);
+  } catch (...) {
+    failed = true;
   }
+  ASSERT_EQ(failed, true);
+}
+
+TEST(TargetCreationFail, TargetKindNotFound) {
+  Map<String, ObjectRef> config = {
+      {"my_bool", Bool("true")},
+      {"your_names", Array<String>{"junru", "jian"}},
+      {
+          "her_maps",
+          Map<String, Integer>{
+              {"a", 1},
+              {"b", 2},
+          },
+      },
+  };
   bool failed = false;
   try {
-    TargetValidateSchema(target);
+    Target::FromConfig(config);
   } catch (...) {
     failed = true;
   }
   ASSERT_EQ(failed, true);
 }
 
+TEST(TargetCreation, DeduplicateKeys) {
+  Map<String, ObjectRef> config = {
+      {"kind", String("llvm")},
+      {"keys", Array<String>{"cpu", "arm_cpu"}},
+      {"device", String("arm_cpu")},
+  };
+  Target target = Target::FromConfig(config);
+  CHECK_EQ(target->kind, TargetKind::Get("llvm"));
+  CHECK_EQ(target->tag, "");
+  CHECK_EQ(target->keys.size(), 2U);
+  CHECK_EQ(target->keys[0], "cpu");
+  CHECK_EQ(target->keys[1], "arm_cpu");
+  CHECK_EQ(target->attrs.size(), 1U);
+  CHECK_EQ(target->GetAttr<String>("device"), "arm_cpu");
+}
+
 int main(int argc, char** argv) {
   testing::InitGoogleTest(&argc, argv);
   testing::FLAGS_gtest_death_test_style = "threadsafe";

Reply via email to