This is an automated email from the ASF dual-hosted git repository.

lunderberg pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 081c23becf [Relax] Allow PrimValue as index in relax.op.take (#16940)
081c23becf is described below

commit 081c23becf190b91a80f82cef2032cce816dc637
Author: Eric Lunderberg <[email protected]>
AuthorDate: Sun Apr 28 12:49:04 2024 -0500

    [Relax] Allow PrimValue as index in relax.op.take (#16940)
    
    * [Relax] Allow PrimValue as index in relax.op.take
    
    Prior to this commit, the `relax.op.take` only allowed tensors as the
    `indices` argument.  This commit extends `R.take` to also allow the
    index to be a `relax::PrimValue`.
    
    * Avoid comparison between signed/unsigned
    
    * Resolve/silence gcc warnings
---
 include/tvm/relax/block_builder.h                  |   2 +-
 include/tvm/topi/transform.h                       |  43 ++++--
 src/relax/ir/block_builder.cc                      |   2 +-
 src/relax/op/op_common.cc                          |  52 +++++--
 src/relax/op/op_common.h                           |  21 +++
 src/relax/op/tensor/index.cc                       |  26 +++-
 tests/python/relax/test_op_index.py                |  18 +++
 tests/python/relax/test_op_take.py                 | 158 +++++++++++++++++++++
 ..._transform_legalize_ops_index_linear_algebra.py |  97 +++++++++++++
 9 files changed, 388 insertions(+), 31 deletions(-)

diff --git a/include/tvm/relax/block_builder.h 
b/include/tvm/relax/block_builder.h
index a1e5a6bc31..7ca9aab6d5 100644
--- a/include/tvm/relax/block_builder.h
+++ b/include/tvm/relax/block_builder.h
@@ -116,7 +116,7 @@ class BlockBuilderNode : public Object {
    * \brief Report an error during transformation construction.
    * \param diagnostic The diagnostic information.
    */
-  virtual void ReportFatal(const Diagnostic& diagnostic) = 0;
+  [[noreturn]] virtual void ReportFatal(const Diagnostic& diagnostic) = 0;
 
   //-------------------------------
   // Scope management
diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h
index a1f66a70ca..3292ce57ba 100644
--- a/include/tvm/topi/transform.h
+++ b/include/tvm/topi/transform.h
@@ -1036,7 +1036,7 @@ inline Tensor sequence_mask(const Tensor& data, const 
Tensor& valid_length, doub
  *
  * \return A Tensor whose op member is the take operation
  */
-inline Tensor take(const Tensor& a, const Tensor& indices, int batch_dims, int 
axis,
+inline Tensor take(const Tensor& a, Variant<Tensor, PrimExpr> indices, int 
batch_dims, int axis,
                    std::string mode = "clip", std::string name = "T_take",
                    std::string tag = kInjective) {
   if (axis < 0) {
@@ -1045,22 +1045,30 @@ inline Tensor take(const Tensor& a, const Tensor& 
indices, int batch_dims, int a
   ICHECK_GE(axis, 0) << "axis out of bounds";
   ICHECK_LT(axis, a->shape.size()) << "axis out of bounds";
   auto axis_dim = a->shape[axis];
-  int indices_len = static_cast<int>(indices->shape.size());
+  auto indices_shape = [&]() -> Array<PrimExpr> {
+    if (auto tensor = indices.as<TensorNode>()) {
+      return tensor->shape;
+    } else {
+      return {};
+    }
+  }();
+
+  int indices_len = static_cast<int>(indices_shape.size());
 
   int batch_dims_ = batch_dims;
   if (batch_dims_ != 0) {
-    ICHECK_GE(batch_dims_, -static_cast<int>(indices->shape.size())) << 
"batch_dims out of bounds";
-    ICHECK_LE(batch_dims_, indices->shape.size()) << "batch_dims out of 
bounds";
+    ICHECK_GE(batch_dims_, -indices_len) << "batch_dims out of bounds";
+    ICHECK_LE(batch_dims_, indices_len) << "batch_dims out of bounds";
 
     if (batch_dims_ < 0) {
-      batch_dims_ = indices->shape.size() + batch_dims_;
+      batch_dims_ = indices_len + batch_dims_;
     }
 
     ICHECK_LT(batch_dims_, a->shape.size()) << "batch_dims out of bounds";
     ICHECK_LE(batch_dims_, axis) << "batch_dims must be less than or equal to 
axis";
     for (int i = 0; i < batch_dims_; ++i) {
       auto addr1 = a->shape[i];
-      auto addr2 = indices->shape[i];
+      auto addr2 = indices_shape[i];
       auto v1 = static_cast<IntImm*>(&addr1)->get()->value;
       auto v2 = static_cast<IntImm*>(&addr2)->get()->value;
       ICHECK_EQ(v1, v2) << "a.shape[" << i << "] should be equal to 
indices.shape[" << i << "]";
@@ -1077,13 +1085,24 @@ inline Tensor take(const Tensor& a, const Tensor& 
indices, int batch_dims, int a
   for (int i = batch_dims_; i < axis; ++i) {
     out_shape.push_back(a->shape[i]);
   }
-  for (size_t i = static_cast<size_t>(batch_dims_); i < indices->shape.size(); 
++i) {
-    out_shape.push_back(indices->shape[i]);
+  for (int i = batch_dims_; i < indices_len; ++i) {
+    out_shape.push_back(indices_shape[i]);
   }
   for (size_t i = axis + 1; i < a->shape.size(); ++i) {
     out_shape.push_back(a->shape[i]);
   }
 
+  auto get_index = [&](const Array<PrimExpr>& indices_position) -> PrimExpr {
+    if (auto tensor = indices.as<Tensor>()) {
+      return tensor.value()(indices_position);
+    } else if (auto prim = indices.as<PrimExpr>()) {
+      ICHECK_EQ(indices_position.size(), 0);
+      return prim.value();
+    } else {
+      LOG(FATAL) << "Variant did not contain either allowed type";
+    }
+  };
+
   if (mode == "clip") {
     if (batch_dims_ == 0) {
       return compute(
@@ -1097,7 +1116,7 @@ inline Tensor take(const Tensor& a, const Tensor& 
indices, int batch_dims, int a
             for (size_t j = 0; j < static_cast<size_t>(axis); ++j) {
               real_indices.push_back(out_index[j]);
             }
-            auto idx = tvm::min(tvm::max(0, indices(indices_position)), 
axis_dim - 1);
+            auto idx = tvm::min(tvm::max(0, get_index(indices_position)), 
axis_dim - 1);
             real_indices.push_back(idx);
             for (size_t j = axis + indices_len; j < out_index.size(); ++j) {
               real_indices.push_back(out_index[j]);
@@ -1120,7 +1139,7 @@ inline Tensor take(const Tensor& a, const Tensor& 
indices, int batch_dims, int a
             for (size_t j = 0; j < static_cast<size_t>(axis); ++j) {
               real_indices.push_back(out_index[j]);
             }
-            auto idx = tvm::min(tvm::max(0, indices(indices_position)), 
axis_dim - 1);
+            auto idx = tvm::min(tvm::max(0, get_index(indices_position)), 
axis_dim - 1);
             real_indices.push_back(idx);
             for (size_t j = axis + indices_len - batch_dims_; j < 
out_index.size(); ++j) {
               real_indices.push_back(out_index[j]);
@@ -1141,7 +1160,7 @@ inline Tensor take(const Tensor& a, const Tensor& 
indices, int batch_dims, int a
           for (size_t j = 0; j < static_cast<size_t>(axis); ++j) {
             real_indices.push_back(out_index[j]);
           }
-          real_indices.push_back(indices(indices_position));
+          real_indices.push_back(get_index(indices_position));
           for (size_t j = axis + indices_len; j < out_index.size(); ++j) {
             real_indices.push_back(out_index[j]);
           }
@@ -1160,7 +1179,7 @@ inline Tensor take(const Tensor& a, const Tensor& 
indices, int batch_dims, int a
           for (size_t j = 0; j < static_cast<size_t>(axis); ++j) {
             real_indices.push_back(out_index[j]);
           }
-          auto idx = truncmod(truncmod(indices(indices_position), axis_dim) + 
axis_dim, axis_dim);
+          auto idx = truncmod(truncmod(get_index(indices_position), axis_dim) 
+ axis_dim, axis_dim);
           real_indices.push_back(idx);
           for (size_t j = axis + indices_len; j < out_index.size(); ++j) {
             real_indices.push_back(out_index[j]);
diff --git a/src/relax/ir/block_builder.cc b/src/relax/ir/block_builder.cc
index 0c40c4e62a..e9a513c317 100644
--- a/src/relax/ir/block_builder.cc
+++ b/src/relax/ir/block_builder.cc
@@ -149,7 +149,7 @@ class BlockBuilderImpl : public BlockBuilderNode {
     }
   }
 
-  void ReportFatal(const Diagnostic& diagnostic) final {
+  [[noreturn]] void ReportFatal(const Diagnostic& diagnostic) final {
     // TODO(relax-team): Print more context information by looking
     // into the diagnostic->loc and surrounding IRModule.
     // We do not materialzie DiagnosticContext to avoid double referencing to
diff --git a/src/relax/op/op_common.cc b/src/relax/op/op_common.cc
index b35bd4b5a3..56bf708f5e 100644
--- a/src/relax/op/op_common.cc
+++ b/src/relax/op/op_common.cc
@@ -35,24 +35,48 @@ Array<Expr> GetCallArgs(const Call& call) {
   return args;
 }
 
-Array<TensorStructInfo> GetInputTensorStructInfo(const Call& call, const 
BlockBuilder& ctx) {
+void CheckNumArguments(const Call& call, const BlockBuilder& ctx) {
   Op op = Downcast<Op>(call->op);
-  int n_input = op->arguments.size();
-  if (static_cast<int>(call->args.size()) != n_input) {
+  int expected_input = op->arguments.size();
+  if (static_cast<int>(call->args.size()) != expected_input) {
     ctx->ReportFatal(Diagnostic::Error(call)
-                     << op << " op should have " << n_input << " arguments");
+                     << "Operator " << op << " expects " << expected_input << 
" arguments"
+                     << ", but was called with " << call->args.size() << " 
arguments");
   }
+}
+
+TensorStructInfo GetInputTensorStructInfo(const Call& call, size_t i_arg, 
const BlockBuilder& ctx) {
+  Op op = Downcast<Op>(call->op);
+
+  ICHECK_EQ(op->arguments.size(), call->args.size())
+      << "Failure caught by this check "
+      << "should have previously been caught by `CheckNumArguments`";
+  ICHECK_LT(i_arg, op->arguments.size());
+
+  auto arg = call->args[i_arg];
+  auto sinfo = GetStructInfo(arg);
+
+  if (auto tensor_sinfo = sinfo.as<TensorStructInfo>()) {
+    return tensor_sinfo.value();
+  } else {
+    ctx->ReportFatal(Diagnostic::Error(call)
+                     << "Operator " << op << " requires argument " << i_arg << 
" ("
+                     << op->arguments[i_arg]->name << ") to be a tensor.  "
+                     << "However, the argument " << arg << " is instead of 
type " << sinfo);
+    // Unreachable, but [[noreturn]] attribute on virtual function
+    // `ReportFatal` is insufficient to silence -Wreturn-type, as
+    // child class might not be [[noreturn]].
+    return TensorStructInfo();
+  }
+}
+
+Array<TensorStructInfo> GetInputTensorStructInfo(const Call& call, const 
BlockBuilder& ctx) {
+  CheckNumArguments(call, ctx);
+
+  Op op = Downcast<Op>(call->op);
   Array<TensorStructInfo> input_tensor_sinfo;
-  input_tensor_sinfo.reserve(n_input);
-  for (int i = 0; i < n_input; ++i) {
-    const auto* sinfo = GetStructInfoAs<TensorStructInfoNode>(call->args[i]);
-    if (sinfo == nullptr) {
-      ctx->ReportFatal(Diagnostic::Error(call)
-                       << op << " requires the input " << 
op->arguments[i]->name
-                       << " to be Tensor. However, the given one has a "
-                       << call->args[i]->struct_info_->GetTypeKey());
-    }
-    input_tensor_sinfo.push_back(GetRef<TensorStructInfo>(sinfo));
+  for (size_t i = 0; i < call->args.size(); ++i) {
+    input_tensor_sinfo.push_back(GetInputTensorStructInfo(call, i, ctx));
   }
   return input_tensor_sinfo;
 }
diff --git a/src/relax/op/op_common.h b/src/relax/op/op_common.h
index 5e19edb47c..94474ce784 100644
--- a/src/relax/op/op_common.h
+++ b/src/relax/op/op_common.h
@@ -44,6 +44,27 @@ namespace relax {
 
 /************ Op input struct info getter ************/
 
+/*!
+ * \brief Check that the operator has
+ *
+ * Verify that the number of arguments matches the expected number for
+ * the operator.
+ *
+ * \param call The context Call to the operator.
+ *
+ * \param ctx The error reporting context.
+ */
+void CheckNumArguments(const Call& call, const BlockBuilder& ctx);
+
+/*!
+ * \brief Get the tensor struct info of the operator input.
+ * \param call The context Call to the operator.
+ * \param i_arg The index of the argument to check
+ * \param ctx The error reporting context.
+ * \return The tensor struct info of the argument
+ */
+TensorStructInfo GetInputTensorStructInfo(const Call& call, size_t i_arg, 
const BlockBuilder& ctx);
+
 /*!
  * \brief Get the tensor struct info of the operator input.
  * \param call The context Call to the operator.
diff --git a/src/relax/op/tensor/index.cc b/src/relax/op/tensor/index.cc
index 7ab98e9468..d052c2a64f 100644
--- a/src/relax/op/tensor/index.cc
+++ b/src/relax/op/tensor/index.cc
@@ -44,9 +44,29 @@ Expr take(Expr x, Expr indices, Optional<Integer> axis) {
 TVM_REGISTER_GLOBAL("relax.op.take").set_body_typed(take);
 
 StructInfo InferStructInfoTake(const Call& call, const BlockBuilder& ctx) {
-  Array<TensorStructInfo> input_sinfo = GetInputTensorStructInfo(call, ctx);
-  TensorStructInfo data_sinfo = input_sinfo[0];
-  TensorStructInfo indices_sinfo = input_sinfo[1];
+  CheckNumArguments(call, ctx);
+  TensorStructInfo data_sinfo = GetInputTensorStructInfo(call, 0, ctx);
+
+  // StructInfo inference when the index is a PrimValue is equivalent
+  // to that of a scalar (0-d) tensor.
+  TensorStructInfo indices_sinfo = [&]() {
+    auto arg = call->args[1];
+    auto sinfo = GetStructInfo(arg);
+    if (auto tensor_sinfo = sinfo.as<TensorStructInfo>()) {
+      return tensor_sinfo.value();
+    } else if (auto prim_sinfo = sinfo.as<PrimStructInfoNode>()) {
+      return TensorStructInfo(ShapeExpr(Array<PrimExpr>{}), prim_sinfo->dtype);
+    } else {
+      ctx->ReportFatal(Diagnostic::Error(call)
+                       << "Operator " << call->op << " requires the indices 
argument to be "
+                       << "either a tensor or a scalar value.  "
+                       << "However, argument " << arg << " has struct info " 
<< sinfo);
+      // Unreachable, but [[noreturn]] attribute on virtual function
+      // `ReportFatal` is insufficient to silence -Wreturn-type, as
+      // child class might not be [[noreturn]].
+      return TensorStructInfo();
+    }
+  }();
 
   if (indices_sinfo->IsUnknownDtype()) {
     // TODO(tvm-team): Do we have an equivalent of `ctx->ReportFatal` for 
warning?
diff --git a/tests/python/relax/test_op_index.py 
b/tests/python/relax/test_op_index.py
index e3c9e4a596..1455b4182a 100644
--- a/tests/python/relax/test_op_index.py
+++ b/tests/python/relax/test_op_index.py
@@ -194,6 +194,24 @@ def test_take_infer_struct_info():
     _check_inference(bb, relax.op.take(y3, idx7), 
relax.TensorStructInfo(dtype="", ndim=2))
 
 
+def test_take_infer_struct_info_scalar_tensor_index():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((4, 10), "float32"))
+    idx = relax.Var("idx", R.Tensor([], "int64"))
+
+    _check_inference(bb, relax.op.take(x0, idx, axis=0), 
relax.TensorStructInfo([10], "float32"))
+    _check_inference(bb, relax.op.take(x0, idx, axis=1), 
relax.TensorStructInfo([4], "float32"))
+
+
+def test_take_infer_struct_info_prim_value_index():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((4, 10), "float32"))
+    idx = relax.Var("idx", R.Prim("int64"))
+
+    _check_inference(bb, relax.op.take(x0, idx, axis=0), 
relax.TensorStructInfo([10], "float32"))
+    _check_inference(bb, relax.op.take(x0, idx, axis=1), 
relax.TensorStructInfo([4], "float32"))
+
+
 def test_take_infer_struct_info_shape_symbolic():
     bb = relax.BlockBuilder()
     m = tir.Var("m", "int64")
diff --git a/tests/python/relax/test_op_take.py 
b/tests/python/relax/test_op_take.py
new file mode 100644
index 0000000000..babf91869a
--- /dev/null
+++ b/tests/python/relax/test_op_take.py
@@ -0,0 +1,158 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import tvm
+import tvm.testing
+from tvm.script import ir as I, relax as R, tir as T
+
+import numpy as np
+
+axis = tvm.testing.parameter(0, 1)
+
+
[email protected]_targets("llvm")
+def test_take_scalar_tensor_as_index(target, dev, axis):
+    """The index of R.take may be a scalar tensor
+
+    Using a scalar tensor as the index reduces the dimension of the
+    output.
+
+    """
+
+    @I.ir_module
+    class Module:
+        @R.function
+        def main(A: R.Tensor([16, 16], "float16")):
+            output = R.take(A, R.const(1), axis=axis)
+            return output
+
+    built = tvm.relax.build(Module, target=target)
+    vm = tvm.relax.VirtualMachine(built, dev)
+
+    np_input = np.random.random(size=[16, 16]).astype("float16")
+    tvm_input = tvm.nd.array(np_input, dev)
+    tvm_output = vm["main"](tvm_input)
+    np_expected = np_input.take(1, axis=axis)
+
+    tvm.testing.assert_allclose(tvm_output.numpy(), np_expected)
+
+
[email protected]_targets("llvm")
+def test_take_1d_tensor_as_index(target, dev, axis):
+    """The index of R.take may be a non-scalar tensor
+
+    In general, `R.take` outputs a tensor of dimension
+    `data.ndim + indices.ndim - 1`.
+    """
+
+    @I.ir_module
+    class Module:
+        @R.function
+        def main(A: R.Tensor([16, 16], "float16")):
+            output = R.take(A, R.const([1]), axis=axis)
+            return output
+
+    built = tvm.relax.build(Module, target=target)
+    vm = tvm.relax.VirtualMachine(built, dev)
+
+    np_input = np.random.random(size=[16, 16]).astype("float16")
+    tvm_input = tvm.nd.array(np_input, dev)
+    tvm_output = vm["main"](tvm_input)
+    np_expected = np_input.take([1], axis=axis)
+
+    tvm.testing.assert_allclose(tvm_output.numpy(), np_expected)
+
+
[email protected]_targets("llvm")
+def test_take_2d_tensor_as_index(target, dev, axis):
+    """The index of R.take may be a 2-d tensor"""
+
+    @I.ir_module
+    class Module:
+        @R.function
+        def main(A: R.Tensor([16, 16], "float16")):
+            output = R.take(A, R.const([[1, 3], [5, 7]]), axis=axis)
+            return output
+
+    built = tvm.relax.build(Module, target=target)
+    vm = tvm.relax.VirtualMachine(built, dev)
+
+    np_input = np.random.random(size=[16, 16]).astype("float16")
+    tvm_input = tvm.nd.array(np_input, dev)
+    tvm_output = vm["main"](tvm_input)
+    np_expected = np_input.take([[1, 3], [5, 7]], axis=axis)
+
+    tvm.testing.assert_allclose(tvm_output.numpy(), np_expected)
+
+
[email protected]_targets("llvm")
+def test_take_constant_prim_value_as_index(target, dev, axis):
+    """The index of R.take may be a R.prim_value
+
+    The `R.prim_value` produces output equivalent to a scalar
+    tensor.
+
+    """
+
+    @I.ir_module
+    class Module:
+        @R.function
+        def main(A: R.Tensor([16, 16], "float16")):
+            output = R.take(A, R.prim_value(1), axis=axis)
+            return output
+
+    built = tvm.relax.build(Module, target=target)
+    vm = tvm.relax.VirtualMachine(built, dev)
+
+    np_input = np.random.random(size=[16, 16]).astype("float16")
+    tvm_input = tvm.nd.array(np_input, dev)
+    tvm_output = vm["main"](tvm_input)
+    np_expected = np_input.take(1, axis=axis)
+
+    tvm.testing.assert_allclose(tvm_output.numpy(), np_expected)
+
+
[email protected]_targets("llvm")
+def test_take_dynamic_prim_value_as_index(target, dev, axis):
+    """The index of R.take may be a dynamic R.prim_value
+
+    The `R.prim_value` produces output equivalent to a scalar
+    tensor.
+
+    """
+
+    @I.ir_module
+    class Module:
+        @R.function
+        def main(A: R.Tensor(["n", "n"], "float16")):
+            n = T.int64()
+            output = R.take(A, R.prim_value(n - 1), axis=axis)
+            return output
+
+    built = tvm.relax.build(Module, target=target)
+    vm = tvm.relax.VirtualMachine(built, dev)
+
+    np_input = np.random.random(size=[16, 16]).astype("float16")
+    tvm_input = tvm.nd.array(np_input, dev)
+    tvm_output = vm["main"](tvm_input)
+    np_expected = np_input.take(15, axis=axis)
+
+    tvm.testing.assert_allclose(tvm_output.numpy(), np_expected)
+
+
+if __name__ == "__main__":
+    tvm.testing.main()
diff --git 
a/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py 
b/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py
index 0d1e969b35..d0aaddb1ca 100644
--- a/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py
+++ b/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py
@@ -55,6 +55,68 @@ def test_take():
     tvm.ir.assert_structural_equal(mod, Expected)
 
 
+def test_take_prim_value():
+    # fmt: off
+    @tvm.script.ir_module
+    class Take:
+        @R.function
+        def main(x: R.Tensor((2, 3, 4), "float32"), index: R.Prim("int64")) -> 
R.Tensor((2, 4), "float32"):
+            gv: R.Tensor((2, 4), "float32") = R.take(x, index, axis=1)
+            return gv
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(x: R.Tensor((2, 3, 4), "float32"), index: R.Prim("int64")) -> 
R.Tensor((2, 4), "float32"):
+            gv = R.call_tir(Expected.take, (x, index), R.Tensor((2, 4), 
dtype="float32"))
+            return gv
+
+        @T.prim_func(private=True)
+        def take(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4)), 
"float32"), index: T.int64, T_take: T.Buffer((T.int64(2), T.int64(4)), 
"float32")):
+            T.func_attr({"tir.noalias": True})
+            for i0, i2 in T.grid(T.int64(2), T.int64(4)):
+                with T.block("T_take"):
+                    ax0, ax2 = T.axis.remap("SS", [i0, i2])
+                    T.reads(rxplaceholder[ax0, index, ax2])
+                    T.writes(T_take[ax0, ax2])
+                    T_take[ax0, ax2] = rxplaceholder[ax0, index, ax2]
+    # fmt: on
+
+    mod = LegalizeOps()(Take)
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_take_const_prim_value():
+    # fmt: off
+    @tvm.script.ir_module
+    class Take:
+        @R.function
+        def main(x: R.Tensor((2, 3, 4), "float32")) -> R.Tensor((2, 4), 
"float32"):
+            gv: R.Tensor((2, 4), "float32") = R.take(x, R.prim_value(0), 
axis=1)
+            return gv
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(x: R.Tensor((2, 3, 4), "float32")) -> R.Tensor((2, 4), 
"float32"):
+            gv = R.call_tir(Expected.take, (x,), R.Tensor((2, 4), 
dtype="float32"))
+            return gv
+
+        @T.prim_func(private=True)
+        def take(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4)), 
"float32"), T_take: T.Buffer((T.int64(2), T.int64(4)), "float32")):
+            T.func_attr({"tir.noalias": True})
+            for i0, i2 in T.grid(T.int64(2), T.int64(4)):
+                with T.block("T_take"):
+                    ax0, ax2 = T.axis.remap("SS", [i0, i2])
+                    T.reads(rxplaceholder[ax0, T.int64(0), ax2])
+                    T.writes(T_take[ax0, ax2])
+                    T_take[ax0, ax2] = rxplaceholder[ax0, T.int64(0), ax2]
+    # fmt: on
+
+    mod = LegalizeOps()(Take)
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
 def test_take_symbolic():
     # fmt: off
     @tvm.script.ir_module
@@ -96,6 +158,41 @@ def test_take_symbolic():
     tvm.ir.assert_structural_equal(mod, Expected)
 
 
+def test_take_symbolic_prim_value():
+    # fmt: off
+    @tvm.script.ir_module
+    class Take:
+        @R.function
+        def main(x: R.Tensor((2, "n", 4), "float32")) -> R.Tensor((2, 4), 
"float32"):
+            n = T.int64()
+            gv: R.Tensor((2, 4), "float32") = R.take(x, R.prim_value(n-1), 
axis=1)
+            return gv
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(x: R.Tensor((2, "n", 4), "float32")) -> R.Tensor((2, 4), 
"float32"):
+            gv = R.call_tir(Expected.take, (x,), R.Tensor((2, 4), 
dtype="float32"))
+            return gv
+
+        @T.prim_func(private=True)
+        def take(x_handle: T.handle, T_take: T.Buffer((T.int64(2), 
T.int64(4)), "float32")):
+            n = T.int64()
+            rxplaceholder = T.match_buffer(x_handle, (T.int64(2), n, 
T.int64(4)), "float32")
+
+            T.func_attr({"tir.noalias": True})
+            for i0, i2 in T.grid(T.int64(2), T.int64(4)):
+                with T.block("T_take"):
+                    ax0, ax2 = T.axis.remap("SS", [i0, i2])
+                    T.reads(rxplaceholder[ax0, n-1, ax2])
+                    T.writes(T_take[ax0, ax2])
+                    T_take[ax0, ax2] = rxplaceholder[ax0, n-1, ax2]
+    # fmt: on
+
+    mod = LegalizeOps()(Take)
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
 def test_strided_slice():
     # fmt: off
     @tvm.script.ir_module

Reply via email to