This is an automated email from the ASF dual-hosted git repository.
jroesch pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new 6d0351a Faster sparse_dense on GPUs (#6580)
6d0351a is described below
commit 6d0351a7f0e23eb5428c59a976edd2bfb8207c0d
Author: Tristan Konolige <[email protected]>
AuthorDate: Fri Oct 9 16:15:26 2020 -0700
Faster sparse_dense on GPUs (#6580)
* Faster sparse_dense on GPUs.
This new sparse_dense requires a padded matrix, so a new op
`sparse_dense_padded` has been added. AlterOpLayout should transform
`sparse_dense` to `sparse_dense_padded` when possible on the gpu.
* formatting
* more formatting
* Check that alteroplayout is definedbefore using it
* check if FTVMAlterOpLayout exists before using it
* formatting
* restore message passing
* Fix sparse_dense and sparse_dense_padded docs
* Fix old sparse_dense, autotvm and sparse_dense dont play well together
* Remove unused imports
* clarify warp count in cuda_transpose
* Document multidimensional access
* Warn users not to use sparse_dense_padded
* rename nn.sparse_dense_padded to nn.internal.sparse_dense_padded
---
python/tvm/relay/op/nn/_nn.py | 16 ++
python/tvm/relay/op/nn/nn.py | 6 +-
python/tvm/relay/op/strategy/cuda.py | 13 ++
python/tvm/relay/op/strategy/generic.py | 6 +
python/tvm/tir/ir_builder.py | 39 +++-
python/tvm/topi/cuda/sparse.py | 310 +++++++++++++++++++++++++--
python/tvm/topi/nn/sparse.py | 25 +++
src/relay/op/nn/sparse.cc | 36 +++-
src/relay/transforms/transform_layout.h | 12 ++
src/target/source/codegen_cuda.cc | 3 +-
src/te/operation/compute_op.cc | 4 +-
src/te/operation/op_util.cc | 3 +-
src/te/schedule/schedule_lang.cc | 2 +-
src/tir/transforms/lower_warp_memory.cc | 6 +-
src/tir/transforms/storage_access.cc | 2 +-
tests/python/topi/python/test_topi_sparse.py | 114 +++++++---
16 files changed, 537 insertions(+), 60 deletions(-)
diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py
index c83f6a9..9e47dc0 100644
--- a/python/tvm/relay/op/nn/_nn.py
+++ b/python/tvm/relay/op/nn/_nn.py
@@ -75,6 +75,22 @@ reg.register_strategy("nn.sparse_dense",
strategy.sparse_dense_strategy)
reg.register_pattern("nn.sparse_dense", reg.OpPattern.OUT_ELEMWISE_FUSABLE)
[email protected]_alter_op_layout("nn.sparse_dense")
+def alter_op_layout_sparse_dense(attrs, inputs, tinfos, out_type):
+ """Alternate the layout of sparse_dense"""
+ return topi.nn.sparse_dense_alter_layout(attrs, inputs, tinfos, out_type)
+
+
[email protected]_compute("nn.internal.sparse_dense_padded")
+def compute_sparse_dense_padded(attrs, inputs, out_type):
+ """Compute definition of sparse_dense_padded"""
+ raise NotImplementedError("nn.internal.sparse_dense_padded is only
available on cuda")
+
+
+reg.register_strategy("nn.internal.sparse_dense_padded",
strategy.sparse_dense_padded_strategy)
+reg.register_pattern("nn.internal.sparse_dense_padded",
reg.OpPattern.OUT_ELEMWISE_FUSABLE)
+
+
# sparse_transpose
@reg.register_compute("nn.sparse_transpose")
def compute_sparse_transpose(attrs, inputs, out_type):
diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py
index 86a76ff..1aad4e7 100644
--- a/python/tvm/relay/op/nn/nn.py
+++ b/python/tvm/relay/op/nn/nn.py
@@ -2016,7 +2016,7 @@ def sparse_dense(data, weight):
data : tvm.relay.Expr
The input data for the matrix multiplication
- weight : namedtuple.
+ weight : Union[namedtuple, Tuple[ndarray, ndarray, ndarray]].
The sparse weight matrix for the matrix multiplication.
Returns
@@ -2024,7 +2024,9 @@ def sparse_dense(data, weight):
result: tvm.relay.Expr
The computed result.
"""
- return _make.sparse_dense(data, weight.data, weight.indices, weight.indptr)
+ if hasattr(weight, "indices"):
+ return _make.sparse_dense(data, weight.data, weight.indices,
weight.indptr)
+ return _make.sparse_dense(data, weight[0], weight[1], weight[2])
def sparse_transpose(x):
diff --git a/python/tvm/relay/op/strategy/cuda.py
b/python/tvm/relay/op/strategy/cuda.py
index baa03f4..7031365 100644
--- a/python/tvm/relay/op/strategy/cuda.py
+++ b/python/tvm/relay/op/strategy/cuda.py
@@ -633,6 +633,19 @@ def sparse_dense_strategy_cuda(attrs, inputs, out_type,
target):
return strategy
+@sparse_dense_padded_strategy.register(["cuda", "gpu"])
+def sparse_dense_padded_strategy_cuda(attrs, inputs, out_type, target):
+ """sparse dense cuda strategy"""
+ strategy = _op.OpStrategy()
+ strategy.add_implementation(
+ wrap_compute_sparse_dense(topi.cuda.sparse_dense_padded),
+ wrap_topi_schedule(topi.cuda.schedule_sparse_dense_padded),
+ name="sparse_dense_padded.cuda",
+ plevel=10,
+ )
+ return strategy
+
+
@argsort_strategy.register(["cuda", "gpu"])
def argsort_strategy_cuda(attrs, inputs, out_type, target):
"""argsort cuda strategy"""
diff --git a/python/tvm/relay/op/strategy/generic.py
b/python/tvm/relay/op/strategy/generic.py
index 56ae976..0f99710 100644
--- a/python/tvm/relay/op/strategy/generic.py
+++ b/python/tvm/relay/op/strategy/generic.py
@@ -724,6 +724,12 @@ def sparse_dense_strategy(attrs, inputs, out_type, target):
return strategy
+@override_native_generic_func("sparse_dense_padded_strategy")
+def sparse_dense_padded_strategy(attrs, inputs, out_type, target):
+ """sparse dense padded generic strategy"""
+ raise NotImplementedError("sparse_dense_padded is only implemented for
cuda")
+
+
# sparse_transpose
@generic_func
def schedule_sparse_transpose(attrs, outs, target):
diff --git a/python/tvm/tir/ir_builder.py b/python/tvm/tir/ir_builder.py
index 8b999bf..77fe79b 100644
--- a/python/tvm/tir/ir_builder.py
+++ b/python/tvm/tir/ir_builder.py
@@ -42,6 +42,9 @@ class BufferVar(ObjectGeneric):
Do not create it directly, create use IRBuilder.
+ BufferVars support array access either via a linear index, or, if given a
+ shape, via a multidimensional index.
+
Examples
--------
In the follow example, x is BufferVar.
@@ -55,6 +58,12 @@ class BufferVar(ObjectGeneric):
x = ib.pointer("float32")
x[0] = x[10] + 1
+ y = ib.allocate("float32", (32, 32))
+ # Array access using a linear index
+ y[(2*32) + 31] = 0.
+ # The same array access using a multidimensional index
+ y[2, 31] = 0.
+
See Also
--------
IRBuilder.pointer
@@ -62,9 +71,10 @@ class BufferVar(ObjectGeneric):
IRBuilder.allocate
"""
- def __init__(self, builder, buffer_var, content_type):
+ def __init__(self, builder, buffer_var, shape, content_type):
self._builder = builder
self._buffer_var = buffer_var
+ self._shape = shape
self._content_type = content_type
def asobject(self):
@@ -74,8 +84,23 @@ class BufferVar(ObjectGeneric):
def dtype(self):
return self._content_type
+ def _linear_index(self, index):
+ if not isinstance(index, tuple) or self._shape is None:
+ return index
+ assert len(index) == len(self._shape), "Index size (%s) does not match
shape size (%s)" % (
+ len(index),
+ len(self._shape),
+ )
+ dim_size = 1
+ lidx = 0
+ for dim, idx in zip(reversed(self._shape), reversed(index)):
+ lidx += idx * dim_size
+ dim_size *= dim
+ return lidx
+
def __getitem__(self, index):
t = DataType(self._content_type)
+ index = self._linear_index(index)
if t.lanes > 1:
base = index * t.lanes
index = _expr.Ramp(base, const(1, base.dtype), t.lanes)
@@ -87,6 +112,7 @@ class BufferVar(ObjectGeneric):
raise ValueError(
"data type does not match content type %s vs %s" %
(value.dtype, self._content_type)
)
+ index = self._linear_index(index)
t = DataType(self._content_type)
if t.lanes > 1:
base = index * t.lanes
@@ -341,7 +367,7 @@ class IRBuilder(object):
if scope:
self.scope_attr(buffer_var, "storage_scope", scope)
self.emit(lambda x: _stmt.Allocate(buffer_var, dtype, shape, const(1,
dtype="uint1"), x))
- return BufferVar(self, buffer_var, dtype)
+ return BufferVar(self, buffer_var, shape, dtype)
def pointer(self, content_type, name="ptr"):
"""Create pointer variable with content type.
@@ -360,9 +386,9 @@ class IRBuilder(object):
The buffer var representing the buffer.
"""
buffer_var = _expr.Var(name, dtype="handle")
- return BufferVar(self, buffer_var, content_type)
+ return BufferVar(self, buffer_var, None, content_type)
- def buffer_ptr(self, buf):
+ def buffer_ptr(self, buf, shape=None):
"""Create pointer variable corresponds to buffer ptr.
Parameters
@@ -370,12 +396,15 @@ class IRBuilder(object):
buf : Buffer
The buffer to be extracted.
+ shape : Tuple
+ Optional shape of the buffer. Overrides existing buffer shape.
+
Returns
-------
ptr : BufferVar
The buffer var representing the buffer.
"""
- return BufferVar(self, buf.data, buf.dtype)
+ return BufferVar(self, buf.data, buf.shape if shape is None else
shape, buf.dtype)
def likely(self, expr):
"""Add likely tag for expression.
diff --git a/python/tvm/topi/cuda/sparse.py b/python/tvm/topi/cuda/sparse.py
index d1d31a6..3fd6fbe 100644
--- a/python/tvm/topi/cuda/sparse.py
+++ b/python/tvm/topi/cuda/sparse.py
@@ -16,15 +16,17 @@
# under the License.
"""Sparse operators"""
-from tvm import te
-from tvm import autotvm
-from tvm.autotvm.task.space import SplitEntity
-from ..util import traverse_inline
+import numpy as np
+import scipy.sparse as sp
+
+import tvm
+from tvm import relay, te
+
from .. import nn
+from ..util import traverse_inline
[email protected]_topi_compute("sparse_dense.cuda")
-def sparse_dense(cfg, data, weight_data, weight_indices, weight_indptr):
+def sparse_dense(data, weight_data, weight_indices, weight_indptr):
"""
Computes sparse-dense matrix multiplication of `data` and
`(weight_data, weight_indices, weight_indptr).T`
@@ -58,8 +60,7 @@ def sparse_dense(cfg, data, weight_data, weight_indices,
weight_indptr):
return nn.sparse_dense(data, weight_data, weight_indices, weight_indptr)
[email protected]_topi_schedule("sparse_dense.cuda")
-def schedule_sparse_dense(cfg, outs):
+def schedule_sparse_dense(outs):
"""Create schedule for sparse dense"""
# pylint:disable=invalid-name
s = te.create_schedule([x.op for x in outs])
@@ -83,12 +84,7 @@ def schedule_sparse_dense(cfg, outs):
thread_x = te.thread_axis("threadIdx.x")
- cfg.define_split("tile_c", c, num_outputs=2)
- if cfg.is_fallback:
- cfg["tile_c"] = SplitEntity([-1, 8])
- _, ci = cfg["tile_c"].apply(s, y_bsrmm, c)
-
- y_bsrmm_factored = s.rfactor(y_bsrmm, ci)
+ y_bsrmm_factored = s.rfactor(y_bsrmm, c)
tx = s[y_bsrmm].op.reduce_axis[0]
s[y_bsrmm].bind(tx, thread_x)
s[y_bsrmm_factored].compute_at(s[y_bsrmm], tx)
@@ -97,3 +93,289 @@ def schedule_sparse_dense(cfg, outs):
traverse_inline(s, outs[0].op, _callback)
return s
+
+
+def schedule_cuda_transpose(s, out):
+ """Schedule for transpose on the gpu.
+
+ Roughly follows this:
+ https://developer.nvidia.com/blog/efficient-matrix-transpose-cuda-cc/, but
+ without the padding for shared memory. For better performance, we could
+ rewrite it in tir to add the padding.
+ """
+
+ def _callback(op):
+ # pylint: disable=invalid-name
+ m, n = s[op].op.axis
+ warp_size =
int(tvm.target.Target.current(allow_none=False).thread_warp_size)
+ no, ni = s[op].split(n, factor=warp_size)
+ mo, mi = s[op].split(m, factor=warp_size)
+ s[op].reorder(mo, no, mi, ni)
+ s[op].bind(mo, te.thread_axis("blockIdx.x"))
+ s[op].bind(no, te.thread_axis("blockIdx.y"))
+ c = s.cache_read(op.input_tensors[0], "shared", op)
+ s[c].compute_at(s[op], no)
+ thread_x = te.thread_axis("threadIdx.x")
+ thread_y = te.thread_axis("threadIdx.y")
+ s[op].bind(ni, thread_x)
+ # This is a hack to make the scheduling language realize that this axis
+ # can be scheduled.
+ a, _ = s[c].split(s[c].op.axis[1], factor=1)
+ s[c].bind(a, thread_x)
+ # Use 4 warps per block. Slightly faster than 1 warp per block
+ ao, _ = s[op].split(mi, nparts=4)
+ s[op].bind(ao, thread_y)
+ ao, _ = s[c].split(s[c].op.axis[0], nparts=4)
+ s[c].bind(ao, thread_y)
+
+ traverse_inline(s, out.op, _callback)
+
+
+def sparse_dense_tir(data, w_data, w_indices, w_indptr):
+ """Compute data * w^T.
+
+ Actually computes (w * data^T) ^ T as data needs to be in column-major
+ format for performance reasons.
+
+ Good resources:
+ Yang, Carl, Aydın Buluç, and John D. Owens. "Design principles for sparse
+ matrix multiplication on the GPU." European Conference on Parallel
+ Processing. Springer, Cham, 2018. <- This code is basically row-split from
here.
+ Gale, Trevor, et al. "Sparse GPU Kernels for Deep Learning." arXiv preprint
+ arXiv:2006.10901 (2020).
+
+
+ Profile with
+ `/opt/nvidia/nsight-compute/2020.1.2/ncu -k default_function_kernel1
+ --section '.*' -s 1 -c 1 venv/bin/python3 test_topi_sparse.py manual`
+ with either default_function_kernel0 for the transpose or
+ default_function_kernel1 for the multiply.
+ """
+
+ def ceil_div(a, b):
+ return (a + (b - 1)) // b
+
+ def gen_ir(data, w_data, w_indices, w_indptr, out):
+ # pylint: disable=invalid-name
+ # TODO(tkonolige): use tensorcores for block multiply
+ # TODO(tkonolige): use vectorize on loads
+ # TODO(tkonolige): seperate implementation if M is small
+ # TODO(tkonolige): seperate implementation for large block sizes
+ ib = tvm.tir.ir_builder.create()
+
+ warp_size =
int(tvm.target.Target.current(allow_none=False).thread_warp_size)
+ m = data.shape[1]
+ nb = w_indptr.shape[0] - 1
+ nnzb = w_data.shape[0]
+ # treat csr like block size 1 bsr
+ if len(w_data.shape) == 1:
+ bs_n = 1
+ bs_k = 1
+ else:
+ bs_n = w_data.shape[1]
+ bs_k = w_data.shape[2]
+ bs_m = bs_n
+ mb = m // bs_m
+ mi = warp_size
+ assert (
+ mb >= mi
+ ), "Number of block rows in dense matrix must be larger than warp
size: {} vs {}.".format(
+ warp_size, m
+ )
+ mo = ceil_div(mb, mi)
+ ni = 1 # TODO(tkonolige): how do I compute the number of warps per
block?
+ no = ceil_div(nb, ni)
+ rowlength_bi = warp_size
+
+ bx = te.thread_axis("blockIdx.x")
+ ib.scope_attr(bx, "thread_extent", mo)
+ by = te.thread_axis("blockIdx.y")
+ ib.scope_attr(by, "thread_extent", no)
+ tx = te.thread_axis("threadIdx.x")
+ ib.scope_attr(tx, "thread_extent", warp_size)
+ warp = te.thread_axis("threadIdx.y")
+ ib.scope_attr(warp, "thread_extent", ni)
+
+ out_ptr = ib.buffer_ptr(out)
+ data_ptr = ib.buffer_ptr(data)
+ w_data_ptr = ib.buffer_ptr(w_data, shape=(nnzb, bs_n, bs_k))
+ w_indices_ptr = ib.buffer_ptr(w_indices)
+ w_indptr_ptr = ib.buffer_ptr(w_indptr)
+
+ n_index = by * ni + warp
+ m_index = bx * mi + tx
+ row_start = w_indptr_ptr[n_index]
+
+ # Guaranteed to be evenly divisible
+ rowlength_bo = ceil_div(w_indptr_ptr[n_index + 1] - row_start,
rowlength_bi)
+
+ # thread local storage for bs_m x bs_n block
+ block = ib.allocate(data.dtype, (bs_m, bs_n), name="block",
scope="local")
+ indices = ib.allocate(w_indices.dtype, (rowlength_bi,),
name="indices", scope="warp")
+ data_cache = ib.allocate(data.dtype, (mi, bs_m, bs_k),
name="data_cache", scope="local")
+ w_data_cache = ib.allocate(
+ w_data.dtype, (rowlength_bi, bs_n, bs_k), name="w_data_cache",
scope="warp"
+ )
+
+ # zero block
+ with ib.for_range(0, bs_m, name="x", for_type="unroll") as x:
+ with ib.for_range(0, bs_n, name="y", for_type="unroll") as y:
+ block[x, y] = 0.0
+ # compute into thread local storage using warp_size chunks
+ with ib.for_range(0, rowlength_bo, name="bb") as bb:
+ elem_idx = bb * rowlength_bi + tx
+ # Cache indices. Guaranteed to be multiple of warp_size.
+ indices[elem_idx] = w_indices_ptr[row_start + elem_idx]
+ # cache dense matrix
+ # each thread has a row
+ # TODO: ideally we could vectorize this
+ with ib.for_range(0, rowlength_bi, name="bi") as bi:
+ with ib.for_range(0, bs_m, name="x", for_type="unroll") as x:
+ with ib.for_range(0, bs_k, name="z", for_type="unroll") as
z:
+ # This memory acces should be out of bounds when
+ # m_index >= mb (which occurs when the dense matrix
+ # rows % 32 != 0), but it seems to work just fine...
+ data_cache[bi, x, z] = data_ptr[indices[bi] * bs_k +
z, m_index * bs_m + x]
+ # cache w_data
+ elem_idx = bb * rowlength_bi + tx
+ with ib.for_range(0, bs_n, name="y", for_type="unroll") as y:
+ with ib.for_range(0, bs_k, name="z", for_type="unroll") as z:
+ w_data_cache[tx, y, z] = w_data_ptr[row_start + elem_idx,
y, z]
+ with ib.for_range(0, mi, name="i") as i:
+ # thread local block matmul
+ with ib.for_range(0, bs_m, name="x", for_type="unroll") as x:
+ with ib.for_range(0, bs_n, name="y", for_type="unroll") as
y:
+ with ib.for_range(0, bs_k, name="z",
for_type="unroll") as z:
+ block[x, y] += data_cache[i, x, z] *
w_data_cache[i, y, z]
+ # store results
+ with ib.for_range(0, bs_m, name="x", for_type="unroll") as x:
+ with ib.for_range(0, bs_n, name="y", for_type="unroll") as y:
+ with ib.if_scope(m_index < mb):
+ with ib.if_scope(n_index < nb):
+ # It doesn't seem like we would be getting coelesced
+ # writes here, but it doesn't seem to matter
+ out_ptr[m_index * bs_m + x, n_index * bs_n + y] =
block[x, y]
+
+ return ib.get()
+
+ data_t = tvm.topi.transpose(data)
+ # handle csr
+ if len(w_data.shape) == 1:
+ blocksize = 1
+ else:
+ blocksize = w_data.shape[1]
+ out_shape = (data_t.shape[1], (w_indptr.shape[0] - 1) * blocksize)
+ out_buf = tvm.tir.decl_buffer(out_shape, data.dtype, "out_buf")
+ out = te.extern(
+ [out_shape],
+ [data_t, w_data, w_indices, w_indptr, data],
+ lambda ins, outs: gen_ir(ins[0], ins[1], ins[2], ins[3], outs[0]),
+ dtype=data.dtype,
+ out_buffers=[out_buf],
+ name="sparse_dense_gpu",
+ tag="sparse_dense_gpu",
+ )
+ return out
+
+
+def sparse_dense_padded(data, weight_data, weight_indices, weight_indptr):
+ """
+ Computes sparse-dense matrix multiplication of `data` and
+ `(weight_data, weight_indices, weight_indptr).T`
+
+ This variation uses a padded matrix where all row lengths are a multiple
of the warp size.
+
+ Parameters
+ ----------
+ cfg: ConfigEntity
+ The config for this template
+
+ data : tvm.te.Tensor
+ 2-D with shape [M, K], float32
+
+ weight_data : tvm.te.Tensor
+ 1-D with shape [nnz] (CSR) or
+ 3-D with shape [num_blocks, bs_r, bs_c] (BSR)
+
+ weight_indices : tvm.te.Tensor
+ 1-D with shape [nnz] (CSR) or
+ 1-D with shape [num_blocks] (BSR)
+
+ weight_indptr : tvm.te.Tensor
+ 1-D with shape [N + 1] (CSR) or
+ 1-D with shape [(N + 1) // bs_r] (BSR)
+
+ Returns
+ -------
+ output : tvm.te.Tensor
+ 2-D with shape [M, N]
+ """
+ return sparse_dense_tir(data, weight_data, weight_indices, weight_indptr)
+
+
+def schedule_sparse_dense_padded(outs):
+ """Create schedule for sparse dense"""
+ # XXX: this will fail if we don't include the data_t Tensor in the schedule
+ # ops. Maybe create_schedule should do some analysis so this isn't
+ # necessary
+ data_t = outs[0].op.input_tensors[0]
+ s = te.create_schedule([outs[0].op, data_t.op])
+ schedule_cuda_transpose(s, outs[0].op.input_tensors[0])
+ return s
+
+
+def pad_sparse_matrix(matrix, blocksize):
+ """Pad rows of sparse matrix matrix so that they are a multiple of
blocksize."""
+ assert isinstance(matrix, sp.bsr_matrix)
+ new_entries = np.zeros(matrix.shape[0], dtype=matrix.indptr.dtype)
+ bsr = matrix.blocksize[0]
+ for i in range(matrix.shape[0] // bsr):
+ row_length = matrix.indptr[i + 1] - matrix.indptr[i]
+ if row_length % blocksize != 0:
+ new_entries[i] = blocksize - (row_length % blocksize)
+ additional = np.sum(new_entries)
+ indices = np.zeros(matrix.indices.shape[0] + additional,
dtype=matrix.indices.dtype)
+ data = np.zeros(
+ (matrix.data.shape[0] + additional, matrix.data.shape[1],
matrix.data.shape[2]),
+ dtype=matrix.data.dtype,
+ )
+
+ n = matrix.shape[0] // bsr
+ indptr = np.zeros(n + 1, dtype=matrix.indptr.dtype)
+ indptr[: matrix.indptr.shape[0]] = matrix.indptr
+
+ for i in range(matrix.shape[0] // bsr):
+ indptr[i + 1] = indptr[i] + new_entries[i] + (matrix.indptr[i + 1] -
matrix.indptr[i])
+ indices[indptr[i] : indptr[i + 1] - new_entries[i]] = matrix.indices[
+ matrix.indptr[i] : matrix.indptr[i + 1]
+ ]
+ data[indptr[i] : indptr[i + 1] - new_entries[i], :, :] = matrix.data[
+ matrix.indptr[i] : matrix.indptr[i + 1], :, :
+ ]
+
+ return sp.bsr_matrix((data, indices, indptr), matrix.shape)
+
+
[email protected]_dense_alter_layout.register(["cuda", "gpu"])
+def _alter_sparse_dense_layout(_attrs, inputs, _tinfos, _out_type):
+ """With cuda, we modify use alter_op_layout to swap the default
+ sparse_dense implementation for one that operates on a padded matrix. We
+ also padd the matrix.
+ """
+ if (
+ isinstance(inputs[1], relay.Constant)
+ and isinstance(inputs[2], relay.Constant)
+ and isinstance(inputs[3], relay.Constant)
+ ):
+ sparse_matrix = sp.bsr_matrix(
+ (inputs[1].data.asnumpy(), inputs[2].data.asnumpy(),
inputs[3].data.asnumpy())
+ )
+ warp_size =
int(tvm.target.Target.current(allow_none=False).thread_warp_size)
+ sparse_matrix = pad_sparse_matrix(sparse_matrix, warp_size)
+ return relay.nn._make.sparse_dense_padded(
+ inputs[0],
+ relay.Constant(tvm.nd.array(sparse_matrix.data)),
+ relay.Constant(tvm.nd.array(sparse_matrix.indices)),
+ relay.Constant(tvm.nd.array(sparse_matrix.indptr)),
+ )
+ return None
diff --git a/python/tvm/topi/nn/sparse.py b/python/tvm/topi/nn/sparse.py
index e3c144a..74a9ad5 100644
--- a/python/tvm/topi/nn/sparse.py
+++ b/python/tvm/topi/nn/sparse.py
@@ -207,3 +207,28 @@ def _csr_transpose_ir(data, indices, indptr, out_data,
out_indices, out_indptr):
last[0] = temp2[0]
return irb.get()
+
+
[email protected]_func
+def sparse_dense_alter_layout(_attrs, _inputs, _tinfos, _out_type):
+ """Change Sparse Dense layout.
+
+ This is used for modifying the inputs weights so they are more amenable for
+ the target.
+
+ Parameters
+ ----------
+ attrs : tvm.ir.Attrs
+ Attributes of current convolution
+ inputs : tvm.relay.Expr
+ Grouped input symbols
+ tinfos : list
+ Input shape and dtype
+ out_type: type
+ The output type
+
+ Note
+ ----
+ Unlike other TOPI functions, this function operates on both graph level
and operator level.
+ """
+ return None
diff --git a/src/relay/op/nn/sparse.cc b/src/relay/op/nn/sparse.cc
index 0aca00c..f12afe2 100644
--- a/src/relay/op/nn/sparse.cc
+++ b/src/relay/op/nn/sparse.cc
@@ -76,7 +76,41 @@ TVM_REGISTER_GLOBAL("relay.op.nn._make.sparse_dense")
});
RELAY_REGISTER_OP("nn.sparse_dense")
- .describe(R"code(Applies a sparse linear transformation: :math:`Y = XW^T`
with X sparse.
+ .describe(R"code(Applies a sparse linear transformation: :math:`Y = XW^T`
with W sparse.
+
+- **data**: `(x1, x2, ..., xn, input_dim)`
+- **weight**: `(units, input_dim)`
+- **out**: `(x1, x2, ..., xn, units)`.
+
+)code" TVM_ADD_FILELINE)
+ .set_attrs_type<SparseDenseAttrs>()
+ .set_num_inputs(4)
+ .add_argument("data", "nD Tensor", "Input data.")
+ .add_argument("weight_data", "1D Tensor", "Weight data matrix.")
+ .add_argument("weight_indices", "1D Tensor", "Weight indices matrix.")
+ .add_argument("weight_indptr", "1D Tensor", "Weight indptr matrix.")
+ .set_support_level(1)
+ .add_type_rel("SparseDense", SparseDenseRel);
+
+Expr MakeSparseDensePadded(Expr data, Expr weight_data, Expr weight_indices,
Expr weight_indptr) {
+ auto attrs = make_object<SparseDenseAttrs>();
+ static const Op& op = Op::Get("nn.internal.sparse_dense_padded");
+ return Call(op, {data, weight_data, weight_indices, weight_indptr},
Attrs(attrs), {});
+}
+
+TVM_REGISTER_GLOBAL("relay.op.nn._make.sparse_dense_padded")
+ .set_body([](const TVMArgs& args, TVMRetValue* rv) {
+ runtime::detail::unpack_call<Expr, 4>(MakeSparseDensePadded, args, rv);
+ });
+
+RELAY_REGISTER_OP("nn.internal.sparse_dense_padded")
+ .describe(
+ R"code(Applies a sparse linear transformation: :math:`Y = XW^T` with W
+sparse. This variation uses a matrix with row lengths padded to a
+multiple of 32 for better GPU performance.
+
+This op should not be directly used by a user. Instead, use `sparse_dense`
+which will be converted to this op when running on the GPU.
- **data**: `(x1, x2, ..., xn, input_dim)`
- **weight**: `(units, input_dim)`
diff --git a/src/relay/transforms/transform_layout.h
b/src/relay/transforms/transform_layout.h
index bf9bcb9..61a7440 100644
--- a/src/relay/transforms/transform_layout.h
+++ b/src/relay/transforms/transform_layout.h
@@ -267,6 +267,18 @@ Expr LayoutRewriter(const Call& ref_call, const
Array<Expr>& new_args, const Obj
}
}
+ // If there is no FInferCorrectLayout for the type, then we just assume the
layout is correct.
+ static auto finfer_layout =
Op::GetAttrMap<FInferCorrectLayout>("FInferCorrectLayout");
+ if (Op::HasAttrMap("FTVMAlterOpLayout")) {
+ static auto falter_layout =
Op::GetAttrMap<FTVMAlterOpLayout>("FTVMAlterOpLayout");
+ if (ref_call->op.as<OpNode>()) {
+ Op op = Downcast<Op>(ref_call->op);
+ if (falter_layout.count(op) && !finfer_layout.count(op)) {
+ return memorizer.CallWithNewLayouts(ref_call, normal_new_args);
+ }
+ }
+ }
+
// old_in, new_in = state[inputs]
Array<Layout> old_in, old_out, new_in, new_out, new_in2;
for (auto inp : inputs) {
diff --git a/src/target/source/codegen_cuda.cc
b/src/target/source/codegen_cuda.cc
index 7dc63d4..d57efa0 100644
--- a/src/target/source/codegen_cuda.cc
+++ b/src/target/source/codegen_cuda.cc
@@ -394,7 +394,8 @@ void CodeGenCUDA::PrintStorageSync(const CallNode* op) {
}
void CodeGenCUDA::PrintStorageScope(const std::string& scope, std::ostream&
os) { // NOLINT(*)
- CHECK_NE(scope, "global");
+ CHECK_NE(scope, "global") << "Cannot allocate global memory when targeting
CUDA. You must pass "
+ "all global arrays as input instead";
if (scope == "shared") {
os << "__shared__ ";
}
diff --git a/src/te/operation/compute_op.cc b/src/te/operation/compute_op.cc
index c3b2a0b..527b251 100644
--- a/src/te/operation/compute_op.cc
+++ b/src/te/operation/compute_op.cc
@@ -46,7 +46,9 @@ using namespace tir;
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<ComputeOpNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const ComputeOpNode*>(node.get());
- p->stream << "compute(" << op->name << ", " << op << ")";
+ p->stream << "compute(" << op->name << ", body=" << op->body << ",
axis=" << op->axis
+ << ", reduce_axis=" << op->reduce_axis << ", tag=" << op->tag
+ << ", attrs=" << op->attrs << ")";
});
TVM_REGISTER_NODE_TYPE(ComputeOpNode);
diff --git a/src/te/operation/op_util.cc b/src/te/operation/op_util.cc
index db649e5..2abf68a7 100644
--- a/src/te/operation/op_util.cc
+++ b/src/te/operation/op_util.cc
@@ -150,7 +150,8 @@ std::vector<std::vector<Stmt> > MakeLoopNest(const Stage&
stage,
value_map[iv] = dom->min;
} else {
// Always restrict threaded IterVar to starts from 0.
- CHECK(is_zero(dom->min));
+ CHECK(is_zero(dom->min)) << "Itervar " << iv << " must start at zero,
but it starts at "
+ << dom->min;
// annotate the extent of the IterVar
nest[i + 1].emplace_back(AttrStmt(bind_iv, tir::attr::thread_extent,
dom->extent, no_op));
if (!debug_keep_trivial_loop && is_one(dom->extent)) {
diff --git a/src/te/schedule/schedule_lang.cc b/src/te/schedule/schedule_lang.cc
index d6327ff..a8257c0 100644
--- a/src/te/schedule/schedule_lang.cc
+++ b/src/te/schedule/schedule_lang.cc
@@ -761,7 +761,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<StageNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const StageNode*>(node.get());
if (op->op.defined()) {
- p->stream << "stage(" << op->origin_op->name << ", " << op << ")";
+ p->stream << "stage(" << op->origin_op->name << ", " << op->op << ")";
} else {
p->stream << "group-stage(" << op << ")";
}
diff --git a/src/tir/transforms/lower_warp_memory.cc
b/src/tir/transforms/lower_warp_memory.cc
index 8892c32..cb6c609 100644
--- a/src/tir/transforms/lower_warp_memory.cc
+++ b/src/tir/transforms/lower_warp_memory.cc
@@ -129,7 +129,11 @@ class WarpStoreCoeffFinder : private StmtVisitor {
void UpdatePattern(const PrimExpr& index) {
Array<PrimExpr> m = arith::DetectLinearEquation(index, {warp_index_});
- CHECK_EQ(m.size(), 2U) << "LowerWarpMemory failed due to store index=" <<
index;
+ CHECK_EQ(m.size(), 2U)
+ << "LowerWarpMemory failed. Could not simplify the store index `" <<
index
+ << "` into the form ax + by + cz + ... Warp memory is approximated by
storing values in "
+ "thread local registers and shuffling values between these
registers. Currently only "
+ "linear equation indices are supported.";
PrimExpr mcoeff = analyzer_->canonical_simplify(m[0]);
const auto* mcoeff_as_int = mcoeff.as<IntImmNode>();
CHECK(mcoeff_as_int && mcoeff_as_int->value > 0)
diff --git a/src/tir/transforms/storage_access.cc
b/src/tir/transforms/storage_access.cc
index 1914609..f9adfb8 100644
--- a/src/tir/transforms/storage_access.cc
+++ b/src/tir/transforms/storage_access.cc
@@ -37,7 +37,7 @@ void StorageAccessVisitor::VisitExpr_(const LoadNode* op) {
const VarNode* buf = op->buffer_var.as<VarNode>();
StorageScope scope = GetScope(buf);
if (Enabled(buf, scope)) {
- CHECK(allow_append_);
+ CHECK(allow_append_) << op << " " << scope.to_string();
AccessEntry e;
e.threads = env_threads();
e.buffer = op->buffer_var;
diff --git a/tests/python/topi/python/test_topi_sparse.py
b/tests/python/topi/python/test_topi_sparse.py
index b50110a..07af478 100644
--- a/tests/python/topi/python/test_topi_sparse.py
+++ b/tests/python/topi/python/test_topi_sparse.py
@@ -19,6 +19,7 @@ import numpy as np
import tvm
from tvm import te
from tvm import topi
+from tvm import relay
import tvm.topi.testing
from tvm.topi.util import get_const_tuple
import tvm.contrib.sparse as tvmsp
@@ -329,11 +330,11 @@ def random_bsr_matrix(M, N, BS_R, BS_C, density, dtype):
return s
-def verify_sparse_dense_bsr(M, N, K, BS_R, BS_C, density, use_relu):
+def verify_sparse_dense_bsr(M, N, K, BS_R, BS_C, density, use_relu, ctx,
target):
X_np = np.random.randn(M, K).astype("float32")
W_sp_np = random_bsr_matrix(N, K, BS_R, BS_C, density=density,
dtype="float32")
W_np = W_sp_np.todense()
- Y_np = X_np.dot(W_np.T)
+ Y_np = X_np @ W_np.T
if use_relu:
Y_np = np.maximum(Y_np, 0.0)
@@ -342,38 +343,29 @@ def verify_sparse_dense_bsr(M, N, K, BS_R, BS_C, density,
use_relu):
W_indptr = te.placeholder(shape=W_sp_np.indptr.shape,
dtype=str(W_sp_np.indptr.dtype))
X = te.placeholder(shape=X_np.shape, dtype=str(X_np.dtype))
- def check_device(device):
- ctx = tvm.context(device, 0)
- if not tvm.testing.device_enabled(device):
- print("Skip because %s is not enabled" % device)
- return
- print("Running on target: %s" % device)
- fcompute, fschedule = tvm.topi.testing.dispatch(device,
_sparse_dense_implement)
- with tvm.target.Target(device):
- Y = fcompute(X, W_data, W_indices, W_indptr)
- if use_relu:
- Y = topi.nn.relu(Y)
- s = fschedule([Y])
- func = tvm.build(s, [X, W_data, W_indices, W_indptr, Y])
- Y_tvm = tvm.nd.array(np.zeros(Y_np.shape, dtype=Y_np.dtype),
ctx=ctx)
- func(
- tvm.nd.array(X_np, ctx=ctx),
- tvm.nd.array(W_sp_np.data, ctx=ctx),
- tvm.nd.array(W_sp_np.indices, ctx=ctx),
- tvm.nd.array(W_sp_np.indptr, ctx=ctx),
- Y_tvm,
- )
- tvm.testing.assert_allclose(Y_tvm.asnumpy(), Y_np, atol=1e-4,
rtol=1e-4)
-
- for device in ["llvm", "cuda"]:
- check_device(device)
+ fcompute, fschedule = tvm.topi.testing.dispatch(target,
_sparse_dense_implement)
+ with tvm.target.Target(target):
+ Y = fcompute(X, W_data, W_indices, W_indptr)
+ if use_relu:
+ Y = topi.nn.relu(Y)
+ s = fschedule([Y])
+ func = tvm.build(s, [X, W_data, W_indices, W_indptr, Y])
+ Y_tvm = tvm.nd.array(np.zeros(Y_np.shape, dtype=Y_np.dtype), ctx=ctx)
+ func(
+ tvm.nd.array(X_np, ctx=ctx),
+ tvm.nd.array(W_sp_np.data, ctx=ctx),
+ tvm.nd.array(W_sp_np.indices, ctx=ctx),
+ tvm.nd.array(W_sp_np.indptr, ctx=ctx),
+ Y_tvm,
+ )
+ tvm.testing.assert_allclose(Y_tvm.asnumpy(), Y_np, atol=1e-4,
rtol=1e-4)
[email protected]_gpu
-def test_sparse_dense_bsr():
[email protected]_targets("llvm", "cuda")
+def test_sparse_dense_bsr_relu(ctx, target):
M, N, K, BS_R, BS_C, density = 1, 64, 128, 8, 16, 0.9
- verify_sparse_dense_bsr(M, N, K, BS_R, BS_C, density, use_relu=True)
- verify_sparse_dense_bsr(M, N, K, BS_R, BS_C, density, use_relu=False)
+ verify_sparse_dense_bsr(M, N, K, BS_R, BS_C, density, True, ctx, target)
+ verify_sparse_dense_bsr(M, N, K, BS_R, BS_C, density, False, ctx, target)
@tvm.testing.uses_gpu
@@ -421,11 +413,69 @@ def test_sparse_dense_bsr_randomized():
check_device(device)
[email protected]_cuda
+def test_sparse_dense_padded_cuda():
+ M = 128
+ N = 1280
+ K = 128
+ X_np = np.random.randn(M, K).astype("float32")
+ W_sp_np = random_bsr_matrix(N, K, 1, 1, density=0.01, dtype="float32")
+ W_sp_np_padded = tvm.topi.cuda.pad_sparse_matrix(W_sp_np, 32)
+
+ W_np = W_sp_np.todense()
+ Y_np = X_np @ W_sp_np.T
+
+ W_data = te.placeholder(shape=W_sp_np_padded.data.shape,
dtype=str(W_sp_np_padded.data.dtype))
+ W_indices = te.placeholder(
+ shape=W_sp_np_padded.indices.shape,
dtype=str(W_sp_np_padded.indices.dtype)
+ )
+ W_indptr = te.placeholder(
+ shape=W_sp_np_padded.indptr.shape,
dtype=str(W_sp_np_padded.indptr.dtype)
+ )
+ X = te.placeholder(shape=X_np.shape, dtype=str(X_np.dtype))
+ with tvm.target.Target("cuda"):
+ ctx = tvm.context("gpu")
+ Y = topi.cuda.sparse_dense_padded(X, W_data, W_indices, W_indptr)
+ s = topi.cuda.schedule_sparse_dense_padded([Y])
+ func = tvm.build(s, [X, W_data, W_indices, W_indptr, Y])
+ Y_tvm = tvm.nd.array(np.zeros(Y_np.shape, dtype=Y_np.dtype), ctx=ctx)
+ func(
+ tvm.nd.array(X_np, ctx=ctx),
+ tvm.nd.array(W_sp_np_padded.data, ctx=ctx),
+ tvm.nd.array(W_sp_np_padded.indices, ctx=ctx),
+ tvm.nd.array(W_sp_np_padded.indptr, ctx=ctx),
+ Y_tvm,
+ )
+ tvm.testing.assert_allclose(Y_tvm.asnumpy(), Y_np, atol=1e-5,
rtol=1e-5)
+
+
[email protected]_cuda
+def test_sparse_dense_padded_alter_op():
+ with tvm.target.Target("cuda"):
+ M = 128
+ N = 16
+ K = 128
+ X_np = np.random.randn(M, K).astype("float32")
+ W_sp_np = random_bsr_matrix(N, K, 2, 2, density=0.01, dtype="float32")
+ mult = relay.op.nn.sparse_dense(
+ relay.Constant(tvm.nd.array(X_np)),
+ (
+ relay.Constant(tvm.nd.array(W_sp_np.data)),
+ relay.Constant(tvm.nd.array(W_sp_np.indices)),
+ relay.Constant(tvm.nd.array(W_sp_np.indptr)),
+ ),
+ )
+ f = relay.Function([], mult)
+ f_ = relay.transform.AlterOpLayout()(tvm.IRModule.from_expr(f))
+ assert f_["main"].body.op.name == "nn.internal.sparse_dense_padded"
+
+
if __name__ == "__main__":
test_csrmv()
test_csrmm()
test_dense()
test_sparse_dense_csr()
- test_sparse_dense_bsr()
test_sparse_dense_bsr_randomized()
test_sparse_transpose_csr()
+ test_sparse_dense_padded_cuda()
+ test_sparse_dense_padded_alter_op()