gemini-code-assist[bot] commented on code in PR #19497:
URL: https://github.com/apache/tvm/pull/19497#discussion_r3177762384


##########
python/tvm/topi/scatter_elements.py:
##########
@@ -103,26 +110,66 @@ def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr, 
reduce_func):
         updates = T.buffer_proxy(updates_ptr)
         out = T.buffer_proxy(out_ptr)
 
-        # Copy initial input data to output
         with IRBuilder() as ib:
             with T.seq_scope():
-                with T.parallel(0, full_range) as i:
-                    out[i] = data[i]
-
-                with T.parallel(0, ind_before_axis_range * 
ind_after_axis_range) as fused:
-                    i = fused // ind_after_axis_range
-                    j = fused % ind_after_axis_range
-                    pre_index1 = i * ind_before_axis_stride + j
-                    pre_index2 = i * before_axis_stride + j
-                    with T.serial(0, ind_axis_range) as k:
-                        # Offset along indices or updates
-                        index1 = pre_index1 + k * ind_after_axis_range
-                        # Get index and shift to positive side if need
-                        k_new = indices[index1]
-                        shifted_index = k_new + (k_new < 0) * axis_range
-                        # Offset along data
-                        index2 = pre_index2 + shifted_index * after_axis_range
-                        reduce_func(out, index2, updates[index1])
+                if is_gpu:
+                    max_threads = int(target.attrs["max_num_threads"])
+
+                    # Init
+                    nthread_bx_init = cast(ceil_div(full_range, max_threads), 
"int32")
+                    tx_init = te.thread_axis("threadIdx.x")
+                    bx_init = te.thread_axis("blockIdx.x")
+                    with T.frame_scope([
+                        T.attr(bx_init, "thread_extent", nthread_bx_init),
+                        T.attr(tx_init, "thread_extent", max_threads),
+                    ]):
+                        tid = bx_init * max_threads + tx_init
+                        with T.If(tid < full_range):
+                            with T.Then():
+                                out[tid] = data[tid]
+
+                    # Scatter
+                    nthread_bx_scat = cast(
+                        ceil_div(ind_full_range_excl_axis, max_threads), 
"int32"
+                    )
+                    tx_scat = te.thread_axis("threadIdx.x")
+                    bx_scat = te.thread_axis("blockIdx.x")
+                    with T.frame_scope([
+                        T.attr(bx_scat, "thread_extent", nthread_bx_scat),
+                        T.attr(tx_scat, "thread_extent", max_threads),
+                    ]):
+                        fused = bx_scat * max_threads + tx_scat
+                        with T.If(fused < ind_full_range_excl_axis):
+                            with T.Then():
+                                i = fused // ind_after_axis_range
+                                j = fused % ind_after_axis_range
+                                pre_index1 = i * ind_before_axis_stride + j
+                                pre_index2 = i * before_axis_stride + j
+                                with T.serial(0, ind_axis_range) as k:
+                                    index1 = pre_index1 + k * 
ind_after_axis_range
+                                    k_new = indices[index1]
+                                    shifted_index = k_new + (k_new < 0) * 
axis_range
+                                    index2 = pre_index2 + shifted_index * 
after_axis_range
+                                    reduce_func(out, index2, updates[index1])
+                else:
+                    # Copy initial input data to output
+                    with T.parallel(0, full_range) as i:
+                        out[i] = data[i]
+
+                    with T.parallel(0, ind_full_range_excl_axis) as fused:
+                        i = fused // ind_after_axis_range
+                        j = fused % ind_after_axis_range
+                        pre_index1 = i * ind_before_axis_stride + j
+                        pre_index2 = i * before_axis_stride + j
+                        with T.serial(0, ind_axis_range) as k:
+                            # Offset along indices or updates
+                            index1 = pre_index1 + k * ind_after_axis_range
+                            # Get index and shift to positive side if need
+                            k_new = indices[index1]
+                            shifted_index = k_new + (k_new < 0) * axis_range
+                            # Offset along data
+                            index2 = pre_index2 + shifted_index * 
after_axis_range
+                            reduce_func(out, index2, updates[index1])
 
             return ib.get()

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   The logic for initializing the output and performing the scatter operation 
is significantly duplicated between the GPU and CPU paths. This duplication 
makes the code harder to maintain and increases the risk of logic divergence if 
changes are needed in the future. Consider extracting the common logic into 
local helper functions (e.g., `_init_body` and `_scatter_body`) within `gen_ir` 
to improve readability and maintainability.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to