This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new b3b2705  [topi] fix sparse dense schedule on cuda (#5803)
b3b2705 is described below

commit b3b27057dfb902de612dacdcd6b4b9c24e119abf
Author: Zijing Gu <jingjing...@live.com>
AuthorDate: Sun Jun 14 17:40:20 2020 -0400

    [topi] fix sparse dense schedule on cuda (#5803)
---
 topi/python/topi/cuda/sparse.py       |  5 +++++
 topi/tests/python/test_topi_sparse.py | 12 ++++++++++--
 2 files changed, 15 insertions(+), 2 deletions(-)

diff --git a/topi/python/topi/cuda/sparse.py b/topi/python/topi/cuda/sparse.py
index 037eea4..fb875b7 100644
--- a/topi/python/topi/cuda/sparse.py
+++ b/topi/python/topi/cuda/sparse.py
@@ -69,6 +69,11 @@ def schedule_sparse_dense(cfg, outs):
             y_bsrmm = op.input_tensors[0]
             assert y_bsrmm.op.tag == "sparse_dense_bsrmm_block"
             out = s.outputs[0].output(0)
+
+            if op not in s.outputs:
+                y_reshape = op.output(0)
+                s[y_reshape].compute_at(s[out], s[out].op.axis[1])
+
             (_, c) = s[y_bsrmm].op.reduce_axis
 
             (m_o, n_o) = s[out].op.axis
diff --git a/topi/tests/python/test_topi_sparse.py 
b/topi/tests/python/test_topi_sparse.py
index 3290fc0..748181d 100644
--- a/topi/tests/python/test_topi_sparse.py
+++ b/topi/tests/python/test_topi_sparse.py
@@ -288,12 +288,13 @@ def random_bsr_matrix(M, N, BS_R, BS_C, density, dtype):
     assert s.indptr.shape == (M // BS_R + 1, )
     return s
 
-def test_sparse_dense_bsr():
-    M, N, K, BS_R, BS_C, density = 1, 64, 128, 8, 16, 0.9
+def verify_sparse_dense_bsr(M, N, K, BS_R, BS_C, density, use_relu):
     X_np = np.random.randn(M, K).astype("float32")
     W_sp_np = random_bsr_matrix(N, K, BS_R, BS_C, density=density, 
dtype="float32")
     W_np = W_sp_np.todense()
     Y_np = X_np.dot(W_np.T)
+    if use_relu:
+        Y_np = np.maximum(Y_np, 0.0)
 
     W_data = te.placeholder(shape=W_sp_np.data.shape, 
dtype=str(W_sp_np.data.dtype))
     W_indices = te.placeholder(shape=W_sp_np.indices.shape, 
dtype=str(W_sp_np.indices.dtype))
@@ -309,6 +310,8 @@ def test_sparse_dense_bsr():
         fcompute, fschedule = topi.testing.dispatch(device, 
_sparse_dense_implement)
         with tvm.target.create(device):
             Y = fcompute(X, W_data, W_indices, W_indptr)
+            if use_relu:
+                Y = topi.nn.relu(Y)
             s = fschedule([Y])
             func = tvm.build(s, [X, W_data, W_indices, W_indptr, Y])
             Y_tvm = tvm.nd.array(np.zeros(Y_np.shape, dtype=Y_np.dtype), 
ctx=ctx)
@@ -322,6 +325,11 @@ def test_sparse_dense_bsr():
     for device in ['llvm', 'cuda']:
         check_device(device)
 
+def test_sparse_dense_bsr():
+    M, N, K, BS_R, BS_C, density = 1, 64, 128, 8, 16, 0.9
+    verify_sparse_dense_bsr(M, N, K, BS_R, BS_C, density, use_relu=True)
+    verify_sparse_dense_bsr(M, N, K, BS_R, BS_C, density, use_relu=False)
+
 def test_sparse_dense_bsr_randomized():
     for _ in range(20):
         BS_R = np.random.randint(1, 16)

Reply via email to