tkonolige commented on a change in pull request #8479:
URL: https://github.com/apache/tvm/pull/8479#discussion_r679338714



##########
File path: python/tvm/topi/cuda/scatter.py
##########
@@ -787,44 +791,94 @@ def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr):
         for i in data_ptr.shape:
             fused_shape *= i
 
-        # For now we avoid parallizing over dimensions indexed by `indices` as
-        # there may be repeated indices and hadling parallel accumulation can
-        # be hard. So we parallelize over X_M .. X_{N-1} instead. This will
-        # work well when these dimensions are large enough to saturate memory
-        # bandwidth, but performance will be bad when these dimensions are
-        # small.
-        bx = te.thread_axis("blockIdx.x")
-        tx = te.thread_axis("threadIdx.x")
         max_threads = 
int(tvm.target.Target.current(allow_none=False).max_num_threads)
         tdim = min(max_threads, fused_updates_dimension)
-        ib.scope_attr(tx, "thread_extent", tdim)
-        bdim = ceil_div(fused_updates_dimension, tdim)
-        ib.scope_attr(bx, "thread_extent", bdim)
-
-        # Copy data into the output. This loop writes to the same portions of
-        # memory as the following loop, so we do not need a memory sync.
-        with ib.for_range(0, ceil_div(fused_shape, fused_updates_dimension), 
name="i") as i:
-            index = i * fused_updates_dimension + bx * tdim + tx
-            with ib.if_scope(bx * tdim + tx < fused_updates_dimension):
+
+        with ib.new_scope():
+            bdim = ceil_div(fused_shape, tdim)
+            bx = te.thread_axis("blockIdx.x")
+            tx = te.thread_axis("threadIdx.x")
+            ib.scope_attr(bx, "thread_extent", bdim)
+            ib.scope_attr(tx, "thread_extent", tdim)
+
+            index = bx * tdim + tx
+            with ib.if_scope(index < fused_shape):
                 out[index] = data[index]
 
-        with ib.for_range(0, fused_indices_dimension) as i:
-            j = bx * tdim + tx
-            with ib.if_scope(j < fused_updates_dimension):
-                offset = fused_updates_dimension
-                index = j  # This is x_M, .. x_{N-1} part of the index into 
out.
-                # Build up the indices[0, y_0, .. y_{K-1}], .. indices[M-1, 
y_0, .. y_{K-1}] part
-                # of the index into out.
-                for l in reversed(range(indices_ptr.shape[0].value)):
-                    # indices[i * l * fused_indices_dimension] = indices[l, 
y_0, ... y_{k-1}]
-                    index += offset * indices[i + l * fused_indices_dimension]
-                    offset *= data_ptr.shape[l]
-                if mode == "update":
-                    out[index] = updates[i * fused_updates_dimension + j]
-                elif mode == "add":
-                    out[index] += updates[i * fused_updates_dimension + j]
-                else:
-                    raise NotImplementedError("scatter_nd mode not in [update, 
add]:", mode)
+        # For better performance, we introduce blockIdx.y to implement 
for-loops
+        # within one thread.
+        # The code is parallel over the scattered indices, so we use atomic_add
+        # to guarantee correctness when mode=="add"
+
+        # For now, atomic is not supported by target "vulkan", "metal", or 
"cuda" with "int64"
+        # So we fallback to normal algorithm, using "+=" rather than atomic_add
+
+        # TODO:

Review comment:
       please put a username on the TODO (your username assuming you will do 
this).

##########
File path: python/tvm/topi/cuda/scatter.py
##########
@@ -787,44 +791,94 @@ def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr):
         for i in data_ptr.shape:
             fused_shape *= i
 
-        # For now we avoid parallizing over dimensions indexed by `indices` as
-        # there may be repeated indices and hadling parallel accumulation can
-        # be hard. So we parallelize over X_M .. X_{N-1} instead. This will
-        # work well when these dimensions are large enough to saturate memory
-        # bandwidth, but performance will be bad when these dimensions are
-        # small.
-        bx = te.thread_axis("blockIdx.x")
-        tx = te.thread_axis("threadIdx.x")
         max_threads = 
int(tvm.target.Target.current(allow_none=False).max_num_threads)
         tdim = min(max_threads, fused_updates_dimension)
