This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 2a9709c90b [Unity][Frontend] FX exp and strided_slice fix (#14338)
2a9709c90b is described below

commit 2a9709c90beaf816607402b91b3e016b553375b3
Author: Ruihang Lai <[email protected]>
AuthorDate: Mon Mar 20 10:05:44 2023 -0400

    [Unity][Frontend] FX exp and strided_slice fix (#14338)
    
    * Add the support of `exp` for the FX translator.
    * Previously the way FX translator dealt with `None` in torch tensor
    slice (e.g., `x[:, None, None]`) is not right. This PR fixes this issue.
    Specifically, the `None` here means dim expansion, and the previous impl
    mistakenly increases the dim counter when seeing `None`, which will lead
    to dim counter out-of-range issue in the end.
---
 python/tvm/relax/frontend/torch/fx_translator.py |  7 +++-
 tests/python/relax/test_frontend_from_fx.py      | 48 ++++++++++++++++++++++--
 2 files changed, 49 insertions(+), 6 deletions(-)

diff --git a/python/tvm/relax/frontend/torch/fx_translator.py 
b/python/tvm/relax/frontend/torch/fx_translator.py
index 0bd987cf2d..a2e2afe668 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -136,6 +136,9 @@ class TorchFXImporter:
     def _cos(self, node: fx.node.Node) -> relax.Var:
         return self.block_builder.emit(relax.op.cos(self.env[node.args[0]]))
 
+    def _exp(self, node: fx.node.Node) -> relax.Var:
+        return self.block_builder.emit(relax.op.exp(self.env[node.args[0]]))
+
     def _sin(self, node: fx.node.Node) -> relax.Var:
         return self.block_builder.emit(relax.op.sin(self.env[node.args[0]]))
 
@@ -858,8 +861,7 @@ class TorchFXImporter:
                     axes.append(i)
                     i = i + 1
                 elif index is None:
-                    expand_dim.append(i)
-                    i = i + 1
+                    expand_dim.append(len(axes) + len(expand_dim))
                 else:
                     raise ValueError("Unsupported index type: " + 
str(type(index)))
             while i < len(shape):
@@ -903,6 +905,7 @@ class TorchFXImporter:
             nn.modules.sparse.Embedding: self._embedding,
             # call_function and call_method
             "cos": self._cos,
+            "exp": self._exp,
             "sin": self._sin,
             "add": self._add,
             "floordiv": self._floordiv,
diff --git a/tests/python/relax/test_frontend_from_fx.py 
b/tests/python/relax/test_frontend_from_fx.py
index 31b43070cb..2e69795d51 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -19,7 +19,7 @@ import pytest
 import tvm
 from tvm import relax
 import tvm.testing
-from tvm.script.parser import relax as R, tir as T
+from tvm.script.parser import ir as I, relax as R, tir as T
 
 
 def verify_model(torch_model, input_info, binding, expected):
@@ -1372,8 +1372,6 @@ def test_getitem():
     torch.set_grad_enabled(False)
     torch.random.manual_seed(0)
 
-    input_info = [([1, 3, 10, 10], "float32")]
-
     class Slice1(Module):
         def forward(self, x):
             return x[0, 1::2, :, :3]
@@ -1398,7 +1396,29 @@ def test_getitem():
                 R.output(gv)
             return gv
 
-    verify_model(Slice1(), input_info, {}, expected1)
+    class Slice2(Module):
+        def forward(self, x):
+            return x[:, None, None, :, None]
+
+    @I.ir_module
+    class expected2:
+        @R.function
+        def main(
+            inp_0: R.Tensor((8, 16), dtype="float32")
+        ) -> R.Tensor((8, 1, 1, 16, 1), dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((8, 16), dtype="float32") = R.strided_slice(
+                    inp_0, axes=[0, 1], begin=[0, 0], end=[8, 16], strides=[1, 
1]
+                )
+                lv1: R.Tensor((8, 1, 1, 16, 1), dtype="float32") = R.reshape(
+                    lv, R.shape([8, 1, 1, 16, 1])
+                )
+                gv: R.Tensor((8, 1, 1, 16, 1), dtype="float32") = lv1
+                R.output(gv)
+            return gv
+
+    verify_model(Slice1(), [([1, 3, 10, 10], "float32")], {}, expected1)
+    verify_model(Slice2(), [([8, 16], "float32")], {}, expected2)
 
 
 @tvm.testing.requires_gpu
@@ -1451,6 +1471,26 @@ def test_unary():
 
     verify_model(Cos(), input_info, {}, expected2)
 
+    # exp
+    class Exp(Module):
+        def forward(self, input):
+            return torch.exp(input)
+
+    @tvm.script.ir_module
+    class expected_exp:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+        ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.exp(input_1)
+                gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    verify_model(Exp(), input_info, {}, expected_exp)
+
     # sqrt
     class Sqrt(Module):
         def forward(self, input):

Reply via email to