This is an automated email from the ASF dual-hosted git repository.

andrewzhaoluo pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 6097df5307 [ONNX][TORCH] Replace scatter op by scatter_elements 
(#14019)
6097df5307 is described below

commit 6097df5307dc3f8254ae685c1bdbfe65f8934670
Author: Valery Chernov <[email protected]>
AuthorDate: Tue Feb 28 04:23:08 2023 +0400

    [ONNX][TORCH] Replace scatter op by scatter_elements (#14019)
    
    * remove scatter attr class
    
    * update pytorch: scatter was replaced by scatter_elements
    
    * remove scatter compute and strategy registration
    
    * remove scatter attrs registration
    
    * update onnx front-end: replace _op.scatter by _op.scatter_elements, add 
checks
    
    * update oneflow front-end
    
    * update paddlepaddle front-end
    
    * update pytorch utils
    
    * remove front-end scatter definition
    
    * fix scatter strategy for rocm
    
    * small update
    
    * remove scatter definition in back-end
    
    * remove scatter strategy for cuda, gpu. transfer special case to 
scatter_elements
    
    * fix test
    
    * small fix
    
    * upstream scatter with torch description
    
    * last upstream of scatter in pytorch front-end
    
    * fix reduction attribute in cuda strategy
    
    * set scalar to test instead of tensor. update check for dynamic dim
    
    * skip scalar source check in tests for scatter due to issue on torch side
    
    * remove scatter op implementation from topi/cuda
    
    * remove scatter op implementation from topi. small clean code
    
    ---------
    
    Co-authored-by: Valery Chernov <[email protected]>
---
 include/tvm/relay/attrs/transform.h           |   8 -
 python/tvm/relay/frontend/oneflow.py          |   2 +-
 python/tvm/relay/frontend/onnx.py             |  40 ++-
 python/tvm/relay/frontend/paddlepaddle.py     |  10 +-
 python/tvm/relay/frontend/pytorch.py          |  59 +++-
 python/tvm/relay/frontend/pytorch_utils.py    |   2 +-
 python/tvm/relay/op/_transform.py             |  10 -
 python/tvm/relay/op/op_attrs.py               |   5 -
 python/tvm/relay/op/strategy/cuda.py          |  30 +-
 python/tvm/relay/op/strategy/generic.py       |  22 +-
 python/tvm/relay/op/strategy/rocm.py          |  14 +-
 python/tvm/relay/op/transform.py              |  25 --
 python/tvm/relay/transform/mixed_precision.py |   2 +
 python/tvm/topi/cuda/scatter.py               | 440 +-------------------------
 python/tvm/topi/generic/search.py             |  16 -
 python/tvm/topi/scatter.py                    | 183 +----------
 src/relay/op/tensor/transform.cc              |  49 ---
 tests/python/frontend/pytorch/test_forward.py |   7 +-
 tests/python/relay/test_op_level3.py          |   4 +-
 19 files changed, 128 insertions(+), 800 deletions(-)

diff --git a/include/tvm/relay/attrs/transform.h 
b/include/tvm/relay/attrs/transform.h
index 51378c8697..e1da1895d0 100644
--- a/include/tvm/relay/attrs/transform.h
+++ b/include/tvm/relay/attrs/transform.h
@@ -148,14 +148,6 @@ struct ReshapeLikeAttrs : public 
tvm::AttrsNode<ReshapeLikeAttrs> {
   }
 };  // struct ReshapeLikeAttrs
 
-struct ScatterAttrs : public tvm::AttrsNode<ScatterAttrs> {
-  Integer axis;
-
-  TVM_DECLARE_ATTRS(ScatterAttrs, "relay.attrs.ScatterAttrs") {
-    TVM_ATTR_FIELD(axis).set_default(0).describe("The axis over which to 
select values.");
-  }
-};
-
 struct ScatterElementsAttrs : public tvm::AttrsNode<ScatterElementsAttrs> {
   Integer axis;
   String reduction;
diff --git a/python/tvm/relay/frontend/oneflow.py 
b/python/tvm/relay/frontend/oneflow.py
index ff4b5a5bcc..1aba9e6419 100644
--- a/python/tvm/relay/frontend/oneflow.py
+++ b/python/tvm/relay/frontend/oneflow.py
@@ -1227,7 +1227,7 @@ class Scatter(OneFlowOpConverter):
     @classmethod
     def _impl_v1(cls, inputs, attrs, params):
         axis = attrs.get("axis", 0)
-        return _op.scatter(inputs[0], inputs[1], inputs[2], axis)
+        return _op.scatter_elements(inputs[0], inputs[1], inputs[2], axis)
 
 
 class Unsqueeze(OneFlowOpConverter):
diff --git a/python/tvm/relay/frontend/onnx.py 
b/python/tvm/relay/frontend/onnx.py
index 2a18906272..7c55bfefb7 100644
--- a/python/tvm/relay/frontend/onnx.py
+++ b/python/tvm/relay/frontend/onnx.py
@@ -1967,7 +1967,7 @@ class MaxUnpool(OnnxOpConverter):
         # Create a tensor of zeros then scatter our data through it.
         zeros_tensor = _op.zeros(total_output_shape, data_type)
         # We need to flatten all our tensors before scattering.
-        flat_tensor = _op.scatter(
+        flat_tensor = _op.scatter_elements(
             _op.reshape(zeros_tensor, [-1]),
             _op.reshape(indices, [-1]),
             _op.reshape(data, [-1]),
@@ -2734,15 +2734,15 @@ class Slice(OnnxOpConverter):
         # Update the starts and ends according to axes if required.
         if axes is not None:
             data_shape = shape_of(inputs[0], 
dtype=infer_type(ends).checked_type.dtype)
-            starts = _op.scatter(
+            starts = _op.scatter_elements(
                 _op.const([0] * data_rank, 
dtype=infer_type(starts).checked_type.dtype),
                 axes,
                 starts,
                 axis=0,
             )
-            ends = _op.scatter(data_shape, axes, ends, axis=0)
+            ends = _op.scatter_elements(data_shape, axes, ends, axis=0)
             if steps is not None:
-                steps = _op.scatter(
+                steps = _op.scatter_elements(
                     _op.const([1] * data_rank, 
dtype=infer_type(steps).checked_type.dtype),
                     axes,
                     steps,
@@ -2848,9 +2848,35 @@ class Scatter(OnnxOpConverter):
     """Operator converter for Scatter."""
 
     @classmethod
-    def _impl_v9(cls, inputs, attr, params):
+    def _args_check(cls, inputs, attr):
+        assert len(inputs) == 3, "Scatter takes 3 inputs (data, indices, 
updates), {} given".format(
+            len(inputs)
+        )
+        assert infer_type(inputs[1]).checked_type.dtype in ["int32", "int64"]
+
+        data_rank = len(infer_shape(inputs[0]))
+        assert data_rank > 0, "Data rank higher than 0 is expected"
+        indices_shape = infer_shape(inputs[1])
+        indices_rank = len(indices_shape)
+        assert indices_rank == data_rank, "Indices rank is not the same as 
data one"
+        updates_shape = infer_shape(inputs[2])
+        updates_rank = len(updates_shape)
+        assert updates_rank == data_rank, "Updates rank is not the same as 
data one"
+
+        for i in range(data_rank):
+            assert (
+                indices_shape[i] == updates_shape[i]
+            ), "Indices dimension size should be the same as updates one"
+
         axis = attr.get("axis", 0)
-        return _op.scatter(inputs[0], inputs[1], inputs[2], axis)
+        assert -data_rank <= axis < data_rank, "Axis is out of bounds"
+
+        return axis
+
+    @classmethod
+    def _impl_v9(cls, inputs, attr, params):
+        axis = cls._args_check(inputs, attr)
+        return _op.scatter_elements(inputs[0], inputs[1], inputs[2], axis)
 
 
 class ScatterElements(OnnxOpConverter):
@@ -4991,7 +5017,7 @@ class ATen(OnnxOpConverter):
         else:
             mode = "add"
         index_tensor = _op.stack(indices, axis=0)
-        return _op.transform.scatter_nd(in_tensor, index_tensor, values, mode)
+        return _op.scatter_nd(in_tensor, index_tensor, values, mode)
 
     @classmethod
     def _reshape(cls, inputs, attr, params):
diff --git a/python/tvm/relay/frontend/paddlepaddle.py 
b/python/tvm/relay/frontend/paddlepaddle.py
index e688369a07..78895e4b49 100755
--- a/python/tvm/relay/frontend/paddlepaddle.py
+++ b/python/tvm/relay/frontend/paddlepaddle.py
@@ -1741,10 +1741,10 @@ def convert_scatter(g, op, block):
     index = _op.transform.broadcast_to(index, shape)
 
     if overwrite:
-        out = _op.scatter(x, index, updates, axis=0)
+        out = _op.scatter_elements(x, index, updates, axis=0)
     else:
         out = _op.scatter_elements(_op.zeros_like(x), index, updates, axis=0, 
reduction="add")
-        out += _op.scatter(x, index, _op.zeros_like(updates), axis=0)
+        out += _op.scatter_elements(x, index, _op.zeros_like(updates), axis=0)
     g.add_node(op.output("Out")[0], out)
 
 
@@ -1826,7 +1826,7 @@ def convert_slice(g, op, block):
 
     if len(axes) < dims:
         if isinstance(starts, _expr.Expr):
-            starts = _op.scatter(
+            starts = _op.scatter_elements(
                 _op.const([0] * dims, 
dtype=infer_type(starts).checked_type.dtype),
                 indices,
                 starts,
@@ -1857,7 +1857,7 @@ def convert_slice(g, op, block):
 
     if len(axes) < dims:
         if isinstance(ends, _expr.Expr):
-            ends = _op.scatter(
+            ends = _op.scatter_elements(
                 _expr.const(
                     np.array([np.iinfo(np.int32).max] * dims),
                     dtype=infer_type(ends).checked_type.dtype,
@@ -1892,7 +1892,7 @@ def convert_slice(g, op, block):
 
     if len(axes) < dims:
         if isinstance(strides, _expr.Expr):
-            strides = _op.scatter(
+            strides = _op.scatter_elements(
                 _expr.const(
                     np.array([1] * dims),
                     dtype=infer_type(strides).checked_type.dtype,
diff --git a/python/tvm/relay/frontend/pytorch.py 
b/python/tvm/relay/frontend/pytorch.py
index 0dc9ffef6f..3cdfc5cb4e 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -496,7 +496,7 @@ class PyTorchOpConverter:
                     end[dim] = target_end
                 else:
                     target_end = _expr.const(target_end)
-                    end = _op.scatter(
+                    end = _op.scatter_elements(
                         end,
                         _op.expand_dims(_expr.const(dim), axis=0),
                         _op.expand_dims(target_end, axis=0),
@@ -508,7 +508,7 @@ class PyTorchOpConverter:
                 ttype = self.infer_type(target_end).dtype
                 if str(ttype) != axis_dtype:
                     target_end = _op.cast(target_end, axis_dtype)
-                end = _op.scatter(
+                end = _op.scatter_elements(
                     end,
                     _op.expand_dims(_expr.const(dim), axis=0),
                     _op.expand_dims(target_end, axis=0),
@@ -2554,11 +2554,62 @@ class PyTorchOpConverter:
         return self.nonzero(inputs, input_types, is_numpy_style=False)
 
     def scatter(self, inputs, input_types):
+        assert len(inputs) == 4 or len(inputs) == 5, (
+            "scatter takes 4 or 5 inputs: data, dim, index, src, reduce 
(optional), "
+            + "but {} given".format(len(inputs))
+        )
         data = inputs[0]
         axis = int(inputs[1])
         index = inputs[2]
         src = inputs[3]
-        return _op.transform.scatter(data, index, src, axis)
+        if len(inputs) == 5:
+            reduce = inputs[4]
+        else:
+            reduce = "update"
+
+        data_shape = self.infer_shape(data)
+        data_rank = len(data_shape)
+        index_shape = self.infer_shape(index)
+        index_rank = len(index_shape)
+        # When index is empty, the operation returns data unchanged
+        if self.is_empty_shape(index_shape):
+            return data
+
+        if np.isscalar(src):
+            assert self.infer_type(src).dtype == "float", "Scalar source can 
be float only"
+            src = _op.broadcast_to_like(src, data_shape)
+            src_shape = data_shape
+        else:
+            src_shape = self.infer_shape(src)
+        src_rank = len(src_shape)
+        assert data_rank == index_rank, "Index rank is not the same as data 
rank"
+        assert data_rank == src_rank, "Src rank is not the same as data rank"
+
+        assert 0 <= axis < data_rank, "Dim is out of bounds"
+
+        for i in range(data_rank):
+            index_dim = index_shape[i]
+            src_dim = src_shape[i]
+            data_dim = data_shape[i]
+            # Skip check for dynamic dimensions
+            if not any([isinstance(index_dim, tvm.tir.Any), 
isinstance(src_dim, tvm.tir.Any)]):
+                assert index_dim <= src_dim, "Index dim size should be less 
than src one"
+            if i != axis and not any(
+                [isinstance(index_dim, tvm.tir.Any), isinstance(data_dim, 
tvm.tir.Any)]
+            ):
+                assert index_dim <= data_dim, "Index dim size should be less 
than data one"
+
+        if reduce is None:
+            reduce = "update"
+        elif reduce == "multiply":
+            reduce = "mul"
+        assert reduce in [
+            "update",
+            "add",
+            "mul",
+        ], 'reduce arg is expected from "add", "multiply" or None'
+
+        return _op.scatter_elements(data, index, src, axis, reduce)
 
     def index_put(self, inputs, input_types):
         in_tensor = inputs[0]
@@ -2571,7 +2622,7 @@ class PyTorchOpConverter:
             mode = "add"
         # Combine array of index tensors into one index tensor with shape (N,_)
         index_tensor = _op.stack(indices, axis=0)
-        return _op.transform.scatter_nd(in_tensor, index_tensor, values, mode)
+        return _op.scatter_nd(in_tensor, index_tensor, values, mode)
 
     def scalar_tensor(self, inputs, input_types):
         data = inputs[0]
diff --git a/python/tvm/relay/frontend/pytorch_utils.py 
b/python/tvm/relay/frontend/pytorch_utils.py
index da4c9e039e..7de1248bda 100644
--- a/python/tvm/relay/frontend/pytorch_utils.py
+++ b/python/tvm/relay/frontend/pytorch_utils.py
@@ -331,7 +331,7 @@ def scatter_roi_align_result_pattern(levels, 
roi_align_results, num_scales):
         scatter_indices = is_op("repeat")(scatter_indices)
         scatter_indices = is_op("repeat")(scatter_indices)
 
-        scatter_res = is_op("scatter")(scatter_res, scatter_indices, 
roi_align_results[i])
+        scatter_res = is_op("scatter_elements")(scatter_res, scatter_indices, 
roi_align_results[i])
 
     return is_op("reshape")(scatter_res)
 
diff --git a/python/tvm/relay/op/_transform.py 
b/python/tvm/relay/op/_transform.py
index 12450dc809..11a10f7ee4 100644
--- a/python/tvm/relay/op/_transform.py
+++ b/python/tvm/relay/op/_transform.py
@@ -104,15 +104,6 @@ _reg.register_pattern("meta_schedule_layout_transform", 
OpPattern.INJECTIVE)
 # argwhere
 _reg.register_strategy("argwhere", strategy.argwhere_strategy)
 
-# scatter
-@_reg.register_compute("scatter")
-def compute_scatter(attrs, inputs, output_type):
-    """Compute definition of scatter"""
-    return [topi.scatter(inputs[0], inputs[1], inputs[2], attrs.axis)]
-
-
-_reg.register_strategy("scatter", strategy.scatter_strategy)
-
 # sparse_fill_empty_rows
 @_reg.register_compute("sparse_fill_empty_rows")
 def compute_sparse_fill_empty_rows(attrs, inputs, output_type):
@@ -677,7 +668,6 @@ def argwhere_shape_func(attrs, inputs, out_ndims):
     return ValueError("Does not support rank higher than 5 in argwhere")
 
 
-_reg.register_shape_func("scatter", False, elemwise_shape_func)
 _reg.register_shape_func("scatter_elements", False, elemwise_shape_func)
 _reg.register_shape_func("scatter_nd", False, elemwise_shape_func)
 
diff --git a/python/tvm/relay/op/op_attrs.py b/python/tvm/relay/op/op_attrs.py
index 0214ae8a46..4e9a9a4707 100644
--- a/python/tvm/relay/op/op_attrs.py
+++ b/python/tvm/relay/op/op_attrs.py
@@ -529,11 +529,6 @@ class RequantizeAttrs(Attrs):
     """Attributes used in requantize operators"""
 
 
-@tvm._ffi.register_object("relay.attrs.ScatterAttrs")
-class ScatterAttrs(Attrs):
-    """Attributes used in scatter operators"""
-
-
 @tvm._ffi.register_object("relay.attrs.SequenceMaskAttrs")
 class SequenceMaskAttrs(Attrs):
     """Attributes used in sequence_mask operators"""
diff --git a/python/tvm/relay/op/strategy/cuda.py 
b/python/tvm/relay/op/strategy/cuda.py
index e0229a615d..c6ea692a8d 100644
--- a/python/tvm/relay/op/strategy/cuda.py
+++ b/python/tvm/relay/op/strategy/cuda.py
@@ -1062,23 +1062,23 @@ def sparse_dense_padded_strategy_cuda(attrs, inputs, 
out_type, target):
     return strategy
 
 
-@scatter_strategy.register(["cuda", "gpu"])
-def scatter_cuda(attrs, inputs, out_type, target):
-    """scatter cuda strategy"""
+@scatter_elements_strategy.register(["cuda", "gpu"])
+def scatter_elements_cuda(attrs, inputs, out_type, target):
+    """scatter elements cuda strategy"""
     strategy = _op.OpStrategy()
     strategy.add_implementation(
-        wrap_compute_scatter(topi.cuda.scatter),
-        wrap_topi_schedule(topi.cuda.schedule_scatter),
-        name="scatter.cuda",
+        wrap_compute_scatter_elements(topi.cuda.scatter_elements),
+        wrap_topi_schedule(topi.cuda.schedule_extern),
+        name="scatter_elements.cuda",
         plevel=10,
     )
 
     rank = len(inputs[0].shape)
 
-    with SpecializedCondition(rank == 1):
+    with SpecializedCondition(rank == 1 and attrs.reduction == "update"):
         if can_use_thrust(target, "tvm.contrib.thrust.stable_sort_by_key"):
             strategy.add_implementation(
-                wrap_compute_scatter(topi.cuda.scatter_via_sort),
+                wrap_compute_scatter_elements(topi.cuda.scatter_via_sort),
                 wrap_topi_schedule(topi.cuda.schedule_scatter_via_sort),
                 name="scatter_via_sort.cuda",
                 plevel=9,  # use the sequential version by default
@@ -1086,20 +1086,6 @@ def scatter_cuda(attrs, inputs, out_type, target):
     return strategy
 
 
-@scatter_elements_strategy.register(["cuda", "gpu"])
-def scatter_elements_cuda(attrs, inputs, out_type, target):
-    """scatter elements cuda strategy"""
-    strategy = _op.OpStrategy()
-    strategy.add_implementation(
-        wrap_compute_scatter_elements(topi.cuda.scatter_elements),
-        wrap_topi_schedule(topi.cuda.schedule_extern),
-        name="scatter_elements.cuda",
-        plevel=10,
-    )
-    # TODO(vvchernov): There is possible specification for rank=1 as for 
scatter
-    return strategy
-
-
 @scatter_nd_strategy.register(["cuda", "gpu"])
 def scatter_nd_cuda(attrs, inputs, out_type, target):
     """scatter_nd cuda strategy"""
diff --git a/python/tvm/relay/op/strategy/generic.py 
b/python/tvm/relay/op/strategy/generic.py
index 4641fb18f7..b08d92a3cc 100644
--- a/python/tvm/relay/op/strategy/generic.py
+++ b/python/tvm/relay/op/strategy/generic.py
@@ -1548,27 +1548,6 @@ def proposal_strategy(attrs, inputs, out_type, target):
     return strategy
 
 
-# scatter
-@override_native_generic_func("scatter_strategy")
-def scatter_strategy(attrs, outs, out_type, target):
-    strategy = _op.OpStrategy()
-    strategy.add_implementation(
-        wrap_compute_scatter(topi.scatter),
-        wrap_topi_schedule(topi.generic.schedule_scatter),
-        name="scatter.generic",
-    )
-    return strategy
-
-
-def wrap_compute_scatter(topi_compute):
-    """Wrap scatter topi compute"""
-
-    def _compute_scatter(attrs, inputs, _):
-        return [topi_compute(inputs[0], inputs[1], inputs[2], attrs.axis)]
-
-    return _compute_scatter
-
-
 # scatter_elements
 @override_native_generic_func("scatter_elements_strategy")
 def scatter_elements_strategy(attrs, inputs, out_type, target):
@@ -1579,6 +1558,7 @@ def scatter_elements_strategy(attrs, inputs, out_type, 
target):
         wrap_topi_schedule(topi.generic.schedule_extern),
         name="scatter_elements.generic",
     )
+    # TODO(vvchernov): implement specialized case (rank=1, 
reduction="update"), see cuda strategy
     return strategy
 
 
diff --git a/python/tvm/relay/op/strategy/rocm.py 
b/python/tvm/relay/op/strategy/rocm.py
index 89cac0db4a..d80f347975 100644
--- a/python/tvm/relay/op/strategy/rocm.py
+++ b/python/tvm/relay/op/strategy/rocm.py
@@ -105,23 +105,23 @@ def argsort_strategy_cuda(attrs, inputs, out_type, 
target):
     return strategy
 
 
-@scatter_strategy.register(["rocm"])
-def scatter_cuda(attrs, inputs, out_type, target):
+@scatter_elements_strategy.register(["rocm"])
+def scatter_elements_cuda(attrs, inputs, out_type, target):
     """scatter rocm strategy"""
     strategy = _op.OpStrategy()
     strategy.add_implementation(
-        wrap_compute_scatter(topi.cuda.scatter),
-        wrap_topi_schedule(topi.cuda.schedule_scatter),
-        name="scatter.rocm",
+        wrap_compute_scatter_elements(topi.cuda.scatter_elements),
+        wrap_topi_schedule(topi.cuda.schedule_extern),
+        name="scatter_elements.rocm",
         plevel=10,
     )
 
     rank = len(inputs[0].shape)
 
-    with SpecializedCondition(rank == 1):
+    with SpecializedCondition(rank == 1 and attrs.reduction == "update"):
         if can_use_rocthrust(target, "tvm.contrib.thrust.stable_sort_by_key"):
             strategy.add_implementation(
-                wrap_compute_scatter(topi.cuda.scatter_via_sort),
+                wrap_compute_scatter_elements(topi.cuda.scatter_via_sort),
                 wrap_topi_schedule(topi.cuda.schedule_scatter_via_sort),
                 name="scatter_via_sort.rocm",
                 plevel=9,  # use the sequential version by default
diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py
index 6718347b31..f2d066d17e 100644
--- a/python/tvm/relay/op/transform.py
+++ b/python/tvm/relay/op/transform.py
@@ -353,31 +353,6 @@ def argwhere(condition):
     return _make.argwhere(condition)
 
 
-def scatter(data, indices, updates, axis):
-    """Update data at positions defined by indices with values in updates.
-
-    Parameters
-    ----------
-    data : relay.Expr
-        The input data to the operator.
-
-    indices : relay.Expr
-        The index locations to update.
-
-    updates : relay.Expr
-        The values to update.
-
-    axis : int
-        The axis to scatter on.
-
-    Returns
-    -------
-    ret : relay.Expr
-        The computed result.
-    """
-    return _make.scatter(data, indices, updates, axis)
-
-
 def scatter_elements(data, indices, updates, axis=0, reduction="update"):
     """Scatter elements with updating data by reduction of values in updates
     at positions defined by indices.
diff --git a/python/tvm/relay/transform/mixed_precision.py 
b/python/tvm/relay/transform/mixed_precision.py
index 5018ba9ba9..f6bb8b8150 100644
--- a/python/tvm/relay/transform/mixed_precision.py
+++ b/python/tvm/relay/transform/mixed_precision.py
@@ -63,6 +63,8 @@ DEFAULT_FOLLOW_LIST = [
     "tile",
     "dyn.tile",
     "scatter",
+    "scatter_elements",
+    "scatter_nd",
     "full",
     "dyn.full",
     "nn.depth_to_space",
diff --git a/python/tvm/topi/cuda/scatter.py b/python/tvm/topi/cuda/scatter.py
index c88c3086f3..39ef5a5a42 100644
--- a/python/tvm/topi/cuda/scatter.py
+++ b/python/tvm/topi/cuda/scatter.py
@@ -14,446 +14,15 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-# pylint: disable=invalid-name, no-member, too-many-locals, 
too-many-arguments, too-many-statements, singleton-comparison, unused-argument
-"""Scatter operator """
+# pylint: disable=invalid-name
+"""Scatter operators"""
 import tvm
 from tvm import te, tir, autotvm
 from ..scatter import _verify_scatter_nd_inputs
 from ..generic import schedule_extern
 from .nms import atomic_add
 from .sort import stable_sort_by_key_thrust
-from ..utils import prod, ceil_div
-
-
-def _memcpy_ir(ib, out_ptr, data_ptr, shape):
-    fused = prod(shape)
-    with ib.new_scope():
-        num_thread = 
int(tvm.target.Target.current(allow_none=False).max_num_threads)
-        num_blocks = ceil_div(fused, num_thread)
-        bx = te.thread_axis("blockIdx.x")
-        ib.scope_attr(bx, "thread_extent", num_blocks)
-        tx = te.thread_axis("threadIdx.x")
-        ib.scope_attr(tx, "thread_extent", num_thread)
-        tid = bx * num_thread + tx
-
-        with ib.if_scope(tid < fused):
-            out_ptr[tid] = data_ptr[tid]
-
-
-def gen_ir_1d(data, indices, updates, axis, out, update_func):
-    """Generate scatter ir for 1d inputs
-
-    Parameters
-    ----------
-    data : tir.Tensor
-        The input data to the operator.
-
-    indices : tir.Tensor
-        The index locations to update.
-
-    updates : tir.Tensor
-        The values to update.
-
-    axis : int
-        The axis to scatter on
-
-    out : tir.Tensor
-        The output tensor.
-
-    update_func: function
-        The function to be applied to a destination and the corresponding 
update.
-
-    Returns
-    -------
-    ret : tir
-        The computational ir.
-    """
-    assert axis == 0
-    n = data.shape[0]
-
-    ib = tvm.tir.ir_builder.create()
-
-    out_ptr = ib.buffer_ptr(out)
-    data_ptr = ib.buffer_ptr(data)
-
-    _memcpy_ir(ib, out_ptr, data_ptr, data.shape)
-
-    indices_ptr = ib.buffer_ptr(indices)
-    updates_ptr = ib.buffer_ptr(updates)
-
-    ni = indices.shape[0]
-
-    with ib.new_scope():
-        bx = te.thread_axis("blockIdx.x")
-        ib.scope_attr(bx, "thread_extent", 1)
-        with ib.for_range(0, ni, name="i") as i:
-            index = indices_ptr[i]
-            with ib.if_scope(index < 0):
-                update_func(out_ptr, index + n, updates_ptr[i])
-            with ib.else_scope():
-                update_func(out_ptr, index, updates_ptr[i])
-
-    return ib.get()
-
-
-def gen_ir_2d(data, indices, updates, axis, out, update_func):
-    """Generate scatter ir for 2d inputs
-
-    Parameters
-    ----------
-    data : tir.Tensor
-        The input data to the operator.
-
-    indices : tir.Tensor
-        The index locations to update.
-
-    updates : tir.Tensor
-        The values to update.
-
-    axis : int
-        The axis to scatter on
-
-    out : tir.Tensor
-        The output tensor.
-
-    update_func: function
-        The function to be applied to a destination and the corresponding 
update
-
-    Returns
-    -------
-    ret : tir
-        The computational ir.
-    """
-    n = data.shape[0]
-    c = data.shape[1]
-
-    ib = tvm.tir.ir_builder.create()
-
-    out_ptr = ib.buffer_ptr(out)
-    data_ptr = ib.buffer_ptr(data)
-
-    _memcpy_ir(ib, out_ptr, data_ptr, data.shape)
-
-    indices_ptr = ib.buffer_ptr(indices)
-    updates_ptr = ib.buffer_ptr(updates)
-
-    ni = indices.shape[0]
-    ci = indices.shape[1]
-
-    if axis == 0:
-        with ib.new_scope():
-            j = te.thread_axis("blockIdx.x")
-            ib.scope_attr(j, "thread_extent", ci)
-            with ib.for_range(0, ni, name="i") as i:
-                idx = i * ci + j
-                index = indices_ptr[idx]
-                with ib.if_scope(index < 0):
-                    update_func(out_ptr, (index + n) * c + j, updates_ptr[idx])
-                with ib.else_scope():
-                    update_func(out_ptr, index * c + j, updates_ptr[idx])
-    else:
-        with ib.new_scope():
-            i = te.thread_axis("blockIdx.x")
-            ib.scope_attr(i, "thread_extent", ni)
-            with ib.for_range(0, ci, name="j") as j:
-                idx = i * ci + j
-                index = indices_ptr[idx]
-                with ib.if_scope(index < 0):
-                    update_func(out_ptr, i * c + (index + c), updates_ptr[idx])
-                with ib.else_scope():
-                    update_func(out_ptr, i * c + index, updates_ptr[idx])
-    return ib.get()
-
-
-def gen_ir_3d(data, indices, updates, axis, out, update_func):
-    """Generate scatter ir for 3d inputs
-
-    Parameters
-    ----------
-    data : tir.Tensor
-        The input data to the operator.
-
-    indices : tir.Tensor
-        The index locations to update.
-
-    updates : tir.Tensor
-        The values to update.
-
-    axis : int
-        The axis to scatter on
-
-    out : tir.Tensor
-        The output tensor.
-
-    update_func: function
-        The function to be applied to a destination and the corresponding 
update
-
-    Returns
-    -------
-    ret : tir
-        The computational ir.
-    """
-    warp_size = tvm.target.Target.current(False).thread_warp_size
-
-    n = data.shape[0]
-    c = data.shape[1]
-    h = data.shape[2]
-
-    ib = tvm.tir.ir_builder.create()
-
-    out_ptr = ib.buffer_ptr(out)
-    data_ptr = ib.buffer_ptr(data)
-
-    _memcpy_ir(ib, out_ptr, data_ptr, data.shape)
-
-    indices_ptr = ib.buffer_ptr(indices)
-    updates_ptr = ib.buffer_ptr(updates)
-    ni = indices.shape[0]
-    ci = indices.shape[1]
-    hi = indices.shape[2]
-
-    if axis == 0:
-        with ib.new_scope():
-            j = te.thread_axis("blockIdx.x")
-            ib.scope_attr(j, "thread_extent", ci)
-            tx = te.thread_axis("threadIdx.x")
-            ib.scope_attr(tx, "thread_extent", warp_size)
-            with ib.for_range(0, ni, name="i") as i:
-                with ib.for_range(0, ceil_div(hi, warp_size), name="k") as k_:
-                    k = k_ * warp_size + tx
-                    with ib.if_scope(k < hi):
-                        idx = (i * ci + j) * hi + k
-                        index = indices_ptr[idx]
-                        with ib.if_scope(index < 0):
-                            update_func(out_ptr, ((index + n) * c + j) * h + 
k, updates_ptr[idx])
-                        with ib.else_scope():
-                            update_func(out_ptr, (index * c + j) * h + k, 
updates_ptr[idx])
-    elif axis == 1:
-        with ib.new_scope():
-            i = te.thread_axis("blockIdx.x")
-            ib.scope_attr(i, "thread_extent", ni)
-            tx = te.thread_axis("threadIdx.x")
-            ib.scope_attr(tx, "thread_extent", warp_size)
-            with ib.for_range(0, ci, name="j") as j:
-                with ib.for_range(0, ceil_div(hi, warp_size), name="k") as k_:
-                    k = k_ * warp_size + tx
-                    with ib.if_scope(k < hi):
-                        idx = (i * ci + j) * hi + k
-                        index = indices_ptr[idx]
-                        with ib.if_scope(index < 0):
-                            update_func(out_ptr, (i * c + (index + c)) * h + 
k, updates_ptr[idx])
-                        with ib.else_scope():
-                            update_func(out_ptr, (i * c + index) * h + k, 
updates_ptr[idx])
-    else:
-        with ib.new_scope():
-            i = te.thread_axis("blockIdx.x")
-            ib.scope_attr(i, "thread_extent", ni)
-            j = te.thread_axis("blockIdx.y")
-            ib.scope_attr(j, "thread_extent", ci)
-            with ib.for_range(0, hi, name="k") as k:
-                idx = (i * ci + j) * hi + k
-                index = indices_ptr[idx]
-                with ib.if_scope(index < 0):
-                    update_func(out_ptr, (i * c + j) * h + (index + h), 
updates_ptr[idx])
-                with ib.else_scope():
-                    update_func(out_ptr, (i * c + j) * h + index, 
updates_ptr[idx])
-    return ib.get()
-
-
-def gen_ir_4d(data, indices, updates, axis, out, update_func):
-    """Generate scatter ir for 4d inputs
-
-    Parameters
-    ----------
-    data : tir.Tensor
-        The input data to the operator.
-
-    indices : tir.Tensor
-        The index locations to update.
-
-    updates : tir.Tensor
-        The values to update.
-
-    axis : int
-        The axis to scatter on
-
-    out : tir.Tensor
-        The output tensor.
-
-    update_func: function
-        The function to be applied to a destination and the corresponding 
update
-
-    Returns
-    -------
-    ret : tir
-        The computational ir.
-    """
-    warp_size = tvm.target.Target.current(False).thread_warp_size
-
-    n = data.shape[0]
-    c = data.shape[1]
-    h = data.shape[2]
-    w = data.shape[3]
-
-    ib = tvm.tir.ir_builder.create()
-
-    out_ptr = ib.buffer_ptr(out)
-    data_ptr = ib.buffer_ptr(data)
-    _memcpy_ir(ib, out_ptr, data_ptr, data.shape)
-
-    indices_ptr = ib.buffer_ptr(indices)
-    updates_ptr = ib.buffer_ptr(updates)
-    ni = indices.shape[0]
-    ci = indices.shape[1]
-    hi = indices.shape[2]
-    wi = indices.shape[3]
-
-    if axis == 0:
-        with ib.new_scope():
-            j = te.thread_axis("blockIdx.y")
-            ib.scope_attr(j, "thread_extent", ci)
-            k = te.thread_axis("blockIdx.z")
-            ib.scope_attr(k, "thread_extent", hi)
-            tx = te.thread_axis("threadIdx.x")
-            ib.scope_attr(tx, "thread_extent", warp_size)
-            with ib.for_range(0, ni, name="i") as i:
-                with ib.for_range(0, ceil_div(wi, warp_size), name="l") as l_:
-                    l = l_ * warp_size + tx
-                    with ib.if_scope(l < wi):
-                        idx = ((i * ci + j) * hi + k) * wi + l
-                        index = indices_ptr[idx]
-                        with ib.if_scope(index < 0):
-                            update_func(
-                                out_ptr, (((index + n) * c + j) * h + k) * w + 
l, updates_ptr[idx]
-                            )
-                        with ib.else_scope():
-                            update_func(
-                                out_ptr, ((index * c + j) * h + k) * w + l, 
updates_ptr[idx]
-                            )
-    elif axis == 1:
-        with ib.new_scope():
-            i = te.thread_axis("blockIdx.x")
-            ib.scope_attr(i, "thread_extent", ni)
-            k = te.thread_axis("blockIdx.z")
-            ib.scope_attr(k, "thread_extent", hi)
-            tx = te.thread_axis("threadIdx.x")
-            ib.scope_attr(tx, "thread_extent", warp_size)
-            with ib.for_range(0, ci, name="j") as j:
-                with ib.for_range(0, ceil_div(wi, warp_size), name="l") as l_:
-                    l = l_ * warp_size + tx
-                    with ib.if_scope(l < wi):
-                        idx = ((i * ci + j) * hi + k) * wi + l
-                        index = indices_ptr[idx]
-                        with ib.if_scope(index < 0):
-                            update_func(
-                                out_ptr, ((i * c + (index + c)) * h + k) * w + 
l, updates_ptr[idx]
-                            )
-                        with ib.else_scope():
-                            update_func(
-                                out_ptr, ((i * c + index) * h + k) * w + l, 
updates_ptr[idx]
-                            )
-    elif axis == 2:
-        with ib.new_scope():
-            i = te.thread_axis("blockIdx.x")
-            ib.scope_attr(i, "thread_extent", ni)
-            j = te.thread_axis("blockIdx.y")
-            ib.scope_attr(j, "thread_extent", ci)
-            tx = te.thread_axis("threadIdx.x")
-            ib.scope_attr(tx, "thread_extent", warp_size)
-            with ib.for_range(0, hi, name="k") as k:
-                with ib.for_range(0, ceil_div(wi, warp_size), name="l") as l_:
-                    l = l_ * warp_size + tx
-                    with ib.if_scope(l < wi):
-                        idx = ((i * ci + j) * hi + k) * wi + l
-                        index = indices_ptr[idx]
-                        with ib.if_scope(index < 0):
-                            update_func(
-                                out_ptr, ((i * c + j) * h + (index + h)) * w + 
l, updates_ptr[idx]
-                            )
-                        with ib.else_scope():
-                            update_func(
-                                out_ptr, ((i * c + j) * h + index) * w + l, 
updates_ptr[idx]
-                            )
-    else:
-        with ib.new_scope():
-            i = te.thread_axis("blockIdx.x")
-            ib.scope_attr(i, "thread_extent", ni)
-            j = te.thread_axis("blockIdx.y")
-            ib.scope_attr(j, "thread_extent", ci)
-            k = te.thread_axis("blockIdx.z")
-            ib.scope_attr(k, "thread_extent", hi)
-            with ib.for_range(0, wi, name="l") as l:
-                idx = ((i * ci + j) * hi + k) * wi + l
-                index = indices_ptr[idx]
-                with ib.if_scope(index < 0):
-                    update_func(out_ptr, ((i * c + j) * h + k) * w + (index + 
w), updates_ptr[idx])
-                with ib.else_scope():
-                    update_func(out_ptr, ((i * c + j) * h + k) * w + index, 
updates_ptr[idx])
-    return ib.get()
-
-
[email protected]_topi_compute("scatter.cuda")
-def scatter(cfg, data, indices, updates, axis=0):
-    """Update data at positions defined by indices with values in updates
-
-    Parameters
-    ----------
-    data : relay.Expr
-        The input data to the operator.
-
-    indices : relay.Expr
-        The index locations to update.
-
-    updates : relay.Expr
-        The values to update.
-
-    axis : int
-        The axis to scatter on
-
-    Returns
-    -------
-    ret : relay.Expr
-        The computed result.
-    """
-    if axis < 0:
-        axis += len(data.shape)
-    assert axis >= 0
-    assert axis < len(data.shape)
-
-    rank = len(data.shape)
-    assert 1 <= rank <= 4, "scatter only supports 1-4 dimensions"
-
-    ir_funcs = {
-        1: gen_ir_1d,
-        2: gen_ir_2d,
-        3: gen_ir_3d,
-        4: gen_ir_4d,
-    }
-
-    def update_func(dst_ptr, dst_index, update):
-        dst_ptr[dst_index] = update
-
-    out_shape = data.shape
-    out_buf = tvm.tir.decl_buffer(out_shape, data.dtype, "out_buf")
-
-    cfg.add_flop(1)  # A dummy value to satisfy AutoTVM
-
-    out = te.extern(
-        [out_shape],
-        [data, indices, updates],
-        lambda ins, outs: ir_funcs[rank](ins[0], ins[1], ins[2], axis, 
outs[0], update_func),
-        dtype=data.dtype,
-        out_buffers=[out_buf],
-        name="scatter_gpu",
-        tag="scatter_gpu",
-    )
-
-    return out
-
-
[email protected]_topi_schedule("scatter.cuda")
-def schedule_scatter(_, outs):
-    return schedule_extern(outs)
+from ..utils import ceil_div
 
 
 def gen_scatter_1d_thrust(data, indices_sorted, updates_sorted, out):
@@ -540,7 +109,7 @@ def gen_scatter_1d_thrust(data, indices_sorted, 
updates_sorted, out):
 
 
 @autotvm.register_topi_compute("scatter_via_sort.cuda")
-def scatter_via_sort(cfg, data, indices, updates, axis=0):
+def scatter_via_sort(cfg, data, indices, updates, axis=0, reduction="add"):
     """Update data at positions defined by indices with values in updates
 
     Parameters
@@ -562,6 +131,7 @@ def scatter_via_sort(cfg, data, indices, updates, axis=0):
     ret : relay.Expr
         The computed result.
     """
+    assert reduction == "add"
     if axis < 0:
         axis += len(data.shape)
     assert axis == 0 and len(data.shape) == 1, "sorting based scatter only 
supported for 1d input"
diff --git a/python/tvm/topi/generic/search.py 
b/python/tvm/topi/generic/search.py
index 826194e75c..9a80e678c2 100644
--- a/python/tvm/topi/generic/search.py
+++ b/python/tvm/topi/generic/search.py
@@ -36,22 +36,6 @@ def schedule_argwhere(outs):
     return _default_schedule(outs, False)
 
 
-def schedule_scatter(outs):
-    """Schedule for scatter operator.
-
-    Parameters
-    ----------
-    outs: Array of Tensor
-      The computation graph description of scatter.
-
-    Returns
-    -------
-    s: Schedule
-      The computation schedule for the op.
-    """
-    return _default_schedule(outs, False)
-
-
 def schedule_sparse_fill_empty_rows(outs):
     return _default_schedule(outs, False)
 
diff --git a/python/tvm/topi/scatter.py b/python/tvm/topi/scatter.py
index 45629c005f..799b3d1673 100644
--- a/python/tvm/topi/scatter.py
+++ b/python/tvm/topi/scatter.py
@@ -14,191 +14,12 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-# pylint: disable=invalid-name, too-many-arguments, too-many-nested-blocks
-"""Scatter operator"""
+# pylint: disable=invalid-name
+"""ScatterND operator"""
 from tvm import te, tir  # hide redefinition of min and max
 from tvm.tir import expr
 
 
[email protected]
-def _scatter_1d(data, indices, updates):
-    out = output_tensor(data.shape, data.dtype)
-    for i in range(data.shape[0]):
-        out[i] = data[i]
-    for i in range(indices.shape[0]):
-        out[indices[i] if indices[i] >= 0 else indices[i] + data.shape[0]] = 
updates[i]
-    return out
-
-
[email protected]
-def _scatter_2d(data, indices, updates, axis):
-    out = output_tensor(data.shape, data.dtype)
-    for i in range(data.shape[0]):
-        for j in range(data.shape[1]):
-            out[i, j] = data[i, j]
-    if axis == 0:
-        for i in range(indices.shape[0]):
-            for j in range(indices.shape[1]):
-                out[
-                    indices[i, j] if indices[i, j] >= 0 else indices[i, j] + 
data.shape[axis], j
-                ] = updates[i, j]
-    else:
-        for i in range(indices.shape[0]):
-            for j in range(indices.shape[1]):
-                out[
-                    i, indices[i, j] if indices[i, j] >= 0 else indices[i, j] 
+ data.shape[axis]
-                ] = updates[i, j]
-
-    return out
-
-
[email protected]
-def _scatter_3d(data, indices, updates, axis):
-    out = output_tensor(data.shape, data.dtype)
-    for i in range(data.shape[0]):
-        for j in range(data.shape[1]):
-            for k in range(data.shape[2]):
-                out[i, j, k] = data[i, j, k]
-    if axis == 0:
-        for i in range(indices.shape[0]):
-            for j in range(indices.shape[1]):
-                for k in range(indices.shape[2]):
-                    out[
-                        indices[i, j, k]
-                        if indices[i, j, k] >= 0
-                        else indices[i, j, k] + data.shape[axis],
-                        j,
-                        k,
-                    ] = updates[i, j, k]
-    elif axis == 1:
-        for i in range(indices.shape[0]):
-            for j in range(indices.shape[1]):
-                for k in range(indices.shape[2]):
-                    out[
-                        i,
-                        indices[i, j, k]
-                        if indices[i, j, k] >= 0
-                        else indices[i, j, k] + data.shape[axis],
-                        k,
-                    ] = updates[i, j, k]
-    else:
-        for i in range(indices.shape[0]):
-            for j in range(indices.shape[1]):
-                for k in range(indices.shape[2]):
-                    out[
-                        i,
-                        j,
-                        indices[i, j, k]
-                        if indices[i, j, k] >= 0
-                        else indices[i, j, k] + data.shape[axis],
-                    ] = updates[i, j, k]
-
-    return out
-
-
[email protected]
-def _scatter_4d(data, indices, updates, axis):
-    out = output_tensor(data.shape, data.dtype)
-    for i in range(data.shape[0]):
-        for j in range(data.shape[1]):
-            for k in range(data.shape[2]):
-                for l in range(data.shape[3]):
-                    out[i, j, k, l] = data[i, j, k, l]
-
-    if axis == 0:
-        for i in range(indices.shape[0]):
-            for j in range(indices.shape[1]):
-                for k in range(indices.shape[2]):
-                    for l in range(indices.shape[3]):
-                        out[
-                            indices[i, j, k, l]
-                            if indices[i, j, k, l] >= 0
-                            else indices[i, j, k, l] + data.shape[axis],
-                            j,
-                            k,
-                            l,
-                        ] = updates[i, j, k, l]
-    elif axis == 1:
-        for i in range(indices.shape[0]):
-            for j in range(indices.shape[1]):
-                for k in range(indices.shape[2]):
-                    for l in range(indices.shape[3]):
-                        out[
-                            i,
-                            indices[i, j, k, l]
-                            if indices[i, j, k, l] >= 0
-                            else indices[i, j, k, l] + data.shape[axis],
-                            k,
-                            l,
-                        ] = updates[i, j, k, l]
-    elif axis == 2:
-        for i in range(indices.shape[0]):
-            for j in range(indices.shape[1]):
-                for k in range(indices.shape[2]):
-                    for l in range(indices.shape[3]):
-                        out[
-                            i,
-                            j,
-                            indices[i, j, k, l]
-                            if indices[i, j, k, l] >= 0
-                            else indices[i, j, k, l] + data.shape[axis],
-                            l,
-                        ] = updates[i, j, k, l]
-    else:
-        for i in range(indices.shape[0]):
-            for j in range(indices.shape[1]):
-                for k in range(indices.shape[2]):
-                    for l in range(indices.shape[3]):
-                        out[
-                            i,
-                            j,
-                            k,
-                            indices[i, j, k, l]
-                            if indices[i, j, k, l] >= 0
-                            else indices[i, j, k, l] + data.shape[axis],
-                        ] = updates[i, j, k, l]
-
-    return out
-
-
-def scatter(data, indices, updates, axis=0):
-    """Update data at positions defined by indices with values in updates
-
-    Parameters
-    ----------
-    data : relay.Expr
-        The input data to the operator.
-
-    indices : relay.Expr
-        The index locations to update.
-
-    updates : relay.Expr
-        The values to update.
-
-    axis : int
-        The axis to scatter on
-
-    Returns
-    -------
-    ret : relay.Expr
-        The computed result.
-    """
-    if axis < 0:
-        axis += len(data.shape)
-    assert axis >= 0
-    assert axis < len(data.shape)
-
-    if len(data.shape) == 1:
-        return _scatter_1d(data, indices, updates)
-    if len(data.shape) == 2:
-        return _scatter_2d(data, indices, updates, axis)
-    if len(data.shape) == 3:
-        return _scatter_3d(data, indices, updates, axis)
-    if len(data.shape) == 4:
-        return _scatter_4d(data, indices, updates, axis)
-    raise ValueError("scatter only support for 1-4 dimensions")
-
-
 def _verify_scatter_nd_inputs(data, indices, updates):
     mdim = int(indices.shape[0])
     assert mdim <= len(data.shape), (
diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc
index 907141c9cb..1bae1a4d9a 100644
--- a/src/relay/op/tensor/transform.cc
+++ b/src/relay/op/tensor/transform.cc
@@ -1095,54 +1095,6 @@ non-zero)doc" TVM_ADD_FILELINE)
     .set_attr<TOpPattern>("TOpPattern", kOpaque)
     .set_support_level(10);
 
-// Scatter
-TVM_REGISTER_NODE_TYPE(ScatterAttrs);
-
-// Scatter
-bool ScatterRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
-                const TypeReporter& reporter) {
-  ICHECK_EQ(num_inputs, 3);
-  ICHECK_EQ(types.size(), 4);
-  auto data = types[0].as<TensorTypeNode>();
-  if (data == nullptr) {
-    return false;
-  }
-  auto indices = types[1].as<TensorTypeNode>();
-  if (indices == nullptr) {
-    return false;
-  }
-  auto updates = types[2].as<TensorTypeNode>();
-  if (updates == nullptr) {
-    return false;
-  }
-  ICHECK(indices->dtype.is_int() || indices->dtype.is_uint())
-      << "indices of scatter must be tensor of integer";
-  const auto param = attrs.as<ScatterAttrs>();
-  ICHECK(param != nullptr);
-  reporter->Assign(types[3], TensorType(data->shape, data->dtype));
-  return true;
-}
-
-TVM_REGISTER_GLOBAL("relay.op._make.scatter")
-    .set_body_typed([](Expr data, Expr indices, Expr updates, int axis) {
-      auto attrs = make_object<ScatterAttrs>();
-      attrs->axis = std::move(axis);
-      static const Op& op = Op::Get("scatter");
-      return Call(op, {data, indices, updates}, Attrs(attrs), {});
-    });
-
-RELAY_REGISTER_OP("scatter")
-    .describe(
-        R"doc(Update data at positions defined by indices with values in 
updates)doc" TVM_ADD_FILELINE)
-    .set_num_inputs(3)
-    .add_argument("data", "Tensor", "The input data tensor.")
-    .add_argument("indices", "Tensor", "The indices location tensor.")
-    .add_argument("updates", "Tensor", "The values to update the input with.")
-    .add_type_rel("Scatter", ScatterRel)
-    .set_attr<TOpIsStateful>("TOpIsStateful", false)
-    .set_attr<TOpPattern>("TOpPattern", kOpaque)
-    .set_support_level(10);
-
 // scatter_elements operator
 TVM_REGISTER_NODE_TYPE(ScatterElementsAttrs);
 
@@ -1168,7 +1120,6 @@ bool ScatterElementsRel(const Array<Type>& types, int 
num_inputs, const Attrs& a
         << "ScatterElements: expect updates type to be TensorType but got " << 
types[2];
     return false;
   }
-  // TODO(vvchernov): ONNX requires int32 and int64
   ICHECK(indices->dtype.is_int() || indices->dtype.is_uint())
       << "ScatterElements: indices must be a tensor of integers.";
 
diff --git a/tests/python/frontend/pytorch/test_forward.py 
b/tests/python/frontend/pytorch/test_forward.py
index 2401e98bce..807c44a364 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -4232,12 +4232,17 @@ def test_forward_scatter():
     verify_trace_model(test_fn_scatter(1), [in_data, in_index, in_src], 
targets)
     verify_trace_model(test_fn_scatter_add(1), [in_data, in_index, in_src], 
targets)
 
-    # Check empty indices for scatter_add
+    # Check empty indices
     in_data = torch.zeros(2, 4)
     in_index = torch.empty((0,))
     in_src = torch.rand(2, 1)
+    verify_trace_model(test_fn_scatter(0), [in_data, in_index, in_src], 
targets)
     verify_trace_model(test_fn_scatter_add(0), [in_data, in_index, in_src], 
targets)
 
+    # Check scalar source
+    # TODO(vvchernov): Scalar source is supported on TVM side, but torch 
failes with
+    # input Tuple(Tensor, Tensor, float). What does scalar mean for torch in 
this case?
+
 
 def test_forward_scatter_reduce():
     """test_forward_scatter_reduce"""
diff --git a/tests/python/relay/test_op_level3.py 
b/tests/python/relay/test_op_level3.py
index f18e935b57..493bf00fc6 100644
--- a/tests/python/relay/test_op_level3.py
+++ b/tests/python/relay/test_op_level3.py
@@ -1002,7 +1002,7 @@ def test_scatter(target, dev, executor_kind):
         d = relay.var("d", relay.TensorType(dshape, "float32"))
         i = relay.var("i", relay.TensorType(ishape, indices_dtype))
         u = relay.var("u", relay.TensorType(ishape, "float32"))
-        z = relay.op.scatter(d, i, u, axis)
+        z = relay.op.scatter_elements(d, i, u, axis)
 
         func = relay.Function([d, i, u], z)
 
@@ -1055,7 +1055,7 @@ class TestDynamicScatter:
         d = relay.var("d", relay.TensorType([relay.Any() for i in 
range(len(dshape))], "float32"))
         i = relay.var("i", relay.TensorType([relay.Any() for i in 
range(len(ishape))], "int64"))
         u = relay.var("u", relay.TensorType([relay.Any() for i in 
range(len(ishape))], "float32"))
-        z = relay.op.scatter(d, i, u, axis)
+        z = relay.op.scatter_elements(d, i, u, axis)
 
         func = relay.Function([d, i, u], z)
 

Reply via email to