This is an automated email from the ASF dual-hosted git repository.

tkonolige pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 48842d78e7 [Fix,TOPI] Consolidate generic and x86 scatter nd (#13755)
48842d78e7 is described below

commit 48842d78e7b78dcc75883da37005f68c04c4f9a6
Author: Tristan Konolige <[email protected]>
AuthorDate: Wed Jan 11 12:47:24 2023 -0800

    [Fix,TOPI] Consolidate generic and x86 scatter nd (#13755)
    
    The generic scatter nd was almost identical to the x86 one and was not 
tested. They now are one and the same.
---
 python/tvm/relay/op/strategy/x86.py           |   2 +-
 python/tvm/topi/scatter.py                    |  55 ++++++------
 python/tvm/topi/x86/__init__.py               |   1 -
 python/tvm/topi/x86/scatter.py                | 119 --------------------------
 tests/python/topi/python/test_topi_scatter.py |   4 -
 5 files changed, 26 insertions(+), 155 deletions(-)

diff --git a/python/tvm/relay/op/strategy/x86.py 
b/python/tvm/relay/op/strategy/x86.py
index d0ad377203..fa002737a7 100644
--- a/python/tvm/relay/op/strategy/x86.py
+++ b/python/tvm/relay/op/strategy/x86.py
@@ -768,7 +768,7 @@ def scatter_nd_strategy_cpu(attrs, inputs, out_type, 
target):
     """scatter_nd x86 strategy"""
     strategy = _op.OpStrategy()
     strategy.add_implementation(
-        wrap_compute_scatter_nd(topi.x86.scatter_nd),
+        wrap_compute_scatter_nd(topi.scatter_nd),
         wrap_topi_schedule(topi.generic.schedule_extern),
         name="scatter_nd.x86",
         plevel=10,
diff --git a/python/tvm/topi/scatter.py b/python/tvm/topi/scatter.py
index afb0d6633a..e0578aab41 100644
--- a/python/tvm/topi/scatter.py
+++ b/python/tvm/topi/scatter.py
@@ -16,8 +16,8 @@
 # under the License.
 # pylint: disable=invalid-name, too-many-arguments, too-many-nested-blocks
 """Scatter operator"""
-from ..tir import decl_buffer, ir_builder, AssertStmt, StringImm, Evaluate, 
expr
 from ..te import extern, hybrid
+from ..tir import decl_buffer, expr, ir_builder
 
 
 @hybrid.script
@@ -268,6 +268,7 @@ def scatter_nd(data, indices, updates, mode):
     _verify_scatter_nd_inputs(data, indices, updates)
 
     def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr):
+        # pylint: disable=invalid-name
         ib = ir_builder.create()
 
         data = ib.buffer_ptr(data_ptr)
@@ -275,56 +276,50 @@ def scatter_nd(data, indices, updates, mode):
         updates = ib.buffer_ptr(updates_ptr)
         out = ib.buffer_ptr(out_ptr)
 
-        fused_shape = 1
-        for i in data.shape:
-            fused_shape *= i
-        with ib.for_range(0, fused_shape) as i:
-            out[i] = data[i]
-
         # We combine all the indices dimensions but the first one into a single
         # dimension so we can iterate it in single loop instead of an arbitrary
-        # number of loops. We do the same thing for all the data dimensions.
+        # number of loops. We do the same thing for all the update dimensions.
         fused_indices_dimension = 1
         for i in indices_ptr.shape[1:]:
             fused_indices_dimension *= i
 
-        fused_data_dimension = 1
-        for i in data_ptr.shape[len(indices_ptr.shape) - 1 :]:
-            fused_data_dimension *= i
+        fused_updates_dimension = 1
+        for i in updates_ptr.shape[len(indices_ptr.shape) - 1 :]:
+            fused_updates_dimension *= i
+
+        fused_shape = 1
+        for i in data_ptr.shape:
+            fused_shape *= i
+
+        with ib.for_range(0, fused_shape) as i:
+            out[i] = data[i]
 
-        with ib.for_range(0, fused_indices_dimension, name="i") as i:
-            with ib.for_range(0, fused_data_dimension, name="j") as j:
-                offset = fused_data_dimension
+        with ib.for_range(0, fused_indices_dimension) as i:
+            with ib.for_range(0, fused_updates_dimension, kind="parallel") as 
j:
+                offset = fused_updates_dimension
                 index = j  # This is x_M, .. x_{N-1} part of the index into 
out.
                 # Build up the indices[0, y_0, .. y_{K-1}], .. indices[M-1, 
y_0, .. y_{K-1}] part
                 # of the index into out.
                 for l in reversed(range(indices_ptr.shape[0].value)):
                     # indices[i * l * fused_indices_dimension] = indices[l, 
y_0, ... y_{k-1}]
                     index += offset * indices[i + l * fused_indices_dimension]
