This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new 3912a64 Add operation scatter_add to relay, based on scatter
implementation. (#6030)
3912a64 is described below
commit 3912a64e43412704fcd0c170e94a49dcc6e62f71
Author: notoraptor <[email protected]>
AuthorDate: Wed Jul 15 22:21:28 2020 -0400
Add operation scatter_add to relay, based on scatter implementation. (#6030)
---
include/tvm/relay/attrs/transform.h | 8 ++
python/tvm/relay/op/_transform.py | 9 ++
python/tvm/relay/op/strategy/generic.py | 7 ++
python/tvm/relay/op/transform.py | 24 +++++
src/relay/op/tensor/transform.cc | 49 ++++++++++
tests/python/relay/test_op_level3.py | 45 +++++++++
topi/python/topi/__init__.py | 1 +
topi/python/topi/generic/search.py | 16 ++++
topi/python/topi/scatter_add.py | 165 ++++++++++++++++++++++++++++++++
9 files changed, 324 insertions(+)
diff --git a/include/tvm/relay/attrs/transform.h
b/include/tvm/relay/attrs/transform.h
index b0c8108..eb73427 100644
--- a/include/tvm/relay/attrs/transform.h
+++ b/include/tvm/relay/attrs/transform.h
@@ -101,6 +101,14 @@ struct ScatterAttrs : public tvm::AttrsNode<ScatterAttrs> {
}
};
+struct ScatterAddAttrs : public tvm::AttrsNode<ScatterAddAttrs> {
+ Integer axis;
+
+ TVM_DECLARE_ATTRS(ScatterAddAttrs, "relay.attrs.ScatterAddAttrs") {
+ TVM_ATTR_FIELD(axis).set_default(0).describe("The axis over which to
select values.");
+ }
+};
+
struct GatherAttrs : public tvm::AttrsNode<GatherAttrs> {
Integer axis;
diff --git a/python/tvm/relay/op/_transform.py
b/python/tvm/relay/op/_transform.py
index dc12658..a2c374d 100644
--- a/python/tvm/relay/op/_transform.py
+++ b/python/tvm/relay/op/_transform.py
@@ -101,6 +101,14 @@ def compute_scatter(attrs, inputs, output_type):
_reg.register_schedule("scatter", strategy.schedule_scatter)
+# scatter_add
+@_reg.register_compute("scatter_add")
+def compute_scatter_add(attrs, inputs, output_type):
+ """Compute definition of scatter_add"""
+ return [topi.scatter_add(inputs[0], inputs[1], inputs[2], attrs.axis)]
+
+_reg.register_schedule("scatter_add", strategy.schedule_scatter_add)
+
#####################
# Shape functions #
#####################
@@ -396,6 +404,7 @@ def argwhere_shape_func(attrs, inputs, out_ndims):
return ValueError("Does not support rank higher than 5 in argwhere")
_reg.register_shape_func("scatter", False, elemwise_shape_func)
+_reg.register_shape_func("scatter_add", False, elemwise_shape_func)
@script
def _layout_transform_shape_func(data_shape,
diff --git a/python/tvm/relay/op/strategy/generic.py
b/python/tvm/relay/op/strategy/generic.py
index db0577c..62c2948 100644
--- a/python/tvm/relay/op/strategy/generic.py
+++ b/python/tvm/relay/op/strategy/generic.py
@@ -842,6 +842,13 @@ def schedule_scatter(attrs, outs, target):
with target:
return topi.generic.schedule_scatter(outs)
+# scatter_add
+@generic_func
+def schedule_scatter_add(attrs, outs, target):
+ """schedule scatter_add"""
+ with target:
+ return topi.generic.schedule_scatter_add(outs)
+
# bitserial_conv2d
def wrap_compute_bitserial_conv2d(topi_compute):
"""wrap bitserial_conv2d topi compute"""
diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py
index ae10dd5..6f23af2 100644
--- a/python/tvm/relay/op/transform.py
+++ b/python/tvm/relay/op/transform.py
@@ -275,6 +275,30 @@ def scatter(data, indices, updates, axis):
"""
return _make.scatter(data, indices, updates, axis)
+def scatter_add(data, indices, updates, axis):
+ """Update data by adding values in updates at positions defined by indices
+
+ Parameters
+ ----------
+ data : relay.Expr
+ The input data to the operator.
+
+ indices : relay.Expr
+ The index locations to update.
+
+ updates : relay.Expr
+ The values to add.
+
+ axis : int
+ The axis to scatter_add on
+
+ Returns
+ -------
+ ret : relay.Expr
+ The computed result.
+ """
+ return _make.scatter_add(data, indices, updates, axis)
+
def reshape_like(data, shape_like):
"""Reshapes the input array by the size of another array.
For an input array with shape ``(d1, d2, ..., dk)``, `reshape_like`
operation reshapes
diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc
index cc1150c..1b07253 100644
--- a/src/relay/op/tensor/transform.cc
+++ b/src/relay/op/tensor/transform.cc
@@ -806,6 +806,55 @@ RELAY_REGISTER_OP("scatter")
.set_attr<TOpPattern>("TOpPattern", kOpaque)
.set_support_level(10);
+// Scatter_add
+TVM_REGISTER_NODE_TYPE(ScatterAddAttrs);
+
+// Scatter Add
+bool ScatterAddRel(const Array<Type>& types, int num_inputs, const Attrs&
attrs,
+ const TypeReporter& reporter) {
+ CHECK_EQ(num_inputs, 3);
+ CHECK_EQ(types.size(), 4);
+ auto data = types[0].as<TensorTypeNode>();
+ if (data == nullptr) {
+ return false;
+ }
+ auto indices = types[1].as<TensorTypeNode>();
+ if (indices == nullptr) {
+ return false;
+ }
+ auto updates = types[2].as<TensorTypeNode>();
+ if (updates == nullptr) {
+ return false;
+ }
+ CHECK(indices->dtype.is_int()) << "indices of scatter_add must be tensor of
integer";
+ const auto param = attrs.as<ScatterAddAttrs>();
+ CHECK(param != nullptr);
+ reporter->Assign(types[3], TensorType(data->shape, data->dtype));
+ return true;
+}
+
+TVM_REGISTER_GLOBAL("relay.op._make.scatter_add")
+ .set_body_typed([](Expr data, Expr indices, Expr updates, int axis) {
+ auto attrs = make_object<ScatterAddAttrs>();
+ attrs->axis = std::move(axis);
+ static const Op& op = Op::Get("scatter_add");
+ return Call(op, {data, indices, updates}, Attrs(attrs), {});
+ });
+
+RELAY_REGISTER_OP("scatter_add")
+ .describe(
+ R"doc(Update data by adding values in updates at positions defined by
indices)doc" TVM_ADD_FILELINE)
+ .set_num_inputs(3)
+ .add_argument("data", "Tensor", "The input data tensor.")
+ .add_argument("indicies", "Tensor", "The indicies location tensor.")
+ .add_argument("updates", "Tensor", "The values to update the input with.")
+ .add_type_rel("ScatterAdd", ScatterAddRel)
+ .set_attr<TOpIsStateful>("TOpIsStateful", false)
+ .set_attr<TOpPattern>("TOpPattern", kOpaque)
+ .set_support_level(10);
+
+////
+
// Take
TVM_REGISTER_NODE_TYPE(TakeAttrs);
diff --git a/tests/python/relay/test_op_level3.py
b/tests/python/relay/test_op_level3.py
index 115900f..0445c98 100644
--- a/tests/python/relay/test_op_level3.py
+++ b/tests/python/relay/test_op_level3.py
@@ -811,6 +811,51 @@ def test_scatter():
verify_scatter((16, 16, 4, 5), (16, 16, 4, 5), 3)
+def test_scatter_add():
+
+ def ref_scatter_add(data, indices, updates, axis=0):
+ output = np.copy(data)
+ for index in np.ndindex(*indices.shape):
+ new_index = list(index)
+ new_index[axis] = indices[index]
+ output[tuple(new_index)] += updates[index]
+ return output
+
+ def verify_scatter_add(dshape, ishape, axis=0):
+ d = relay.var("d", relay.TensorType(dshape, "float32"))
+ i = relay.var("i", relay.TensorType(ishape, "int64"))
+ u = relay.var("u", relay.TensorType(ishape, "float32"))
+ z = relay.op.scatter_add(d, i, u, axis)
+
+ func = relay.Function([d, i, u], z)
+
+ data_np = np.random.uniform(size=dshape).astype("float32")
+ updates_np = np.random.uniform(size=ishape).astype("float32")
+ indices_np = np.random.randint(-dshape[axis], dshape[axis] - 1,
ishape).astype("int64")
+
+ ref_res = ref_scatter_add(data_np, indices_np, updates_np, axis)
+ # TODO(mbrookhart): expand testing when adding more backend schedules
+ for target, ctx in [("llvm", tvm.cpu())]:
+ for kind in ["graph", "debug"]:
+ intrp = relay.create_executor(kind, ctx=ctx, target=target)
+ op_res = intrp.evaluate(func)(data_np, indices_np, updates_np)
+ tvm.testing.assert_allclose(
+ op_res.asnumpy(), ref_res, rtol=1e-5)
+
+ verify_scatter_add((10, ), (10, ), 0)
+ verify_scatter_add((10, 5), (10, 5), -2)
+ verify_scatter_add((10, 5), (10, 5), -1)
+ verify_scatter_add((10, 5), (3, 5), 0)
+ verify_scatter_add((12, 4), (7, 2), 1)
+ verify_scatter_add((2, 3, 4), (1, 3, 4), 0)
+ verify_scatter_add((2, 3, 4), (2, 1, 4), 1)
+ verify_scatter_add((2, 3, 4), (2, 3, 1), 2)
+ verify_scatter_add((2, 3, 4, 5), (1, 3, 4, 5), 0)
+ verify_scatter_add((6, 3, 4, 5), (2, 3, 4, 5), 1)
+ verify_scatter_add((2, 3, 8, 5), (2, 3, 1, 1), 2)
+ verify_scatter_add((16, 16, 4, 5), (16, 16, 4, 5), 3)
+
+
def test_gather():
def verify_gather(data, axis, indices, ref_res):
data = np.asarray(data, dtype='float32')
diff --git a/topi/python/topi/__init__.py b/topi/python/topi/__init__.py
index 56c3a74..f308aa6 100644
--- a/topi/python/topi/__init__.py
+++ b/topi/python/topi/__init__.py
@@ -40,6 +40,7 @@ from .transform import *
from .broadcast import *
from .sort import *
from .scatter import *
+from .scatter_add import *
from .argwhere import *
from . import generic
from . import nn
diff --git a/topi/python/topi/generic/search.py
b/topi/python/topi/generic/search.py
index 895dadb..b3c8772 100644
--- a/topi/python/topi/generic/search.py
+++ b/topi/python/topi/generic/search.py
@@ -50,3 +50,19 @@ def schedule_scatter(outs):
The computation schedule for the op.
"""
return _default_schedule(outs, False)
+
+
+def schedule_scatter_add(outs):
+ """Schedule for scatter_add operator.
+
+ Parameters
+ ----------
+ outs: Array of Tensor
+ The computation graph description of scatter_add.
+
+ Returns
+ -------
+ s: Schedule
+ The computation schedule for the op.
+ """
+ return _default_schedule(outs, False)
diff --git a/topi/python/topi/scatter_add.py b/topi/python/topi/scatter_add.py
new file mode 100644
index 0000000..046972b
--- /dev/null
+++ b/topi/python/topi/scatter_add.py
@@ -0,0 +1,165 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, too-many-arguments, too-many-nested-blocks
+"""Scatter Add operator"""
+from tvm.te import hybrid
+
+
[email protected]
+def _scatter_add_1d(data, indices, updates):
+ out = output_tensor(data.shape, data.dtype)
+ for i in range(data.shape[0]):
+ out[i] = data[i]
+ for i in range(indices.shape[0]):
+ out[indices[i] if indices[i] >= 0 else indices[i] +
+ data.shape[0]] += updates[i]
+ return out
+
+
[email protected]
+def _scatter_add_2d(data, indices, updates, axis):
+ out = output_tensor(data.shape, data.dtype)
+ for i in const_range(data.shape[0]):
+ for j in const_range(data.shape[1]):
+ out[i, j] = data[i, j]
+ if axis == 0:
+ for i in range(indices.shape[0]):
+ for j in range(indices.shape[1]):
+ out[indices[i, j] if indices[i, j] >=
+ 0 else indices[i, j] + data.shape[axis], j] += updates[i,
j]
+ else:
+ for i in range(indices.shape[0]):
+ for j in range(indices.shape[1]):
+ out[i, indices[i, j] if indices[i, j] >=
+ 0 else indices[i, j] + data.shape[axis]] += updates[i, j]
+
+ return out
+
+
[email protected]
+def _scatter_add_3d(data, indices, updates, axis):
+ out = output_tensor(data.shape, data.dtype)
+ for i in const_range(data.shape[0]):
+ for j in const_range(data.shape[1]):
+ for k in const_range(data.shape[2]):
+ out[i, j, k] = data[i, j, k]
+ if axis == 0:
+ for i in range(indices.shape[0]):
+ for j in range(indices.shape[1]):
+ for k in const_range(indices.shape[2]):
+ out[indices[i, j, k] if indices[i, j, k] >=
+ 0 else indices[i, j, k] + data.shape[axis], j, k] +=
updates[i, j, k]
+ elif axis == 1:
+ for i in range(indices.shape[0]):
+ for j in range(indices.shape[1]):
+ for k in const_range(indices.shape[2]):
+ out[i, indices[i, j, k] if indices[i, j, k] >=
+ 0 else indices[i, j, k] + data.shape[axis], k] +=
updates[i, j, k]
+ else:
+ for i in range(indices.shape[0]):
+ for j in range(indices.shape[1]):
+ for k in const_range(indices.shape[2]):
+ out[i, j, indices[i, j, k] if indices[i, j, k] >=
+ 0 else indices[i, j, k] + data.shape[axis]] +=
updates[i, j, k]
+
+ return out
+
+
[email protected]
+def _scatter_add_4d(data, indices, updates, axis):
+ out = output_tensor(data.shape, data.dtype)
+ for i in const_range(data.shape[0]):
+ for j in const_range(data.shape[1]):
+ for k in const_range(data.shape[2]):
+ for l in const_range(data.shape[3]):
+ out[i, j, k, l] = data[i, j, k, l]
+
+ if axis == 0:
+ for i in range(indices.shape[0]):
+ for j in range(indices.shape[1]):
+ for k in const_range(indices.shape[2]):
+ for l in const_range(indices.shape[3]):
+ out[indices[i, j, k, l] if indices[i, j, k, l] >=
+ 0 else indices[i, j, k, l] + data.shape[axis],
+ j, k, l] += updates[i, j, k, l]
+ elif axis == 1:
+ for i in range(indices.shape[0]):
+ for j in range(indices.shape[1]):
+ for k in const_range(indices.shape[2]):
+ for l in const_range(indices.shape[3]):
+ out[i,
+ indices[i, j, k, l] if indices[i, j, k, l] >=
+ 0 else indices[i, j, k, l] + data.shape[axis],
+ k, l] += updates[i, j, k, l]
+ elif axis == 2:
+ for i in range(indices.shape[0]):
+ for j in range(indices.shape[1]):
+ for k in const_range(indices.shape[2]):
+ for l in const_range(indices.shape[3]):
+ out[i, j,
+ indices[i, j, k, l] if indices[i, j, k, l] >=
+ 0 else indices[i, j, k, l] + data.shape[axis],
+ l] += updates[i, j, k, l]
+ else:
+ for i in range(indices.shape[0]):
+ for j in range(indices.shape[1]):
+ for k in const_range(indices.shape[2]):
+ for l in const_range(indices.shape[3]):
+ out[i, j, k,
+ indices[i, j, k, l] if indices[i, j, k, l] >=
+ 0 else indices[i, j, k, l] + data.shape[axis]
+ ] += updates[i, j, k, l]
+
+ return out
+
+
+def scatter_add(data, indices, updates, axis=0):
+ """Update data by adding values in updates at positions defined by indices
+
+ Parameters
+ ----------
+ data : relay.Expr
+ The input data to the operator.
+
+ indices : relay.Expr
+ The index locations to update.
+
+ updates : relay.Expr
+ The values to update.
+
+ axis : int
+ The axis to scatter_add on
+
+ Returns
+ -------
+ ret : relay.Expr
+ The computed result.
+ """
+ if axis < 0:
+ axis += len(data.shape)
+ assert axis >= 0
+ assert axis < len(data.shape)
+
+ if len(data.shape) == 1:
+ return _scatter_add_1d(data, indices, updates)
+ if len(data.shape) == 2:
+ return _scatter_add_2d(data, indices, updates, axis)
+ if len(data.shape) == 3:
+ return _scatter_add_3d(data, indices, updates, axis)
+ if len(data.shape) == 4:
+ return _scatter_add_4d(data, indices, updates, axis)
+ raise ValueError("scatter_add only support for 1-4 dimensions")