junrushao1994 commented on a change in pull request #6218:
URL: https://github.com/apache/incubator-tvm/pull/6218#discussion_r468930805
##########
File path: src/target/target.cc
##########
@@ -162,14 +313,149 @@ Target Target::Create(const String& target_str) {
return CreateTarget(splits[0], {splits.begin() + 1, splits.end()});
}
+ObjectRef TargetNode::ParseAttr(const ObjectRef& obj,
+ const TargetKindNode::ValueTypeInfo& info)
const {
+ if (info.type_index ==
Integer::ContainerType::_GetOrAllocRuntimeTypeIndex()) {
+ const auto* v = obj.as<IntImmNode>();
+ CHECK(v != nullptr) << "Expect type 'int', but get: " << obj->GetTypeKey();
+ return GetRef<Integer>(v);
+ }
+ if (info.type_index == String::ContainerType::_GetOrAllocRuntimeTypeIndex())
{
+ const auto* v = obj.as<StringObj>();
+ CHECK(v != nullptr) << "Expect type 'str', but get: " << obj->GetTypeKey();
+ return GetRef<String>(v);
+ }
+ if (info.type_index == Target::ContainerType::_GetOrAllocRuntimeTypeIndex())
{
+ CHECK(obj->IsInstance<MapNode>())
+ << "Expect type 'dict' to construct Target, but get: " <<
obj->GetTypeKey();
+ return Target::FromConfig(Downcast<Map<String, ObjectRef>>(obj));
+ }
+ if (info.type_index == ArrayNode::_GetOrAllocRuntimeTypeIndex()) {
+ CHECK(obj->IsInstance<ArrayNode>()) << "Expect type 'list', but get: " <<
obj->GetTypeKey();
+ Array<ObjectRef> array = Downcast<Array<ObjectRef>>(obj);
+ std::vector<ObjectRef> result;
+ int i = 0;
+ for (const ObjectRef& e : array) {
+ ++i;
+ try {
+ result.push_back(TargetNode::ParseAttr(e, *info.key));
+ } catch (const dmlc::Error& e) {
+ LOG(FATAL) << "Error occurred when parsing element " << i << " of the
array: " << array
+ << ". Details:\n"
+ << e.what();
+ }
+ }
+ return Array<ObjectRef>(result);
+ }
+ if (info.type_index == MapNode::_GetOrAllocRuntimeTypeIndex()) {
+ CHECK(obj->IsInstance<MapNode>()) << "Expect type 'dict', but get: " <<
obj->GetTypeKey();
+ std::unordered_map<ObjectRef, ObjectRef, ObjectHash, ObjectEqual> result;
+ for (const auto& kv : Downcast<Map<ObjectRef, ObjectRef>>(obj)) {
+ ObjectRef key, val;
+ try {
+ key = TargetNode::ParseAttr(kv.first, *info.key);
+ } catch (const tvm::Error& e) {
+ LOG(FATAL) << "Error occurred when parsing a key of the dict: " <<
kv.first
+ << ". Details:\n"
+ << e.what();
+ }
+ try {
+ val = TargetNode::ParseAttr(kv.second, *info.val);
+ } catch (const tvm::Error& e) {
+ LOG(FATAL) << "Error occurred when parsing a value of the dict: " <<
kv.second
+ << ". Details:\n"
+ << e.what();
+ }
+ result[key] = val;
+ }
+ return Map<ObjectRef, ObjectRef>(result);
+ }
+ LOG(FATAL) << "Unsupported type registered: \"" << info.type_key
+ << "\", and the type given is: " << obj->GetTypeKey();
+ throw;
+}
+
+Target Target::FromConfig(const Map<String, ObjectRef>& config_dict) {
+ const String kKind = "kind";
+ const String kTag = "tag";
+ const String kKeys = "keys";
+ std::unordered_map<std::string, ObjectRef> config(config_dict.begin(),
config_dict.end());
+ ObjectPtr<TargetNode> target = make_object<TargetNode>();
+ // parse 'kind'
+ if (config.count(kKind)) {
+ const auto* kind = config[kKind].as<StringObj>();
+ CHECK(kind != nullptr) << "AttributeError: Expect type of field 'kind' is
string, but get: "
+ << config[kKind]->GetTypeKey();
+ target->kind = TargetKind::Get(GetRef<String>(kind));
+ config.erase(kKind);
+ } else {
+ LOG(FATAL) << "AttributeError: Field 'kind' is not found";
+ }
+ // parse "tag"
+ if (config.count(kTag)) {
+ const auto* tag = config[kTag].as<StringObj>();
+ CHECK(tag != nullptr) << "AttributeError: Expect type of field 'tag' is
string, but get: "
+ << config[kTag]->GetTypeKey();
+ target->tag = GetRef<String>(tag);
+ config.erase(kTag);
+ } else {
+ target->tag = "";
+ }
+ // parse "keys"
+ // TODO(@junrushao1994): add more keys according to CreateTarget
Review comment:
Except for keys given from users, there are several keys we have to
append, according to the convention of CreateTarget
1) device name
2) default keys
3) and then de-duplicate those keys
I will fix this in the PR
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]