-                    ib.emit(
-                        AssertStmt(
-                            indices[i + l * fused_indices_dimension] < 
shape[l],
-                            StringImm("index out of bounds"),
-                            Evaluate(0),
-                        )
-                    )
-                    offset *= shape[l]
-                if mode == "add":
-                    out[index] += updates[i * fused_data_dimension + j]
-                elif mode == "update":
-                    out[index] = updates[i * fused_data_dimension + j]
+                    offset *= data_ptr.shape[l]
+                if mode == "update":
+                    out[index] = updates[i * fused_updates_dimension + j]
+                elif mode == "add":
+                    out[index] += updates[i * fused_updates_dimension + j]
                 else:
                     raise NotImplementedError("scatter_nd mode not in [update, 
add]:", mode)
 
         return ib.get()
 
-    out_buf = decl_buffer(shape, data.dtype, "out_buf")
+    out_buf = decl_buffer(data.shape, data.dtype, "out_buf")
     return extern(
-        [shape],
+        [data.shape],
         [data, indices, updates],
         lambda ins, outs: gen_ir(ins[0], ins[1], ins[2], outs[0]),
         dtype=data.dtype,
         out_buffers=[out_buf],
-        name="scatter_nd_generic",
-        tag="scatter_nd_generic",
+        name="scatter_nd.generic",
+        tag="scatter_nd.generic",
     )
diff --git a/python/tvm/topi/x86/__init__.py b/python/tvm/topi/x86/__init__.py
index d075090f01..a54b156380 100644
--- a/python/tvm/topi/x86/__init__.py
+++ b/python/tvm/topi/x86/__init__.py
@@ -40,7 +40,6 @@ from .conv3d_transpose import *
 from .sparse import *
 from .conv2d_alter_op import *
 from .dense_alter_op import *
-from .scatter import *
 from .group_conv2d import *
 from .math_alter_op import *
 from .concat import *
