This is an automated email from the ASF dual-hosted git repository.
comaniac pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 155f669 [TVMC] Fix tvmc compile to extract target and target_host
from --target (#8176)
155f669 is described below
commit 155f669849f10ca4128b2cbd3feb7d4511538c02
Author: Leandro Nunes <[email protected]>
AuthorDate: Thu Jun 3 17:20:59 2021 +0100
[TVMC] Fix tvmc compile to extract target and target_host from --target
(#8176)
* [TVMC] Fix tvmc compile to extract target and target_host from --target
* Removes validation to accept up to two TVM targets and
set them as target and target_host
* Update python/tvm/driver/tvmc/common.py
Co-authored-by: Cody Yu <[email protected]>
Co-authored-by: Cody Yu <[email protected]>
---
python/tvm/driver/tvmc/common.py | 27 +++++++++++++++++++++------
tests/python/driver/tvmc/test_tvmc_common.py | 17 +++++++++++++++++
2 files changed, 38 insertions(+), 6 deletions(-)
diff --git a/python/tvm/driver/tvmc/common.py b/python/tvm/driver/tvmc/common.py
index 34f59aa..48e18fb 100644
--- a/python/tvm/driver/tvmc/common.py
+++ b/python/tvm/driver/tvmc/common.py
@@ -97,10 +97,10 @@ def validate_targets(parse_targets):
)
tvm_targets = [t for t in targets if t in tvm_target_kinds]
- if len(tvm_targets) > 1:
+ if len(tvm_targets) > 2:
verbose_tvm_targets = ", ".join(tvm_targets)
raise TVMCException(
- "Only one of the following targets can be used at a time. "
+ "Only two of the following targets can be used at a time. "
f"Found: {verbose_tvm_targets}."
)
@@ -199,6 +199,7 @@ def parse_target(target):
"""
codegens = []
+ tvm_target_kinds = tvm.target.Target.list_kinds()
parsed_tokens = tokenize_target(target)
split_codegens = []
@@ -222,6 +223,7 @@ def parse_target(target):
for codegen_def in split_codegens:
# the first is expected to be the name
name = codegen_def[0]
+ is_tvm_target = name in tvm_target_kinds
raw_target = " ".join(codegen_def)
all_opts = codegen_def[1:] if len(codegen_def) > 1 else []
opts = {}
@@ -244,7 +246,9 @@ def parse_target(target):
opts[opt_name] = opt_value
- codegens.append({"name": name, "opts": opts, "raw": raw_target})
+ codegens.append(
+ {"name": name, "opts": opts, "raw": raw_target, "is_tvm_target":
is_tvm_target}
+ )
return codegens
@@ -295,10 +299,21 @@ def target_from_cli(target):
raise TVMCException(f"Error parsing target string '{target}'.\nThe
error was: {ex}")
validate_targets(parsed_targets)
- target = parsed_targets[-1]["raw"]
- extra_targets = parsed_targets[:-1] if len(parsed_targets) > 1 else []
+ tvm_targets = [t for t in parsed_targets if t["is_tvm_target"]]
+
+ # Validated target strings have 1 or 2 tvm targets, otherwise
+ # `validate_targets` above will fail.
+ if len(tvm_targets) == 1:
+ target = tvm_targets[0]["raw"]
+ target_host = None
+ else:
+ assert len(tvm_targets) == 2
+ target = tvm_targets[0]["raw"]
+ target_host = tvm_targets[1]["raw"]
+
+ extra_targets = [t for t in parsed_targets if not t["is_tvm_target"]]
- return tvm.target.Target(target), extra_targets
+ return tvm.target.Target(target, host=target_host), extra_targets
def tracker_host_port_from_cli(rpc_tracker_str):
diff --git a/tests/python/driver/tvmc/test_tvmc_common.py
b/tests/python/driver/tvmc/test_tvmc_common.py
index 078076b..476fac5 100644
--- a/tests/python/driver/tvmc/test_tvmc_common.py
+++ b/tests/python/driver/tvmc/test_tvmc_common.py
@@ -192,6 +192,11 @@ def test_target_from_cli__error_duplicate():
_ = tvmc.common.target_from_cli("llvm, llvm")
+def test_target_invalid_more_than_two_tvm_targets():
+ with pytest.raises(TVMCException):
+ _ = tvmc.common.target_from_cli("cuda, opencl, llvm")
+
+
def test_target_from_cli__error_target_not_found():
with pytest.raises(TVMCException):
_ = tvmc.common.target_from_cli("invalidtarget")
@@ -202,6 +207,18 @@ def test_target_from_cli__error_no_tvm_target():
_ = tvmc.common.target_from_cli("ethos-n77")
+def test_target_two_tvm_targets():
+ tvm_target, extra_targets = tvmc.common.target_from_cli(
+ "opencl -device=mali, llvm -mtriple=aarch64-linux-gnu"
+ )
+
+ assert "opencl" in str(tvm_target)
+ assert "llvm" in str(tvm_target.host)
+
+ # No extra targets
+ assert 0 == len(extra_targets)
+
+
def test_tokenize_target_with_opts():
tokens = tvmc.common.tokenize_target("foo -opt1=value1 --flag, bar
-opt2=value2")
expected_tokens = ["foo", "-opt1=value1", "--flag", ",", "bar",
"-opt2=value2"]