This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 67e121da89 [Relax] Tensor.split with uneven tensors (#17757)
67e121da89 is described below
commit 67e121da89f5d51f6196eb1827a6095c9b6248aa
Author: Hugo Latendresse <[email protected]>
AuthorDate: Mon Mar 24 10:38:12 2025 -0400
[Relax] Tensor.split with uneven tensors (#17757)
We used to have an assertion that when using torch.split,
the dimension to be split must be an multiple of the split size.
This is not required in pytorch: the last tensors will simply be smaller.
This PR allows that capability.
---
include/tvm/topi/transform.h | 21 +++----
.../frontend/torch/exported_program_translator.py | 1 +
.../tvm/relax/transform/legalize_ops/manipulate.py | 11 ----
src/topi/transform.cc | 4 +-
tests/python/relax/test_from_exported_to_cuda.py | 66 ++++++++++++++++++++--
.../test_transform_legalize_ops_manipulate.py | 50 +++++++++++++---
6 files changed, 115 insertions(+), 38 deletions(-)
diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h
index faacd2ce57..762148dcfa 100644
--- a/include/tvm/topi/transform.h
+++ b/include/tvm/topi/transform.h
@@ -575,8 +575,9 @@ inline Tensor stack(const Array<Tensor>& inputs, int axis =
0, std::string name
*
* \return A Tensor whose op member is the split operation
*/
-inline Array<Tensor> split(const Tensor& x, Array<PrimExpr> split_indices, int
axis,
- std::string name = "T_split", std::string tag =
kInjective) {
+inline Array<Tensor> split_indices_array(const Tensor& x, Array<PrimExpr>
split_indices, int axis,
+ std::string name = "T_split",
+ std::string tag = kInjective) {
if (axis < 0) {
axis += static_cast<int>(x->shape.size());
}
@@ -968,9 +969,9 @@ inline Tensor strided_slice(const Tensor& x, const
Array<Integer>& begin, const
*
* \return A Tensor whose op member is the split operation
*/
-inline Array<Tensor> split_sections(const Tensor& x, int num_sections, int
axis,
- std::string name = "T_split_sections",
- std::string tag = kInjective) {
+inline Array<Tensor> split_n_sections(const Tensor& x, int num_sections, int
axis,
+ std::string name = "T_split_sections",
+ std::string tag = kInjective) {
if (axis < 0) {
axis += static_cast<int>(x->shape.size());
}
@@ -980,14 +981,8 @@ inline Array<Tensor> split_sections(const Tensor& x, int
num_sections, int axis,
ICHECK_GT(num_sections, 0) << "Slice count must be > 0";
- if (auto node = src_axis_size.as<IntImmNode>()) {
- ICHECK_EQ(node->value % num_sections, 0)
- << "num_sections must be an integer factor of the size of axis " <<
axis << " ("
- << node->value << ")";
- }
-
Array<PrimExpr> split_indices;
- auto seg_size = indexdiv(src_axis_size, num_sections);
+ auto seg_size = indexdiv(src_axis_size + num_sections - 1, num_sections);
for (int i = 0; i < num_sections; ++i) {
// region at index 0 is added by split()
if (i != 0) {
@@ -995,7 +990,7 @@ inline Array<Tensor> split_sections(const Tensor& x, int
num_sections, int axis,
}
}
- return split(x, split_indices, axis, name, tag);
+ return split_indices_array(x, split_indices, axis, name, tag);
}
/*!
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index bc7a4c4cb0..2abc0b0248 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -329,6 +329,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
"select.int": self._select,
"slice.Tensor": self._slice,
"split.Tensor": self._split,
+ "split_with_sizes.default": self._split,
"squeeze.default": self._squeeze,
"squeeze.dim": self._squeeze,
"take.default": self._take,
diff --git a/python/tvm/relax/transform/legalize_ops/manipulate.py
b/python/tvm/relax/transform/legalize_ops/manipulate.py
index c71a41dc1c..662d4e946b 100644
--- a/python/tvm/relax/transform/legalize_ops/manipulate.py
+++ b/python/tvm/relax/transform/legalize_ops/manipulate.py
@@ -16,7 +16,6 @@
# under the License.
# pylint: disable=invalid-name
"""Default legalization function for manipulate operators."""
-import logging
from typing import Optional
import tvm
@@ -109,16 +108,6 @@ def _permute_dims(bb: BlockBuilder, call: Call) -> Expr:
def _split(bb: BlockBuilder, call: Call) -> Expr:
if isinstance(call.attrs.indices_or_sections, tir.IntImm):
indices_or_sections = call.attrs.indices_or_sections.value
- modulo = tvm.arith.Analyzer().simplify(
- call.args[0].struct_info.shape.values[call.attrs.axis] %
indices_or_sections
- )
- if isinstance(modulo, tir.IntImm):
- if modulo != 0:
- logging.info(
- "Split cannot be legalized by TOPI when the axis being
split has "
- "length that not divisible by the input number of section."
- )
- return call
else:
indices_or_sections = call.attrs.indices_or_sections
return bb.call_te(topi.split, call.args[0], indices_or_sections,
call.attrs.axis)
diff --git a/src/topi/transform.cc b/src/topi/transform.cc
index 2e0fde3b28..7ef63a9b3f 100644
--- a/src/topi/transform.cc
+++ b/src/topi/transform.cc
@@ -84,9 +84,9 @@ TVM_REGISTER_GLOBAL("topi.ndarray_size").set_body([](TVMArgs
args, TVMRetValue*
TVM_REGISTER_GLOBAL("topi.split").set_body([](TVMArgs args, TVMRetValue* rv) {
if (args[1].type_code() == kDLInt || args[1].type_code() == kDLUInt) {
- *rv = split_sections(args[0], args[1], args[2]);
+ *rv = split_n_sections(args[0], args[1], args[2]);
} else {
- *rv = split(args[0], args[1], args[2]);
+ *rv = split_indices_array(args[0], args[1], args[2]);
}
});
diff --git a/tests/python/relax/test_from_exported_to_cuda.py
b/tests/python/relax/test_from_exported_to_cuda.py
index 6cc12370d6..c120eb8981 100644
--- a/tests/python/relax/test_from_exported_to_cuda.py
+++ b/tests/python/relax/test_from_exported_to_cuda.py
@@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
+
import tvm
from tvm import relax
import tvm.testing
@@ -50,10 +51,17 @@ def
assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, tar
gpu_params = [tvm.nd.array(p, dev) for p in tvm_params["main"]]
gpu_out = vm["main"](gpu_data, *gpu_params)
- pytorch_out = torch_module(torch_data).detach().numpy()
- actual = gpu_out[0].numpy()
- desired = pytorch_out
- np.testing.assert_allclose(actual=actual, desired=desired, rtol=1e-5,
atol=1e-5)
+ pytorch_out = torch_module(torch_data)
+
+ if isinstance(pytorch_out, tuple):
+ for i in range(len(pytorch_out)):
+ actual = gpu_out[i].numpy()
+ desired = pytorch_out[i].detach().numpy()
+ np.testing.assert_allclose(actual=actual, desired=desired,
rtol=1e-5, atol=1e-5)
+ else:
+ actual = gpu_out[0].numpy()
+ desired = pytorch_out.detach().numpy()
+ np.testing.assert_allclose(actual=actual, desired=desired, rtol=1e-5,
atol=1e-5)
@tvm.testing.parametrize_targets("cuda")
@@ -281,5 +289,55 @@ def test_linalg_vector_norm(target, dev):
assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module3,
target, dev)
[email protected]_targets("cuda")
+def test_split_size(target, dev):
+ # Test split using the split_size argument such that it is not a divisor
+ # of the dimension to split (the last tensor will be smaller)
+ batch = 2
+ channels = 7
+ height, width = 2, 2
+ split_size = 3 # last tensor will have just 1 element
+ dim = 1 # split across channels
+ raw_data = np.random.rand(batch, channels, height, width).astype("float32")
+
+ class SplitModelSplitSize(nn.Module):
+ def __init__(self, split_size, dim):
+ super().__init__()
+ self.split_size = split_size
+ self.dim = dim
+
+ def forward(self, x):
+ return torch.split(x, split_size_or_sections=self.split_size,
dim=self.dim)
+
+ torch_module = SplitModelSplitSize(split_size=split_size, dim=dim).eval()
+
+ assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module,
target, dev)
+
+
[email protected]_targets("cuda")
+def test_split_sections_list(target, dev):
+ # Test split using a list of section sizes
+ batch = 3
+ channels = 2
+ height = 10
+ width = 5
+ sections = [3, 2, 5]
+ dim = 2 # split across height
+ raw_data = np.random.rand(batch, channels, height, width).astype("float32")
+
+ class SplitModelSectionsList(nn.Module):
+ def __init__(self, split_size, dim):
+ super().__init__()
+ self.split_size = split_size
+ self.dim = dim
+
+ def forward(self, x):
+ return torch.split(x, split_size_or_sections=self.split_size,
dim=self.dim)
+
+ torch_module = SplitModelSectionsList(split_size=sections, dim=dim).eval()
+
+ assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module,
target, dev)
+
+
if __name__ == "__main__":
tvm.testing.main()
diff --git a/tests/python/relax/test_transform_legalize_ops_manipulate.py
b/tests/python/relax/test_transform_legalize_ops_manipulate.py
index 0565b7a579..4836ffd010 100644
--- a/tests/python/relax/test_transform_legalize_ops_manipulate.py
+++ b/tests/python/relax/test_transform_legalize_ops_manipulate.py
@@ -15,6 +15,10 @@
# specific language governing permissions and limitations
# under the License.
+import sys
+
+sys.path.append("/ssd1/htalendr/tvm/python")
+
import tvm
from tvm import relax
from tvm.relax.transform import LegalizeOps
@@ -788,12 +792,42 @@ def test_split_by_indices_n_section_indivisible():
class Split:
@R.function
def main(x: R.Tensor((2, 10, 4), "float32")) -> R.Tuple([R.Tensor((2,
4, 4), "float32"), R.Tensor((2, 4, 4), "float32"), R.Tensor((2, 2, 4),
"float32")]):
- gv: R.Tuple([R.Tensor((2, 4, 4), "float32"), R.Tensor((2, 4, 4),
"float32"), R.Tensor((2, 2, 4), "float32")]) = R.split(x, 3, axis=1)
+ gv: R.Tuple([R.Tensor((2, 4, 4), "float32"), R.Tensor((2, 4, 4),
"float32"), R.Tensor((2, 2, 4), "float32")]) = R.split(x,
indices_or_sections=3, axis=1)
return gv
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(x: R.Tensor((2, 10, 4), "float32")) -> R.Tuple([R.Tensor((2,
4, 4), "float32"), R.Tensor((2, 4, 4), "float32"), R.Tensor((2, 2, 4),
"float32")]):
+ gv = R.call_tir(Expected.split, (x,), [R.Tensor((2, 4, 4),
"float32"), R.Tensor((2, 4, 4), "float32"), R.Tensor((2, 2, 4), "float32")])
+ return gv
+
+ @T.prim_func(private=True)
+ def split(rxplaceholder: T.Buffer((T.int64(2), T.int64(10),
T.int64(4)), "float32"), T_split_sections: T.Buffer((T.int64(2), T.int64(4),
T.int64(4)), "float32"), T_split_sections_1: T.Buffer((T.int64(2), T.int64(4),
T.int64(4)), "float32"), T_split_sections_2: T.Buffer((T.int64(2), T.int64(2),
T.int64(4)), "float32")):
+ T.func_attr({"tir.noalias": True})
+ for i0, i1, i2 in T.grid(T.int64(2), T.int64(4), T.int64(4)):
+ with T.block("T_split_sections"):
+ ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2])
+ T.reads(rxplaceholder[ax0, ax1, ax2])
+ T.writes(T_split_sections[ax0, ax1, ax2])
+ T_split_sections[ax0, ax1, ax2] = rxplaceholder[ax0, ax1,
ax2]
+ for i0, i1, i2 in T.grid(T.int64(2), T.int64(4), T.int64(4)):
+ with T.block("T_split_sections_1"):
+ ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2])
+ T.reads(rxplaceholder[ax0, ax1 + T.int64(4), ax2])
+ T.writes(T_split_sections_1[ax0, ax1, ax2])
+ T_split_sections_1[ax0, ax1, ax2] = rxplaceholder[ax0, ax1
+ T.int64(4), ax2]
+ for i0, i1, i2 in T.grid(T.int64(2), T.int64(2), T.int64(4)):
+ with T.block("T_split_sections_2"):
+ ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2])
+ T.reads(rxplaceholder[ax0, ax1 + T.int64(8), ax2])
+ T.writes(T_split_sections_2[ax0, ax1, ax2])
+ T_split_sections_2[ax0, ax1, ax2] = rxplaceholder[ax0, ax1
+ T.int64(8), ax2]
+
# fmt: on
mod = LegalizeOps()(Split)
- tvm.ir.assert_structural_equal(mod, Split)
+ tvm.ir.assert_structural_equal(mod, Expected)
def test_split_by_indices_n_section_divisible():
@@ -850,7 +884,7 @@ def test_split_by_indices_n_section_divisible_symbolic():
def main(dumb_param: R.Tensor(("n",)), x: R.Tensor(("m", "(n * 3)"),
"float32")) -> R.Tuple(R.Tensor(("m", "((n * 3) // 3)"), "float32"),
R.Tensor(("m", "((((n * 3) // 3) * 2) - ((n * 3) // 3))"), "float32"),
R.Tensor(("m", "((n * 3) - (((n * 3) // 3) * 2))"), "float32")):
m = T.int64()
n = T.int64()
- gv = R.call_tir(Expected.split, (x,), [R.Tensor((m, ((n * 3) //
3)), "float32"), R.Tensor((m, ((((n * 3) // 3) * 2) - ((n * 3) // 3))),
"float32"), R.Tensor((m, ((n * 3) - (((n * 3) // 3) * 2))), "float32")],
tir_vars=(n,))
+ gv = R.call_tir(Expected.split, (x,), [R.Tensor((m, ((n * 3 + 3 -
1) // 3)), "float32"), R.Tensor((m, ((((n * 3 + 3 - 1) // 3) * 2) - ((n * 3 + 3
- 1) // 3))), "float32"), R.Tensor((m, ((n * 3) - (((n * 3 + 3 - 1) // 3) *
2))), "float32")], tir_vars=R.shape([n]))
return gv
@T.prim_func(private=True)
@@ -858,9 +892,9 @@ def test_split_by_indices_n_section_divisible_symbolic():
T.func_attr({"tir.noalias": True})
m = T.int64()
rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n *
T.int64(3)], dtype="float32")
- T_split_sections = T.match_buffer(var_T_split_sections, [m, n *
T.int64(3) // T.int64(3)], dtype="float32")
- T_split_sections_1 = T.match_buffer(var_T_split_sections_1, [m, n
* T.int64(3) // T.int64(3) * T.int64(2) - n * T.int64(3) // T.int64(3)],
dtype="float32")
- T_split_sections_2 = T.match_buffer(var_T_split_sections_2, [m, n
* T.int64(3) - n * T.int64(3) // T.int64(3) * T.int64(2)], dtype="float32")
+ T_split_sections = T.match_buffer(var_T_split_sections, [m, (n *
T.int64(3) + T.int64(3) - T.int64(1)) // T.int64(3)], dtype="float32")
+ T_split_sections_1 = T.match_buffer(var_T_split_sections_1, [m, (n
* T.int64(3) + T.int64(3) - T.int64(1)) // T.int64(3) * T.int64(2) - (n *
T.int64(3) + T.int64(3) - T.int64(1)) // T.int64(3)], dtype="float32")
+ T_split_sections_2 = T.match_buffer(var_T_split_sections_2, [m, n
* T.int64(3) - (n * T.int64(3) + T.int64(3) - T.int64(1)) // T.int64(3) *
T.int64(2)], dtype="float32")
for i0, i1 in T.grid(m, n):
with T.block("T_split_sections"):
ax0, ax1 = T.axis.remap("SS", [i0, i1])
@@ -870,9 +904,9 @@ def test_split_by_indices_n_section_divisible_symbolic():
for i0, i1 in T.grid(m, n):
with T.block("T_split_sections_1"):
ax0, ax1 = T.axis.remap("SS", [i0, i1])
- T.reads(rxplaceholder[ax0, n + ax1])
+ T.reads(rxplaceholder[ax0, ax1 + n])
T.writes(T_split_sections_1[ax0, ax1])
- T_split_sections_1[ax0, ax1] = rxplaceholder[ax0, n + ax1]
+ T_split_sections_1[ax0, ax1] = rxplaceholder[ax0, ax1 + n]
for i0, i1 in T.grid(m, n):
with T.block("T_split_sections_2"):
ax0, ax1 = T.axis.remap("SS", [i0, i1])