diff --git a/python/tvm/topi/x86/scatter.py b/python/tvm/topi/x86/scatter.py
deleted file mode 100644
index 5eb5e6e99b..0000000000
--- a/python/tvm/topi/x86/scatter.py
+++ /dev/null
@@ -1,119 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""Scatter operators for x86"""
-import tvm
-from tvm import te
-from ..scatter import _verify_scatter_nd_inputs
-
-
-def scatter_nd(data, indices, updates, mode):
-    """Scatter elements from a n-dimension array.
-
-    Given updates with shape (Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1}), indices 
with shape
-    (M, Y_0, ..., Y_{K-1}), and output copied from data with shape (X_0, X_1, 
..., X_{N-1}),
-    scatter_nd computes
-
-    .. code-block::
-
-        output[indices[0, y_0, ..., y_{K-1}],
-               ...,
-               indices[M-1, y_0, ..., y_{K-1}],
-               x_M,
-               ...,
-               x_{N-1}
-              ] = f(output[...], updates[y_0, ..., y_{K-1}, x_M, ..., x_{N-1}])
-
-    where the update function f is determinted by the mode.
-
-    Parameters
-    ----------
-    data : tvm.te.Tensor
-        The source array.
-
-    indices : tvm.te.Tensor
-        The indices of the values to extract.
-
-    updates : tvm.te.Tensor
-        The updates to apply at the Indices
-
-    mode : string
-        The update mode for the algorithm, either "update" or "add"
-        If update, the update values will replace the input data
-        If add, the update values will be added to the input data
-
-    Returns
-    -------
-    ret : tvm.te.Tensor
-    """
-    _verify_scatter_nd_inputs(data, indices, updates)
-
-    def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr):
-        # pylint: disable=invalid-name
-        ib = tvm.tir.ir_builder.create()
-
-        data = ib.buffer_ptr(data_ptr)
-        indices = ib.buffer_ptr(indices_ptr)
-        updates = ib.buffer_ptr(updates_ptr)
-        out = ib.buffer_ptr(out_ptr)
-
-        # We combine all the indices dimensions but the first one into a single
-        # dimension so we can iterate it in single loop instead of an arbitrary
-        # number of loops. We do the same thing for all the update dimensions.
-        fused_indices_dimension = 1
-        for i in indices_ptr.shape[1:]:
-            fused_indices_dimension *= i
-
-        fused_updates_dimension = 1
-        for i in updates_ptr.shape[len(indices_ptr.shape) - 1 :]:
-            fused_updates_dimension *= i
-
-        fused_shape = 1
-        for i in data_ptr.shape:
-            fused_shape *= i
-
-        with ib.for_range(0, fused_shape) as i:
-            out[i] = data[i]
-
-        with ib.for_range(0, fused_indices_dimension) as i:
-            with ib.for_range(0, fused_updates_dimension, kind="parallel") as 
j:
-                offset = fused_updates_dimension
-                index = j  # This is x_M, .. x_{N-1} part of the index into 
out.
-                # Build up the indices[0, y_0, .. y_{K-1}], .. indices[M-1, 
y_0, .. y_{K-1}] part
-                # of the index into out.
-                for l in reversed(range(indices_ptr.shape[0].value)):
-                    # indices[i * l * fused_indices_dimension] = indices[l, 
y_0, ... y_{k-1}]
-                    index += offset * indices[i + l * fused_indices_dimension]
-                    offset *= data_ptr.shape[l]
-                if mode == "update":
-                    out[index] = updates[i * fused_updates_dimension + j]
-                elif mode == "add":
-                    out[index] += updates[i * fused_updates_dimension + j]
-                else:
-                    raise NotImplementedError("scatter_nd mode not in [update, 
add]:", mode)
-
-        return ib.get()
-
-    out_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "out_buf")
-    return te.extern(
-        [data.shape],
-        [data, indices, updates],
-        lambda ins, outs: gen_ir(ins[0], ins[1], ins[2], outs[0]),
-        dtype=data.dtype,
-        out_buffers=[out_buf],
-        name="scatter_nd_x86",
-        tag="scatter_nd_x86",
-    )
diff --git a/tests/python/topi/python/test_topi_scatter.py 
b/tests/python/topi/python/test_topi_scatter.py
index 648ef62a04..025e44889d 100644
--- a/tests/python/topi/python/test_topi_scatter.py
+++ b/tests/python/topi/python/test_topi_scatter.py
@@ -33,10 +33,6 @@ def test_scatter_nd(dev, target):
                 lambda x, y, z: topi.cuda.scatter_nd(x, y, z, mode),
                 topi.generic.schedule_extern,
             ),
-            "cpu": (
-                lambda x, y, z: topi.x86.scatter_nd(x, y, z, mode),
-                topi.generic.schedule_extern,
-            ),
         }
         fcompute, fschedule = tvm.topi.testing.dispatch(target, 
implementations)
         tvm.topi.testing.compare_numpy_tvm(

Reply via email to