This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 67e121da89 [Relax] Tensor.split with uneven tensors (#17757)
67e121da89 is described below

commit 67e121da89f5d51f6196eb1827a6095c9b6248aa
Author: Hugo Latendresse <[email protected]>
AuthorDate: Mon Mar 24 10:38:12 2025 -0400

    [Relax] Tensor.split with uneven tensors (#17757)
    
    We used to have an assertion that when using torch.split,
    the dimension to be split must be an multiple of the split size.
    This is not required in pytorch: the last tensors will simply be smaller.
    This PR allows that capability.
---
 include/tvm/topi/transform.h                       | 21 +++----
 .../frontend/torch/exported_program_translator.py  |  1 +
 .../tvm/relax/transform/legalize_ops/manipulate.py | 11 ----
 src/topi/transform.cc                              |  4 +-
 tests/python/relax/test_from_exported_to_cuda.py   | 66 ++++++++++++++++++++--
 .../test_transform_legalize_ops_manipulate.py      | 50 +++++++++++++---
 6 files changed, 115 insertions(+), 38 deletions(-)

diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h
index faacd2ce57..762148dcfa 100644
--- a/include/tvm/topi/transform.h
+++ b/include/tvm/topi/transform.h
@@ -575,8 +575,9 @@ inline Tensor stack(const Array<Tensor>& inputs, int axis = 
0, std::string name
  *
  * \return A Tensor whose op member is the split operation
  */
-inline Array<Tensor> split(const Tensor& x, Array<PrimExpr> split_indices, int 
axis,
-                           std::string name = "T_split", std::string tag = 
kInjective) {
+inline Array<Tensor> split_indices_array(const Tensor& x, Array<PrimExpr> 
split_indices, int axis,
+                                         std::string name = "T_split",
+                                         std::string tag = kInjective) {
   if (axis < 0) {
     axis += static_cast<int>(x->shape.size());
   }
@@ -968,9 +969,9 @@ inline Tensor strided_slice(const Tensor& x, const 
Array<Integer>& begin, const
  *
  * \return A Tensor whose op member is the split operation
  */
-inline Array<Tensor> split_sections(const Tensor& x, int num_sections, int 
axis,
-                                    std::string name = "T_split_sections",
-                                    std::string tag = kInjective) {
+inline Array<Tensor> split_n_sections(const Tensor& x, int num_sections, int 
axis,
+                                      std::string name = "T_split_sections",
+                                      std::string tag = kInjective) {
   if (axis < 0) {
     axis += static_cast<int>(x->shape.size());
   }
@@ -980,14 +981,8 @@ inline Array<Tensor> split_sections(const Tensor& x, int 
num_sections, int axis,
 
   ICHECK_GT(num_sections, 0) << "Slice count must be > 0";
 
-  if (auto node = src_axis_size.as<IntImmNode>()) {
-    ICHECK_EQ(node->value % num_sections, 0)
-        << "num_sections must be an integer factor of the size of axis " << 
axis << " ("
-        << node->value << ")";
-  }
-
   Array<PrimExpr> split_indices;
-  auto seg_size = indexdiv(src_axis_size, num_sections);
+  auto seg_size = indexdiv(src_axis_size + num_sections - 1, num_sections);
   for (int i = 0; i < num_sections; ++i) {
     // region at index 0 is added by split()
     if (i != 0) {
@@ -995,7 +990,7 @@ inline Array<Tensor> split_sections(const Tensor& x, int 
num_sections, int axis,
     }
   }
 
-  return split(x, split_indices, axis, name, tag);
+  return split_indices_array(x, split_indices, axis, name, tag);
 }
 
 /*!
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py 
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index bc7a4c4cb0..2abc0b0248 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -329,6 +329,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             "select.int": self._select,
             "slice.Tensor": self._slice,
             "split.Tensor": self._split,
+            "split_with_sizes.default": self._split,
             "squeeze.default": self._squeeze,
             "squeeze.dim": self._squeeze,
             "take.default": self._take,
diff --git a/python/tvm/relax/transform/legalize_ops/manipulate.py 
b/python/tvm/relax/transform/legalize_ops/manipulate.py
index c71a41dc1c..662d4e946b 100644
--- a/python/tvm/relax/transform/legalize_ops/manipulate.py
+++ b/python/tvm/relax/transform/legalize_ops/manipulate.py
@@ -16,7 +16,6 @@
 # under the License.
 # pylint: disable=invalid-name
 """Default legalization function for manipulate operators."""
-import logging
 from typing import Optional
 
 import tvm
@@ -109,16 +108,6 @@ def _permute_dims(bb: BlockBuilder, call: Call) -> Expr:
 def _split(bb: BlockBuilder, call: Call) -> Expr:
     if isinstance(call.attrs.indices_or_sections, tir.IntImm):
         indices_or_sections = call.attrs.indices_or_sections.value
-        modulo = tvm.arith.Analyzer().simplify(
-            call.args[0].struct_info.shape.values[call.attrs.axis] % 
indices_or_sections
-        )
-        if isinstance(modulo, tir.IntImm):
-            if modulo != 0:
-                logging.info(
-                    "Split cannot be legalized by TOPI when the axis being 
split has "
-                    "length that not divisible by the input number of section."
-                )
-                return call
     else:
         indices_or_sections = call.attrs.indices_or_sections
     return bb.call_te(topi.split, call.args[0], indices_or_sections, 
call.attrs.axis)
diff --git a/src/topi/transform.cc b/src/topi/transform.cc
index 2e0fde3b28..7ef63a9b3f 100644
--- a/src/topi/transform.cc
+++ b/src/topi/transform.cc
@@ -84,9 +84,9 @@ TVM_REGISTER_GLOBAL("topi.ndarray_size").set_body([](TVMArgs 
args, TVMRetValue*
 
 TVM_REGISTER_GLOBAL("topi.split").set_body([](TVMArgs args, TVMRetValue* rv) {
   if (args[1].type_code() == kDLInt || args[1].type_code() == kDLUInt) {
-    *rv = split_sections(args[0], args[1], args[2]);
+    *rv = split_n_sections(args[0], args[1], args[2]);
   } else {
-    *rv = split(args[0], args[1], args[2]);
+    *rv = split_indices_array(args[0], args[1], args[2]);
   }
 });
 
diff --git a/tests/python/relax/test_from_exported_to_cuda.py 
b/tests/python/relax/test_from_exported_to_cuda.py
index 6cc12370d6..c120eb8981 100644
--- a/tests/python/relax/test_from_exported_to_cuda.py
+++ b/tests/python/relax/test_from_exported_to_cuda.py
@@ -15,6 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 
+
 import tvm
 from tvm import relax
 import tvm.testing
@@ -50,10 +51,17 @@ def 
assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, tar
     gpu_params = [tvm.nd.array(p, dev) for p in tvm_params["main"]]
     gpu_out = vm["main"](gpu_data, *gpu_params)
 
-    pytorch_out = torch_module(torch_data).detach().numpy()
-    actual = gpu_out[0].numpy()
-    desired = pytorch_out
-    np.testing.assert_allclose(actual=actual, desired=desired, rtol=1e-5, 
atol=1e-5)
+    pytorch_out = torch_module(torch_data)
+
+    if isinstance(pytorch_out, tuple):
+        for i in range(len(pytorch_out)):
+            actual = gpu_out[i].numpy()
+            desired = pytorch_out[i].detach().numpy()
+            np.testing.assert_allclose(actual=actual, desired=desired, 
rtol=1e-5, atol=1e-5)
+    else:
+        actual = gpu_out[0].numpy()
+        desired = pytorch_out.detach().numpy()
+        np.testing.assert_allclose(actual=actual, desired=desired, rtol=1e-5, 
atol=1e-5)
 
 
 @tvm.testing.parametrize_targets("cuda")
@@ -281,5 +289,55 @@ def test_linalg_vector_norm(target, dev):
     assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module3, 
target, dev)
 
 
[email protected]_targets("cuda")
+def test_split_size(target, dev):
+    # Test split using the split_size argument such that it is not a divisor
+    # of the dimension to split (the last tensor will be smaller)
+    batch = 2
+    channels = 7
+    height, width = 2, 2
+    split_size = 3  # last tensor will have just 1 element
+    dim = 1  # split across channels
+    raw_data = np.random.rand(batch, channels, height, width).astype("float32")
+
+    class SplitModelSplitSize(nn.Module):
+        def __init__(self, split_size, dim):
+            super().__init__()
+            self.split_size = split_size
+            self.dim = dim
+
+        def forward(self, x):
+            return torch.split(x, split_size_or_sections=self.split_size, 
dim=self.dim)
+
+    torch_module = SplitModelSplitSize(split_size=split_size, dim=dim).eval()
+
+    assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, 
target, dev)
+
+
[email protected]_targets("cuda")
+def test_split_sections_list(target, dev):
+    # Test split using a list of section sizes
+    batch = 3
+    channels = 2
+    height = 10
+    width = 5
+    sections = [3, 2, 5]
+    dim = 2  # split across height
+    raw_data = np.random.rand(batch, channels, height, width).astype("float32")
+
+    class SplitModelSectionsList(nn.Module):
+        def __init__(self, split_size, dim):
+            super().__init__()
+            self.split_size = split_size
+            self.dim = dim
+
+        def forward(self, x):
+            return torch.split(x, split_size_or_sections=self.split_size, 
dim=self.dim)
+
+    torch_module = SplitModelSectionsList(split_size=sections, dim=dim).eval()
+
+    assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, 
target, dev)
+
+
 if __name__ == "__main__":
     tvm.testing.main()
diff --git a/tests/python/relax/test_transform_legalize_ops_manipulate.py 
b/tests/python/relax/test_transform_legalize_ops_manipulate.py
index 0565b7a579..4836ffd010 100644
--- a/tests/python/relax/test_transform_legalize_ops_manipulate.py
+++ b/tests/python/relax/test_transform_legalize_ops_manipulate.py
@@ -15,6 +15,10 @@
 # specific language governing permissions and limitations
 # under the License.
 
+import sys
+
+sys.path.append("/ssd1/htalendr/tvm/python")
+
 import tvm
 from tvm import relax
 from tvm.relax.transform import LegalizeOps
@@ -788,12 +792,42 @@ def test_split_by_indices_n_section_indivisible():
     class Split:
         @R.function
         def main(x: R.Tensor((2, 10, 4), "float32")) -> R.Tuple([R.Tensor((2, 
4, 4), "float32"), R.Tensor((2, 4, 4), "float32"), R.Tensor((2, 2, 4), 
"float32")]):
-            gv: R.Tuple([R.Tensor((2, 4, 4), "float32"), R.Tensor((2, 4, 4), 
"float32"), R.Tensor((2, 2, 4), "float32")]) = R.split(x, 3, axis=1)
+            gv: R.Tuple([R.Tensor((2, 4, 4), "float32"), R.Tensor((2, 4, 4), 
"float32"), R.Tensor((2, 2, 4), "float32")]) = R.split(x, 
indices_or_sections=3, axis=1)
             return gv
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(x: R.Tensor((2, 10, 4), "float32")) -> R.Tuple([R.Tensor((2, 
4, 4), "float32"), R.Tensor((2, 4, 4), "float32"), R.Tensor((2, 2, 4), 
"float32")]):
+            gv = R.call_tir(Expected.split, (x,), [R.Tensor((2, 4, 4), 
"float32"), R.Tensor((2, 4, 4), "float32"), R.Tensor((2, 2, 4), "float32")])
+            return gv
+
+        @T.prim_func(private=True)
+        def split(rxplaceholder: T.Buffer((T.int64(2), T.int64(10), 
T.int64(4)), "float32"), T_split_sections: T.Buffer((T.int64(2), T.int64(4), 
T.int64(4)), "float32"), T_split_sections_1: T.Buffer((T.int64(2), T.int64(4), 
T.int64(4)), "float32"), T_split_sections_2: T.Buffer((T.int64(2), T.int64(2), 
T.int64(4)), "float32")):
+            T.func_attr({"tir.noalias": True})
+            for i0, i1, i2 in T.grid(T.int64(2), T.int64(4), T.int64(4)):
+                with T.block("T_split_sections"):
+                    ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2])
+                    T.reads(rxplaceholder[ax0, ax1, ax2])
+                    T.writes(T_split_sections[ax0, ax1, ax2])
+                    T_split_sections[ax0, ax1, ax2] = rxplaceholder[ax0, ax1, 
ax2]
+            for i0, i1, i2 in T.grid(T.int64(2), T.int64(4), T.int64(4)):
+                with T.block("T_split_sections_1"):
+                    ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2])
+                    T.reads(rxplaceholder[ax0, ax1 + T.int64(4), ax2])
+                    T.writes(T_split_sections_1[ax0, ax1, ax2])
+                    T_split_sections_1[ax0, ax1, ax2] = rxplaceholder[ax0, ax1 
+ T.int64(4), ax2]
+            for i0, i1, i2 in T.grid(T.int64(2), T.int64(2), T.int64(4)):
+                with T.block("T_split_sections_2"):
+                    ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2])
+                    T.reads(rxplaceholder[ax0, ax1 + T.int64(8), ax2])
+                    T.writes(T_split_sections_2[ax0, ax1, ax2])
+                    T_split_sections_2[ax0, ax1, ax2] = rxplaceholder[ax0, ax1 
+ T.int64(8), ax2]
+
     # fmt: on
 
     mod = LegalizeOps()(Split)
-    tvm.ir.assert_structural_equal(mod, Split)
+    tvm.ir.assert_structural_equal(mod, Expected)
 
 
 def test_split_by_indices_n_section_divisible():
@@ -850,7 +884,7 @@ def test_split_by_indices_n_section_divisible_symbolic():
         def main(dumb_param: R.Tensor(("n",)), x: R.Tensor(("m", "(n * 3)"), 
"float32")) -> R.Tuple(R.Tensor(("m", "((n * 3) // 3)"), "float32"), 
R.Tensor(("m", "((((n * 3) // 3) * 2) - ((n * 3) // 3))"), "float32"), 
R.Tensor(("m", "((n * 3) - (((n * 3) // 3) * 2))"), "float32")):
             m = T.int64()
             n = T.int64()
-            gv = R.call_tir(Expected.split, (x,), [R.Tensor((m, ((n * 3) // 
3)), "float32"), R.Tensor((m, ((((n * 3) // 3) * 2) - ((n * 3) // 3))), 
"float32"), R.Tensor((m, ((n * 3) - (((n * 3) // 3) * 2))), "float32")], 
tir_vars=(n,))
+            gv = R.call_tir(Expected.split, (x,), [R.Tensor((m, ((n * 3 + 3 - 
1) // 3)), "float32"), R.Tensor((m, ((((n * 3 + 3 - 1) // 3) * 2) - ((n * 3 + 3 
- 1) // 3))), "float32"), R.Tensor((m, ((n * 3) - (((n * 3 + 3 - 1) // 3) * 
2))), "float32")], tir_vars=R.shape([n]))
             return gv
 
         @T.prim_func(private=True)
@@ -858,9 +892,9 @@ def test_split_by_indices_n_section_divisible_symbolic():
             T.func_attr({"tir.noalias": True})
             m = T.int64()
             rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n * 
T.int64(3)], dtype="float32")
-            T_split_sections = T.match_buffer(var_T_split_sections, [m, n * 
T.int64(3) // T.int64(3)], dtype="float32")
-            T_split_sections_1 = T.match_buffer(var_T_split_sections_1, [m, n 
* T.int64(3) // T.int64(3) * T.int64(2) - n * T.int64(3) // T.int64(3)], 
dtype="float32")
-            T_split_sections_2 = T.match_buffer(var_T_split_sections_2, [m, n 
* T.int64(3) - n * T.int64(3) // T.int64(3) * T.int64(2)], dtype="float32")
+            T_split_sections = T.match_buffer(var_T_split_sections, [m, (n * 
T.int64(3) + T.int64(3) - T.int64(1)) // T.int64(3)], dtype="float32")
+            T_split_sections_1 = T.match_buffer(var_T_split_sections_1, [m, (n 
* T.int64(3) + T.int64(3) - T.int64(1)) // T.int64(3) * T.int64(2) - (n * 
T.int64(3) + T.int64(3) - T.int64(1)) // T.int64(3)], dtype="float32")
+            T_split_sections_2 = T.match_buffer(var_T_split_sections_2, [m, n 
* T.int64(3) - (n * T.int64(3) + T.int64(3) - T.int64(1)) // T.int64(3) * 
T.int64(2)], dtype="float32")
             for i0, i1 in T.grid(m, n):
                 with T.block("T_split_sections"):
                     ax0, ax1 = T.axis.remap("SS", [i0, i1])
@@ -870,9 +904,9 @@ def test_split_by_indices_n_section_divisible_symbolic():
             for i0, i1 in T.grid(m, n):
                 with T.block("T_split_sections_1"):
                     ax0, ax1 = T.axis.remap("SS", [i0, i1])
-                    T.reads(rxplaceholder[ax0, n + ax1])
+                    T.reads(rxplaceholder[ax0, ax1 + n])
                     T.writes(T_split_sections_1[ax0, ax1])
-                    T_split_sections_1[ax0, ax1] = rxplaceholder[ax0, n + ax1]
+                    T_split_sections_1[ax0, ax1] = rxplaceholder[ax0, ax1 + n]
             for i0, i1 in T.grid(m, n):
                 with T.block("T_split_sections_2"):
                     ax0, ax1 = T.axis.remap("SS", [i0, i1])

Reply via email to