masahi commented on a change in pull request #8479:
URL: https://github.com/apache/tvm/pull/8479#discussion_r678960028



##########
File path: python/tvm/topi/cuda/scatter.py
##########
@@ -789,42 +790,80 @@ def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr):
 
         # For now we avoid parallizing over dimensions indexed by `indices` as
         # there may be repeated indices and hadling parallel accumulation can
-        # be hard. So we parallelize over X_M .. X_{N-1} instead. This will
-        # work well when these dimensions are large enough to saturate memory
-        # bandwidth, but performance will be bad when these dimensions are
-        # small.
-        bx = te.thread_axis("blockIdx.x")
-        tx = te.thread_axis("threadIdx.x")
+        # be hard. So we parallelize over X_M .. X_{N-1} instead.
+
+        # For better performance, we introduce blockIdx.y to implement 
for-loops
+        # within one thread.
+        # The code is parallel over the scattered indices, so we use atomic_add
+        # to guarantee correctness when mode=="add"
+
         max_threads = 
int(tvm.target.Target.current(allow_none=False).max_num_threads)
         tdim = min(max_threads, fused_updates_dimension)
-        ib.scope_attr(tx, "thread_extent", tdim)
-        bdim = ceil_div(fused_updates_dimension, tdim)
-        ib.scope_attr(bx, "thread_extent", bdim)
-
-        # Copy data into the output. This loop writes to the same portions of
-        # memory as the following loop, so we do not need a memory sync.
-        with ib.for_range(0, ceil_div(fused_shape, fused_updates_dimension), 
name="i") as i:
-            index = i * fused_updates_dimension + bx * tdim + tx
-            with ib.if_scope(bx * tdim + tx < fused_updates_dimension):
+
+        with ib.new_scope():
+            bdim = ceil_div(fused_shape, tdim)
+            bx = te.thread_axis("blockIdx.x")
+            tx = te.thread_axis("threadIdx.x")
+            ib.scope_attr(bx, "thread_extent", bdim)
+            ib.scope_attr(tx, "thread_extent", tdim)
+
+            index = bx * tdim + tx
+            with ib.if_scope(index < fused_shape):
                 out[index] = data[index]
 
-        with ib.for_range(0, fused_indices_dimension) as i:
-            j = bx * tdim + tx
-            with ib.if_scope(j < fused_updates_dimension):
-                offset = fused_updates_dimension
-                index = j  # This is x_M, .. x_{N-1} part of the index into 
out.
-                # Build up the indices[0, y_0, .. y_{K-1}], .. indices[M-1, 
y_0, .. y_{K-1}] part
-                # of the index into out.
-                for l in reversed(range(indices_ptr.shape[0].value)):
-                    # indices[i * l * fused_indices_dimension] = indices[l, 
y_0, ... y_{k-1}]
-                    index += offset * indices[i + l * fused_indices_dimension]
-                    offset *= data_ptr.shape[l]
-                if mode == "update":
-                    out[index] = updates[i * fused_updates_dimension + j]
-                elif mode == "add":
-                    out[index] += updates[i * fused_updates_dimension + j]
-                else:
-                    raise NotImplementedError("scatter_nd mode not in [update, 
add]:", mode)
+        with ib.new_scope():
+            if updates.dtype == "int64" and mode == "add":
+                bdim_x = ceil_div(fused_updates_dimension, tdim)
+                bx = te.thread_axis("blockIdx.x")
+                tx = te.thread_axis("threadIdx.x")
+                ib.scope_attr(bx, "thread_extent", bdim_x)
+                ib.scope_attr(tx, "thread_extent", tdim)
+                with ib.for_range(0, fused_indices_dimension) as i:
+                    j = bx * tdim + tx
+                    with ib.if_scope(j < fused_updates_dimension):
+                        offset = fused_updates_dimension
+                        index = j  # This is x_M, .. x_{N-1} part of the index 
into out.
+                        # Build up the
+                        # indices[0, y_0, .. y_{K-1}], ... indices[M-1, y_0, 
.. y_{K-1}]
+                        # part of the index into out.
+                        for l in reversed(range(indices_ptr.shape[0].value)):
+                            # indices[i * l * fused_indices_dimension] = 
indices[l, y_0,
+                            #                                                  
 ... y_{k-1}]
+                            index += offset * indices[i + l * 
fused_indices_dimension]
+                            offset *= data_ptr.shape[l]
+                        out[index] += updates[i * fused_updates_dimension + j]
+            else:
+                bdim_x = ceil_div(fused_updates_dimension, tdim)
+                bdim_y = fused_indices_dimension
+                # In case of large input sizes, bim_y might be too large.
+                # So it could be moved to blockIdx.x position, which holds 
larger scales.
+                bx = te.thread_axis("blockIdx.y")
+                by = te.thread_axis("blockIdx.x")
+                tx = te.thread_axis("threadIdx.x")
+                ib.scope_attr(bx, "thread_extent", bdim_x)
+                ib.scope_attr(by, "thread_extent", bdim_y)
+                ib.scope_attr(tx, "thread_extent", tdim)
+
+                j = bx * tdim + tx
+                with ib.if_scope(j < fused_updates_dimension):
+                    offset = fused_updates_dimension
+                    index = j  # This is x_M, .. x_{N-1} part of the index 
into out.
+                    # Build up the indices[0, y_0, .. y_{K-1}], .. 
indices[M-1, y_0, .. y_{K-1}]
+                    # part of the index into out.
+                    up_index = by * fused_updates_dimension + j
+                    for l in reversed(range(indices_ptr.shape[0].value)):
+                        # indices[by * l * fused_indices_dimension] = 
indices[l, y_0, ... y_{k-1}]
+                        index += offset * indices[by + l * 
fused_indices_dimension]
+                        offset *= data_ptr.shape[l]
+                    if mode == "update":
+                        out[index] = updates[up_index]

Review comment:
       Yes that sounds good, we can discuss with more people then. This has 
been on my mind for a while, since both our `scatter` and `scatter_nd` op 
sacrifice performance for deterministic output, while all other frameworks make 
the opposite choice (they say output is undefined when indices are not unique). 




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to