This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 37f6aa0c7e [MetaSchedule] Fix tensorcore winograd task extraction
(#13625)
37f6aa0c7e is described below
commit 37f6aa0c7e1610a9635550e120f14abbdda8df48
Author: masahi <[email protected]>
AuthorDate: Fri Dec 16 11:08:43 2022 +0900
[MetaSchedule] Fix tensorcore winograd task extraction (#13625)
* [MetaSchedule] Fix tensorcore winograd task extraction
* add test
* fixed target
---
python/tvm/relay/op/strategy/cuda.py | 2 ++
.../unittest/test_meta_schedule_relay_integration.py | 18 ++++++++++++++++++
2 files changed, 20 insertions(+)
diff --git a/python/tvm/relay/op/strategy/cuda.py
b/python/tvm/relay/op/strategy/cuda.py
index 312ec0fe2f..cc43809266 100644
--- a/python/tvm/relay/op/strategy/cuda.py
+++ b/python/tvm/relay/op/strategy/cuda.py
@@ -261,6 +261,8 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
)
if (
target.kind.name == "cuda"
+ and not is_auto_scheduler_enabled()
+ and not is_meta_schedule_enabled()
and nvcc.have_tensorcore(target=target)
and (
(N % 16 == 0 and CI % 16 == 0 and CO % 16 == 0)
diff --git a/tests/python/unittest/test_meta_schedule_relay_integration.py
b/tests/python/unittest/test_meta_schedule_relay_integration.py
index 062da0b00c..604f337099 100644
--- a/tests/python/unittest/test_meta_schedule_relay_integration.py
+++ b/tests/python/unittest/test_meta_schedule_relay_integration.py
@@ -108,6 +108,24 @@ def test_meta_schedule_integration_extract_from_resnet():
assert t.task_name in expected_task_names, t.task_name
+@requires_torch
+def test_task_extraction_winograd_tensorcore():
+ mod, params, _ = get_network(name="resnet_50", input_shape=[16, 3, 224,
224])
+ seq = tvm.transform.Sequential(
+ [
+ relay.transform.ToMixedPrecision("float16"),
+ relay.transform.ConvertLayout({"nn.conv2d": ["NHWC", "HWIO"]}),
+ ]
+ )
+ with tvm.transform.PassContext(opt_level=3):
+ mod = seq(mod)
+
+ target = tvm.target.Target("nvidia/geforce-rtx-3070")
+ extracted_tasks = ms.relay_integration.extract_tasks(mod, target=target,
params=params)
+
+ assert len([t for t in extracted_tasks if "winograd" in t.task_name]) == 4
+
+
@requires_torch
def test_task_extraction_anchor_block():
mod, params, _ = get_network(name="resnet_18", input_shape=[1, 3, 224,
224])