gemini-code-assist[bot] commented on code in PR #18716:
URL: https://github.com/apache/tvm/pull/18716#discussion_r2771442569
##########
tests/python/codegen/test_target_codegen_vulkan.py:
##########
@@ -211,188 +211,136 @@ def test_vulkan_while_if(target, dev):
target = tvm.target.Target(target)
n = 1
dtype = "int32"
- A = te.placeholder((n,), name="A", dtype=dtype)
-
- def do_compute(A, B, n):
- ib = tvm.tir.ir_builder.create()
- A = ib.buffer_ptr(A)
- B = ib.buffer_ptr(B)
- if "gpu" in target.keys:
- ib.scope_attr(te.thread_axis("blockIdx.x"), "thread_extent", 0)
+ def get_module(is_gpu):
+ if is_gpu:
- iterations = ib.allocate("int32", (1,), name="iterations",
scope="local")
- iterations[0] = 0
- B[0] = 0
+ @T.prim_func
+ def while_if_gpu(A: T.Buffer((1,), "int32"), B: T.Buffer((1,),
"int32")):
+ for bx in T.thread_binding(1, thread="blockIdx.x"):
+ iterations = T.decl_buffer((1,), "int32", scope="local")
+ iterations[0] = 0
+ B[0] = 0
+ while iterations[0] < T.if_then_else(A[0] > 0, 10, 20):
+ iterations[0] = iterations[0] + 1
+ B[0] = B[0] + iterations[0]
- loop_condition = iterations[0] < tvm.tir.if_then_else(A[0] > 0, 10, 20)
- with ib.while_loop(loop_condition):
- iterations[0] += 1
- B[0] += iterations[0]
+ return tvm.IRModule.from_expr(while_if_gpu.with_attr("target",
target))
+ else:
- return ib.get()
+ @T.prim_func
+ def while_if_cpu(A: T.Buffer((1,), "int32"), B: T.Buffer((1,),
"int32")):
+ iterations = T.decl_buffer((1,), "int32", scope="local")
+ iterations[0] = 0
+ B[0] = 0
+ while iterations[0] < T.if_then_else(A[0] > 0, 10, 20):
+ iterations[0] = iterations[0] + 1
+ B[0] = B[0] + iterations[0]
- B = te.extern(
- A.shape,
- [A],
- lambda ins, outs: do_compute(ins[0], outs[0], n),
- dtype=dtype,
- )
-
- # Create IRModule
- mod = tvm.IRModule.from_expr(te.create_prim_func([A, B]))
- sch = tvm.s_tir.Schedule(mod)
+ return tvm.IRModule.from_expr(while_if_cpu.with_attr("target",
target))
- # Build
- func = tvm.compile(sch.mod, target=target)
+ mod = get_module("gpu" in target.keys)
+ compiled_func = tvm.compile(mod, target=target)
- a = tvm.runtime.tensor(np.array([5], dtype=A.dtype), dev)
- b = tvm.runtime.tensor(np.zeros(n, dtype=A.dtype), dev)
- func(a, b)
+ a = tvm.runtime.tensor(np.array([5], dtype=dtype), dev)
+ b = tvm.runtime.tensor(np.zeros(n, dtype=dtype), dev)
+ compiled_func(a, b)
tvm.testing.assert_allclose(b.numpy(), [55])
- a = tvm.runtime.tensor(np.array([-5], dtype=A.dtype), dev)
- b = tvm.runtime.tensor(np.zeros(n, dtype=A.dtype), dev)
- func(a, b)
+ a = tvm.runtime.tensor(np.array([-5], dtype=dtype), dev)
+ b = tvm.runtime.tensor(np.zeros(n, dtype=dtype), dev)
+ compiled_func(a, b)
tvm.testing.assert_allclose(b.numpy(), [210])
@tvm.testing.exclude_targets("llvm")
def test_vulkan_local_threadidx(target, dev):
target = tvm.target.Target(target)
n = 32
- A = te.placeholder((n,), name="A", dtype="int32")
-
- def do_compute(A, B, n):
- ib = tvm.tir.ir_builder.create()
- A = ib.buffer_ptr(A)
- B = ib.buffer_ptr(B)
-
- tx = te.thread_axis("threadIdx.x")
- with ib.for_range(0, 1):
- ib.scope_attr(tx, "thread_extent", 16)
- B[tx + 0] = A[tx + 0]
-
- with ib.for_range(0, 1):
- ib.scope_attr(tx, "thread_extent", 16)
- B[tx + 16] = A[tx + 16]
-
- return ib.get()
-
- B = te.extern(
- A.shape,
- [A],
- lambda ins, outs: do_compute(ins[0], outs[0], n),
- dtype="int32",
- )
-
- # Create IRModule
- mod = tvm.IRModule.from_expr(te.create_prim_func([A, B]))
- sch = tvm.s_tir.Schedule(mod)
-
- # Build
- func = tvm.compile(sch.mod, target=target)
-
- n = 32
- a_np = np.arange(n).astype(dtype=A.dtype)
+ @T.prim_func
+ def local_threadidx_func(A: T.Buffer((32,), "int32"), B: T.Buffer((32,),
"int32")):
+ # First block with thread extent 16
+ for _ in range(1):
+ for tx in T.thread_binding(16, thread="threadIdx.x"):
+ B[tx + 0] = A[tx + 0]
+ # Second block with thread extent 16
+ for _ in range(1):
+ for tx in T.thread_binding(16, thread="threadIdx.x"):
+ B[tx + 16] = A[tx + 16]
+
+ mod = tvm.IRModule.from_expr(local_threadidx_func)
+ func = tvm.compile(mod, target=target)
+
+ a_np = np.arange(n).astype(dtype="int32")
b_np = np.zeros((n,), dtype="int32")
a = tvm.runtime.tensor(a_np, dev)
b = tvm.runtime.tensor(b_np, dev)
func(a, b)
tvm.testing.assert_allclose(b.numpy(), a_np)
-class TestVectorizedIndices:
- load_type, store_type = tvm.testing.parameters(
- # Load N values, write to N locations.
- # Vectorized copy.
- ("ramp", "ramp"),
- # Load 1 value, write to N locations.
- # Scalar load, vectorized store.
- #
- # Most TVM operations (e.g. schedule[tensor].vectorize(axis)) have
- # the broadcast outside of the index, but it is semantically okay
- # for the broadcast to be inside the index, and it shows up with
- # some optimizations.
- ("broadcast", "ramp"),
- # Load 1 values, write to 1 location.
- # Broadcasting on both sides should be equivalent to a scalar copy.
- ("broadcast", "broadcast"),
- # Loads N values, write to 1 location.
- # Disabled as it would have unclear semantics.
- # ("ramp","broadcoast"),
- )
- indirect_indices = tvm.testing.parameter(True, False, ids=["reorder",
"no_reorder"])
-
- @tvm.testing.fixture
- def ref_data(self, load_type, store_type, indirect_indices):
- n = 4
-
- index_map = {
- "ramp": np.arange(n),
- "broadcast": np.zeros(n, dtype="int32"),
- }
-
- a_np = np.random.randint(np.iinfo("int32").max, size=n).astype("int32")
- b_np = np.zeros(shape=n, dtype=a_np.dtype)
- reorder_np = np.arange(n, dtype="int32")[::-1]
-
- load_index = index_map[load_type]
- store_index = index_map[store_type]
-
- if indirect_indices:
- load_index = reorder_np[load_index]
-
- b_np[store_index] = a_np[load_index]
-
- return a_np, reorder_np, b_np
-
- @tvm.testing.fixture
- def mod(self, target, load_type, store_type, indirect_indices):
- target = tvm.target.Target(target)
-
- n = 4
- dtype = "int32"
- A = te.placeholder((n,), dtype=dtype, name="A")
- R = te.placeholder((n,), dtype=dtype, name="R")
-
- def do_compute(ins, outs):
- ib = tvm.tir.ir_builder.create()
- A, R = map(ib.buffer_ptr, ins)
- B = ib.buffer_ptr(outs[0])
-
- if "gpu" in target.keys:
- ib.scope_attr(te.thread_axis("blockIdx.x"), "thread_extent", 0)
-
- index_map = {
- "ramp": tvm.tir.Ramp(0, 1, 4),
- "broadcast": tvm.tir.Broadcast(0, 4),
- }
[email protected]_targets("vulkan -from_device=0")
+def test_vectorized_index_ramp(target, dev):
+ """Test vectorized copy with ramp indices (load N values, write to N
locations)"""
+ n = 4
+ ramp_index = tvm.tir.Ramp(0, 1, 4)
- load_index = index_map[load_type]
- store_index = index_map[store_type]
+ @I.ir_module
+ class Module:
+ @T.prim_func
+ def main(var_A: T.handle, var_B: T.handle):
+ T.func_attr({"tir.noalias": True})
+ A = T.match_buffer(var_A, (n,), "int32", offset_factor=1)
+ B = T.match_buffer(var_B, (n,), "int32", offset_factor=1)
+ with T.sblock("compute"):
+ T.reads()
+ T.writes()
+ bx = T.launch_thread("blockIdx.x", 1)
+ B[ramp_index] = A[ramp_index]
- if indirect_indices:
- load_index = R[load_index]
+ f = tvm.compile(Module, target=target)
- B[store_index] = A[load_index]
+ a_np = np.random.randint(np.iinfo("int32").max, size=n).astype("int32")
+ b_np = np.zeros(n, dtype="int32")
- return ib.get()
+ a = tvm.runtime.tensor(a_np, dev)
+ b = tvm.runtime.tensor(b_np, dev)
+ f(a, b)
+ tvm.testing.assert_allclose(b.numpy(), a_np)
- B = te.extern(A.shape, [A, R], do_compute, dtype="int32")
- return tvm.IRModule.from_expr(te.create_prim_func([A, R, B]))
[email protected]_targets("vulkan -from_device=0")
+def test_vectorized_index_broadcast(target, dev):
+ """Test broadcast index (load 1 value, write to N locations)"""
+ n = 4
+ broadcast_index = tvm.tir.Broadcast(0, 4)
+ ramp_index = tvm.tir.Ramp(0, 1, 4)
- def test_ramp_broadcast_index(self, target, dev, mod, ref_data):
- f = tvm.compile(mod, target=target)
+ @I.ir_module
+ class Module:
+ @T.prim_func
+ def main(var_A: T.handle, var_B: T.handle):
+ T.func_attr({"tir.noalias": True})
+ A = T.match_buffer(var_A, (n,), "int32", offset_factor=1)
+ B = T.match_buffer(var_B, (n,), "int32", offset_factor=1)
+ with T.sblock("compute"):
+ T.reads()
+ T.writes()
+ bx = T.launch_thread("blockIdx.x", 1)
+ # Load from broadcast index (single element), store to ramp
index
+ B[ramp_index] = A[broadcast_index]
+
+ f = tvm.compile(Module, target=target)
+
+ a_np = np.random.randint(np.iinfo("int32").max, size=n).astype("int32")
+ b_np = np.zeros(n, dtype="int32")
- a_np, reorder_np, b_np = ref_data
- a = tvm.runtime.tensor(a_np, dev)
- r = tvm.runtime.tensor(reorder_np, dev)
- b = tvm.runtime.tensor(np.zeros(shape=b_np.shape, dtype="int32"), dev)
- f(a, r, b)
- tvm.testing.assert_allclose(b.numpy(), b_np)
+ a = tvm.runtime.tensor(a_np, dev)
+ b = tvm.runtime.tensor(b_np, dev)
+ f(a, b)
+ # All elements of b should be a[0] (broadcast load)
+ tvm.testing.assert_allclose(b.numpy(), np.full(n, a_np[0]))
Review Comment:

