This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 95cbdaad77 [Relax] Allow ingesting tensor.chunk() from exported torch
program (#17758)
95cbdaad77 is described below
commit 95cbdaad77f3ff19c01b2cd23517c2902226ab58
Author: Hugo Latendresse <[email protected]>
AuthorDate: Thu Mar 27 13:02:54 2025 -0400
[Relax] Allow ingesting tensor.chunk() from exported torch program (#17758)
Front-end support for torch.chunk
https://pytorch.org/docs/stable/generated/torch.chunk.html
---
.../frontend/torch/base_fx_graph_translator.py | 11 +++
.../frontend/torch/exported_program_translator.py | 1 +
tests/python/relax/test_from_exported_to_cuda.py | 95 +++++++++++++++++++---
3 files changed, 95 insertions(+), 12 deletions(-)
diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index ecd8665b43..71554a8a5b 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -956,6 +956,17 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
axis = args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0)
return self.block_builder.emit(relax.op.concat(args[0], axis=axis))
+ def _chunk(self, node: fx.Node) -> relax.Var:
+ x = self.env[node.args[0]]
+ chunks = node.args[1]
+ dim = node.args[2] if len(node.args) > 2 else node.kwargs.get("dim", 0)
+ x_shape = self.shape_of(x)
+ max_chunks = x_shape[dim].value
+ n_sections = min(chunks, max_chunks)
+ return self.block_builder.emit(
+ relax.op.split(x=x, indices_or_sections=n_sections, axis=dim)
+ )
+
def _cumsum(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 37caae6c98..4319fbebe7 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -359,6 +359,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
"argmin.default": self._argmax_argmin(relax.op.argmin),
# tensor manipulation
"cat.default": self._cat,
+ "chunk.default": self._chunk,
"clamp.Tensor": self._clamp,
"concat.default": self._cat,
"copy_.default": self._copy_,
diff --git a/tests/python/relax/test_from_exported_to_cuda.py
b/tests/python/relax/test_from_exported_to_cuda.py
index f7501dd3b5..64babdc43a 100644
--- a/tests/python/relax/test_from_exported_to_cuda.py
+++ b/tests/python/relax/test_from_exported_to_cuda.py
@@ -15,7 +15,6 @@
# specific language governing permissions and limitations
# under the License.
-
import tvm
from tvm import relax
import tvm.testing
@@ -332,6 +331,30 @@ def test_split_size(target, dev):
assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module,
target, dev)
[email protected]_targets("cuda")
+def test_split_sections_list(target, dev):
+ # Test split using a list of section sizes
+ batch = 3
+ channels = 2
+ height = 10
+ width = 5
+ sections = [3, 2, 5]
+ dim = 2 # split across height
+ raw_data = np.random.rand(batch, channels, height, width).astype("float32")
+
+ class SplitModelSectionsList(nn.Module):
+ def __init__(self, split_size, dim):
+ super().__init__()
+ self.split_size = split_size
+ self.dim = dim
+
+ def forward(self, x):
+ return torch.split(x, split_size_or_sections=self.split_size,
dim=self.dim)
+
+ torch_module = SplitModelSectionsList(split_size=sections, dim=dim).eval()
+ assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module,
target, dev)
+
+
@tvm.testing.parametrize_targets("cuda")
def test_batch_norm0(target, dev):
# Eval, no momentum, no affine, no running stats
@@ -373,26 +396,74 @@ def test_batch_norm3(target, dev):
@tvm.testing.parametrize_targets("cuda")
-def test_split_sections_list(target, dev):
- # Test split using a list of section sizes
- batch = 3
+def test_chunk_even(target, dev):
+ # Chunks is a divisor of the dimension size
+ batch = 6
channels = 2
- height = 10
+ height = 3
+ width = 4
+ chunks = 3
+ dim = 0
+ raw_data = np.random.rand(batch, channels, height, width).astype("float32")
+
+ class ChunkModel(nn.Module):
+ def __init__(self, chunks, dim):
+ super().__init__()
+ self.chunks = chunks
+ self.dim = dim
+
+ def forward(self, x):
+ return x.chunk(self.chunks, dim=self.dim)
+
+ torch_module = ChunkModel(chunks=chunks, dim=dim).eval()
+ assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module,
target, dev)
+
+
[email protected]_targets("cuda")
+def test_chunk_uneven(target, dev):
+ # Chunks is not a divisor of the dimension size
+ batch = 2
+ channels = 5
+ height = 4
width = 5
- sections = [3, 2, 5]
- dim = 2 # split across height
+ chunks = 2
+ dim = 1
raw_data = np.random.rand(batch, channels, height, width).astype("float32")
- class SplitModelSectionsList(nn.Module):
- def __init__(self, split_size, dim):
+ class ChunkModel(nn.Module):
+ def __init__(self, chunks, dim):
super().__init__()
- self.split_size = split_size
+ self.chunks = chunks
self.dim = dim
def forward(self, x):
- return torch.split(x, split_size_or_sections=self.split_size,
dim=self.dim)
+ return x.chunk(self.chunks, dim=self.dim)
- torch_module = SplitModelSectionsList(split_size=sections, dim=dim).eval()
+ torch_module = ChunkModel(chunks=chunks, dim=dim).eval()
+ assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module,
target, dev)
+
+
[email protected]_targets("cuda")
+def test_chunk_too_many(target, dev):
+ # If user asks for more chunks than the size of the dim, pytorch simply
splits in sections of size 1
+ batch = 1
+ channels = 3
+ height = 2
+ width = 2
+ chunks = 99
+ dim = 1
+ raw_data = np.random.rand(batch, channels, height, width).astype("float32")
+
+ class ChunkModel(nn.Module):
+ def __init__(self, chunks, dim):
+ super().__init__()
+ self.chunks = chunks
+ self.dim = dim
+
+ def forward(self, x):
+ return x.chunk(self.chunks, dim=self.dim)
+
+ torch_module = ChunkModel(chunks=chunks, dim=dim).eval()
assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module,
target, dev)