This is an automated email from the ASF dual-hosted git repository.
jwfromm pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new ac3078c7f5 [Unity][FX] Add fx support for functional linear and conv2d
(#15566)
ac3078c7f5 is described below
commit ac3078c7f553398424d4677175de9e7147f12a2a
Author: Josh Fromm <[email protected]>
AuthorDate: Thu Aug 17 12:30:35 2023 -0700
[Unity][FX] Add fx support for functional linear and conv2d (#15566)
* Add fx support for functional linear and conv2d
* Use param in getattr where possible
* Sneak in one fix to Slice with primvalues
* Fix lint
* Format
---
python/tvm/relax/frontend/onnx/onnx_frontend.py | 5 +-
python/tvm/relax/frontend/torch/fx_translator.py | 83 +++++++++++++++++++-----
tests/python/relax/test_frontend_from_fx.py | 26 ++++++++
3 files changed, 97 insertions(+), 17 deletions(-)
diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index 152db73c51..af7decbc13 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -987,7 +987,10 @@ class Slice(OnnxOpConverter):
steps = get_constant(inputs[4], params)
if not all(
[
- (isinstance(param, (relax.Constant, relax.ShapeExpr)) or param
is None)
+ (
+ isinstance(param, (relax.Constant, relax.ShapeExpr,
relax.PrimValue))
+ or param is None
+ )
for param in [starts, ends, axes, steps]
]
):
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py
b/python/tvm/relax/frontend/torch/fx_translator.py
index 85764565e2..be95a4880b 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -42,8 +42,7 @@ class TorchFXImporter:
self.create_convert_map()
########## Utilities ##########
- @staticmethod
- def _fetch_attr(model, target: str):
+ def _fetch_attr(self, model, target: str):
import torch # type: ignore
target_atoms = target.split(".")
@@ -55,6 +54,10 @@ class TorchFXImporter:
)
attr_itr = getattr(attr_itr, atom)
if isinstance(attr_itr, torch.Tensor):
+ # Its possible for the resulting tensor to be a parameter.
+ # If so, return the parameter instead.
+ if attr_itr in self.params:
+ return self.params[attr_itr]
return TorchFXImporter._convert_torch_tensor_to_relax(attr_itr)
return attr_itr
@@ -662,6 +665,13 @@ class TorchFXImporter:
bias = None if module.bias is None else self.params[module.bias]
return self.block_builder.emit(relax.op.linear(x, weight, bias,
"float32"))
+ def _linear_functional(self, node: fx.node.Node) -> relax.Var:
+ args = self.retrieve_args(node)
+ x = args[0]
+ weight = args[1]
+ bias = args[2] if len(args) > 2 else None
+ return self.block_builder.emit(relax.op.linear(x, weight, bias,
"float32"))
+
def _conv1d(self, node: fx.node.Node) -> relax.Var:
x = self.env[node.args[0]]
module = self.named_modules[node.target]
@@ -690,32 +700,34 @@ class TorchFXImporter:
return self.block_builder.emit(relax.op.add(conv1d, bias))
- def _conv2d(self, node: fx.node.Node) -> relax.Var:
- x = self.env[node.args[0]]
- module = self.named_modules[node.target]
- weight = self.params[module.weight]
-
+ def _conv2d_impl(
+ self,
+ x: relax.Expr,
+ weight: relax.Expr,
+ bias: Optional[relax.Expr],
+ strides: Optional[Tuple],
+ padding: Optional[Tuple],
+ dilation: Optional[Tuple],
+ groups: Optional[Tuple],
+ ):
conv2d = self.block_builder.emit(
relax.op.nn.conv2d(
x,
weight,
- strides=module.stride,
- padding=module.padding,
- dilation=module.dilation,
- groups=module.groups,
+ strides=strides,
+ padding=padding,
+ dilation=dilation,
+ groups=groups,
data_layout="NCHW",
kernel_layout="OIHW",
out_dtype="float32",
)
)
- if module.bias is None:
+ if bias is None:
return conv2d
-
- bias = self.params[module.bias]
assert len(self.shape_of(bias)) == 1
bias = relax.op.reshape(bias, (1, -1, 1, 1))
-
return self.block_builder.emit(relax.op.add(conv2d, bias))
def _conv1d_transpose(self, node: fx.node.Node) -> relax.Var:
@@ -774,6 +786,43 @@ class TorchFXImporter:
return self.block_builder.emit(relax.op.add(conv2d_transpose, bias))
+ def _conv2d(self, node: fx.node.Node) -> relax.Var:
+ x = self.env[node.args[0]]
+ module = self.named_modules[node.target]
+ weight = self.params[module.weight]
+ bias = None
+ if module.bias is not None:
+ bias = self.params[module.bias]
+
+ return self._conv2d_impl(
+ x,
+ weight,
+ bias=bias,
+ strides=module.stride,
+ padding=module.padding,
+ dilation=module.dilation,
+ groups=module.groups,
+ )
+
+ def _conv2d_functional(self, node: fx.node.Node) -> relax.Var:
+ args = self.retrieve_args(node)
+ x = args[0]
+ weight = args[1]
+ bias = args[2] if len(args) > 2 else None
+ stride = args[3] if len(args) > 3 else 1
+ padding = args[4] if len(args) > 4 else 0
+ dilation = args[5] if len(args) > 5 else 1
+ groups = args[6] if len(args) > 6 else 1
+ return self._conv2d_impl(
+ x,
+ weight,
+ bias=bias,
+ strides=stride,
+ padding=padding,
+ dilation=dilation,
+ groups=groups,
+ )
+
def _max_pool2d(self, node: fx.node.Node) -> relax.Var:
x = self.env[node.args[0]]
if node.target in self.named_modules:
@@ -1272,6 +1321,8 @@ class TorchFXImporter:
"type": self._type,
"astype": self._type,
"matmul": self._matmul,
+ "conv2d": self._conv2d_functional,
+ "linear": self._linear_functional,
"addmm": self._addmm,
"baddbmm": self._baddbmm,
"bmm": self._matmul,
@@ -1395,7 +1446,7 @@ class TorchFXImporter:
output = self.block_builder.emit_output(args[0])
break
elif node.op == "get_attr":
- self.env[node] = TorchFXImporter._fetch_attr(model,
node.target)
+ self.env[node] = self._fetch_attr(model, node.target)
elif node.op == "call_module":
module = self.named_modules[node.target]
assert (
diff --git a/tests/python/relax/test_frontend_from_fx.py
b/tests/python/relax/test_frontend_from_fx.py
index fe4bcbcf5a..2b95d3897d 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -206,6 +206,15 @@ def test_conv2d():
def forward(self, input):
return self.conv(input)
+ class Conv2D1Func(Module):
+ def __init__(self):
+ super().__init__()
+ self.weight = torch.randn(size=[6, 3, 7, 7])
+ self.bias = torch.randn(size=[6])
+
+ def forward(self, input):
+ return torch.nn.functional.conv2d(input, self.weight, self.bias)
+
@tvm.script.ir_module
class expected1:
@R.function
@@ -271,6 +280,10 @@ def test_conv2d():
binding = {"w1": model.conv.weight.detach().numpy(), "w2":
model.conv.bias.detach().numpy()}
verify_model(model, input_info, binding, expected1)
+ model = Conv2D1Func()
+ binding = {"w1": model.weight.numpy(), "w2": model.bias.numpy()}
+ verify_model(model, input_info, binding, expected1)
+
model = Conv2D2()
binding = {"w1": model.conv.weight.detach().numpy()}
verify_model(model, input_info, binding, expected2)
@@ -365,6 +378,15 @@ def test_linear():
def forward(self, input):
return self.linear(input)
+ class Dense1Func(Module):
+ def __init__(self):
+ super().__init__()
+ self.weight = torch.randn(size=[7, 10])
+ self.bias = torch.randn(size=[7])
+
+ def forward(self, input):
+ return torch.nn.functional.linear(input, self.weight, self.bias)
+
@tvm.script.ir_module
class expected1:
@R.function
@@ -415,6 +437,10 @@ def test_linear():
binding = {"w1": model.linear.weight.detach().numpy(), "w2":
model.linear.bias.detach().numpy()}
verify_model(model, input_info, binding, expected1)
+ model = Dense1Func()
+ binding = {"w1": model.weight.numpy(), "w2": model.bias.numpy()}
+ verify_model(model, input_info, binding, expected1)
+
model = Dense2()
binding = {"w1": model.linear.weight.detach().numpy()}
verify_model(model, input_info, binding, expected2)