This is an automated email from the ASF dual-hosted git repository.

yongwww pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new dcf5edb889 [Relax][Pytorch] support for arange in exported programs 
translator (#17802)
dcf5edb889 is described below

commit dcf5edb889098183fce0923afc1d08d1910b1906
Author: Hugo Latendresse <[email protected]>
AuthorDate: Wed Apr 2 18:08:28 2025 -0400

    [Relax][Pytorch] support for arange in exported programs translator (#17802)
    
    * cherry pick from arange branch
    
    * cleanup + unit test
---
 .../frontend/torch/exported_program_translator.py  |  4 ++-
 tests/python/relax/test_from_exported_to_cuda.py   | 33 ++++++++++++++++++++++
 2 files changed, 36 insertions(+), 1 deletion(-)

diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py 
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 62e98b88ed..b35cf4ce20 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -397,9 +397,11 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             "reshape.default": self._reshape,
             # tensor creation
             "_to_copy.default": self._to_copy,
+            "arange.default": self._arange,
+            "arange.start": self._arange,
+            "arange.start_step": self._arange,
             "detach.default": self._detach,
             "detach_.default": self._detach,
-            "arange.start": self._arange,
             "contiguous.default": lambda node: self.env[node.args[0]],  # no-op
             "clone.default": lambda node: self.env[node.args[0]],
             "empty.memory_format": self._empty,
diff --git a/tests/python/relax/test_from_exported_to_cuda.py 
b/tests/python/relax/test_from_exported_to_cuda.py
index 19b8f80a23..56ee527caf 100644
--- a/tests/python/relax/test_from_exported_to_cuda.py
+++ b/tests/python/relax/test_from_exported_to_cuda.py
@@ -467,6 +467,39 @@ def test_chunk_too_many(target, dev):
     assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, 
target, dev)
 
 
[email protected]_targets("cuda")
+def test_arange(target, dev):
+    # arange.default
+    raw_data = np.array([0, 0, 0, 0, 0])
+
+    class ArangeDefaultModel(nn.Module):
+        def forward(self, x):
+            return x + torch.arange(5)
+
+    torch_module = ArangeDefaultModel().eval()
+    assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, 
target, dev)
+
+    # arange.start
+    raw_data = np.array([0, 0, 0])
+
+    class ArangeStartModel(nn.Module):
+        def forward(self, x):
+            return x + torch.arange(1, 4)
+
+    torch_module = ArangeStartModel().eval()
+    assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, 
target, dev)
+
+    # arange.start_step
+    raw_data = np.array([0.0, 0.0, 0.0], dtype=np.float32)
+
+    class ArangeStartStopModel(nn.Module):
+        def forward(self, x):
+            return x + torch.arange(1, 2.5, 0.5, dtype=torch.float32)
+
+    torch_module = ArangeStartStopModel().eval()
+    assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, 
target, dev)
+
+
 @tvm.testing.parametrize_targets("cuda")
 def test_index_select(target, dev):
     class IndexSelectModel(nn.Module):

Reply via email to