This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new d38bef5  [Target] Allow spaces in target attributes (#8587)
d38bef5 is described below

commit d38bef57d803a4cae6be11181156bbf4243a694c
Author: Lunderberg <[email protected]>
AuthorDate: Tue Aug 3 21:43:39 2021 -0500

    [Target] Allow spaces in target attributes (#8587)
    
    * [Target] Allow for spaces in target attributes.
    
    Some target parameters, such as the device_name on vulkan, have spaces
    in them.  This prevented round-trips between string and Target
    objects, which can occur in some cases.
    
    * [Vulkan] Fixed "device_name" property querying.
    
    * [Target] Switched from escaped spaces to quoted spaces.
    
    Instead of -attr=value\ with\ spaces, will instead be written as
    -attr='value with spaces'.
    
    Co-authored-by: Eric Lunderberg <[email protected]>
---
 src/runtime/vulkan/vulkan_device_api.cc     |   2 +-
 src/target/target.cc                        | 106 +++++++++++++++++++++-------
 tests/python/unittest/test_target_target.py |   8 +++
 3 files changed, 90 insertions(+), 26 deletions(-)

diff --git a/src/runtime/vulkan/vulkan_device_api.cc 
b/src/runtime/vulkan/vulkan_device_api.cc
index b4987eb..1c190f3 100644
--- a/src/runtime/vulkan/vulkan_device_api.cc
+++ b/src/runtime/vulkan/vulkan_device_api.cc
@@ -122,7 +122,7 @@ void VulkanDeviceAPI::GetAttr(Device dev, DeviceAttrKind 
kind, TVMRetValue* rv)
       break;
     }
     case kDeviceName:
-      *rv = prop.device_name;
+      *rv = String(prop.device_name);
       break;
 
     case kMaxClockRate:
diff --git a/src/target/target.cc b/src/target/target.cc
index d8e71de..50067d4 100644
--- a/src/target/target.cc
+++ b/src/target/target.cc
@@ -30,6 +30,7 @@
 
 #include <algorithm>
 #include <cctype>
+#include <cstring>
 #include <stack>
 
 #include "../runtime/object_internal.h"
@@ -147,17 +148,83 @@ static int FindFirstSubstr(const std::string& str, const 
std::string& substr) {
 }
 
 static Optional<String> JoinString(const std::vector<String>& array, char 
separator) {
+  char escape = '\\';
+  char quote = '\'';
+
   if (array.empty()) {
     return NullOpt;
   }
+
   std::ostringstream os;
-  os << array[0];
-  for (size_t i = 1; i < array.size(); ++i) {
-    os << separator << array[i];
+
+  for (size_t i = 0; i < array.size(); ++i) {
+    if (i > 0) {
+      os << separator;
+    }
+
+    std::string str = array[i];
+
+    if ((str.find(separator) == std::string::npos) && (str.find(quote) == 
std::string::npos)) {
+      os << str;
+    } else {
+      os << quote;
+      for (char c : str) {
+        if (c == separator || c == quote) {
+          os << escape;
+        }
+        os << c;
+      }
+      os << quote;
+    }
   }
   return String(os.str());
 }
 
+static std::vector<std::string> SplitString(const std::string& str, char 
separator) {
+  char escape = '\\';
+  char quote = '\'';
+
+  std::vector<std::string> output;
+
+  const char* start = str.data();
+  const char* end = start + str.size();
+  const char* pos = start;
+
+  std::stringstream current_word;
+
+  auto finish_word = [&]() {
+    std::string word = current_word.str();
+    if (word.size()) {
+      output.push_back(word);
+      current_word.str("");
+    }
+  };
+
+  bool pos_quoted = false;
+
+  while (pos < end) {
+    if ((*pos == separator) && !pos_quoted) {
+      finish_word();
+      pos++;
+    } else if ((*pos == escape) && (pos + 1 < end) && (pos[1] == quote)) {
+      current_word << quote;
+      pos += 2;
+    } else if (*pos == quote) {
+      pos_quoted = !pos_quoted;
+      pos++;
+    } else {
+      current_word << *pos;
+      pos++;
+    }
+  }
+
+  ICHECK(!pos_quoted) << "Mismatched quotes '' in string";
+
+  finish_word();
+
+  return output;
+}
+
 static int ParseKVPair(const std::string& s, const std::string& s_next, 
std::string* key,
                        std::string* value) {
   int pos;
@@ -207,9 +274,9 @@ const TargetKindNode::ValueTypeInfo& 
TargetInternal::FindTypeInfo(const TargetKi
 
 ObjectRef TargetInternal::ParseType(const std::string& str,
                                     const TargetKindNode::ValueTypeInfo& info) 
{
-  std::istringstream is(str);
   if (info.type_index == 
Integer::ContainerType::_GetOrAllocRuntimeTypeIndex()) {
     // Parsing integer
+    std::istringstream is(str);
     int v;
     if (!(is >> v)) {
       std::string lower(str.size(), '\x0');
@@ -226,19 +293,18 @@ ObjectRef TargetInternal::ParseType(const std::string& 
str,
     }
     return Integer(v);
   } else if (info.type_index == 
String::ContainerType::_GetOrAllocRuntimeTypeIndex()) {
-    // Parsing string
-    std::string v;
-    if (!(is >> v)) {
-      throw Error(": Cannot parse into type \"String\" from string: " + str);
-    }
-    return String(v);
+    // Parsing string, strip leading/trailing spaces
+    auto start = str.find_first_not_of(' ');
+    auto end = str.find_last_not_of(' ');
+    return String(str.substr(start, (end - start + 1)));
+
   } else if (info.type_index == 
Target::ContainerType::_GetOrAllocRuntimeTypeIndex()) {
     // Parsing target
     return Target(TargetInternal::FromString(str));
   } else if (info.type_index == ArrayNode::_GetOrAllocRuntimeTypeIndex()) {
     // Parsing array
     std::vector<ObjectRef> result;
-    for (std::string substr; std::getline(is, substr, ',');) {
+    for (const std::string& substr : SplitString(str, ',')) {
       try {
         ObjectRef parsed = TargetInternal::ParseType(substr, *info.key);
         result.push_back(parsed);
@@ -550,24 +616,14 @@ ObjectPtr<Object> TargetInternal::FromConfigString(const 
String& config_str) {
 }
 
 ObjectPtr<Object> TargetInternal::FromRawString(const String& target_str) {
+  ICHECK_GT(target_str.length(), 0) << "Cannot parse empty target string";
   // Split the string by empty spaces
-  std::string name;
-  std::vector<std::string> options;
-  std::string str;
-  for (std::istringstream is(target_str); is >> str;) {
-    if (name.empty()) {
-      name = str;
-    } else {
-      options.push_back(str);
-    }
-  }
-  if (name.empty()) {
-    throw Error(": Cannot parse empty target string");
-  }
+  std::vector<std::string> options = SplitString(std::string(target_str), ' ');
+  std::string name = options[0];
   // Create the target config
   std::unordered_map<String, ObjectRef> config = {{"kind", String(name)}};
   TargetKind kind = GetTargetKind(name);
-  for (size_t iter = 0, end = options.size(); iter < end;) {
+  for (size_t iter = 1, end = options.size(); iter < end;) {
     std::string key, value;
     try {
       // Parse key-value pair
diff --git a/tests/python/unittest/test_target_target.py 
b/tests/python/unittest/test_target_target.py
index bb3aa9e..5007ef1 100644
--- a/tests/python/unittest/test_target_target.py
+++ b/tests/python/unittest/test_target_target.py
@@ -76,6 +76,14 @@ def test_target_string_parse():
     assert tvm.target.arm_cpu().device_name == "arm_cpu"
 
 
+def test_target_string_with_spaces():
+    target = tvm.target.Target(
+        "vulkan -device_name='Name of GPU with spaces' -device_type=discrete"
+    )
+    assert target.attrs["device_name"] == "Name of GPU with spaces"
+    assert target.attrs["device_type"] == "discrete"
+
+
 def test_target_create():
     targets = [cuda(), rocm(), mali(), intel_graphics(), arm_cpu("rk3399"), 
vta(), bifrost()]
     for tgt in targets:

Reply via email to