tkonolige commented on a change in pull request #8479:
URL: https://github.com/apache/tvm/pull/8479#discussion_r674157231



##########
File path: python/tvm/topi/cuda/scatter.py
##########
@@ -793,36 +794,51 @@ def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr):
         # work well when these dimensions are large enough to saturate memory
         # bandwidth, but performance will be bad when these dimensions are
         # small.
-        bx = te.thread_axis("blockIdx.x")
-        tx = te.thread_axis("threadIdx.x")
+
+        # For better performance, we introduce blockIdx.y to implement 
for-loops
+        # within one thread.
+        # Atomic_add guarantees correctness when mode=="add"
+
         max_threads = 
int(tvm.target.Target.current(allow_none=False).max_num_threads)
         tdim = min(max_threads, fused_updates_dimension)
-        ib.scope_attr(tx, "thread_extent", tdim)
-        bdim = ceil_div(fused_updates_dimension, tdim)
-        ib.scope_attr(bx, "thread_extent", bdim)
-
-        # Copy data into the output. This loop writes to the same portions of
-        # memory as the following loop, so we do not need a memory sync.
-        with ib.for_range(0, ceil_div(fused_shape, fused_updates_dimension), 
name="i") as i:
-            index = i * fused_updates_dimension + bx * tdim + tx
-            with ib.if_scope(bx * tdim + tx < fused_updates_dimension):
+
+        with ib.new_scope():
+            bdim = ceil_div(fused_shape, tdim)
+            bx = te.thread_axis("blockIdx.x")
+            tx = te.thread_axis("threadIdx.x")
+            ib.scope_attr(bx, "thread_extent", bdim)
+            ib.scope_attr(tx, "thread_extent", tdim)
+
+            index = bx * tdim + tx
+            with ib.if_scope(index < fused_shape):
                 out[index] = data[index]
 
-        with ib.for_range(0, fused_indices_dimension) as i:
+        with ib.new_scope():
+            bdim_x = ceil_div(fused_updates_dimension, tdim)
+            bdim_y = fused_indices_dimension
+            bx = te.thread_axis("blockIdx.x")
+            by = te.thread_axis("blockIdx.y")
+            tx = te.thread_axis("threadIdx.x")
+            ib.scope_attr(bx, "thread_extent", bdim_x)
+            ib.scope_attr(by, "thread_extent", bdim_y)
+            ib.scope_attr(tx, "thread_extent", tdim)
+
             j = bx * tdim + tx
             with ib.if_scope(j < fused_updates_dimension):
                 offset = fused_updates_dimension
-                index = j  # This is x_M, .. x_{N-1} part of the index into 
out.
-                # Build up the indices[0, y_0, .. y_{K-1}], .. indices[M-1, 
y_0, .. y_{K-1}] part
-                # of the index into out.

Review comment:
       Can you keep this comment. I believe it still holds

##########
File path: python/tvm/topi/cuda/scatter.py
##########
@@ -793,36 +794,51 @@ def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr):
         # work well when these dimensions are large enough to saturate memory
         # bandwidth, but performance will be bad when these dimensions are
         # small.

Review comment:
       This comment is no longer valid right?

##########
File path: python/tvm/topi/cuda/scatter.py
##########
@@ -793,36 +794,51 @@ def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr):
         # work well when these dimensions are large enough to saturate memory
         # bandwidth, but performance will be bad when these dimensions are
         # small.
-        bx = te.thread_axis("blockIdx.x")
-        tx = te.thread_axis("threadIdx.x")
+
+        # For better performance, we introduce blockIdx.y to implement 
for-loops
+        # within one thread.
+        # Atomic_add guarantees correctness when mode=="add"

Review comment:
       ```suggestion
           # The code is parallel over the scattered indices, so we use 
atomic_add to guarantee correctness when mode=="add".
   ```




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to