The refactoring from the parameterized `TestVectorizedIndices` class to two
separate tests seems to have reduced test coverage. The original tests covered
various combinations of `load_type`, `store_type`, and `indirect_indices`. The
new tests only cover a subset of these cases (e.g., indirect indexing via a
reordering buffer seems to be missing). Was this intentional? If not, it would
be good to restore the lost coverage.
##########
python/tvm/relax/transform/legalize_ops/grad.py:
##########
@@ -160,53 +163,60 @@ def _grad_take_backward(bb: BlockBuilder, call: Call) ->
Expr:
def te_take_backward(output_grad, x, indices):
def gen_ir(output_grad_ptr, x_ptr, indices_ptr, out_ptr):
# pylint: disable=invalid-name
- ib = tir.ir_builder.create()
-
- output_grad = ib.buffer_ptr(output_grad_ptr)
- indices = ib.buffer_ptr(indices_ptr)
- out = ib.buffer_ptr(out_ptr)
+ # Use buffer_proxy for flat indexing on multi-dimensional buffers
+ out = buffer_proxy(out_ptr)
+ grad = buffer_proxy(output_grad_ptr)
+ idx = buffer_proxy(indices_ptr)
fused_shape = 1
for i in x_ptr.shape:
fused_shape *= i
- with ib.for_range(0, fused_shape) as i:
- out[i] = tir.const(0, dtype=x_ptr.dtype)
+ # Build init loop (zero-fill output buffer)
+ with IRBuilder() as ib:
+ with T.serial(fused_shape) as i:
+ out[i] = tir.const(0, dtype=x_ptr.dtype)
+ init_stmt = ib.get()
assert len(indices_ptr.shape) == 1 # indices in take must be
1-dim Tensor
indices_len = indices_ptr.shape[0]
- if axis is not None:
- fused_output_grad_shape_pre = 1
- fused_output_grad_shape_nxt = 1
- for i in range(len(output_grad_ptr.shape)):
- if i < axis:
- fused_output_grad_shape_pre *= output_grad_ptr.shape[i]
- elif i > axis:
- fused_output_grad_shape_nxt *= output_grad_ptr.shape[i]
-
- x_axis_len = x_ptr.shape[axis]
-
- with ib.for_range(
- 0, fused_output_grad_shape_pre *
fused_output_grad_shape_nxt, "parallel"
- ) as fused:
- i = fused // fused_output_grad_shape_nxt
- j = fused % fused_output_grad_shape_nxt
- with ib.for_range(0, indices_len, "serial") as l:
- out[
- i * fused_output_grad_shape_nxt * x_axis_len
- + indices[l] * fused_output_grad_shape_nxt
- + j
- ] += output_grad[
- i * fused_output_grad_shape_nxt * indices_len
- + l * fused_output_grad_shape_nxt
- + j
- ]
- else:
- with ib.for_range(0, indices_len, "serial") as l:
- out[indices[l]] += output_grad[l]
-
- return ib.get()
+ # Build accumulation loop
+ with IRBuilder() as ib:
+ if axis is not None:
+ fused_output_grad_shape_pre = 1
+ fused_output_grad_shape_nxt = 1
+ for i in range(len(output_grad_ptr.shape)):
+ if i < axis:
+ fused_output_grad_shape_pre *=
output_grad_ptr.shape[i]
+ elif i > axis:
+ fused_output_grad_shape_nxt *=
output_grad_ptr.shape[i]
+
+ x_axis_len = x_ptr.shape[axis]
+
+ with T.serial(
+ fused_output_grad_shape_pre *
fused_output_grad_shape_nxt
+ ) as fused:
+ i = fused // fused_output_grad_shape_nxt
+ j = fused % fused_output_grad_shape_nxt
+ with T.serial(indices_len) as loop_l:
+ out_idx = (
+ i * fused_output_grad_shape_nxt * x_axis_len
+ + idx[loop_l] * fused_output_grad_shape_nxt
+ + j
+ )
+ grad_idx = (
+ i * fused_output_grad_shape_nxt * indices_len
+ + loop_l * fused_output_grad_shape_nxt
+ + j
+ )
+ out[out_idx] = out[out_idx] + grad[grad_idx]
+ else:
+ with T.serial(indices_len) as loop_l:
+ out[idx[loop_l]] = out[idx[loop_l]] + grad[loop_l]
+ accum_stmt = ib.get()
+
+ return tir.SeqStmt([init_stmt, accum_stmt])
Review Comment:

