This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 41ba01ddd9 [Unity] Relax op: index (#13987)
41ba01ddd9 is described below

commit 41ba01ddd9d6bfa68d59956c8c5e990fb9d7c2fa
Author: Ruihang Lai <[email protected]>
AuthorDate: Tue Feb 14 14:02:20 2023 -0500

    [Unity] Relax op: index (#13987)
    
    This PR is about the high-level tensor computation operators in Relax.
    
    This PR includes the tensor indexing operators.
---
 include/tvm/relax/attrs/index.h                    |  62 +++
 python/tvm/relax/op/__init__.py                    |   2 +
 python/tvm/relax/op/index.py                       |  90 ++++
 python/tvm/relax/op/{__init__.py => op_attrs.py}   |  20 +-
 python/tvm/script/ir_builder/relax/ir.py           |   4 +
 src/relax/op/tensor/index.cc                       | 195 +++++++
 src/relax/op/tensor/index.h                        |  65 +++
 tests/python/relax/test_op_index.py                | 593 +++++++++++++++++++++
 .../python/relax/test_tvmscript_parser_op_index.py |  82 +++
 9 files changed, 1105 insertions(+), 8 deletions(-)

diff --git a/include/tvm/relax/attrs/index.h b/include/tvm/relax/attrs/index.h
new file mode 100644
index 0000000000..c95395a803
--- /dev/null
+++ b/include/tvm/relax/attrs/index.h
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/relax/attrs/index.h
+ * \brief Attributes for indexing operators.
+ */
+#ifndef TVM_RELAX_ATTRS_INDEX_H_
+#define TVM_RELAX_ATTRS_INDEX_H_
+
+#include <tvm/relax/expr.h>
+
+namespace tvm {
+namespace relax {
+
+/*! \brief Attributes used in take operator */
+struct TakeAttrs : public tvm::AttrsNode<TakeAttrs> {
+  Optional<Integer> axis;
+
+  TVM_DECLARE_ATTRS(TakeAttrs, "relax.attrs.TakeAttrs") {
+    TVM_ATTR_FIELD(axis).describe("The axis over which to select values.");
+  }
+};  // struct TakeAttrs
+
+/*! \brief Attributes used in strided_slice operator */
+struct StridedSliceAttrs : public tvm::AttrsNode<StridedSliceAttrs> {
+  Array<Integer> axes;
+  Array<PrimExpr> begin;
+  Array<PrimExpr> end;
+  Optional<Array<PrimExpr>> strides;
+
+  TVM_DECLARE_ATTRS(StridedSliceAttrs, "relax.attrs.StridedSliceAttrs") {
+    TVM_ATTR_FIELD(axes).describe("Axes along which slicing is applied.");
+    TVM_ATTR_FIELD(begin).describe("The indices to begin with in the slicing, 
inclusive.");
+    TVM_ATTR_FIELD(end).describe("The indices indicating end of the slice, 
exclusive.");
+    TVM_ATTR_FIELD(strides).describe(
+        "Specifies the stride values, it can be negative in that case, the 
input tensor will be "
+        "reversed in that particular axis. If not specified, it by default is 
an list of ones of "
+        "the same length as `axes`.");
+  }
+};  // struct StridedSliceAttrs
+
+}  // namespace relax
+}  // namespace tvm
+
+#endif  // TVM_RELAX_ATTRS_INDEX_H_
diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py
index 9a131cdf95..3393a5dcae 100644
--- a/python/tvm/relax/op/__init__.py
+++ b/python/tvm/relax/op/__init__.py
@@ -20,6 +20,8 @@
 # Operators
 from .base import *
 from .binary import *
+from .index import *
 from .manipulate import *
+from .op_attrs import *
 from . import builtin
 from . import memory
diff --git a/python/tvm/relax/op/index.py b/python/tvm/relax/op/index.py
new file mode 100644
index 0000000000..2a7afa5ba0
--- /dev/null
+++ b/python/tvm/relax/op/index.py
@@ -0,0 +1,90 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Indexing operators."""
+from typing import List, Optional, Union
+
+from tvm.ir.expr import PrimExpr
+
+from . import _ffi_api
+from ..expr import Expr
+
+PrimExprLike = Union[int, PrimExpr]
+
+
+def take(x: Expr, indices: Expr, axis: Optional[int] = None) -> Expr:
+    """Take elements from a tensor along an axis.
+
+    Parameters
+    ----------
+    x : relax.Expr
+        The source tensor.
+
+    indices : relax.Expr
+        The indices of the values to extract.
+        It is required to be a one-dimensional tensor which has integer dtype.
+
+    axis : Optional[int]
+        The axis over which to select values.
+        If it is none, the input tensor is required to be one-dimensional.
+
+    Returns
+    -------
+    ret : relax.Expr
+        The taken result.
+    """
+    return _ffi_api.take(x, indices, axis)  # type: ignore
+
+
+def strided_slice(
+    x: Expr,
+    axes: List[int],
+    begin: List[PrimExprLike],
+    end: List[PrimExprLike],
+    strides: Optional[List[PrimExprLike]] = None,
+) -> Expr:
+    """Strided slice of a tensor.
+
+    Parameters
+    ----------
+    x : relax.Expr
+        The source tensor to be sliced.
+
+    axes : List[int]
+        Axes along which slicing is applied.
+
+    begin : List[PrimExprLike]
+        The indices to begin with in the slicing, inclusive.
+
+    end : List[PrimExprLike]
+        The indices indicating end of the slice, exclusive.
+
+    strides : Optional[List[PrimExprLike]]
+        Specifies the stride values, it can be negative in that case,
+        the input tensor will be reversed in that particular axis.
+        If not specified, it by default is an list of ones of the same length 
as `axes`.
+
+    Returns
+    -------
+    ret : relax.Expr
+        The sliced result.
+
+    Note
+    ----
+    strided_slice require the input `begin`, `end` and `strides` to have the
+    same length as `axes`.
+    """
+    return _ffi_api.strided_slice(x, axes, begin, end, strides)  # type: ignore
diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/op_attrs.py
similarity index 68%
copy from python/tvm/relax/op/__init__.py
copy to python/tvm/relax/op/op_attrs.py
index 9a131cdf95..44cb2cf3a5 100644
--- a/python/tvm/relax/op/__init__.py
+++ b/python/tvm/relax/op/op_attrs.py
@@ -14,12 +14,16 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-# pylint: disable=wildcard-import, redefined-builtin
-"""Relax core operators."""
+"""The attributes node used for Relax operators"""
+from tvm.ir import Attrs
+import tvm._ffi
 
-# Operators
-from .base import *
-from .binary import *
-from .manipulate import *
-from . import builtin
-from . import memory
+
+@tvm._ffi.register_object("relax.attrs.TakeAttrs")
+class TakeAttrs(Attrs):
+    """Attributes used in take operator"""
+
+
+@tvm._ffi.register_object("relax.attrs.StridedSliceAttrs")
+class StridedSliceAttrs(Attrs):
+    """Attributes used in strided_slice operator"""
diff --git a/python/tvm/script/ir_builder/relax/ir.py 
b/python/tvm/script/ir_builder/relax/ir.py
index 0e6595cb45..75a00ea049 100644
--- a/python/tvm/script/ir_builder/relax/ir.py
+++ b/python/tvm/script/ir_builder/relax/ir.py
@@ -42,6 +42,8 @@ from tvm.relax.op import (
     print,
     reshape,
     shape_of,
+    strided_slice,
+    take,
 )
 from tvm.relax.struct_info import StructInfo
 from tvm.relax.utils import args_converter
@@ -427,5 +429,7 @@ __all__ = [
     "shape",
     "shape_of",
     "str",
+    "strided_slice",
+    "take",
     "tuple",
 ]
diff --git a/src/relax/op/tensor/index.cc b/src/relax/op/tensor/index.cc
new file mode 100644
index 0000000000..246abef908
--- /dev/null
+++ b/src/relax/op/tensor/index.cc
@@ -0,0 +1,195 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file index.cc
+ * \brief indexing operators.
+ */
+
+#include "index.h"
+
+#include <utility>
+#include <vector>
+
+namespace tvm {
+namespace relax {
+
+/* relax.take */
+TVM_REGISTER_NODE_TYPE(TakeAttrs);
+
+Expr take(Expr x, Expr indices, Optional<Integer> axis) {
+  ObjectPtr<TakeAttrs> attrs = make_object<TakeAttrs>();
+  attrs->axis = std::move(axis);
+
+  static const Op& op = Op::Get("relax.take");
+  return Call(op, {std::move(x), std::move(indices)}, Attrs(attrs), {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.take").set_body_typed(take);
+
+StructInfo InferStructInfoTake(const Call& call, const BlockBuilder& ctx) {
+  Array<TensorStructInfo> input_sinfo = GetInputTensorStructInfo(call, ctx);
+  TensorStructInfo data_sinfo = input_sinfo[0];
+  TensorStructInfo indices_sinfo = input_sinfo[1];
+  if (indices_sinfo->ndim != 1) {
+    ctx->ReportFatal(Diagnostic::Error(call)
+                     << "Take op requires the input indices to be 
1-dimensional tensor. However, "
+                        "the given indices ndim is "
+                     << indices_sinfo->ndim);
+  } else if (!indices_sinfo->IsUnknownDtype() &&
+             !(indices_sinfo->dtype.is_int() || 
indices_sinfo->dtype.is_uint())) {
+    ctx->ReportFatal(Diagnostic::Error(call)
+                     << "Take op requires the input indices to have integer 
dtype. However, the "
+                        "given indices dtype is "
+                     << indices_sinfo->dtype);
+  }
+
+  const auto* attrs = call->attrs.as<TakeAttrs>();
+  if (!attrs->axis.defined() && data_sinfo->ndim != 1) {
+    ctx->ReportFatal(Diagnostic::Error(call)
+                     << "Take op expects the input data to be 1-dimensional 
tensor when the axis "
+                        "is not specified. However, the given data tensor has 
ndim "
+                     << data_sinfo->ndim);
+  }
+  if (data_sinfo->IsUnknownNdim()) {
+    return TensorStructInfo(data_sinfo->dtype, kUnknownNDim);
+  }
+
+  int axis = attrs->axis.defined()
+                 ? NormalizeAxis(call, ctx, data_sinfo->ndim, 
attrs->axis.value()->value)
+                 : 0;
+  const auto* data_shape = data_sinfo->shape.as<ShapeExprNode>();
+  const auto* indices_shape = indices_sinfo->shape.as<ShapeExprNode>();
+  if (data_shape == nullptr || indices_shape == nullptr) {
+    return TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim);
+  }
+
+  Array<PrimExpr> output_shape = data_shape->values;
+  output_shape.Set(axis, indices_shape->values[0]);
+  return TensorStructInfo(ShapeExpr(output_shape), data_sinfo->dtype);
+}
+
+TVM_REGISTER_OP("relax.take")
+    .set_attrs_type<TakeAttrs>()
+    .set_num_inputs(2)
+    .add_argument("x", "Tensor", "The source tensor.")
+    .add_argument("indices", "Tensor", "The indices of the values to extract.")
+    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoTake);
+
+/* relax.strided_slice */
+TVM_REGISTER_NODE_TYPE(StridedSliceAttrs);
+
+Expr strided_slice(Expr x,                 //
+                   Array<Integer> axes,    //
+                   Array<PrimExpr> begin,  //
+                   Array<PrimExpr> end,    //
+                   Optional<Array<PrimExpr>> strides) {
+  int n_axis = axes.size();
+  CHECK_EQ(static_cast<int>(begin.size()), n_axis)
+      << "StridedSlice requires the number of begin indices to equal the 
number of axes.";
+  CHECK_EQ(static_cast<int>(end.size()), n_axis)
+      << "StridedSlice requires the number of end indices to equal the number 
of axes.";
+  if (strides.defined()) {
+    CHECK_EQ(static_cast<int>(strides.value().size()), n_axis)
+        << "StridedSlice requires the number of strides to equal the number of 
axes.";
+  }
+
+  // Todo(relax-team): We are going to support dynamic strided slice, where
+  // begin/end/stride can be not static at compile time. Therefore, 
begin/end/stride
+  // should not be part of StridedSliceAttrs, as we only allow static values to
+  // reside in attributes. However, using ShapeExpr to represent these
+  // arrays is not conceptually right, because they are not describing a
+  // concrete shape. The proper way to support dynamic strided slice is to use
+  // Tuple of PrimValue to represent begin/end/stride. Since at this moment
+  // we have no support for PrimValue, we store begin/end/stride as attribute
+  // fields as a workaround.
+  // Will switch to Tuple of PrimValue after introducing PrimValue.
+  auto f_convert_to_int64 = [](const PrimExpr& value) {
+    if (value->IsInstance<IntImmNode>()) {
+      return cast(DataType::Int(64), value);
+    }
+    CHECK(value.dtype() == DataType::Int(64)) << "strided_slice expects the 
input begin/end/stride "
+                                                 "values to be all int64. 
However, the given "
+                                              << value << " has dtype " << 
value->dtype;
+    return value;
+  };
+
+  ObjectPtr<StridedSliceAttrs> attrs = make_object<StridedSliceAttrs>();
+  attrs->axes = std::move(axes);
+  attrs->begin = begin.Map(f_convert_to_int64);
+  attrs->end = end.Map(f_convert_to_int64);
+  attrs->strides = strides.defined() ? strides.value().Map(f_convert_to_int64) 
: strides;
+
+  static const Op& op = Op::Get("relax.strided_slice");
+  return Call(op, {std::move(x)}, Attrs(attrs), {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.strided_slice").set_body_typed(strided_slice);
+
+StructInfo InferStructInfoStridedSlice(const Call& call, const BlockBuilder& 
ctx) {
+  TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx);
+  const auto* attrs = call->attrs.as<StridedSliceAttrs>();
+  if (attrs->axes.empty()) {
+    return data_sinfo;
+  }
+
+  if (data_sinfo->IsUnknownNdim()) {
+    return TensorStructInfo(data_sinfo->dtype, kUnknownNDim);
+  }
+
+  std::vector<int> axes = NormalizeAxes(call, ctx, data_sinfo->ndim, 
attrs->axes);
+  const auto* data_shape = data_sinfo->shape.as<ShapeExprNode>();
+  if (data_shape == nullptr) {
+    return TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim);
+  }
+
+  int n_axis = axes.size();
+  Array<PrimExpr> strides = attrs->strides.defined()
+                                ? attrs->strides.value()
+                                : Array<PrimExpr>(n_axis, 
IntImm(DataType::Int(64), 1));
+  std::vector<int> int_strides;
+  int_strides.reserve(n_axis);
+  // Only do output shape inference when all the begin/end/stride values are 
integers.
+  for (int i = 0; i < n_axis; ++i) {
+    const auto* int_begin = attrs->begin[i].as<IntImmNode>();
+    const auto* int_end = attrs->end[i].as<IntImmNode>();
+    const auto* int_stride = strides[i].as<IntImmNode>();
+    if (!int_begin || !int_end || !int_stride) {
+      return TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim);
+    }
+    int_strides.push_back(int_stride->value);
+  }
+
+  Array<PrimExpr> output_shape = data_shape->values;
+  for (int i = 0; i < n_axis; ++i) {
+    PrimExpr len = int_strides[i] < 0 ? ceildiv(attrs->begin[i] - 
attrs->end[i], -int_strides[i])
+                                      : ceildiv(attrs->end[i] - 
attrs->begin[i], int_strides[i]);
+    output_shape.Set(axes[i], len);
+  }
+  return TensorStructInfo(ShapeExpr(output_shape), data_sinfo->dtype);
+}
+
+TVM_REGISTER_OP("relax.strided_slice")
+    .set_attrs_type<StridedSliceAttrs>()
+    .set_num_inputs(1)
+    .add_argument("x", "Tensor", "The source tensor to be sliced.")
+    .set_attr<FInferStructInfo>("FInferStructInfo", 
InferStructInfoStridedSlice);
+
+}  // namespace relax
+}  // namespace tvm
diff --git a/src/relax/op/tensor/index.h b/src/relax/op/tensor/index.h
new file mode 100644
index 0000000000..6944493a0f
--- /dev/null
+++ b/src/relax/op/tensor/index.h
@@ -0,0 +1,65 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file index.h
+ * \brief The functions to make Relax tensor indexing operator calls.
+ */
+#ifndef TVM_RELAX_OP_TENSOR_INDEX_H_
+#define TVM_RELAX_OP_TENSOR_INDEX_H_
+
+#include <tvm/relax/attrs/index.h>
+
+#include "../op_common.h"
+
+namespace tvm {
+namespace relax {
+
+/*!
+ * \brief Take elements from a tensor along an axis.
+ * \param x The source tensor.
+ * \param indices The indices of the values to extract.
+ * It is required to be a one-dimensional tensor which has integer dtype.
+ * \param axis The axis over which to select values.
+ * If it is `NullOpt`, the input tensor is required to be one-dimensional.
+ * \return The taken result.
+ */
+Expr take(Expr x, Expr indices, Optional<Integer> axis);
+
+/*!
+ * \brief Strided slice of a tensor.
+ * \param x The source tensor to be sliced.
+ * \param axes Axes along which slicing is applied.
+ * \param begin The indices to begin with in the slicing, inclusive.
+ * \param end The indices indicating end of the slice, exclusive.
+ * \param strides Specifies the stride values, it can be negative in that case,
+ * the input tensor will be reversed in that particular axis.
+ * If it is `NullOpt`, it by default is an list of ones of the same length as 
`axes`.
+ * \return The sliced result
+ */
+Expr strided_slice(Expr x,                 //
+                   Array<Integer> axes,    //
+                   Array<PrimExpr> begin,  //
+                   Array<PrimExpr> end,    //
+                   Optional<Array<PrimExpr>> strides);
+
+}  // namespace relax
+}  // namespace tvm
+
+#endif  // TVM_RELAX_OP_TENSOR_INDEX_H_
diff --git a/tests/python/relax/test_op_index.py 
b/tests/python/relax/test_op_index.py
new file mode 100644
index 0000000000..77a04b1a1a
--- /dev/null
+++ b/tests/python/relax/test_op_index.py
@@ -0,0 +1,593 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import pytest
+import tvm
+import tvm.testing
+from tvm import relax, tir
+from tvm import TVMError
+from tvm.ir import Op
+from tvm.script import relax as R
+
+
+def test_op_correctness():
+    x = relax.Var("x", R.Tensor((2, 3), "float32"))
+    idx = relax.Var("idx", R.Tensor((2,), "float32"))
+    assert relax.op.take(x, idx, axis=1).op == Op.get("relax.take")
+    assert relax.op.strided_slice(x, axes=[0], begin=[0], end=[2]).op == 
Op.get(
+        "relax.strided_slice"
+    )
+
+
+def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: 
relax.StructInfo):
+    ret = bb.normalize(call)
+    tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo)
+
+
+def test_take_infer_struct_info():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((4, 10), "float32"))
+    x1 = relax.Var("x", R.Tensor("float32", ndim=2))
+    x2 = relax.Var("x", R.Tensor("float32"))
+    x3 = relax.Var("x", R.Tensor((4, 10)))
+    x4 = relax.Var("x", R.Tensor(ndim=2))
+    x5 = relax.Var("x", R.Tensor())
+    y0 = relax.Var("y", R.Tensor((10,), "float32"))
+    y1 = relax.Var("y", R.Tensor("float32", ndim=1))
+    y2 = relax.Var("y", R.Tensor((10,)))
+    y3 = relax.Var("y", R.Tensor(ndim=1))
+    idx0 = relax.Var("idx", R.Tensor((6,), "int64"))
+    idx1 = relax.Var("idx", R.Tensor("int64", ndim=1))
+    idx2 = relax.Var("idx", R.Tensor((6,)))
+    idx3 = relax.Var("idx", R.Tensor(ndim=1))
+
+    _check_inference(bb, relax.op.take(x0, idx0, axis=1), 
relax.TensorStructInfo((4, 6), "float32"))
+    _check_inference(
+        bb, relax.op.take(x0, idx0, axis=-1), relax.TensorStructInfo((4, 6), 
"float32")
+    )
+    _check_inference(
+        bb, relax.op.take(x1, idx0, axis=1), 
relax.TensorStructInfo(dtype="float32", ndim=2)
+    )
+    _check_inference(bb, relax.op.take(x2, idx0, axis=1), 
relax.TensorStructInfo(dtype="float32"))
+    _check_inference(bb, relax.op.take(x3, idx0, axis=1), 
relax.TensorStructInfo((4, 6), dtype=""))
+    _check_inference(bb, relax.op.take(x4, idx0, axis=1), 
relax.TensorStructInfo(dtype="", ndim=2))
+    _check_inference(bb, relax.op.take(x5, idx0, axis=1), 
relax.TensorStructInfo(dtype=""))
+    _check_inference(
+        bb, relax.op.take(x0, idx1, axis=1), 
relax.TensorStructInfo(dtype="float32", ndim=2)
+    )
+    _check_inference(
+        bb, relax.op.take(x1, idx1, axis=1), 
relax.TensorStructInfo(dtype="float32", ndim=2)
+    )
+    _check_inference(bb, relax.op.take(x2, idx1, axis=1), 
relax.TensorStructInfo(dtype="float32"))
+    _check_inference(bb, relax.op.take(x3, idx1, axis=1), 
relax.TensorStructInfo(dtype="", ndim=2))
+    _check_inference(bb, relax.op.take(x4, idx1, axis=1), 
relax.TensorStructInfo(dtype="", ndim=2))
+    _check_inference(bb, relax.op.take(x5, idx1, axis=1), 
relax.TensorStructInfo(dtype=""))
+    _check_inference(bb, relax.op.take(x0, idx2, axis=1), 
relax.TensorStructInfo((4, 6), "float32"))
+    _check_inference(
+        bb, relax.op.take(x1, idx2, axis=1), 
relax.TensorStructInfo(dtype="float32", ndim=2)
+    )
+    _check_inference(bb, relax.op.take(x2, idx2, axis=1), 
relax.TensorStructInfo(dtype="float32"))
+    _check_inference(bb, relax.op.take(x3, idx2, axis=1), 
relax.TensorStructInfo((4, 6), dtype=""))
+    _check_inference(bb, relax.op.take(x4, idx2, axis=1), 
relax.TensorStructInfo(dtype="", ndim=2))
+    _check_inference(bb, relax.op.take(x5, idx2, axis=1), 
relax.TensorStructInfo(dtype=""))
+    _check_inference(
+        bb, relax.op.take(x0, idx3, axis=1), 
relax.TensorStructInfo(dtype="float32", ndim=2)
+    )
+    _check_inference(
+        bb, relax.op.take(x1, idx3, axis=1), 
relax.TensorStructInfo(dtype="float32", ndim=2)
+    )
+    _check_inference(bb, relax.op.take(x2, idx3, axis=1), 
relax.TensorStructInfo(dtype="float32"))
+    _check_inference(bb, relax.op.take(x3, idx3, axis=1), 
relax.TensorStructInfo(dtype="", ndim=2))
+    _check_inference(bb, relax.op.take(x4, idx3, axis=1), 
relax.TensorStructInfo(dtype="", ndim=2))
+    _check_inference(bb, relax.op.take(x5, idx3, axis=1), 
relax.TensorStructInfo(dtype=""))
+    _check_inference(bb, relax.op.take(y0, idx0), relax.TensorStructInfo((6,), 
"float32"))
+    _check_inference(bb, relax.op.take(y1, idx0), 
relax.TensorStructInfo(dtype="float32", ndim=1))
+    _check_inference(bb, relax.op.take(y2, idx0), relax.TensorStructInfo((6,), 
dtype=""))
+    _check_inference(bb, relax.op.take(y3, idx0), 
relax.TensorStructInfo(dtype="", ndim=1))
+    _check_inference(bb, relax.op.take(y0, idx1), 
relax.TensorStructInfo(dtype="float32", ndim=1))
+    _check_inference(bb, relax.op.take(y1, idx1), 
relax.TensorStructInfo(dtype="float32", ndim=1))
+    _check_inference(bb, relax.op.take(y2, idx1), 
relax.TensorStructInfo(dtype="", ndim=1))
+    _check_inference(bb, relax.op.take(y3, idx1), 
relax.TensorStructInfo(dtype="", ndim=1))
+    _check_inference(bb, relax.op.take(y0, idx2), relax.TensorStructInfo((6,), 
"float32"))
+    _check_inference(bb, relax.op.take(y1, idx2), 
relax.TensorStructInfo(dtype="float32", ndim=1))
+    _check_inference(bb, relax.op.take(y2, idx2), relax.TensorStructInfo((6,), 
dtype=""))
+    _check_inference(bb, relax.op.take(y3, idx2), 
relax.TensorStructInfo(dtype="", ndim=1))
+    _check_inference(bb, relax.op.take(y0, idx3), 
relax.TensorStructInfo(dtype="float32", ndim=1))
+    _check_inference(bb, relax.op.take(y1, idx3), 
relax.TensorStructInfo(dtype="float32", ndim=1))
+    _check_inference(bb, relax.op.take(y2, idx3), 
relax.TensorStructInfo(dtype="", ndim=1))
+    _check_inference(bb, relax.op.take(y3, idx3), 
relax.TensorStructInfo(dtype="", ndim=1))
+
+
+def test_take_infer_struct_info_shape_symbolic():
+    bb = relax.BlockBuilder()
+    m = tir.Var("m", "int64")
+    n = tir.Var("n", "int64")
+    i = tir.Var("i", "int64")
+    x0 = relax.Var("x", R.Tensor((m, n), "float32"))
+    x1 = relax.Var("x", R.Tensor((m, n)))
+    y0 = relax.Var("y", R.Tensor((n,), "float32"))
+    y1 = relax.Var("y", R.Tensor((n,)))
+    idx0 = relax.Var("idx", R.Tensor((i,), "int64"))
+    idx1 = relax.Var(
+        "idx",
+        R.Tensor(
+            (i,),
+        ),
+    )
+
+    _check_inference(bb, relax.op.take(x0, idx0, axis=1), 
relax.TensorStructInfo((m, i), "float32"))
+    _check_inference(bb, relax.op.take(x1, idx0, axis=1), 
relax.TensorStructInfo((m, i), dtype=""))
+    _check_inference(bb, relax.op.take(x0, idx1, axis=1), 
relax.TensorStructInfo((m, i), "float32"))
+    _check_inference(bb, relax.op.take(x1, idx1, axis=1), 
relax.TensorStructInfo((m, i), dtype=""))
+    _check_inference(bb, relax.op.take(y0, idx0), relax.TensorStructInfo((i,), 
"float32"))
+    _check_inference(bb, relax.op.take(y1, idx0), relax.TensorStructInfo((i,), 
dtype=""))
+    _check_inference(bb, relax.op.take(y0, idx1), relax.TensorStructInfo((i,), 
"float32"))
+    _check_inference(bb, relax.op.take(y1, idx1), relax.TensorStructInfo((i,), 
dtype=""))
+
+
+def test_take_infer_struct_info_shape_var():
+    bb = relax.BlockBuilder()
+    sx0 = relax.Var("sx", relax.ShapeStructInfo((4, 10)))
+    sx1 = relax.Var("sx", relax.ShapeStructInfo(ndim=2))
+    sx2 = relax.Var("sx", relax.ShapeStructInfo())
+    sidx0 = relax.Var("sidx", relax.ShapeStructInfo((6,)))
+    sidx1 = relax.Var("sidx", relax.ShapeStructInfo(ndim=1))
+    x0 = relax.Var("x", relax.TensorStructInfo(sx0, "float32"))
+    x1 = relax.Var("x", relax.TensorStructInfo(sx1, "float32"))
+    x2 = relax.Var("x", relax.TensorStructInfo(sx2, "float32"))
+    x3 = relax.Var("x", R.Tensor((4, 10), "float32"))
+    idx0 = relax.Var("idx", relax.TensorStructInfo(sidx0, "int64"))
+    idx1 = relax.Var("idx", relax.TensorStructInfo(sidx1, "int64"))
+    idx2 = relax.Var("idx", R.Tensor((6,), "int64"))
+
+    _check_inference(
+        bb, relax.op.take(x0, idx0, axis=1), 
relax.TensorStructInfo(dtype="float32", ndim=2)
+    )
+    _check_inference(
+        bb, relax.op.take(x0, idx1, axis=1), 
relax.TensorStructInfo(dtype="float32", ndim=2)
+    )
+    _check_inference(
+        bb, relax.op.take(x0, idx2, axis=1), 
relax.TensorStructInfo(dtype="float32", ndim=2)
+    )
+    _check_inference(
+        bb, relax.op.take(x1, idx0, axis=1), 
relax.TensorStructInfo(dtype="float32", ndim=2)
+    )
+    _check_inference(
+        bb, relax.op.take(x1, idx1, axis=1), 
relax.TensorStructInfo(dtype="float32", ndim=2)
+    )
+    _check_inference(
+        bb, relax.op.take(x1, idx2, axis=1), 
relax.TensorStructInfo(dtype="float32", ndim=2)
+    )
+    _check_inference(bb, relax.op.take(x2, idx0, axis=1), 
relax.TensorStructInfo(dtype="float32"))
+    _check_inference(bb, relax.op.take(x2, idx1, axis=1), 
relax.TensorStructInfo(dtype="float32"))
+    _check_inference(bb, relax.op.take(x2, idx2, axis=1), 
relax.TensorStructInfo(dtype="float32"))
+    _check_inference(
+        bb, relax.op.take(x3, idx0, axis=1), 
relax.TensorStructInfo(dtype="float32", ndim=2)
+    )
+    _check_inference(
+        bb, relax.op.take(x3, idx1, axis=1), 
relax.TensorStructInfo(dtype="float32", ndim=2)
+    )
+
+
+def test_take_infer_struct_info_more_input_dtype():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((4, 10), "float16"))
+    x1 = relax.Var("x", R.Tensor((4, 10), "int16"))
+    x2 = relax.Var("x", R.Tensor((4, 10), "int32"))
+    idx0 = relax.Var("idx", R.Tensor((6,), "int32"))
+    idx1 = relax.Var("idx", R.Tensor((6,), "int8"))
+    idx2 = relax.Var("idx", R.Tensor((6,), "uint32"))
+
+    _check_inference(bb, relax.op.take(x0, idx0, axis=1), 
relax.TensorStructInfo((4, 6), "float16"))
+    _check_inference(bb, relax.op.take(x1, idx0, axis=1), 
relax.TensorStructInfo((4, 6), "int16"))
+    _check_inference(bb, relax.op.take(x2, idx0, axis=1), 
relax.TensorStructInfo((4, 6), "int32"))
+    _check_inference(bb, relax.op.take(x0, idx1, axis=1), 
relax.TensorStructInfo((4, 6), "float16"))
+    _check_inference(bb, relax.op.take(x1, idx1, axis=1), 
relax.TensorStructInfo((4, 6), "int16"))
+    _check_inference(bb, relax.op.take(x2, idx1, axis=1), 
relax.TensorStructInfo((4, 6), "int32"))
+    _check_inference(bb, relax.op.take(x0, idx2, axis=1), 
relax.TensorStructInfo((4, 6), "float16"))
+    _check_inference(bb, relax.op.take(x1, idx2, axis=1), 
relax.TensorStructInfo((4, 6), "int16"))
+    _check_inference(bb, relax.op.take(x2, idx2, axis=1), 
relax.TensorStructInfo((4, 6), "int32"))
+
+
+def test_take_infer_struct_info_indices_not_one_dimensional():
+    bb = relax.BlockBuilder()
+    sidx0 = relax.Var("sidx", relax.ShapeStructInfo((6, 6)))
+    sidx1 = relax.Var("sidx", relax.ShapeStructInfo(()))
+    sidx2 = relax.Var("sidx", relax.ShapeStructInfo(ndim=2))
+    sidx3 = relax.Var("sidx", relax.ShapeStructInfo(ndim=0))
+    sidx4 = relax.Var("sidx", relax.ShapeStructInfo())
+    x = relax.Var("x", R.Tensor((4, 10), "float32"))
+    idx0 = relax.Var("idx", R.Tensor((6, 6), "int64"))
+    idx1 = relax.Var("idx", R.Tensor((), "int64"))
+    idx2 = relax.Var("idx", R.Tensor("int64", ndim=2))
+    idx3 = relax.Var("idx", R.Tensor("int64", ndim=0))
+    idx4 = relax.Var("idx", R.Tensor("int64"))
+    idx5 = relax.Var("idx", relax.TensorStructInfo(sidx0, "int64"))
+    idx6 = relax.Var("idx", relax.TensorStructInfo(sidx1, "int64"))
+    idx7 = relax.Var("idx", relax.TensorStructInfo(sidx2, "int64"))
+    idx8 = relax.Var("idx", relax.TensorStructInfo(sidx3, "int64"))
+    idx9 = relax.Var("idx", relax.TensorStructInfo(sidx4, "int64"))
+
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.take(x, idx0, axis=1))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.take(x, idx1, axis=1))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.take(x, idx2, axis=1))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.take(x, idx3, axis=1))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.take(x, idx4, axis=1))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.take(x, idx5, axis=1))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.take(x, idx6, axis=1))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.take(x, idx7, axis=1))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.take(x, idx8, axis=1))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.take(x, idx9, axis=1))
+
+
+def test_take_infer_struct_info_indices_not_integer_dtype():
+    bb = relax.BlockBuilder()
+    x = relax.Var("x", R.Tensor((4, 10), "float32"))
+    idx0 = relax.Var("idx", R.Tensor((6, 6), "float32"))
+    idx1 = relax.Var("idx", R.Tensor((6, 6), "float64"))
+
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.take(x, idx0, axis=1))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.take(x, idx1, axis=1))
+
+
+def test_take_infer_struct_info_multi_dimensional_without_axis():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((4, 10), "float32"))
+    x1 = relax.Var("x", R.Tensor("float32", ndim=2))
+    x2 = relax.Var("x", R.Tensor("float32"))
+    idx0 = relax.Var("idx", R.Tensor((6,), "int64"))
+    idx1 = relax.Var("idx", R.Tensor("int64", ndim=1))
+
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.take(x0, idx0))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.take(x1, idx0))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.take(x2, idx0))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.take(x0, idx1))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.take(x1, idx1))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.take(x2, idx1))
+
+
+def test_take_infer_struct_info_axis_out_of_range():
+    bb = relax.BlockBuilder()
+    x = relax.Var("x", R.Tensor((4, 10), "float32"))
+    idx = relax.Var("idx", R.Tensor((6,), "int64"))
+
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.take(x, idx, axis=-3))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.take(x, idx, axis=2))
+
+
+def test_take_infer_struct_info_wrong_input_type():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", relax.ShapeStructInfo((4, 10)))
+    x1 = relax.Var("x", R.Tensor((4, 10), "float32"))
+    idx0 = relax.Var("idx", relax.ShapeStructInfo((6,)))
+    idx1 = relax.Var("idx", R.Tensor((6,), "int64"))
+
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.take(x0, idx1, axis=1))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.take(x1, idx0, axis=1))
+
+
+def test_strided_slice_infer_struct_info():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((8, 9, 10, 10), "float32"))
+    x1 = relax.Var("x", R.Tensor("float32", ndim=4))
+    x2 = relax.Var("x", R.Tensor("float32"))
+    x3 = relax.Var("x", R.Tensor((8, 9, 10, 10)))
+    x4 = relax.Var("x", R.Tensor(ndim=4))
+    x5 = relax.Var("x", R.Tensor())
+
+    _check_inference(
+        bb,
+        relax.op.strided_slice(
+            x0, axes=[0, 1, 3], begin=[1, 0, 8], end=[8, 9, 0], strides=[2, 1, 
-3]
+        ),
+        relax.TensorStructInfo((4, 9, 10, 3), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.strided_slice(
+            x1, axes=[0, 1, 3], begin=[1, 0, 8], end=[8, 9, 0], strides=[2, 1, 
-3]
+        ),
+        relax.TensorStructInfo(dtype="float32", ndim=4),
+    )
+    _check_inference(
+        bb,
+        relax.op.strided_slice(
+            x2, axes=[0, 1, 3], begin=[1, 0, 8], end=[8, 9, 0], strides=[2, 1, 
-3]
+        ),
+        relax.TensorStructInfo(dtype="float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.strided_slice(
+            x3, axes=[0, 1, 3], begin=[1, 0, 8], end=[8, 9, 0], strides=[2, 1, 
-3]
+        ),
+        relax.TensorStructInfo((4, 9, 10, 3), dtype=""),
+    )
+    _check_inference(
+        bb,
+        relax.op.strided_slice(
+            x4, axes=[0, 1, 3], begin=[1, 0, 8], end=[8, 9, 0], strides=[2, 1, 
-3]
+        ),
+        relax.TensorStructInfo(dtype="", ndim=4),
+    )
+    _check_inference(
+        bb,
+        relax.op.strided_slice(
+            x5, axes=[0, 1, 3], begin=[1, 0, 8], end=[8, 9, 0], strides=[2, 1, 
-3]
+        ),
+        relax.TensorStructInfo(dtype=""),
+    )
+    _check_inference(
+        bb,
+        relax.op.strided_slice(
+            x0, axes=[-1, -3, -4], begin=[8, 0, 1], end=[0, 9, 8], 
strides=[-3, 1, 2]
+        ),
+        relax.TensorStructInfo((4, 9, 10, 3), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.strided_slice(x0, axes=[1, 2], begin=[1, 0], end=[8, 9]),
+        relax.TensorStructInfo((8, 7, 9, 10), "float32"),
+    )
+
+
+def test_strided_slice_infer_struct_info_shape_symbolic():
+    bb = relax.BlockBuilder()
+    m = tir.Var("m", "int64")
+    n = tir.Var("n", "int64")
+    x0 = relax.Var("x", R.Tensor((m, n), "float32"))
+    x1 = relax.Var("x", R.Tensor((m, n)))
+
+    _check_inference(
+        bb,
+        relax.op.strided_slice(x0, axes=[0], begin=[1], end=[3]),
+        relax.TensorStructInfo((2, n), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.strided_slice(x0, axes=[0], begin=[1], end=[8], strides=[3]),
+        relax.TensorStructInfo((3, n), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.strided_slice(x1, axes=[0], begin=[1], end=[3]),
+        relax.TensorStructInfo((2, n), dtype=""),
+    )
+    _check_inference(
+        bb,
+        relax.op.strided_slice(x1, axes=[0], begin=[1], end=[8], strides=[3]),
+        relax.TensorStructInfo((3, n), dtype=""),
+    )
+
+
+def test_strided_slice_infer_struct_info_shape_var():
+    bb = relax.BlockBuilder()
+    s0 = relax.Var("s", relax.ShapeStructInfo((8, 10)))
+    s1 = relax.Var("s", relax.ShapeStructInfo(ndim=2))
+    s2 = relax.Var("s", relax.ShapeStructInfo())
+    x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32"))
+    x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32"))
+    x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32"))
+    x3 = relax.Var("x", relax.TensorStructInfo(s0, dtype=""))
+    x4 = relax.Var("x", relax.TensorStructInfo(s1, dtype=""))
+    x5 = relax.Var("x", relax.TensorStructInfo(s2, dtype=""))
+
+    _check_inference(
+        bb,
+        relax.op.strided_slice(x0, axes=[0], begin=[0], end=[8]),
+        relax.TensorStructInfo(dtype="float32", ndim=2),
+    )
+    _check_inference(
+        bb,
+        relax.op.strided_slice(x1, axes=[0], begin=[0], end=[8]),
+        relax.TensorStructInfo(dtype="float32", ndim=2),
+    )
+    _check_inference(
+        bb,
+        relax.op.strided_slice(x2, axes=[0], begin=[0], end=[8]),
+        relax.TensorStructInfo(dtype="float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.strided_slice(x3, axes=[0], begin=[0], end=[8]),
+        relax.TensorStructInfo(dtype="", ndim=2),
+    )
+    _check_inference(
+        bb,
+        relax.op.strided_slice(x4, axes=[0], begin=[0], end=[8]),
+        relax.TensorStructInfo(dtype="", ndim=2),
+    )
+    _check_inference(
+        bb,
+        relax.op.strided_slice(x5, axes=[0], begin=[0], end=[8]),
+        relax.TensorStructInfo(dtype=""),
+    )
+
+
+def test_strided_slice_infer_struct_info_more_input_dtype():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((8, 9), "float16"))
+    x1 = relax.Var("x", R.Tensor((8, 9), "int32"))
+    x2 = relax.Var("x", R.Tensor((8, 9), "int64"))
+
+    _check_inference(
+        bb,
+        relax.op.strided_slice(x0, axes=[0], begin=[0], end=[8]),
+        relax.TensorStructInfo((8, 9), "float16"),
+    )
+    _check_inference(
+        bb,
+        relax.op.strided_slice(x1, axes=[0], begin=[0], end=[8]),
+        relax.TensorStructInfo((8, 9), "int32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.strided_slice(x2, axes=[0], begin=[0], end=[8]),
+        relax.TensorStructInfo((8, 9), "int64"),
+    )
+
+
+def test_strided_slice_infer_struct_info_symbolic_begin_end_strides():
+    bb = relax.BlockBuilder()
+    a = tir.Var("a", "int64")
+    x = relax.Var("x", R.Tensor((8, 9), "float32"))
+
+    _check_inference(
+        bb,
+        relax.op.strided_slice(x, axes=[0], begin=[a], end=[8]),
+        relax.TensorStructInfo(dtype="float32", ndim=2),
+    )
+    _check_inference(
+        bb,
+        relax.op.strided_slice(x, axes=[0], begin=[0], end=[a]),
+        relax.TensorStructInfo(dtype="float32", ndim=2),
+    )
+    _check_inference(
+        bb,
+        relax.op.strided_slice(x, axes=[0], begin=[0], end=[8], strides=[a]),
+        relax.TensorStructInfo(dtype="float32", ndim=2),
+    )
+
+
+def test_strided_slice_infer_struct_info_no_axis():
+    bb = relax.BlockBuilder()
+    m = tir.Var("m", "int64")
+    n = tir.Var("n", "int64")
+    s0 = relax.Var("s", relax.ShapeStructInfo((m, n)))
+    s1 = relax.Var("s", relax.ShapeStructInfo(ndim=2))
+    s2 = relax.Var("s", relax.ShapeStructInfo())
+    x0 = relax.Var("x", R.Tensor((m, n), "float32"))
+    x1 = relax.Var("x", R.Tensor(dtype="float32", ndim=2))
+    x2 = relax.Var("x", R.Tensor(dtype="float32"))
+    x3 = relax.Var("x", relax.TensorStructInfo(s0, "float32"))
+    x4 = relax.Var("x", relax.TensorStructInfo(s1, "float32"))
+    x5 = relax.Var("x", relax.TensorStructInfo(s2, "float32"))
+
+    _check_inference(
+        bb,
+        relax.op.strided_slice(x0, axes=[], begin=[], end=[]),
+        relax.TensorStructInfo((m, n), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.strided_slice(x1, axes=[], begin=[], end=[]),
+        relax.TensorStructInfo(dtype="float32", ndim=2),
+    )
+    _check_inference(
+        bb,
+        relax.op.strided_slice(x2, axes=[], begin=[], end=[]),
+        relax.TensorStructInfo(dtype="float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.strided_slice(x3, axes=[], begin=[], end=[]),
+        relax.TensorStructInfo(s0, "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.strided_slice(x4, axes=[], begin=[], end=[]),
+        relax.TensorStructInfo(s1, "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.strided_slice(x5, axes=[], begin=[], end=[]),
+        relax.TensorStructInfo(s2, "float32"),
+    )
+
+
+def test_strided_slice_begin_end_strides_int64():
+    x = relax.Var("x", R.Tensor((8, 9, 10, 10), "float32"))
+    strided_slice = relax.op.strided_slice(
+        x, axes=[0, 1, 3], begin=[1, 0, 8], end=[8, 9, 0], strides=[2, 1, -3]
+    )
+
+    assert strided_slice.attrs.begin[0].dtype == "int64"
+    assert strided_slice.attrs.begin[1].dtype == "int64"
+    assert strided_slice.attrs.begin[2].dtype == "int64"
+    assert strided_slice.attrs.end[0].dtype == "int64"
+    assert strided_slice.attrs.end[1].dtype == "int64"
+    assert strided_slice.attrs.end[2].dtype == "int64"
+    assert strided_slice.attrs.strides[0].dtype == "int64"
+    assert strided_slice.attrs.strides[1].dtype == "int64"
+    assert strided_slice.attrs.strides[2].dtype == "int64"
+
+
+def test_strided_slice_inconsistent_axes_begin_end_strides_length():
+    x = relax.Var("x", R.Tensor((8, 9), "float32"))
+
+    with pytest.raises(TVMError):
+        relax.op.strided_slice(x, axes=[1], begin=[], end=[9])
+    with pytest.raises(TVMError):
+        relax.op.strided_slice(x, axes=[1], begin=[0], end=[])
+    with pytest.raises(TVMError):
+        relax.op.strided_slice(x, axes=[1], begin=[0], end=[9], strides=[])
+
+
+def test_strided_slice_infer_struct_info_repetitive_axes():
+    bb = relax.BlockBuilder()
+    x = relax.Var("x", R.Tensor((8, 9), "float32"))
+
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.strided_slice(x, axes=[0, 0], begin=[0, 0], 
end=[8, 8]))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.strided_slice(x, axes=[0, -2], begin=[0, 0], 
end=[8, 8]))
+
+
+def test_strided_slice_infer_struct_info_axis_out_of_range():
+    bb = relax.BlockBuilder()
+    x = relax.Var("x", R.Tensor((8, 9), "float32"))
+
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.strided_slice(x, axes=[2], begin=[0], end=[8]))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.strided_slice(x, axes=[-3], begin=[0], end=[8]))
+
+
+def test_strided_slice_infer_struct_info_wrong_input_type():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", relax.ShapeStructInfo((8, 9)))
+    x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((8, 9), "float32")))
+
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.strided_slice(x0, axes=[0], begin=[0], end=[8]))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.strided_slice(x1, axes=[0], begin=[0], end=[8]))
+
+
+if __name__ == "__main__":
+    tvm.testing.main()
diff --git a/tests/python/relax/test_tvmscript_parser_op_index.py 
b/tests/python/relax/test_tvmscript_parser_op_index.py
new file mode 100644
index 0000000000..b271d1a7f3
--- /dev/null
+++ b/tests/python/relax/test_tvmscript_parser_op_index.py
@@ -0,0 +1,82 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from typing import Optional, Union
+
+import tvm
+import tvm.script
+import tvm.testing
+from tvm import IRModule, relax
+from tvm.script import relax as R
+
+
+def _check(
+    parsed: Union[relax.Function, IRModule],
+    expect: Optional[Union[relax.Function, IRModule]],
+):
+    test = parsed.script(show_meta=True)
+    roundtrip_mod = tvm.script.from_source(test)
+    tvm.ir.assert_structural_equal(parsed, roundtrip_mod)
+    if expect:
+        tvm.ir.assert_structural_equal(parsed, expect)
+
+
+def test_take():
+    @R.function
+    def foo(
+        x: R.Tensor((2, 3, 4), "float32"), indices: R.Tensor((3,), "int64")
+    ) -> R.Tensor((2, 3, 3), "float32"):
+        gv: R.Tensor((2, 3, 3), "float32") = R.take(x, indices, axis=2)
+        return gv
+
+    x = relax.Var("x", R.Tensor((2, 3, 4), "float32"))
+    indices = relax.Var("indices", R.Tensor((3,), "int64"))
+    bb = relax.BlockBuilder()
+    with bb.function("foo", [x, indices]):
+        gv = bb.emit(relax.op.take(x, indices, axis=2))
+        bb.emit_func_output(gv)
+
+    _check(foo, bb.get()["foo"])
+
+
+def test_strided_slice():
+    @R.function
+    def foo(x: R.Tensor((8, 9, 10, 10), "float32")) -> R.Tensor((4, 9, 10, 3), 
"float32"):
+        gv: R.Tensor((4, 9, 10, 3), "float32") = R.strided_slice(
+            x,
+            axes=[0, 1, -1],
+            begin=[1, 0, 8],
+            end=[8, 9, 0],
+            strides=[2, 1, -3],
+        )
+        return gv
+
+    bb = relax.BlockBuilder()
+    x = relax.Var("x", R.Tensor((8, 9, 10, 10), "float32"))
+    with bb.function("foo", [x]):
+        gv = bb.emit(
+            relax.op.strided_slice(
+                x, axes=[0, 1, -1], begin=[1, 0, 8], end=[8, 9, 0], 
strides=[2, 1, -3]
+            )
+        )
+        bb.emit_func_output(gv)
+
+    _check(foo, bb.get()["foo"])
+
+
+if __name__ == "__main__":
+    tvm.testing.main()


Reply via email to