This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 6de551b65b [Unity][Frontend] FX translator support torch.baddbmm
(#14202)
6de551b65b is described below
commit 6de551b65b220a3ed3c2c699b1f957bd1527f4b6
Author: Ruihang Lai <[email protected]>
AuthorDate: Sun Mar 5 20:39:15 2023 -0500
[Unity][Frontend] FX translator support torch.baddbmm (#14202)
This PR brings the support of translating `torch.baddbmm` into
combination of operators (matmul, add, multiply). Unit tests
are provided accordingly.
This PR also fixes the kwarg fetching issue of `torch.interpolate`.
---
python/tvm/relax/frontend/torch/fx_translator.py | 63 +++++++++++++++++++--
tests/python/relax/test_frontend_from_fx.py | 71 +++++++++++++++++++++++-
2 files changed, 127 insertions(+), 7 deletions(-)
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py
b/python/tvm/relax/frontend/torch/fx_translator.py
index fa68b2eee3..c89b15a7d5 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -360,6 +360,28 @@ class TorchFXImporter:
matmul = self.block_builder.emit(relax.op.linear_algebra.matmul(y, z,
out_dtype="float32"))
return self.block_builder.emit(relax.op.add(x, matmul))
+ def _baddbmm(self, node: fx.node.Node) -> relax.Var:
+ x = self.env[node.args[0]]
+ a = self.env[node.args[1]]
+ b = self.env[node.args[2]]
+ alpha = node.kwargs["alpha"] if "alpha" in node.kwargs else 1
+ beta = node.kwargs["beta"] if "beta" in node.kwargs else 1
+
+ res = None
+ if alpha != 0:
+ res = self.block_builder.emit(relax.op.matmul(a, b))
+ if alpha != 1:
+ dtype = res.struct_info.dtype
+ res = self.block_builder.emit(relax.op.multiply(res,
relax.const(alpha, dtype)))
+ if beta != 0:
+ dtype = x.struct_info.dtype
+ if beta != 1:
+ bias = self.block_builder.emit(relax.op.multiply(x,
relax.const(beta, dtype)))
+ else:
+ bias = x
+ res = bias if res is None else
self.block_builder.emit(relax.op.add(res, bias))
+ return res
+
########## Manipulation ##########
def _cat(self, node: fx.node.Node) -> relax.Var:
@@ -661,12 +683,40 @@ class TorchFXImporter:
# (TODO) this is a temporary implementation for interpolate that only
considers NCHW layout
# it basically replicates the implementation in
tvm.relay.frontend.pytorch
data = self.env[node.args[0]]
- size = node.kwargs["size"]
- scale_factor = node.kwargs["scale_factor"]
- method = node.kwargs["mode"]
- align_corners = node.kwargs["align_corners"]
- recompute_scale_factor = node.kwargs["recompute_scale_factor"]
- antialias = node.kwargs["antialias"]
+ size = (
+ node.args[1]
+ if len(node.args) > 1
+ else (node.kwargs["size"] if "size" in node.kwargs else None)
+ )
+ scale_factor = (
+ node.args[2]
+ if len(node.args) > 2
+ else (node.kwargs["scale_factor"] if "scale_factor" in node.kwargs
else None)
+ )
+ method = (
+ node.args[3]
+ if len(node.args) > 3
+ else (node.kwargs["method"] if "method" in node.kwargs else
"nearest")
+ )
+ align_corners = (
+ node.args[4]
+ if len(node.args) > 4
+ else (node.kwargs["align_corners"] if "align_corners" in
node.kwargs else None)
+ )
+ recompute_scale_factor = (
+ node.args[5]
+ if len(node.args) > 5
+ else (
+ node.kwargs["recompute_scale_factor"]
+ if "recompute_scale_factor" in node.kwargs
+ else None
+ )
+ )
+ antialias = (
+ node.args[6]
+ if len(node.args) > 6
+ else (node.kwargs["antialias"] if "antialias" in node.kwargs else
False)
+ )
assert recompute_scale_factor is None
assert antialias is False
@@ -816,6 +866,7 @@ class TorchFXImporter:
"astype": self._type,
"matmul": self._matmul,
"addmm": self._addmm,
+ "baddbmm": self._baddbmm,
"bmm": self._matmul,
"cat": self._cat,
"expand": self._expand,
diff --git a/tests/python/relax/test_frontend_from_fx.py
b/tests/python/relax/test_frontend_from_fx.py
index 4fd7cee812..8dfbc97d8b 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -261,6 +261,75 @@ def test_bmm():
)
[email protected]_gpu
+def test_baddbmm():
+ import torch
+ from torch.nn import Module
+
+ torch.set_grad_enabled(False)
+ torch.random.manual_seed(0)
+
+ class BAddBMM1(Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, c, x, y):
+ return torch.baddbmm(c, x, y)
+
+ class BAddBMM2(Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, c, x, y):
+ return torch.baddbmm(c, x, y, alpha=2, beta=0)
+
+ @tvm.script.ir_module
+ class Expected1:
+ @R.function
+ def main(
+ inp_0: R.Tensor((4, 128, 512), dtype="float32"),
+ inp_1: R.Tensor((4, 128, 256), dtype="float32"),
+ inp_2: R.Tensor((4, 256, 512), dtype="float32"),
+ ) -> R.Tensor((4, 128, 512), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((4, 128, 512), dtype="float32") = R.matmul(inp_1,
inp_2)
+ lv1: R.Tensor((4, 128, 512), dtype="float32") = R.add(lv,
inp_0)
+ gv: R.Tensor((4, 128, 512), dtype="float32") = lv1
+ R.output(gv)
+ return gv
+
+ @tvm.script.ir_module
+ class Expected2:
+ @R.function
+ def main(
+ inp_0: R.Tensor((4, 128, 512), dtype="float32"),
+ inp_1: R.Tensor((4, 128, 256), dtype="float32"),
+ inp_2: R.Tensor((4, 256, 512), dtype="float32"),
+ ) -> R.Tensor((4, 128, 512), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((4, 128, 512), dtype="float32") = R.matmul(inp_1,
inp_2)
+ lv1: R.Tensor((4, 128, 512), dtype="float32") = R.multiply(
+ lv, R.const(2, "float32")
+ )
+ gv: R.Tensor((4, 128, 512), dtype="float32") = lv1
+ R.output(gv)
+ return gv
+
+ verify_model(
+ BAddBMM1(),
+ [((4, 128, 512), "float32"), ((4, 128, 256), "float32"), ((4, 256,
512), "float32")],
+ {},
+ Expected1,
+ )
+
+ verify_model(
+ BAddBMM2(),
+ [((4, 128, 512), "float32"), ((4, 128, 256), "float32"), ((4, 256,
512), "float32")],
+ {},
+ Expected2,
+ )
+
+
@tvm.testing.requires_gpu
def test_relu():
import torch
@@ -1488,7 +1557,7 @@ def test_interpolate():
class Interpolate(Module):
def forward(self, input):
- return torch.nn.functional.interpolate(input, size=(5, 5))
+ return torch.nn.functional.interpolate(input, (5, 5))
@tvm.script.ir_module
class expected1: