This is an automated email from the ASF dual-hosted git repository.
tkonolige pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 48842d78e7 [Fix,TOPI] Consolidate generic and x86 scatter nd (#13755)
48842d78e7 is described below
commit 48842d78e7b78dcc75883da37005f68c04c4f9a6
Author: Tristan Konolige <[email protected]>
AuthorDate: Wed Jan 11 12:47:24 2023 -0800
[Fix,TOPI] Consolidate generic and x86 scatter nd (#13755)
The generic scatter nd was almost identical to the x86 one and was not
tested. They now are one and the same.
---
python/tvm/relay/op/strategy/x86.py | 2 +-
python/tvm/topi/scatter.py | 55 ++++++------
python/tvm/topi/x86/__init__.py | 1 -
python/tvm/topi/x86/scatter.py | 119 --------------------------
tests/python/topi/python/test_topi_scatter.py | 4 -
5 files changed, 26 insertions(+), 155 deletions(-)
diff --git a/python/tvm/relay/op/strategy/x86.py
b/python/tvm/relay/op/strategy/x86.py
index d0ad377203..fa002737a7 100644
--- a/python/tvm/relay/op/strategy/x86.py
+++ b/python/tvm/relay/op/strategy/x86.py
@@ -768,7 +768,7 @@ def scatter_nd_strategy_cpu(attrs, inputs, out_type,
target):
"""scatter_nd x86 strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
- wrap_compute_scatter_nd(topi.x86.scatter_nd),
+ wrap_compute_scatter_nd(topi.scatter_nd),
wrap_topi_schedule(topi.generic.schedule_extern),
name="scatter_nd.x86",
plevel=10,
diff --git a/python/tvm/topi/scatter.py b/python/tvm/topi/scatter.py
index afb0d6633a..e0578aab41 100644
--- a/python/tvm/topi/scatter.py
+++ b/python/tvm/topi/scatter.py
@@ -16,8 +16,8 @@
# under the License.
# pylint: disable=invalid-name, too-many-arguments, too-many-nested-blocks
"""Scatter operator"""
-from ..tir import decl_buffer, ir_builder, AssertStmt, StringImm, Evaluate,
expr
from ..te import extern, hybrid
+from ..tir import decl_buffer, expr, ir_builder
@hybrid.script
@@ -268,6 +268,7 @@ def scatter_nd(data, indices, updates, mode):
_verify_scatter_nd_inputs(data, indices, updates)
def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr):
+ # pylint: disable=invalid-name
ib = ir_builder.create()
data = ib.buffer_ptr(data_ptr)
@@ -275,56 +276,50 @@ def scatter_nd(data, indices, updates, mode):
updates = ib.buffer_ptr(updates_ptr)
out = ib.buffer_ptr(out_ptr)
- fused_shape = 1
- for i in data.shape:
- fused_shape *= i
- with ib.for_range(0, fused_shape) as i:
- out[i] = data[i]
-
# We combine all the indices dimensions but the first one into a single
# dimension so we can iterate it in single loop instead of an arbitrary
- # number of loops. We do the same thing for all the data dimensions.
+ # number of loops. We do the same thing for all the update dimensions.
fused_indices_dimension = 1
for i in indices_ptr.shape[1:]:
fused_indices_dimension *= i
- fused_data_dimension = 1
- for i in data_ptr.shape[len(indices_ptr.shape) - 1 :]:
- fused_data_dimension *= i
+ fused_updates_dimension = 1
+ for i in updates_ptr.shape[len(indices_ptr.shape) - 1 :]:
+ fused_updates_dimension *= i
+
+ fused_shape = 1
+ for i in data_ptr.shape:
+ fused_shape *= i
+
+ with ib.for_range(0, fused_shape) as i:
+ out[i] = data[i]
- with ib.for_range(0, fused_indices_dimension, name="i") as i:
- with ib.for_range(0, fused_data_dimension, name="j") as j:
- offset = fused_data_dimension
+ with ib.for_range(0, fused_indices_dimension) as i:
+ with ib.for_range(0, fused_updates_dimension, kind="parallel") as
j:
+ offset = fused_updates_dimension
index = j # This is x_M, .. x_{N-1} part of the index into
out.
# Build up the indices[0, y_0, .. y_{K-1}], .. indices[M-1,
y_0, .. y_{K-1}] part
# of the index into out.
for l in reversed(range(indices_ptr.shape[0].value)):
# indices[i * l * fused_indices_dimension] = indices[l,
y_0, ... y_{k-1}]
index += offset * indices[i + l * fused_indices_dimension]
- ib.emit(
- AssertStmt(
- indices[i + l * fused_indices_dimension] <
shape[l],
- StringImm("index out of bounds"),
- Evaluate(0),
- )
- )
- offset *= shape[l]
- if mode == "add":
- out[index] += updates[i * fused_data_dimension + j]
- elif mode == "update":
- out[index] = updates[i * fused_data_dimension + j]
+ offset *= data_ptr.shape[l]
+ if mode == "update":
+ out[index] = updates[i * fused_updates_dimension + j]
+ elif mode == "add":
+ out[index] += updates[i * fused_updates_dimension + j]
else:
raise NotImplementedError("scatter_nd mode not in [update,
add]:", mode)
return ib.get()
- out_buf = decl_buffer(shape, data.dtype, "out_buf")
+ out_buf = decl_buffer(data.shape, data.dtype, "out_buf")
return extern(
- [shape],
+ [data.shape],
[data, indices, updates],
lambda ins, outs: gen_ir(ins[0], ins[1], ins[2], outs[0]),
dtype=data.dtype,
out_buffers=[out_buf],
- name="scatter_nd_generic",
- tag="scatter_nd_generic",
+ name="scatter_nd.generic",
+ tag="scatter_nd.generic",
)
diff --git a/python/tvm/topi/x86/__init__.py b/python/tvm/topi/x86/__init__.py
index d075090f01..a54b156380 100644
--- a/python/tvm/topi/x86/__init__.py
+++ b/python/tvm/topi/x86/__init__.py
@@ -40,7 +40,6 @@ from .conv3d_transpose import *
from .sparse import *
from .conv2d_alter_op import *
from .dense_alter_op import *
-from .scatter import *
from .group_conv2d import *
from .math_alter_op import *
from .concat import *
diff --git a/python/tvm/topi/x86/scatter.py b/python/tvm/topi/x86/scatter.py
deleted file mode 100644
index 5eb5e6e99b..0000000000
--- a/python/tvm/topi/x86/scatter.py
+++ /dev/null
@@ -1,119 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""Scatter operators for x86"""
-import tvm
-from tvm import te
-from ..scatter import _verify_scatter_nd_inputs
-
-
-def scatter_nd(data, indices, updates, mode):
- """Scatter elements from a n-dimension array.
-
- Given updates with shape (Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1}), indices
with shape
- (M, Y_0, ..., Y_{K-1}), and output copied from data with shape (X_0, X_1,
..., X_{N-1}),
- scatter_nd computes
-
- .. code-block::
-
- output[indices[0, y_0, ..., y_{K-1}],
- ...,
- indices[M-1, y_0, ..., y_{K-1}],
- x_M,
- ...,
- x_{N-1}
- ] = f(output[...], updates[y_0, ..., y_{K-1}, x_M, ..., x_{N-1}])
-
- where the update function f is determinted by the mode.
-
- Parameters
- ----------
- data : tvm.te.Tensor
- The source array.
-
- indices : tvm.te.Tensor
- The indices of the values to extract.
-
- updates : tvm.te.Tensor
- The updates to apply at the Indices
-
- mode : string
- The update mode for the algorithm, either "update" or "add"
- If update, the update values will replace the input data
- If add, the update values will be added to the input data
-
- Returns
- -------
- ret : tvm.te.Tensor
- """
- _verify_scatter_nd_inputs(data, indices, updates)
-
- def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr):
- # pylint: disable=invalid-name
- ib = tvm.tir.ir_builder.create()
-
- data = ib.buffer_ptr(data_ptr)
- indices = ib.buffer_ptr(indices_ptr)
- updates = ib.buffer_ptr(updates_ptr)
- out = ib.buffer_ptr(out_ptr)
-
- # We combine all the indices dimensions but the first one into a single
- # dimension so we can iterate it in single loop instead of an arbitrary
- # number of loops. We do the same thing for all the update dimensions.
- fused_indices_dimension = 1
- for i in indices_ptr.shape[1:]:
- fused_indices_dimension *= i
-
- fused_updates_dimension = 1
- for i in updates_ptr.shape[len(indices_ptr.shape) - 1 :]:
- fused_updates_dimension *= i
-
- fused_shape = 1
- for i in data_ptr.shape:
- fused_shape *= i
-
- with ib.for_range(0, fused_shape) as i:
- out[i] = data[i]
-
- with ib.for_range(0, fused_indices_dimension) as i:
- with ib.for_range(0, fused_updates_dimension, kind="parallel") as
j:
- offset = fused_updates_dimension
- index = j # This is x_M, .. x_{N-1} part of the index into
out.
- # Build up the indices[0, y_0, .. y_{K-1}], .. indices[M-1,
y_0, .. y_{K-1}] part
- # of the index into out.
- for l in reversed(range(indices_ptr.shape[0].value)):
- # indices[i * l * fused_indices_dimension] = indices[l,
y_0, ... y_{k-1}]
- index += offset * indices[i + l * fused_indices_dimension]
- offset *= data_ptr.shape[l]
- if mode == "update":
- out[index] = updates[i * fused_updates_dimension + j]
- elif mode == "add":
- out[index] += updates[i * fused_updates_dimension + j]
- else:
- raise NotImplementedError("scatter_nd mode not in [update,
add]:", mode)
-
- return ib.get()
-
- out_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "out_buf")
- return te.extern(
- [data.shape],
- [data, indices, updates],
- lambda ins, outs: gen_ir(ins[0], ins[1], ins[2], outs[0]),
- dtype=data.dtype,
- out_buffers=[out_buf],
- name="scatter_nd_x86",
- tag="scatter_nd_x86",
- )
diff --git a/tests/python/topi/python/test_topi_scatter.py
b/tests/python/topi/python/test_topi_scatter.py
index 648ef62a04..025e44889d 100644
--- a/tests/python/topi/python/test_topi_scatter.py
+++ b/tests/python/topi/python/test_topi_scatter.py
@@ -33,10 +33,6 @@ def test_scatter_nd(dev, target):
lambda x, y, z: topi.cuda.scatter_nd(x, y, z, mode),
topi.generic.schedule_extern,
),
- "cpu": (
- lambda x, y, z: topi.x86.scatter_nd(x, y, z, mode),
- topi.generic.schedule_extern,
- ),
}
fcompute, fschedule = tvm.topi.testing.dispatch(target,
implementations)
tvm.topi.testing.compare_numpy_tvm(