tkonolige commented on a change in pull request #8479:
URL: https://github.com/apache/tvm/pull/8479#discussion_r670646959



##########
File path: python/tvm/topi/cuda/scatter.py
##########
@@ -787,42 +787,45 @@ def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr):
         for i in data_ptr.shape:
             fused_shape *= i
 
-        # For now we avoid parallizing over dimensions indexed by `indices` as
-        # there may be repeated indices and hadling parallel accumulation can
-        # be hard. So we parallelize over X_M .. X_{N-1} instead. This will
-        # work well when these dimensions are large enough to saturate memory
-        # bandwidth, but performance will be bad when these dimensions are
-        # small.

Review comment:
       Can you add a comment about how we are doing parallelism (we are 
thread-parallel over all the update dimension and each block handles one set of 
indices?)

##########
File path: python/tvm/topi/cuda/scatter.py
##########
@@ -787,42 +787,45 @@ def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr):
         for i in data_ptr.shape:
             fused_shape *= i
 
-        # For now we avoid parallizing over dimensions indexed by `indices` as
-        # there may be repeated indices and hadling parallel accumulation can
-        # be hard. So we parallelize over X_M .. X_{N-1} instead. This will
-        # work well when these dimensions are large enough to saturate memory
-        # bandwidth, but performance will be bad when these dimensions are
-        # small.
-        bx = te.thread_axis("blockIdx.x")
-        tx = te.thread_axis("threadIdx.x")
-        max_threads = 
int(tvm.target.Target.current(allow_none=False).max_num_threads)
-        tdim = min(max_threads, fused_updates_dimension)
-        ib.scope_attr(tx, "thread_extent", tdim)
-        bdim = ceil_div(fused_updates_dimension, tdim)
-        ib.scope_attr(bx, "thread_extent", bdim)
-
-        # Copy data into the output. This loop writes to the same portions of
-        # memory as the following loop, so we do not need a memory sync.
-        with ib.for_range(0, ceil_div(fused_shape, fused_updates_dimension), 
name="i") as i:
-            index = i * fused_updates_dimension + bx * tdim + tx
-            with ib.if_scope(bx * tdim + tx < fused_updates_dimension):
+        # Init output tensor.
+        with ib.new_scope():
+            bidx = te.thread_axis("blockIdx.x")
+            tidx = te.thread_axis("threadIdx.x")
+            gridDim = 1
+            for i in data_ptr.shape[:-1]:
+                gridDim *= i
+            blockDim = data_ptr.shape[-1]
+
+            ib.scope_attr(bidx, "thread_extent", gridDim)
+            ib.scope_attr(tidx, "thread_extent", blockDim)

Review comment:
       In some cases this dimension will be very small. Can you instead split 
the full shape by max_num_threads?

##########
File path: python/tvm/topi/cuda/scatter.py
##########
@@ -787,42 +787,45 @@ def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr):
         for i in data_ptr.shape:
             fused_shape *= i
 
-        # For now we avoid parallizing over dimensions indexed by `indices` as
-        # there may be repeated indices and hadling parallel accumulation can
-        # be hard. So we parallelize over X_M .. X_{N-1} instead. This will
-        # work well when these dimensions are large enough to saturate memory
-        # bandwidth, but performance will be bad when these dimensions are
-        # small.
-        bx = te.thread_axis("blockIdx.x")
-        tx = te.thread_axis("threadIdx.x")
-        max_threads = 
int(tvm.target.Target.current(allow_none=False).max_num_threads)
-        tdim = min(max_threads, fused_updates_dimension)
-        ib.scope_attr(tx, "thread_extent", tdim)
-        bdim = ceil_div(fused_updates_dimension, tdim)
-        ib.scope_attr(bx, "thread_extent", bdim)
-
-        # Copy data into the output. This loop writes to the same portions of
-        # memory as the following loop, so we do not need a memory sync.
-        with ib.for_range(0, ceil_div(fused_shape, fused_updates_dimension), 
name="i") as i:
-            index = i * fused_updates_dimension + bx * tdim + tx
-            with ib.if_scope(bx * tdim + tx < fused_updates_dimension):
+        # Init output tensor.
+        with ib.new_scope():
+            bidx = te.thread_axis("blockIdx.x")
+            tidx = te.thread_axis("threadIdx.x")
+            gridDim = 1
+            for i in data_ptr.shape[:-1]:
+                gridDim *= i
+            blockDim = data_ptr.shape[-1]
+
+            ib.scope_attr(bidx, "thread_extent", gridDim)
+            ib.scope_attr(tidx, "thread_extent", blockDim)
+            index = bidx * blockDim + tidx
+            with ib.if_scope(index < fused_shape):
                 out[index] = data[index]
 
