This is an automated email from the ASF dual-hosted git repository.

jcf94 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 972d7b5  [Relay] Add support of conv2d with NHWC for Bifrost (#8430)
972d7b5 is described below

commit 972d7b52d2dbdd7cc1db98c3af04b04c4fc31b75
Author: Anastasia Stulova <[email protected]>
AuthorDate: Sat Jul 10 07:51:07 2021 +0100

    [Relay] Add support of conv2d with NHWC for Bifrost (#8430)
    
    Reuse generic Mali strategy for conv2d with NHWC in
    Bifrost target.
---
 python/tvm/relay/op/strategy/bifrost.py           | 8 ++++++++
 tests/python/topi/python/test_topi_conv2d_nhwc.py | 4 ++++
 2 files changed, 12 insertions(+)

diff --git a/python/tvm/relay/op/strategy/bifrost.py 
b/python/tvm/relay/op/strategy/bifrost.py
index 24e68a4..8008391 100644
--- a/python/tvm/relay/op/strategy/bifrost.py
+++ b/python/tvm/relay/op/strategy/bifrost.py
@@ -65,6 +65,14 @@ def conv2d_strategy_bifrost(attrs, inputs, out_type, target):
                     
wrap_topi_schedule(topi.bifrost.schedule_conv2d_nchw_spatial_pack),
                     name="conv2d_nchw_spatial_pack.bifrost",
                 )
+        elif layout == "NHWC":
+            assert kernel_layout == "HWIO"
+            # For now just reuse general Mali strategy.
+            strategy.add_implementation(
+                wrap_compute_conv2d(topi.mali.conv2d_nhwc_spatial_pack),
+                
wrap_topi_schedule(topi.mali.schedule_conv2d_nhwc_spatial_pack),
+                name="conv2d_nhwc_spatial_pack.bifrost",
+            )
         else:
             raise RuntimeError("Unsupported conv2d layout {} for 
Mali(Bifrost)".format(layout))
     elif is_depthwise_conv2d(data.shape, layout, kernel.shape, kernel_layout, 
groups):
diff --git a/tests/python/topi/python/test_topi_conv2d_nhwc.py 
b/tests/python/topi/python/test_topi_conv2d_nhwc.py
index 1a80b8e..eb4c5a3 100644
--- a/tests/python/topi/python/test_topi_conv2d_nhwc.py
+++ b/tests/python/topi/python/test_topi_conv2d_nhwc.py
@@ -38,6 +38,10 @@ _conv2d_nhwc_implement = {
         topi.mali.conv2d_nhwc_spatial_pack,
         topi.mali.schedule_conv2d_nhwc_spatial_pack,
     ),
+    "bifrost": (
+        topi.mali.conv2d_nhwc_spatial_pack,
+        topi.mali.schedule_conv2d_nhwc_spatial_pack,
+    ),
     "hls": (topi.nn.conv2d_nhwc, topi.hls.schedule_conv2d_nhwc),
 }
 

Reply via email to