tkonolige commented on a change in pull request #7927:
URL: https://github.com/apache/tvm/pull/7927#discussion_r620690363
##########
File path: python/tvm/topi/cuda/scatter.py
##########
@@ -723,7 +723,7 @@ def update_func(dst_ptr, dst_index, update):
return out
-def scatter_nd(data, indices, shape):
+def scatter_nd(data, indices, updates, mode):
"""Scatter elements from a n-dimension array.
Given data with shape (Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1}), indices with
shape
Review comment:
Can you update this?
##########
File path: python/tvm/topi/x86/scatter.py
##########
@@ -46,62 +46,69 @@ def scatter_nd(data, indices, shape):
indices : tvm.te.Tensor
The indices of the values to extract.
- shape : Sequence[int]
- The output shape. This must be specified because it cannot be inferred.
+ updates : tvm.te.Tensor
+ The updates to apply at the Indices
+
+ mode : string
+ The update mode for the algorith, either "update" or "add"
Review comment:
```suggestion
The update mode for the algorithm, either "update" or "add"
```
##########
File path: python/tvm/topi/scatter.py
##########
@@ -248,29 +248,31 @@ def scatter_nd(data, indices, shape):
indices : tvm.te.Tensor
The indices of the values to extract.
- shape : Sequence[int]
- The output shape. This must be specified because it cannot be inferred.
+ updates : tvm.te.Tensor
+ The updates to apply at the Indices
+
+ mode : string
+ The update mode for the algorith, either "update" or "add"
Review comment:
What do the different modes do?
##########
File path: python/tvm/topi/cuda/scatter.py
##########
@@ -789,38 +790,41 @@ def gen_ir(data_ptr, indices_ptr, out_ptr):
bx = te.thread_axis("blockIdx.x")
tx = te.thread_axis("threadIdx.x")
max_threads =
int(tvm.target.Target.current(allow_none=False).max_num_threads)
- tdim = min(max_threads, fused_data_dimension)
+ tdim = min(max_threads, fused_updates_dimension)
ib.scope_attr(tx, "thread_extent", tdim)
- bdim = ceil_div(fused_data_dimension, tdim)
+ bdim = ceil_div(fused_updates_dimension, tdim)
ib.scope_attr(bx, "thread_extent", bdim)
- # zero data
- # TODO(tkonolige): could we use topi.full to zero it instead?
with ib.for_range(0, ceil_div(fused_shape, bdim)) as i:
- index = i * fused_data_dimension + bx * tdim + tx
+ index = i * fused_updates_dimension + bx * tdim + tx
with ib.if_scope(index < fused_shape):
- out[index] = tvm.tir.Cast(data_ptr.dtype, 0)
+ out[index] = data[index]
with ib.for_range(0, fused_indices_dimension) as i:
j = bx * tdim + tx
- with ib.if_scope(j < fused_data_dimension):
- offset = fused_data_dimension
+ with ib.if_scope(j < fused_updates_dimension):
+ offset = fused_updates_dimension
index = j # This is x_M, .. x_{N-1} part of the index into
out.
# Build up the indices[0, y_0, .. y_{K-1}], .. indices[M-1,
y_0, .. y_{K-1}] part
# of the index into out.
for l in reversed(range(indices_ptr.shape[0].value)):
# indices[i * l * fused_indices_dimension] = indices[l,
y_0, ... y_{k-1}]
index += offset * indices[i + l * fused_indices_dimension]
- offset *= shape[l]
- out[index] += data[i * fused_data_dimension + j]
+ offset *= data_ptr.shape[l]
+ if mode == "update":
+ out[index] = updates[i * fused_updates_dimension + j]
+ elif mode == "add":
+ out[index] += updates[i * fused_updates_dimension + j]
+ else:
+ raise NotImplementedError("scatter_nd mode not
supported:", mode)
Review comment:
Please add the supported modes to the error message.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]