This is an automated email from the ASF dual-hosted git repository.
lunderberg pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 081c23becf [Relax] Allow PrimValue as index in relax.op.take (#16940)
081c23becf is described below
commit 081c23becf190b91a80f82cef2032cce816dc637
Author: Eric Lunderberg <[email protected]>
AuthorDate: Sun Apr 28 12:49:04 2024 -0500
[Relax] Allow PrimValue as index in relax.op.take (#16940)
* [Relax] Allow PrimValue as index in relax.op.take
Prior to this commit, the `relax.op.take` only allowed tensors as the
`indices` argument. This commit extends `R.take` to also allow the
index to be a `relax::PrimValue`.
* Avoid comparison between signed/unsigned
* Resolve/silence gcc warnings
---
include/tvm/relax/block_builder.h | 2 +-
include/tvm/topi/transform.h | 43 ++++--
src/relax/ir/block_builder.cc | 2 +-
src/relax/op/op_common.cc | 52 +++++--
src/relax/op/op_common.h | 21 +++
src/relax/op/tensor/index.cc | 26 +++-
tests/python/relax/test_op_index.py | 18 +++
tests/python/relax/test_op_take.py | 158 +++++++++++++++++++++
..._transform_legalize_ops_index_linear_algebra.py | 97 +++++++++++++
9 files changed, 388 insertions(+), 31 deletions(-)
diff --git a/include/tvm/relax/block_builder.h
b/include/tvm/relax/block_builder.h
index a1e5a6bc31..7ca9aab6d5 100644
--- a/include/tvm/relax/block_builder.h
+++ b/include/tvm/relax/block_builder.h
@@ -116,7 +116,7 @@ class BlockBuilderNode : public Object {
* \brief Report an error during transformation construction.
* \param diagnostic The diagnostic information.
*/
- virtual void ReportFatal(const Diagnostic& diagnostic) = 0;
+ [[noreturn]] virtual void ReportFatal(const Diagnostic& diagnostic) = 0;
//-------------------------------
// Scope management
diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h
index a1f66a70ca..3292ce57ba 100644
--- a/include/tvm/topi/transform.h
+++ b/include/tvm/topi/transform.h
@@ -1036,7 +1036,7 @@ inline Tensor sequence_mask(const Tensor& data, const
Tensor& valid_length, doub
*
* \return A Tensor whose op member is the take operation
*/
-inline Tensor take(const Tensor& a, const Tensor& indices, int batch_dims, int
axis,
+inline Tensor take(const Tensor& a, Variant<Tensor, PrimExpr> indices, int
batch_dims, int axis,
std::string mode = "clip", std::string name = "T_take",
std::string tag = kInjective) {
if (axis < 0) {
@@ -1045,22 +1045,30 @@ inline Tensor take(const Tensor& a, const Tensor&
indices, int batch_dims, int a
ICHECK_GE(axis, 0) << "axis out of bounds";
ICHECK_LT(axis, a->shape.size()) << "axis out of bounds";
auto axis_dim = a->shape[axis];
- int indices_len = static_cast<int>(indices->shape.size());
+ auto indices_shape = [&]() -> Array<PrimExpr> {
+ if (auto tensor = indices.as<TensorNode>()) {
+ return tensor->shape;
+ } else {
+ return {};
+ }
+ }();
+
+ int indices_len = static_cast<int>(indices_shape.size());
int batch_dims_ = batch_dims;
if (batch_dims_ != 0) {
- ICHECK_GE(batch_dims_, -static_cast<int>(indices->shape.size())) <<
"batch_dims out of bounds";
- ICHECK_LE(batch_dims_, indices->shape.size()) << "batch_dims out of
bounds";
+ ICHECK_GE(batch_dims_, -indices_len) << "batch_dims out of bounds";
+ ICHECK_LE(batch_dims_, indices_len) << "batch_dims out of bounds";
if (batch_dims_ < 0) {
- batch_dims_ = indices->shape.size() + batch_dims_;
+ batch_dims_ = indices_len + batch_dims_;
}
ICHECK_LT(batch_dims_, a->shape.size()) << "batch_dims out of bounds";
ICHECK_LE(batch_dims_, axis) << "batch_dims must be less than or equal to
axis";
for (int i = 0; i < batch_dims_; ++i) {
auto addr1 = a->shape[i];
- auto addr2 = indices->shape[i];
+ auto addr2 = indices_shape[i];
auto v1 = static_cast<IntImm*>(&addr1)->get()->value;
auto v2 = static_cast<IntImm*>(&addr2)->get()->value;
ICHECK_EQ(v1, v2) << "a.shape[" << i << "] should be equal to
indices.shape[" << i << "]";
@@ -1077,13 +1085,24 @@ inline Tensor take(const Tensor& a, const Tensor&
indices, int batch_dims, int a
for (int i = batch_dims_; i < axis; ++i) {
out_shape.push_back(a->shape[i]);
}
- for (size_t i = static_cast<size_t>(batch_dims_); i < indices->shape.size();
++i) {
- out_shape.push_back(indices->shape[i]);
+ for (int i = batch_dims_; i < indices_len; ++i) {
+ out_shape.push_back(indices_shape[i]);
}
for (size_t i = axis + 1; i < a->shape.size(); ++i) {
out_shape.push_back(a->shape[i]);
}
+ auto get_index = [&](const Array<PrimExpr>& indices_position) -> PrimExpr {
+ if (auto tensor = indices.as<Tensor>()) {
+ return tensor.value()(indices_position);
+ } else if (auto prim = indices.as<PrimExpr>()) {
+ ICHECK_EQ(indices_position.size(), 0);
+ return prim.value();
+ } else {
+ LOG(FATAL) << "Variant did not contain either allowed type";
+ }
+ };
+
if (mode == "clip") {
if (batch_dims_ == 0) {
return compute(
@@ -1097,7 +1116,7 @@ inline Tensor take(const Tensor& a, const Tensor&
indices, int batch_dims, int a
for (size_t j = 0; j < static_cast<size_t>(axis); ++j) {
real_indices.push_back(out_index[j]);
}
- auto idx = tvm::min(tvm::max(0, indices(indices_position)),
axis_dim - 1);
+ auto idx = tvm::min(tvm::max(0, get_index(indices_position)),
axis_dim - 1);
real_indices.push_back(idx);
for (size_t j = axis + indices_len; j < out_index.size(); ++j) {
real_indices.push_back(out_index[j]);
@@ -1120,7 +1139,7 @@ inline Tensor take(const Tensor& a, const Tensor&
indices, int batch_dims, int a
for (size_t j = 0; j < static_cast<size_t>(axis); ++j) {
real_indices.push_back(out_index[j]);
}
- auto idx = tvm::min(tvm::max(0, indices(indices_position)),
axis_dim - 1);
+ auto idx = tvm::min(tvm::max(0, get_index(indices_position)),
axis_dim - 1);
real_indices.push_back(idx);
for (size_t j = axis + indices_len - batch_dims_; j <
out_index.size(); ++j) {
real_indices.push_back(out_index[j]);
@@ -1141,7 +1160,7 @@ inline Tensor take(const Tensor& a, const Tensor&
indices, int batch_dims, int a
for (size_t j = 0; j < static_cast<size_t>(axis); ++j) {
real_indices.push_back(out_index[j]);
}
- real_indices.push_back(indices(indices_position));
+ real_indices.push_back(get_index(indices_position));
for (size_t j = axis + indices_len; j < out_index.size(); ++j) {
real_indices.push_back(out_index[j]);
}
@@ -1160,7 +1179,7 @@ inline Tensor take(const Tensor& a, const Tensor&
indices, int batch_dims, int a
for (size_t j = 0; j < static_cast<size_t>(axis); ++j) {
real_indices.push_back(out_index[j]);
}
- auto idx = truncmod(truncmod(indices(indices_position), axis_dim) +
axis_dim, axis_dim);
+ auto idx = truncmod(truncmod(get_index(indices_position), axis_dim)
+ axis_dim, axis_dim);
real_indices.push_back(idx);
for (size_t j = axis + indices_len; j < out_index.size(); ++j) {
real_indices.push_back(out_index[j]);
diff --git a/src/relax/ir/block_builder.cc b/src/relax/ir/block_builder.cc
index 0c40c4e62a..e9a513c317 100644
--- a/src/relax/ir/block_builder.cc
+++ b/src/relax/ir/block_builder.cc
@@ -149,7 +149,7 @@ class BlockBuilderImpl : public BlockBuilderNode {
}
}
- void ReportFatal(const Diagnostic& diagnostic) final {
+ [[noreturn]] void ReportFatal(const Diagnostic& diagnostic) final {
// TODO(relax-team): Print more context information by looking
// into the diagnostic->loc and surrounding IRModule.
// We do not materialzie DiagnosticContext to avoid double referencing to
diff --git a/src/relax/op/op_common.cc b/src/relax/op/op_common.cc
index b35bd4b5a3..56bf708f5e 100644
--- a/src/relax/op/op_common.cc
+++ b/src/relax/op/op_common.cc
@@ -35,24 +35,48 @@ Array<Expr> GetCallArgs(const Call& call) {
return args;
}
-Array<TensorStructInfo> GetInputTensorStructInfo(const Call& call, const
BlockBuilder& ctx) {
+void CheckNumArguments(const Call& call, const BlockBuilder& ctx) {
Op op = Downcast<Op>(call->op);
- int n_input = op->arguments.size();
- if (static_cast<int>(call->args.size()) != n_input) {
+ int expected_input = op->arguments.size();
+ if (static_cast<int>(call->args.size()) != expected_input) {
ctx->ReportFatal(Diagnostic::Error(call)
- << op << " op should have " << n_input << " arguments");
+ << "Operator " << op << " expects " << expected_input <<
" arguments"
+ << ", but was called with " << call->args.size() << "
arguments");
}
+}
+
+TensorStructInfo GetInputTensorStructInfo(const Call& call, size_t i_arg,
const BlockBuilder& ctx) {
+ Op op = Downcast<Op>(call->op);
+
+ ICHECK_EQ(op->arguments.size(), call->args.size())
+ << "Failure caught by this check "
+ << "should have previously been caught by `CheckNumArguments`";
+ ICHECK_LT(i_arg, op->arguments.size());
+
+ auto arg = call->args[i_arg];
+ auto sinfo = GetStructInfo(arg);
+
+ if (auto tensor_sinfo = sinfo.as<TensorStructInfo>()) {
+ return tensor_sinfo.value();
+ } else {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "Operator " << op << " requires argument " << i_arg <<
" ("
+ << op->arguments[i_arg]->name << ") to be a tensor. "
+ << "However, the argument " << arg << " is instead of
type " << sinfo);
+ // Unreachable, but [[noreturn]] attribute on virtual function
+ // `ReportFatal` is insufficient to silence -Wreturn-type, as
+ // child class might not be [[noreturn]].
+ return TensorStructInfo();
+ }
+}
+
+Array<TensorStructInfo> GetInputTensorStructInfo(const Call& call, const
BlockBuilder& ctx) {
+ CheckNumArguments(call, ctx);
+
+ Op op = Downcast<Op>(call->op);
Array<TensorStructInfo> input_tensor_sinfo;
- input_tensor_sinfo.reserve(n_input);
- for (int i = 0; i < n_input; ++i) {
- const auto* sinfo = GetStructInfoAs<TensorStructInfoNode>(call->args[i]);
- if (sinfo == nullptr) {
- ctx->ReportFatal(Diagnostic::Error(call)
- << op << " requires the input " <<
op->arguments[i]->name
- << " to be Tensor. However, the given one has a "
- << call->args[i]->struct_info_->GetTypeKey());
- }
- input_tensor_sinfo.push_back(GetRef<TensorStructInfo>(sinfo));
+ for (size_t i = 0; i < call->args.size(); ++i) {
+ input_tensor_sinfo.push_back(GetInputTensorStructInfo(call, i, ctx));
}
return input_tensor_sinfo;
}
diff --git a/src/relax/op/op_common.h b/src/relax/op/op_common.h
index 5e19edb47c..94474ce784 100644
--- a/src/relax/op/op_common.h
+++ b/src/relax/op/op_common.h
@@ -44,6 +44,27 @@ namespace relax {
/************ Op input struct info getter ************/
+/*!
+ * \brief Check that the operator has
+ *
+ * Verify that the number of arguments matches the expected number for
+ * the operator.
+ *
+ * \param call The context Call to the operator.
+ *
+ * \param ctx The error reporting context.
+ */
+void CheckNumArguments(const Call& call, const BlockBuilder& ctx);
+
+/*!
+ * \brief Get the tensor struct info of the operator input.
+ * \param call The context Call to the operator.
+ * \param i_arg The index of the argument to check
+ * \param ctx The error reporting context.
+ * \return The tensor struct info of the argument
+ */
+TensorStructInfo GetInputTensorStructInfo(const Call& call, size_t i_arg,
const BlockBuilder& ctx);
+
/*!
* \brief Get the tensor struct info of the operator input.
* \param call The context Call to the operator.
diff --git a/src/relax/op/tensor/index.cc b/src/relax/op/tensor/index.cc
index 7ab98e9468..d052c2a64f 100644
--- a/src/relax/op/tensor/index.cc
+++ b/src/relax/op/tensor/index.cc
@@ -44,9 +44,29 @@ Expr take(Expr x, Expr indices, Optional<Integer> axis) {
TVM_REGISTER_GLOBAL("relax.op.take").set_body_typed(take);
StructInfo InferStructInfoTake(const Call& call, const BlockBuilder& ctx) {
- Array<TensorStructInfo> input_sinfo = GetInputTensorStructInfo(call, ctx);
- TensorStructInfo data_sinfo = input_sinfo[0];
- TensorStructInfo indices_sinfo = input_sinfo[1];
+ CheckNumArguments(call, ctx);
+ TensorStructInfo data_sinfo = GetInputTensorStructInfo(call, 0, ctx);
+
+ // StructInfo inference when the index is a PrimValue is equivalent
+ // to that of a scalar (0-d) tensor.
+ TensorStructInfo indices_sinfo = [&]() {
+ auto arg = call->args[1];
+ auto sinfo = GetStructInfo(arg);
+ if (auto tensor_sinfo = sinfo.as<TensorStructInfo>()) {
+ return tensor_sinfo.value();
+ } else if (auto prim_sinfo = sinfo.as<PrimStructInfoNode>()) {
+ return TensorStructInfo(ShapeExpr(Array<PrimExpr>{}), prim_sinfo->dtype);
+ } else {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "Operator " << call->op << " requires the indices
argument to be "
+ << "either a tensor or a scalar value. "
+ << "However, argument " << arg << " has struct info "
<< sinfo);
+ // Unreachable, but [[noreturn]] attribute on virtual function
+ // `ReportFatal` is insufficient to silence -Wreturn-type, as
+ // child class might not be [[noreturn]].
+ return TensorStructInfo();
+ }
+ }();
if (indices_sinfo->IsUnknownDtype()) {
// TODO(tvm-team): Do we have an equivalent of `ctx->ReportFatal` for
warning?
diff --git a/tests/python/relax/test_op_index.py
b/tests/python/relax/test_op_index.py
index e3c9e4a596..1455b4182a 100644
--- a/tests/python/relax/test_op_index.py
+++ b/tests/python/relax/test_op_index.py
@@ -194,6 +194,24 @@ def test_take_infer_struct_info():
_check_inference(bb, relax.op.take(y3, idx7),
relax.TensorStructInfo(dtype="", ndim=2))
+def test_take_infer_struct_info_scalar_tensor_index():
+ bb = relax.BlockBuilder()
+ x0 = relax.Var("x", R.Tensor((4, 10), "float32"))
+ idx = relax.Var("idx", R.Tensor([], "int64"))
+
+ _check_inference(bb, relax.op.take(x0, idx, axis=0),
relax.TensorStructInfo([10], "float32"))
+ _check_inference(bb, relax.op.take(x0, idx, axis=1),
relax.TensorStructInfo([4], "float32"))
+
+
+def test_take_infer_struct_info_prim_value_index():
+ bb = relax.BlockBuilder()
+ x0 = relax.Var("x", R.Tensor((4, 10), "float32"))
+ idx = relax.Var("idx", R.Prim("int64"))
+
+ _check_inference(bb, relax.op.take(x0, idx, axis=0),
relax.TensorStructInfo([10], "float32"))
+ _check_inference(bb, relax.op.take(x0, idx, axis=1),
relax.TensorStructInfo([4], "float32"))
+
+
def test_take_infer_struct_info_shape_symbolic():
bb = relax.BlockBuilder()
m = tir.Var("m", "int64")
diff --git a/tests/python/relax/test_op_take.py
b/tests/python/relax/test_op_take.py
new file mode 100644
index 0000000000..babf91869a
--- /dev/null
+++ b/tests/python/relax/test_op_take.py
@@ -0,0 +1,158 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import tvm
+import tvm.testing
+from tvm.script import ir as I, relax as R, tir as T
+
+import numpy as np
+
+axis = tvm.testing.parameter(0, 1)
+
+
[email protected]_targets("llvm")
+def test_take_scalar_tensor_as_index(target, dev, axis):
+ """The index of R.take may be a scalar tensor
+
+ Using a scalar tensor as the index reduces the dimension of the
+ output.
+
+ """
+
+ @I.ir_module
+ class Module:
+ @R.function
+ def main(A: R.Tensor([16, 16], "float16")):
+ output = R.take(A, R.const(1), axis=axis)
+ return output
+
+ built = tvm.relax.build(Module, target=target)
+ vm = tvm.relax.VirtualMachine(built, dev)
+
+ np_input = np.random.random(size=[16, 16]).astype("float16")
+ tvm_input = tvm.nd.array(np_input, dev)
+ tvm_output = vm["main"](tvm_input)
+ np_expected = np_input.take(1, axis=axis)
+
+ tvm.testing.assert_allclose(tvm_output.numpy(), np_expected)
+
+
[email protected]_targets("llvm")
+def test_take_1d_tensor_as_index(target, dev, axis):
+ """The index of R.take may be a non-scalar tensor
+
+ In general, `R.take` outputs a tensor of dimension
+ `data.ndim + indices.ndim - 1`.
+ """
+
+ @I.ir_module
+ class Module:
+ @R.function
+ def main(A: R.Tensor([16, 16], "float16")):
+ output = R.take(A, R.const([1]), axis=axis)
+ return output
+
+ built = tvm.relax.build(Module, target=target)
+ vm = tvm.relax.VirtualMachine(built, dev)
+
+ np_input = np.random.random(size=[16, 16]).astype("float16")
+ tvm_input = tvm.nd.array(np_input, dev)
+ tvm_output = vm["main"](tvm_input)
+ np_expected = np_input.take([1], axis=axis)
+
+ tvm.testing.assert_allclose(tvm_output.numpy(), np_expected)
+
+
[email protected]_targets("llvm")
+def test_take_2d_tensor_as_index(target, dev, axis):
+ """The index of R.take may be a 2-d tensor"""
+
+ @I.ir_module
+ class Module:
+ @R.function
+ def main(A: R.Tensor([16, 16], "float16")):
+ output = R.take(A, R.const([[1, 3], [5, 7]]), axis=axis)
+ return output
+
+ built = tvm.relax.build(Module, target=target)
+ vm = tvm.relax.VirtualMachine(built, dev)
+
+ np_input = np.random.random(size=[16, 16]).astype("float16")
+ tvm_input = tvm.nd.array(np_input, dev)
+ tvm_output = vm["main"](tvm_input)
+ np_expected = np_input.take([[1, 3], [5, 7]], axis=axis)
+
+ tvm.testing.assert_allclose(tvm_output.numpy(), np_expected)
+
+
[email protected]_targets("llvm")
+def test_take_constant_prim_value_as_index(target, dev, axis):
+ """The index of R.take may be a R.prim_value
+
+ The `R.prim_value` produces output equivalent to a scalar
+ tensor.
+
+ """
+
+ @I.ir_module
+ class Module:
+ @R.function
+ def main(A: R.Tensor([16, 16], "float16")):
+ output = R.take(A, R.prim_value(1), axis=axis)
+ return output
+
+ built = tvm.relax.build(Module, target=target)
+ vm = tvm.relax.VirtualMachine(built, dev)
+
+ np_input = np.random.random(size=[16, 16]).astype("float16")
+ tvm_input = tvm.nd.array(np_input, dev)
+ tvm_output = vm["main"](tvm_input)
+ np_expected = np_input.take(1, axis=axis)
+
+ tvm.testing.assert_allclose(tvm_output.numpy(), np_expected)
+
+
[email protected]_targets("llvm")
+def test_take_dynamic_prim_value_as_index(target, dev, axis):
+ """The index of R.take may be a dynamic R.prim_value
+
+ The `R.prim_value` produces output equivalent to a scalar
+ tensor.
+
+ """
+
+ @I.ir_module
+ class Module:
+ @R.function
+ def main(A: R.Tensor(["n", "n"], "float16")):
+ n = T.int64()
+ output = R.take(A, R.prim_value(n - 1), axis=axis)
+ return output
+
+ built = tvm.relax.build(Module, target=target)
+ vm = tvm.relax.VirtualMachine(built, dev)
+
+ np_input = np.random.random(size=[16, 16]).astype("float16")
+ tvm_input = tvm.nd.array(np_input, dev)
+ tvm_output = vm["main"](tvm_input)
+ np_expected = np_input.take(15, axis=axis)
+
+ tvm.testing.assert_allclose(tvm_output.numpy(), np_expected)
+
+
+if __name__ == "__main__":
+ tvm.testing.main()
diff --git
a/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py
b/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py
index 0d1e969b35..d0aaddb1ca 100644
--- a/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py
+++ b/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py
@@ -55,6 +55,68 @@ def test_take():
tvm.ir.assert_structural_equal(mod, Expected)
+def test_take_prim_value():
+ # fmt: off
+ @tvm.script.ir_module
+ class Take:
+ @R.function
+ def main(x: R.Tensor((2, 3, 4), "float32"), index: R.Prim("int64")) ->
R.Tensor((2, 4), "float32"):
+ gv: R.Tensor((2, 4), "float32") = R.take(x, index, axis=1)
+ return gv
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(x: R.Tensor((2, 3, 4), "float32"), index: R.Prim("int64")) ->
R.Tensor((2, 4), "float32"):
+ gv = R.call_tir(Expected.take, (x, index), R.Tensor((2, 4),
dtype="float32"))
+ return gv
+
+ @T.prim_func(private=True)
+ def take(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4)),
"float32"), index: T.int64, T_take: T.Buffer((T.int64(2), T.int64(4)),
"float32")):
+ T.func_attr({"tir.noalias": True})
+ for i0, i2 in T.grid(T.int64(2), T.int64(4)):
+ with T.block("T_take"):
+ ax0, ax2 = T.axis.remap("SS", [i0, i2])
+ T.reads(rxplaceholder[ax0, index, ax2])
+ T.writes(T_take[ax0, ax2])
+ T_take[ax0, ax2] = rxplaceholder[ax0, index, ax2]
+ # fmt: on
+
+ mod = LegalizeOps()(Take)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_take_const_prim_value():
+ # fmt: off
+ @tvm.script.ir_module
+ class Take:
+ @R.function
+ def main(x: R.Tensor((2, 3, 4), "float32")) -> R.Tensor((2, 4),
"float32"):
+ gv: R.Tensor((2, 4), "float32") = R.take(x, R.prim_value(0),
axis=1)
+ return gv
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(x: R.Tensor((2, 3, 4), "float32")) -> R.Tensor((2, 4),
"float32"):
+ gv = R.call_tir(Expected.take, (x,), R.Tensor((2, 4),
dtype="float32"))
+ return gv
+
+ @T.prim_func(private=True)
+ def take(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4)),
"float32"), T_take: T.Buffer((T.int64(2), T.int64(4)), "float32")):
+ T.func_attr({"tir.noalias": True})
+ for i0, i2 in T.grid(T.int64(2), T.int64(4)):
+ with T.block("T_take"):
+ ax0, ax2 = T.axis.remap("SS", [i0, i2])
+ T.reads(rxplaceholder[ax0, T.int64(0), ax2])
+ T.writes(T_take[ax0, ax2])
+ T_take[ax0, ax2] = rxplaceholder[ax0, T.int64(0), ax2]
+ # fmt: on
+
+ mod = LegalizeOps()(Take)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
def test_take_symbolic():
# fmt: off
@tvm.script.ir_module
@@ -96,6 +158,41 @@ def test_take_symbolic():
tvm.ir.assert_structural_equal(mod, Expected)
+def test_take_symbolic_prim_value():
+ # fmt: off
+ @tvm.script.ir_module
+ class Take:
+ @R.function
+ def main(x: R.Tensor((2, "n", 4), "float32")) -> R.Tensor((2, 4),
"float32"):
+ n = T.int64()
+ gv: R.Tensor((2, 4), "float32") = R.take(x, R.prim_value(n-1),
axis=1)
+ return gv
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(x: R.Tensor((2, "n", 4), "float32")) -> R.Tensor((2, 4),
"float32"):
+ gv = R.call_tir(Expected.take, (x,), R.Tensor((2, 4),
dtype="float32"))
+ return gv
+
+ @T.prim_func(private=True)
+ def take(x_handle: T.handle, T_take: T.Buffer((T.int64(2),
T.int64(4)), "float32")):
+ n = T.int64()
+ rxplaceholder = T.match_buffer(x_handle, (T.int64(2), n,
T.int64(4)), "float32")
+
+ T.func_attr({"tir.noalias": True})
+ for i0, i2 in T.grid(T.int64(2), T.int64(4)):
+ with T.block("T_take"):
+ ax0, ax2 = T.axis.remap("SS", [i0, i2])
+ T.reads(rxplaceholder[ax0, n-1, ax2])
+ T.writes(T_take[ax0, ax2])
+ T_take[ax0, ax2] = rxplaceholder[ax0, n-1, ax2]
+ # fmt: on
+
+ mod = LegalizeOps()(Take)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
def test_strided_slice():
# fmt: off
@tvm.script.ir_module