This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 59a2256959 [Relax] Add gather_elements and gather_nd operators (#17523)
59a2256959 is described below

commit 59a2256959a0ea35c4380d98bb983bf4f1e9d856
Author: Siyuan Feng <[email protected]>
AuthorDate: Wed Nov 13 23:29:03 2024 +0800

    [Relax] Add gather_elements and gather_nd operators (#17523)
    
    Add gather_elements and gather_nd operators to Relax and corresponding
    ONNX frontend.
---
 include/tvm/relax/attrs/manipulate.h               |  17 +++
 python/tvm/relax/frontend/onnx/onnx_frontend.py    |  22 ++-
 python/tvm/relax/op/__init__.py                    |   2 +
 python/tvm/relax/op/manipulate.py                  |  73 +++++++++
 .../tvm/relax/transform/legalize_ops/manipulate.py |  16 ++
 python/tvm/script/ir_builder/relax/ir.py           |   4 +
 python/tvm/topi/transform.py                       |   4 +-
 src/relax/op/tensor/manipulate.cc                  | 163 +++++++++++++++++++++
 src/relax/op/tensor/manipulate.h                   |  26 ++++
 src/topi/transform.cc                              |   3 +-
 tests/python/relax/test_frontend_onnx.py           |  62 ++++++++
 tests/python/relax/test_op_manipulate.py           | 127 ++++++++++++++++
 12 files changed, 514 insertions(+), 5 deletions(-)

diff --git a/include/tvm/relax/attrs/manipulate.h 
b/include/tvm/relax/attrs/manipulate.h
index ea41488354..e6c16d233a 100644
--- a/include/tvm/relax/attrs/manipulate.h
+++ b/include/tvm/relax/attrs/manipulate.h
@@ -152,6 +152,23 @@ struct FlipAttrs : public tvm::AttrsNode<FlipAttrs> {
   }
 };  // struct FlipAttrs
 
+/*! \brief Attributes used in gather_elements operators */
+struct GatherElementsAttrs : public tvm::AttrsNode<GatherElementsAttrs> {
+  Integer axis;
+
+  TVM_DECLARE_ATTRS(GatherElementsAttrs, "relax.attrs.GatherElementsAttrs") {
+    TVM_ATTR_FIELD(axis).set_default(0).describe("The axis along which to 
index.");
+  }
+};  // struct GatherElementsAttrs
+
+/*! \brief Attributes used in gather_nd operators */
+struct GatherNDAttrs : public tvm::AttrsNode<GatherNDAttrs> {
+  Integer batch_dims;
+  TVM_DECLARE_ATTRS(GatherNDAttrs, "relax.attrs.GatherNDAttrs") {
+    TVM_ATTR_FIELD(batch_dims).set_default(Integer(0)).describe("The number of 
batch dims.");
+  }
+};  // struct GatherNDAttrs
+
 /*! \brief Attributes used in scatter_elements operators */
 struct ScatterElementsAttrs : public tvm::AttrsNode<ScatterElementsAttrs> {
   Integer axis;
diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py 
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index 94ccfdb23e..b64e87822a 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -781,6 +781,24 @@ class Gather(OnnxOpConverter):
         return relax.op.take(data, indices, axis)
 
 
+class GatherElements(OnnxOpConverter):
+    """Convert an onnx GatherElements node into an equivalent Relax 
expression."""
+
+    @classmethod
+    def _impl_v13(cls, bb, inputs, attr, params):
+        axis = attr.get("axis", 0)
+        return relax.op.gather_elements(inputs[0], inputs[1], axis)
+
+
+class GatherND(OnnxOpConverter):
+    """Convert an onnx GatherND node into an equivalent Relax expression."""
+
+    @classmethod
+    def _impl_v13(cls, bb, inputs, attr, params):
+        batch_dims = attr.get("batch_dims", 0)
+        return relax.op.gather_nd(inputs[0], inputs[1], batch_dims)
+
+
 class Scatter(OnnxOpConverter):
     """Convert an onnx Scatter node into an equivalent Relax expression."""
 
@@ -3116,8 +3134,8 @@ def _get_convert_map():
         "Squeeze": Squeeze,
         "Constant": Constant,
         "Gather": Gather,
-        # "GatherElements": GatherElements,
-        # "GatherND": GatherND,
+        "GatherElements": GatherElements,
+        "GatherND": GatherND,
         "Scatter": Scatter,
         "ScatterElements": ScatterElements,
         "ScatterND": ScatterND,
diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py
index 1603ea2f0f..97f18a2396 100644
--- a/python/tvm/relax/op/__init__.py
+++ b/python/tvm/relax/op/__init__.py
@@ -92,6 +92,8 @@ from .manipulate import (
     expand_dims,
     flatten,
     flip,
+    gather_elements,
+    gather_nd,
     layout_transform,
     one_hot,
     permute_dims,
diff --git a/python/tvm/relax/op/manipulate.py 
b/python/tvm/relax/op/manipulate.py
index 3210cc8216..0f6e537ab3 100644
--- a/python/tvm/relax/op/manipulate.py
+++ b/python/tvm/relax/op/manipulate.py
@@ -435,6 +435,79 @@ def flip(data, axis):
     return _ffi_api.flip(data, axis)  # type: ignore
 
 
+def gather_elements(data: Expr, indices: Expr, axis: int = 0) -> Expr:
+    """Gather elements from data according to indices along the specified axis.
+
+    Parameters
+    ----------
+    data : relax.Expr
+        The input data to the operator.
+
+    indices : relax.Expr
+        The indices tensor, must have integer type.
+
+    axis : int
+        The axis along which to index. Default is 0.
+
+    Returns
+    -------
+    ret : relax.Expr
+        The computed result.
+
+    Examples
+    --------
+    .. code-block:: python
+
+        data = [[1, 2], [3, 4]]
+        indices = [[0, 0], [1, 0]]
+        axis = 1
+        output = [[1, 1], [4, 3]]
+
+        data = [[1, 2, 3], [4, 5, 6]]
+        indices = [[1, 1, 1]]
+        axis = 0
+        output = [[4, 5, 6]]
+    """
+    return _ffi_api.gather_elements(data, indices, axis)  # type: ignore
+
+
+def gather_nd(data: Expr, indices: Expr, batch_dims: int = 0) -> Expr:
+    """Update data at positions defined by indices with values in updates.
+
+    Parameters
+    ----------
+    data : relax.Expr
+        The input data to the operator.
+
+    indices : relax.Expr
+        The indices tensor, must have integer type.
+
+    batch_dims : int
+        The number of batch dimensions. Default is 0.
+
+    Returns
+    -------
+    ret : relax.Expr
+        The computed result.
+
+    Examples
+    --------
+    .. code-block:: python
+
+        batch_dims = 0
+        data    = [[0,1],[2,3]]   # data_shape    = [2, 2]
+        indices = [[0,0],[1,1]]   # indices_shape = [2, 2]
+        output  = [0,3]           # output_shape  = [2]
+
+        batch_dims = 1
+        data    = [[[0,1],[2,3]],[[4,5],[6,7]]] # data_shape    = [2, 2, 2]
+        indices = [[1],[0]]                     # indices_shape = [2, 1]
+        output  = [[2,3],[4,5]]                 # output_shape  = [2, 2]
+
+    """
+    return _ffi_api.gather_nd(data, indices, batch_dims)  # type: ignore
+
+
 def scatter_elements(
     data: Expr, indices: Expr, updates: Expr, axis: int = 0, reduction: str = 
"update"
 ):
diff --git a/python/tvm/relax/transform/legalize_ops/manipulate.py 
b/python/tvm/relax/transform/legalize_ops/manipulate.py
index 163085a07c..55bc2772bc 100644
--- a/python/tvm/relax/transform/legalize_ops/manipulate.py
+++ b/python/tvm/relax/transform/legalize_ops/manipulate.py
@@ -156,6 +156,22 @@ def _flip(bb: BlockBuilder, call: Call) -> Expr:
     return bb.call_te(topi.flip, call.args[0], int(call.attrs.axis))
 
 
+@register_legalize("relax.gather_elements")
+def _gather_elements(bb: BlockBuilder, call: Call) -> Expr:
+    return bb.call_te(topi.gather, call.args[0], int(call.attrs.axis), 
call.args[1])
+
+
+@register_legalize("relax.gather_nd")
+def _gather_nd(bb: BlockBuilder, call: Call) -> Expr:
+    def te_gather_nd(data, indices, batch_dims):
+        indices_ndim = len(indices.shape)
+        axes = [indices_ndim - 1] + list(range(indices_ndim - 1))
+        indices = topi.transpose(indices, axes)
+        return topi.gather_nd(data, indices, batch_dims)
+
+    return bb.call_te(te_gather_nd, call.args[0], call.args[1], 
int(call.attrs.batch_dims))
+
+
 @register_legalize("relax.scatter_elements")
 def _scatter_elements(bb: BlockBuilder, call: Call) -> Expr:
     return bb.call_te(
diff --git a/python/tvm/script/ir_builder/relax/ir.py 
b/python/tvm/script/ir_builder/relax/ir.py
index 049345fcb1..ddc534cf60 100644
--- a/python/tvm/script/ir_builder/relax/ir.py
+++ b/python/tvm/script/ir_builder/relax/ir.py
@@ -94,6 +94,8 @@ from tvm.relax.op import (
     floor_mod,
     full,
     full_like,
+    gather_elements,
+    gather_nd,
     grad,
     greater,
     greater_equal,
@@ -772,6 +774,8 @@ __all__ = [
     "func_ret_struct_info",
     "func_ret_value",
     "function",
+    "gather_elements",
+    "gather_nd",
     "gpu",
     "grad",
     "greater",
diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py
index 686311fbee..2844825a4a 100644
--- a/python/tvm/topi/transform.py
+++ b/python/tvm/topi/transform.py
@@ -528,7 +528,7 @@ def gather(data, axis, indices):
     return cpp.gather(data, axis, indices)
 
 
-def gather_nd(a, indices):
+def gather_nd(a, indices, batch_dims=0):
     """Gather elements from a n-dimension array..
 
     Parameters
@@ -543,7 +543,7 @@ def gather_nd(a, indices):
     -------
     ret : tvm.te.Tensor
     """
-    return cpp.gather_nd(a, indices)
+    return cpp.gather_nd(a, indices, batch_dims)
 
 
 def matmul(a, b, transp_a=False, transp_b=False):
diff --git a/src/relax/op/tensor/manipulate.cc 
b/src/relax/op/tensor/manipulate.cc
index ba44341302..f64b3ec4f9 100644
--- a/src/relax/op/tensor/manipulate.cc
+++ b/src/relax/op/tensor/manipulate.cc
@@ -1418,6 +1418,169 @@ TVM_REGISTER_OP("relax.flip")
     .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoFlip)
     .set_attr<Bool>("FPurity", Bool(true));
 
+/* relax.gather_elements */
+TVM_REGISTER_NODE_TYPE(GatherElementsAttrs);
+
+Expr gather_elements(Expr data, Expr indices, int axis) {
+  auto attrs = make_object<GatherElementsAttrs>();
+  attrs->axis = Integer(axis);
+  static const Op& op = Op::Get("relax.gather_elements");
+  return Call(op, {data, indices}, Attrs(attrs), {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.gather_elements").set_body_typed(gather_elements);
+
+StructInfo InferStructInfoGatherElements(const Call& call, const BlockBuilder& 
ctx) {
+  const auto* data_sinfo = 
GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
+  const auto* indices_sinfo = 
GetStructInfoAs<TensorStructInfoNode>(call->args[1]);
+  const auto* attrs = call->attrs.as<GatherElementsAttrs>();
+
+  if (data_sinfo == nullptr) {
+    ctx->ReportFatal(
+        Diagnostic::Error(call)
+        << "GatherElements requires the input data to be a Tensor. However, 
the given one is "
+        << call->args[0]->struct_info_->GetTypeKey());
+  }
+  if (indices_sinfo == nullptr) {
+    ctx->ReportFatal(
+        Diagnostic::Error(call)
+        << "GatherElements requires the input indices to be a Tensor. However, 
the given one is "
+        << call->args[1]->struct_info_->GetTypeKey());
+  }
+
+  if (!indices_sinfo->IsUnknownDtype() && !indices_sinfo->dtype.is_int()) {
+    ctx->ReportFatal(
+        Diagnostic::Error(call)
+        << "GatherElements requires the input indices to have int64 dtype. 
However, the "
+        << "given indices dtype is " << indices_sinfo->dtype);
+  }
+
+  if (data_sinfo->IsUnknownNdim() || indices_sinfo->IsUnknownNdim()) {
+    return TensorStructInfo(data_sinfo->dtype, kUnknownNDim, 
data_sinfo->vdevice);
+  }
+
+  int axis = attrs->axis.IntValue();
+  if (axis < -data_sinfo->ndim || axis >= data_sinfo->ndim) {
+    ctx->ReportFatal(Diagnostic::Error(call)
+                     << "GatherElements requires axis to be within the input 
dimension range ["
+                     << -data_sinfo->ndim << ", " << data_sinfo->ndim - 1 << 
"]. However, the "
+                     << "given axis is " << axis);
+  }
+
+  if (data_sinfo->ndim != indices_sinfo->ndim) {
+    ctx->ReportFatal(Diagnostic::Error(call)
+                     << "GatherElements requires data and indices to have the 
same rank. However, "
+                     << "data rank is " << data_sinfo->ndim << " while indices 
rank is "
+                     << indices_sinfo->ndim);
+  }
+  if (indices_sinfo->shape.defined()) {
+    return TensorStructInfo(indices_sinfo->shape.value(), data_sinfo->dtype, 
data_sinfo->vdevice);
+  }
+  return TensorStructInfo(data_sinfo->dtype, indices_sinfo->ndim, 
data_sinfo->vdevice);
+}
+
+TVM_REGISTER_OP("relax.gather_elements")
+    .set_attrs_type<GatherElementsAttrs>()
+    .set_num_inputs(2)
+    .add_argument("data", "Tensor", "The input tensor.")
+    .add_argument("indices", "Tensor", "The indices tensor.")
+    .set_attr<FInferStructInfo>("FInferStructInfo", 
InferStructInfoGatherElements)
+    .set_attr<Bool>("FPurity", Bool(true));
+
+/* relax.gather_nd */
+TVM_REGISTER_NODE_TYPE(GatherNDAttrs);
+
+Expr gather_nd(Expr data, Expr indices, int batch_dims) {
+  auto attrs = make_object<GatherNDAttrs>();
+  attrs->batch_dims = Integer(batch_dims);
+  static const Op& op = Op::Get("relax.gather_nd");
+  return Call(op, {data, indices}, Attrs(attrs), {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.gather_nd").set_body_typed(gather_nd);
+
+StructInfo InferStructInfoGatherND(const Call& call, const BlockBuilder& ctx) {
+  const auto* data_sinfo = 
GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
+  const auto* indices_sinfo = 
GetStructInfoAs<TensorStructInfoNode>(call->args[1]);
+  const auto* attrs = call->attrs.as<GatherNDAttrs>();
+
+  if (data_sinfo == nullptr) {
+    ctx->ReportFatal(
+        Diagnostic::Error(call)
+        << "GatherND requires the input data to be a Tensor. However, the 
given one is "
+        << call->args[0]->struct_info_->GetTypeKey());
+  }
+  if (indices_sinfo == nullptr) {
+    ctx->ReportFatal(
+        Diagnostic::Error(call)
+        << "GatherND requires the input indices to be a Tensor. However, the 
given one is "
+        << call->args[1]->struct_info_->GetTypeKey());
+  }
+  ICHECK_GE(attrs->batch_dims.IntValue(), 0);
+  int batch_dims = attrs->batch_dims.IntValue();
+  int input_dims = data_sinfo->ndim;
+  if (!indices_sinfo->IsUnknownDtype() && indices_sinfo->dtype != 
DataType::Int(64)) {
+    ctx->ReportFatal(Diagnostic::Error(call)
+                     << "GatherND requires the input indices to have int64 
dtype. However, the "
+                     << "given indices dtype is " << indices_sinfo->dtype);
+  }
+
+  if (data_sinfo->IsUnknownNdim() || indices_sinfo->IsUnknownNdim()) {
+    return TensorStructInfo(data_sinfo->dtype, kUnknownNDim, 
data_sinfo->vdevice);
+  }
+
+  if (batch_dims < 0 || batch_dims > data_sinfo->ndim) {
+    ctx->ReportFatal(
+        Diagnostic::Error(call)
+        << "GatherND batch_dims must be in range [0, data.ndim]. However, got 
batch_dims="
+        << batch_dims << ", data.ndim=" << input_dims);
+  }
+
+  if (batch_dims > indices_sinfo->ndim - 1) {
+    ctx->ReportFatal(Diagnostic::Error(call)
+                     << "GatherND batch_dims cannot exceed indices.ndim-1. 
However, got batch_dims="
+                     << batch_dims << ", indices.ndim=" << 
indices_sinfo->ndim);
+  }
+
+  // Check if indices shape is known
+  const auto* indices_shape = indices_sinfo->shape.as<ShapeExprNode>();
+  const auto* data_shape = data_sinfo->shape.as<ShapeExprNode>();
+  if (!indices_shape || 
!indices_shape->values.back()->IsInstance<IntImmNode>()) {
+    return TensorStructInfo(data_sinfo->dtype, kUnknownNDim, 
data_sinfo->vdevice);
+  }
+  int l = indices_shape->values.back().as<IntImmNode>()->value;
+  int output_ndim = indices_sinfo->ndim + input_dims - l - 1 - batch_dims;
+  if (!data_shape) {
+    return TensorStructInfo(data_sinfo->dtype, output_ndim, 
data_sinfo->vdevice);
+  }
+
+  // In this condition, all input shapes are known
+  Array<PrimExpr> out_shape;
+  if (l > input_dims - batch_dims) {
+    ctx->ReportFatal(Diagnostic::Error(call)
+                     << "GatherND requires the last dimension of indices to be 
less than or "
+                        "equal to the rank of data minus batch_dims. However, 
the given shapes are "
+                     << "indices: " << ShapeExpr(indices_shape->values) << ", 
data: "
+                     << ShapeExpr(data_shape->values) << ", with batch_dims=" 
<< batch_dims);
+  }
+  for (int i = 0; i < indices_sinfo->ndim - 1; ++i) {
+    out_shape.push_back(indices_shape->values[i]);
+  }
+  for (int i = batch_dims + l; i < input_dims; ++i) {
+    out_shape.push_back(data_shape->values[i]);
+  }
+  ICHECK_EQ(out_shape.size(), output_ndim);
+  return TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype, 
data_sinfo->vdevice);
+}
+
+TVM_REGISTER_OP("relax.gather_nd")
+    .set_attrs_type<GatherNDAttrs>()
+    .set_num_inputs(2)
+    .add_argument("data", "Tensor", "The input tensor.")
+    .add_argument("indices", "Tensor", "The indices tensor.")
+    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoGatherND)
+    .set_attr<Bool>("FPurity", Bool(true));
+
 /* relax.scatter_elements */
 TVM_REGISTER_NODE_TYPE(ScatterElementsAttrs);
 
diff --git a/src/relax/op/tensor/manipulate.h b/src/relax/op/tensor/manipulate.h
index 010ceb663e..1a0c7ddbc7 100644
--- a/src/relax/op/tensor/manipulate.h
+++ b/src/relax/op/tensor/manipulate.h
@@ -174,6 +174,32 @@ Expr tile(Expr data, Array<Integer> repeats);
  */
 Expr flip(Expr data, Integer axis);
 
+/*!
+ * \brief Gather elements from a tensor using indices.
+ * \param data The input tensor.
+ * \param indices The indices tensor, must have integer type.
+ * \param axis The axis along which to index. Default is 0.
+ * \return The computed result.
+ *
+ * \note The shape of indices must match the shape of data, except at 
dimension axis
+ *       where it must just be not null. The output will have the same shape 
as indices.
+ */
+Expr gather_elements(Expr data, Expr indices, int axis = 0);
+
+/*!
+ * \brief Gather values from a tensor using N-dimensional indices.
+ * \param data The input tensor.
+ * \param indices The indices tensor, must have integer type.
+ * \param batch_dims The number of batch dimensions. Default is 0.
+ * \return The computed result.
+ *
+ * \note For batch_dims > 0, the first batch_dims dimensions of data and 
indices must be equal.
+ *       The last dimension of indices indicates the depth of each index 
vector.
+ *       The output shape is batch_dims + indices.shape[:-1] + 
data.shape[batch_dims +
+ * indices.shape[-1]:]
+ */
+Expr gather_nd(Expr data, Expr indices, int batch_dims = 0);
+
 /*!
  * \brief Scatter updates into an array according to indices.
  * \param data The input tensor.
diff --git a/src/topi/transform.cc b/src/topi/transform.cc
index d844739568..2e0fde3b28 100644
--- a/src/topi/transform.cc
+++ b/src/topi/transform.cc
@@ -139,7 +139,8 @@ TVM_REGISTER_GLOBAL("topi.gather").set_body([](TVMArgs 
args, TVMRetValue* rv) {
 });
 
 TVM_REGISTER_GLOBAL("topi.gather_nd").set_body([](TVMArgs args, TVMRetValue* 
rv) {
-  *rv = gather_nd(args[0], args[1]);
+  int batch_dims = args[2];
+  *rv = gather_nd(args[0], args[1], batch_dims);
 });
 
 TVM_REGISTER_GLOBAL("topi.unravel_index").set_body([](TVMArgs args, 
TVMRetValue* rv) {
diff --git a/tests/python/relax/test_frontend_onnx.py 
b/tests/python/relax/test_frontend_onnx.py
index a4a4f78bd3..89f08e5af9 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -544,6 +544,68 @@ def test_gather():
     _verify_gather([3, 3], [[0, 2]], [3, 1, 2], 1)
 
 
[email protected](
+    "data_shape, indices_shape, axis",
+    [
+        ([3, 4, 5], [1, 4, 5], 0),
+        ([3, 4, 5], [3, 2, 5], 1),
+        ([3, 4, 5], [3, 4, 2], 2),
+    ],
+)
+def test_gather_elements(data_shape, indices_shape, axis):
+    gather_elements_node = helper.make_node("GatherElements", ["data", 
"indices"], ["y"], axis=axis)
+
+    graph = helper.make_graph(
+        [gather_elements_node],
+        "gather_elements_test",
+        inputs=[
+            helper.make_tensor_value_info("data", TensorProto.FLOAT, 
data_shape),
+            helper.make_tensor_value_info("indices", TensorProto.INT64, 
indices_shape),
+        ],
+        outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, 
indices_shape)],
+    )
+
+    model = helper.make_model(graph, producer_name="gather_elements_test")
+    input_values = {
+        "data": np.random.randn(*data_shape).astype("float32"),
+        "indices": np.random.randint(0, data_shape[axis], 
indices_shape).astype("int64"),
+    }
+    check_correctness(model, inputs=input_values)
+
+
[email protected](
+    "data_shape, indices_shape, batch_dims",
+    [
+        ([2, 2], [2, 2], 0),
+        ([2, 2], [2, 1], 0),
+        ([2, 2, 2], [1], 0),
+        ([2, 2, 2], [2, 2], 0),
+        ([2, 2, 2], [2, 1, 2], 0),
+        ([2, 2, 2], [2, 2], 1),
+        ([2, 2, 2], [2, 1], 1),
+    ],
+)
+def test_gather_nd(data_shape, indices_shape, batch_dims):
+    gather_nd_node = helper.make_node("GatherND", ["data", "indices"], ["y"], 
batch_dims=batch_dims)
+
+    graph = helper.make_graph(
+        [gather_nd_node],
+        "gather_nd_test",
+        inputs=[
+            helper.make_tensor_value_info("data", TensorProto.FLOAT, 
data_shape),
+            helper.make_tensor_value_info("indices", TensorProto.INT64, 
indices_shape),
+        ],
+        outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, None)],
+    )
+
+    model = helper.make_model(graph, producer_name="gather_nd_test")
+    input_values = {
+        "data": np.random.randn(*data_shape).astype("float32"),
+        "indices": np.random.randint(0, 2, indices_shape).astype("int64"),
+    }
+    check_correctness(model, inputs=input_values)
+
+
 @pytest.mark.parametrize("axis", [0, 1, 2])
 @pytest.mark.parametrize(("name", "opset"), [("Scatter", 10), 
("ScatterElements", 11)])
 def test_scatter(axis: int, name: str, opset: int):
diff --git a/tests/python/relax/test_op_manipulate.py 
b/tests/python/relax/test_op_manipulate.py
index f6aefc8591..23ab6780cf 100644
--- a/tests/python/relax/test_op_manipulate.py
+++ b/tests/python/relax/test_op_manipulate.py
@@ -3205,6 +3205,133 @@ def test_flip_infer_struct_info_wrong_inputs():
         bb.normalize(relax.op.flip(x0, axis=3))
 
 
+def test_gather_elements_infer_struct_info():
+    bb = relax.BlockBuilder()
+    vdev0 = VDevice("llvm")
+    x0 = relax.Var("x", R.Tensor((2, 3, 4), "float32"))
+    x1 = relax.Var("x", R.Tensor("float32", ndim=3))
+    x2 = relax.Var("x", R.Tensor("float32"))
+    x3 = relax.Var("x", R.Tensor((2, 3, 4), "float32", vdev0))
+    i0 = relax.Var("i", R.Tensor((2, 3, 4), "int64"))
+    i1 = relax.Var("i", R.Tensor((2, 3, 4)))
+    i2 = relax.Var("i", R.Tensor("int64", ndim=3))
+    i3 = relax.Var("i", R.Tensor(ndim=3))
+    i4 = relax.Var("i", R.Tensor((2, 3, 4), "int64", vdev0))
+
+    _check_inference(
+        bb, relax.op.gather_elements(x0, i0, axis=1), 
relax.TensorStructInfo((2, 3, 4), "float32")
+    )
+    _check_inference(
+        bb,
+        relax.op.gather_elements(x3, i4, axis=1),
+        relax.TensorStructInfo((2, 3, 4), "float32", vdev0),
+    )
+    _check_inference(
+        bb,
+        relax.op.gather_elements(x1, i0, axis=1),
+        relax.TensorStructInfo((2, 3, 4), dtype="float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.gather_elements(x2, i0, axis=0),
+        relax.TensorStructInfo(dtype="float32", ndim=-1),
+    )
+    _check_inference(
+        bb, relax.op.gather_elements(x0, i1, axis=1), 
relax.TensorStructInfo((2, 3, 4), "float32")
+    )
+    _check_inference(
+        bb,
+        relax.op.gather_elements(x1, i2, axis=1),
+        relax.TensorStructInfo(dtype="float32", ndim=3),
+    )
+    _check_inference(
+        bb,
+        relax.op.gather_elements(x2, i3, axis=0),
+        relax.TensorStructInfo(dtype="float32", ndim=-1),
+    )
+
+
+def test_gather_elements_infer_struct_info_shape_symbolic():
+    bb = relax.BlockBuilder()
+    a = tir.Var("a", "int64")
+    b = tir.Var("b", "int64")
+    x = relax.Var("x", R.Tensor((a, b), "float32"))
+    i = relax.Var("i", R.Tensor((a, b), "int64"))
+
+    _check_inference(
+        bb, relax.op.gather_elements(x, i, axis=1), relax.TensorStructInfo((a, 
b), "float32")
+    )
+
+
+def test_gather_elements_infer_struct_info_wrong_inputs():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((2, 3, 4), "float32"))
+    x1 = relax.Var("x", R.Tensor((2, 3), "float32"))
+    i0 = relax.Var("i", R.Tensor((2, 3, 4), "int64"))
+    i1 = relax.Var("i", R.Tensor((2, 3), "int64"))
+    i2 = relax.Var("i", R.Tensor((2, 3, 4), "float32"))
+
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.gather_elements(x0, i0, axis=3))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.gather_elements(x0, i1))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.gather_elements(x1, i0))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.gather_elements(x0, i2))
+
+
+def test_gather_nd_infer_struct_info():
+    bb = relax.BlockBuilder()
+    vdev0 = VDevice("llvm")
+    x0 = relax.Var("x", R.Tensor((2, 3, 4), "float32"))
+    x1 = relax.Var("x", R.Tensor("float32", ndim=3))
+    x2 = relax.Var("x", R.Tensor("float32"))
+    x3 = relax.Var("x", R.Tensor((2, 3, 4), "float32", vdev0))
+    i0 = relax.Var("i", R.Tensor((2, 2), "int64"))
+    i1 = relax.Var("i", R.Tensor((2, 2)))
+    i2 = relax.Var("i", R.Tensor("int64", ndim=2))
+    i3 = relax.Var("i", R.Tensor(ndim=2))
+    i4 = relax.Var("i", R.Tensor((2, 2), "int64", vdev0))
+
+    _check_inference(bb, relax.op.gather_nd(x0, i0), 
relax.TensorStructInfo((2, 4), "float32"))
+    _check_inference(
+        bb, relax.op.gather_nd(x3, i4), relax.TensorStructInfo((2, 4), 
"float32", vdev0)
+    )
+    _check_inference(
+        bb, relax.op.gather_nd(x1, i0), 
relax.TensorStructInfo(dtype="float32", ndim=2)
+    )
+    _check_inference(
+        bb, relax.op.gather_nd(x2, i0), 
relax.TensorStructInfo(dtype="float32", ndim=-1)
+    )
+    _check_inference(bb, relax.op.gather_nd(x0, i1), 
relax.TensorStructInfo((2, 4), "float32"))
+    _check_inference(bb, relax.op.gather_nd(x1, i2), 
relax.TensorStructInfo(dtype="float32"))
+    _check_inference(bb, relax.op.gather_nd(x2, i3), 
relax.TensorStructInfo(dtype="float32"))
+
+
+def test_gather_nd_infer_struct_info_shape_symbolic():
+    bb = relax.BlockBuilder()
+    a = tir.Var("a", "int64")
+    b = tir.Var("b", "int64")
+    c = tir.Var("c", "int64")
+    x = relax.Var("x", R.Tensor((a, b, c), "float32"))
+    i = relax.Var("i", R.Tensor((2, 2), "int64"))
+
+    _check_inference(bb, relax.op.gather_nd(x, i), relax.TensorStructInfo((2, 
c), "float32"))
+
+
+def test_gather_nd_infer_struct_info_wrong_inputs():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((2, 3, 4), "float32"))
+    i0 = relax.Var("i", R.Tensor((2, 4), "int64"))  # indices too long
+    i1 = relax.Var("i", R.Tensor((2, 2), "float32"))  # wrong dtype
+
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.gather_nd(x0, i0))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.gather_nd(x0, i1))
+
+
 def test_scatter_elements_infer_struct_info():
     bb = relax.BlockBuilder()
     vdev0 = VDevice("llvm")

Reply via email to