This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch sparse
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/sparse by this push:
     new 0c1e53e  Extending the GPU dot operator (#7226)
0c1e53e is described below

commit 0c1e53edcdae0fbb508feb232b3f1a0f1530c271
Author: Stefan Henneking <stefan.hennek...@gmail.com>
AuthorDate: Wed Aug 9 11:34:39 2017 -0700

    Extending the GPU dot operator (#7226)
    
    * Added GPU DotCsrRspDnsImpl declaration and TODOs
    
    * cleaning up function doc, variable types, and code-style
    
    * minor bug fixes
    
    * enable GPU dot(csr,rsp)=dns unit test
    
    * extend sparse dot unit test
    
    * adding GPU impl of DotCsrRspDns and its kernels
    
    * add TODO
    
    * changed variable types from index_t to dim_t
    
    * fix function description
    
    * added DotCsrRspRspImpl and its kernels (baseline, functionality)
    
    * added DotCsrDnsRspImpl and its kernels (baseline, functionality); plus 
code documentation
    
    * refactored dot benchmark
    
    * optimized DotCsrTransDnsRsp GPU kernel
    
    * change of dot impl interface to include OpContext, for temp storage
    
    * removing __device__ flag from CPU kernels
    
    * minor fixes and changing variable data types
    
    * minor fixes based on code reviews
---
 benchmark/python/dot.py                       | 263 ++++++++
 benchmark/python/sparse_op.py                 | 228 -------
 src/operator/tensor/dot-inl.cuh               | 832 ++++++++++++++++++++++----
 src/operator/tensor/dot-inl.h                 | 332 +++++-----
 tests/python/gpu/test_operator_gpu.py         |   2 +-
 tests/python/unittest/test_sparse_operator.py |  54 +-
 6 files changed, 1184 insertions(+), 527 deletions(-)

diff --git a/benchmark/python/dot.py b/benchmark/python/dot.py
new file mode 100644
index 0000000..2e5821a
--- /dev/null
+++ b/benchmark/python/dot.py
@@ -0,0 +1,263 @@
+import ctypes
+
+from mxnet.test_utils import *
+import scipy.sparse as sp
+import os
+import time
+import argparse
+
+from mxnet.base import check_call, _LIB
+from util import get_data, estimate_density
+
+parser = argparse.ArgumentParser(description="Benchmark sparse operators",
+                                 
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+parser.add_argument('--num-omp-threads', type=int, default=1, help='number of 
omp threads to set in MXNet')
+args = parser.parse_args()
+
+# some data information
+kdda = {
+    'data_mini': 'kdda.t.mini',
+    'data_name': 'kdda.t',
+    'data_origin_name': 'kdda.t.bz2',
+    'url': 
"https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/kdda.t.bz2";,
+    'feature_dim': 20216830,
+    'm': 200,
+    'batch_size': [64]
+}
+
+avazu = {
+    'data_mini': 'avazu-app.t.mini',
+    'data_name': 'avazu-app.t',
+    'data_origin_name': 'avazu-app.t.bz2',
+    'url': 
"https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/avazu-app.t.bz2";,
+    'feature_dim': 1000000,
+    'm': 500,
+    'batch_size': [64, 128]
+}
+
+
+def measure_cost(repeat, f, *args, **kwargs):
+    mx.nd.waitall()
+    start = time.time()
+    for i in range(repeat):
+        f(*args, **kwargs)
+    mx.nd.waitall()
+    end = time.time()
+    diff = end - start
+    return diff / repeat
+
+
+def test_dot_real(data_dict):
+    def get_iter(path, data_shape, batch_size):
+        data_train = mx.io.LibSVMIter(data_libsvm=path,
+                                      data_shape=data_shape,
+                                      batch_size=batch_size)
+        data_iter = iter(data_train)
+        return data_iter
+
+    data_dir = os.path.join(os.getcwd(), 'data')
+
+    path = os.path.join(data_dir, data_dict['data_name'])
+    if not os.path.exists(path):
+        get_data(
+            data_dir,
+            data_dict['data_name'],
+            data_dict['url'],
+            data_dict['data_origin_name']
+        )
+        assert os.path.exists(path)
+    
+    k = data_dict['feature_dim']
+    m = data_dict['m']
+    density = estimate_density(path, data_dict['feature_dim'])
+
+    mini_path = os.path.join(data_dir, data_dict['data_mini'])
+    if not os.path.exists(mini_path):
+        os.system("head -n 2000 %r > %r" % (path, mini_path))
+        assert os.path.exists(mini_path)
+    
+    print "Running Benchmarking on %r data" % data_dict['data_mini']
+    for batch_size in data_dict['batch_size']:  # iterator through different 
batch size of choice
+        print "batch_size is %d" % batch_size
+        # model
+        data_shape = (k, )
+        train_iter = get_iter(mini_path, data_shape, batch_size)
+        weight = mx.nd.random_uniform(low=0, high=1, shape=(k, m))
+
+        csr_data = []
+        dns_data = []
+        num_batch = 0
+        for batch in train_iter:
+            data = train_iter.getdata()
+            csr_data.append(data)
+            dns_data.append(data.todense())
+            num_batch += 1
+        bag_of_data = [csr_data, dns_data]
+        num_repeat = 5
+        costs = []
+        for d in bag_of_data:
+            weight.wait_to_read()
+            cost = 0.
+            count = 0
+            for d_batch in d:
+                d_batch.wait_to_read()
+                cost += measure_cost(True, num_repeat, mx.nd.dot, d_batch, 
weight)
+                count += 1
+            costs.append(cost/count)
+        t_sparse = costs[0]
+        t_dense = costs[1]
+        ratio = t_dense / t_sparse
+        print('density(%)\tn\tm\tk\tt_dense/t_sparse\tt_dense\tt_sparse')
+        fmt = "%0.4f\t\t%d\t%d\t%d\t%0.2f\t\t\t%0.4f\t%0.6f"
+        print(fmt % (density * 100, batch_size, m, k, ratio, t_dense, 
t_sparse))
+
+
+def test_dot_synthetic():
+    """benchmark sparse mxnet dot and scipy dot operator with matrices of 
given density.
+    `t_sparse` is the runtime of the invoked sparse dot operator in ms, while 
`t_dense` is the 
+    runtime of dot(dns, dns), with the same matrices except that they are in 
default storage type.
+    """
+    # Benchmark MXNet's sparse dot operator
+    def bench_mx_dot(lhs_shape, rhs_shape, lhs_stype, rhs_stype, lhs_den, 
rhs_den, trans_lhs, ctx, repeat):
+        set_default_context(ctx)
+        # Create matrix instances
+        lhs_nd = rand_ndarray(lhs_shape, lhs_stype, density=lhs_den)
+        rhs_nd = rand_ndarray(rhs_shape, rhs_stype, density=rhs_den)
+        lhs_dns = lhs_nd if lhs_stype == 'default' else lhs_nd.todense()
+        rhs_dns = rhs_nd if rhs_stype == 'default' else rhs_nd.todense()
+        # One warm up run, verify correctness
+        out = mx.nd.dot(lhs_nd, rhs_dns, trans_lhs)
+        out_expected = mx.nd.dot(lhs_dns, rhs_dns, trans_lhs)
+        assert_almost_equal(out.asnumpy(), out_expected.asnumpy(), rtol=1e-2, 
atol=1e-3)
+        # Start benchmarking
+        lhs_nd.wait_to_read()
+        rhs_nd.wait_to_read()
+        sparse_cost = measure_cost(repeat, mx.nd.dot, lhs_nd, rhs_nd, 
trans_lhs)
+        dense_cost = measure_cost(repeat, mx.nd.dot, lhs_dns, rhs_dns, 
trans_lhs)
+        speedup = dense_cost / sparse_cost
+        # Print results
+        m = lhs_shape[0]
+        k = lhs_shape[1]
+        n = rhs_shape[1]
+        results = '{:15.1f} {:15.1f} {:>10} {:8d} {:8d} {:8d} {:13.2f} 
{:13.2f} {:8.2f}'.format(lhs_den*100, rhs_den*100, str(ctx), m, k, n, 
sparse_cost*1000, dense_cost*1000, speedup)
+        print(results)
+
+    # Benchmark Scipy's sparse dot operator
+    def bench_sp_dot(lhs_shape, rhs_shape, lhs_stype, rhs_stype, lhs_den, 
rhs_den, trans_lhs, ctx, repeat):
+        set_default_context(ctx)
+        assert default_context().device_type is 'cpu'
+        assert lhs_stype is 'csr'
+        assert rhs_stype is 'default'
+        # Create matrix instances
+        lhs_nd = rand_ndarray(lhs_shape, lhs_stype, density=lhs_den)
+        rhs_nd = rand_ndarray(rhs_shape, rhs_stype, density=rhs_den)
+        lhs_nd.wait_to_read()
+        rhs_nd.wait_to_read()
+        lhs_dns_np = np.transpose(lhs_nd.asnumpy()) if trans_lhs else 
lhs_nd.asnumpy()
+        rhs_dns_np = rhs_nd.asnumpy()
+        lhs_csr_sp = sp.spmatrix.transpose(sp.csr_matrix(lhs_nd.asnumpy())) if 
trans_lhs else sp.csr_matrix(lhs_nd.asnumpy())
+        # One warm up run
+        out = sp.spmatrix.dot(lhs_csr_sp, rhs_dns_np)
+        # Start benchmarking
+        sparse_cost = measure_cost(repeat, sp.spmatrix.dot, lhs_csr_sp, 
rhs_dns_np)
+        dense_cost = measure_cost(repeat, np.dot, lhs_dns_np, rhs_dns_np)
+        speedup = dense_cost / sparse_cost
+        # Print results
+        m = lhs_shape[0]
+        k = lhs_shape[1]
+        n = rhs_shape[1]
+        results = '{:15.1f} {:15.1f} {:>10} {:8d} {:8d} {:8d} {:13.2f} 
{:13.2f} {:8.2f}'.format(lhs_den*100, rhs_den*100, str(ctx), m, k, n, 
sparse_cost*1000, dense_cost*1000, speedup)
+        print(results)
+
+    check_call(_LIB.MXSetNumOMPThreads(ctypes.c_int(args.num_omp_threads)))
+    # TODO(haibin): make these runtime options
+    # params
+    # m, n, k        rows and columns of lhs and rhs matrix
+    #                forward  pass:  m x k    * k x n = m x n
+    #                backward pass: (m x k)^T * m x n = k x n
+    # density_lhs    density of the left-hand side matrix
+    # density_rhs    density of the right-hand side matrix, if applicable
+    # num_repeat     number of benchmark runs to average over
+    # context        mx.cpu(), mx.gpu()
+    #                note: benchmark different contexts separately; to 
benchmark cpu, compile without CUDA
+    # mx_benchmarks  csr_dns, csr.T_dns, csr_rsp
+    # sp_benchmarks  csr_dns, csr.T_dns
+    #                note: scipy benchmarks are only conducted if context is 
mx.cpu()
+    m = 512
+    k = [50000, 100000]
+    n = [64, 128]
+    density_lhs = [0.64, 0.32, 0.16, 0.08, 0.04, 0.02, 0.01]
+    density_rhs = [0.64, 0.32, 0.16, 0.08, 0.04, 0.02, 0.01]
+    num_repeat = 10
+    context = mx.gpu()
+    mx_benchmarks = ["csr_dns", "csr.T_dns", "csr_rsp"]
+    sp_benchmarks = ["csr_dns", "csr.T_dns"]
+
+    headline = '{:>15} {:>15} {:>10} {:>8} {:>8} {:>8} {:>13} {:>13} 
{:>8}'.format('lhs_density(%)', 'rhs_density(%)', 'context', 'm', 'k', 'n', 
't_sparse(ms)', 't_dense(ms)', 'speedup')
+    if "csr_dns" in mx_benchmarks:
+        print("==================================================")
+        print("  mxnet sparse dot benchmark: dot(csr, dns) = dns ")
+        print("  (matrix multiplication: m x k * k x n = m x n)  ")
+        print("==================================================")
+        print(headline)
+        transpose_lhs = False
+        for i in range(len(n)):
+            for d_lhs in density_lhs:
+                bench_mx_dot((m, k[i]), (k[i], n[i]), 'csr', 'default', d_lhs, 
1, transpose_lhs, context, num_repeat)
+            print ""
+
+    if "csr_dns" in sp_benchmarks and mx.cpu() == context:
+        print("==================================================")
+        print("  scipy sparse dot benchmark: dot(csr, dns) = dns ")
+        print("  (matrix multiplication: m x k * k x n = m x n)  ")
+        print("==================================================")
+        print(headline)
+        transpose_lhs = False
+        for i in range(len(n)):
+            for d_lhs in density_lhs:
+                bench_sp_dot((m, k[i]), (k[i], n[i]), 'csr', 'default', d_lhs, 
1, transpose_lhs, context, num_repeat)
+            print ""
+
+    if "csr.T_dns" in mx_benchmarks:
+        print("==================================================")
+        print(" mxnet sparse dot benchmark: dot(csr.T, dns) = rsp")
+        print("(matrix multiplication: (m x k)^T * m x n = k x n)")
+        print("==================================================")
+        print(headline)
+        transpose_lhs = True
+        for i in range(len(n)):
+            for d_lhs in density_lhs:
+                bench_mx_dot((m, k[i]), (m, n[i]), 'csr', 'default', d_lhs, 1, 
transpose_lhs, context, num_repeat)
+            print ""
+
+    if "csr.T_dns" in sp_benchmarks and mx.cpu() == context:
+        print("==================================================")
+        print(" scipy sparse dot benchmark: dot(csr.T, dns) = dns")
+        print("(matrix multiplication: (m x k)^T * m x n = k x n)")
+        print("==================================================")
+        print(headline)
+        transpose_lhs = True
+        for i in range(len(n)):
+            for d_lhs in density_lhs:
+                bench_sp_dot((m, k[i]), (m, n[i]), 'csr', 'default', d_lhs, 1, 
transpose_lhs, context, num_repeat)
+            print ""
+
+    if "csr_rsp" in mx_benchmarks:
+        print("==================================================")
+        print("  mxnet sparse dot benchmark: dot(csr, rsp) = dns ")
+        print("  (matrix multiplication: m x k * k x n = m x n)  ")
+        print("==================================================")
+        print(headline)
+        transpose_lhs = False
+        for i in range(len(n)):
+            for d_lhs in density_lhs:
+              for d_rhs in density_rhs:
+                bench_mx_dot((m, k[i]), (k[i], n[i]), 'csr', 'row_sparse', 
d_lhs, d_rhs, transpose_lhs, context, num_repeat)
+              print ""
+            print ""
+
+
+if __name__ == "__main__":
+    test_dot_synthetic()
+    test_dot_real(avazu)
+    test_dot_real(kdda)
diff --git a/benchmark/python/sparse_op.py b/benchmark/python/sparse_op.py
deleted file mode 100644
index 15ca4df..0000000
--- a/benchmark/python/sparse_op.py
+++ /dev/null
@@ -1,228 +0,0 @@
-import ctypes
-
-from mxnet.test_utils import *
-import scipy.sparse as sp
-import os
-import time
-import argparse
-
-from mxnet.base import check_call, _LIB
-from util import get_data, estimate_density
-
-parser = argparse.ArgumentParser(description="Benchmark sparse operators",
-                                 
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
-parser.add_argument('--num-omp-threads', type=int, default=1, help='number of 
omp threads to set in MXNet')
-args = parser.parse_args()
-
-# some data information
-kdda = {
-    'data_mini': 'kdda.t.mini',
-    'data_name': 'kdda.t',
-    'data_origin_name': 'kdda.t.bz2',
-    'url': 
"https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/kdda.t.bz2";,
-    'feature_dim': 20216830,
-    'm': 200,
-    'batch_size': [64]
-}
-
-avazu = {
-    'data_mini': 'avazu-app.t.mini',
-    'data_name': 'avazu-app.t',
-    'data_origin_name': 'avazu-app.t.bz2',
-    'url': 
"https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/avazu-app.t.bz2";,
-    'feature_dim': 1000000,
-    'm': 500,
-    'batch_size': [64, 128]
-}
-
-
-def measure_cost(repeat, f, *args, **kwargs):
-    # start bench
-    start = time.time()
-    results = []
-    for i in range(repeat):
-        results.append(f(*args, **kwargs))
-    for result in results:
-        result.wait_to_read()
-    end = time.time()
-    diff = end - start
-    return diff / repeat
-
-
-def test_dot_real(data_dict):
-    def get_iter(path, data_shape, batch_size):
-        data_train = mx.io.LibSVMIter(data_libsvm=path,
-                                      data_shape=data_shape,
-                                      batch_size=batch_size)
-        data_iter = iter(data_train)
-        return data_iter
-
-    data_dir = os.path.join(os.getcwd(), 'data')
-
-    path = os.path.join(data_dir, data_dict['data_name'])
-    if not os.path.exists(path):
-        get_data(
-            data_dir,
-            data_dict['data_name'],
-            data_dict['url'],
-            data_dict['data_origin_name']
-        )
-        assert os.path.exists(path)
-    
-    k = data_dict['feature_dim']
-    m = data_dict['m']
-    density = estimate_density(path, data_dict['feature_dim'])
-
-    mini_path = os.path.join(data_dir, data_dict['data_mini'])
-    if not os.path.exists(mini_path):
-        os.system("head -n 2000 %r > %r" % (path, mini_path))
-        assert os.path.exists(mini_path)
-    
-    print "Running Benchmarking on %r data" % data_dict['data_mini']
-    for batch_size in data_dict['batch_size']:  # iterator through different 
batch size of choice
-        print "batch_size is %d" % batch_size
-        # model
-        data_shape = (k, )
-        train_iter = get_iter(mini_path, data_shape, batch_size)
-        weight = mx.nd.random_uniform(low=0, high=1, shape=(k, m))
-
-        csr_data = []
-        dns_data = []
-        num_batch = 0
-        for batch in train_iter:
-            data = train_iter.getdata()
-            csr_data.append(data)
-            dns_data.append(data.todense())
-            num_batch += 1
-        bag_of_data = [csr_data, dns_data]
-        num_repeat = 5
-        costs = []
-        for d in bag_of_data:
-            weight.wait_to_read()
-            cost = 0.
-            count = 0
-            for d_batch in d:
-                d_batch.wait_to_read()
-                cost += measure_cost(num_repeat, mx.nd.dot, d_batch, weight)
-                count += 1
-            costs.append(cost/count)
-        t_sparse = costs[0]
-        t_dense = costs[1]
-        ratio = t_dense / t_sparse
-        print('density(%)\tn\tm\tk\tt_dense/t_sparse\tt_dense\tt_sparse')
-        fmt = "%0.4f\t\t%d\t%d\t%d\t%0.2f\t\t\t%0.4f\t%0.6f"
-        print(fmt % (density * 100, batch_size, m, k, ratio, t_dense, 
t_sparse))
-
-
-def test_dot_synthetic():
-    """benchmark mx.nd.dot(sparse_ndarray, dense_ndarray) with given density.
-    `t_sparse` is the time cost of dot(csr, dns), while `t_dense` is the time 
cost
-    of dot(dns, dns), with the same matrix except that it is in default 
storage type.
-    """
-    def measure_cost_forward_baseline(repeat, dot, lhs, rhs):
-        start = time.time()
-        for i in range(repeat):
-            dot(lhs, rhs)
-        end = time.time()
-        diff = end - start
-        return diff / repeat
-
-    def measure_cost_backward_baseline(repeat, dot, transpose, lhs, rhs):
-        start = time.time()
-        for i in range(repeat):
-            dot(transpose(lhs), rhs)
-        end = time.time()
-        diff = end - start
-        return diff / repeat
-
-    def bench_dot_forward(m, k, n, density, ctx, repeat):
-        set_default_context(ctx)
-        dns = mx.nd.random_uniform(shape=(k, n)).copyto(ctx)
-        data_shape = (m, k)
-        csr_data = rand_ndarray(data_shape, 'csr', density)
-        dns_data = csr_data.todense()
-        rhs_dns_np = dns.asnumpy()
-        lhs_csr_sp = sp.csr_matrix(dns_data.asnumpy())  # csr in scipy
-        lhs_dns_np = lhs_csr_sp.todense()
-
-        data = [dns_data, csr_data]
-        costs = []
-        for d in data:
-            dns.wait_to_read()
-            d.wait_to_read()
-            cost = measure_cost(repeat, mx.nd.dot, d, dns)
-            costs.append(cost)
-        ratio = costs[0] / costs[1]
-
-        costs_baseline = []
-        cost = measure_cost_forward_baseline(repeat, np.dot, lhs_dns_np, 
rhs_dns_np)
-        costs_baseline.append(cost)
-        cost = measure_cost_forward_baseline(repeat, sp.spmatrix.dot, 
lhs_csr_sp, rhs_dns_np)
-        costs_baseline.append(cost)
-        ratio_baseline = costs_baseline[0] / costs_baseline[1]
-        fmt = 
"%0.1f\t\t%s\t%d\t%d\t%d\t%0.2f\t\t\t%0.2f\t%0.5f\t\t%0.2f\t\t\t\t%0.6f\t%0.5f"
-        print(fmt % (density * 100, str(ctx), n, m, k, ratio, costs[0], 
costs[1],
-                     ratio_baseline, costs_baseline[0], costs_baseline[1]))
-
-    def bench_dot_backward(m, k, n, density, ctx, repeat):
-        set_default_context(ctx)
-        dns = mx.nd.random_uniform(shape=(m, n)).copyto(ctx)
-        data_shape = (m, k)
-        csr_data = rand_ndarray(data_shape, 'csr', density)
-        dns_data = csr_data.todense()
-        rhs_dns_np = dns.asnumpy()
-        lhs_csr_sp = sp.csr_matrix(dns_data.asnumpy())
-        lhs_dns_np = lhs_csr_sp.todense()
-
-        data = [dns_data, csr_data]
-        costs = []
-        for d in data:
-            dns.wait_to_read()
-            d.wait_to_read()
-            cost = measure_cost(repeat, mx.nd.dot, d, dns, transpose_a=True)
-            costs.append(cost)
-        ratio = costs[0] / costs[1]
-
-        costs_baseline = []
-        cost = measure_cost_backward_baseline(repeat, np.dot, np.transpose, 
lhs_dns_np, rhs_dns_np)
-        costs_baseline.append(cost)
-        cost = measure_cost_backward_baseline(repeat, sp.spmatrix.dot, 
sp.spmatrix.transpose, lhs_csr_sp, rhs_dns_np)
-        costs_baseline.append(cost)
-        ratio_baseline = costs_baseline[0] / costs_baseline[1]
-        fmt = 
"%0.1f\t\t%s\t%d\t%d\t%d\t%0.2f\t\t\t%0.2f\t%0.5f\t\t%0.2f\t\t\t\t%0.6f\t%0.5f"
-        print(fmt % (density * 100, str(ctx), n, m, k, ratio, costs[0], 
costs[1],
-                     ratio_baseline, costs_baseline[0], costs_baseline[1]))
-
-    print("A = sparse NDArray of shape(m, k)")
-    print("B = dense NDArray of shape(k, n)")
-    print("dot_forward\tdot(csr, dns)")
-    print('density(%)\tcontext\tn\tm\tk\tt_dense/t_sparse\tt_dense\tt_sparse'
-          '\tt_scipy_dense/t_scipy_sparse\tt_scipy_dense\tt_scipy_sparse')
-
-    check_call(_LIB.MXSetNumOMPThreads(ctypes.c_int(args.num_omp_threads)))
-    # TODO(haibin) make these runtime options
-    m = 512
-    k = [50000, 100000]
-    n = [64, 128]
-    density = [1.00, 0.90, 0.70, 0.50, 0.30, 0.20, 0.10, 0.07, 0.05, 0.02, 
0.01, 0.005, 0.001]
-    num_repeat = 10
-    # contexts = [mx.cpu(), mx.gpu(0)]
-    contexts = [mx.cpu()]
-    for i in range(2):
-        for ctx in contexts:
-            for den in density:
-                bench_dot_forward(m, k[i], n[i], den, ctx, num_repeat)
-
-    print("dot_backward\tdot(csr.T, dns)")
-    print('density(%)\tcontext\tn\tm\tk\tt_dense/t_sparse\tt_dense\tt_sparse'
-          '\tt_scipy_dense/t_scipy_sparse\tt_scipy_dense\tt_scipy_sparse')
-    for i in range(2):
-        for ctx in contexts:
-            for den in density:
-                bench_dot_backward(m, k[i], n[i], den, ctx, num_repeat)
-
-
-if __name__ == "__main__":
-    test_dot_real(avazu)
-    test_dot_real(kdda)
-    test_dot_synthetic()
diff --git a/src/operator/tensor/dot-inl.cuh b/src/operator/tensor/dot-inl.cuh
index 8960798..562bf10 100644
--- a/src/operator/tensor/dot-inl.cuh
+++ b/src/operator/tensor/dot-inl.cuh
@@ -9,66 +9,163 @@
 #include <mxnet/base.h>
 #include <mxnet/operator.h>
 
+#include <cub/cub.cuh>
+
 namespace mxnet {
 namespace op {
-using mshadow::cuda::kBaseThreadNum;
 
 /*!
- * \brief Scalar kernel of dot(csr, dns1) = dns2
+ * \brief GPU auxiliary kernel to flag non-zero rows of an rsp matrix with 
indices.
+ * Parallelized by matrix rows: 1 thread/row
+ */
+struct SetRspRowFlgKernel {
+  /*!
+   * \brief
+   * \param tid      global thread id
+   * \param row_flg  array to flag storage indices of non-zero rows
+   * \param row_idx  rsp matrix row index array storing indices of non-zero 
rows
+   * \param nnr      rsp matrix number of non-zero rows (storage shape)
+   */
+  template<typename RType>
+  __device__ __forceinline__ static void Map(int tid,
+                                             RType* row_flg,
+                                             const RType* row_idx,
+                                             const nnvm::dim_t nnr) {
+    if (tid < nnr) {
+      row_flg[row_idx[tid]] = tid+1;
+    }
+  }
+};
+
+/*!
+ * \brief GPU auxiliary kernel for marking non-zero columns of a csr matrix.
+ * Parallelized by matrix rows: 1 warp/row
+ */
+struct MarkCsrZeroColsWarpKernel {
+  /*!
+   * \brief
+   * \param tid       global thread id
+   * \param col_idx   csr matrix column indices
+   * \param indptr    csr matrix row index pointer
+   * \param num_rows  csr matrix number of rows
+   * \param num_cols  csr matrix number of columns
+   */
+  template<typename CType, typename IType>
+  __device__ __forceinline__ static void Map(int tid,
+                                             nnvm::dim_t* flg,
+                                             const CType* col_idx,
+                                             const IType* indptr,
+                                             const nnvm::dim_t num_rows,
+                                             const nnvm::dim_t num_cols) {
+    typedef unsigned long long int uint64_cu;
+    static_assert(sizeof(uint64_cu) == sizeof(nnvm::dim_t), "unexpected sizeof 
dim_t");
+
+    const nnvm::dim_t warp_id = tid / 32;      // global warp   id
+    const nnvm::dim_t lane    = tid & (32-1);  // local  thread id within warp
+
+    if (warp_id < num_rows) {
+      uint64_cu zero = 0;
+      uint64_cu one = 1;
+      for (IType j = indptr[warp_id]+lane; j < indptr[warp_id+1]; j+=32) {
+        atomicCAS(reinterpret_cast<uint64_cu*>(flg+col_idx[j]), zero, one);
+      }
+    }
+  }
+};
+
+/*!
+ * \brief GPU auxiliary kernel for filling the row index array of an rsp 
matrix.
+ * Parallelized by matrix rows: 1 thread/row
+ */
+struct FillRspRowIdxKernel {
+  /*!
+   * \brief
+   * \param tid          global thread id
+   * \param row_idx      row index array to store indices of non-zero rows
+   * \param row_flg_sum  inclusive prefix sum array over 0/1 marked row flag 
array
+   * \param num_rows     rsp matrix number of rows (shape)
+   */
+  template<typename RType>
+  __device__ __forceinline__ static void Map(int tid,
+                                             RType* row_idx,
+                                             const nnvm::dim_t* row_flg_sum,
+                                             const nnvm::dim_t num_rows) {
+    if (tid < num_rows) {
+      nnvm::dim_t prev = (tid == 0)? 0 : row_flg_sum[tid-1];
+      if (row_flg_sum[tid] > prev) {
+        row_idx[prev] = static_cast<RType>(tid);
+      }
+    }
+  }
+};
+
+/*!
+ * \brief GPU scalar kernel of dot(csr, dns1) = dns2
  * Parallelization by output matrix elements: 1 thread/element
  */
 template<int req>
 struct DotCsrDnsDnsScalarKernel {
   /*!
    * \brief This function represents performing an inner product between a row 
of lhs
-   * and a column of rhs and then assigning the value to out[i].
-   * \param i i-th element in out 1D view
-   * \param out output matrix
-   * \param data_l csr values of lhs
-   * \param indptr_l csr indptr of lhs
-   * \param col_idx_l csr col_idx of lhs
-   * \param data_r dense data of rhs
-   * \param num_cols number of columns of output
+   * and a column of rhs and then assigning the value to out[tid].
+   * \param tid         global thread id
+   * \param out         output matrix data
+   * \param data_l      csr matrix data
+   * \param indptr_l    csr matrix row index pointer
+   * \param col_idx_l   csr matrix column indices
+   * \param data_r      dns1 matrix data of rhs
+   * \param num_cols_r  dns1 matrix number of columns
    */
   template<typename DType, typename IType, typename CType>
-  MSHADOW_XINLINE static void Map(int i, DType* out, const DType* data_l, 
const IType* indptr_l,
-                                  const CType* col_idx_l, const DType* data_r,
-                                  const int num_cols) {
-    const int irow = i / num_cols;  // row id of the lhs
-    const int icol = i % num_cols;  // col id of the rhs
+  __device__ __forceinline__ static void Map(int tid,
+                                             DType* out,
+                                             const DType* data_l,
+                                             const IType* indptr_l,
+                                             const CType* col_idx_l,
+                                             const DType* data_r,
+                                             const nnvm::dim_t num_cols_r) {
+    const nnvm::dim_t irow = tid / num_cols_r;  // row id of the lhs
+    const nnvm::dim_t icol = tid % num_cols_r;  // col id of the rhs
     DType sum = 0;
     for (IType j = indptr_l[irow]; j < indptr_l[irow+1]; ++j) {
       const CType cur_col = col_idx_l[j];  // corresponding row id of the rhs
-      sum += data_l[j] * data_r[cur_col*num_cols+icol];
+      sum += data_l[j] * data_r[cur_col*num_cols_r+icol];
     }
-    KERNEL_ASSIGN(out[i], req, sum);
+    KERNEL_ASSIGN(out[tid], req, sum);
   }
 };
 
 /*!
- * \brief Vector kernel of dot(csr, dns1) = dns2
+ * \brief GPU vector kernel of dot(csr, dns1) = dns2
  * Parallelization by output matrix elements: 1 warp/element
  */
 template<int req>
 struct DotCsrDnsDnsVectorKernel {
+  /*!
+   * \brief see DotCsrDnsDnsScalarKernel Map for documentation.
+   */
   template<typename DType, typename IType, typename CType>
-  __device__ __forceinline__ static void Map(int tid, DType* out, const DType* 
data_l, const IType* indptr_l,
-                                             const CType* col_idx_l, const 
DType* data_r,
-                                             const int num_cols_r) {
-    __shared__ volatile DType vals[kBaseThreadNum];
-
-    const int warp_id = tid / 32;           // global warp id
-    const int lane = tid & (32-1);          // local thread id within warp
-    const int irow = warp_id / num_cols_r;  // lhs row that this warp computes
-    const int kcol = warp_id % num_cols_r;  // rhs column that this warp 
computes
+  __device__ __forceinline__ static void Map(int tid,
+                                             DType* out,
+                                             const DType* data_l,
+                                             const IType* indptr_l,
+                                             const CType* col_idx_l,
+                                             const DType* data_r,
+                                             const nnvm::dim_t num_cols_r) {
+    using nnvm::dim_t;
+    __shared__ volatile DType vals[mshadow::cuda::kBaseThreadNum];
+    const dim_t warp_id = tid / 32;           // global warp id
+    const dim_t lane = tid & (32-1);          // local thread id within warp
+    const dim_t irow = warp_id / num_cols_r;  // lhs row that this warp 
computes
+    const dim_t kcol = warp_id % num_cols_r;  // rhs column that this warp 
computes
 
     // Range of nnz elements in this row
-    const int low  = static_cast<int>(indptr_l[irow]);
-    const int high = static_cast<int>(indptr_l[irow+1]);
+    const dim_t low  = static_cast<dim_t>(indptr_l[irow]);
+    const dim_t high = static_cast<dim_t>(indptr_l[irow+1]);
 
     // Compute running sum per thread
     DType sum = 0;
-    for (int j = low+lane; j < high; j+=32) {
+    for (dim_t j = low+lane; j < high; j+=32) {
       sum += data_l[j] * data_r[col_idx_l[j]*num_cols_r + kcol];
     }
     vals[threadIdx.x] = sum; __syncwarp();
@@ -87,39 +184,45 @@ struct DotCsrDnsDnsVectorKernel {
 };
 
 /*!
- * \brief Scalar kernel of dot(csr.T(), dns1) = dns2
+ * \brief GPU scalar kernel of dot(csr.T, dns1) = dns2
  * Parallelization by output matrix elements: 1 thread/element
  */
 template<int req>
 struct DotCsrTransDnsDnsScalarKernel {
   /*!
    * \brief This function represents performing an inner product between a 
column of lhs
-   * and a column of rhs and then assigning the value to out[i].
-   * \param i i-th element in out 1D view
-   * \param out output matrix
-   * \param data_l csr values of lhs
-   * \param indptr_l csr indptr of lhs
-   * \param col_idx_l csr col_idx of lhs
-   * \param data_r dense data of rhs
-   * \param num_rows_l number of rows of lhs
-   * \param num_cols number of columns of outputs
+   * and a column of rhs and then assigning the value to out[tid].
+   * \param tid         global thread id
+   * \param out         output matrix
+   * \param data_l      csr matrix data
+   * \param indptr_l    csr matrix row index pointer
+   * \param col_idx_l   csr matrix column indices
+   * \param data_r      dns1 matrix data of rhs
+   * \param num_rows_l  csr matrix number of rows (= number of columns of 
csr.T)
+   * \param num_cols_r  dns1 matrix number of columns
    */
   template<typename DType, typename IType, typename CType>
-  MSHADOW_XINLINE static void Map(int i, DType* out, const DType* data_l, 
const IType* indptr_l,
-                                  const CType* col_idx_l, const DType* data_r, 
const int num_rows_l,
-                                  const int num_cols) {
-    const int irow = i / num_cols;  // col id of the lhs
-    const int icol = i % num_cols;  // col id of the rhs
+  __device__ __forceinline__ static void Map(int tid,
+                                             DType* out,
+                                             const DType* data_l,
+                                             const IType* indptr_l,
+                                             const CType* col_idx_l,
+                                             const DType* data_r,
+                                             const nnvm::dim_t num_rows_l,
+                                             const nnvm::dim_t num_cols_r) {
+    using nnvm::dim_t;
+    const dim_t irow = tid / num_cols_r;  // col id of the lhs
+    const dim_t icol = tid % num_cols_r;  // col id of the rhs
     DType sum = 0;
 
     // Each thread scans each column with binary search to find nnz elements 
in its row
-    for (int k = 0; k < num_rows_l; ++k) {
-      const IType low = indptr_l[k];
-      const IType high = indptr_l[k+1];
+    for (dim_t k = 0; k < num_rows_l; ++k) {
+      const dim_t low = static_cast<dim_t>(indptr_l[k]);
+      const dim_t high = static_cast<dim_t>(indptr_l[k+1]);
       if (low == high || irow < col_idx_l[low] || irow > col_idx_l[high-1]) 
continue;
-      int j = -1, l = low, r = high - 1;
+      dim_t j = high, l = low, r = high - 1;
       while (l <= r) {
-        int m = l + (r - l) / 2;
+        dim_t m = l + (r - l) / 2;
         if (col_idx_l[m] == irow) {
           j = m; break;
         }
@@ -129,36 +232,43 @@ struct DotCsrTransDnsDnsScalarKernel {
           r = m - 1;
         }
       }
-      if (j >= 0) {
-        sum += data_l[j] * data_r[k*num_cols+icol];
+      if (j < high) {
+        sum += data_l[j] * data_r[k*num_cols_r+icol];
       }
     }
-    KERNEL_ASSIGN(out[i], req, sum);
+    KERNEL_ASSIGN(out[tid], req, sum);
   }
 };
 
 /*!
- * \brief Warp kernel of dot(csr.T(), dns1) = dns2
+ * \brief GPU warp kernel of dot(csr.T, dns1) = dns2
  * Parallelization by columns: 1 warp computes one lhs column for one rhs 
column
  */
-template<int req>
 struct DotCsrTransDnsDnsWarpKernel {
+  /*!
+   * \brief see DotCsrTransDnsDnsScalarKernel Map for documentation.
+   */
   template<typename DType, typename IType, typename CType>
-  __device__ __forceinline__ static void Map(int tid, DType* out, const DType* 
data_l, const IType* indptr_l,
-                                             const CType* col_idx_l, const 
DType* data_r,
-                                             const int num_cols_r) {
-    const int warp_id = tid / 32;           // global warp id
-    const int lane = tid & (32-1);          // local thread id within warp
-    const int icol = warp_id / num_cols_r;  // lhs column that this warp 
computes
-    const int kcol = warp_id % num_cols_r;  // rhs column that this warp 
computes
+  __device__ __forceinline__ static void Map(int tid,
+                                             DType* out,
+                                             const DType* data_l,
+                                             const IType* indptr_l,
+                                             const CType* col_idx_l,
+                                             const DType* data_r,
+                                             const nnvm::dim_t num_cols_r) {
+    using nnvm::dim_t;
+    const dim_t warp_id = tid / 32;           // global warp id
+    const dim_t lane = tid & (32-1);          // local thread id within warp
+    const dim_t icol = warp_id / num_cols_r;  // lhs column that this warp 
computes
+    const dim_t kcol = warp_id % num_cols_r;  // rhs column that this warp 
computes
 
     // Compute range of nnz elements in this column
-    const int low  = static_cast<int>(indptr_l[icol]);
-    const int high = static_cast<int>(indptr_l[icol+1]);
+    const dim_t low  = static_cast<dim_t>(indptr_l[icol]);
+    const dim_t high = static_cast<dim_t>(indptr_l[icol+1]);
 
     // Iterate through the nnz elements in this column
-    for (int j = low+lane; j < high; j+=32) {
-      const int irow = static_cast<int>(col_idx_l[j]);
+    for (dim_t j = low+lane; j < high; j+=32) {
+      const dim_t irow = static_cast<dim_t>(col_idx_l[j]);
       const DType val = data_l[j]*data_r[icol*num_cols_r+kcol];
       atomicAdd(static_cast<DType *>(&(out[irow*num_cols_r+kcol])), val);
     }
@@ -166,31 +276,38 @@ struct DotCsrTransDnsDnsWarpKernel {
 };
 
 /*!
- * \brief Thread block kernel of dot(csr.T(), dns1) = dns2
+ * \brief GPU thread block kernel of dot(csr.T, dns1) = dns2
  * Parallelization by columns: 1 thread block computes one lhs column for all 
rhs columns
  */
-template<int req>
 struct DotCsrTransDnsDnsThreadBlockKernel {
+  /*!
+   * \brief see DotCsrTransDnsDnsScalarKernel Map for documentation.
+   */
   template<typename DType, typename IType, typename CType>
-  __device__ __forceinline__ static void Map(int tid, DType* out, const DType* 
data_l, const IType* indptr_l,
-                                             const CType* col_idx_l, const 
DType* data_r,
-                                             const int num_cols_r) {
-    const int warps_per_block = blockDim.x / 32;  // number of warps in this 
thread block
-    const int warp_id = tid / 32;                 // global warp id
-    const int lane = tid & (32-1);                // local thread id within 
warp
-    const int icol = blockIdx.x;                  // lhs column that this 
thread block computes
-    const int kcol = warp_id % warps_per_block;   // rhs column where warp 
starts computing (offset)
+  __device__ __forceinline__ static void Map(int tid,
+                                             DType* out,
+                                             const DType* data_l,
+                                             const IType* indptr_l,
+                                             const CType* col_idx_l,
+                                             const DType* data_r,
+                                             const nnvm::dim_t num_cols_r) {
+    using nnvm::dim_t;
+    const dim_t warps_per_block = blockDim.x / 32;  // number of warps in this 
thread block
+    const dim_t warp_id = tid / 32;                 // global warp id
+    const dim_t lane = tid & (32-1);                // local thread id within 
warp
+    const dim_t icol = blockIdx.x;                  // lhs column that this 
thread block computes
+    const dim_t kcol = warp_id % warps_per_block;   // rhs column where warp 
starts computing (offset)
 
     // Compute range of nnz elements in this lhs column
-    const int low  = static_cast<int>(indptr_l[icol]);
-    const int high = static_cast<int>(indptr_l[icol+1]);
+    const dim_t low  = static_cast<dim_t>(indptr_l[icol]);
+    const dim_t high = static_cast<dim_t>(indptr_l[icol+1]);
 
     // Iterate through the nnz elements in this lhs column
-    for (int j = low+lane; j < high; j+=32) {
-      const int irow = static_cast<int>(col_idx_l[j]);
+    for (dim_t j = low+lane; j < high; j+=32) {
+      const dim_t irow = static_cast<dim_t>(col_idx_l[j]);
       const DType datum_l = data_l[j];
       // Iterate over rhs columns that this warp computes
-      for (int k = kcol; k < num_cols_r; k+=warps_per_block) {
+      for (dim_t k = kcol; k < num_cols_r; k+=warps_per_block) {
         const DType val = datum_l*data_r[icol*num_cols_r+k];
         atomicAdd(static_cast<DType *>(&(out[irow*num_cols_r+k])), val);
       }
@@ -199,29 +316,36 @@ struct DotCsrTransDnsDnsThreadBlockKernel {
 };
 
 /*!
- * \brief Warp block kernel of dot(csr.T(), dns1) = dns2
+ * \brief GPU warp block kernel of dot(csr.T, dns1) = dns2
  * Parallelization by columns: 1 warp computes one lhs column for all rhs 
columns
  */
-template<int req>
 struct DotCsrTransDnsDnsWarpBlockKernel {
+  /*!
+   * \brief see DotCsrTransDnsDnsScalarKernel Map for documentation.
+   */
   template<typename DType, typename IType, typename CType>
-  __device__ __forceinline__ static void Map(int tid, DType* out, const DType* 
data_l, const IType* indptr_l,
-                                             const CType* col_idx_l, const 
DType* data_r,
-                                             const int num_cols_r) {
-    const int warp_id = tid / 32;   // global warp id
-    const int lane = tid & (32-1);  // local thread id within warp
-    const int icol = warp_id;       // lhs column that this warp computes
+  __device__ __forceinline__ static void Map(int tid,
+                                             DType* out,
+                                             const DType* data_l,
+                                             const IType* indptr_l,
+                                             const CType* col_idx_l,
+                                             const DType* data_r,
+                                             const nnvm::dim_t num_cols_r) {
+    using nnvm::dim_t;
+    const dim_t warp_id = tid / 32;   // global warp id
+    const dim_t lane = tid & (32-1);  // local thread id within warp
+    const dim_t icol = warp_id;       // lhs column that this warp computes
 
     // Compute range of nnz elements in this column
-    const int low  = static_cast<int>(indptr_l[icol]);
-    const int high = static_cast<int>(indptr_l[icol+1]);
+    const dim_t low  = static_cast<dim_t>(indptr_l[icol]);
+    const dim_t high = static_cast<dim_t>(indptr_l[icol+1]);
 
     // Iterate through the nnz elements in lhs column
-    for (int j = low+lane; j < high; j+=32) {
-      const int irow = static_cast<int>(col_idx_l[j]);
+    for (dim_t j = low+lane; j < high; j+=32) {
+      const dim_t irow = static_cast<dim_t>(col_idx_l[j]);
       const DType datum_l = data_l[j];
       // Iterate over all rhs columns
-      for (int k = 0; k < num_cols_r; k++) {
+      for (dim_t k = 0; k < num_cols_r; k++) {
         const DType val = datum_l*data_r[icol*num_cols_r+k];
         atomicAdd(static_cast<DType *>(&(out[irow*num_cols_r+k])), val);
       }
@@ -229,7 +353,166 @@ struct DotCsrTransDnsDnsWarpBlockKernel {
   }
 };
 
-inline void DotCsrDnsDnsImpl(mshadow::Stream<gpu>* s,
+/*!
+ * \brief GPU warp kernel of dot(csr.T, dns) = rsp
+ * Parallelization by columns: 1 warp computes one lhs column for one rhs 
column
+ */
+struct DotCsrTransDnsRspWarpKernel {
+  /*!
+   * \brief
+   * \param tid              global thread id
+   * \param out              output rsp matrix data
+   * \param row_flg_sum_out  inclusive prefix sum array over 0/1 marked row 
flag array
+   * \param data_l           csr matrix data
+   * \param indptr_l         csr matrix row index pointer
+   * \param col_idx_l        csr matrix column indices
+   * \param data_r           dns matrix data
+   * \param num_cols_r       dns matrix number of columns
+   */
+  template<typename DType, typename IType, typename CType>
+  __device__ __forceinline__ static void Map(int tid,
+                                             DType* out,
+                                             const nnvm::dim_t* 
row_flg_sum_out,
+                                             const DType* data_l,
+                                             const IType* indptr_l,
+                                             const CType* col_idx_l,
+                                             const DType* data_r,
+                                             const nnvm::dim_t num_cols_r) {
+    using nnvm::dim_t;
+    const dim_t warp_id = tid / 32;           // global warp id
+    const dim_t lane = tid & (32-1);          // local thread id within warp
+    const dim_t icol = warp_id / num_cols_r;  // lhs column that this warp 
computes
+    const dim_t kcol = warp_id % num_cols_r;  // rhs column that this warp 
computes
+
+    // Compute range of nnz elements in this column
+    const dim_t low  = static_cast<dim_t>(indptr_l[icol]);
+    const dim_t high = static_cast<dim_t>(indptr_l[icol+1]);
+
+    // Iterate through the nnz elements in this column
+    for (dim_t j = low+lane; j < high; j+=32) {
+      const dim_t irow = static_cast<dim_t>(col_idx_l[j]);
+      const dim_t rsp_row = row_flg_sum_out[irow]-1;
+      const DType val = data_l[j]*data_r[icol*num_cols_r+kcol];
+      atomicAdd(static_cast<DType *>(&(out[rsp_row*num_cols_r+kcol])), val);
+    }
+  }
+};
+
+/*!
+ * \brief GPU Kernel of dot(csr.T, rsp1) = rsp2
+ * Parallelization by rows: 1 thread/row
+ * TODO: write a faster kernel optimized for GPU
+ */
+struct DotCsrTransRspRspByRowsKernel {
+  /*!
+   * \brief
+   * \param tid           global thread id
+   * \param out           output rsp matrix data
+   * \param row_idx_out   output rsp matrix non-zero row indices
+   * \param data_l        csr matrix data
+   * \param indptr_l      csr matrix row index pointer
+   * \param col_idx_l     csr matrix column indices
+   * \param data_r        rsp1 matrix data
+   * \param row_idx_r     rsp1 matrix non-zero row indices
+   * \param num_cols_r    rsp1 matrix number of cols
+   * \param nnr_r         rsp1 matrix number of non-zero rows
+   * \param nnr_out       output rsp matrix number of non-zero rows
+   */
+  template<typename DType, typename IType, typename CType, typename RType>
+  __device__ __forceinline__ static void Map(int tid,
+                                             DType* out,
+                                             const RType* row_idx_out,
+                                             const DType* data_l,
+                                             const IType* indptr_l,
+                                             const CType* col_idx_l,
+                                             const DType* data_r,
+                                             const RType* row_idx_r,
+                                             const nnvm::dim_t num_cols_r,
+                                             const nnvm::dim_t nnr_r,
+                                             const nnvm::dim_t nnr_out) {
+    using nnvm::dim_t;
+    // This thread computes non-zero row 'tid' of the output matrix
+    // The actual row id corresponding to the lhs row is row_idx_out[tid]
+    if (tid < nnr_out) {
+      const dim_t offset_out = tid * num_cols_r;
+      // Iterate over rhs matrix rows (or, equivalently, lhs columns worthy 
taking a look at)
+      for (dim_t i = 0; i < nnr_r; i++) {
+        const RType j = row_idx_r[i];  // j is the actual rhs row id (= lhs 
column id)
+        if (indptr_l[j] == indptr_l[j+1]) continue;
+        const dim_t offset_r = i * num_cols_r;
+        // Iterate over lhs column j to find possible non-zero value in this 
row
+        // TODO: remove sequential search, this is a bottleneck
+        for (IType k = indptr_l[j]; k < indptr_l[j+1]; k++) {
+          const CType col_idx = col_idx_l[k];
+          if (col_idx == row_idx_out[tid]) {
+            for (dim_t l = 0; l < num_cols_r; l++) {
+              out[offset_out+l] += data_l[k] * data_r[offset_r+l];
+            }
+          } else if (col_idx > row_idx_out[tid]) {
+            break;
+          }
+        }
+      }
+    }
+  }
+};
+
+/*!
+ * \brief GPU Kernel of dot(csr, rsp) = dns
+ * Parallelization by output elements: 1 thread/element
+ */
+struct DotCsrRspDnsScalarKernel {
+  /*!
+   * \brief
+   * \param tid        global thread id
+   * \param out        output dns matrix data
+   * \param data_l     csr matrix data
+   * \param indptr_l   csr matrix row index pointer
+   * \param col_idx_l  csr matrix column indices
+   * \param data_r     rsp matrix data
+   * \param row_idx_r  rsp matrix non-zero row indices
+   * \param row_flg_r  rsp matrix auxiliary array holding storage indices of 
non-zero rows
+   * \param nnr_r      rsp matrix number of non-zero rows
+   * \param num_rows   output dns matrix number of rows
+   * \param num_cols   output dns matrix number of columns
+   */
+  template<typename DType, typename IType, typename CType, typename RType>
+  __device__ __forceinline__ static void Map(int tid,
+                                             DType* out,
+                                             const DType* data_l,
+                                             const IType* indptr_l,
+                                             const CType* col_idx_l,
+                                             const DType* data_r,
+                                             const RType* row_idx_r,
+                                             const RType* row_flg_r,
+                                             const nnvm::dim_t nnr_r,
+                                             const nnvm::dim_t num_rows,
+                                             const nnvm::dim_t num_cols) {
+    using nnvm::dim_t;
+    if (tid < num_rows*num_cols) {
+      const dim_t i = static_cast<dim_t>(tid) / num_cols;  // i = row this 
thread computes
+      const dim_t k = static_cast<dim_t>(tid) % num_cols;  // k = col this 
thread computes
+      // Compute inner product of i-th row and k-th col
+      DType sum = 0;
+      for (IType j = indptr_l[i]; j < indptr_l[i+1]; j++) {
+        const dim_t csr_col = col_idx_l[j];
+        const dim_t rsp_row_idx = row_flg_r[csr_col];
+        if (rsp_row_idx > 0) {
+          sum += data_l[j] * data_r[(rsp_row_idx-1)*num_cols+k];
+        }
+      }
+      if (sum != 0) {
+        out[i*num_cols+k] += sum;
+      }
+    }
+  }
+};
+
+/*!
+ * \brief GPU Impl of dot(csr, dns1) = dns2 and dot(csr.T, dns1) = dns2
+ */
+inline void DotCsrDnsDnsImpl(const OpContext& ctx,
+                             const gpu& gpu_dev,
                              const NDArray& lhs,
                              const TBlob& rhs,
                              const OpReqType req,
@@ -239,6 +522,22 @@ inline void DotCsrDnsDnsImpl(mshadow::Stream<gpu>* s,
   CHECK_EQ(lhs.storage_type(), kCSRStorage);
   if (!lhs.storage_initialized()) return;
 
+  using mshadow::cuda::kBaseThreadNum;
+  using mxnet_op::Kernel;
+  using mxnet_op::set_zero;
+  using nnvm::dim_t;
+  mshadow::Stream<gpu>* s = ctx.get_stream<gpu>();
+
+  const dim_t num_rows_l = lhs.shape()[0];
+  const dim_t num_cols_r = rhs.shape_[1];
+  const dim_t threads_per_warp = mxnet_op::cuda_get_device_prop().warpSize;
+  const dim_t threads_per_block = kBaseThreadNum;
+  dim_t num_threads;
+  // TODO: remove kernel dependency on warpSize=32
+  if (threads_per_warp != 32) {
+    LOG(FATAL) << "DotCsrDnsDnsImpl GPU kernels expect warpSize=32";
+  }
+
   const TBlob data_l = lhs.data();
   const TBlob indptr_l = lhs.aux_data(csr::kIndPtr);
   const TBlob col_idx_l = lhs.aux_data(csr::kIdx);
@@ -249,13 +548,9 @@ inline void DotCsrDnsDnsImpl(mshadow::Stream<gpu>* s,
     MSHADOW_IDX_TYPE_SWITCH(indptr_l.type_flag_, IType, {  // indptr type
       MSHADOW_IDX_TYPE_SWITCH(col_idx_l.type_flag_, CType, {  // col idx type
         if (kWriteTo == req) {
-          mxnet_op::Kernel<mxnet_op::set_zero, gpu>::Launch(s, 
data_out.Size(), data_out.dptr<DType>());
+          num_threads = data_out.Size();
+          Kernel<set_zero, gpu>::Launch(s, num_threads, 
data_out.dptr<DType>());
         }
-        int num_threads;
-        const int threads_per_warp = 32;
-        const int threads_per_block = kBaseThreadNum;
-        const int num_rows_l = lhs.shape()[0];
-        const int num_cols_r = rhs.shape_[1];
         if (trans_lhs) {
           // Different kernel versions are optimized for different matrix 
instances
           // TODO: switch between kernel versions depending on input
@@ -268,42 +563,34 @@ inline void DotCsrDnsDnsImpl(mshadow::Stream<gpu>* s,
             case 1:
               num_threads = data_out.Size();
               MXNET_ASSIGN_REQ_SWITCH(req, ReqType, {
-                mxnet_op::Kernel<DotCsrTransDnsDnsScalarKernel<ReqType>, 
gpu>::Launch(s, num_threads,
+                Kernel<DotCsrTransDnsDnsScalarKernel<ReqType>, gpu>::Launch(s, 
num_threads,
                     data_out.dptr<DType>(), data_l.dptr<DType>(), 
indptr_l.dptr<IType>(),
                     col_idx_l.dptr<CType>(), data_r.dptr<DType>(), num_rows_l, 
num_cols_r);
               });
               break;
             case 2:
               num_threads = threads_per_warp * num_rows_l * num_cols_r;
-              MXNET_ASSIGN_REQ_SWITCH(req, ReqType, {
-                mxnet_op::Kernel<DotCsrTransDnsDnsWarpKernel<ReqType>, 
gpu>::Launch(s, num_threads,
-                    data_out.dptr<DType>(), data_l.dptr<DType>(), 
indptr_l.dptr<IType>(),
-                    col_idx_l.dptr<CType>(), data_r.dptr<DType>(), num_cols_r);
-              });
+              Kernel<DotCsrTransDnsDnsWarpKernel, gpu>::Launch(s, num_threads,
+                  data_out.dptr<DType>(), data_l.dptr<DType>(), 
indptr_l.dptr<IType>(),
+                  col_idx_l.dptr<CType>(), data_r.dptr<DType>(), num_cols_r);
               break;
             case 3:
               num_threads = threads_per_block * num_rows_l;
-              MXNET_ASSIGN_REQ_SWITCH(req, ReqType, {
-                mxnet_op::Kernel<DotCsrTransDnsDnsThreadBlockKernel<ReqType>, 
gpu>::Launch(s, num_threads,
-                    data_out.dptr<DType>(), data_l.dptr<DType>(), 
indptr_l.dptr<IType>(),
-                    col_idx_l.dptr<CType>(), data_r.dptr<DType>(), num_cols_r);
-              });
+              Kernel<DotCsrTransDnsDnsThreadBlockKernel, gpu>::Launch(s, 
num_threads,
+                  data_out.dptr<DType>(), data_l.dptr<DType>(), 
indptr_l.dptr<IType>(),
+                  col_idx_l.dptr<CType>(), data_r.dptr<DType>(), num_cols_r);
               break;
             case 4:
               num_threads = threads_per_warp * num_rows_l;
-              MXNET_ASSIGN_REQ_SWITCH(req, ReqType, {
-                mxnet_op::Kernel<DotCsrTransDnsDnsWarpBlockKernel<ReqType>, 
gpu>::Launch(s, num_threads,
-                    data_out.dptr<DType>(), data_l.dptr<DType>(), 
indptr_l.dptr<IType>(),
-                    col_idx_l.dptr<CType>(), data_r.dptr<DType>(), num_cols_r);
-              });
+              Kernel<DotCsrTransDnsDnsWarpBlockKernel, gpu>::Launch(s, 
num_threads,
+                  data_out.dptr<DType>(), data_l.dptr<DType>(), 
indptr_l.dptr<IType>(),
+                  col_idx_l.dptr<CType>(), data_r.dptr<DType>(), num_cols_r);
               break;
             default:
               num_threads = threads_per_warp * num_rows_l * num_cols_r;
-              MXNET_ASSIGN_REQ_SWITCH(req, ReqType, {
-                mxnet_op::Kernel<DotCsrTransDnsDnsWarpKernel<ReqType>, 
gpu>::Launch(s, num_threads,
-                    data_out.dptr<DType>(), data_l.dptr<DType>(), 
indptr_l.dptr<IType>(),
-                    col_idx_l.dptr<CType>(), data_r.dptr<DType>(), num_cols_r);
-              });
+              Kernel<DotCsrTransDnsDnsWarpKernel, gpu>::Launch(s, num_threads,
+                  data_out.dptr<DType>(), data_l.dptr<DType>(), 
indptr_l.dptr<IType>(),
+                  col_idx_l.dptr<CType>(), data_r.dptr<DType>(), num_cols_r);
               break;
           }
         } else {
@@ -315,7 +602,7 @@ inline void DotCsrDnsDnsImpl(mshadow::Stream<gpu>* s,
             case 1:
               num_threads = data_out.Size();
               MXNET_ASSIGN_REQ_SWITCH(req, ReqType, {
-                mxnet_op::Kernel<DotCsrDnsDnsScalarKernel<ReqType>, 
gpu>::Launch(s, num_threads,
+                Kernel<DotCsrDnsDnsScalarKernel<ReqType>, gpu>::Launch(s, 
num_threads,
                     data_out.dptr<DType>(), data_l.dptr<DType>(), 
indptr_l.dptr<IType>(),
                     col_idx_l.dptr<CType>(), data_r.dptr<DType>(), num_cols_r);
               });
@@ -323,7 +610,7 @@ inline void DotCsrDnsDnsImpl(mshadow::Stream<gpu>* s,
             case 2:
               num_threads = threads_per_warp * num_rows_l * num_cols_r;
               MXNET_ASSIGN_REQ_SWITCH(req, ReqType, {
-                mxnet_op::Kernel<DotCsrDnsDnsVectorKernel<ReqType>, 
gpu>::Launch(s, num_threads,
+                Kernel<DotCsrDnsDnsVectorKernel<ReqType>, gpu>::Launch(s, 
num_threads,
                     data_out.dptr<DType>(), data_l.dptr<DType>(), 
indptr_l.dptr<IType>(),
                     col_idx_l.dptr<CType>(), data_r.dptr<DType>(), num_cols_r);
               });
@@ -332,14 +619,14 @@ inline void DotCsrDnsDnsImpl(mshadow::Stream<gpu>* s,
               if (num_cols_r > 4) {
                 num_threads = data_out.Size();
                 MXNET_ASSIGN_REQ_SWITCH(req, ReqType, {
-                  mxnet_op::Kernel<DotCsrDnsDnsScalarKernel<ReqType>, 
gpu>::Launch(s, num_threads,
+                  Kernel<DotCsrDnsDnsScalarKernel<ReqType>, gpu>::Launch(s, 
num_threads,
                       data_out.dptr<DType>(), data_l.dptr<DType>(), 
indptr_l.dptr<IType>(),
                       col_idx_l.dptr<CType>(), data_r.dptr<DType>(), 
num_cols_r);
                 });
               } else {
                 num_threads = threads_per_warp * num_rows_l * num_cols_r;
                 MXNET_ASSIGN_REQ_SWITCH(req, ReqType, {
-                  mxnet_op::Kernel<DotCsrDnsDnsVectorKernel<ReqType>, 
gpu>::Launch(s, num_threads,
+                  Kernel<DotCsrDnsDnsVectorKernel<ReqType>, gpu>::Launch(s, 
num_threads,
                       data_out.dptr<DType>(), data_l.dptr<DType>(), 
indptr_l.dptr<IType>(),
                       col_idx_l.dptr<CType>(), data_r.dptr<DType>(), 
num_cols_r);
                 });
@@ -353,27 +640,308 @@ inline void DotCsrDnsDnsImpl(mshadow::Stream<gpu>* s,
 }
 
 /*!
- * \brief Impl of dot(csr.T, dns) = rsp
+ * \brief GPU Impl of dot(csr, dns) = rsp and dot(csr.T, dns) = rsp
  */
-inline void DotCsrDnsRspImpl(mshadow::Stream<gpu>* s,
+inline void DotCsrDnsRspImpl(const OpContext& ctx,
+                             const gpu& gpu_dev,
                              const NDArray& lhs,
                              const TBlob& rhs,
                              const OpReqType req,
                              const bool trans_lhs,
                              NDArray* ret) {
-  LOG(FATAL) << "DotCsrDnsRspImpl gpu version is not implemented.";
+  if (kNullOp == req) return;
+  CHECK_EQ(lhs.storage_type(), kCSRStorage);
+  CHECK_EQ(ret->storage_type(), kRowSparseStorage);
+  CHECK_EQ(req, kWriteTo);
+  if (!lhs.storage_initialized()) return;
+
+  using mshadow::Shape1;
+  using mxnet_op::Kernel;
+  using mxnet_op::set_zero;
+  using nnvm::dim_t;
+  mshadow::Stream<gpu>* s = ctx.get_stream<gpu>();
+
+  const TBlob data_l = lhs.data();
+  const TBlob indptr_l = lhs.aux_data(csr::kIndPtr);
+  const TBlob col_idx_l = lhs.aux_data(csr::kIdx);
+  const TBlob& data_r = rhs;
+
+  const dim_t num_rows_l = lhs.shape()[0];
+  const dim_t num_cols_l = lhs.shape()[1];
+  const dim_t num_cols_r = rhs.shape_[1];
+  const dim_t threads_per_warp = mxnet_op::cuda_get_device_prop().warpSize;
+  dim_t num_threads;
+  // TODO: remove kernel dependency on warpSize=32
+  if (threads_per_warp != 32) {
+    LOG(FATAL) << "DotCsrDnsRspImpl GPU kernels expect warpSize=32";
+  }
+
+  MSHADOW_SGL_DBL_TYPE_SWITCH(data_l.type_flag_, DType, {  // data type
+    MSHADOW_IDX_TYPE_SWITCH(indptr_l.type_flag_, IType, {  // indptr type
+      MSHADOW_IDX_TYPE_SWITCH(col_idx_l.type_flag_, CType, {  // col idx type
+        if (trans_lhs) {
+          // Compute number of non-zero rows (nnr) of output matrix
+          // - alloc temp storage for row_flg array and for cub's prefix sum
+          // - mark non-zero columns of csr matrix in row_flg
+          // - compute inclusive prefix sum over marked array
+          // - copy last value (nnr_out) from device to host
+          dim_t* row_flg_out = NULL;
+          void* d_temp_storage = NULL;
+          size_t temp_storage_bytes = 0;
+          cub::DeviceScan::InclusiveSum(d_temp_storage,
+                                        temp_storage_bytes,
+                                        row_flg_out,
+                                        row_flg_out,
+                                        num_cols_l,
+                                        mshadow::Stream<gpu>::GetStream(s));
+          mshadow::Tensor<gpu, 1, char> workspace = ctx.requested[0]
+              .get_space_typed<gpu, 1, char>(Shape1(num_cols_l * sizeof(dim_t) 
+
+                                                    temp_storage_bytes), s);
+          row_flg_out = reinterpret_cast<dim_t*>(workspace.dptr_);
+          d_temp_storage = workspace.dptr_ + num_cols_l*sizeof(dim_t);
+          num_threads = num_cols_l;
+          Kernel<set_zero, gpu>::Launch(s, num_threads, row_flg_out);
+          num_threads = num_rows_l * threads_per_warp;
+          Kernel<MarkCsrZeroColsWarpKernel, gpu>::Launch(s, num_threads,
+              row_flg_out, col_idx_l.dptr<CType>(), indptr_l.dptr<IType>(),
+              num_rows_l, num_cols_l);
+          cub::DeviceScan::InclusiveSum(d_temp_storage,
+                                        temp_storage_bytes,
+                                        row_flg_out,
+                                        row_flg_out,
+                                        num_cols_l,
+                                        mshadow::Stream<gpu>::GetStream(s));
+          dim_t nnr_out = 0;
+          CUDA_CALL(cudaMemcpy(&nnr_out, &row_flg_out[num_cols_l-1], 
sizeof(dim_t),
+                               cudaMemcpyDeviceToHost));
+
+          // Allocate output matrix space
+          ret->CheckAndAlloc({Shape1(nnr_out)});
+          const TBlob data_out_blob = ret->data();
+          const TBlob row_idx_out_blob = ret->aux_data(rowsparse::kIdx);
+          MSHADOW_IDX_TYPE_SWITCH(row_idx_out_blob.type_flag_, RType, {  // 
row idx type
+            DType* data_out = data_out_blob.dptr<DType>();
+            RType* row_idx_out = row_idx_out_blob.dptr<RType>();
+            num_threads = nnr_out * num_cols_r;
+            Kernel<set_zero, gpu>::Launch(s, num_threads, data_out);
+            num_threads = nnr_out;
+            Kernel<set_zero, gpu>::Launch(s, num_threads, row_idx_out);
+
+            // Fill row_idx array of output matrix, using the row_flg values
+            num_threads = num_cols_l;
+            Kernel<FillRspRowIdxKernel, gpu>::Launch(s, num_threads,
+                row_idx_out, row_flg_out, num_cols_l);
+
+            // Perform matrix-matrix multiply
+            num_threads = threads_per_warp * num_rows_l * num_cols_r;
+            Kernel<DotCsrTransDnsRspWarpKernel, gpu>::Launch(s, num_threads,
+                data_out, row_flg_out,
+                data_l.dptr<DType>(), indptr_l.dptr<IType>(), 
col_idx_l.dptr<CType>(),
+                data_r.dptr<DType>(), num_cols_r);
+          });
+        } else {
+          LOG(FATAL) << "DotCsrDnsRspImpl has not implemented dot(csr, dns) = 
rsp yet.";
+        }
+      });
+    });
+  });
 }
 
 /*!
- * \brief Impl of dot(csr.T, rsp) = rsp2
+ * \brief GPU Impl of dot(csr, rsp1) = rsp2 and dot(csr.T, rsp1) = rsp2
+ * TODO: Optimize for GPU; this is a baseline implementation providing
+ *       the operator functionality, it is not yet fully optimized for GPU.
  */
-inline void DotCsrRspRspImpl(mshadow::Stream<gpu>* s,
+inline void DotCsrRspRspImpl(const OpContext& ctx,
+                             const gpu& gpu_dev,
                              const NDArray& lhs,
                              const NDArray& rhs,
                              const OpReqType req,
                              const bool trans_lhs,
                              NDArray* ret) {
-  LOG(FATAL) << "DotCsrRspRspImpl gpu version is not implemented.";
+  if (kNullOp == req) return;
+  // Reuse dot(csr, dns) implementation if rhs rsp matrix is in fact dense
+  if (rhs.storage_shape()[0] == rhs.shape()[0]) {
+    DotCsrDnsRspImpl(ctx, gpu_dev, lhs, rhs.data(), req, trans_lhs, ret);
+    return;
+  }
+  CHECK_EQ(lhs.storage_type(), kCSRStorage);
+  CHECK_EQ(rhs.storage_type(), kRowSparseStorage);
+  CHECK_EQ(ret->storage_type(), kRowSparseStorage);
+  if (!lhs.storage_initialized() || !rhs.storage_initialized()) return;
+  CHECK_EQ(req, kWriteTo);
+
+  using mshadow::Shape1;
+  using mxnet_op::Kernel;
+  using mxnet_op::set_zero;
+  using nnvm::dim_t;
+  mshadow::Stream<gpu>* s = ctx.get_stream<gpu>();
+
+  const TBlob data_l = lhs.data();
+  const TBlob indptr_l = lhs.aux_data(csr::kIndPtr);
+  const TBlob col_idx_l = lhs.aux_data(csr::kIdx);
+  const TBlob data_r = rhs.data();
+  const TBlob row_idx_r = rhs.aux_data(rowsparse::kIdx);
+
+  const dim_t num_rows_l = lhs.shape()[0];
+  const dim_t num_cols_l = lhs.shape()[1];
+  const dim_t num_cols_r = rhs.shape()[1];
+  const dim_t nnr_r = rhs.storage_shape()[0];
+  const dim_t threads_per_warp = mxnet_op::cuda_get_device_prop().warpSize;
+  dim_t num_threads;
+  // TODO: remove kernel dependency on warpSize=32
+  if (threads_per_warp != 32) {
+    LOG(FATAL) << "DotCsrRspRspImpl GPU kernels expect warpSize=32";
+  }
+
+  MSHADOW_SGL_DBL_TYPE_SWITCH(data_l.type_flag_, DType, {  // data type
+    MSHADOW_IDX_TYPE_SWITCH(indptr_l.type_flag_, IType, {  // indptr type
+      MSHADOW_IDX_TYPE_SWITCH(col_idx_l.type_flag_, CType, {  // col idx type
+        MSHADOW_IDX_TYPE_SWITCH(row_idx_r.type_flag_, RType, {  // row idx type
+          if (trans_lhs) {
+            // Compute number of non-zero rows (nnr) of output matrix
+            // - alloc temp storage for row_flg array and for cub's prefix sum
+            // - mark non-zero columns of csr matrix in row_flg
+            // - compute inclusive prefix sum over marked array
+            // - copy last value (nnr_out) from device to host
+            dim_t* row_flg_out = NULL;
+            void* d_temp_storage = NULL;
+            size_t temp_storage_bytes = 0;
+            cub::DeviceScan::InclusiveSum(d_temp_storage,
+                                          temp_storage_bytes,
+                                          row_flg_out,
+                                          row_flg_out,
+                                          num_cols_l,
+                                          mshadow::Stream<gpu>::GetStream(s));
+            mshadow::Tensor<gpu, 1, char> workspace = ctx.requested[0]
+                .get_space_typed<gpu, 1, char>(Shape1(num_cols_l * 
sizeof(dim_t) +
+                                                      temp_storage_bytes), s);
+            row_flg_out = reinterpret_cast<dim_t*>(workspace.dptr_);
+            d_temp_storage = workspace.dptr_ + num_cols_l*sizeof(dim_t);
+            num_threads = num_cols_l;
+            Kernel<set_zero, gpu>::Launch(s, num_threads, row_flg_out);
+            num_threads = num_rows_l * threads_per_warp;
+            Kernel<MarkCsrZeroColsWarpKernel, gpu>::Launch(s, num_threads,
+                row_flg_out, col_idx_l.dptr<CType>(), indptr_l.dptr<IType>(),
+                num_rows_l, num_cols_l);
+            cub::DeviceScan::InclusiveSum(d_temp_storage,
+                                          temp_storage_bytes,
+                                          row_flg_out,
+                                          row_flg_out,
+                                          num_cols_l,
+                                          mshadow::Stream<gpu>::GetStream(s));
+            dim_t nnr_out = 0;
+            CUDA_CALL(cudaMemcpy(&nnr_out, &row_flg_out[num_cols_l-1], 
sizeof(dim_t),
+                                 cudaMemcpyDeviceToHost));
+
+            // Allocate output matrix space
+            ret->CheckAndAlloc({mshadow::Shape1(nnr_out)});
+            const TBlob data_out_blob = ret->data();
+            const TBlob row_idx_out_blob = ret->aux_data(rowsparse::kIdx);
+            DType* data_out = data_out_blob.dptr<DType>();
+            RType* row_idx_out = row_idx_out_blob.dptr<RType>();
+            num_threads = nnr_out * num_cols_r;
+            Kernel<set_zero, gpu>::Launch(s, num_threads, data_out);
+            num_threads = nnr_out;
+            Kernel<set_zero, gpu>::Launch(s, num_threads, row_idx_out);
+
+            // Fill row_idx array of output matrix, using the row_flg values
+            num_threads = num_cols_l;
+            Kernel<FillRspRowIdxKernel, gpu>::Launch(s, num_threads,
+                row_idx_out, row_flg_out, num_cols_l);
+
+            // Perform matrix-matrix multiply
+            num_threads = nnr_out;
+            Kernel<DotCsrTransRspRspByRowsKernel, gpu>::Launch(s, num_threads,
+                data_out, row_idx_out,
+                data_l.dptr<DType>(), indptr_l.dptr<IType>(), 
col_idx_l.dptr<CType>(),
+                data_r.dptr<DType>(), row_idx_r.dptr<RType>(),
+                num_cols_r, nnr_r, nnr_out);
+          } else {
+            LOG(FATAL) << "DotCsrRspRspImpl has not implemented dot(csr, rsp1) 
= rsp2 yet.";
+          }
+        });
+      });
+    });
+  });
+}
+
+/*!
+ * \brief GPU Impl of dot(csr, rsp) = dns and dot(csr.T, rsp) = dns
+ */
+inline void DotCsrRspDnsImpl(const OpContext& ctx,
+                             const gpu& gpu_dev,
+                             const NDArray& lhs,
+                             const NDArray& rhs,
+                             const OpReqType req,
+                             const bool trans_lhs,
+                             TBlob* ret) {
+  // Reuse dot(csr, dns) implementation if rhs rsp matrix is in fact dense
+  if (rhs.storage_shape()[0] == rhs.shape()[0]) {
+    DotCsrDnsDnsImpl(ctx, gpu_dev, lhs, rhs.data(), req, trans_lhs, ret);
+    return;
+  }
+  if (kNullOp == req) return;
+  CHECK_EQ(lhs.storage_type(), kCSRStorage);
+  CHECK_EQ(rhs.storage_type(), kRowSparseStorage);
+
+  using mxnet_op::Kernel;
+  using mxnet_op::set_zero;
+  mshadow::Stream<gpu>* s = ctx.get_stream<gpu>();
+  if (!lhs.storage_initialized() || !rhs.storage_initialized()) {
+    if (kWriteTo == req) {
+      MSHADOW_TYPE_SWITCH(ret->type_flag_, DType, {  // data type
+        Kernel<set_zero, gpu>::Launch(s, ret->Size(), ret->dptr<DType>());
+      });
+    }
+    return;
+  }
+
+  using nnvm::dim_t;
+  const dim_t num_rows = ret->shape_[0];
+  const dim_t num_cols = ret->shape_[1];
+  const dim_t nnr_r = rhs.storage_shape()[0];
+  dim_t num_threads;
+
+  const TBlob data_l = lhs.data();
+  const TBlob indptr_l = lhs.aux_data(csr::kIndPtr);
+  const TBlob col_idx_l = lhs.aux_data(csr::kIdx);
+  const TBlob data_r = rhs.data();
+  const TBlob row_idx_r = rhs.aux_data(rowsparse::kIdx);
+
+  MSHADOW_SGL_DBL_TYPE_SWITCH(data_l.type_flag_, DType, {  // data type
+    MSHADOW_IDX_TYPE_SWITCH(indptr_l.type_flag_, IType, {  // indptr type
+      MSHADOW_IDX_TYPE_SWITCH(col_idx_l.type_flag_, CType, {  // col idx type
+        MSHADOW_IDX_TYPE_SWITCH(row_idx_r.type_flag_, RType, {  // row idx type
+          if (kWriteTo == req) {
+            num_threads = num_rows*num_cols;
+            Kernel<set_zero, gpu>::Launch(s, num_threads, ret->dptr<DType>());
+          }
+          if (trans_lhs) {
+            LOG(FATAL) << "DotCsrRspDnsImpl has not implemented dot(csr.T, 
rsp) = dns yet.";
+          } else {
+            // TODO: Consider implementing a vector kernel for SpMV (similar 
to DotCsrDnsDns)
+            // Alloc temp storage for row_flg array
+            RType* row_flg_r = ctx.requested[0]
+                .get_space_typed<gpu, 1, 
RType>(mshadow::Shape1(rhs.shape()[0]), s).dptr_;
+            num_threads = rhs.shape()[0];
+            Kernel<set_zero, gpu>::Launch(s, num_threads, row_flg_r);
+            // Set row_flg index array
+            num_threads = nnr_r;
+            Kernel<SetRspRowFlgKernel, gpu>::Launch(s, num_threads,
+                row_flg_r, row_idx_r.dptr<RType>(), nnr_r);
+            // Perform sparse matrix-matrix multiply
+            num_threads = num_rows*num_cols;
+            Kernel<DotCsrRspDnsScalarKernel, gpu>::Launch(s, num_threads,
+                ret->dptr<DType>(),
+                data_l.dptr<DType>(), indptr_l.dptr<IType>(), 
col_idx_l.dptr<CType>(),
+                data_r.dptr<DType>(), row_idx_r.dptr<RType>(), row_flg_r, 
rhs.storage_shape()[0],
+                num_rows, num_cols);
+          }
+        });
+      });
+    });
+  });
 }
 
 }  // namespace op
diff --git a/src/operator/tensor/dot-inl.h b/src/operator/tensor/dot-inl.h
index 8aabde0..42aecb4 100644
--- a/src/operator/tensor/dot-inl.h
+++ b/src/operator/tensor/dot-inl.h
@@ -187,8 +187,8 @@ inline bool DotForwardInferStorageType(const 
nnvm::NodeAttrs& attrs,
   CHECK_EQ(out_attrs->size(), 1U);
   const DotParam& param = nnvm::get<DotParam>(attrs.parsed);
   // csr has many zero columns, so the result of dot(csr.T, matrix) should be 
rsp
-  // dot(csr.T,dns)=rsp not yet implemented on gpu
-  if (param.transpose_a && kCSRStorage == (*in_attrs)[0] && ctx.dev_type != 
Context::kGPU) {
+  // TODO(stefan/haibin): don't enforce kRowSparseStorage if out_attrs has 
already been set
+  if (param.transpose_a && kCSRStorage == (*in_attrs)[0]) {
     STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, 0, kRowSparseStorage);
   } else {
     STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, 0, kDefaultStorage);
@@ -213,7 +213,7 @@ inline bool DotBackwardInferStorageType(const 
nnvm::NodeAttrs& attrs,
 }
 
 /*!
- * \brief Kernel of dot(csr, dns1) = dns2
+ * \brief CPU Kernel of dot(csr, dns1) = dns2
  * Parallelization by row blocks
  */
 struct DotCsrDnsDnsByRowBlocks {
@@ -222,19 +222,26 @@ struct DotCsrDnsDnsByRowBlocks {
    * \param i the i-th thread
    */
   template<typename DType, typename IType, typename CType>
-  MSHADOW_XINLINE static void Map(int i, DType* out, const DType* data_l, 
const IType* indptr_l,
-                                  const CType* col_idx_l, const DType* data_r, 
const size_t seg_len,
-                                  const size_t num_rows, const size_t 
num_cols) {
-    const size_t seg_start = i * seg_len;
+  MSHADOW_CINLINE static void Map(int i,
+                                  DType* out,
+                                  const DType* data_l,
+                                  const IType* indptr_l,
+                                  const CType* col_idx_l,
+                                  const DType* data_r,
+                                  const nnvm::dim_t seg_len,
+                                  const nnvm::dim_t num_rows,
+                                  const nnvm::dim_t num_cols) {
+    using nnvm::dim_t;
+    const dim_t seg_start = i * seg_len;
     if (seg_start >= num_rows) return;
-    const size_t seg_end = (seg_start+seg_len < num_rows? seg_start+seg_len : 
num_rows);
-    for (size_t j = seg_start; j < seg_end; ++j) {
+    const dim_t seg_end = std::min(seg_start + seg_len, num_rows);
+    for (dim_t j = seg_start; j < seg_end; ++j) {
       if (indptr_l[j] == indptr_l[j+1]) continue;
-      const size_t offset_out = j * num_cols;
-      for (auto k = indptr_l[j]; k < indptr_l[j+1]; ++k) {
-        const auto val = data_l[k];
-        const size_t offset_r = col_idx_l[k] * num_cols;
-        for (size_t l = 0; l < num_cols; ++l) {
+      const dim_t offset_out = j * num_cols;
+      for (IType k = indptr_l[j]; k < indptr_l[j+1]; ++k) {
+        const DType val = data_l[k];
+        const dim_t offset_r = col_idx_l[k] * num_cols;
+        for (dim_t l = 0; l < num_cols; ++l) {
           out[offset_out+l] += data_r[offset_r+l] * val;
         }
       }
@@ -243,7 +250,7 @@ struct DotCsrDnsDnsByRowBlocks {
 };
 
 /*!
- * \brief Kernel of dot(csr.T(), dns1) = dns2
+ * \brief CPU Kernel of dot(csr.T(), dns1) = dns2
  * Parallelization by row blocks
  */
 struct DotCsrTransDnsDnsByRowBlocks {
@@ -252,22 +259,29 @@ struct DotCsrTransDnsDnsByRowBlocks {
    * \param i the i-th thread
    */
   template<typename DType, typename IType, typename CType>
-  MSHADOW_XINLINE static void Map(int i, DType* out, const DType* data_l, 
const IType* indptr_l,
-                                  const CType* col_idx_l, const DType* data_r, 
const size_t seg_len,
-                                  const size_t num_rows_l, const size_t 
num_rows,
-                                  const size_t num_cols) {
-    const size_t seg_start = i * seg_len;
+  MSHADOW_CINLINE static void Map(int i,
+                                  DType* out,
+                                  const DType* data_l,
+                                  const IType* indptr_l,
+                                  const CType* col_idx_l,
+                                  const DType* data_r,
+                                  const nnvm::dim_t seg_len,
+                                  const nnvm::dim_t num_rows_l,
+                                  const nnvm::dim_t num_rows,
+                                  const nnvm::dim_t num_cols) {
+    using nnvm::dim_t;
+    const dim_t seg_start = i * seg_len;
     if (seg_start >= num_rows) return;
-    const size_t seg_end = (i + 1) * seg_len;
-    for (size_t j = 0; j < num_rows_l; ++j) {
+    const dim_t seg_end = (i + 1) * seg_len;
+    for (dim_t j = 0; j < num_rows_l; ++j) {
       if (indptr_l[j] == indptr_l[j+1]) continue;
-      const size_t offset_r = j * num_cols;
-      for (auto k = indptr_l[j]; k < indptr_l[j+1]; ++k) {
-        const auto col_idx = col_idx_l[k];
+      const dim_t offset_r = j * num_cols;
+      for (IType k = indptr_l[j]; k < indptr_l[j+1]; ++k) {
+        const CType col_idx = col_idx_l[k];
         if (col_idx < seg_start || col_idx >= seg_end) continue;
-        const size_t offset_out = col_idx * num_cols;
-        const auto val = data_l[k];
-        for (size_t l = 0; l < num_cols; ++l) {
+        const dim_t offset_out = col_idx * num_cols;
+        const DType val = data_l[k];
+        for (dim_t l = 0; l < num_cols; ++l) {
           out[offset_out+l] += data_r[offset_r+l] * val;
         }
       }
@@ -276,11 +290,10 @@ struct DotCsrTransDnsDnsByRowBlocks {
 };
 
 /*!
- * \brief Kernel of dot(csr.T(), dns) = rsp
+ * \brief CPU Kernel of dot(csr.T(), dns) = rsp
  * Parallelization by row blocks.
- * This kernel fills up the row_idx array
- * of the rsp with 1 for nonzero rows and 0
- * for zero rows.
+ * This kernel fills up the row_idx array of the rsp 
+ * with 1 for nonzero rows and 0 for zero rows.
  * The matrix will be compacted after this kernel call.
  */
 struct DotCsrTransDnsRspByRowBlocks {
@@ -289,24 +302,31 @@ struct DotCsrTransDnsRspByRowBlocks {
    * \param i the i-th thread
    */
   template<typename DType, typename RType, typename IType, typename CType>
-  MSHADOW_XINLINE static void Map(int i, DType* out, RType* row_idx, const 
DType* data_l,
-                                  const IType* indptr_l, const CType* 
col_idx_l,
-                                  const DType* data_r, const size_t seg_len,
-                                  const size_t num_rows_l, const size_t 
num_rows,
-                                  const size_t num_cols) {
-    const size_t seg_start = i * seg_len;
+  MSHADOW_CINLINE static void Map(int i,
+                                  DType* out,
+                                  RType* row_idx,
+                                  const DType* data_l,
+                                  const IType* indptr_l,
+                                  const CType* col_idx_l,
+                                  const DType* data_r,
+                                  const nnvm::dim_t seg_len,
+                                  const nnvm::dim_t num_rows_l,
+                                  const nnvm::dim_t num_rows,
+                                  const nnvm::dim_t num_cols) {
+    using nnvm::dim_t;
+    const dim_t seg_start = i * seg_len;
     if (seg_start >= num_rows) return;
-    const size_t seg_end = (i + 1) * seg_len;
-    for (size_t j = 0; j < num_rows_l; ++j) {
+    const dim_t seg_end = (i + 1) * seg_len;
+    for (dim_t j = 0; j < num_rows_l; ++j) {
       if (indptr_l[j] == indptr_l[j+1]) continue;
-      const size_t offset_r = j * num_cols;
-      for (auto k = indptr_l[j]; k < indptr_l[j+1]; ++k) {
-        const auto col_idx = col_idx_l[k];
+      const dim_t offset_r = j * num_cols;
+      for (IType k = indptr_l[j]; k < indptr_l[j+1]; ++k) {
+        const CType col_idx = col_idx_l[k];
         if (col_idx < seg_start || col_idx >= seg_end) continue;
-        const size_t offset_out = col_idx * num_cols;
+        const dim_t offset_out = col_idx * num_cols;
         row_idx[col_idx] = 1;
-        const auto val = data_l[k];
-        for (size_t l = 0; l < num_cols; ++l) {
+        const DType val = data_l[k];
+        for (dim_t l = 0; l < num_cols; ++l) {
           out[offset_out+l] += data_r[offset_r+l] * val;
         }
       }
@@ -315,33 +335,40 @@ struct DotCsrTransDnsRspByRowBlocks {
 };
 
 /*!
- * \brief Kernel of dot(csr, rsp) = dns
+ * \brief CPU Kernel of dot(csr, rsp) = dns
  * Parallelization by row blocks
  */
 struct DotCsrRspDnsByRowBlocks {
   /*!
    * \brief
-   * \param i the i-th thread
-   * \param nnr_r storage_shape[0] of the rsp
-   * \param num_rows dns.shape[0]
-   * \param num_cols dns.shape[1]
+   * \param i         the i-th thread
+   * \param nnr_r     storage_shape[0] of the rsp
+   * \param num_rows  dns.shape[0]
+   * \param num_cols  dns.shape[1]
    */
   template<typename DType, typename IType, typename CType, typename RType>
-  MSHADOW_XINLINE static void Map(int i, DType* out, const DType* data_l,
-                                  const IType* indptr_l, const CType* 
col_idx_l,
-                                  const DType* data_r, const RType* row_idx_r,
-                                  const size_t nnr_r, const size_t num_rows,
-                                  const size_t num_cols, const size_t seg_len) 
{
-    const size_t seg_start = i * seg_len;
+  MSHADOW_CINLINE static void Map(int i,
+                                  DType* out,
+                                  const DType* data_l,
+                                  const IType* indptr_l,
+                                  const CType* col_idx_l,
+                                  const DType* data_r,
+                                  const RType* row_idx_r,
+                                  const nnvm::dim_t nnr_r,
+                                  const nnvm::dim_t num_rows,
+                                  const nnvm::dim_t num_cols,
+                                  const nnvm::dim_t seg_len) {
+    using nnvm::dim_t;
+    const dim_t seg_start = i * seg_len;
     if (seg_start >= num_rows) return;
-    const size_t seg_end = (seg_start+seg_len < num_rows? seg_start+seg_len : 
num_rows);
-    for (size_t j = seg_start; j < seg_end; ++j) {
+    const dim_t seg_end = std::min(seg_start + seg_len, num_rows);
+    for (dim_t j = seg_start; j < seg_end; ++j) {
       if (indptr_l[j] == indptr_l[j+1]) continue;
-      const size_t offset_out = j * num_cols;
+      const dim_t offset_out = j * num_cols;
       // Use binary search to find the lower_bound of val in row_idx array
       const RType* first = row_idx_r;
       const RType* last = row_idx_r + nnr_r;
-      const auto val = col_idx_l[indptr_l[j]];
+      const CType val = col_idx_l[indptr_l[j]];
       const RType* it;
       int count = last - first, step;
       while (count > 0) {
@@ -358,10 +385,10 @@ struct DotCsrRspDnsByRowBlocks {
       const RType* row_idx_ptr = first;
       // end of binary search
       if (row_idx_ptr == row_idx_r+nnr_r || *row_idx_ptr> 
col_idx_l[indptr_l[j+1]-1]) continue;
-      for (auto k = indptr_l[j]; k < indptr_l[j+1] && row_idx_ptr != 
row_idx_r+nnr_r;) {
+      for (IType k = indptr_l[j]; k < indptr_l[j+1] && row_idx_ptr != 
row_idx_r+nnr_r;) {
         if (col_idx_l[k] == *row_idx_ptr) {
-          const size_t offset_r = (row_idx_ptr - row_idx_r) * num_cols;
-          for (size_t l = 0; l < num_cols; ++l) {
+          const dim_t offset_r = (row_idx_ptr - row_idx_r) * num_cols;
+          for (dim_t l = 0; l < num_cols; ++l) {
             out[offset_out+l] += data_l[k] * data_r[offset_r+l];
           }
           ++k;
@@ -377,7 +404,7 @@ struct DotCsrRspDnsByRowBlocks {
 };
 
 /*!
- * \brief Kernel of dot(csr.T(), rsp) = dns with row_idx marked for non-zero 
rows
+ * \brief CPU Kernel of dot(csr.T(), rsp1) = rsp2, with row_idx marked for 
non-zero rows
  * Parallelization by row blocks
  */
 struct DotCsrTransRspRspByRowBlocks {
@@ -390,25 +417,33 @@ struct DotCsrTransRspRspByRowBlocks {
    * \param num_cols number of cols of out matrix
    */
   template<typename DType, typename IType, typename CType, typename RType>
-  MSHADOW_XINLINE static void Map(int i, DType* out, RType* row_idx_out,
-                                  const DType* data_l, const IType* indptr_l,
-                                  const CType* col_idx_l, const DType* data_r,
-                                  const RType* row_idx_r, const size_t 
num_rows_l,
-                                  const size_t nnr_r, const size_t num_rows,
-                                  const size_t num_cols, const size_t seg_len) 
{
-    const size_t seg_start = i * seg_len;
+  MSHADOW_CINLINE static void Map(int i,
+                                  DType* out,
+                                  RType* row_idx_out,
+                                  const DType* data_l,
+                                  const IType* indptr_l,
+                                  const CType* col_idx_l,
+                                  const DType* data_r,
+                                  const RType* row_idx_r,
+                                  const nnvm::dim_t num_rows_l,
+                                  const nnvm::dim_t nnr_r,
+                                  const nnvm::dim_t num_rows,
+                                  const nnvm::dim_t num_cols,
+                                  const nnvm::dim_t seg_len) {
+    using nnvm::dim_t;
+    const dim_t seg_start = i * seg_len;
     if (seg_start >= num_rows) return;
-    const size_t seg_end = (i + 1) * seg_len;
-    for (size_t rid = 0; rid < nnr_r; ++rid) {
-      const auto j = row_idx_r[rid];
+    const dim_t seg_end = (i + 1) * seg_len;
+    for (dim_t rid = 0; rid < nnr_r; ++rid) {
+      const RType j = row_idx_r[rid];
       if (indptr_l[j] == indptr_l[j+1]) continue;
-      const size_t offset_r = rid * num_cols;
-      for (auto k = indptr_l[j]; k < indptr_l[j+1]; ++k) {
-        const auto col_idx = col_idx_l[k];
+      const dim_t offset_r = rid * num_cols;
+      for (IType k = indptr_l[j]; k < indptr_l[j+1]; ++k) {
+        const CType col_idx = col_idx_l[k];
         if (col_idx < seg_start || col_idx >= seg_end) continue;
         row_idx_out[col_idx] = 1;  // mark nonzero row as 1
-        const size_t offset_out = col_idx * num_cols;
-        for (size_t l = 0; l < num_cols; ++l) {
+        const dim_t offset_out = col_idx * num_cols;
+        for (dim_t l = 0; l < num_cols; ++l) {
           out[offset_out+l] += data_r[offset_r+l] * data_l[k];
         }
       }
@@ -416,7 +451,11 @@ struct DotCsrTransRspRspByRowBlocks {
   }
 };
 
-inline void DotCsrDnsDnsImpl(mshadow::Stream<cpu>* s,
+/*!
+ * \brief CPU Impl of dot(csr, dns1) = dns2 and dot(csr.T, dns1) = dns2
+ */
+inline void DotCsrDnsDnsImpl(const OpContext& ctx,
+                             const cpu& cpu_dev,
                              const NDArray& lhs,
                              const TBlob& rhs,
                              const OpReqType req,
@@ -426,6 +465,9 @@ inline void DotCsrDnsDnsImpl(mshadow::Stream<cpu>* s,
   CHECK_EQ(lhs.storage_type(), kCSRStorage);
   if (!lhs.storage_initialized()) return;
 
+  using nnvm::dim_t;
+
+  mshadow::Stream<cpu>* s = ctx.get_stream<cpu>();
   const TBlob data_l = lhs.data();
   const TBlob indptr_l = lhs.aux_data(csr::kIndPtr);
   const TBlob col_idx_l = lhs.aux_data(csr::kIdx);
@@ -435,12 +477,14 @@ inline void DotCsrDnsDnsImpl(mshadow::Stream<cpu>* s,
   MSHADOW_TYPE_SWITCH(data_l.type_flag_, DType, {  // data type
     MSHADOW_IDX_TYPE_SWITCH(indptr_l.type_flag_, IType, {  // indptr type
       MSHADOW_IDX_TYPE_SWITCH(col_idx_l.type_flag_, CType, {  // col idx type
+        dim_t num_threads;
         if (kWriteTo == req) {
+          num_threads = data_out.Size();
           mxnet_op::Kernel<mxnet_op::set_zero, cpu>::Launch(
-              s, data_out.Size(), data_out.dptr<DType>());
+              s, num_threads, data_out.dptr<DType>());
         }
-        int num_threads = mxnet_op::get_num_threads<cpu>(data_out.shape_[0]);
-        size_t seg_len = (data_out.shape_[0] + num_threads - 1) / num_threads;
+        num_threads = mxnet_op::get_num_threads<cpu>(data_out.shape_[0]);
+        dim_t seg_len = (data_out.shape_[0] + num_threads - 1) / num_threads;
         if (trans_lhs) {
           mxnet_op::Kernel<DotCsrTransDnsDnsByRowBlocks, cpu>::Launch(s, 
num_threads,
               data_out.dptr<DType>(), data_l.dptr<DType>(), 
indptr_l.dptr<IType>(),
@@ -458,9 +502,10 @@ inline void DotCsrDnsDnsImpl(mshadow::Stream<cpu>* s,
 }
 
 /*!
- * \brief Impl of dot(csr, rsp)
+ * \brief CPU Impl of dot(csr.T, dns) = rsp
  */
-inline void DotCsrDnsRspImpl(mshadow::Stream<cpu>* s,
+inline void DotCsrDnsRspImpl(const OpContext& ctx,
+                             const cpu& cpu_dev,
                              const NDArray& lhs,
                              const TBlob& rhs,
                              const OpReqType req,
@@ -470,7 +515,12 @@ inline void DotCsrDnsRspImpl(mshadow::Stream<cpu>* s,
   CHECK_EQ(lhs.storage_type(), kCSRStorage);
   CHECK_EQ(ret->storage_type(), kRowSparseStorage);
   if (!lhs.storage_initialized()) return;
+  CHECK_EQ(req, kWriteTo);
+
+  using mxnet_op::set_zero;
+  using nnvm::dim_t;
 
+  mshadow::Stream<cpu>* s = ctx.get_stream<cpu>();
   const TBlob data_l = lhs.data();
   const TBlob indptr_l = lhs.aux_data(csr::kIndPtr);
   const TBlob col_idx_l = lhs.aux_data(csr::kIdx);
@@ -484,27 +534,25 @@ inline void DotCsrDnsRspImpl(mshadow::Stream<cpu>* s,
   MSHADOW_TYPE_SWITCH(data_l.type_flag_, DType, {  // data type
     MSHADOW_IDX_TYPE_SWITCH(indptr_l.type_flag_, IType, {  // indptr type
       MSHADOW_IDX_TYPE_SWITCH(col_idx_l.type_flag_, CType, {  // col idx type
-        MSHADOW_IDX_TYPE_SWITCH(row_idx_out.type_flag_, RType, {  // col idx 
type
-          if (kWriteTo == req) {
-            mxnet_op::Kernel<mxnet_op::set_zero, cpu>::Launch(
-                s, data_out.Size(), data_out.dptr<DType>());
-          }
+        MSHADOW_IDX_TYPE_SWITCH(row_idx_out.type_flag_, RType, {  // row idx 
type
+          dim_t num_threads = data_out.Size();
+          mxnet_op::Kernel<set_zero, cpu>::Launch(s, num_threads, 
data_out.dptr<DType>());
           RType* row_idx = row_idx_out.dptr<RType>();
-          mxnet_op::Kernel<mxnet_op::set_zero, cpu>::Launch(
-              s, row_idx_out.Size(), row_idx);
-          int num_threads = mxnet_op::get_num_threads<cpu>(data_out.shape_[0]);
-          size_t seg_len = (data_out.shape_[0] + num_threads - 1) / 
num_threads;
+          num_threads = row_idx_out.Size();
+          mxnet_op::Kernel<set_zero, cpu>::Launch(s, num_threads, row_idx);
+          num_threads = mxnet_op::get_num_threads<cpu>(data_out.shape_[0]);
+          dim_t seg_len = (data_out.shape_[0] + num_threads - 1) / num_threads;
           if (trans_lhs) {
             mxnet_op::Kernel<DotCsrTransDnsRspByRowBlocks, cpu>::Launch(s, 
num_threads,
                 data_out.dptr<DType>(), row_idx, data_l.dptr<DType>(),
                 indptr_l.dptr<IType>(), col_idx_l.dptr<CType>(), 
data_r.dptr<DType>(),
                 seg_len, lhs.shape()[0], data_out.shape_[0], 
data_out.shape_[1]);
-            index_t nnr = 0;
+            dim_t nnr = 0;
             nnr = mxnet::common::ParallelAccumulate(row_idx, ret->shape()[0], 
nnr);
             ret->set_aux_shape(rowsparse::kIdx, mshadow::Shape1(nnr));
             if (0 == nnr) return;
             mshadow::Tensor<cpu, 2, DType> rsp_data = data_out.FlatTo2D<cpu, 
DType>(s);
-            size_t idx = 0;
+            dim_t idx = 0;
             for (index_t i = 0; i < ret->shape()[0]; ++i) {
               if (row_idx[i] > 0) {
                 row_idx[idx] = i;
@@ -513,8 +561,7 @@ inline void DotCsrDnsRspImpl(mshadow::Stream<cpu>* s,
               }
             }
           } else {
-            LOG(FATAL) << "DotCsrDnsRspImpl has not implemented dot(csr, 
dns)=rsp yet."
-                          " Only the cpu version of dot(csr.T, dns)=rsp is 
supported now";
+            LOG(FATAL) << "DotCsrDnsRspImpl has not implemented dot(csr, 
dns)=rsp yet.";
           }
         });
       });
@@ -522,31 +569,36 @@ inline void DotCsrDnsRspImpl(mshadow::Stream<cpu>* s,
   });
 }
 
-template<typename xpu>
-void DotCsrRspDnsImpl(mshadow::Stream<xpu>* s,
-                      const NDArray& lhs,
-                      const NDArray& rhs,
-                      const OpReqType req,
-                      const bool trans_lhs,
-                      TBlob* ret) {
+/*!
+ * \brief CPU Impl of dot(csr, rsp) = dns
+ */
+inline void DotCsrRspDnsImpl(const OpContext& ctx,
+                             const cpu& cpu_dev,
+                             const NDArray& lhs,
+                             const NDArray& rhs,
+                             const OpReqType req,
+                             const bool trans_lhs,
+                             TBlob* ret) {
+  if (kNullOp == req) return;
   // reuse csr dns implementation when storage_shape == shape for rhs
   if (rhs.storage_shape()[0] == rhs.shape()[0]) {  // if rsp is actually dense
-    DotCsrDnsDnsImpl(s, lhs, rhs.data(), req, trans_lhs, ret);
+    DotCsrDnsDnsImpl(ctx, cpu_dev, lhs, rhs.data(), req, trans_lhs, ret);
     return;
   }
 
-  if (kNullOp == req) return;
   CHECK_EQ(lhs.storage_type(), kCSRStorage);
   CHECK_EQ(rhs.storage_type(), kRowSparseStorage);
+  mshadow::Stream<cpu>* s = ctx.get_stream<cpu>();
   if (!lhs.storage_initialized() || !rhs.storage_initialized()) {
     if (kWriteTo == req) {
       MSHADOW_TYPE_SWITCH(ret->type_flag_, DType, {  // data type
-        mxnet_op::Kernel<mxnet_op::set_zero, xpu>::Launch(
+        mxnet_op::Kernel<mxnet_op::set_zero, cpu>::Launch(
             s, ret->Size(), ret->dptr<DType>());
       });
     }
     return;
   }
+  using nnvm::dim_t;
 
   const TBlob data_l = lhs.data();
   const TBlob indptr_l = lhs.aux_data(csr::kIndPtr);
@@ -557,17 +609,19 @@ void DotCsrRspDnsImpl(mshadow::Stream<xpu>* s,
   MSHADOW_TYPE_SWITCH(data_l.type_flag_, DType, {  // data type
     MSHADOW_IDX_TYPE_SWITCH(indptr_l.type_flag_, IType, {  // indptr type
       MSHADOW_IDX_TYPE_SWITCH(col_idx_l.type_flag_, CType, {  // col idx type
-        MSHADOW_IDX_TYPE_SWITCH(row_idx_r.type_flag_, RType, {  // col idx type
+        MSHADOW_IDX_TYPE_SWITCH(row_idx_r.type_flag_, RType, {  // row idx type
+          dim_t num_threads;
           if (kWriteTo == req) {
-            mxnet_op::Kernel<mxnet_op::set_zero, xpu>::Launch(
-                s, ret->Size(), ret->dptr<DType>());
+            num_threads = ret->Size();
+            mxnet_op::Kernel<mxnet_op::set_zero, cpu>::Launch(s, num_threads,
+                                                              
ret->dptr<DType>());
           }
-          int num_threads = mxnet_op::get_num_threads<xpu>(ret->shape_[0]);
-          size_t seg_len = (ret->shape_[0] + num_threads - 1) / num_threads;
+          num_threads = mxnet_op::get_num_threads<cpu>(ret->shape_[0]);
+          dim_t seg_len = (ret->shape_[0] + num_threads - 1) / num_threads;
           if (trans_lhs) {
             LOG(FATAL) << "DotCsrRspDnsImpl has not implemented dot(csr.T, 
rsp) = dns yet";
           } else {
-            mxnet_op::Kernel<DotCsrRspDnsByRowBlocks, xpu>::Launch(s, 
num_threads,
+            mxnet_op::Kernel<DotCsrRspDnsByRowBlocks, cpu>::Launch(s, 
num_threads,
                 ret->dptr<DType>(), data_l.dptr<DType>(),
                 indptr_l.dptr<IType>(), col_idx_l.dptr<CType>(), 
data_r.dptr<DType>(),
                 row_idx_r.dptr<RType>(), rhs.storage_shape()[0],
@@ -580,26 +634,32 @@ void DotCsrRspDnsImpl(mshadow::Stream<xpu>* s,
 }
 
 /*!
- * \brief Impl of dot(csr.T, rsp) = rsp2
+ * \brief CPU Impl of dot(csr.T, rsp1) = rsp2
  */
-inline void DotCsrRspRspImpl(mshadow::Stream<cpu>* s,
+inline void DotCsrRspRspImpl(const OpContext& ctx,
+                             const cpu& cpu_dev,
                              const NDArray& lhs,
                              const NDArray& rhs,
                              const OpReqType req,
                              const bool trans_lhs,
                              NDArray* ret) {
+  if (kNullOp == req) return;
   // reuse csr dns implementation when storage_shape == shape for rhs
   if (rhs.storage_shape()[0] == rhs.shape()[0]) {  // if rsp is actually dense
-    DotCsrDnsRspImpl(s, lhs, rhs.data(), req, trans_lhs, ret);
+    DotCsrDnsRspImpl(ctx, cpu_dev, lhs, rhs.data(), req, trans_lhs, ret);
     return;
   }
 
-  if (kNullOp == req) return;
   CHECK_EQ(lhs.storage_type(), kCSRStorage);
   CHECK_EQ(rhs.storage_type(), kRowSparseStorage);
   CHECK_EQ(ret->storage_type(), kRowSparseStorage);
   if (!lhs.storage_initialized() || !rhs.storage_initialized()) return;
+  CHECK_EQ(req, kWriteTo);
+
+  using mxnet_op::set_zero;
+  using nnvm::dim_t;
 
+  mshadow::Stream<cpu>* s = ctx.get_stream<cpu>();
   const TBlob data_l = lhs.data();
   const TBlob indptr_l = lhs.aux_data(csr::kIndPtr);
   const TBlob col_idx_l = lhs.aux_data(csr::kIdx);
@@ -616,28 +676,26 @@ inline void DotCsrRspRspImpl(mshadow::Stream<cpu>* s,
   MSHADOW_TYPE_SWITCH(data_l.type_flag_, DType, {  // data type
     MSHADOW_IDX_TYPE_SWITCH(indptr_l.type_flag_, IType, {  // indptr type
       MSHADOW_IDX_TYPE_SWITCH(col_idx_l.type_flag_, CType, {  // col idx type
-        MSHADOW_IDX_TYPE_SWITCH(row_idx_r.type_flag_, RType, {  // col idx type
-          if (kWriteTo == req) {
-            mxnet_op::Kernel<mxnet_op::set_zero, cpu>::Launch(
-                s, data_out.Size(), data_out.dptr<DType>());
-          }
-          int num_threads = mxnet_op::get_num_threads<cpu>(data_out.shape_[0]);
-          size_t seg_len = (data_out.shape_[0] + num_threads - 1) / 
num_threads;
+        MSHADOW_IDX_TYPE_SWITCH(row_idx_r.type_flag_, RType, {  // row idx type
+          dim_t num_threads = data_out.Size();
+          mxnet_op::Kernel<set_zero, cpu>::Launch(s, num_threads, 
data_out.dptr<DType>());
+          num_threads = mxnet_op::get_num_threads<cpu>(data_out.shape_[0]);
+          dim_t seg_len = (data_out.shape_[0] + num_threads - 1) / num_threads;
           if (trans_lhs) {
             RType* row_idx = row_idx_out.dptr<RType>();
-            mxnet_op::Kernel<mxnet_op::set_zero, cpu>::Launch(
-                s, row_idx_out.Size(), row_idx);
+            num_threads = row_idx_out.Size();
+            mxnet_op::Kernel<set_zero, cpu>::Launch(s, num_threads, row_idx);
             mxnet_op::Kernel<DotCsrTransRspRspByRowBlocks, cpu>::Launch(s, 
num_threads,
                 data_out.dptr<DType>(), row_idx, data_l.dptr<DType>(),
                 indptr_l.dptr<IType>(), col_idx_l.dptr<CType>(), 
data_r.dptr<DType>(),
                 row_idx_r.dptr<RType>(), lhs.shape()[0], 
rhs.storage_shape()[0],
                 ret->shape()[0], ret->shape()[1], seg_len);
-            index_t nnr = 0;
+            dim_t nnr = 0;
             nnr = mxnet::common::ParallelAccumulate(row_idx, ret->shape()[0], 
nnr);
             ret->set_aux_shape(rowsparse::kIdx, mshadow::Shape1(nnr));
             if (0 == nnr) return;
             mshadow::Tensor<cpu, 2, DType> rsp_data = data_out.FlatTo2D<cpu, 
DType>(s);
-            size_t idx = 0;
+            dim_t idx = 0;
             for (index_t i = 0; i < ret->shape()[0]; ++i) {
               if (row_idx[i] > 0) {
                 row_idx[idx] = i;
@@ -713,22 +771,21 @@ void DotForwardEx(const nnvm::NodeAttrs& attrs,
   auto lhs_stype = inputs[0].storage_type();
   auto rhs_stype = inputs[1].storage_type();
   auto out_stype = outputs[0].storage_type();
-  mshadow::Stream<xpu>* s = ctx.get_stream<xpu>();
   if (lhs_stype == kCSRStorage && rhs_stype == kDefaultStorage && out_stype == 
kDefaultStorage) {
     TBlob ret = outputs[0].data();
-    DotCsrDnsDnsImpl(s, inputs[0], inputs[1].data(), req[0], 
param.transpose_a, &ret);
+    DotCsrDnsDnsImpl(ctx, xpu(), inputs[0], inputs[1].data(), req[0], 
param.transpose_a, &ret);
   } else if (lhs_stype == kCSRStorage && rhs_stype == kRowSparseStorage
       && out_stype == kDefaultStorage) {
     TBlob ret = outputs[0].data();
-    DotCsrRspDnsImpl<xpu>(s, inputs[0], inputs[1], req[0], param.transpose_a, 
&ret);
+    DotCsrRspDnsImpl(ctx, xpu(), inputs[0], inputs[1], req[0], 
param.transpose_a, &ret);
   } else if (lhs_stype == kCSRStorage && rhs_stype == kDefaultStorage
       && out_stype == kRowSparseStorage) {
     NDArray out = outputs[0];
-    DotCsrDnsRspImpl(s, inputs[0], inputs[1].data(), req[0], 
param.transpose_a, &out);
+    DotCsrDnsRspImpl(ctx, xpu(), inputs[0], inputs[1].data(), req[0], 
param.transpose_a, &out);
   } else if (lhs_stype == kCSRStorage && rhs_stype == kRowSparseStorage
       && out_stype == kRowSparseStorage) {
     NDArray ret = outputs[0];
-    DotCsrRspRspImpl(s, inputs[0], inputs[1], req[0], param.transpose_a, &ret);
+    DotCsrRspRspImpl(ctx, xpu(), inputs[0], inputs[1], req[0], 
param.transpose_a, &ret);
   } else {
     FCompExFallback<xpu>(attrs, ctx, inputs, req, outputs, DotForward_<xpu>, 
"DotForward_");
   }
@@ -755,17 +812,16 @@ void DotBackwardEx(const nnvm::NodeAttrs& attrs,
   const auto lhs_stype = inputs[1].storage_type();
   const auto rhs_stype = inputs[2].storage_type();
   const auto grad_rhs_stype = outputs[1].storage_type();
-  mshadow::Stream<xpu>* s = ctx.get_stream<xpu>();
   if (ograd_stype == kDefaultStorage  // ograd dns format
       && lhs_stype == kCSRStorage  // csr input lhs of the op
       && grad_rhs_stype == kDefaultStorage) {  // grad(rhs) dns format
     TBlob ret = outputs[1].data();
-    DotCsrDnsDnsImpl(s, inputs[1], inputs[0].data(), req[1], 
!param.transpose_a, &ret);
+    DotCsrDnsDnsImpl(ctx, xpu(), inputs[1], inputs[0].data(), req[1], 
!param.transpose_a, &ret);
   } else if (ograd_stype == kDefaultStorage
       && lhs_stype == kCSRStorage
       && grad_rhs_stype == kRowSparseStorage) {
     NDArray ret = outputs[1];
-    DotCsrDnsRspImpl(s, inputs[1], inputs[0].data(), req[1], 
!param.transpose_a, &ret);
+    DotCsrDnsRspImpl(ctx, xpu(), inputs[1], inputs[0].data(), req[1], 
!param.transpose_a, &ret);
   } else {
     FCompExFallback<xpu>(attrs, ctx, inputs, req, outputs, DotBackward_<xpu>, 
"DotBackward_");
   }
diff --git a/tests/python/gpu/test_operator_gpu.py 
b/tests/python/gpu/test_operator_gpu.py
index ac59811..adc4c3b 100644
--- a/tests/python/gpu/test_operator_gpu.py
+++ b/tests/python/gpu/test_operator_gpu.py
@@ -5,7 +5,7 @@ sys.path.insert(0, os.path.join(curr_path, '../unittest'))
 from test_operator import *
 from test_optimizer import *
 from test_random import *
-from test_sparse_operator import test_sparse_dot, test_sparse_nd_zeros
+from test_sparse_operator import test_cast_storage_ex, test_sparse_dot, 
test_sparse_nd_zeros
 from test_sparse_ndarray import test_create_csr, test_create_row_sparse
 import mxnet as mx
 import numpy as np
diff --git a/tests/python/unittest/test_sparse_operator.py 
b/tests/python/unittest/test_sparse_operator.py
index aa7349d..e6e8add 100644
--- a/tests/python/unittest/test_sparse_operator.py
+++ b/tests/python/unittest/test_sparse_operator.py
@@ -67,6 +67,7 @@ def test_elemwise_add_ex_multiple_stages():
     exec_test.backward(out_grads=exec_test.outputs)
     assert_almost_equal(arr_grads[0].asnumpy(), arr_grads[1].asnumpy())
 
+
 # TODO(haibin) also add test for backward pass.
 def test_cast_storage_ex():
     def test_rsp_to_dns(shape, density):
@@ -113,46 +114,44 @@ def test_cast_storage_ex():
             test_dns_to_rsp((rnd.randint(1, 10), rnd.randint(512, 1024)), d) # 
test gpu block  kernel
 
 
-
 def test_sparse_dot():
-    def test_dot_csr(lhs_shape, rhs_shape, rhs_stype, trans_lhs, density=1):
-        lhs_nd = rand_ndarray(lhs_shape, 'csr', 1)
+    def test_dot_csr(lhs_shape, rhs_shape, rhs_stype, trans_lhs, lhs_density, 
rhs_density):
+        lhs_nd = rand_ndarray(lhs_shape, 'csr', density=lhs_density)
         lhs_dns = lhs_nd.todense()
-        rhs_nd = rand_ndarray(rhs_shape, rhs_stype, density=density)
+        rhs_nd = rand_ndarray(rhs_shape, rhs_stype, density=rhs_density)
         rhs_dns = rhs_nd if rhs_stype == 'default' else rhs_nd.todense()
-        out = mx.nd.dot(lhs_nd, rhs_dns, transpose_a=trans_lhs)
-        if trans_lhs and default_context().device_type is 'cpu':
-            assert out.stype == 'row_sparse'
-        else:
-            assert out.stype == 'default'
-        out_expected = mx.nd.dot(lhs_dns, rhs_dns, transpose_a=trans_lhs)
-        out_np = out_expected.asnumpy()
-        backward_trans = not trans_lhs
-        rhs_backward_grad = mx.nd.dot(lhs_dns, out_expected, 
transpose_a=backward_trans).asnumpy()
+
+        out = mx.nd.dot(lhs_nd, rhs_nd, transpose_a=trans_lhs)
+        out_dns = mx.nd.dot(lhs_dns, rhs_dns, transpose_a=trans_lhs)
+        out_np = out_dns.asnumpy()
         assert_almost_equal(out.asnumpy(), out_np, rtol=1e-4, atol=1e-5)
 
         # test symbolic forward
         lhs = mx.symbol.Variable('lhs', stype='csr')
         rhs = mx.symbol.Variable('rhs', stype=rhs_stype)
-        test = mx.symbol.dot(lhs, rhs, transpose_a=trans_lhs)
+        out = mx.symbol.dot(lhs, rhs, transpose_a=trans_lhs)
         location = {'lhs': lhs_nd, 'rhs': rhs_nd}
-        expected = {'rhs': rhs_backward_grad}
-        check_symbolic_forward(test, location, [out_np], rtol=1e-3, atol=1e-4)
+        check_symbolic_forward(out, location, [out_np], rtol=1e-3, atol=1e-4)
+
         # test symbolic backward
-        check_symbolic_backward(test, location, [out_np], expected,
+        backward_trans = not trans_lhs
+        rhs_backward_grad = mx.nd.dot(lhs_dns, out_dns, 
transpose_a=backward_trans).asnumpy()
+        expected = {'rhs': rhs_backward_grad}
+        check_symbolic_backward(out, location, [out_np], expected,
                                 grad_req={'lhs': 'null', 'rhs': 'write'},
                                 rtol=1e-3, atol=1e-4)
 
-    lhs_shape = rand_shape_2d(50, 200)
-    test_dot_csr(lhs_shape, (lhs_shape[1], 1), 'default', False) # test gpu 
SpMV
-    test_dot_csr(lhs_shape, (lhs_shape[0], 1), 'default', True ) # (vector 
kernel)
-    test_dot_csr(lhs_shape, (lhs_shape[1], rnd.randint(5, 10)), 'default', 
False) # test gpu SpMM
-    test_dot_csr(lhs_shape, (lhs_shape[0], rnd.randint(5, 10)), 'default', 
True ) # (scalar kernel)
-    if default_context().device_type is 'cpu':
-        test_dot_csr(lhs_shape, (lhs_shape[1], rnd.randint(1, 10)), 
'row_sparse', False)
-        test_dot_csr(lhs_shape, (lhs_shape[0], rnd.randint(1, 10)), 
'row_sparse', True )
-        test_dot_csr(lhs_shape, (lhs_shape[1], rnd.randint(1, 10)), 
'row_sparse', False, 0.05)
-        test_dot_csr(lhs_shape, (lhs_shape[0], rnd.randint(1, 10)), 
'row_sparse', True , 0.05)
+    density = [1.00, 0.50, 0.10, 0.05, 0.01]
+    for lhs_d in density:
+        lhs_shape = rand_shape_2d(50, 200)
+        rhs_d = 1
+        test_dot_csr(lhs_shape, (lhs_shape[1], 1), 'default', False, lhs_d, 
rhs_d) # test gpu SpMV
+        test_dot_csr(lhs_shape, (lhs_shape[0], 1), 'default', True , lhs_d, 
rhs_d) # (vector kernel)
+        test_dot_csr(lhs_shape, (lhs_shape[1], rnd.randint(5, 10)), 'default', 
False, lhs_d, rhs_d) # test gpu SpMM
+        test_dot_csr(lhs_shape, (lhs_shape[0], rnd.randint(5, 10)), 'default', 
True , lhs_d, rhs_d) # (scalar kernel)
+        for rhs_d in density:
+            test_dot_csr(lhs_shape, (lhs_shape[1], rnd.randint(1, 10)), 
'row_sparse', False, lhs_d, rhs_d)
+            test_dot_csr(lhs_shape, (lhs_shape[0], rnd.randint(1, 10)), 
'row_sparse', True, lhs_d, rhs_d)
 
 
 def test_sparse_slice():
@@ -265,7 +264,6 @@ def test_sparse_elementwise_sum():
     maxdim = 5
     for dim in range(2, maxdim):
         shape = tuple(np.random.randint(5, 10, size=dim))
-        print shape
         check_sparse_elementwise_sum_with_shape('row_sparse', shape, 
np.random.randint(1, 9))
 
 

-- 
To stop receiving notification emails like this one, please contact
['"comm...@mxnet.apache.org" <comm...@mxnet.apache.org>'].

Reply via email to