This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 59a2256959 [Relax] Add gather_elements and gather_nd operators (#17523)
59a2256959 is described below
commit 59a2256959a0ea35c4380d98bb983bf4f1e9d856
Author: Siyuan Feng <[email protected]>
AuthorDate: Wed Nov 13 23:29:03 2024 +0800
[Relax] Add gather_elements and gather_nd operators (#17523)
Add gather_elements and gather_nd operators to Relax and corresponding
ONNX frontend.
---
include/tvm/relax/attrs/manipulate.h | 17 +++
python/tvm/relax/frontend/onnx/onnx_frontend.py | 22 ++-
python/tvm/relax/op/__init__.py | 2 +
python/tvm/relax/op/manipulate.py | 73 +++++++++
.../tvm/relax/transform/legalize_ops/manipulate.py | 16 ++
python/tvm/script/ir_builder/relax/ir.py | 4 +
python/tvm/topi/transform.py | 4 +-
src/relax/op/tensor/manipulate.cc | 163 +++++++++++++++++++++
src/relax/op/tensor/manipulate.h | 26 ++++
src/topi/transform.cc | 3 +-
tests/python/relax/test_frontend_onnx.py | 62 ++++++++
tests/python/relax/test_op_manipulate.py | 127 ++++++++++++++++
12 files changed, 514 insertions(+), 5 deletions(-)
diff --git a/include/tvm/relax/attrs/manipulate.h
b/include/tvm/relax/attrs/manipulate.h
index ea41488354..e6c16d233a 100644
--- a/include/tvm/relax/attrs/manipulate.h
+++ b/include/tvm/relax/attrs/manipulate.h
@@ -152,6 +152,23 @@ struct FlipAttrs : public tvm::AttrsNode<FlipAttrs> {
}
}; // struct FlipAttrs
+/*! \brief Attributes used in gather_elements operators */
+struct GatherElementsAttrs : public tvm::AttrsNode<GatherElementsAttrs> {
+ Integer axis;
+
+ TVM_DECLARE_ATTRS(GatherElementsAttrs, "relax.attrs.GatherElementsAttrs") {
+ TVM_ATTR_FIELD(axis).set_default(0).describe("The axis along which to
index.");
+ }
+}; // struct GatherElementsAttrs
+
+/*! \brief Attributes used in gather_nd operators */
+struct GatherNDAttrs : public tvm::AttrsNode<GatherNDAttrs> {
+ Integer batch_dims;
+ TVM_DECLARE_ATTRS(GatherNDAttrs, "relax.attrs.GatherNDAttrs") {
+ TVM_ATTR_FIELD(batch_dims).set_default(Integer(0)).describe("The number of
batch dims.");
+ }
+}; // struct GatherNDAttrs
+
/*! \brief Attributes used in scatter_elements operators */
struct ScatterElementsAttrs : public tvm::AttrsNode<ScatterElementsAttrs> {
Integer axis;
diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index 94ccfdb23e..b64e87822a 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -781,6 +781,24 @@ class Gather(OnnxOpConverter):
return relax.op.take(data, indices, axis)
+class GatherElements(OnnxOpConverter):
+ """Convert an onnx GatherElements node into an equivalent Relax
expression."""
+
+ @classmethod
+ def _impl_v13(cls, bb, inputs, attr, params):
+ axis = attr.get("axis", 0)
+ return relax.op.gather_elements(inputs[0], inputs[1], axis)
+
+
+class GatherND(OnnxOpConverter):
+ """Convert an onnx GatherND node into an equivalent Relax expression."""
+
+ @classmethod
+ def _impl_v13(cls, bb, inputs, attr, params):
+ batch_dims = attr.get("batch_dims", 0)
+ return relax.op.gather_nd(inputs[0], inputs[1], batch_dims)
+
+
class Scatter(OnnxOpConverter):
"""Convert an onnx Scatter node into an equivalent Relax expression."""
@@ -3116,8 +3134,8 @@ def _get_convert_map():
"Squeeze": Squeeze,
"Constant": Constant,
"Gather": Gather,
- # "GatherElements": GatherElements,
- # "GatherND": GatherND,
+ "GatherElements": GatherElements,
+ "GatherND": GatherND,
"Scatter": Scatter,
"ScatterElements": ScatterElements,
"ScatterND": ScatterND,
diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py
index 1603ea2f0f..97f18a2396 100644
--- a/python/tvm/relax/op/__init__.py
+++ b/python/tvm/relax/op/__init__.py
@@ -92,6 +92,8 @@ from .manipulate import (
expand_dims,
flatten,
flip,
+ gather_elements,
+ gather_nd,
layout_transform,
one_hot,
permute_dims,
diff --git a/python/tvm/relax/op/manipulate.py
b/python/tvm/relax/op/manipulate.py
index 3210cc8216..0f6e537ab3 100644
--- a/python/tvm/relax/op/manipulate.py
+++ b/python/tvm/relax/op/manipulate.py
@@ -435,6 +435,79 @@ def flip(data, axis):
return _ffi_api.flip(data, axis) # type: ignore
+def gather_elements(data: Expr, indices: Expr, axis: int = 0) -> Expr:
+ """Gather elements from data according to indices along the specified axis.
+
+ Parameters
+ ----------
+ data : relax.Expr
+ The input data to the operator.
+
+ indices : relax.Expr
+ The indices tensor, must have integer type.
+
+ axis : int
+ The axis along which to index. Default is 0.
+
+ Returns
+ -------
+ ret : relax.Expr
+ The computed result.
+
+ Examples
+ --------
+ .. code-block:: python
+
+ data = [[1, 2], [3, 4]]
+ indices = [[0, 0], [1, 0]]
+ axis = 1
+ output = [[1, 1], [4, 3]]
+
+ data = [[1, 2, 3], [4, 5, 6]]
+ indices = [[1, 1, 1]]
+ axis = 0
+ output = [[4, 5, 6]]
+ """
+ return _ffi_api.gather_elements(data, indices, axis) # type: ignore
+
+
+def gather_nd(data: Expr, indices: Expr, batch_dims: int = 0) -> Expr:
+ """Update data at positions defined by indices with values in updates.
+
+ Parameters
+ ----------
+ data : relax.Expr
+ The input data to the operator.
+
+ indices : relax.Expr
+ The indices tensor, must have integer type.
+
+ batch_dims : int
+ The number of batch dimensions. Default is 0.
+
+ Returns
+ -------
+ ret : relax.Expr
+ The computed result.
+
+ Examples
+ --------
+ .. code-block:: python
+
+ batch_dims = 0
+ data = [[0,1],[2,3]] # data_shape = [2, 2]
+ indices = [[0,0],[1,1]] # indices_shape = [2, 2]
+ output = [0,3] # output_shape = [2]
+
+ batch_dims = 1
+ data = [[[0,1],[2,3]],[[4,5],[6,7]]] # data_shape = [2, 2, 2]
+ indices = [[1],[0]] # indices_shape = [2, 1]
+ output = [[2,3],[4,5]] # output_shape = [2, 2]
+
+ """
+ return _ffi_api.gather_nd(data, indices, batch_dims) # type: ignore
+
+
def scatter_elements(
data: Expr, indices: Expr, updates: Expr, axis: int = 0, reduction: str =
"update"
):
diff --git a/python/tvm/relax/transform/legalize_ops/manipulate.py
b/python/tvm/relax/transform/legalize_ops/manipulate.py
index 163085a07c..55bc2772bc 100644
--- a/python/tvm/relax/transform/legalize_ops/manipulate.py
+++ b/python/tvm/relax/transform/legalize_ops/manipulate.py
@@ -156,6 +156,22 @@ def _flip(bb: BlockBuilder, call: Call) -> Expr:
return bb.call_te(topi.flip, call.args[0], int(call.attrs.axis))
+@register_legalize("relax.gather_elements")
+def _gather_elements(bb: BlockBuilder, call: Call) -> Expr:
+ return bb.call_te(topi.gather, call.args[0], int(call.attrs.axis),
call.args[1])
+
+
+@register_legalize("relax.gather_nd")
+def _gather_nd(bb: BlockBuilder, call: Call) -> Expr:
+ def te_gather_nd(data, indices, batch_dims):
+ indices_ndim = len(indices.shape)
+ axes = [indices_ndim - 1] + list(range(indices_ndim - 1))
+ indices = topi.transpose(indices, axes)
+ return topi.gather_nd(data, indices, batch_dims)
+
+ return bb.call_te(te_gather_nd, call.args[0], call.args[1],
int(call.attrs.batch_dims))
+
+
@register_legalize("relax.scatter_elements")
def _scatter_elements(bb: BlockBuilder, call: Call) -> Expr:
return bb.call_te(
diff --git a/python/tvm/script/ir_builder/relax/ir.py
b/python/tvm/script/ir_builder/relax/ir.py
index 049345fcb1..ddc534cf60 100644
--- a/python/tvm/script/ir_builder/relax/ir.py
+++ b/python/tvm/script/ir_builder/relax/ir.py
@@ -94,6 +94,8 @@ from tvm.relax.op import (
floor_mod,
full,
full_like,
+ gather_elements,
+ gather_nd,
grad,
greater,
greater_equal,
@@ -772,6 +774,8 @@ __all__ = [
"func_ret_struct_info",
"func_ret_value",
"function",
+ "gather_elements",
+ "gather_nd",
"gpu",
"grad",
"greater",
diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py
index 686311fbee..2844825a4a 100644
--- a/python/tvm/topi/transform.py
+++ b/python/tvm/topi/transform.py
@@ -528,7 +528,7 @@ def gather(data, axis, indices):
return cpp.gather(data, axis, indices)
-def gather_nd(a, indices):
+def gather_nd(a, indices, batch_dims=0):
"""Gather elements from a n-dimension array..
Parameters
@@ -543,7 +543,7 @@ def gather_nd(a, indices):
-------
ret : tvm.te.Tensor
"""
- return cpp.gather_nd(a, indices)
+ return cpp.gather_nd(a, indices, batch_dims)
def matmul(a, b, transp_a=False, transp_b=False):
diff --git a/src/relax/op/tensor/manipulate.cc
b/src/relax/op/tensor/manipulate.cc
index ba44341302..f64b3ec4f9 100644
--- a/src/relax/op/tensor/manipulate.cc
+++ b/src/relax/op/tensor/manipulate.cc
@@ -1418,6 +1418,169 @@ TVM_REGISTER_OP("relax.flip")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoFlip)
.set_attr<Bool>("FPurity", Bool(true));
+/* relax.gather_elements */
+TVM_REGISTER_NODE_TYPE(GatherElementsAttrs);
+
+Expr gather_elements(Expr data, Expr indices, int axis) {
+ auto attrs = make_object<GatherElementsAttrs>();
+ attrs->axis = Integer(axis);
+ static const Op& op = Op::Get("relax.gather_elements");
+ return Call(op, {data, indices}, Attrs(attrs), {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.gather_elements").set_body_typed(gather_elements);
+
+StructInfo InferStructInfoGatherElements(const Call& call, const BlockBuilder&
ctx) {
+ const auto* data_sinfo =
GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
+ const auto* indices_sinfo =
GetStructInfoAs<TensorStructInfoNode>(call->args[1]);
+ const auto* attrs = call->attrs.as<GatherElementsAttrs>();
+
+ if (data_sinfo == nullptr) {
+ ctx->ReportFatal(
+ Diagnostic::Error(call)
+ << "GatherElements requires the input data to be a Tensor. However,
the given one is "
+ << call->args[0]->struct_info_->GetTypeKey());
+ }
+ if (indices_sinfo == nullptr) {
+ ctx->ReportFatal(
+ Diagnostic::Error(call)
+ << "GatherElements requires the input indices to be a Tensor. However,
the given one is "
+ << call->args[1]->struct_info_->GetTypeKey());
+ }
+
+ if (!indices_sinfo->IsUnknownDtype() && !indices_sinfo->dtype.is_int()) {
+ ctx->ReportFatal(
+ Diagnostic::Error(call)
+ << "GatherElements requires the input indices to have int64 dtype.
However, the "
+ << "given indices dtype is " << indices_sinfo->dtype);
+ }
+
+ if (data_sinfo->IsUnknownNdim() || indices_sinfo->IsUnknownNdim()) {
+ return TensorStructInfo(data_sinfo->dtype, kUnknownNDim,
data_sinfo->vdevice);
+ }
+
+ int axis = attrs->axis.IntValue();
+ if (axis < -data_sinfo->ndim || axis >= data_sinfo->ndim) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "GatherElements requires axis to be within the input
dimension range ["
+ << -data_sinfo->ndim << ", " << data_sinfo->ndim - 1 <<
"]. However, the "
+ << "given axis is " << axis);
+ }
+
+ if (data_sinfo->ndim != indices_sinfo->ndim) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "GatherElements requires data and indices to have the
same rank. However, "
+ << "data rank is " << data_sinfo->ndim << " while indices
rank is "
+ << indices_sinfo->ndim);
+ }
+ if (indices_sinfo->shape.defined()) {
+ return TensorStructInfo(indices_sinfo->shape.value(), data_sinfo->dtype,
data_sinfo->vdevice);
+ }
+ return TensorStructInfo(data_sinfo->dtype, indices_sinfo->ndim,
data_sinfo->vdevice);
+}
+
+TVM_REGISTER_OP("relax.gather_elements")
+ .set_attrs_type<GatherElementsAttrs>()
+ .set_num_inputs(2)
+ .add_argument("data", "Tensor", "The input tensor.")
+ .add_argument("indices", "Tensor", "The indices tensor.")
+ .set_attr<FInferStructInfo>("FInferStructInfo",
InferStructInfoGatherElements)
+ .set_attr<Bool>("FPurity", Bool(true));
+
+/* relax.gather_nd */
+TVM_REGISTER_NODE_TYPE(GatherNDAttrs);
+
+Expr gather_nd(Expr data, Expr indices, int batch_dims) {
+ auto attrs = make_object<GatherNDAttrs>();
+ attrs->batch_dims = Integer(batch_dims);
+ static const Op& op = Op::Get("relax.gather_nd");
+ return Call(op, {data, indices}, Attrs(attrs), {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.gather_nd").set_body_typed(gather_nd);
+
+StructInfo InferStructInfoGatherND(const Call& call, const BlockBuilder& ctx) {
+ const auto* data_sinfo =
GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
+ const auto* indices_sinfo =
GetStructInfoAs<TensorStructInfoNode>(call->args[1]);
+ const auto* attrs = call->attrs.as<GatherNDAttrs>();
+
+ if (data_sinfo == nullptr) {
+ ctx->ReportFatal(
+ Diagnostic::Error(call)
+ << "GatherND requires the input data to be a Tensor. However, the
given one is "
+ << call->args[0]->struct_info_->GetTypeKey());
+ }
+ if (indices_sinfo == nullptr) {
+ ctx->ReportFatal(
+ Diagnostic::Error(call)
+ << "GatherND requires the input indices to be a Tensor. However, the
given one is "
+ << call->args[1]->struct_info_->GetTypeKey());
+ }
+ ICHECK_GE(attrs->batch_dims.IntValue(), 0);
+ int batch_dims = attrs->batch_dims.IntValue();
+ int input_dims = data_sinfo->ndim;
+ if (!indices_sinfo->IsUnknownDtype() && indices_sinfo->dtype !=
DataType::Int(64)) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "GatherND requires the input indices to have int64
dtype. However, the "
+ << "given indices dtype is " << indices_sinfo->dtype);
+ }
+
+ if (data_sinfo->IsUnknownNdim() || indices_sinfo->IsUnknownNdim()) {
+ return TensorStructInfo(data_sinfo->dtype, kUnknownNDim,
data_sinfo->vdevice);
+ }
+
+ if (batch_dims < 0 || batch_dims > data_sinfo->ndim) {
+ ctx->ReportFatal(
+ Diagnostic::Error(call)
+ << "GatherND batch_dims must be in range [0, data.ndim]. However, got
batch_dims="
+ << batch_dims << ", data.ndim=" << input_dims);
+ }
+
+ if (batch_dims > indices_sinfo->ndim - 1) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "GatherND batch_dims cannot exceed indices.ndim-1.
However, got batch_dims="
+ << batch_dims << ", indices.ndim=" <<
indices_sinfo->ndim);
+ }
+
+ // Check if indices shape is known
+ const auto* indices_shape = indices_sinfo->shape.as<ShapeExprNode>();
+ const auto* data_shape = data_sinfo->shape.as<ShapeExprNode>();
+ if (!indices_shape ||
!indices_shape->values.back()->IsInstance<IntImmNode>()) {
+ return TensorStructInfo(data_sinfo->dtype, kUnknownNDim,
data_sinfo->vdevice);
+ }
+ int l = indices_shape->values.back().as<IntImmNode>()->value;
+ int output_ndim = indices_sinfo->ndim + input_dims - l - 1 - batch_dims;
+ if (!data_shape) {
+ return TensorStructInfo(data_sinfo->dtype, output_ndim,
data_sinfo->vdevice);
+ }
+
+ // In this condition, all input shapes are known
+ Array<PrimExpr> out_shape;
+ if (l > input_dims - batch_dims) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "GatherND requires the last dimension of indices to be
less than or "
+ "equal to the rank of data minus batch_dims. However,
the given shapes are "
+ << "indices: " << ShapeExpr(indices_shape->values) << ",
data: "
+ << ShapeExpr(data_shape->values) << ", with batch_dims="
<< batch_dims);
+ }
+ for (int i = 0; i < indices_sinfo->ndim - 1; ++i) {
+ out_shape.push_back(indices_shape->values[i]);
+ }
+ for (int i = batch_dims + l; i < input_dims; ++i) {
+ out_shape.push_back(data_shape->values[i]);
+ }
+ ICHECK_EQ(out_shape.size(), output_ndim);
+ return TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype,
data_sinfo->vdevice);
+}
+
+TVM_REGISTER_OP("relax.gather_nd")
+ .set_attrs_type<GatherNDAttrs>()
+ .set_num_inputs(2)
+ .add_argument("data", "Tensor", "The input tensor.")
+ .add_argument("indices", "Tensor", "The indices tensor.")
+ .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoGatherND)
+ .set_attr<Bool>("FPurity", Bool(true));
+
/* relax.scatter_elements */
TVM_REGISTER_NODE_TYPE(ScatterElementsAttrs);
diff --git a/src/relax/op/tensor/manipulate.h b/src/relax/op/tensor/manipulate.h
index 010ceb663e..1a0c7ddbc7 100644
--- a/src/relax/op/tensor/manipulate.h
+++ b/src/relax/op/tensor/manipulate.h
@@ -174,6 +174,32 @@ Expr tile(Expr data, Array<Integer> repeats);
*/
Expr flip(Expr data, Integer axis);
+/*!
+ * \brief Gather elements from a tensor using indices.
+ * \param data The input tensor.
+ * \param indices The indices tensor, must have integer type.
+ * \param axis The axis along which to index. Default is 0.
+ * \return The computed result.
+ *
+ * \note The shape of indices must match the shape of data, except at
dimension axis
+ * where it must just be not null. The output will have the same shape
as indices.
+ */
+Expr gather_elements(Expr data, Expr indices, int axis = 0);
+
+/*!
+ * \brief Gather values from a tensor using N-dimensional indices.
+ * \param data The input tensor.
+ * \param indices The indices tensor, must have integer type.
+ * \param batch_dims The number of batch dimensions. Default is 0.
+ * \return The computed result.
+ *
+ * \note For batch_dims > 0, the first batch_dims dimensions of data and
indices must be equal.
+ * The last dimension of indices indicates the depth of each index
vector.
+ * The output shape is batch_dims + indices.shape[:-1] +
data.shape[batch_dims +
+ * indices.shape[-1]:]
+ */
+Expr gather_nd(Expr data, Expr indices, int batch_dims = 0);
+
/*!
* \brief Scatter updates into an array according to indices.
* \param data The input tensor.
diff --git a/src/topi/transform.cc b/src/topi/transform.cc
index d844739568..2e0fde3b28 100644
--- a/src/topi/transform.cc
+++ b/src/topi/transform.cc
@@ -139,7 +139,8 @@ TVM_REGISTER_GLOBAL("topi.gather").set_body([](TVMArgs
args, TVMRetValue* rv) {
});
TVM_REGISTER_GLOBAL("topi.gather_nd").set_body([](TVMArgs args, TVMRetValue*
rv) {
- *rv = gather_nd(args[0], args[1]);
+ int batch_dims = args[2];
+ *rv = gather_nd(args[0], args[1], batch_dims);
});
TVM_REGISTER_GLOBAL("topi.unravel_index").set_body([](TVMArgs args,
TVMRetValue* rv) {
diff --git a/tests/python/relax/test_frontend_onnx.py
b/tests/python/relax/test_frontend_onnx.py
index a4a4f78bd3..89f08e5af9 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -544,6 +544,68 @@ def test_gather():
_verify_gather([3, 3], [[0, 2]], [3, 1, 2], 1)
[email protected](
+ "data_shape, indices_shape, axis",
+ [
+ ([3, 4, 5], [1, 4, 5], 0),
+ ([3, 4, 5], [3, 2, 5], 1),
+ ([3, 4, 5], [3, 4, 2], 2),
+ ],
+)
+def test_gather_elements(data_shape, indices_shape, axis):
+ gather_elements_node = helper.make_node("GatherElements", ["data",
"indices"], ["y"], axis=axis)
+
+ graph = helper.make_graph(
+ [gather_elements_node],
+ "gather_elements_test",
+ inputs=[
+ helper.make_tensor_value_info("data", TensorProto.FLOAT,
data_shape),
+ helper.make_tensor_value_info("indices", TensorProto.INT64,
indices_shape),
+ ],
+ outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT,
indices_shape)],
+ )
+
+ model = helper.make_model(graph, producer_name="gather_elements_test")
+ input_values = {
+ "data": np.random.randn(*data_shape).astype("float32"),
+ "indices": np.random.randint(0, data_shape[axis],
indices_shape).astype("int64"),
+ }
+ check_correctness(model, inputs=input_values)
+
+
[email protected](
+ "data_shape, indices_shape, batch_dims",
+ [
+ ([2, 2], [2, 2], 0),
+ ([2, 2], [2, 1], 0),
+ ([2, 2, 2], [1], 0),
+ ([2, 2, 2], [2, 2], 0),
+ ([2, 2, 2], [2, 1, 2], 0),
+ ([2, 2, 2], [2, 2], 1),
+ ([2, 2, 2], [2, 1], 1),
+ ],
+)
+def test_gather_nd(data_shape, indices_shape, batch_dims):
+ gather_nd_node = helper.make_node("GatherND", ["data", "indices"], ["y"],
batch_dims=batch_dims)
+
+ graph = helper.make_graph(
+ [gather_nd_node],
+ "gather_nd_test",
+ inputs=[
+ helper.make_tensor_value_info("data", TensorProto.FLOAT,
data_shape),
+ helper.make_tensor_value_info("indices", TensorProto.INT64,
indices_shape),
+ ],
+ outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, None)],
+ )
+
+ model = helper.make_model(graph, producer_name="gather_nd_test")
+ input_values = {
+ "data": np.random.randn(*data_shape).astype("float32"),
+ "indices": np.random.randint(0, 2, indices_shape).astype("int64"),
+ }
+ check_correctness(model, inputs=input_values)
+
+
@pytest.mark.parametrize("axis", [0, 1, 2])
@pytest.mark.parametrize(("name", "opset"), [("Scatter", 10),
("ScatterElements", 11)])
def test_scatter(axis: int, name: str, opset: int):
diff --git a/tests/python/relax/test_op_manipulate.py
b/tests/python/relax/test_op_manipulate.py
index f6aefc8591..23ab6780cf 100644
--- a/tests/python/relax/test_op_manipulate.py
+++ b/tests/python/relax/test_op_manipulate.py
@@ -3205,6 +3205,133 @@ def test_flip_infer_struct_info_wrong_inputs():
bb.normalize(relax.op.flip(x0, axis=3))
+def test_gather_elements_infer_struct_info():
+ bb = relax.BlockBuilder()
+ vdev0 = VDevice("llvm")
+ x0 = relax.Var("x", R.Tensor((2, 3, 4), "float32"))
+ x1 = relax.Var("x", R.Tensor("float32", ndim=3))
+ x2 = relax.Var("x", R.Tensor("float32"))
+ x3 = relax.Var("x", R.Tensor((2, 3, 4), "float32", vdev0))
+ i0 = relax.Var("i", R.Tensor((2, 3, 4), "int64"))
+ i1 = relax.Var("i", R.Tensor((2, 3, 4)))
+ i2 = relax.Var("i", R.Tensor("int64", ndim=3))
+ i3 = relax.Var("i", R.Tensor(ndim=3))
+ i4 = relax.Var("i", R.Tensor((2, 3, 4), "int64", vdev0))
+
+ _check_inference(
+ bb, relax.op.gather_elements(x0, i0, axis=1),
relax.TensorStructInfo((2, 3, 4), "float32")
+ )
+ _check_inference(
+ bb,
+ relax.op.gather_elements(x3, i4, axis=1),
+ relax.TensorStructInfo((2, 3, 4), "float32", vdev0),
+ )
+ _check_inference(
+ bb,
+ relax.op.gather_elements(x1, i0, axis=1),
+ relax.TensorStructInfo((2, 3, 4), dtype="float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.gather_elements(x2, i0, axis=0),
+ relax.TensorStructInfo(dtype="float32", ndim=-1),
+ )
+ _check_inference(
+ bb, relax.op.gather_elements(x0, i1, axis=1),
relax.TensorStructInfo((2, 3, 4), "float32")
+ )
+ _check_inference(
+ bb,
+ relax.op.gather_elements(x1, i2, axis=1),
+ relax.TensorStructInfo(dtype="float32", ndim=3),
+ )
+ _check_inference(
+ bb,
+ relax.op.gather_elements(x2, i3, axis=0),
+ relax.TensorStructInfo(dtype="float32", ndim=-1),
+ )
+
+
+def test_gather_elements_infer_struct_info_shape_symbolic():
+ bb = relax.BlockBuilder()
+ a = tir.Var("a", "int64")
+ b = tir.Var("b", "int64")
+ x = relax.Var("x", R.Tensor((a, b), "float32"))
+ i = relax.Var("i", R.Tensor((a, b), "int64"))
+
+ _check_inference(
+ bb, relax.op.gather_elements(x, i, axis=1), relax.TensorStructInfo((a,
b), "float32")
+ )
+
+
+def test_gather_elements_infer_struct_info_wrong_inputs():
+ bb = relax.BlockBuilder()
+ x0 = relax.Var("x", R.Tensor((2, 3, 4), "float32"))
+ x1 = relax.Var("x", R.Tensor((2, 3), "float32"))
+ i0 = relax.Var("i", R.Tensor((2, 3, 4), "int64"))
+ i1 = relax.Var("i", R.Tensor((2, 3), "int64"))
+ i2 = relax.Var("i", R.Tensor((2, 3, 4), "float32"))
+
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.gather_elements(x0, i0, axis=3))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.gather_elements(x0, i1))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.gather_elements(x1, i0))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.gather_elements(x0, i2))
+
+
+def test_gather_nd_infer_struct_info():
+ bb = relax.BlockBuilder()
+ vdev0 = VDevice("llvm")
+ x0 = relax.Var("x", R.Tensor((2, 3, 4), "float32"))
+ x1 = relax.Var("x", R.Tensor("float32", ndim=3))
+ x2 = relax.Var("x", R.Tensor("float32"))
+ x3 = relax.Var("x", R.Tensor((2, 3, 4), "float32", vdev0))
+ i0 = relax.Var("i", R.Tensor((2, 2), "int64"))
+ i1 = relax.Var("i", R.Tensor((2, 2)))
+ i2 = relax.Var("i", R.Tensor("int64", ndim=2))
+ i3 = relax.Var("i", R.Tensor(ndim=2))
+ i4 = relax.Var("i", R.Tensor((2, 2), "int64", vdev0))
+
+ _check_inference(bb, relax.op.gather_nd(x0, i0),
relax.TensorStructInfo((2, 4), "float32"))
+ _check_inference(
+ bb, relax.op.gather_nd(x3, i4), relax.TensorStructInfo((2, 4),
"float32", vdev0)
+ )
+ _check_inference(
+ bb, relax.op.gather_nd(x1, i0),
relax.TensorStructInfo(dtype="float32", ndim=2)
+ )
+ _check_inference(
+ bb, relax.op.gather_nd(x2, i0),
relax.TensorStructInfo(dtype="float32", ndim=-1)
+ )
+ _check_inference(bb, relax.op.gather_nd(x0, i1),
relax.TensorStructInfo((2, 4), "float32"))
+ _check_inference(bb, relax.op.gather_nd(x1, i2),
relax.TensorStructInfo(dtype="float32"))
+ _check_inference(bb, relax.op.gather_nd(x2, i3),
relax.TensorStructInfo(dtype="float32"))
+
+
+def test_gather_nd_infer_struct_info_shape_symbolic():
+ bb = relax.BlockBuilder()
+ a = tir.Var("a", "int64")
+ b = tir.Var("b", "int64")
+ c = tir.Var("c", "int64")
+ x = relax.Var("x", R.Tensor((a, b, c), "float32"))
+ i = relax.Var("i", R.Tensor((2, 2), "int64"))
+
+ _check_inference(bb, relax.op.gather_nd(x, i), relax.TensorStructInfo((2,
c), "float32"))
+
+
+def test_gather_nd_infer_struct_info_wrong_inputs():
+ bb = relax.BlockBuilder()
+ x0 = relax.Var("x", R.Tensor((2, 3, 4), "float32"))
+ i0 = relax.Var("i", R.Tensor((2, 4), "int64")) # indices too long
+ i1 = relax.Var("i", R.Tensor((2, 2), "float32")) # wrong dtype
+
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.gather_nd(x0, i0))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.gather_nd(x0, i1))
+
+
def test_scatter_elements_infer_struct_info():
bb = relax.BlockBuilder()
vdev0 = VDevice("llvm")