This is an automated email from the ASF dual-hosted git repository.
andrewzhaoluo pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 6097df5307 [ONNX][TORCH] Replace scatter op by scatter_elements
(#14019)
6097df5307 is described below
commit 6097df5307dc3f8254ae685c1bdbfe65f8934670
Author: Valery Chernov <[email protected]>
AuthorDate: Tue Feb 28 04:23:08 2023 +0400
[ONNX][TORCH] Replace scatter op by scatter_elements (#14019)
* remove scatter attr class
* update pytorch: scatter was replaced by scatter_elements
* remove scatter compute and strategy registration
* remove scatter attrs registration
* update onnx front-end: replace _op.scatter by _op.scatter_elements, add
checks
* update oneflow front-end
* update paddlepaddle front-end
* update pytorch utils
* remove front-end scatter definition
* fix scatter strategy for rocm
* small update
* remove scatter definition in back-end
* remove scatter strategy for cuda, gpu. transfer special case to
scatter_elements
* fix test
* small fix
* upstream scatter with torch description
* last upstream of scatter in pytorch front-end
* fix reduction attribute in cuda strategy
* set scalar to test instead of tensor. update check for dynamic dim
* skip scalar source check in tests for scatter due to issue on torch side
* remove scatter op implementation from topi/cuda
* remove scatter op implementation from topi. small clean code
---------
Co-authored-by: Valery Chernov <[email protected]>
---
include/tvm/relay/attrs/transform.h | 8 -
python/tvm/relay/frontend/oneflow.py | 2 +-
python/tvm/relay/frontend/onnx.py | 40 ++-
python/tvm/relay/frontend/paddlepaddle.py | 10 +-
python/tvm/relay/frontend/pytorch.py | 59 +++-
python/tvm/relay/frontend/pytorch_utils.py | 2 +-
python/tvm/relay/op/_transform.py | 10 -
python/tvm/relay/op/op_attrs.py | 5 -
python/tvm/relay/op/strategy/cuda.py | 30 +-
python/tvm/relay/op/strategy/generic.py | 22 +-
python/tvm/relay/op/strategy/rocm.py | 14 +-
python/tvm/relay/op/transform.py | 25 --
python/tvm/relay/transform/mixed_precision.py | 2 +
python/tvm/topi/cuda/scatter.py | 440 +-------------------------
python/tvm/topi/generic/search.py | 16 -
python/tvm/topi/scatter.py | 183 +----------
src/relay/op/tensor/transform.cc | 49 ---
tests/python/frontend/pytorch/test_forward.py | 7 +-
tests/python/relay/test_op_level3.py | 4 +-
19 files changed, 128 insertions(+), 800 deletions(-)
diff --git a/include/tvm/relay/attrs/transform.h
b/include/tvm/relay/attrs/transform.h
index 51378c8697..e1da1895d0 100644
--- a/include/tvm/relay/attrs/transform.h
+++ b/include/tvm/relay/attrs/transform.h
@@ -148,14 +148,6 @@ struct ReshapeLikeAttrs : public
tvm::AttrsNode<ReshapeLikeAttrs> {
}
}; // struct ReshapeLikeAttrs
-struct ScatterAttrs : public tvm::AttrsNode<ScatterAttrs> {
- Integer axis;
-
- TVM_DECLARE_ATTRS(ScatterAttrs, "relay.attrs.ScatterAttrs") {
- TVM_ATTR_FIELD(axis).set_default(0).describe("The axis over which to
select values.");
- }
-};
-
struct ScatterElementsAttrs : public tvm::AttrsNode<ScatterElementsAttrs> {
Integer axis;
String reduction;
diff --git a/python/tvm/relay/frontend/oneflow.py
b/python/tvm/relay/frontend/oneflow.py
index ff4b5a5bcc..1aba9e6419 100644
--- a/python/tvm/relay/frontend/oneflow.py
+++ b/python/tvm/relay/frontend/oneflow.py
@@ -1227,7 +1227,7 @@ class Scatter(OneFlowOpConverter):
@classmethod
def _impl_v1(cls, inputs, attrs, params):
axis = attrs.get("axis", 0)
- return _op.scatter(inputs[0], inputs[1], inputs[2], axis)
+ return _op.scatter_elements(inputs[0], inputs[1], inputs[2], axis)
class Unsqueeze(OneFlowOpConverter):
diff --git a/python/tvm/relay/frontend/onnx.py
b/python/tvm/relay/frontend/onnx.py
index 2a18906272..7c55bfefb7 100644
--- a/python/tvm/relay/frontend/onnx.py
+++ b/python/tvm/relay/frontend/onnx.py
@@ -1967,7 +1967,7 @@ class MaxUnpool(OnnxOpConverter):
# Create a tensor of zeros then scatter our data through it.
zeros_tensor = _op.zeros(total_output_shape, data_type)
# We need to flatten all our tensors before scattering.
- flat_tensor = _op.scatter(
+ flat_tensor = _op.scatter_elements(
_op.reshape(zeros_tensor, [-1]),
_op.reshape(indices, [-1]),
_op.reshape(data, [-1]),
@@ -2734,15 +2734,15 @@ class Slice(OnnxOpConverter):
# Update the starts and ends according to axes if required.
if axes is not None:
data_shape = shape_of(inputs[0],
dtype=infer_type(ends).checked_type.dtype)
- starts = _op.scatter(
+ starts = _op.scatter_elements(
_op.const([0] * data_rank,
dtype=infer_type(starts).checked_type.dtype),
axes,
starts,
axis=0,
)
- ends = _op.scatter(data_shape, axes, ends, axis=0)
+ ends = _op.scatter_elements(data_shape, axes, ends, axis=0)
if steps is not None:
- steps = _op.scatter(
+ steps = _op.scatter_elements(
_op.const([1] * data_rank,
dtype=infer_type(steps).checked_type.dtype),
axes,
steps,
@@ -2848,9 +2848,35 @@ class Scatter(OnnxOpConverter):
"""Operator converter for Scatter."""
@classmethod
- def _impl_v9(cls, inputs, attr, params):
+ def _args_check(cls, inputs, attr):
+ assert len(inputs) == 3, "Scatter takes 3 inputs (data, indices,
updates), {} given".format(
+ len(inputs)
+ )
+ assert infer_type(inputs[1]).checked_type.dtype in ["int32", "int64"]
+
+ data_rank = len(infer_shape(inputs[0]))
+ assert data_rank > 0, "Data rank higher than 0 is expected"
+ indices_shape = infer_shape(inputs[1])
+ indices_rank = len(indices_shape)
+ assert indices_rank == data_rank, "Indices rank is not the same as
data one"
+ updates_shape = infer_shape(inputs[2])
+ updates_rank = len(updates_shape)
+ assert updates_rank == data_rank, "Updates rank is not the same as
data one"
+
+ for i in range(data_rank):
+ assert (
+ indices_shape[i] == updates_shape[i]
+ ), "Indices dimension size should be the same as updates one"
+
axis = attr.get("axis", 0)
- return _op.scatter(inputs[0], inputs[1], inputs[2], axis)
+ assert -data_rank <= axis < data_rank, "Axis is out of bounds"
+
+ return axis
+
+ @classmethod
+ def _impl_v9(cls, inputs, attr, params):
+ axis = cls._args_check(inputs, attr)
+ return _op.scatter_elements(inputs[0], inputs[1], inputs[2], axis)
class ScatterElements(OnnxOpConverter):
@@ -4991,7 +5017,7 @@ class ATen(OnnxOpConverter):
else:
mode = "add"
index_tensor = _op.stack(indices, axis=0)
- return _op.transform.scatter_nd(in_tensor, index_tensor, values, mode)
+ return _op.scatter_nd(in_tensor, index_tensor, values, mode)
@classmethod
def _reshape(cls, inputs, attr, params):
diff --git a/python/tvm/relay/frontend/paddlepaddle.py
b/python/tvm/relay/frontend/paddlepaddle.py
index e688369a07..78895e4b49 100755
--- a/python/tvm/relay/frontend/paddlepaddle.py
+++ b/python/tvm/relay/frontend/paddlepaddle.py
@@ -1741,10 +1741,10 @@ def convert_scatter(g, op, block):
index = _op.transform.broadcast_to(index, shape)
if overwrite:
- out = _op.scatter(x, index, updates, axis=0)
+ out = _op.scatter_elements(x, index, updates, axis=0)
else:
out = _op.scatter_elements(_op.zeros_like(x), index, updates, axis=0,
reduction="add")
- out += _op.scatter(x, index, _op.zeros_like(updates), axis=0)
+ out += _op.scatter_elements(x, index, _op.zeros_like(updates), axis=0)
g.add_node(op.output("Out")[0], out)
@@ -1826,7 +1826,7 @@ def convert_slice(g, op, block):
if len(axes) < dims:
if isinstance(starts, _expr.Expr):
- starts = _op.scatter(
+ starts = _op.scatter_elements(
_op.const([0] * dims,
dtype=infer_type(starts).checked_type.dtype),
indices,
starts,
@@ -1857,7 +1857,7 @@ def convert_slice(g, op, block):
if len(axes) < dims:
if isinstance(ends, _expr.Expr):
- ends = _op.scatter(
+ ends = _op.scatter_elements(
_expr.const(
np.array([np.iinfo(np.int32).max] * dims),
dtype=infer_type(ends).checked_type.dtype,
@@ -1892,7 +1892,7 @@ def convert_slice(g, op, block):
if len(axes) < dims:
if isinstance(strides, _expr.Expr):
- strides = _op.scatter(
+ strides = _op.scatter_elements(
_expr.const(
np.array([1] * dims),
dtype=infer_type(strides).checked_type.dtype,
diff --git a/python/tvm/relay/frontend/pytorch.py
b/python/tvm/relay/frontend/pytorch.py
index 0dc9ffef6f..3cdfc5cb4e 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -496,7 +496,7 @@ class PyTorchOpConverter:
end[dim] = target_end
else:
target_end = _expr.const(target_end)
- end = _op.scatter(
+ end = _op.scatter_elements(
end,
_op.expand_dims(_expr.const(dim), axis=0),
_op.expand_dims(target_end, axis=0),
@@ -508,7 +508,7 @@ class PyTorchOpConverter:
ttype = self.infer_type(target_end).dtype
if str(ttype) != axis_dtype:
target_end = _op.cast(target_end, axis_dtype)
- end = _op.scatter(
+ end = _op.scatter_elements(
end,
_op.expand_dims(_expr.const(dim), axis=0),
_op.expand_dims(target_end, axis=0),
@@ -2554,11 +2554,62 @@ class PyTorchOpConverter:
return self.nonzero(inputs, input_types, is_numpy_style=False)
def scatter(self, inputs, input_types):
+ assert len(inputs) == 4 or len(inputs) == 5, (
+ "scatter takes 4 or 5 inputs: data, dim, index, src, reduce
(optional), "
+ + "but {} given".format(len(inputs))
+ )
data = inputs[0]
axis = int(inputs[1])
index = inputs[2]
src = inputs[3]
- return _op.transform.scatter(data, index, src, axis)
+ if len(inputs) == 5:
+ reduce = inputs[4]
+ else:
+ reduce = "update"
+
+ data_shape = self.infer_shape(data)
+ data_rank = len(data_shape)
+ index_shape = self.infer_shape(index)
+ index_rank = len(index_shape)
+ # When index is empty, the operation returns data unchanged
+ if self.is_empty_shape(index_shape):
+ return data
+
+ if np.isscalar(src):
+ assert self.infer_type(src).dtype == "float", "Scalar source can
be float only"
+ src = _op.broadcast_to_like(src, data_shape)
+ src_shape = data_shape
+ else:
+ src_shape = self.infer_shape(src)
+ src_rank = len(src_shape)
+ assert data_rank == index_rank, "Index rank is not the same as data
rank"
+ assert data_rank == src_rank, "Src rank is not the same as data rank"
+
+ assert 0 <= axis < data_rank, "Dim is out of bounds"
+
+ for i in range(data_rank):
+ index_dim = index_shape[i]
+ src_dim = src_shape[i]
+ data_dim = data_shape[i]
+ # Skip check for dynamic dimensions
+ if not any([isinstance(index_dim, tvm.tir.Any),
isinstance(src_dim, tvm.tir.Any)]):
+ assert index_dim <= src_dim, "Index dim size should be less
than src one"
+ if i != axis and not any(
+ [isinstance(index_dim, tvm.tir.Any), isinstance(data_dim,
tvm.tir.Any)]
+ ):
+ assert index_dim <= data_dim, "Index dim size should be less
than data one"
+
+ if reduce is None:
+ reduce = "update"
+ elif reduce == "multiply":
+ reduce = "mul"
+ assert reduce in [
+ "update",
+ "add",
+ "mul",
+ ], 'reduce arg is expected from "add", "multiply" or None'
+
+ return _op.scatter_elements(data, index, src, axis, reduce)
def index_put(self, inputs, input_types):
in_tensor = inputs[0]
@@ -2571,7 +2622,7 @@ class PyTorchOpConverter:
mode = "add"
# Combine array of index tensors into one index tensor with shape (N,_)
index_tensor = _op.stack(indices, axis=0)
- return _op.transform.scatter_nd(in_tensor, index_tensor, values, mode)
+ return _op.scatter_nd(in_tensor, index_tensor, values, mode)
def scalar_tensor(self, inputs, input_types):
data = inputs[0]
diff --git a/python/tvm/relay/frontend/pytorch_utils.py
b/python/tvm/relay/frontend/pytorch_utils.py
index da4c9e039e..7de1248bda 100644
--- a/python/tvm/relay/frontend/pytorch_utils.py
+++ b/python/tvm/relay/frontend/pytorch_utils.py
@@ -331,7 +331,7 @@ def scatter_roi_align_result_pattern(levels,
roi_align_results, num_scales):
scatter_indices = is_op("repeat")(scatter_indices)
scatter_indices = is_op("repeat")(scatter_indices)
- scatter_res = is_op("scatter")(scatter_res, scatter_indices,
roi_align_results[i])
+ scatter_res = is_op("scatter_elements")(scatter_res, scatter_indices,
roi_align_results[i])
return is_op("reshape")(scatter_res)
diff --git a/python/tvm/relay/op/_transform.py
b/python/tvm/relay/op/_transform.py
index 12450dc809..11a10f7ee4 100644
--- a/python/tvm/relay/op/_transform.py
+++ b/python/tvm/relay/op/_transform.py
@@ -104,15 +104,6 @@ _reg.register_pattern("meta_schedule_layout_transform",
OpPattern.INJECTIVE)
# argwhere
_reg.register_strategy("argwhere", strategy.argwhere_strategy)
-# scatter
-@_reg.register_compute("scatter")
-def compute_scatter(attrs, inputs, output_type):
- """Compute definition of scatter"""
- return [topi.scatter(inputs[0], inputs[1], inputs[2], attrs.axis)]
-
-
-_reg.register_strategy("scatter", strategy.scatter_strategy)
-
# sparse_fill_empty_rows
@_reg.register_compute("sparse_fill_empty_rows")
def compute_sparse_fill_empty_rows(attrs, inputs, output_type):
@@ -677,7 +668,6 @@ def argwhere_shape_func(attrs, inputs, out_ndims):
return ValueError("Does not support rank higher than 5 in argwhere")
-_reg.register_shape_func("scatter", False, elemwise_shape_func)
_reg.register_shape_func("scatter_elements", False, elemwise_shape_func)
_reg.register_shape_func("scatter_nd", False, elemwise_shape_func)
diff --git a/python/tvm/relay/op/op_attrs.py b/python/tvm/relay/op/op_attrs.py
index 0214ae8a46..4e9a9a4707 100644
--- a/python/tvm/relay/op/op_attrs.py
+++ b/python/tvm/relay/op/op_attrs.py
@@ -529,11 +529,6 @@ class RequantizeAttrs(Attrs):
"""Attributes used in requantize operators"""
-@tvm._ffi.register_object("relay.attrs.ScatterAttrs")
-class ScatterAttrs(Attrs):
- """Attributes used in scatter operators"""
-
-
@tvm._ffi.register_object("relay.attrs.SequenceMaskAttrs")
class SequenceMaskAttrs(Attrs):
"""Attributes used in sequence_mask operators"""
diff --git a/python/tvm/relay/op/strategy/cuda.py
b/python/tvm/relay/op/strategy/cuda.py
index e0229a615d..c6ea692a8d 100644
--- a/python/tvm/relay/op/strategy/cuda.py
+++ b/python/tvm/relay/op/strategy/cuda.py
@@ -1062,23 +1062,23 @@ def sparse_dense_padded_strategy_cuda(attrs, inputs,
out_type, target):
return strategy
-@scatter_strategy.register(["cuda", "gpu"])
-def scatter_cuda(attrs, inputs, out_type, target):
- """scatter cuda strategy"""
+@scatter_elements_strategy.register(["cuda", "gpu"])
+def scatter_elements_cuda(attrs, inputs, out_type, target):
+ """scatter elements cuda strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
- wrap_compute_scatter(topi.cuda.scatter),
- wrap_topi_schedule(topi.cuda.schedule_scatter),
- name="scatter.cuda",
+ wrap_compute_scatter_elements(topi.cuda.scatter_elements),
+ wrap_topi_schedule(topi.cuda.schedule_extern),
+ name="scatter_elements.cuda",
plevel=10,
)
rank = len(inputs[0].shape)
- with SpecializedCondition(rank == 1):
+ with SpecializedCondition(rank == 1 and attrs.reduction == "update"):
if can_use_thrust(target, "tvm.contrib.thrust.stable_sort_by_key"):
strategy.add_implementation(
- wrap_compute_scatter(topi.cuda.scatter_via_sort),
+ wrap_compute_scatter_elements(topi.cuda.scatter_via_sort),
wrap_topi_schedule(topi.cuda.schedule_scatter_via_sort),
name="scatter_via_sort.cuda",
plevel=9, # use the sequential version by default
@@ -1086,20 +1086,6 @@ def scatter_cuda(attrs, inputs, out_type, target):
return strategy
-@scatter_elements_strategy.register(["cuda", "gpu"])
-def scatter_elements_cuda(attrs, inputs, out_type, target):
- """scatter elements cuda strategy"""
- strategy = _op.OpStrategy()
- strategy.add_implementation(
- wrap_compute_scatter_elements(topi.cuda.scatter_elements),
- wrap_topi_schedule(topi.cuda.schedule_extern),
- name="scatter_elements.cuda",
- plevel=10,
- )
- # TODO(vvchernov): There is possible specification for rank=1 as for
scatter
- return strategy
-
-
@scatter_nd_strategy.register(["cuda", "gpu"])
def scatter_nd_cuda(attrs, inputs, out_type, target):
"""scatter_nd cuda strategy"""
diff --git a/python/tvm/relay/op/strategy/generic.py
b/python/tvm/relay/op/strategy/generic.py
index 4641fb18f7..b08d92a3cc 100644
--- a/python/tvm/relay/op/strategy/generic.py
+++ b/python/tvm/relay/op/strategy/generic.py
@@ -1548,27 +1548,6 @@ def proposal_strategy(attrs, inputs, out_type, target):
return strategy
-# scatter
-@override_native_generic_func("scatter_strategy")
-def scatter_strategy(attrs, outs, out_type, target):
- strategy = _op.OpStrategy()
- strategy.add_implementation(
- wrap_compute_scatter(topi.scatter),
- wrap_topi_schedule(topi.generic.schedule_scatter),
- name="scatter.generic",
- )
- return strategy
-
-
-def wrap_compute_scatter(topi_compute):
- """Wrap scatter topi compute"""
-
- def _compute_scatter(attrs, inputs, _):
- return [topi_compute(inputs[0], inputs[1], inputs[2], attrs.axis)]
-
- return _compute_scatter
-
-
# scatter_elements
@override_native_generic_func("scatter_elements_strategy")
def scatter_elements_strategy(attrs, inputs, out_type, target):
@@ -1579,6 +1558,7 @@ def scatter_elements_strategy(attrs, inputs, out_type,
target):
wrap_topi_schedule(topi.generic.schedule_extern),
name="scatter_elements.generic",
)
+ # TODO(vvchernov): implement specialized case (rank=1,
reduction="update"), see cuda strategy
return strategy
diff --git a/python/tvm/relay/op/strategy/rocm.py
b/python/tvm/relay/op/strategy/rocm.py
index 89cac0db4a..d80f347975 100644
--- a/python/tvm/relay/op/strategy/rocm.py
+++ b/python/tvm/relay/op/strategy/rocm.py
@@ -105,23 +105,23 @@ def argsort_strategy_cuda(attrs, inputs, out_type,
target):
return strategy
-@scatter_strategy.register(["rocm"])
-def scatter_cuda(attrs, inputs, out_type, target):
+@scatter_elements_strategy.register(["rocm"])
+def scatter_elements_cuda(attrs, inputs, out_type, target):
"""scatter rocm strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
- wrap_compute_scatter(topi.cuda.scatter),
- wrap_topi_schedule(topi.cuda.schedule_scatter),
- name="scatter.rocm",
+ wrap_compute_scatter_elements(topi.cuda.scatter_elements),
+ wrap_topi_schedule(topi.cuda.schedule_extern),
+ name="scatter_elements.rocm",
plevel=10,
)
rank = len(inputs[0].shape)
- with SpecializedCondition(rank == 1):
+ with SpecializedCondition(rank == 1 and attrs.reduction == "update"):
if can_use_rocthrust(target, "tvm.contrib.thrust.stable_sort_by_key"):
strategy.add_implementation(
- wrap_compute_scatter(topi.cuda.scatter_via_sort),
+ wrap_compute_scatter_elements(topi.cuda.scatter_via_sort),
wrap_topi_schedule(topi.cuda.schedule_scatter_via_sort),
name="scatter_via_sort.rocm",
plevel=9, # use the sequential version by default
diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py
index 6718347b31..f2d066d17e 100644
--- a/python/tvm/relay/op/transform.py
+++ b/python/tvm/relay/op/transform.py
@@ -353,31 +353,6 @@ def argwhere(condition):
return _make.argwhere(condition)
-def scatter(data, indices, updates, axis):
- """Update data at positions defined by indices with values in updates.
-
- Parameters
- ----------
- data : relay.Expr
- The input data to the operator.
-
- indices : relay.Expr
- The index locations to update.
-
- updates : relay.Expr
- The values to update.
-
- axis : int
- The axis to scatter on.
-
- Returns
- -------
- ret : relay.Expr
- The computed result.
- """
- return _make.scatter(data, indices, updates, axis)
-
-
def scatter_elements(data, indices, updates, axis=0, reduction="update"):
"""Scatter elements with updating data by reduction of values in updates
at positions defined by indices.
diff --git a/python/tvm/relay/transform/mixed_precision.py
b/python/tvm/relay/transform/mixed_precision.py
index 5018ba9ba9..f6bb8b8150 100644
--- a/python/tvm/relay/transform/mixed_precision.py
+++ b/python/tvm/relay/transform/mixed_precision.py
@@ -63,6 +63,8 @@ DEFAULT_FOLLOW_LIST = [
"tile",
"dyn.tile",
"scatter",
+ "scatter_elements",
+ "scatter_nd",
"full",
"dyn.full",
"nn.depth_to_space",
diff --git a/python/tvm/topi/cuda/scatter.py b/python/tvm/topi/cuda/scatter.py
index c88c3086f3..39ef5a5a42 100644
--- a/python/tvm/topi/cuda/scatter.py
+++ b/python/tvm/topi/cuda/scatter.py
@@ -14,446 +14,15 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-# pylint: disable=invalid-name, no-member, too-many-locals,
too-many-arguments, too-many-statements, singleton-comparison, unused-argument
-"""Scatter operator """
+# pylint: disable=invalid-name
+"""Scatter operators"""
import tvm
from tvm import te, tir, autotvm
from ..scatter import _verify_scatter_nd_inputs
from ..generic import schedule_extern
from .nms import atomic_add
from .sort import stable_sort_by_key_thrust
-from ..utils import prod, ceil_div
-
-
-def _memcpy_ir(ib, out_ptr, data_ptr, shape):
- fused = prod(shape)
- with ib.new_scope():
- num_thread =
int(tvm.target.Target.current(allow_none=False).max_num_threads)
- num_blocks = ceil_div(fused, num_thread)
- bx = te.thread_axis("blockIdx.x")
- ib.scope_attr(bx, "thread_extent", num_blocks)
- tx = te.thread_axis("threadIdx.x")
- ib.scope_attr(tx, "thread_extent", num_thread)
- tid = bx * num_thread + tx
-
- with ib.if_scope(tid < fused):
- out_ptr[tid] = data_ptr[tid]
-
-
-def gen_ir_1d(data, indices, updates, axis, out, update_func):
- """Generate scatter ir for 1d inputs
-
- Parameters
- ----------
- data : tir.Tensor
- The input data to the operator.
-
- indices : tir.Tensor
- The index locations to update.
-
- updates : tir.Tensor
- The values to update.
-
- axis : int
- The axis to scatter on
-
- out : tir.Tensor
- The output tensor.
-
- update_func: function
- The function to be applied to a destination and the corresponding
update.
-
- Returns
- -------
- ret : tir
- The computational ir.
- """
- assert axis == 0
- n = data.shape[0]
-
- ib = tvm.tir.ir_builder.create()
-
- out_ptr = ib.buffer_ptr(out)
- data_ptr = ib.buffer_ptr(data)
-
- _memcpy_ir(ib, out_ptr, data_ptr, data.shape)
-
- indices_ptr = ib.buffer_ptr(indices)
- updates_ptr = ib.buffer_ptr(updates)
-
- ni = indices.shape[0]
-
- with ib.new_scope():
- bx = te.thread_axis("blockIdx.x")
- ib.scope_attr(bx, "thread_extent", 1)
- with ib.for_range(0, ni, name="i") as i:
- index = indices_ptr[i]
- with ib.if_scope(index < 0):
- update_func(out_ptr, index + n, updates_ptr[i])
- with ib.else_scope():
- update_func(out_ptr, index, updates_ptr[i])
-
- return ib.get()
-
-
-def gen_ir_2d(data, indices, updates, axis, out, update_func):
- """Generate scatter ir for 2d inputs
-
- Parameters
- ----------
- data : tir.Tensor
- The input data to the operator.
-
- indices : tir.Tensor
- The index locations to update.
-
- updates : tir.Tensor
- The values to update.
-
- axis : int
- The axis to scatter on
-
- out : tir.Tensor
- The output tensor.
-
- update_func: function
- The function to be applied to a destination and the corresponding
update
-
- Returns
- -------
- ret : tir
- The computational ir.
- """
- n = data.shape[0]
- c = data.shape[1]
-
- ib = tvm.tir.ir_builder.create()
-
- out_ptr = ib.buffer_ptr(out)
- data_ptr = ib.buffer_ptr(data)
-
- _memcpy_ir(ib, out_ptr, data_ptr, data.shape)
-
- indices_ptr = ib.buffer_ptr(indices)
- updates_ptr = ib.buffer_ptr(updates)
-
- ni = indices.shape[0]
- ci = indices.shape[1]
-
- if axis == 0:
- with ib.new_scope():
- j = te.thread_axis("blockIdx.x")
- ib.scope_attr(j, "thread_extent", ci)
- with ib.for_range(0, ni, name="i") as i:
- idx = i * ci + j
- index = indices_ptr[idx]
- with ib.if_scope(index < 0):
- update_func(out_ptr, (index + n) * c + j, updates_ptr[idx])
- with ib.else_scope():
- update_func(out_ptr, index * c + j, updates_ptr[idx])
- else:
- with ib.new_scope():
- i = te.thread_axis("blockIdx.x")
- ib.scope_attr(i, "thread_extent", ni)
- with ib.for_range(0, ci, name="j") as j:
- idx = i * ci + j
- index = indices_ptr[idx]
- with ib.if_scope(index < 0):
- update_func(out_ptr, i * c + (index + c), updates_ptr[idx])
- with ib.else_scope():
- update_func(out_ptr, i * c + index, updates_ptr[idx])
- return ib.get()
-
-
-def gen_ir_3d(data, indices, updates, axis, out, update_func):
- """Generate scatter ir for 3d inputs
-
- Parameters
- ----------
- data : tir.Tensor
- The input data to the operator.
-
- indices : tir.Tensor
- The index locations to update.
-
- updates : tir.Tensor
- The values to update.
-
- axis : int
- The axis to scatter on
-
- out : tir.Tensor
- The output tensor.
-
- update_func: function
- The function to be applied to a destination and the corresponding
update
-
- Returns
- -------
- ret : tir
- The computational ir.
- """
- warp_size = tvm.target.Target.current(False).thread_warp_size
-
- n = data.shape[0]
- c = data.shape[1]
- h = data.shape[2]
-
- ib = tvm.tir.ir_builder.create()
-
- out_ptr = ib.buffer_ptr(out)
- data_ptr = ib.buffer_ptr(data)
-
- _memcpy_ir(ib, out_ptr, data_ptr, data.shape)
-
- indices_ptr = ib.buffer_ptr(indices)
- updates_ptr = ib.buffer_ptr(updates)
- ni = indices.shape[0]
- ci = indices.shape[1]
- hi = indices.shape[2]
-
- if axis == 0:
- with ib.new_scope():
- j = te.thread_axis("blockIdx.x")
- ib.scope_attr(j, "thread_extent", ci)
- tx = te.thread_axis("threadIdx.x")
- ib.scope_attr(tx, "thread_extent", warp_size)
- with ib.for_range(0, ni, name="i") as i:
- with ib.for_range(0, ceil_div(hi, warp_size), name="k") as k_:
- k = k_ * warp_size + tx
- with ib.if_scope(k < hi):
- idx = (i * ci + j) * hi + k
- index = indices_ptr[idx]
- with ib.if_scope(index < 0):
- update_func(out_ptr, ((index + n) * c + j) * h +
k, updates_ptr[idx])
- with ib.else_scope():
- update_func(out_ptr, (index * c + j) * h + k,
updates_ptr[idx])
- elif axis == 1:
- with ib.new_scope():
- i = te.thread_axis("blockIdx.x")
- ib.scope_attr(i, "thread_extent", ni)
- tx = te.thread_axis("threadIdx.x")
- ib.scope_attr(tx, "thread_extent", warp_size)
- with ib.for_range(0, ci, name="j") as j:
- with ib.for_range(0, ceil_div(hi, warp_size), name="k") as k_:
- k = k_ * warp_size + tx
- with ib.if_scope(k < hi):
- idx = (i * ci + j) * hi + k
- index = indices_ptr[idx]
- with ib.if_scope(index < 0):
- update_func(out_ptr, (i * c + (index + c)) * h +
k, updates_ptr[idx])
- with ib.else_scope():
- update_func(out_ptr, (i * c + index) * h + k,
updates_ptr[idx])
- else:
- with ib.new_scope():
- i = te.thread_axis("blockIdx.x")
- ib.scope_attr(i, "thread_extent", ni)
- j = te.thread_axis("blockIdx.y")
- ib.scope_attr(j, "thread_extent", ci)
- with ib.for_range(0, hi, name="k") as k:
- idx = (i * ci + j) * hi + k
- index = indices_ptr[idx]
- with ib.if_scope(index < 0):
- update_func(out_ptr, (i * c + j) * h + (index + h),
updates_ptr[idx])
- with ib.else_scope():
- update_func(out_ptr, (i * c + j) * h + index,
updates_ptr[idx])
- return ib.get()
-
-
-def gen_ir_4d(data, indices, updates, axis, out, update_func):
- """Generate scatter ir for 4d inputs
-
- Parameters
- ----------
- data : tir.Tensor
- The input data to the operator.
-
- indices : tir.Tensor
- The index locations to update.
-
- updates : tir.Tensor
- The values to update.
-
- axis : int
- The axis to scatter on
-
- out : tir.Tensor
- The output tensor.
-
- update_func: function
- The function to be applied to a destination and the corresponding
update
-
- Returns
- -------
- ret : tir
- The computational ir.
- """
- warp_size = tvm.target.Target.current(False).thread_warp_size
-
- n = data.shape[0]
- c = data.shape[1]
- h = data.shape[2]
- w = data.shape[3]
-
- ib = tvm.tir.ir_builder.create()
-
- out_ptr = ib.buffer_ptr(out)
- data_ptr = ib.buffer_ptr(data)
- _memcpy_ir(ib, out_ptr, data_ptr, data.shape)
-
- indices_ptr = ib.buffer_ptr(indices)
- updates_ptr = ib.buffer_ptr(updates)
- ni = indices.shape[0]
- ci = indices.shape[1]
- hi = indices.shape[2]
- wi = indices.shape[3]
-
- if axis == 0:
- with ib.new_scope():
- j = te.thread_axis("blockIdx.y")
- ib.scope_attr(j, "thread_extent", ci)
- k = te.thread_axis("blockIdx.z")
- ib.scope_attr(k, "thread_extent", hi)
- tx = te.thread_axis("threadIdx.x")
- ib.scope_attr(tx, "thread_extent", warp_size)
- with ib.for_range(0, ni, name="i") as i:
- with ib.for_range(0, ceil_div(wi, warp_size), name="l") as l_:
- l = l_ * warp_size + tx
- with ib.if_scope(l < wi):
- idx = ((i * ci + j) * hi + k) * wi + l
- index = indices_ptr[idx]
- with ib.if_scope(index < 0):
- update_func(
- out_ptr, (((index + n) * c + j) * h + k) * w +
l, updates_ptr[idx]
- )
- with ib.else_scope():
- update_func(
- out_ptr, ((index * c + j) * h + k) * w + l,
updates_ptr[idx]
- )
- elif axis == 1:
- with ib.new_scope():
- i = te.thread_axis("blockIdx.x")
- ib.scope_attr(i, "thread_extent", ni)
- k = te.thread_axis("blockIdx.z")
- ib.scope_attr(k, "thread_extent", hi)
- tx = te.thread_axis("threadIdx.x")
- ib.scope_attr(tx, "thread_extent", warp_size)
- with ib.for_range(0, ci, name="j") as j:
- with ib.for_range(0, ceil_div(wi, warp_size), name="l") as l_:
- l = l_ * warp_size + tx
- with ib.if_scope(l < wi):
- idx = ((i * ci + j) * hi + k) * wi + l
- index = indices_ptr[idx]
- with ib.if_scope(index < 0):
- update_func(
- out_ptr, ((i * c + (index + c)) * h + k) * w +
l, updates_ptr[idx]
- )
- with ib.else_scope():
- update_func(
- out_ptr, ((i * c + index) * h + k) * w + l,
updates_ptr[idx]
- )
- elif axis == 2:
- with ib.new_scope():
- i = te.thread_axis("blockIdx.x")
- ib.scope_attr(i, "thread_extent", ni)
- j = te.thread_axis("blockIdx.y")
- ib.scope_attr(j, "thread_extent", ci)
- tx = te.thread_axis("threadIdx.x")
- ib.scope_attr(tx, "thread_extent", warp_size)
- with ib.for_range(0, hi, name="k") as k:
- with ib.for_range(0, ceil_div(wi, warp_size), name="l") as l_:
- l = l_ * warp_size + tx
- with ib.if_scope(l < wi):
- idx = ((i * ci + j) * hi + k) * wi + l
- index = indices_ptr[idx]
- with ib.if_scope(index < 0):
- update_func(
- out_ptr, ((i * c + j) * h + (index + h)) * w +
l, updates_ptr[idx]
- )
- with ib.else_scope():
- update_func(
- out_ptr, ((i * c + j) * h + index) * w + l,
updates_ptr[idx]
- )
- else:
- with ib.new_scope():
- i = te.thread_axis("blockIdx.x")
- ib.scope_attr(i, "thread_extent", ni)
- j = te.thread_axis("blockIdx.y")
- ib.scope_attr(j, "thread_extent", ci)
- k = te.thread_axis("blockIdx.z")
- ib.scope_attr(k, "thread_extent", hi)
- with ib.for_range(0, wi, name="l") as l:
- idx = ((i * ci + j) * hi + k) * wi + l
- index = indices_ptr[idx]
- with ib.if_scope(index < 0):
- update_func(out_ptr, ((i * c + j) * h + k) * w + (index +
w), updates_ptr[idx])
- with ib.else_scope():
- update_func(out_ptr, ((i * c + j) * h + k) * w + index,
updates_ptr[idx])
- return ib.get()
-
-
[email protected]_topi_compute("scatter.cuda")
-def scatter(cfg, data, indices, updates, axis=0):
- """Update data at positions defined by indices with values in updates
-
- Parameters
- ----------
- data : relay.Expr
- The input data to the operator.
-
- indices : relay.Expr
- The index locations to update.
-
- updates : relay.Expr
- The values to update.
-
- axis : int
- The axis to scatter on
-
- Returns
- -------
- ret : relay.Expr
- The computed result.
- """
- if axis < 0:
- axis += len(data.shape)
- assert axis >= 0
- assert axis < len(data.shape)
-
- rank = len(data.shape)
- assert 1 <= rank <= 4, "scatter only supports 1-4 dimensions"
-
- ir_funcs = {
- 1: gen_ir_1d,
- 2: gen_ir_2d,
- 3: gen_ir_3d,
- 4: gen_ir_4d,
- }
-
- def update_func(dst_ptr, dst_index, update):
- dst_ptr[dst_index] = update
-
- out_shape = data.shape
- out_buf = tvm.tir.decl_buffer(out_shape, data.dtype, "out_buf")
-
- cfg.add_flop(1) # A dummy value to satisfy AutoTVM
-
- out = te.extern(
- [out_shape],
- [data, indices, updates],
- lambda ins, outs: ir_funcs[rank](ins[0], ins[1], ins[2], axis,
outs[0], update_func),
- dtype=data.dtype,
- out_buffers=[out_buf],
- name="scatter_gpu",
- tag="scatter_gpu",
- )
-
- return out
-
-
[email protected]_topi_schedule("scatter.cuda")
-def schedule_scatter(_, outs):
- return schedule_extern(outs)
+from ..utils import ceil_div
def gen_scatter_1d_thrust(data, indices_sorted, updates_sorted, out):
@@ -540,7 +109,7 @@ def gen_scatter_1d_thrust(data, indices_sorted,
updates_sorted, out):
@autotvm.register_topi_compute("scatter_via_sort.cuda")
-def scatter_via_sort(cfg, data, indices, updates, axis=0):
+def scatter_via_sort(cfg, data, indices, updates, axis=0, reduction="add"):
"""Update data at positions defined by indices with values in updates
Parameters
@@ -562,6 +131,7 @@ def scatter_via_sort(cfg, data, indices, updates, axis=0):
ret : relay.Expr
The computed result.
"""
+ assert reduction == "add"
if axis < 0:
axis += len(data.shape)
assert axis == 0 and len(data.shape) == 1, "sorting based scatter only
supported for 1d input"
diff --git a/python/tvm/topi/generic/search.py
b/python/tvm/topi/generic/search.py
index 826194e75c..9a80e678c2 100644
--- a/python/tvm/topi/generic/search.py
+++ b/python/tvm/topi/generic/search.py
@@ -36,22 +36,6 @@ def schedule_argwhere(outs):
return _default_schedule(outs, False)
-def schedule_scatter(outs):
- """Schedule for scatter operator.
-
- Parameters
- ----------
- outs: Array of Tensor
- The computation graph description of scatter.
-
- Returns
- -------
- s: Schedule
- The computation schedule for the op.
- """
- return _default_schedule(outs, False)
-
-
def schedule_sparse_fill_empty_rows(outs):
return _default_schedule(outs, False)
diff --git a/python/tvm/topi/scatter.py b/python/tvm/topi/scatter.py
index 45629c005f..799b3d1673 100644
--- a/python/tvm/topi/scatter.py
+++ b/python/tvm/topi/scatter.py
@@ -14,191 +14,12 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-# pylint: disable=invalid-name, too-many-arguments, too-many-nested-blocks
-"""Scatter operator"""
+# pylint: disable=invalid-name
+"""ScatterND operator"""
from tvm import te, tir # hide redefinition of min and max
from tvm.tir import expr
[email protected]
-def _scatter_1d(data, indices, updates):
- out = output_tensor(data.shape, data.dtype)
- for i in range(data.shape[0]):
- out[i] = data[i]
- for i in range(indices.shape[0]):
- out[indices[i] if indices[i] >= 0 else indices[i] + data.shape[0]] =
updates[i]
- return out
-
-
[email protected]
-def _scatter_2d(data, indices, updates, axis):
- out = output_tensor(data.shape, data.dtype)
- for i in range(data.shape[0]):
- for j in range(data.shape[1]):
- out[i, j] = data[i, j]
- if axis == 0:
- for i in range(indices.shape[0]):
- for j in range(indices.shape[1]):
- out[
- indices[i, j] if indices[i, j] >= 0 else indices[i, j] +
data.shape[axis], j
- ] = updates[i, j]
- else:
- for i in range(indices.shape[0]):
- for j in range(indices.shape[1]):
- out[
- i, indices[i, j] if indices[i, j] >= 0 else indices[i, j]
+ data.shape[axis]
- ] = updates[i, j]
-
- return out
-
-
[email protected]
-def _scatter_3d(data, indices, updates, axis):
- out = output_tensor(data.shape, data.dtype)
- for i in range(data.shape[0]):
- for j in range(data.shape[1]):
- for k in range(data.shape[2]):
- out[i, j, k] = data[i, j, k]
- if axis == 0:
- for i in range(indices.shape[0]):
- for j in range(indices.shape[1]):
- for k in range(indices.shape[2]):
- out[
- indices[i, j, k]
- if indices[i, j, k] >= 0
- else indices[i, j, k] + data.shape[axis],
- j,
- k,
- ] = updates[i, j, k]
- elif axis == 1:
- for i in range(indices.shape[0]):
- for j in range(indices.shape[1]):
- for k in range(indices.shape[2]):
- out[
- i,
- indices[i, j, k]
- if indices[i, j, k] >= 0
- else indices[i, j, k] + data.shape[axis],
- k,
- ] = updates[i, j, k]
- else:
- for i in range(indices.shape[0]):
- for j in range(indices.shape[1]):
- for k in range(indices.shape[2]):
- out[
- i,
- j,
- indices[i, j, k]
- if indices[i, j, k] >= 0
- else indices[i, j, k] + data.shape[axis],
- ] = updates[i, j, k]
-
- return out
-
-
[email protected]
-def _scatter_4d(data, indices, updates, axis):
- out = output_tensor(data.shape, data.dtype)
- for i in range(data.shape[0]):
- for j in range(data.shape[1]):
- for k in range(data.shape[2]):
- for l in range(data.shape[3]):
- out[i, j, k, l] = data[i, j, k, l]
-
- if axis == 0:
- for i in range(indices.shape[0]):
- for j in range(indices.shape[1]):
- for k in range(indices.shape[2]):
- for l in range(indices.shape[3]):
- out[
- indices[i, j, k, l]
- if indices[i, j, k, l] >= 0
- else indices[i, j, k, l] + data.shape[axis],
- j,
- k,
- l,
- ] = updates[i, j, k, l]
- elif axis == 1:
- for i in range(indices.shape[0]):
- for j in range(indices.shape[1]):
- for k in range(indices.shape[2]):
- for l in range(indices.shape[3]):
- out[
- i,
- indices[i, j, k, l]
- if indices[i, j, k, l] >= 0
- else indices[i, j, k, l] + data.shape[axis],
- k,
- l,
- ] = updates[i, j, k, l]
- elif axis == 2:
- for i in range(indices.shape[0]):
- for j in range(indices.shape[1]):
- for k in range(indices.shape[2]):
- for l in range(indices.shape[3]):
- out[
- i,
- j,
- indices[i, j, k, l]
- if indices[i, j, k, l] >= 0
- else indices[i, j, k, l] + data.shape[axis],
- l,
- ] = updates[i, j, k, l]
- else:
- for i in range(indices.shape[0]):
- for j in range(indices.shape[1]):
- for k in range(indices.shape[2]):
- for l in range(indices.shape[3]):
- out[
- i,
- j,
- k,
- indices[i, j, k, l]
- if indices[i, j, k, l] >= 0
- else indices[i, j, k, l] + data.shape[axis],
- ] = updates[i, j, k, l]
-
- return out
-
-
-def scatter(data, indices, updates, axis=0):
- """Update data at positions defined by indices with values in updates
-
- Parameters
- ----------
- data : relay.Expr
- The input data to the operator.
-
- indices : relay.Expr
- The index locations to update.
-
- updates : relay.Expr
- The values to update.
-
- axis : int
- The axis to scatter on
-
- Returns
- -------
- ret : relay.Expr
- The computed result.
- """
- if axis < 0:
- axis += len(data.shape)
- assert axis >= 0
- assert axis < len(data.shape)
-
- if len(data.shape) == 1:
- return _scatter_1d(data, indices, updates)
- if len(data.shape) == 2:
- return _scatter_2d(data, indices, updates, axis)
- if len(data.shape) == 3:
- return _scatter_3d(data, indices, updates, axis)
- if len(data.shape) == 4:
- return _scatter_4d(data, indices, updates, axis)
- raise ValueError("scatter only support for 1-4 dimensions")
-
-
def _verify_scatter_nd_inputs(data, indices, updates):
mdim = int(indices.shape[0])
assert mdim <= len(data.shape), (
diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc
index 907141c9cb..1bae1a4d9a 100644
--- a/src/relay/op/tensor/transform.cc
+++ b/src/relay/op/tensor/transform.cc
@@ -1095,54 +1095,6 @@ non-zero)doc" TVM_ADD_FILELINE)
.set_attr<TOpPattern>("TOpPattern", kOpaque)
.set_support_level(10);
-// Scatter
-TVM_REGISTER_NODE_TYPE(ScatterAttrs);
-
-// Scatter
-bool ScatterRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
- const TypeReporter& reporter) {
- ICHECK_EQ(num_inputs, 3);
- ICHECK_EQ(types.size(), 4);
- auto data = types[0].as<TensorTypeNode>();
- if (data == nullptr) {
- return false;
- }
- auto indices = types[1].as<TensorTypeNode>();
- if (indices == nullptr) {
- return false;
- }
- auto updates = types[2].as<TensorTypeNode>();
- if (updates == nullptr) {
- return false;
- }
- ICHECK(indices->dtype.is_int() || indices->dtype.is_uint())
- << "indices of scatter must be tensor of integer";
- const auto param = attrs.as<ScatterAttrs>();
- ICHECK(param != nullptr);
- reporter->Assign(types[3], TensorType(data->shape, data->dtype));
- return true;
-}
-
-TVM_REGISTER_GLOBAL("relay.op._make.scatter")
- .set_body_typed([](Expr data, Expr indices, Expr updates, int axis) {
- auto attrs = make_object<ScatterAttrs>();
- attrs->axis = std::move(axis);
- static const Op& op = Op::Get("scatter");
- return Call(op, {data, indices, updates}, Attrs(attrs), {});
- });
-
-RELAY_REGISTER_OP("scatter")
- .describe(
- R"doc(Update data at positions defined by indices with values in
updates)doc" TVM_ADD_FILELINE)
- .set_num_inputs(3)
- .add_argument("data", "Tensor", "The input data tensor.")
- .add_argument("indices", "Tensor", "The indices location tensor.")
- .add_argument("updates", "Tensor", "The values to update the input with.")
- .add_type_rel("Scatter", ScatterRel)
- .set_attr<TOpIsStateful>("TOpIsStateful", false)
- .set_attr<TOpPattern>("TOpPattern", kOpaque)
- .set_support_level(10);
-
// scatter_elements operator
TVM_REGISTER_NODE_TYPE(ScatterElementsAttrs);
@@ -1168,7 +1120,6 @@ bool ScatterElementsRel(const Array<Type>& types, int
num_inputs, const Attrs& a
<< "ScatterElements: expect updates type to be TensorType but got " <<
types[2];
return false;
}
- // TODO(vvchernov): ONNX requires int32 and int64
ICHECK(indices->dtype.is_int() || indices->dtype.is_uint())
<< "ScatterElements: indices must be a tensor of integers.";
diff --git a/tests/python/frontend/pytorch/test_forward.py
b/tests/python/frontend/pytorch/test_forward.py
index 2401e98bce..807c44a364 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -4232,12 +4232,17 @@ def test_forward_scatter():
verify_trace_model(test_fn_scatter(1), [in_data, in_index, in_src],
targets)
verify_trace_model(test_fn_scatter_add(1), [in_data, in_index, in_src],
targets)
- # Check empty indices for scatter_add
+ # Check empty indices
in_data = torch.zeros(2, 4)
in_index = torch.empty((0,))
in_src = torch.rand(2, 1)
+ verify_trace_model(test_fn_scatter(0), [in_data, in_index, in_src],
targets)
verify_trace_model(test_fn_scatter_add(0), [in_data, in_index, in_src],
targets)
+ # Check scalar source
+ # TODO(vvchernov): Scalar source is supported on TVM side, but torch
failes with
+ # input Tuple(Tensor, Tensor, float). What does scalar mean for torch in
this case?
+
def test_forward_scatter_reduce():
"""test_forward_scatter_reduce"""
diff --git a/tests/python/relay/test_op_level3.py
b/tests/python/relay/test_op_level3.py
index f18e935b57..493bf00fc6 100644
--- a/tests/python/relay/test_op_level3.py
+++ b/tests/python/relay/test_op_level3.py
@@ -1002,7 +1002,7 @@ def test_scatter(target, dev, executor_kind):
d = relay.var("d", relay.TensorType(dshape, "float32"))
i = relay.var("i", relay.TensorType(ishape, indices_dtype))
u = relay.var("u", relay.TensorType(ishape, "float32"))
- z = relay.op.scatter(d, i, u, axis)
+ z = relay.op.scatter_elements(d, i, u, axis)
func = relay.Function([d, i, u], z)
@@ -1055,7 +1055,7 @@ class TestDynamicScatter:
d = relay.var("d", relay.TensorType([relay.Any() for i in
range(len(dshape))], "float32"))
i = relay.var("i", relay.TensorType([relay.Any() for i in
range(len(ishape))], "int64"))
u = relay.var("u", relay.TensorType([relay.Any() for i in
range(len(ishape))], "float32"))
- z = relay.op.scatter(d, i, u, axis)
+ z = relay.op.scatter_elements(d, i, u, axis)
func = relay.Function([d, i, u], z)