This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 6de551b65b [Unity][Frontend] FX translator support torch.baddbmm 
(#14202)
6de551b65b is described below

commit 6de551b65b220a3ed3c2c699b1f957bd1527f4b6
Author: Ruihang Lai <[email protected]>
AuthorDate: Sun Mar 5 20:39:15 2023 -0500

    [Unity][Frontend] FX translator support torch.baddbmm (#14202)
    
    This PR brings the support of translating `torch.baddbmm` into
    combination of operators (matmul, add, multiply). Unit tests
    are provided accordingly.
    
    This PR also fixes the kwarg fetching issue of `torch.interpolate`.
---
 python/tvm/relax/frontend/torch/fx_translator.py | 63 +++++++++++++++++++--
 tests/python/relax/test_frontend_from_fx.py      | 71 +++++++++++++++++++++++-
 2 files changed, 127 insertions(+), 7 deletions(-)

diff --git a/python/tvm/relax/frontend/torch/fx_translator.py 
b/python/tvm/relax/frontend/torch/fx_translator.py
index fa68b2eee3..c89b15a7d5 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -360,6 +360,28 @@ class TorchFXImporter:
         matmul = self.block_builder.emit(relax.op.linear_algebra.matmul(y, z, 
out_dtype="float32"))
         return self.block_builder.emit(relax.op.add(x, matmul))
 
+    def _baddbmm(self, node: fx.node.Node) -> relax.Var:
+        x = self.env[node.args[0]]
+        a = self.env[node.args[1]]
+        b = self.env[node.args[2]]
+        alpha = node.kwargs["alpha"] if "alpha" in node.kwargs else 1
+        beta = node.kwargs["beta"] if "beta" in node.kwargs else 1
+
+        res = None
+        if alpha != 0:
+            res = self.block_builder.emit(relax.op.matmul(a, b))
+            if alpha != 1:
+                dtype = res.struct_info.dtype
+                res = self.block_builder.emit(relax.op.multiply(res, 
relax.const(alpha, dtype)))
+        if beta != 0:
+            dtype = x.struct_info.dtype
+            if beta != 1:
+                bias = self.block_builder.emit(relax.op.multiply(x, 
relax.const(beta, dtype)))
+            else:
+                bias = x
+            res = bias if res is None else 
self.block_builder.emit(relax.op.add(res, bias))
+        return res
+
     ########## Manipulation ##########
 
     def _cat(self, node: fx.node.Node) -> relax.Var:
@@ -661,12 +683,40 @@ class TorchFXImporter:
         # (TODO) this is a temporary implementation for interpolate that only 
considers NCHW layout
         # it basically replicates the implementation in 
tvm.relay.frontend.pytorch
         data = self.env[node.args[0]]
-        size = node.kwargs["size"]
-        scale_factor = node.kwargs["scale_factor"]
-        method = node.kwargs["mode"]
-        align_corners = node.kwargs["align_corners"]
-        recompute_scale_factor = node.kwargs["recompute_scale_factor"]
-        antialias = node.kwargs["antialias"]
+        size = (
+            node.args[1]
+            if len(node.args) > 1
+            else (node.kwargs["size"] if "size" in node.kwargs else None)
+        )
+        scale_factor = (
+            node.args[2]
+            if len(node.args) > 2
+            else (node.kwargs["scale_factor"] if "scale_factor" in node.kwargs 
else None)
+        )
+        method = (
+            node.args[3]
+            if len(node.args) > 3
+            else (node.kwargs["method"] if "method" in node.kwargs else 
"nearest")
+        )
+        align_corners = (
+            node.args[4]
+            if len(node.args) > 4
+            else (node.kwargs["align_corners"] if "align_corners" in 
node.kwargs else None)
+        )
+        recompute_scale_factor = (
+            node.args[5]
+            if len(node.args) > 5
+            else (
+                node.kwargs["recompute_scale_factor"]
+                if "recompute_scale_factor" in node.kwargs
+                else None
+            )
+        )
+        antialias = (
+            node.args[6]
+            if len(node.args) > 6
+            else (node.kwargs["antialias"] if "antialias" in node.kwargs else 
False)
+        )
 
         assert recompute_scale_factor is None
         assert antialias is False
@@ -816,6 +866,7 @@ class TorchFXImporter:
             "astype": self._type,
             "matmul": self._matmul,
             "addmm": self._addmm,
+            "baddbmm": self._baddbmm,
             "bmm": self._matmul,
             "cat": self._cat,
             "expand": self._expand,
diff --git a/tests/python/relax/test_frontend_from_fx.py 
b/tests/python/relax/test_frontend_from_fx.py
index 4fd7cee812..8dfbc97d8b 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -261,6 +261,75 @@ def test_bmm():
     )
 
 
[email protected]_gpu
+def test_baddbmm():
+    import torch
+    from torch.nn import Module
+
+    torch.set_grad_enabled(False)
+    torch.random.manual_seed(0)
+
+    class BAddBMM1(Module):
+        def __init__(self):
+            super().__init__()
+
+        def forward(self, c, x, y):
+            return torch.baddbmm(c, x, y)
+
+    class BAddBMM2(Module):
+        def __init__(self):
+            super().__init__()
+
+        def forward(self, c, x, y):
+            return torch.baddbmm(c, x, y, alpha=2, beta=0)
+
+    @tvm.script.ir_module
+    class Expected1:
+        @R.function
+        def main(
+            inp_0: R.Tensor((4, 128, 512), dtype="float32"),
+            inp_1: R.Tensor((4, 128, 256), dtype="float32"),
+            inp_2: R.Tensor((4, 256, 512), dtype="float32"),
+        ) -> R.Tensor((4, 128, 512), dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((4, 128, 512), dtype="float32") = R.matmul(inp_1, 
inp_2)
+                lv1: R.Tensor((4, 128, 512), dtype="float32") = R.add(lv, 
inp_0)
+                gv: R.Tensor((4, 128, 512), dtype="float32") = lv1
+                R.output(gv)
+            return gv
+
+    @tvm.script.ir_module
+    class Expected2:
+        @R.function
+        def main(
+            inp_0: R.Tensor((4, 128, 512), dtype="float32"),
+            inp_1: R.Tensor((4, 128, 256), dtype="float32"),
+            inp_2: R.Tensor((4, 256, 512), dtype="float32"),
+        ) -> R.Tensor((4, 128, 512), dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((4, 128, 512), dtype="float32") = R.matmul(inp_1, 
inp_2)
+                lv1: R.Tensor((4, 128, 512), dtype="float32") = R.multiply(
+                    lv, R.const(2, "float32")
+                )
+                gv: R.Tensor((4, 128, 512), dtype="float32") = lv1
+                R.output(gv)
+            return gv
+
+    verify_model(
+        BAddBMM1(),
+        [((4, 128, 512), "float32"), ((4, 128, 256), "float32"), ((4, 256, 
512), "float32")],
+        {},
+        Expected1,
+    )
+
+    verify_model(
+        BAddBMM2(),
+        [((4, 128, 512), "float32"), ((4, 128, 256), "float32"), ((4, 256, 
512), "float32")],
+        {},
+        Expected2,
+    )
+
+
 @tvm.testing.requires_gpu
 def test_relu():
     import torch
@@ -1488,7 +1557,7 @@ def test_interpolate():
 
     class Interpolate(Module):
         def forward(self, input):
-            return torch.nn.functional.interpolate(input, size=(5, 5))
+            return torch.nn.functional.interpolate(input, (5, 5))
 
     @tvm.script.ir_module
     class expected1:

Reply via email to