This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new 3912a64  Add operation scatter_add to relay, based on scatter 
implementation. (#6030)
3912a64 is described below

commit 3912a64e43412704fcd0c170e94a49dcc6e62f71
Author: notoraptor <notorap...@users.noreply.github.com>
AuthorDate: Wed Jul 15 22:21:28 2020 -0400

    Add operation scatter_add to relay, based on scatter implementation. (#6030)
---
 include/tvm/relay/attrs/transform.h     |   8 ++
 python/tvm/relay/op/_transform.py       |   9 ++
 python/tvm/relay/op/strategy/generic.py |   7 ++
 python/tvm/relay/op/transform.py        |  24 +++++
 src/relay/op/tensor/transform.cc        |  49 ++++++++++
 tests/python/relay/test_op_level3.py    |  45 +++++++++
 topi/python/topi/__init__.py            |   1 +
 topi/python/topi/generic/search.py      |  16 ++++
 topi/python/topi/scatter_add.py         | 165 ++++++++++++++++++++++++++++++++
 9 files changed, 324 insertions(+)

diff --git a/include/tvm/relay/attrs/transform.h 
b/include/tvm/relay/attrs/transform.h
index b0c8108..eb73427 100644
--- a/include/tvm/relay/attrs/transform.h
+++ b/include/tvm/relay/attrs/transform.h
@@ -101,6 +101,14 @@ struct ScatterAttrs : public tvm::AttrsNode<ScatterAttrs> {
   }
 };
 
+struct ScatterAddAttrs : public tvm::AttrsNode<ScatterAddAttrs> {
+  Integer axis;
+
+  TVM_DECLARE_ATTRS(ScatterAddAttrs, "relay.attrs.ScatterAddAttrs") {
+    TVM_ATTR_FIELD(axis).set_default(0).describe("The axis over which to 
select values.");
+  }
+};
+
 struct GatherAttrs : public tvm::AttrsNode<GatherAttrs> {
   Integer axis;
 
diff --git a/python/tvm/relay/op/_transform.py 
b/python/tvm/relay/op/_transform.py
index dc12658..a2c374d 100644
--- a/python/tvm/relay/op/_transform.py
+++ b/python/tvm/relay/op/_transform.py
@@ -101,6 +101,14 @@ def compute_scatter(attrs, inputs, output_type):
 
 _reg.register_schedule("scatter", strategy.schedule_scatter)
 
+# scatter_add
+@_reg.register_compute("scatter_add")
+def compute_scatter_add(attrs, inputs, output_type):
+    """Compute definition of scatter_add"""
+    return [topi.scatter_add(inputs[0], inputs[1], inputs[2], attrs.axis)]
+
+_reg.register_schedule("scatter_add", strategy.schedule_scatter_add)
+
 #####################
 #  Shape functions  #
 #####################
@@ -396,6 +404,7 @@ def argwhere_shape_func(attrs, inputs, out_ndims):
     return ValueError("Does not support rank higher than 5 in argwhere")
 
 _reg.register_shape_func("scatter", False, elemwise_shape_func)
+_reg.register_shape_func("scatter_add", False, elemwise_shape_func)
 
 @script
 def _layout_transform_shape_func(data_shape,
diff --git a/python/tvm/relay/op/strategy/generic.py 
b/python/tvm/relay/op/strategy/generic.py
index db0577c..62c2948 100644
--- a/python/tvm/relay/op/strategy/generic.py
+++ b/python/tvm/relay/op/strategy/generic.py
@@ -842,6 +842,13 @@ def schedule_scatter(attrs, outs, target):
     with target:
         return topi.generic.schedule_scatter(outs)
 
+# scatter_add
+@generic_func
+def schedule_scatter_add(attrs, outs, target):
+    """schedule scatter_add"""
+    with target:
+        return topi.generic.schedule_scatter_add(outs)
+
 # bitserial_conv2d
 def wrap_compute_bitserial_conv2d(topi_compute):
     """wrap bitserial_conv2d topi compute"""
diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py
index ae10dd5..6f23af2 100644
--- a/python/tvm/relay/op/transform.py
+++ b/python/tvm/relay/op/transform.py
@@ -275,6 +275,30 @@ def scatter(data, indices, updates, axis):
     """
     return _make.scatter(data, indices, updates, axis)
 
+def scatter_add(data, indices, updates, axis):
+    """Update data by adding values in updates at positions defined by indices
+
+    Parameters
+    ----------
+    data : relay.Expr
+        The input data to the operator.
+
+    indices : relay.Expr
+        The index locations to update.
+
+    updates : relay.Expr
+        The values to add.
+
+    axis : int
+        The axis to scatter_add on
+
+    Returns
+    -------
+    ret : relay.Expr
+        The computed result.
+    """
+    return _make.scatter_add(data, indices, updates, axis)
+
 def reshape_like(data, shape_like):
     """Reshapes the input array by the size of another array.
     For an input array with shape ``(d1, d2, ..., dk)``, `reshape_like` 
operation reshapes
diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc
index cc1150c..1b07253 100644
--- a/src/relay/op/tensor/transform.cc
+++ b/src/relay/op/tensor/transform.cc
@@ -806,6 +806,55 @@ RELAY_REGISTER_OP("scatter")
     .set_attr<TOpPattern>("TOpPattern", kOpaque)
     .set_support_level(10);
 
+// Scatter_add
+TVM_REGISTER_NODE_TYPE(ScatterAddAttrs);
+
+// Scatter Add
+bool ScatterAddRel(const Array<Type>& types, int num_inputs, const Attrs& 
attrs,
+                   const TypeReporter& reporter) {
+  CHECK_EQ(num_inputs, 3);
+  CHECK_EQ(types.size(), 4);
+  auto data = types[0].as<TensorTypeNode>();
+  if (data == nullptr) {
+    return false;
+  }
+  auto indices = types[1].as<TensorTypeNode>();
+  if (indices == nullptr) {
+    return false;
+  }
+  auto updates = types[2].as<TensorTypeNode>();
+  if (updates == nullptr) {
+    return false;
+  }
+  CHECK(indices->dtype.is_int()) << "indices of scatter_add must be tensor of 
integer";
+  const auto param = attrs.as<ScatterAddAttrs>();
+  CHECK(param != nullptr);
+  reporter->Assign(types[3], TensorType(data->shape, data->dtype));
+  return true;
+}
+
+TVM_REGISTER_GLOBAL("relay.op._make.scatter_add")
+    .set_body_typed([](Expr data, Expr indices, Expr updates, int axis) {
+      auto attrs = make_object<ScatterAddAttrs>();
+      attrs->axis = std::move(axis);
+      static const Op& op = Op::Get("scatter_add");
+      return Call(op, {data, indices, updates}, Attrs(attrs), {});
+    });
+
+RELAY_REGISTER_OP("scatter_add")
+    .describe(
+        R"doc(Update data by adding values in updates at positions defined by 
indices)doc" TVM_ADD_FILELINE)
+    .set_num_inputs(3)
+    .add_argument("data", "Tensor", "The input data tensor.")
+    .add_argument("indicies", "Tensor", "The indicies location tensor.")
+    .add_argument("updates", "Tensor", "The values to update the input with.")
+    .add_type_rel("ScatterAdd", ScatterAddRel)
+    .set_attr<TOpIsStateful>("TOpIsStateful", false)
+    .set_attr<TOpPattern>("TOpPattern", kOpaque)
+    .set_support_level(10);
+
+////
+
 // Take
 TVM_REGISTER_NODE_TYPE(TakeAttrs);
 
diff --git a/tests/python/relay/test_op_level3.py 
b/tests/python/relay/test_op_level3.py
index 115900f..0445c98 100644
--- a/tests/python/relay/test_op_level3.py
+++ b/tests/python/relay/test_op_level3.py
@@ -811,6 +811,51 @@ def test_scatter():
     verify_scatter((16, 16, 4, 5), (16, 16, 4, 5), 3)
 
 
+def test_scatter_add():
+
+    def ref_scatter_add(data, indices, updates, axis=0):
+        output = np.copy(data)
+        for index in np.ndindex(*indices.shape):
+            new_index = list(index)
+            new_index[axis] = indices[index]
+            output[tuple(new_index)] += updates[index]
+        return output
+
+    def verify_scatter_add(dshape, ishape, axis=0):
+        d = relay.var("d", relay.TensorType(dshape, "float32"))
+        i = relay.var("i", relay.TensorType(ishape, "int64"))
+        u = relay.var("u", relay.TensorType(ishape, "float32"))
+        z = relay.op.scatter_add(d, i, u, axis)
+
+        func = relay.Function([d, i, u], z)
+
+        data_np = np.random.uniform(size=dshape).astype("float32")
+        updates_np = np.random.uniform(size=ishape).astype("float32")
+        indices_np = np.random.randint(-dshape[axis], dshape[axis] - 1, 
ishape).astype("int64")
+
+        ref_res = ref_scatter_add(data_np, indices_np, updates_np, axis)
+        # TODO(mbrookhart): expand testing when adding more backend schedules
+        for target, ctx in [("llvm", tvm.cpu())]:
+            for kind in ["graph", "debug"]:
+                intrp = relay.create_executor(kind, ctx=ctx, target=target)
+                op_res = intrp.evaluate(func)(data_np, indices_np, updates_np)
+                tvm.testing.assert_allclose(
+                    op_res.asnumpy(), ref_res, rtol=1e-5)
+
+    verify_scatter_add((10, ), (10, ), 0)
+    verify_scatter_add((10, 5), (10, 5), -2)
+    verify_scatter_add((10, 5), (10, 5), -1)
+    verify_scatter_add((10, 5), (3, 5), 0)
+    verify_scatter_add((12, 4), (7, 2), 1)
+    verify_scatter_add((2, 3, 4), (1, 3, 4), 0)
+    verify_scatter_add((2, 3, 4), (2, 1, 4), 1)
+    verify_scatter_add((2, 3, 4), (2, 3, 1), 2)
+    verify_scatter_add((2, 3, 4, 5), (1, 3, 4, 5), 0)
+    verify_scatter_add((6, 3, 4, 5), (2, 3, 4, 5), 1)
+    verify_scatter_add((2, 3, 8, 5), (2, 3, 1, 1), 2)
+    verify_scatter_add((16, 16, 4, 5), (16, 16, 4, 5), 3)
+
+
 def test_gather():
     def verify_gather(data, axis, indices, ref_res):
         data = np.asarray(data, dtype='float32')
diff --git a/topi/python/topi/__init__.py b/topi/python/topi/__init__.py
index 56c3a74..f308aa6 100644
--- a/topi/python/topi/__init__.py
+++ b/topi/python/topi/__init__.py
@@ -40,6 +40,7 @@ from .transform import *
 from .broadcast import *
 from .sort import *
 from .scatter import *
+from .scatter_add import *
 from .argwhere import *
 from . import generic
 from . import nn
diff --git a/topi/python/topi/generic/search.py 
b/topi/python/topi/generic/search.py
index 895dadb..b3c8772 100644
--- a/topi/python/topi/generic/search.py
+++ b/topi/python/topi/generic/search.py
@@ -50,3 +50,19 @@ def schedule_scatter(outs):
       The computation schedule for the op.
     """
     return _default_schedule(outs, False)
+
+
+def schedule_scatter_add(outs):
+    """Schedule for scatter_add operator.
+
+    Parameters
+    ----------
+    outs: Array of Tensor
+      The computation graph description of scatter_add.
+
+    Returns
+    -------
+    s: Schedule
+      The computation schedule for the op.
+    """
+    return _default_schedule(outs, False)
diff --git a/topi/python/topi/scatter_add.py b/topi/python/topi/scatter_add.py
new file mode 100644
index 0000000..046972b
--- /dev/null
+++ b/topi/python/topi/scatter_add.py
@@ -0,0 +1,165 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, too-many-arguments, too-many-nested-blocks
+"""Scatter Add operator"""
+from tvm.te import hybrid
+
+
+@hybrid.script
+def _scatter_add_1d(data, indices, updates):
+    out = output_tensor(data.shape, data.dtype)
+    for i in range(data.shape[0]):
+        out[i] = data[i]
+    for i in range(indices.shape[0]):
+        out[indices[i] if indices[i] >= 0 else indices[i] +
+            data.shape[0]] += updates[i]
+    return out
+
+
+@hybrid.script
+def _scatter_add_2d(data, indices, updates, axis):
+    out = output_tensor(data.shape, data.dtype)
+    for i in const_range(data.shape[0]):
+        for j in const_range(data.shape[1]):
+            out[i, j] = data[i, j]
+    if axis == 0:
+        for i in range(indices.shape[0]):
+            for j in range(indices.shape[1]):
+                out[indices[i, j] if indices[i, j] >=
+                    0 else indices[i, j] + data.shape[axis], j] += updates[i, 
j]
+    else:
+        for i in range(indices.shape[0]):
+            for j in range(indices.shape[1]):
+                out[i, indices[i, j] if indices[i, j] >=
+                    0 else indices[i, j] + data.shape[axis]] += updates[i, j]
+
+    return out
+
+
+@hybrid.script
+def _scatter_add_3d(data, indices, updates, axis):
+    out = output_tensor(data.shape, data.dtype)
+    for i in const_range(data.shape[0]):
+        for j in const_range(data.shape[1]):
+            for k in const_range(data.shape[2]):
+                out[i, j, k] = data[i, j, k]
+    if axis == 0:
+        for i in range(indices.shape[0]):
+            for j in range(indices.shape[1]):
+                for k in const_range(indices.shape[2]):
+                    out[indices[i, j, k] if indices[i, j, k] >=
+                        0 else indices[i, j, k] + data.shape[axis], j, k] += 
updates[i, j, k]
+    elif axis == 1:
+        for i in range(indices.shape[0]):
+            for j in range(indices.shape[1]):
+                for k in const_range(indices.shape[2]):
+                    out[i, indices[i, j, k] if indices[i, j, k] >=
+                        0 else indices[i, j, k] + data.shape[axis], k] += 
updates[i, j, k]
+    else:
+        for i in range(indices.shape[0]):
+            for j in range(indices.shape[1]):
+                for k in const_range(indices.shape[2]):
+                    out[i, j, indices[i, j, k] if indices[i, j, k] >=
+                        0 else indices[i, j, k] + data.shape[axis]] += 
updates[i, j, k]
+
+    return out
+
+
+@hybrid.script
+def _scatter_add_4d(data, indices, updates, axis):
+    out = output_tensor(data.shape, data.dtype)
+    for i in const_range(data.shape[0]):
+        for j in const_range(data.shape[1]):
+            for k in const_range(data.shape[2]):
+                for l in const_range(data.shape[3]):
+                    out[i, j, k, l] = data[i, j, k, l]
+
+    if axis == 0:
+        for i in range(indices.shape[0]):
+            for j in range(indices.shape[1]):
+                for k in const_range(indices.shape[2]):
+                    for l in const_range(indices.shape[3]):
+                        out[indices[i, j, k, l] if indices[i, j, k, l] >=
+                            0 else indices[i, j, k, l] + data.shape[axis],
+                            j, k, l] += updates[i, j, k, l]
+    elif axis == 1:
+        for i in range(indices.shape[0]):
+            for j in range(indices.shape[1]):
+                for k in const_range(indices.shape[2]):
+                    for l in const_range(indices.shape[3]):
+                        out[i,
+                            indices[i, j, k, l] if indices[i, j, k, l] >=
+                            0 else indices[i, j, k, l] + data.shape[axis],
+                            k, l] += updates[i, j, k, l]
+    elif axis == 2:
+        for i in range(indices.shape[0]):
+            for j in range(indices.shape[1]):
+                for k in const_range(indices.shape[2]):
+                    for l in const_range(indices.shape[3]):
+                        out[i, j,
+                            indices[i, j, k, l] if indices[i, j, k, l] >=
+                            0 else indices[i, j, k, l] + data.shape[axis],
+                            l] += updates[i, j, k, l]
+    else:
+        for i in range(indices.shape[0]):
+            for j in range(indices.shape[1]):
+                for k in const_range(indices.shape[2]):
+                    for l in const_range(indices.shape[3]):
+                        out[i, j, k,
+                            indices[i, j, k, l] if indices[i, j, k, l] >=
+                            0 else indices[i, j, k, l] + data.shape[axis]
+                            ] += updates[i, j, k, l]
+
+    return out
+
+
+def scatter_add(data, indices, updates, axis=0):
+    """Update data by adding values in updates at positions defined by indices
+
+    Parameters
+    ----------
+    data : relay.Expr
+        The input data to the operator.
+
+    indices : relay.Expr
+        The index locations to update.
+
+    updates : relay.Expr
+        The values to update.
+
+    axis : int
+        The axis to scatter_add on
+
+    Returns
+    -------
+    ret : relay.Expr
+        The computed result.
+    """
+    if axis < 0:
+        axis += len(data.shape)
+    assert axis >= 0
+    assert axis < len(data.shape)
+
+    if len(data.shape) == 1:
+        return _scatter_add_1d(data, indices, updates)
+    if len(data.shape) == 2:
+        return _scatter_add_2d(data, indices, updates, axis)
+    if len(data.shape) == 3:
+        return _scatter_add_3d(data, indices, updates, axis)
+    if len(data.shape) == 4:
+        return _scatter_add_4d(data, indices, updates, axis)
+    raise ValueError("scatter_add only support for 1-4 dimensions")

Reply via email to