-        with ib.for_range(0, fused_indices_dimension) as i:
-            j = bx * tdim + tx
+        # Update output tensor by given values.
+        with ib.new_scope():
+            bidx = te.thread_axis("blockIdx.x")
+            tidx = te.thread_axis("threadIdx.x")
+            gridDim = fused_indices_dimension  # 32 * 600 = 19200
+            blockDim = fused_updates_dimension
+            ib.scope_attr(bidx, "thread_extent", gridDim)
+            ib.scope_attr(tidx, "thread_extent", blockDim)
+
+            j = tidx
             with ib.if_scope(j < fused_updates_dimension):
                 offset = fused_updates_dimension
-                index = j  # This is x_M, .. x_{N-1} part of the index into 
out.
-                # Build up the indices[0, y_0, .. y_{K-1}], .. indices[M-1, 
y_0, .. y_{K-1}] part
-                # of the index into out.
-                for l in reversed(range(indices_ptr.shape[0].value)):
+                findex = j

Review comment:
       You've set `j = tidx` and then only use it in one spot. Why not just use 
`tidx` everywhere?

##########
File path: python/tvm/topi/cuda/scatter.py
##########
@@ -787,42 +787,45 @@ def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr):
         for i in data_ptr.shape:
             fused_shape *= i
 
-        # For now we avoid parallizing over dimensions indexed by `indices` as
-        # there may be repeated indices and hadling parallel accumulation can
-        # be hard. So we parallelize over X_M .. X_{N-1} instead. This will
-        # work well when these dimensions are large enough to saturate memory
-        # bandwidth, but performance will be bad when these dimensions are
-        # small.
-        bx = te.thread_axis("blockIdx.x")
-        tx = te.thread_axis("threadIdx.x")
-        max_threads = 
int(tvm.target.Target.current(allow_none=False).max_num_threads)
-        tdim = min(max_threads, fused_updates_dimension)
-        ib.scope_attr(tx, "thread_extent", tdim)
-        bdim = ceil_div(fused_updates_dimension, tdim)
-        ib.scope_attr(bx, "thread_extent", bdim)
-
-        # Copy data into the output. This loop writes to the same portions of
-        # memory as the following loop, so we do not need a memory sync.
-        with ib.for_range(0, ceil_div(fused_shape, fused_updates_dimension), 
name="i") as i:
-            index = i * fused_updates_dimension + bx * tdim + tx
-            with ib.if_scope(bx * tdim + tx < fused_updates_dimension):
+        # Init output tensor.
+        with ib.new_scope():
+            bidx = te.thread_axis("blockIdx.x")
+            tidx = te.thread_axis("threadIdx.x")
+            gridDim = 1
+            for i in data_ptr.shape[:-1]:
+                gridDim *= i
+            blockDim = data_ptr.shape[-1]
+
+            ib.scope_attr(bidx, "thread_extent", gridDim)
+            ib.scope_attr(tidx, "thread_extent", blockDim)
+            index = bidx * blockDim + tidx
+            with ib.if_scope(index < fused_shape):
                 out[index] = data[index]
 
-        with ib.for_range(0, fused_indices_dimension) as i:
-            j = bx * tdim + tx
+        # Update output tensor by given values.
+        with ib.new_scope():
+            bidx = te.thread_axis("blockIdx.x")
+            tidx = te.thread_axis("threadIdx.x")
+            gridDim = fused_indices_dimension  # 32 * 600 = 19200

Review comment:
       remove this comment




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to