This is an automated email from the ASF dual-hosted git repository.

jwfromm pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new ac3078c7f5 [Unity][FX] Add fx support for functional linear and conv2d 
(#15566)
ac3078c7f5 is described below

commit ac3078c7f553398424d4677175de9e7147f12a2a
Author: Josh Fromm <[email protected]>
AuthorDate: Thu Aug 17 12:30:35 2023 -0700

    [Unity][FX] Add fx support for functional linear and conv2d (#15566)
    
    * Add fx support for functional linear and conv2d
    
    * Use param in getattr where possible
    
    * Sneak in one fix to Slice with primvalues
    
    * Fix lint
    
    * Format
---
 python/tvm/relax/frontend/onnx/onnx_frontend.py  |  5 +-
 python/tvm/relax/frontend/torch/fx_translator.py | 83 +++++++++++++++++++-----
 tests/python/relax/test_frontend_from_fx.py      | 26 ++++++++
 3 files changed, 97 insertions(+), 17 deletions(-)

diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py 
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index 152db73c51..af7decbc13 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -987,7 +987,10 @@ class Slice(OnnxOpConverter):
         steps = get_constant(inputs[4], params)
         if not all(
             [
-                (isinstance(param, (relax.Constant, relax.ShapeExpr)) or param 
is None)
+                (
+                    isinstance(param, (relax.Constant, relax.ShapeExpr, 
relax.PrimValue))
+                    or param is None
+                )
                 for param in [starts, ends, axes, steps]
             ]
         ):
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py 
b/python/tvm/relax/frontend/torch/fx_translator.py
index 85764565e2..be95a4880b 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -42,8 +42,7 @@ class TorchFXImporter:
         self.create_convert_map()
 
     ########## Utilities ##########
-    @staticmethod
-    def _fetch_attr(model, target: str):
+    def _fetch_attr(self, model, target: str):
         import torch  # type: ignore
 
         target_atoms = target.split(".")
@@ -55,6 +54,10 @@ class TorchFXImporter:
                 )
             attr_itr = getattr(attr_itr, atom)
         if isinstance(attr_itr, torch.Tensor):
+            # Its possible for the resulting tensor to be a parameter.
+            # If so, return the parameter instead.
+            if attr_itr in self.params:
+                return self.params[attr_itr]
             return TorchFXImporter._convert_torch_tensor_to_relax(attr_itr)
         return attr_itr
 
@@ -662,6 +665,13 @@ class TorchFXImporter:
         bias = None if module.bias is None else self.params[module.bias]
         return self.block_builder.emit(relax.op.linear(x, weight, bias, 
"float32"))
 
+    def _linear_functional(self, node: fx.node.Node) -> relax.Var:
+        args = self.retrieve_args(node)
+        x = args[0]
+        weight = args[1]
+        bias = args[2] if len(args) > 2 else None
+        return self.block_builder.emit(relax.op.linear(x, weight, bias, 
"float32"))
+
     def _conv1d(self, node: fx.node.Node) -> relax.Var:
         x = self.env[node.args[0]]
         module = self.named_modules[node.target]
@@ -690,32 +700,34 @@ class TorchFXImporter:
 
         return self.block_builder.emit(relax.op.add(conv1d, bias))
 
-    def _conv2d(self, node: fx.node.Node) -> relax.Var:
-        x = self.env[node.args[0]]
-        module = self.named_modules[node.target]
-        weight = self.params[module.weight]
-
+    def _conv2d_impl(
+        self,
+        x: relax.Expr,
+        weight: relax.Expr,
+        bias: Optional[relax.Expr],
+        strides: Optional[Tuple],
+        padding: Optional[Tuple],
+        dilation: Optional[Tuple],
+        groups: Optional[Tuple],
+    ):
         conv2d = self.block_builder.emit(
             relax.op.nn.conv2d(
                 x,
                 weight,
-                strides=module.stride,
-                padding=module.padding,
-                dilation=module.dilation,
-                groups=module.groups,
+                strides=strides,
+                padding=padding,
+                dilation=dilation,
+                groups=groups,
                 data_layout="NCHW",
                 kernel_layout="OIHW",
                 out_dtype="float32",
             )
         )
 
-        if module.bias is None:
+        if bias is None:
             return conv2d
-
-        bias = self.params[module.bias]
         assert len(self.shape_of(bias)) == 1
         bias = relax.op.reshape(bias, (1, -1, 1, 1))
-
         return self.block_builder.emit(relax.op.add(conv2d, bias))
 
     def _conv1d_transpose(self, node: fx.node.Node) -> relax.Var:
@@ -774,6 +786,43 @@ class TorchFXImporter:
 
         return self.block_builder.emit(relax.op.add(conv2d_transpose, bias))
 
+    def _conv2d(self, node: fx.node.Node) -> relax.Var:
+        x = self.env[node.args[0]]
+        module = self.named_modules[node.target]
+        weight = self.params[module.weight]
+        bias = None
+        if module.bias is not None:
+            bias = self.params[module.bias]
+
+        return self._conv2d_impl(
+            x,
+            weight,
+            bias=bias,
+            strides=module.stride,
+            padding=module.padding,
+            dilation=module.dilation,
+            groups=module.groups,
+        )
+
+    def _conv2d_functional(self, node: fx.node.Node) -> relax.Var:
+        args = self.retrieve_args(node)
+        x = args[0]
+        weight = args[1]
+        bias = args[2] if len(args) > 2 else None
+        stride = args[3] if len(args) > 3 else 1
+        padding = args[4] if len(args) > 4 else 0
+        dilation = args[5] if len(args) > 5 else 1
+        groups = args[6] if len(args) > 6 else 1
+        return self._conv2d_impl(
+            x,
+            weight,
+            bias=bias,
+            strides=stride,
+            padding=padding,
+            dilation=dilation,
+            groups=groups,
+        )
+
     def _max_pool2d(self, node: fx.node.Node) -> relax.Var:
         x = self.env[node.args[0]]
         if node.target in self.named_modules:
@@ -1272,6 +1321,8 @@ class TorchFXImporter:
             "type": self._type,
             "astype": self._type,
             "matmul": self._matmul,
+            "conv2d": self._conv2d_functional,
+            "linear": self._linear_functional,
             "addmm": self._addmm,
             "baddbmm": self._baddbmm,
             "bmm": self._matmul,
@@ -1395,7 +1446,7 @@ class TorchFXImporter:
                             output = self.block_builder.emit_output(args[0])
                         break
                     elif node.op == "get_attr":
-                        self.env[node] = TorchFXImporter._fetch_attr(model, 
node.target)
+                        self.env[node] = self._fetch_attr(model, node.target)
                     elif node.op == "call_module":
                         module = self.named_modules[node.target]
                         assert (
diff --git a/tests/python/relax/test_frontend_from_fx.py 
b/tests/python/relax/test_frontend_from_fx.py
index fe4bcbcf5a..2b95d3897d 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -206,6 +206,15 @@ def test_conv2d():
         def forward(self, input):
             return self.conv(input)
 
+    class Conv2D1Func(Module):
+        def __init__(self):
+            super().__init__()
+            self.weight = torch.randn(size=[6, 3, 7, 7])
+            self.bias = torch.randn(size=[6])
+
+        def forward(self, input):
+            return torch.nn.functional.conv2d(input, self.weight, self.bias)
+
     @tvm.script.ir_module
     class expected1:
         @R.function
@@ -271,6 +280,10 @@ def test_conv2d():
     binding = {"w1": model.conv.weight.detach().numpy(), "w2": 
model.conv.bias.detach().numpy()}
     verify_model(model, input_info, binding, expected1)
 
+    model = Conv2D1Func()
+    binding = {"w1": model.weight.numpy(), "w2": model.bias.numpy()}
+    verify_model(model, input_info, binding, expected1)
+
     model = Conv2D2()
     binding = {"w1": model.conv.weight.detach().numpy()}
     verify_model(model, input_info, binding, expected2)
@@ -365,6 +378,15 @@ def test_linear():
         def forward(self, input):
             return self.linear(input)
 
+    class Dense1Func(Module):
+        def __init__(self):
+            super().__init__()
+            self.weight = torch.randn(size=[7, 10])
+            self.bias = torch.randn(size=[7])
+
+        def forward(self, input):
+            return torch.nn.functional.linear(input, self.weight, self.bias)
+
     @tvm.script.ir_module
     class expected1:
         @R.function
@@ -415,6 +437,10 @@ def test_linear():
     binding = {"w1": model.linear.weight.detach().numpy(), "w2": 
model.linear.bias.detach().numpy()}
     verify_model(model, input_info, binding, expected1)
 
+    model = Dense1Func()
+    binding = {"w1": model.weight.numpy(), "w2": model.bias.numpy()}
+    verify_model(model, input_info, binding, expected1)
+
     model = Dense2()
     binding = {"w1": model.linear.weight.detach().numpy()}
     verify_model(model, input_info, binding, expected2)

Reply via email to