This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 57b42a8a62 [Unity][BYOC] Update testcases to follow recent changes
(#14339)
57b42a8a62 is described below
commit 57b42a8a625320aacb3dfcf5c1b42cc6a4edf13f
Author: Sunghyun Park <[email protected]>
AuthorDate: Mon Mar 20 07:00:24 2023 -0700
[Unity][BYOC] Update testcases to follow recent changes (#14339)
This PR updates test cases to follow recent changes
---
tests/python/relax/test_transform_codegen_pass.py | 23 ++++++++---------------
1 file changed, 8 insertions(+), 15 deletions(-)
diff --git a/tests/python/relax/test_transform_codegen_pass.py
b/tests/python/relax/test_transform_codegen_pass.py
index 29ce7fd28f..77756dc664 100644
--- a/tests/python/relax/test_transform_codegen_pass.py
+++ b/tests/python/relax/test_transform_codegen_pass.py
@@ -19,7 +19,7 @@ import pytest
import os
import tvm
import tvm.testing
-from tvm import relax
+from tvm import relax, tir
import numpy as np
from tvm.script import relax as R
from tvm.relax.testing import transform
@@ -43,7 +43,7 @@ has_tensorrt_runtime = pytest.mark.skipif(
pytestmark = [has_tensorrt_codegen, has_tensorrt_runtime]
# Target gpu
-target_str = "nvidia/geforce-rtx-3070" # "nvidia/nvidia-t4"
+target_str = "nvidia/nvidia-t4"
target = tvm.target.Target(target_str)
dev = tvm.cuda()
@@ -58,7 +58,7 @@ def check_roundtrip(exec0, dev, inputs, expected):
exec0.mod.export_library("exec.so")
exec1 = tvm.runtime.load_module("exec.so")
os.remove("exec.so")
- assert exec0.stats() == exec1["stats"]
+ assert exec0.stats() == exec1["stats"]()
assert exec0.as_text() == exec1["as_text"]()
check_executable(exec0, dev, inputs, expected)
@@ -68,18 +68,11 @@ def check_roundtrip(exec0, dev, inputs, expected):
def gen_ground_truth(mod, target, dev, inputs):
# Lower and run tuning
# Since there is no default schedule for GPU in MS yet, this is necessary
- with tempfile.TemporaryDirectory() as work_dir:
- with target, tvm.transform.PassContext(trace=Trace(mod), opt_level=0):
- seq = tvm.transform.Sequential(
- [
- relax.transform.LegalizeOps(),
- relax.transform.MetaScheduleTuneIRMod(
- params={}, work_dir=work_dir, max_trials_global=8
- ),
- relax.transform.MetaScheduleApplyDatabase(work_dir),
- ]
- )
- new_mod = seq(mod)
+ with target:
+ seq = tvm.transform.Sequential(
+ [relax.transform.LegalizeOps(), tir.transform.DefaultGPUSchedule()]
+ )
+ new_mod = seq(mod)
assert relax.analysis.well_formed(new_mod)
exec = relax.build(new_mod, target, params={})
vm = relax.VirtualMachine(exec, dev)