This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new d38bef5 [Target] Allow spaces in target attributes (#8587)
d38bef5 is described below
commit d38bef57d803a4cae6be11181156bbf4243a694c
Author: Lunderberg <[email protected]>
AuthorDate: Tue Aug 3 21:43:39 2021 -0500
[Target] Allow spaces in target attributes (#8587)
* [Target] Allow for spaces in target attributes.
Some target parameters, such as the device_name on vulkan, have spaces
in them. This prevented round-trips between string and Target
objects, which can occur in some cases.
* [Vulkan] Fixed "device_name" property querying.
* [Target] Switched from escaped spaces to quoted spaces.
Instead of -attr=value\ with\ spaces, will instead be written as
-attr='value with spaces'.
Co-authored-by: Eric Lunderberg <[email protected]>
---
src/runtime/vulkan/vulkan_device_api.cc | 2 +-
src/target/target.cc | 106 +++++++++++++++++++++-------
tests/python/unittest/test_target_target.py | 8 +++
3 files changed, 90 insertions(+), 26 deletions(-)
diff --git a/src/runtime/vulkan/vulkan_device_api.cc
b/src/runtime/vulkan/vulkan_device_api.cc
index b4987eb..1c190f3 100644
--- a/src/runtime/vulkan/vulkan_device_api.cc
+++ b/src/runtime/vulkan/vulkan_device_api.cc
@@ -122,7 +122,7 @@ void VulkanDeviceAPI::GetAttr(Device dev, DeviceAttrKind
kind, TVMRetValue* rv)
break;
}
case kDeviceName:
- *rv = prop.device_name;
+ *rv = String(prop.device_name);
break;
case kMaxClockRate:
diff --git a/src/target/target.cc b/src/target/target.cc
index d8e71de..50067d4 100644
--- a/src/target/target.cc
+++ b/src/target/target.cc
@@ -30,6 +30,7 @@
#include <algorithm>
#include <cctype>
+#include <cstring>
#include <stack>
#include "../runtime/object_internal.h"
@@ -147,17 +148,83 @@ static int FindFirstSubstr(const std::string& str, const
std::string& substr) {
}
static Optional<String> JoinString(const std::vector<String>& array, char
separator) {
+ char escape = '\\';
+ char quote = '\'';
+
if (array.empty()) {
return NullOpt;
}
+
std::ostringstream os;
- os << array[0];
- for (size_t i = 1; i < array.size(); ++i) {
- os << separator << array[i];
+
+ for (size_t i = 0; i < array.size(); ++i) {
+ if (i > 0) {
+ os << separator;
+ }
+
+ std::string str = array[i];
+
+ if ((str.find(separator) == std::string::npos) && (str.find(quote) ==
std::string::npos)) {
+ os << str;
+ } else {
+ os << quote;
+ for (char c : str) {
+ if (c == separator || c == quote) {
+ os << escape;
+ }
+ os << c;
+ }
+ os << quote;
+ }
}
return String(os.str());
}
+static std::vector<std::string> SplitString(const std::string& str, char
separator) {
+ char escape = '\\';
+ char quote = '\'';
+
+ std::vector<std::string> output;
+
+ const char* start = str.data();
+ const char* end = start + str.size();
+ const char* pos = start;
+
+ std::stringstream current_word;
+
+ auto finish_word = [&]() {
+ std::string word = current_word.str();
+ if (word.size()) {
+ output.push_back(word);
+ current_word.str("");
+ }
+ };
+
+ bool pos_quoted = false;
+
+ while (pos < end) {
+ if ((*pos == separator) && !pos_quoted) {
+ finish_word();
+ pos++;
+ } else if ((*pos == escape) && (pos + 1 < end) && (pos[1] == quote)) {
+ current_word << quote;
+ pos += 2;
+ } else if (*pos == quote) {
+ pos_quoted = !pos_quoted;
+ pos++;
+ } else {
+ current_word << *pos;
+ pos++;
+ }
+ }
+
+ ICHECK(!pos_quoted) << "Mismatched quotes '' in string";
+
+ finish_word();
+
+ return output;
+}
+
static int ParseKVPair(const std::string& s, const std::string& s_next,
std::string* key,
std::string* value) {
int pos;
@@ -207,9 +274,9 @@ const TargetKindNode::ValueTypeInfo&
TargetInternal::FindTypeInfo(const TargetKi
ObjectRef TargetInternal::ParseType(const std::string& str,
const TargetKindNode::ValueTypeInfo& info)
{
- std::istringstream is(str);
if (info.type_index ==
Integer::ContainerType::_GetOrAllocRuntimeTypeIndex()) {
// Parsing integer
+ std::istringstream is(str);
int v;
if (!(is >> v)) {
std::string lower(str.size(), '\x0');
@@ -226,19 +293,18 @@ ObjectRef TargetInternal::ParseType(const std::string&
str,
}
return Integer(v);
} else if (info.type_index ==
String::ContainerType::_GetOrAllocRuntimeTypeIndex()) {
- // Parsing string
- std::string v;
- if (!(is >> v)) {
- throw Error(": Cannot parse into type \"String\" from string: " + str);
- }
- return String(v);
+ // Parsing string, strip leading/trailing spaces
+ auto start = str.find_first_not_of(' ');
+ auto end = str.find_last_not_of(' ');
+ return String(str.substr(start, (end - start + 1)));
+
} else if (info.type_index ==
Target::ContainerType::_GetOrAllocRuntimeTypeIndex()) {
// Parsing target
return Target(TargetInternal::FromString(str));
} else if (info.type_index == ArrayNode::_GetOrAllocRuntimeTypeIndex()) {
// Parsing array
std::vector<ObjectRef> result;
- for (std::string substr; std::getline(is, substr, ',');) {
+ for (const std::string& substr : SplitString(str, ',')) {
try {
ObjectRef parsed = TargetInternal::ParseType(substr, *info.key);
result.push_back(parsed);
@@ -550,24 +616,14 @@ ObjectPtr<Object> TargetInternal::FromConfigString(const
String& config_str) {
}
ObjectPtr<Object> TargetInternal::FromRawString(const String& target_str) {
+ ICHECK_GT(target_str.length(), 0) << "Cannot parse empty target string";
// Split the string by empty spaces
- std::string name;
- std::vector<std::string> options;
- std::string str;
- for (std::istringstream is(target_str); is >> str;) {
- if (name.empty()) {
- name = str;
- } else {
- options.push_back(str);
- }
- }
- if (name.empty()) {
- throw Error(": Cannot parse empty target string");
- }
+ std::vector<std::string> options = SplitString(std::string(target_str), ' ');
+ std::string name = options[0];
// Create the target config
std::unordered_map<String, ObjectRef> config = {{"kind", String(name)}};
TargetKind kind = GetTargetKind(name);
- for (size_t iter = 0, end = options.size(); iter < end;) {
+ for (size_t iter = 1, end = options.size(); iter < end;) {
std::string key, value;
try {
// Parse key-value pair
diff --git a/tests/python/unittest/test_target_target.py
b/tests/python/unittest/test_target_target.py
index bb3aa9e..5007ef1 100644
--- a/tests/python/unittest/test_target_target.py
+++ b/tests/python/unittest/test_target_target.py
@@ -76,6 +76,14 @@ def test_target_string_parse():
assert tvm.target.arm_cpu().device_name == "arm_cpu"
+def test_target_string_with_spaces():
+ target = tvm.target.Target(
+ "vulkan -device_name='Name of GPU with spaces' -device_type=discrete"
+ )
+ assert target.attrs["device_name"] == "Name of GPU with spaces"
+ assert target.attrs["device_type"] == "discrete"
+
+
def test_target_create():
targets = [cuda(), rocm(), mali(), intel_graphics(), arm_cpu("rk3399"),
vta(), bifrost()]
for tgt in targets: