zxybazh commented on a change in pull request #7534:
URL: https://github.com/apache/tvm/pull/7534#discussion_r583829327
##########
File path: tests/python/unittest/test_target_target.py
##########
@@ -216,7 +216,35 @@ def test_target_host_warning():
attributes fails as expected.
"""
with pytest.raises(ValueError):
- tgt = tvm.target.Target("cuda --host nvidia/jetson-nano", "llvm")
+ tvm.target.Target("cuda --host nvidia/jetson-nano", "llvm")
+
+
+def test_target_host_merge_0():
+ tgt = tvm.target.Target(tvm.target.Target("cuda --host
nvidia/jetson-nano"), None)
+ assert tgt.kind.name == "cuda"
+ assert tgt.host.kind.name == "cuda"
+ assert tgt.host.attrs["arch"] == "sm_53"
+ assert tgt.host.attrs["shared_memory_per_block"] == 49152
+ assert tgt.host.attrs["max_threads_per_block"] == 1024
+ assert tgt.host.attrs["thread_warp_size"] == 32
+ assert tgt.host.attrs["registers_per_block"] == 32768
+
+
+def test_target_host_merge_1():
+ tgt = tvm.target.Target("cuda --host llvm")
+ tgt = tvm.target.Target(tgt, tgt.host)
+ assert tgt.kind.name == "cuda"
+ assert tgt.host.kind.name == "llvm"
+
+
+def test_target_host_merge_2():
+ with pytest.raises(ValueError):
+ tvm.target.Target(tvm.target.Target("cuda --host llvm"),
tvm.target.Target("llvm"))
+
+
+def test_target_host_merge_3():
+ with pytest.raises(ValueError):
+ tvm.target.Target(tvm.target.Target("cuda --host llvm"), 12.34)
if __name__ == "__main__":
Review comment:
Wow, works like a charm. Good point!
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]