This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new b5b0337568 [Relax][PyTorch] support for index.Tensor (#17836)
b5b0337568 is described below
commit b5b0337568f4aa6912c9c48f9e03027fc140e9af
Author: Hugo Latendresse <[email protected]>
AuthorDate: Mon Apr 21 10:59:51 2025 -0400
[Relax][PyTorch] support for index.Tensor (#17836)
New op for advanced indexing + unit tests
---
.../frontend/torch/base_fx_graph_translator.py | 5 +
.../frontend/torch/exported_program_translator.py | 1 +
python/tvm/relax/op/__init__.py | 1 +
python/tvm/relax/op/manipulate.py | 63 +++++++++
.../tvm/relax/transform/legalize_ops/manipulate.py | 9 ++
python/tvm/script/ir_builder/relax/ir.py | 2 +
python/tvm/topi/transform.py | 51 ++++++++
src/relax/op/tensor/manipulate.cc | 145 +++++++++++++++++++++
src/relax/op/tensor/manipulate.h | 12 ++
tests/python/relax/test_from_exported_to_cuda.py | 141 +++++++++++++++++++-
10 files changed, 424 insertions(+), 6 deletions(-)
diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index 733a5d6b1a..13d13ff24c 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -1148,6 +1148,11 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
index = self.env[node.args[2]]
return self.block_builder.emit(relax.op.gather_elements(x, index,
axis=dim))
+ def _index_tensor(self, node: fx.Node) -> relax.Var:
+ args = self.retrieve_args(node)
+ indices = args[1]
+ return self.block_builder.emit(relax.op.index_tensor(args[0], indices))
+
def _permute(self, node: fx.Node) -> relax.Var:
import torch # type: ignore
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index af1393329e..ab55ded36c 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -420,6 +420,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
"flatten.using_ints": self._flatten,
"flip.default": self._flip,
"gather.default": self._gather,
+ "index.Tensor": self._index_tensor,
"narrow.default": self._narrow,
"permute.default": self._permute,
"repeat.default": self._repeat,
diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py
index 3145a7c292..097313a33d 100644
--- a/python/tvm/relax/op/__init__.py
+++ b/python/tvm/relax/op/__init__.py
@@ -95,6 +95,7 @@ from .manipulate import (
flip,
gather_elements,
gather_nd,
+ index_tensor,
layout_transform,
one_hot,
permute_dims,
diff --git a/python/tvm/relax/op/manipulate.py
b/python/tvm/relax/op/manipulate.py
index 725e58bd01..a693adf432 100644
--- a/python/tvm/relax/op/manipulate.py
+++ b/python/tvm/relax/op/manipulate.py
@@ -532,6 +532,69 @@ def gather_nd(data: Expr, indices: Expr, batch_dims: int =
0) -> Expr:
return _ffi_api.gather_nd(data, indices, batch_dims) # type: ignore
+def index_tensor(data: Expr, indices: Union[Expr, List[Expr]]) -> Expr:
+ """Advanced‑tensor indexing (NumPy/PyTorch‐style).
+
+ Given k index tensors ``indices = (I0, I1, …, Ik‑1)`` this
+ operator selects elements from ``data`` as if one had written
+ ``data[I0, I1, …, Ik‑1]`` in NumPy/PyTorch:
+
+ All index tensors must have an integer dtype.
+
+ Their shapes are broadcast together to a common shape ``B`` in
+ the usual NumPy way.
+
+ The result shape is ``B + data.shape[k:]`` (i.e. the broadcast
+ shape followed by the remaining axes of ``data`` that are *not*
+ indexed).
+
+ At compile‑time Relax checks that the number of index tensors
+ ``k`` does not exceed ``data.ndim``, that the dtypes are integer,
+ and that the shapes are consitent (broadcast‑compatible).
+
+ Parameters
+ ----------
+ data : relax.Expr
+ The input tensor to be indexed.
+
+ indices : Union[relax.Expr, List[relax.Expr]]
+ A Tuple expression containing the index tensors,
+ or a Python ``list`` / ``tuple`` that will be promoted to a
+ tuple expression automatically. Each tensor must have an
+ integer dtype.
+
+ Returns
+ -------
+ result : relax.Expr
+ The tensor obtained after advanced indexing. Its dtype equals
+ ``data.dtype``
+
+ Examples
+ --------
+ .. code-block:: python
+
+ import numpy as np
+ import tvm.relax as R
+
+ x = R.const(np.arange(9).reshape(3, 3).astype("float32"))
+ row = R.const(np.array([0, 2])) # shape (2,)
+ col = R.const(np.array([1, 0])) # shape (2,)
+
+ y = R.index_tensor(x, [row, col])
+ # y.shape == (2,) ; y == [1., 6.]
+
+ # Broadcasting: row : (2,1), col : (1,3) → B = (2,3)
+ row = R.const(np.array([[0],[1]]))
+ col = R.const(np.array([[0,1,2]]))
+ z = R.index_tensor(x, [row, col])
+ # z.shape == (2,3)
+
+ """
+ if isinstance(indices, (list, tuple)):
+ indices = RxTuple(indices)
+ return _ffi_api.index_tensor(data, indices) # type: ignore
+
+
def scatter_elements(
data: Expr, indices: Expr, updates: Expr, axis: int = 0, reduction: str =
"update"
):
diff --git a/python/tvm/relax/transform/legalize_ops/manipulate.py
b/python/tvm/relax/transform/legalize_ops/manipulate.py
index a481d7af95..84baa887d9 100644
--- a/python/tvm/relax/transform/legalize_ops/manipulate.py
+++ b/python/tvm/relax/transform/legalize_ops/manipulate.py
@@ -49,6 +49,7 @@ register_legalize(
"relax.collapse_sum_like",
_reshape(topi.collapse_sum, "collapse_sum", is_collapse_sum_like=True),
)
+
register_legalize("relax.collapse_sum_to", _reshape(topi.collapse_sum,
"collapse_sum"))
@@ -184,6 +185,14 @@ def _gather_nd(bb: BlockBuilder, call: Call) -> Expr:
return bb.call_te(te_gather_nd, call.args[0], call.args[1],
int(call.attrs.batch_dims))
+@register_legalize("relax.index_tensor")
+def _index_tensor(bb: BlockBuilder, call: Call) -> Expr:
+ t = call.args[1]
+ n_field = len(t.struct_info.fields)
+ fields = [bb.emit(TupleGetItem(t, i)) for i in range(n_field)]
+ return bb.call_te(topi.index_tensor, call.args[0], fields)
+
+
@register_legalize("relax.scatter_elements")
def _scatter_elements(bb: BlockBuilder, call: Call) -> Expr:
return bb.call_te(
diff --git a/python/tvm/script/ir_builder/relax/ir.py
b/python/tvm/script/ir_builder/relax/ir.py
index 79b1884aac..22b00cd704 100644
--- a/python/tvm/script/ir_builder/relax/ir.py
+++ b/python/tvm/script/ir_builder/relax/ir.py
@@ -101,6 +101,7 @@ from tvm.relax.op import (
greater_equal,
hint_on_device,
image,
+ index_tensor,
invoke_closure,
invoke_pure_closure,
isfinite,
@@ -785,6 +786,7 @@ __all__ = [
"hexagon",
"hint_on_device",
"image",
+ "index_tensor",
"invoke_closure",
"invoke_pure_closure",
"isfinite",
diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py
index 37743e97a3..1ef6523059 100644
--- a/python/tvm/topi/transform.py
+++ b/python/tvm/topi/transform.py
@@ -1054,3 +1054,54 @@ def trilu(data, k, upper):
return tvm.tir.Select(check_position, value, tvm.tir.const(0,
data.dtype))
return te.compute(data.shape, _apply_trilu, name="trilu",
tag=topi.tag.ELEMWISE)
+
+
+def index_tensor(data, indices):
+ """Advanced‑tensor indexing (NumPy/PyTorch‐style).
+
+ Given k index tensors ``indices = (I0, I1, …, Ik‑1)`` this
+ operator selects elements from ``data`` as if one had written
+ ``data[I0, I1, …, Ik‑1]`` in NumPy/PyTorch:
+
+ * All index tensors must have an integer dtype.
+ * Their shapes are broadcast together to a common shape ``B`` in
+ the usual NumPy way.
+ * The result shape is ``B + data.shape[k:]`` (i.e. the broadcast
+ shape followed by the remaining axes of ``data`` that are *not*
+ indexed).
+ * ``k`` must not exceed ``data.ndim``; otherwise a compile‑time
+ error is raised.
+
+ Parameters
+ ----------
+ data : tvm.te.Tensor
+ The tensor to be indexed.
+
+ indices : Sequence[tvm.te.Tensor]
+ A Python ``list`` / ``tuple`` of **k** index tensors,
+ or a `tvm.te.Tensor` tuple expression. Each tensor must have an
+ integer dtype.
+
+ Returns
+ -------
+ result : tvm.te.Tensor
+ The tensor obtained after advanced indexing. Its dtype equals
+ ``data.dtype``
+
+ Examples
+ --------
+ .. code-block:: python
+
+ x = te.placeholder((3, 3), name="x") # shape (3,3)
+ row = te.placeholder((2,), name="row", dtype="int32")
+ col = te.placeholder((2,), name="col", dtype="int32")
+
+ # Equivalent to x[row, col] in NumPy / PyTorch
+ y = topi.index_tensor(x, [row, col]) # shape (2,)
+
+ # Broadcasting example:
+ row = te.placeholder((2, 1), name="row", dtype="int32")
+ col = te.placeholder((1, 3), name="col", dtype="int32")
+ z = topi.index_tensor(x, [row, col]) # shape (2, 3)
+ """
+ return topi.adv_index(data, indices)
diff --git a/src/relax/op/tensor/manipulate.cc
b/src/relax/op/tensor/manipulate.cc
index 4abfe01387..f56135a35b 100644
--- a/src/relax/op/tensor/manipulate.cc
+++ b/src/relax/op/tensor/manipulate.cc
@@ -474,6 +474,151 @@ TVM_REGISTER_OP("relax.flatten")
.set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy",
MixedPrecisionPolicyKind::kFollow)
.set_attr<Bool>("FPurity", Bool(true));
+/* relax.index_tensor */
+
+Expr index_tensor(Expr first, Expr tensors) {
+ static const Op& op = Op::Get("relax.index_tensor");
+ return Call(op, {std::move(first), std::move(tensors)}, Attrs(), {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.index_tensor").set_body_typed(index_tensor);
+
+StructInfo InferStructInfoIndexTensor(const Call& call, const BlockBuilder&
ctx) {
+ if (call->args.size() != 2) {
+ ctx->ReportFatal(Diagnostic::Error(call) << "Index.Tensor op should have 2
arguments");
+ }
+
+ TensorStructInfo data_sinfo = GetInputTensorStructInfo(call, 0, ctx);
+ Array<TensorStructInfo> indices_sinfo = GetTensorStructInfoFromTuple(call,
ctx, call->args[1]);
+
+ if (indices_sinfo.empty()) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "index_tensor expects a non‑empty tuple of index
tensors");
+ }
+
+ DataType output_dtype = data_sinfo->dtype;
+ int n_indices = static_cast<int>(indices_sinfo.size());
+ Optional<VDevice> vdev = data_sinfo->vdevice;
+
+ // Indices must be integers
+ for (int i = 0; i < n_indices; ++i) {
+ const auto& s = indices_sinfo[i];
+ if (!s->IsUnknownDtype() && !s->dtype.is_int()) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "index_tensor requires every index tensor to have an
integer dtype; "
+ << "index " << i << " has dtype " << s->dtype);
+ }
+ }
+
+ // Count of indices must be less than or equal to data.ndim
+ if (!data_sinfo->IsUnknownNdim() && n_indices > data_sinfo->ndim) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "index_tensor received " << n_indices
+ << " index tensors, but data has only " <<
data_sinfo->ndim << " dimensions");
+ }
+
+ arith::Analyzer* analyzer = ctx->GetAnalyzer();
+ bool all_index_have_shape_value = true;
+ std::vector<Array<PrimExpr>> index_shapes;
+ int max_index_ndim = 0;
+
+ for (const auto& s : indices_sinfo) {
+ const auto* shp = s->shape.as<ShapeExprNode>();
+ if (!shp) {
+ all_index_have_shape_value = false;
+ } else {
+ index_shapes.push_back(shp->values);
+ max_index_ndim = std::max(max_index_ndim,
static_cast<int>(shp->values.size()));
+ }
+ if (!s->IsUnknownNdim()) {
+ max_index_ndim = std::max(max_index_ndim, s->ndim);
+ }
+ }
+
+ Optional<Array<PrimExpr>> broadcast_shape;
+ bool shape_unknown = !all_index_have_shape_value;
+
+ if (all_index_have_shape_value) {
+ // initialise broadcast result with 1’s
+ Array<PrimExpr> out_shape;
+ for (int i = 0; i < max_index_ndim; ++i) {
+ out_shape.push_back(IntImm(DataType::Int(64), 1));
+ }
+
+ for (const auto& ishape : index_shapes) {
+ int cur_ndim = ishape.size();
+ for (int axis = 0; axis < max_index_ndim; ++axis) {
+ int lhs_axis = max_index_ndim - 1 - axis; // aligned from right
+ int rhs_axis = cur_ndim - 1 - axis;
+ if (rhs_axis < 0) break; // shorter rank – done
+
+ PrimExpr lhs_dim = out_shape[lhs_axis];
+ PrimExpr rhs_dim = ishape[rhs_axis];
+
+ const auto* lhs_int = lhs_dim.as<IntImmNode>();
+ const auto* rhs_int = rhs_dim.as<IntImmNode>();
+
+ // Case 1: current broadcast slot is 1 -> always replace
+ if (lhs_int && lhs_int->value == 1) {
+ out_shape.Set(lhs_axis, rhs_dim);
+ continue;
+ }
+ // Case 2: rhs is 1 -> keep lhs_dim unchanged
+ if (rhs_int && rhs_int->value == 1) {
+ continue;
+ }
+ // Both are non‑one constants: must equal
+ if (lhs_int && rhs_int && lhs_int->value != rhs_int->value) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "index_tensor: cannot broadcast index shapes.
Mismatch at axis "
+ << lhs_axis << ": " << lhs_dim << " vs " <<
rhs_dim);
+ }
+ // Give up if not provablt equal
+ if (!analyzer->CanProveEqual(lhs_dim, rhs_dim)) {
+ shape_unknown = true;
+ break;
+ }
+ }
+ if (shape_unknown) break;
+ }
+
+ if (!shape_unknown) broadcast_shape = out_shape;
+ }
+
+ // Count of dimensions in output
+ int out_ndim = kUnknownNDim;
+ if (!data_sinfo->IsUnknownNdim()) {
+ int tail_ndim = data_sinfo->ndim - n_indices;
+ if (broadcast_shape.defined()) {
+ out_ndim = static_cast<int>(broadcast_shape.value().size()) + tail_ndim;
+ } else if (!shape_unknown) {
+ out_ndim = max_index_ndim + tail_ndim;
+ }
+ }
+
+ // Derive output shape
+ if (broadcast_shape.defined()) {
+ const auto* data_shape_expr = data_sinfo->shape.as<ShapeExprNode>();
+ if (data_shape_expr) {
+ Array<PrimExpr> result_shape = broadcast_shape.value();
+ for (int i = n_indices; i < data_sinfo->ndim; ++i) {
+ result_shape.push_back(data_shape_expr->values[i]);
+ }
+ return TensorStructInfo(ShapeExpr(result_shape), output_dtype, vdev);
+ }
+ }
+
+ // Unknown output shape
+ return TensorStructInfo(output_dtype, out_ndim, vdev);
+}
+
+TVM_REGISTER_OP("relax.index_tensor")
+ .set_num_inputs(2)
+ .add_argument("data", "Tensor", "The input data.")
+ .add_argument("indices", "List of Tensors", "The indices used to index.")
+ .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoIndexTensor)
+ .set_attr<Bool>("FPurity", Bool(true));
+
/* relax.layout_transform */
TVM_REGISTER_NODE_TYPE(LayoutTransformAttrs);
diff --git a/src/relax/op/tensor/manipulate.h b/src/relax/op/tensor/manipulate.h
index 7e5de217bc..4580f9191b 100644
--- a/src/relax/op/tensor/manipulate.h
+++ b/src/relax/op/tensor/manipulate.h
@@ -206,6 +206,18 @@ Expr gather_elements(Expr data, Expr indices, int axis =
0);
*/
Expr gather_nd(Expr data, Expr indices, int batch_dims = 0);
+/*!
+ * \brief NumPy/PyTorch‑style advanced indexing with tensors.
+ * \param data The input tensor.
+ * \param indices A Tuple expression (or list) containing the index tensors.
+ * \return The indexed tensor.
+ *
+ * \note When all shapes are static, Relax checks that the index shapes are
+ * broadcast-compatible. Bounds checking of the values in indices is
+ * deferred to runtime.
+ */
+Expr index_tensor(Expr data, Expr indices);
+
/*!
* \brief Scatter updates into an array according to indices.
* \param data The input tensor.
diff --git a/tests/python/relax/test_from_exported_to_cuda.py
b/tests/python/relax/test_from_exported_to_cuda.py
index e92855885e..76a4bb2039 100644
--- a/tests/python/relax/test_from_exported_to_cuda.py
+++ b/tests/python/relax/test_from_exported_to_cuda.py
@@ -63,6 +63,108 @@ def
assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, tar
np.testing.assert_allclose(actual=actual, desired=desired, rtol=1e-5,
atol=1e-5)
[email protected]_targets("cuda")
+def test_index_tensor(target, dev):
+ class IndexModel0(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x):
+ return x[torch.tensor([0])]
+
+ torch_module = IndexModel0().eval()
+ raw_data = np.random.rand(3, 3).astype("float32")
+ assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module,
target, dev)
+
+ class IndexModel1(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x):
+ return x[torch.tensor([[0]])]
+
+ torch_module = IndexModel1().eval()
+ raw_data = np.random.rand(2, 3).astype("float32")
+ assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module,
target, dev)
+
+ class IndexTensorModel2(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x):
+ return x[torch.tensor([0, 2])]
+
+ torch_module = IndexTensorModel2().eval()
+ raw_data = np.random.rand(3, 4).astype("float32")
+ assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module,
target, dev)
+
+ class IndexTensorModel3(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x):
+ return x[[[[0, 2], [1, 3]]]]
+
+ torch_module = IndexTensorModel3().eval()
+ raw_data = np.random.rand(5, 5, 5).astype("float32")
+ assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module,
target, dev)
+
+ class IndexTensorModel4(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x):
+ return x[[[1, 4]]]
+
+ torch_module = IndexTensorModel4().eval()
+ raw_data = np.random.rand(5, 5, 5).astype("float32")
+ assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module,
target, dev)
+
+ class IndexTensorModel5(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x):
+ return x[[[[1, 2, 4]]]]
+
+ torch_module = IndexTensorModel5().eval()
+ raw_data = np.random.rand(5, 5, 5).astype("float32")
+ assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module,
target, dev)
+
+ class IndexTensorModel6(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x):
+ return x[[[0, 1], [0, 1]]]
+
+ torch_module = IndexTensorModel6().eval()
+ raw_data = np.random.rand(5, 5, 5, 5).astype("float32")
+ assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module,
target, dev)
+
+ class IndexTensorModel7(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x):
+ return x[[[0, 1, 2, 3], [1, 2, 3, 4], [2, 3, 4, 0]]]
+
+ torch_module = IndexTensorModel7().eval()
+ raw_data = np.random.rand(5, 5, 5, 5).astype("float32")
+ assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module,
target, dev)
+
+ class IndexTensorModel8(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x):
+ return x[[[[0, 1], [2, 3]], [[2, 3], [3, 4]], [[2, 4], [1, 2]],
[[0, 4], [0, 3]]]]
+
+ torch_module = IndexTensorModel8().eval()
+ raw_data = np.random.rand(5, 5, 5, 5).astype("float32")
+ assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module,
target, dev)
+
+
@tvm.testing.parametrize_targets("cuda")
def test_full(target, dev):
class FullModel(nn.Module):
@@ -73,9 +175,7 @@ def test_full(target, dev):
return torch.full((2, 3), 3.141592)
torch_module = FullModel().eval()
-
raw_data = np.random.rand(3, 3).astype("float32")
-
assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module,
target, dev)
@@ -91,7 +191,6 @@ def test_full_like(target, dev):
torch_module = FullLike().eval()
raw_data = np.random.rand(2, 3).astype("float32")
-
assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module,
target, dev)
@@ -105,9 +204,7 @@ def test_ones(target, dev):
return torch.ones((2, 3))
torch_module = FullModel().eval()
-
raw_data = np.random.rand(1, 1).astype("float32")
-
assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module,
target, dev)
@@ -583,10 +680,42 @@ def test_sum(target, dev):
return new_vec.sum()
torch_module = SumModel().eval()
-
raw_data = np.random.rand(10, 10, 10).astype("float32")
assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module,
target, dev)
[email protected]_targets("cuda")
+def test_mul(target, dev):
+ class MulModule(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.y = torch.tensor(np.random.rand(2, 3).astype("float32"))
+
+ def forward(self, x):
+ return x.mul(self.y)
+
+ torch_module = MulModule().eval()
+ raw_data = np.random.rand(2, 3).astype("float32")
+ assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module,
target, dev)
+
+
[email protected]_targets("cuda")
+def test_concat(target, dev):
+ class ConcatFour(nn.Module):
+ def __init__(self, dim=0):
+ super(ConcatFour, self).__init__()
+ self.dim = dim
+ self.x2 = torch.randn(2, 3)
+ self.x3 = torch.randn(2, 3)
+ self.x4 = torch.randn(2, 3)
+
+ def forward(self, x):
+ return torch.cat((x, self.x2, self.x3, self.x4), dim=self.dim)
+
+ torch_module = ConcatFour().eval()
+ raw_data = np.random.rand(2, 3).astype("float32")
+ assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module,
target, dev)
+
+
if __name__ == "__main__":
tvm.testing.main()