Wheest commented on a change in pull request #7050:
URL: https://github.com/apache/tvm/pull/7050#discussion_r538449021



##########
File path: python/tvm/topi/sparse/csrmm.py
##########
@@ -121,3 +121,46 @@ def csrmm(a, b, c=None):
         2-D with shape [m, n]
     """
     return csrmm_default(a.data, a.indices, a.indptr, b, c)
+
+
+def batch_csrmm(data, indices, indptr, dense, oshape):
+    # pylint: disable=invalid-name
+    assert len(data.shape) == 1 and len(indices.shape) == 1 and 
len(indptr.shape) == 1 \
+        and len(dense.shape) == 3, "only support 2-dim csrmm"
+    assert indptr.dtype == 'int32', f"CSR indptr must be integers, but is 
{indptr.dtype}"
+    assert indices.dtype == 'int32', f"CSR indices must be integers, but is 
{indices.dtype}"
+
+    assert isinstance(dense, te.tensor.Tensor), \
+        "dense matrix is assumed to be tvm.te.Tensor, but dense is `%s`" % 
(type(dense))
+
+    M = simplify(indptr.shape[0]-1)
+    batches, _, N = dense.shape
+    def csrmm_default_ir(data, indices, indptr, dense, out):
+        """define ir for csrmm"""
+        irb = tvm.tir.ir_builder.create()
+        data_ptr = irb.buffer_ptr(data)
+        indices_ptr = irb.buffer_ptr(indices)
+        indptr_ptr = irb.buffer_ptr(indptr)
+        dense_ptr = irb.buffer_ptr(dense)
+        out_ptr = irb.buffer_ptr(out)
+        M = simplify(indptr.shape[0]-1)
+        batches, _, N = dense.shape
+        with irb.for_range(0, batches, name='batch') as batch:
+            with irb.for_range(0, N, for_type="vectorize", name='n') as n:
+                with irb.for_range(0, M, for_type="parallel", name='row') as 
row:
+                    dot = irb.allocate('float32', (1,), name='dot', 
scope='local')
+                    out_ptr[(batch*N*M) + (row*N+n)] = 0.

Review comment:
       +1 on multidimensional access, will update my code when I go back to it.
   
   
   
   >     1. You probably want to implement block sparsity as CSR is just a 
special case of BSR.
   
   I'll need to read more about block sparsity, as it's not a format I 
understand yet.  Will take a look, interesting if we get two formats for the 
price of one.
   
   >     2. There is already existing code to convert a dense matrix to a 
sparse matrix. Do we need another version for conv2d?
   
   The existing `csrmm` function does not support batches afaik (so would be 
`(NxK) x (KxM)`, rather than the `(NxK) x (BxKxM)` we need where `B` is the 
number of batches).  Ideally one sparse matmul would be good, in theory we 
could do `B` calls to the standard `csrmm` function.  Practically I'd need to 
think about what that would look like from an implementation perspective, since 
the size data is stored in the TVM tensors, and is used in `csrmm` to generate 
the function.
   
   




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to