Instead of using two separate `IRBuilder` contexts and combining them with
`tir.SeqStmt`, you could use a single `IRBuilder` with `T.seq_scope()`. This
utility is introduced in this PR for this exact purpose and would make the code
more concise and idiomatic. Note that the `assert` and `indices_len`
calculation would need to be moved before the `IRBuilder` block.
```python
with IRBuilder() as ib:
with T.seq_scope():
# Build init loop (zero-fill output buffer)
with T.serial(fused_shape) as i:
out[i] = tir.const(0, dtype=x_ptr.dtype)
# Build accumulation loop
if axis is not None:
fused_output_grad_shape_pre = 1
fused_output_grad_shape_nxt = 1
for i in range(len(output_grad_ptr.shape)):
if i < axis:
fused_output_grad_shape_pre *=
output_grad_ptr.shape[i]
elif i > axis:
fused_output_grad_shape_nxt *=
output_grad_ptr.shape[i]
x_axis_len = x_ptr.shape[axis]
with T.serial(
fused_output_grad_shape_pre *
fused_output_grad_shape_nxt
) as fused:
i = fused // fused_output_grad_shape_nxt
j = fused % fused_output_grad_shape_nxt
with T.serial(indices_len) as loop_l:
out_idx = (
i * fused_output_grad_shape_nxt *
x_axis_len
+ idx[loop_l] *
fused_output_grad_shape_nxt
+ j
)
grad_idx = (
i * fused_output_grad_shape_nxt *
indices_len
+ loop_l * fused_output_grad_shape_nxt
+ j
)
out[out_idx] = out[out_idx] + grad[grad_idx]
else:
with T.serial(indices_len) as loop_l:
out[idx[loop_l]] = out[idx[loop_l]] +
grad[loop_l]
return ib.get()
```
##########
python/tvm/topi/scatter.py:
##########
@@ -115,35 +115,37 @@ def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr):
for i in data_ptr.shape:
fused_shape *= i
- with ib.for_range(0, fused_shape) as i:
- out[i] = data[i]
-
- with ib.for_range(0, fused_indices_dimension) as i:
- with ib.for_range(0, fused_updates_dimension, kind="parallel") as
j:
- offset = fused_updates_dimension
- index = j # This is x_M, .. x_{N-1} part of the index into
out.
- # Build up the indices[0, y_0, .. y_{K-1}], .. indices[M-1,
y_0, .. y_{K-1}] part
- # of the index into out.
- for l in reversed(range(indices_ptr.shape[0].value)):
- # indices[i * l * fused_indices_dimension] = indices[l,
y_0, ... y_{k-1}]
- index += offset * indices[i + l * fused_indices_dimension]
- offset *= data_ptr.shape[l]
- if mode == "update":
- out[index] = updates[i * fused_updates_dimension + j]
- elif mode == "add":
- out[index] += updates[i * fused_updates_dimension + j]
- elif mode == "mul":
- out[index] *= updates[i * fused_updates_dimension + j]
- elif mode == "min":
- out[index] = tir.min(out[index], updates[i *
fused_updates_dimension + j])
- elif mode == "max":
- out[index] = tir.max(out[index], updates[i *
fused_updates_dimension + j])
- else:
- raise NotImplementedError(
- "scatter_nd mode not in [update, add, mul, min,
max]:", mode
- )
-
- return ib.get()
+ with IRBuilder() as ib1:
+ with T.serial(0, fused_shape) as i:
+ out[i] = data[i]
+
+ with IRBuilder() as ib2:
+ with T.serial(0, fused_indices_dimension) as i:
+ with T.parallel(0, fused_updates_dimension) as j:
+ offset = fused_updates_dimension
+ index = j # This is x_M, .. x_{N-1} part of the index
into out.
+ # Build up the indices[0, y_0, ..], .. indices[M-1, y_0,
..] part
+ # of the index into out.
+ for l in reversed(range(indices_ptr.shape[0].value)):
+ # indices[l, y_0, ... y_{k-1}]
+ index += offset * indices[i + l *
fused_indices_dimension]
+ offset *= data_ptr.shape[l]
+ if mode == "update":
+ out[index] = updates[i * fused_updates_dimension + j]
+ elif mode == "add":
+ out[index] += updates[i * fused_updates_dimension + j]
+ elif mode == "mul":
+ out[index] *= updates[i * fused_updates_dimension + j]
+ elif mode == "min":
+ out[index] = tir.min(out[index], updates[i *
fused_updates_dimension + j])
+ elif mode == "max":
+ out[index] = tir.max(out[index], updates[i *
fused_updates_dimension + j])
+ else:
+ raise NotImplementedError(
+ "scatter_nd mode not in [update, add, mul, min,
max]:", mode
+ )
+
+ return tir.SeqStmt([ib1.get(), ib2.get()])
Review Comment:

