This is an automated email from the ASF dual-hosted git repository.

comaniac pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 155f669  [TVMC] Fix tvmc compile to extract target and target_host 
from --target (#8176)
155f669 is described below

commit 155f669849f10ca4128b2cbd3feb7d4511538c02
Author: Leandro Nunes <[email protected]>
AuthorDate: Thu Jun 3 17:20:59 2021 +0100

    [TVMC] Fix tvmc compile to extract target and target_host from --target 
(#8176)
    
    * [TVMC] Fix tvmc compile to extract target and target_host from --target
    
     * Removes validation to accept up to two TVM targets and
       set them as target and target_host
    
    * Update python/tvm/driver/tvmc/common.py
    
    Co-authored-by: Cody Yu <[email protected]>
    
    Co-authored-by: Cody Yu <[email protected]>
---
 python/tvm/driver/tvmc/common.py             | 27 +++++++++++++++++++++------
 tests/python/driver/tvmc/test_tvmc_common.py | 17 +++++++++++++++++
 2 files changed, 38 insertions(+), 6 deletions(-)

diff --git a/python/tvm/driver/tvmc/common.py b/python/tvm/driver/tvmc/common.py
index 34f59aa..48e18fb 100644
--- a/python/tvm/driver/tvmc/common.py
+++ b/python/tvm/driver/tvmc/common.py
@@ -97,10 +97,10 @@ def validate_targets(parse_targets):
         )
 
     tvm_targets = [t for t in targets if t in tvm_target_kinds]
-    if len(tvm_targets) > 1:
+    if len(tvm_targets) > 2:
         verbose_tvm_targets = ", ".join(tvm_targets)
         raise TVMCException(
-            "Only one of the following targets can be used at a time. "
+            "Only two of the following targets can be used at a time. "
             f"Found: {verbose_tvm_targets}."
         )
 
@@ -199,6 +199,7 @@ def parse_target(target):
     """
     codegens = []
 
+    tvm_target_kinds = tvm.target.Target.list_kinds()
     parsed_tokens = tokenize_target(target)
 
     split_codegens = []
@@ -222,6 +223,7 @@ def parse_target(target):
     for codegen_def in split_codegens:
         # the first is expected to be the name
         name = codegen_def[0]
+        is_tvm_target = name in tvm_target_kinds
         raw_target = " ".join(codegen_def)
         all_opts = codegen_def[1:] if len(codegen_def) > 1 else []
         opts = {}
@@ -244,7 +246,9 @@ def parse_target(target):
 
             opts[opt_name] = opt_value
 
-        codegens.append({"name": name, "opts": opts, "raw": raw_target})
+        codegens.append(
+            {"name": name, "opts": opts, "raw": raw_target, "is_tvm_target": 
is_tvm_target}
+        )
 
     return codegens
 
@@ -295,10 +299,21 @@ def target_from_cli(target):
             raise TVMCException(f"Error parsing target string '{target}'.\nThe 
error was: {ex}")
 
         validate_targets(parsed_targets)
-        target = parsed_targets[-1]["raw"]
-        extra_targets = parsed_targets[:-1] if len(parsed_targets) > 1 else []
+        tvm_targets = [t for t in parsed_targets if t["is_tvm_target"]]
+
+        # Validated target strings have 1 or 2 tvm targets, otherwise
+        # `validate_targets` above will fail.
+        if len(tvm_targets) == 1:
+            target = tvm_targets[0]["raw"]
+            target_host = None
+        else:
+            assert len(tvm_targets) == 2
+            target = tvm_targets[0]["raw"]
+            target_host = tvm_targets[1]["raw"]
+
+        extra_targets = [t for t in parsed_targets if not t["is_tvm_target"]]
 
-    return tvm.target.Target(target), extra_targets
+    return tvm.target.Target(target, host=target_host), extra_targets
 
 
 def tracker_host_port_from_cli(rpc_tracker_str):
diff --git a/tests/python/driver/tvmc/test_tvmc_common.py 
b/tests/python/driver/tvmc/test_tvmc_common.py
index 078076b..476fac5 100644
--- a/tests/python/driver/tvmc/test_tvmc_common.py
+++ b/tests/python/driver/tvmc/test_tvmc_common.py
@@ -192,6 +192,11 @@ def test_target_from_cli__error_duplicate():
         _ = tvmc.common.target_from_cli("llvm, llvm")
 
 
+def test_target_invalid_more_than_two_tvm_targets():
+    with pytest.raises(TVMCException):
+        _ = tvmc.common.target_from_cli("cuda, opencl, llvm")
+
+
 def test_target_from_cli__error_target_not_found():
     with pytest.raises(TVMCException):
         _ = tvmc.common.target_from_cli("invalidtarget")
@@ -202,6 +207,18 @@ def test_target_from_cli__error_no_tvm_target():
         _ = tvmc.common.target_from_cli("ethos-n77")
 
 
+def test_target_two_tvm_targets():
+    tvm_target, extra_targets = tvmc.common.target_from_cli(
+        "opencl -device=mali, llvm -mtriple=aarch64-linux-gnu"
+    )
+
+    assert "opencl" in str(tvm_target)
+    assert "llvm" in str(tvm_target.host)
+
+    # No extra targets
+    assert 0 == len(extra_targets)
+
+
 def test_tokenize_target_with_opts():
     tokens = tvmc.common.tokenize_target("foo -opt1=value1 --flag, bar 
-opt2=value2")
     expected_tokens = ["foo", "-opt1=value1", "--flag", ",", "bar", 
"-opt2=value2"]

Reply via email to