-        ib.scope_attr(tx, "thread_extent", tdim)
-        bdim = ceil_div(fused_updates_dimension, tdim)
-        ib.scope_attr(bx, "thread_extent", bdim)
-
-        # Copy data into the output. This loop writes to the same portions of
-        # memory as the following loop, so we do not need a memory sync.
-        with ib.for_range(0, ceil_div(fused_shape, fused_updates_dimension), 
name="i") as i:
-            index = i * fused_updates_dimension + bx * tdim + tx
-            with ib.if_scope(bx * tdim + tx < fused_updates_dimension):
+
+        with ib.new_scope():
+            bdim = ceil_div(fused_shape, tdim)
+            bx = te.thread_axis("blockIdx.x")
+            tx = te.thread_axis("threadIdx.x")
+            ib.scope_attr(bx, "thread_extent", bdim)
+            ib.scope_attr(tx, "thread_extent", tdim)
+
+            index = bx * tdim + tx
+            with ib.if_scope(index < fused_shape):
                 out[index] = data[index]
 
-        with ib.for_range(0, fused_indices_dimension) as i:
-            j = bx * tdim + tx
-            with ib.if_scope(j < fused_updates_dimension):
-                offset = fused_updates_dimension
-                index = j  # This is x_M, .. x_{N-1} part of the index into 
out.
-                # Build up the indices[0, y_0, .. y_{K-1}], .. indices[M-1, 
y_0, .. y_{K-1}] part
-                # of the index into out.
-                for l in reversed(range(indices_ptr.shape[0].value)):
-                    # indices[i * l * fused_indices_dimension] = indices[l, 
y_0, ... y_{k-1}]
-                    index += offset * indices[i + l * fused_indices_dimension]
-                    offset *= data_ptr.shape[l]
-                if mode == "update":
-                    out[index] = updates[i * fused_updates_dimension + j]
-                elif mode == "add":
-                    out[index] += updates[i * fused_updates_dimension + j]
-                else:
-                    raise NotImplementedError("scatter_nd mode not in [update, 
add]:", mode)
+        # For better performance, we introduce blockIdx.y to implement 
for-loops
+        # within one thread.
+        # The code is parallel over the scattered indices, so we use atomic_add
+        # to guarantee correctness when mode=="add"
+
+        # For now, atomic is not supported by target "vulkan", "metal", or 
"cuda" with "int64"
+        # So we fallback to normal algorithm, using "+=" rather than atomic_add
+
+        # TODO:
+        # Since multiple threads compete for the same write index, which leads 
to
+        # non-determinstic output for update mode. We could add a new attribute
+        # "allow_non_deterministic" to scatter_nd op, which is False by 
default.
+        # And change ONNX frontend to emit scatter_op with 
allow_non_deterministic = True,
+        # which will allow the new code path for update mode as well
+        with ib.new_scope():
+            if (
+                mode == "update"
+                or cur_target_kind("vulkan")
+                or cur_target_kind("metal")
+                or (updates.dtype == "int64" and mode == "add")
+            ):
+                bdim_x = ceil_div(fused_updates_dimension, tdim)
+                bx = te.thread_axis("blockIdx.x")
+                tx = te.thread_axis("threadIdx.x")
+                ib.scope_attr(bx, "thread_extent", bdim_x)
+                ib.scope_attr(tx, "thread_extent", tdim)
+                with ib.for_range(0, fused_indices_dimension) as i:
+                    j = bx * tdim + tx
+                    with ib.if_scope(j < fused_updates_dimension):
+                        offset = fused_updates_dimension
+                        index = j  # This is x_M, .. x_{N-1} part of the index 
into out.
+                        # Build up the
+                        # indices[0, y_0, .. y_{K-1}], ... indices[M-1, y_0, 
.. y_{K-1}]
+                        # part of the index into out.
+                        for l in reversed(range(indices_ptr.shape[0].value)):
+                            # indices[i * l * fused_indices_dimension] = 
indices[l, y_0,
+                            #                                                  
 ... y_{k-1}]
+                            index += offset * indices[i + l * 
fused_indices_dimension]
+                            offset *= data_ptr.shape[l]
+                        if mode == "update":
+                            out[index] = updates[i * fused_updates_dimension + 
j]
+                        elif mode == "add":
+                            out[index] += updates[i * fused_updates_dimension 
+ j]
+                        else:
+                            raise NotImplementedError("scatter_nd mode not in 
[update, add]:", mode)
+            elif mode == "add":
+                bdim_x = fused_indices_dimension
+                bdim_y = ceil_div(fused_updates_dimension, tdim)
+                # In case of large input sizes, fused_indices_dimension might 
be too large.
+                # So it could be moved to blockIdx.x position, which holds 
larger scales.

Review comment:
       ```suggestion
                   # In case of large input sizes, fused_indices_dimension 
might be too large.
                   # So we use blockIdx.x because holds larger scales.
   ```




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to