This is an automated email from the ASF dual-hosted git repository.
jcf94 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 972d7b5 [Relay] Add support of conv2d with NHWC for Bifrost (#8430)
972d7b5 is described below
commit 972d7b52d2dbdd7cc1db98c3af04b04c4fc31b75
Author: Anastasia Stulova <[email protected]>
AuthorDate: Sat Jul 10 07:51:07 2021 +0100
[Relay] Add support of conv2d with NHWC for Bifrost (#8430)
Reuse generic Mali strategy for conv2d with NHWC in
Bifrost target.
---
python/tvm/relay/op/strategy/bifrost.py | 8 ++++++++
tests/python/topi/python/test_topi_conv2d_nhwc.py | 4 ++++
2 files changed, 12 insertions(+)
diff --git a/python/tvm/relay/op/strategy/bifrost.py
b/python/tvm/relay/op/strategy/bifrost.py
index 24e68a4..8008391 100644
--- a/python/tvm/relay/op/strategy/bifrost.py
+++ b/python/tvm/relay/op/strategy/bifrost.py
@@ -65,6 +65,14 @@ def conv2d_strategy_bifrost(attrs, inputs, out_type, target):
wrap_topi_schedule(topi.bifrost.schedule_conv2d_nchw_spatial_pack),
name="conv2d_nchw_spatial_pack.bifrost",
)
+ elif layout == "NHWC":
+ assert kernel_layout == "HWIO"
+ # For now just reuse general Mali strategy.
+ strategy.add_implementation(
+ wrap_compute_conv2d(topi.mali.conv2d_nhwc_spatial_pack),
+
wrap_topi_schedule(topi.mali.schedule_conv2d_nhwc_spatial_pack),
+ name="conv2d_nhwc_spatial_pack.bifrost",
+ )
else:
raise RuntimeError("Unsupported conv2d layout {} for
Mali(Bifrost)".format(layout))
elif is_depthwise_conv2d(data.shape, layout, kernel.shape, kernel_layout,
groups):
diff --git a/tests/python/topi/python/test_topi_conv2d_nhwc.py
b/tests/python/topi/python/test_topi_conv2d_nhwc.py
index 1a80b8e..eb4c5a3 100644
--- a/tests/python/topi/python/test_topi_conv2d_nhwc.py
+++ b/tests/python/topi/python/test_topi_conv2d_nhwc.py
@@ -38,6 +38,10 @@ _conv2d_nhwc_implement = {
topi.mali.conv2d_nhwc_spatial_pack,
topi.mali.schedule_conv2d_nhwc_spatial_pack,
),
+ "bifrost": (
+ topi.mali.conv2d_nhwc_spatial_pack,
+ topi.mali.schedule_conv2d_nhwc_spatial_pack,
+ ),
"hls": (topi.nn.conv2d_nhwc, topi.hls.schedule_conv2d_nhwc),
}