You can use `T.seq_scope()` here to avoid using two separate `IRBuilder`
blocks and `tir.SeqStmt`. This would make the code more consistent with other
parts of the codebase that use this new utility.
##########
python/tvm/topi/gpu/sort.py:
##########
@@ -260,223 +317,341 @@ def serial_merge(
step_count,
first,
last,
+ i_buf,
+ j_buf,
):
- target = tvm.target.Target.current()
- is_webgpu = "webgpu" in str(target)
- target_dtype = "int32" if is_webgpu else "int64"
- i = ib.allocate(target_dtype, (1,), name="i", scope="local")
- j = ib.allocate(target_dtype, (1,), name="j", scope="local")
- i_val = aStart + first
- j_val = bStart + diag - last
- if is_webgpu:
- i[0] = cast(i_val, target_dtype)
- j[0] = cast(j_val, target_dtype)
- else:
- i[0] = i_val
- j[0] = j_val
-
- with ib.for_range(0, tvm.te.min(aCount + bCount - diag, step_count))
as count:
- i_idx = base_idx + i[0]
- j_idx = base_idx + j[0]
+ i_val = aStart + first[0]
+ j_val = bStart + diag - last[0]
+ i_buf[0] = i_val
+ j_buf[0] = j_val
+
+ with T.serial(0, tvm.te.min(aCount + bCount - diag, step_count)) as
count:
+ i_idx = base_idx + i_buf[0]
+ j_idx = base_idx + j_buf[0]
k_idx = base_idx + (kStart + diag + count)
- def assign_i():
- """assign i value to current output"""
- dest[k_idx] = source[i_idx]
- if values is not None:
- dest_idx[k_idx] = source_idx[i_idx]
- i[0] += 1
-
- def assign_j():
- """assign j value to current output"""
- dest[k_idx] = source[j_idx]
- if values is not None:
- dest_idx[k_idx] = source_idx[j_idx]
- j[0] += 1
-
- ## if both of the iterators are in range
- with ib.if_scope(tvm.tir.all(i[0] < aStart + aCount, j[0] < bStart
+ bCount)):
- # compare them and insert whichever is next into the output
- with ib.if_scope(compare(source[i_idx], source[j_idx])):
- assign_i()
- with ib.else_scope():
- assign_j()
- # otherwise, simply copy the remainder of the valid iterator to
the output
- with ib.else_scope():
- with ib.if_scope(i[0] < aStart + aCount):
- assign_i()
- with ib.else_scope():
- assign_j()
-
- target = tvm.target.Target.current()
- target_dtype = "int32" if "webgpu" in str(target) else "int64"
- with ib.for_range(0, cast(upper_lim - lower_lim, target_dtype),
dtype=target_dtype) as l2_width:
- width = 2 << (l2_width + lower_lim)
- # Define and launch the cuda kernel
- with ib.new_scope():
- target = tvm.target.Target.current()
- if "vulkan" in str(target):
- # Vulkan can't handle dynamic nthread, so we thread slightly
differently
- # for vulkan. We don't do this generally because it causes a
15% perf
- # regression on other platforms
- ntx = max_threads
- nbx = tvm.tir.generic.cast(ceil_div(width, max_threads *
thread_work), "int32")
- nbz = tvm.tir.generic.cast(ceil_div(size, width), "int32")
- tx, bx, by = _get_threads(ib, ntx, nbx, nthread_by * nbz)
- else:
- ntx = tvm.tir.generic.cast(tvm.te.min(max_threads, width),
"int32")
- nbx = tvm.tir.generic.cast(ceil_div(width, max_threads *
thread_work), "int32")
- nbz = tvm.tir.generic.cast(ceil_div(size, width), "int32")
- tx, bx, by = _get_threads(ib, ntx, nbx, nthread_by * nbz)
- by, bz = by % nthread_by, by // nthread_by
-
- def mergepath(
- source,
- dest,
- source_idx,
- dest_idx,
- aCount,
- bCount,
- aStart,
- bStart,
- kStart,
- step_count,
- even,
- ):
- # pylint: disable=arguments-out-of-order
- def merge(source, dest, source_idx, dest_idx):
- diag = tx * step_count
- first, last = get_merge_begin(
+ with T.If(tvm.tir.all(i_buf[0] < aStart + aCount, j_buf[0] <
bStart + bCount)):
+ with T.Then():
+ with T.If(compare(source[i_idx], source[j_idx])):
+ with T.Then():
+ dest[k_idx] = source[i_idx]
+ if values is not None:
+ dest_idx[k_idx] = source_idx[i_idx]
+ i_buf[0] = i_buf[0] + 1
+ with T.Else():
+ dest[k_idx] = source[j_idx]
+ if values is not None:
+ dest_idx[k_idx] = source_idx[j_idx]
+ j_buf[0] = j_buf[0] + 1
+ with T.Else():
+ with T.If(i_buf[0] < aStart + aCount):
+ with T.Then():
+ dest[k_idx] = source[i_idx]
+ if values is not None:
+ dest_idx[k_idx] = source_idx[i_idx]
+ i_buf[0] = i_buf[0] + 1
+ with T.Else():
+ dest[k_idx] = source[j_idx]
+ if values is not None:
+ dest_idx[k_idx] = source_idx[j_idx]
+ j_buf[0] = j_buf[0] + 1
+
+ def mergepath(
+ source,
+ dest,
+ source_idx,
+ dest_idx,
+ base_idx,
+ aCount,
+ bCount,
+ aStart,
+ bStart,
+ kStart,
+ tx,
+ step_count,
+ even,
+ ):
+ with T.frame_scope(
+ [
+ T.allocate([1], "int64", scope="local"), # first
+ T.allocate([1], "int64", scope="local"), # last
+ T.allocate([1], "int64", scope="local"), # i_buf
+ T.allocate([1], "int64", scope="local"), # j_buf
+ ]
+ ) as (first_ptr, last_ptr, i_ptr, j_ptr):
Review Comment:

The previous implementation had special handling for the `webgpu` target,
using `int32` for local buffers like `first` and `last`. The refactored code
seems to hardcode `int64`. This might cause issues on WebGPU where 64-bit
integers are not well supported. Was this change intentional?
##########
python/tvm/topi/index_put.py:
##########
@@ -85,56 +87,56 @@ def index_put(data, indices, values, accumulate=False):
index_len *= dim
def gen_ir(data_ptr, index_ptrs, values_ptr, out_ptr, reduce_func):
- ir_builder = tir.ir_builder.create()
-
- data = ir_builder.buffer_ptr(data_ptr)
- indices = [ir_builder.buffer_ptr(idx) for idx in index_ptrs]
- values = ir_builder.buffer_ptr(values_ptr)
- out = ir_builder.buffer_ptr(out_ptr)
-
- with ir_builder.for_range(0, full_range, "i", kind="parallel") as i:
- out[i] = data[i]
-
- with ir_builder.for_range(0, index_len, "k", kind="parallel") as k:
- # Decompose k into multi-dimensional broadcast index
- k_temp = k
- broadcast_indices = []
- for i in range(broadcast_ndim - 1, -1, -1):
- broadcast_indices.insert(0, k_temp % broadcast_shape[i])
- k_temp = k_temp // broadcast_shape[i]
-
- flat_index = 0
- stride = 1
- for dim in range(len(shape) - 1, -1, -1):
- # Get the index for this dimension using broadcasting
- idx_shape = index_shapes[dim]
- idx_ndim = len(idx_shape)
-
- # Compute the linear index into this index tensor
- idx_offset = 0
- idx_stride = 1
+ data = T.buffer_proxy(data_ptr)
+ indices = [T.buffer_proxy(idx) for idx in index_ptrs]
+ values = T.buffer_proxy(values_ptr)
+ out = T.buffer_proxy(out_ptr)
+
+ with IRBuilder() as ib1:
+ with T.parallel(0, full_range) as i:
+ out[i] = data[i]
+
+ with IRBuilder() as ib2:
+ with T.parallel(0, index_len) as k:
+ # Decompose k into multi-dimensional broadcast index
+ k_temp = k
+ broadcast_indices = []
for i in range(broadcast_ndim - 1, -1, -1):
- # Right-align the index shape with broadcast shape
- dim_idx = idx_ndim - broadcast_ndim + i
- if dim_idx >= 0:
- dim_size = idx_shape[dim_idx]
- # Use broadcasting: if size is 1, use index 0
- # otherwise use broadcast_indices[i]
- if utils.equal_const_int(dim_size, 1):
- idx_in_dim = 0
- else:
- idx_in_dim = broadcast_indices[i]
- idx_offset += idx_in_dim * idx_stride
- idx_stride *= dim_size
-
- idx_val = indices[dim][idx_offset]
- shifted_idx = idx_val + (idx_val < 0) * shape[dim]
- flat_index += shifted_idx * stride
- stride *= shape[dim]
-
- reduce_func(out, flat_index, values[k])
-
- return ir_builder.get()
+ broadcast_indices.insert(0, k_temp % broadcast_shape[i])
+ k_temp = k_temp // broadcast_shape[i]
+
+ flat_index = 0
+ stride = 1
+ for dim in range(len(shape) - 1, -1, -1):
+ # Get the index for this dimension using broadcasting
+ idx_shape = index_shapes[dim]
+ idx_ndim = len(idx_shape)
+
+ # Compute the linear index into this index tensor
+ idx_offset = 0
+ idx_stride = 1
+ for i in range(broadcast_ndim - 1, -1, -1):
+ # Right-align the index shape with broadcast shape
+ dim_idx = idx_ndim - broadcast_ndim + i
+ if dim_idx >= 0:
+ dim_size = idx_shape[dim_idx]
+ # Use broadcasting: if size is 1, use index 0
+ # otherwise use broadcast_indices[i]
+ if utils.equal_const_int(dim_size, 1):
+ idx_in_dim = 0
+ else:
+ idx_in_dim = broadcast_indices[i]
+ idx_offset += idx_in_dim * idx_stride
+ idx_stride *= dim_size
+
+ idx_val = indices[dim][idx_offset]
+ shifted_idx = idx_val + (idx_val < 0) * shape[dim]
+ flat_index += shifted_idx * stride
+ stride *= shape[dim]
+
+ reduce_func(out, flat_index, values[k])
+
+ return tir.SeqStmt([ib1.get(), ib2.get()])
Review Comment:

Similar to other files in this PR, you can use `T.seq_scope()` to combine
the two consecutive loops within a single `IRBuilder` context, instead of
creating two `IRBuilder` instances and combining them with `tir.SeqStmt`.
##########
python/tvm/topi/scatter_elements.py:
##########
@@ -95,35 +97,33 @@ def scatter_elements(data, indices, updates, axis=0,
reduction="update"):
def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr, reduce_func):
# pylint: disable=invalid-name
- ib = tir.ir_builder.create()
-
- data = ib.buffer_ptr(data_ptr)
- indices = ib.buffer_ptr(indices_ptr)
- updates = ib.buffer_ptr(updates_ptr)
- out = ib.buffer_ptr(out_ptr)
+ data = T.buffer_proxy(data_ptr)
+ indices = T.buffer_proxy(indices_ptr)
+ updates = T.buffer_proxy(updates_ptr)
+ out = T.buffer_proxy(out_ptr)
# Copy initial input data to output
- with ib.for_range(0, full_range, "i", kind="parallel") as i:
- out[i] = data[i]
-
- with ib.for_range(
- 0, ind_before_axis_range * ind_after_axis_range, "fused",
kind="parallel"
- ) as fused:
- i = fused // ind_after_axis_range
- j = fused % ind_after_axis_range
- pre_index1 = i * ind_before_axis_stride + j
- pre_index2 = i * before_axis_stride + j
- with ib.for_range(0, ind_axis_range, "k") as k:
- # Offset along indices or updates
- index1 = pre_index1 + k * ind_after_axis_range
- # Get index and shift to positive side if need
- k_new = indices[index1]
- shifted_index = k_new + (k_new < 0) * axis_range
- # Offset along data
- index2 = pre_index2 + shifted_index * after_axis_range
- reduce_func(out, index2, updates[index1])
-
- return ib.get()
+ with IRBuilder() as ib1:
+ with T.parallel(0, full_range) as i:
+ out[i] = data[i]
+
+ with IRBuilder() as ib2:
+ with T.parallel(0, ind_before_axis_range * ind_after_axis_range)
as fused:
+ i = fused // ind_after_axis_range
+ j = fused % ind_after_axis_range
+ pre_index1 = i * ind_before_axis_stride + j
+ pre_index2 = i * before_axis_stride + j
+ with T.serial(0, ind_axis_range) as k:
+ # Offset along indices or updates
+ index1 = pre_index1 + k * ind_after_axis_range
+ # Get index and shift to positive side if need
+ k_new = indices[index1]
+ shifted_index = k_new + (k_new < 0) * axis_range
+ # Offset along data
+ index2 = pre_index2 + shifted_index * after_axis_range
+ reduce_func(out, index2, updates[index1])
+
+ return tir.SeqStmt([ib1.get(), ib2.get()])
Review Comment:

This is another place where `T.seq_scope()` could be used to simplify the
code by combining the two `IRBuilder` blocks into one.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]