This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 95cbdaad77 [Relax] Allow ingesting tensor.chunk() from exported torch 
program (#17758)
95cbdaad77 is described below

commit 95cbdaad77f3ff19c01b2cd23517c2902226ab58
Author: Hugo Latendresse <[email protected]>
AuthorDate: Thu Mar 27 13:02:54 2025 -0400

    [Relax] Allow ingesting tensor.chunk() from exported torch program (#17758)
    
    Front-end support for torch.chunk 
https://pytorch.org/docs/stable/generated/torch.chunk.html
---
 .../frontend/torch/base_fx_graph_translator.py     | 11 +++
 .../frontend/torch/exported_program_translator.py  |  1 +
 tests/python/relax/test_from_exported_to_cuda.py   | 95 +++++++++++++++++++---
 3 files changed, 95 insertions(+), 12 deletions(-)

diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py 
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index ecd8665b43..71554a8a5b 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -956,6 +956,17 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
         axis = args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0)
         return self.block_builder.emit(relax.op.concat(args[0], axis=axis))
 
+    def _chunk(self, node: fx.Node) -> relax.Var:
+        x = self.env[node.args[0]]
+        chunks = node.args[1]
+        dim = node.args[2] if len(node.args) > 2 else node.kwargs.get("dim", 0)
+        x_shape = self.shape_of(x)
+        max_chunks = x_shape[dim].value
+        n_sections = min(chunks, max_chunks)
+        return self.block_builder.emit(
+            relax.op.split(x=x, indices_or_sections=n_sections, axis=dim)
+        )
+
     def _cumsum(self, node: fx.Node) -> relax.Var:
         x = self.env[node.args[0]]
 
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py 
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 37caae6c98..4319fbebe7 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -359,6 +359,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             "argmin.default": self._argmax_argmin(relax.op.argmin),
             # tensor manipulation
             "cat.default": self._cat,
+            "chunk.default": self._chunk,
             "clamp.Tensor": self._clamp,
             "concat.default": self._cat,
             "copy_.default": self._copy_,
diff --git a/tests/python/relax/test_from_exported_to_cuda.py 
b/tests/python/relax/test_from_exported_to_cuda.py
index f7501dd3b5..64babdc43a 100644
--- a/tests/python/relax/test_from_exported_to_cuda.py
+++ b/tests/python/relax/test_from_exported_to_cuda.py
@@ -15,7 +15,6 @@
 # specific language governing permissions and limitations
 # under the License.
 
-
 import tvm
 from tvm import relax
 import tvm.testing
@@ -332,6 +331,30 @@ def test_split_size(target, dev):
     assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, 
target, dev)
 
 
[email protected]_targets("cuda")
+def test_split_sections_list(target, dev):
+    # Test split using a list of section sizes
+    batch = 3
+    channels = 2
+    height = 10
+    width = 5
+    sections = [3, 2, 5]
+    dim = 2  # split across height
+    raw_data = np.random.rand(batch, channels, height, width).astype("float32")
+
+    class SplitModelSectionsList(nn.Module):
+        def __init__(self, split_size, dim):
+            super().__init__()
+            self.split_size = split_size
+            self.dim = dim
+
+        def forward(self, x):
+            return torch.split(x, split_size_or_sections=self.split_size, 
dim=self.dim)
+
+    torch_module = SplitModelSectionsList(split_size=sections, dim=dim).eval()
+    assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, 
target, dev)
+
+
 @tvm.testing.parametrize_targets("cuda")
 def test_batch_norm0(target, dev):
     # Eval, no momentum, no affine, no running stats
@@ -373,26 +396,74 @@ def test_batch_norm3(target, dev):
 
 
 @tvm.testing.parametrize_targets("cuda")
-def test_split_sections_list(target, dev):
-    # Test split using a list of section sizes
-    batch = 3
+def test_chunk_even(target, dev):
+    # Chunks is a divisor of the dimension size
+    batch = 6
     channels = 2
-    height = 10
+    height = 3
+    width = 4
+    chunks = 3
+    dim = 0
+    raw_data = np.random.rand(batch, channels, height, width).astype("float32")
+
+    class ChunkModel(nn.Module):
+        def __init__(self, chunks, dim):
+            super().__init__()
+            self.chunks = chunks
+            self.dim = dim
+
+        def forward(self, x):
+            return x.chunk(self.chunks, dim=self.dim)
+
+    torch_module = ChunkModel(chunks=chunks, dim=dim).eval()
+    assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, 
target, dev)
+
+
[email protected]_targets("cuda")
+def test_chunk_uneven(target, dev):
+    # Chunks is not a divisor of the dimension size
+    batch = 2
+    channels = 5
+    height = 4
     width = 5
-    sections = [3, 2, 5]
-    dim = 2  # split across height
+    chunks = 2
+    dim = 1
     raw_data = np.random.rand(batch, channels, height, width).astype("float32")
 
-    class SplitModelSectionsList(nn.Module):
-        def __init__(self, split_size, dim):
+    class ChunkModel(nn.Module):
+        def __init__(self, chunks, dim):
             super().__init__()
-            self.split_size = split_size
+            self.chunks = chunks
             self.dim = dim
 
         def forward(self, x):
-            return torch.split(x, split_size_or_sections=self.split_size, 
dim=self.dim)
+            return x.chunk(self.chunks, dim=self.dim)
 
-    torch_module = SplitModelSectionsList(split_size=sections, dim=dim).eval()
+    torch_module = ChunkModel(chunks=chunks, dim=dim).eval()
+    assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, 
target, dev)
+
+
[email protected]_targets("cuda")
+def test_chunk_too_many(target, dev):
+    # If user asks for more chunks than the size of the dim, pytorch simply 
splits in sections of size 1
+    batch = 1
+    channels = 3
+    height = 2
+    width = 2
+    chunks = 99
+    dim = 1
+    raw_data = np.random.rand(batch, channels, height, width).astype("float32")
+
+    class ChunkModel(nn.Module):
+        def __init__(self, chunks, dim):
+            super().__init__()
+            self.chunks = chunks
+            self.dim = dim
+
+        def forward(self, x):
+            return x.chunk(self.chunks, dim=self.dim)
+
+    torch_module = ChunkModel(chunks=chunks, dim=dim).eval()
     assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, 
target, dev)
 
 

Reply via email to