This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new e1c430c7e3 [Relay][Frontend][Torch] fix pytorch frontend linspace op 
(#16417)
e1c430c7e3 is described below

commit e1c430c7e3180b65e234cf39b3f1de6e71825f55
Author: TaoMiao <taom...@pku.edu.cn>
AuthorDate: Fri Jan 19 02:40:36 2024 +0800

    [Relay][Frontend][Torch] fix pytorch frontend linspace op (#16417)
    
    fix pytorch frontend linspace op
---
 python/tvm/relay/frontend/pytorch.py          | 2 +-
 tests/python/frontend/pytorch/test_forward.py | 5 +++++
 2 files changed, 6 insertions(+), 1 deletion(-)

diff --git a/python/tvm/relay/frontend/pytorch.py 
b/python/tvm/relay/frontend/pytorch.py
index 35f74544b8..8594ee0e06 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -918,7 +918,7 @@ class PyTorchOpConverter:
         # Find the spacing between values as step
         if step != 1:
             step = (stop - start) / (step - 1)
-            stop = stop + step
+            stop = stop + (step / 2)
         else:
             stop = start + step
 
diff --git a/tests/python/frontend/pytorch/test_forward.py 
b/tests/python/frontend/pytorch/test_forward.py
index bf96c21399..6d07f081e9 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -3632,6 +3632,10 @@ def test_forward_linspace():
         def forward(self, *args):
             return torch.linspace(1, 2, 1, dtype=torch.int16)
 
+    class Linspace9(Module):
+        def forward(self, *args):
+            return torch.linspace(0, 8, 10)
+
     verify_model(Linspace1().float().eval())
     verify_model(Linspace2().float().eval())
     verify_model(Linspace3().float().eval())
@@ -3640,6 +3644,7 @@ def test_forward_linspace():
     verify_model(Linspace6().float().eval())
     verify_model(Linspace7().float().eval())
     verify_model(Linspace8().float().eval())
+    verify_model(Linspace9().float().eval())
 
 
 @tvm.testing.uses_gpu

Reply via email to