tkonolige commented on a change in pull request #7050:
URL: https://github.com/apache/tvm/pull/7050#discussion_r537807196



##########
File path: python/tvm/topi/sparse/csrmm.py
##########
@@ -121,3 +121,46 @@ def csrmm(a, b, c=None):
         2-D with shape [m, n]
     """
     return csrmm_default(a.data, a.indices, a.indptr, b, c)
+
+
+def batch_csrmm(data, indices, indptr, dense, oshape):
+    # pylint: disable=invalid-name
+    assert len(data.shape) == 1 and len(indices.shape) == 1 and 
len(indptr.shape) == 1 \
+        and len(dense.shape) == 3, "only support 2-dim csrmm"
+    assert indptr.dtype == 'int32', f"CSR indptr must be integers, but is 
{indptr.dtype}"
+    assert indices.dtype == 'int32', f"CSR indices must be integers, but is 
{indices.dtype}"
+
+    assert isinstance(dense, te.tensor.Tensor), \
+        "dense matrix is assumed to be tvm.te.Tensor, but dense is `%s`" % 
(type(dense))
+
+    M = simplify(indptr.shape[0]-1)
+    batches, _, N = dense.shape
+    def csrmm_default_ir(data, indices, indptr, dense, out):
+        """define ir for csrmm"""
+        irb = tvm.tir.ir_builder.create()
+        data_ptr = irb.buffer_ptr(data)
+        indices_ptr = irb.buffer_ptr(indices)
+        indptr_ptr = irb.buffer_ptr(indptr)
+        dense_ptr = irb.buffer_ptr(dense)
+        out_ptr = irb.buffer_ptr(out)
+        M = simplify(indptr.shape[0]-1)
+        batches, _, N = dense.shape
+        with irb.for_range(0, batches, name='batch') as batch:
+            with irb.for_range(0, N, for_type="vectorize", name='n') as n:
+                with irb.for_range(0, M, for_type="parallel", name='row') as 
row:
+                    dot = irb.allocate('float32', (1,), name='dot', 
scope='local')
+                    out_ptr[(batch*N*M) + (row*N+n)] = 0.

Review comment:
       ir_builder supports multidimensional access (`out_ptr[batch, row, n]`), 
which might make this code cleaner.

##########
File path: python/tvm/topi/sparse/csrmm.py
##########
@@ -121,3 +121,46 @@ def csrmm(a, b, c=None):
         2-D with shape [m, n]
     """
     return csrmm_default(a.data, a.indices, a.indptr, b, c)
+
+
+def batch_csrmm(data, indices, indptr, dense, oshape):
+    # pylint: disable=invalid-name
+    assert len(data.shape) == 1 and len(indices.shape) == 1 and 
len(indptr.shape) == 1 \
+        and len(dense.shape) == 3, "only support 2-dim csrmm"
+    assert indptr.dtype == 'int32', f"CSR indptr must be integers, but is 
{indptr.dtype}"
+    assert indices.dtype == 'int32', f"CSR indices must be integers, but is 
{indices.dtype}"
+
+    assert isinstance(dense, te.tensor.Tensor), \
+        "dense matrix is assumed to be tvm.te.Tensor, but dense is `%s`" % 
(type(dense))
+
+    M = simplify(indptr.shape[0]-1)
+    batches, _, N = dense.shape
+    def csrmm_default_ir(data, indices, indptr, dense, out):
+        """define ir for csrmm"""
+        irb = tvm.tir.ir_builder.create()
+        data_ptr = irb.buffer_ptr(data)
+        indices_ptr = irb.buffer_ptr(indices)
+        indptr_ptr = irb.buffer_ptr(indptr)
+        dense_ptr = irb.buffer_ptr(dense)
+        out_ptr = irb.buffer_ptr(out)
+        M = simplify(indptr.shape[0]-1)
+        batches, _, N = dense.shape
+        with irb.for_range(0, batches, name='batch') as batch:
+            with irb.for_range(0, N, for_type="vectorize", name='n') as n:
+                with irb.for_range(0, M, for_type="parallel", name='row') as 
row:
+                    dot = irb.allocate('float32', (1,), name='dot', 
scope='local')
+                    out_ptr[(batch*N*M) + (row*N+n)] = 0.
+                    dot[0] = 0.
+                    row_start = indptr_ptr[row]
+                    row_end = indptr_ptr[row+1]
+                    row_elems = row_end-row_start
+                    with irb.for_range(0, row_elems, name='idx') as idx:
+                        elem = row_start+idx
+                        dot[0] += data_ptr[elem] * 
dense_ptr[indices_ptr[elem]*N+n]
+                    out_ptr[(batch*N*M) + row*N+n] += dot[0]
+        return irb.get()
+    matmul = te.extern(oshape, [data, indices, indptr, dense],
+                       lambda ins, outs: csrmm_default_ir(ins[0], ins[1], 
ins[2], ins[3], outs[0]),
+                       tag="csrmm", dtype='float32', name='out')

Review comment:
       I think we would like to support more than float32.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to