This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new 0a1e160  [Topi][Cuda]Optimizations of global_ave_pool for NHWC layout 
(#5450)
0a1e160 is described below

commit 0a1e160134fe52085f1d9dcba1bb3dfe9439857f
Author: SXM-inspur <61525342+sxm-ins...@users.noreply.github.com>
AuthorDate: Wed Apr 29 04:44:03 2020 +0800

    [Topi][Cuda]Optimizations of global_ave_pool for NHWC layout (#5450)
    
    * Optimizations of global_ave_pool for NHWC layout
    
    * Optimize the code format to pass inspection of pylint
    
    Co-authored-by: Shawn-Inspur <wushao...@inspur.com>
---
 python/tvm/relay/op/strategy/cuda.py   |  2 +-
 topi/python/topi/cuda/pooling.py       |  8 +++--
 topi/tests/python/test_topi_pooling.py | 59 ++++++++++++++++++++++------------
 3 files changed, 46 insertions(+), 23 deletions(-)

diff --git a/python/tvm/relay/op/strategy/cuda.py 
b/python/tvm/relay/op/strategy/cuda.py
index e976978..9189b5e 100644
--- a/python/tvm/relay/op/strategy/cuda.py
+++ b/python/tvm/relay/op/strategy/cuda.py
@@ -58,7 +58,7 @@ def schedule_pool_grad_cuda(attrs, outs, target):
 def schedule_adaptive_pool_cuda(attrs, outs, target):
     """schedule adaptive pooling ops for cuda"""
     with target:
-        return topi.cuda.schedule_adaptive_pool(outs)
+        return topi.cuda.schedule_adaptive_pool(outs, attrs.layout)
 
 @softmax_strategy.register(["cuda", "gpu"])
 def softmax_strategy_cuda(attrs, inputs, out_type, target):
diff --git a/topi/python/topi/cuda/pooling.py b/topi/python/topi/cuda/pooling.py
index 26c18ee..9839984 100644
--- a/topi/python/topi/cuda/pooling.py
+++ b/topi/python/topi/cuda/pooling.py
@@ -22,7 +22,7 @@ from .. import tag
 from ..util import traverse_inline
 
 
-def schedule_adaptive_pool(outs):
+def schedule_adaptive_pool(outs, layout='NCHW'):
     """Schedule for adaptive_pool.
 
     Parameters
@@ -51,8 +51,12 @@ def schedule_adaptive_pool(outs):
         else:
             Out = outs[0].op.output(0)
             s[Pool].set_scope("local")
+
         by, ty = s[Out].split(s[Out].op.axis[0], factor=num_thread)
-        bx, tx = s[Out].split(s[Out].op.axis[1], factor=num_thread)
+        if layout == 'NHWC':
+            bx, tx = s[Out].split(s[Out].op.axis[3], factor=num_thread)
+        else:
+            bx, tx = s[Out].split(s[Out].op.axis[1], factor=num_thread)
         s[Out].reorder(by, bx, ty, tx)
         s[Out].bind(ty, thread_y)
         s[Out].bind(tx, thread_x)
diff --git a/topi/tests/python/test_topi_pooling.py 
b/topi/tests/python/test_topi_pooling.py
index 9bdbb10..9f71a31 100644
--- a/topi/tests/python/test_topi_pooling.py
+++ b/topi/tests/python/test_topi_pooling.py
@@ -14,6 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+# pylint: disable=invalid-name, too-many-locals, too-many-statements, 
unused-argument
 """Test code for pooling"""
 import math
 import numpy as np
@@ -44,6 +45,7 @@ _pool_grad_schedule = {
 }
 
 def verify_pool(n, ic, ih, kh, sh, padding, pool_type, ceil_mode, 
count_include_pad=True):
+    """verify function of pool"""
     iw = ih
     kw = kh
     sw = sh
@@ -76,15 +78,17 @@ def verify_pool(n, ic, ih, kh, sh, padding, pool_type, 
ceil_mode, count_include_
         for i in range(oh):
             for j in range(ow):
                 if count_include_pad:
-                    b_np[:,:,i,j] = np.mean(pad_np[:, :, i*sh:i*sh+kh, 
j*sw:j*sw+kw], axis=(2,3))
+                    b_np[:, :, i, j] = \
+                            np.mean(pad_np[:, :, i*sh:i*sh+kh, j*sw:j*sw+kw], 
axis=(2, 3))
                 else:
-                    pad_count = np.sum(pad_np[:, :, i*sh:i*sh+kh, 
j*sw:j*sw+kw] > 0, axis=(2,3))
-                    b_np[:,:,i,j] = np.sum(pad_np[:, :, i*sh:i*sh+kh, 
j*sw:j*sw+kw], axis=(2,3)) / np.maximum(pad_count, 1)
+                    pad_count = np.sum(pad_np[:, :, i*sh:i*sh+kh, 
j*sw:j*sw+kw] > 0, axis=(2, 3))
+                    b_np[:, :, i, j] = np.sum(pad_np[:, :, i*sh:i*sh+kh, 
j*sw:j*sw+kw], axis=(2, 3)) \
+                                       / np.maximum(pad_count, 1)
 
-    elif pool_type =='max':
+    elif pool_type == 'max':
         for i in range(oh):
             for j in range(ow):
-                b_np[:,:,i,j] = np.max(pad_np[:, :, i*sh:i*sh+kh, 
j*sw:j*sw+kw], axis=(2,3))
+                b_np[:, :, i, j] = np.max(pad_np[:, :, i*sh:i*sh+kh, 
j*sw:j*sw+kw], axis=(2, 3))
     b_np = np.maximum(b_np, 0.0)
 
     def check_device(device):
@@ -108,11 +112,11 @@ def verify_pool(n, ic, ih, kh, sh, padding, pool_type, 
ceil_mode, count_include_
 
 def verify_pool_grad(n, ic, ih, kh, sh, padding, pool_type, ceil_mode, 
count_include_pad=True,
                      add_relu=False):
+    """verify function of pool_grad"""
     iw = ih
     kw = kh
     sw = sh
     pt, pl, pb, pr = padding
-    layout = "NCHW"
     A = te.placeholder((n, ic, ih, iw), name='A')
     B = topi.nn.pool(A, kernel=[kh, kw], stride=[sh, sw], padding=padding,
                      pool_type=pool_type, ceil_mode=ceil_mode,
@@ -164,6 +168,7 @@ def verify_pool_grad(n, ic, ih, kh, sh, padding, pool_type, 
ceil_mode, count_inc
         check_device(device)
 
 def test_pool():
+    """test cases of pool"""
     verify_pool(1, 256, 32, 2, 2, [0, 0, 0, 0], 'avg', False, True)
     verify_pool(1, 256, 31, 3, 3, [1, 2, 1, 2], 'avg', False, True)
     verify_pool(1, 256, 32, 2, 2, [1, 2, 1, 2], 'avg', False, False)
@@ -179,6 +184,7 @@ def test_pool():
     verify_pool(1, 256, 31, 3, 3, [3, 2, 1, 0], 'max', True)
 
 def test_pool_grad():
+    """test cases of pool_grad"""
     verify_pool_grad(1, 256, 32, 3, 2, [1, 1, 1, 1], 'avg', False, False)
     verify_pool_grad(1, 256, 32, 2, 2, [0, 0, 0, 0], 'avg', False, True)
     verify_pool_grad(1, 256, 31, 3, 3, [1, 2, 1, 2], 'avg', False, True)
@@ -200,10 +206,10 @@ def test_pool_grad():
     verify_pool_grad(1, 256, 32, 2, 2, [0, 0, 0, 0], 'max', False, 
add_relu=True)
 
 
-def verify_global_pool(n, c, h, w, pool_type, layout='NCHW'):
-
+def verify_global_pool(dshape, pool_type, layout='NCHW'):
+    """verify function of global_pool"""
     assert layout in ["NCHW", "NHWC"]
-    A = te.placeholder((n, c, h, w), name='A')
+    A = te.placeholder(shape=dshape, name='A')
     B = topi.nn.global_pool(A, pool_type=pool_type, layout=layout)
     B = topi.nn.relu(B)
 
@@ -212,7 +218,7 @@ def verify_global_pool(n, c, h, w, pool_type, 
layout='NCHW'):
     axis = (layout.find('H'), layout.find('W'))
     if pool_type == 'avg':
         b_np = np.mean(a_np, axis=axis, keepdims=True)
-    elif pool_type =='max':
+    elif pool_type == 'max':
         b_np = np.max(a_np, axis=axis, keepdims=True)
     b_np = np.maximum(b_np, 0.0)
 
@@ -224,7 +230,10 @@ def verify_global_pool(n, c, h, w, pool_type, 
layout='NCHW'):
         print("Running on target: %s" % device)
         with tvm.target.create(device):
             s_func = topi.testing.dispatch(device, _adaptive_pool_schedule)
-            s = s_func(B)
+            if device == "cuda":
+                s = s_func(B, layout)
+            else:
+                s = s_func(B)
         a = tvm.nd.array(a_np, ctx)
         b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), 
ctx)
         f = tvm.build(s, [A, B], device)
@@ -235,17 +244,19 @@ def verify_global_pool(n, c, h, w, pool_type, 
layout='NCHW'):
         check_device(device)
 
 def test_global_pool():
-    verify_global_pool(1, 1024, 7, 7, 'avg')
-    verify_global_pool(4, 1024, 7, 7, 'avg')
-    verify_global_pool(1, 1024, 7, 7, 'max')
-    verify_global_pool(4, 1024, 7, 7, 'max')
-    verify_global_pool(1, 1024, 7, 7, 'avg', 'NHWC')
-    verify_global_pool(4, 1024, 7, 7, 'avg', 'NHWC')
-    verify_global_pool(1, 1024, 7, 7, 'max', 'NHWC')
-    verify_global_pool(4, 1024, 7, 7, 'max', 'NHWC')
+    """test cases of global_pool"""
+    verify_global_pool((1, 1024, 7, 7), 'avg')
+    verify_global_pool((4, 1024, 7, 7), 'avg')
+    verify_global_pool((1, 1024, 7, 7), 'max')
+    verify_global_pool((4, 1024, 7, 7), 'max')
+    verify_global_pool((1, 7, 7, 1024), 'avg', 'NHWC')
+    verify_global_pool((4, 7, 7, 1024), 'avg', 'NHWC')
+    verify_global_pool((1, 7, 7, 1024), 'max', 'NHWC')
+    verify_global_pool((4, 7, 7, 1024), 'max', 'NHWC')
 
 
 def verify_adaptive_pool(dshape, out_size, pool_type, layout="NCHW", 
dtype="float32"):
+    """verify function of adaptive_pool"""
     np_data = np.random.uniform(low=0, high=255, size=dshape).astype(dtype)
     np_out = topi.testing.adaptive_pool(np_data, out_size, pool_type, layout)
     oshape = np_out.shape
@@ -265,7 +276,10 @@ def verify_adaptive_pool(dshape, out_size, pool_type, 
layout="NCHW", dtype="floa
         print("Running on target: %s" % device)
         with tvm.target.create(device):
             s_func = topi.testing.dispatch(device, _adaptive_pool_schedule)
-            s = s_func(out)
+            if device == "cuda":
+                s = s_func(out, layout)
+            else:
+                s = s_func(out)
         a = tvm.nd.array(np_data, ctx)
         b = tvm.nd.array(np.zeros(get_const_tuple(oshape), dtype=out.dtype), 
ctx)
         f = tvm.build(s, [data, out], device)
@@ -277,6 +291,7 @@ def verify_adaptive_pool(dshape, out_size, pool_type, 
layout="NCHW", dtype="floa
 
 
 def test_adaptive_pool():
+    """test cases of adaptive_pool"""
     verify_adaptive_pool((1, 3, 224, 224), (1, 1), "max")
     verify_adaptive_pool((1, 3, 224, 224), (1, 1), "avg")
     verify_adaptive_pool((1, 14, 56, 78), (34, 13), "max")
@@ -295,6 +310,7 @@ def test_adaptive_pool():
 
 def verify_pool3d(n, ic, ih, kh, sh, padding, pool_type,
                   ceil_mode, count_include_pad=True, layout='NCDHW'):
+    """verify function of pool3d"""
     id = iw = ih
     kd = kw = kh
     sd = sw = sh
@@ -334,6 +350,7 @@ def verify_pool3d(n, ic, ih, kh, sh, padding, pool_type,
 
 
 def test_pool3d():
+    """test cases of pool3d"""
     verify_pool3d(1, 256, 32, 2, 2, [0, 0, 0, 0, 0, 0], 'avg', False, True)
     verify_pool3d(1, 256, 31, 3, 3, [1, 1, 2, 2, 2, 1], 'avg', False, True)
     verify_pool3d(1, 256, 32, 2, 2, [1, 1, 2, 2, 2, 1], 'avg', False, False)
@@ -351,6 +368,7 @@ def test_pool3d():
 
 def verify_pool1d(n, ic, iw, kw, sw, padding, pool_type,
                   ceil_mode, count_include_pad=True, layout='NCW'):
+    """verify function of pool1d"""
     input_shape = (n, ic, iw)
     kernel = [kw]
     stride = [sw]
@@ -387,6 +405,7 @@ def verify_pool1d(n, ic, iw, kw, sw, padding, pool_type,
 
 
 def test_pool1d():
+    """test cases of pool1d"""
     verify_pool1d(1, 256, 32, 2, 2, [0, 0], 'avg', False, True)
     verify_pool1d(1, 256, 31, 3, 3, [1, 2], 'avg', False, True)
     verify_pool1d(1, 256, 32, 2, 2, [1, 2], 'avg', False, False)

Reply via email to