Hzfengsy commented on a change in pull request #5498:
URL: https://github.com/apache/incubator-tvm/pull/5498#discussion_r418915913



##########
File path: tests/python/integration/test_reduce.py
##########
@@ -338,6 +338,102 @@ def check_target(device):
     check_target("cuda")
     check_target("vulkan")
 
+def test_warp_reduction1():
+    m = 32
+    n = 128
+    A = te.placeholder((m, n), name='A')
+    k = te.reduce_axis((0, n))
+    B = te.compute((m,), lambda i: te.max(A[i][k], axis=k), name='B')
+
+    nthx = 32
+    nthy = 4
+    block_x = te.thread_axis("blockIdx.x")
+    thread_x = te.thread_axis((0, nthx), "threadIdx.x")
+    thread_y = te.thread_axis((0, nthy), "threadIdx.y")
+    s = te.create_schedule(B.op)
+
+    def check_target(device):
+        ctx = tvm.context(device, 0)
+        if not ctx.exist:
+            print("skip because %s is not enabled.." % device)
+            return
+
+        # schedule
+        k = s[B].op.reduce_axis[0]
+        ko, _ = s[B].split(k, nparts=nthx)
+        s[B].bind(ko, thread_x)
+        xo, xi = s[B].split(s[B].op.axis[0], factor=nthy)
+        s[B].bind(xi, thread_y)
+        s[B].bind(xo, block_x)
+
+        # validation.
+        func = tvm.build(s, [A, B], "cuda", name="warp_reduction")
+        a_np = np.random.uniform(size=(m,n)).astype(A.dtype)
+        b_np = np.zeros((m,), dtype=A.dtype)
+        a = tvm.nd.array(a_np, ctx)
+        b = tvm.nd.array(b_np, ctx)
+        b_np = np.max(a_np, axis=1)
+        func(a, b)
+        tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-3, atol=1e-3)
+
+    check_target("cuda")
+
+def test_warp_reduction2():
+    def fcombine(x, y):
+        return x[0] + y[0], x[1] * y[1]
+
+    def fidentity(t0, t1):
+        return tvm.tir.const(0, t0), tvm.tir.const(1, t1)
+
+    add_mul_reducer = te.comm_reducer(fcombine, fidentity, 
name='add_mul_reducer')
+
+    # compute
+    m = 16
+    n = 256
+    A0 = te.placeholder((m, n), name='A0', dtype='float32')
+    A1 = te.placeholder((m, n), name='Al', dtype='float32')
+    k = te.reduce_axis((0, n), 'k')
+    T0, T1 = te.compute((m, ), lambda i: \
+        add_mul_reducer((A0[i, k], A1[i, k]), axis=k), name='T')
+
+    nthdx, nthdy = 32, 2
+    block_x = te.thread_axis("blockIdx.x")
+    thread_x = te.thread_axis((0, nthdx), "threadIdx.x")
+    thread_y = te.thread_axis((0, nthdy), "threadIdx.y")
+
+    def check_target(device):
+        ctx = tvm.context(device, 0)
+        if not ctx.exist:

Review comment:
       Exactly. We should keep both the old approach and the new optimization 
for different kinds of CUDA and devices. 




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to