This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 406427ce3c [Unity][Frontend] Annotate number of non-static input of FX 
function (#14067)
406427ce3c is described below

commit 406427ce3c1328024122de0f0a9385a51afe2c82
Author: Wuwei Lin <[email protected]>
AuthorDate: Tue Feb 21 10:18:09 2023 -0800

    [Unity][Frontend] Annotate number of non-static input of FX function 
(#14067)
---
 python/tvm/relax/frontend/torch/fx_translator.py | 30 +++++++++++---
 tests/python/relax/test_frontend_from_fx.py      | 51 +++++++++++++++++++++++-
 2 files changed, 73 insertions(+), 8 deletions(-)

diff --git a/python/tvm/relax/frontend/torch/fx_translator.py 
b/python/tvm/relax/frontend/torch/fx_translator.py
index a762b0a0fb..4acad61855 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -36,7 +36,7 @@ class TorchFXImporter:
         from torch import fx
 
         self.env: Dict[fx.node.Node, relax.Expr] = {}
-        self.params: Dict[torch.Tensor, relax.Constant] = {}
+        self.params: Dict[torch.Tensor, relax.Expr] = {}
         self.named_modules: Dict[str, torch.Module] = None
         self.block_builder: relax.BlockBuilder = None
         self.create_convert_map()
@@ -675,7 +675,9 @@ class TorchFXImporter:
             "adaptive_avg_pool2d": self._adaptive_avg_pool2d(is_module=False),
         }
 
-    def from_fx(self, model, input_info: List[Tuple[Tuple[int], str]]) -> 
tvm.IRModule:
+    def from_fx(
+        self, model, input_info: List[Tuple[Tuple[int], str]], 
keep_params_as_input: bool
+    ) -> tvm.IRModule:
         """Convert a PyTorch FX GraphModule to a Relax program."""
         from torch import fx
 
@@ -693,7 +695,17 @@ class TorchFXImporter:
 
         # Initialize the block builder with a function and a dataflow block.
         self.block_builder = relax.BlockBuilder()
-        with self.block_builder.function(name="main", params=inputs.copy()):
+        if keep_params_as_input:
+            func_attrs = {"num_input": len(inputs)}
+            for name, param in model.named_parameters():
+                shape = param.data.shape
+                dtype = self._convert_data_type(str(param.data.dtype))
+                inputs.append(relax.Var(name, relax.TensorStructInfo(shape, 
dtype)))
+                self.params[param] = inputs[-1]
+        else:
+            func_attrs = None
+
+        with self.block_builder.function(name="main", params=inputs.copy(), 
attrs=func_attrs):
             output = None
             with self.block_builder.dataflow():
                 # Translate model parameters.
@@ -701,7 +713,8 @@ class TorchFXImporter:
                     shape = param.data.shape
                     dtype = self._convert_data_type(str(param.data.dtype))
                     if dtype in ("float32", "float16"):
-                        self.params[param] = 
relax.const(param.data.cpu().numpy(), dtype)
+                        if not keep_params_as_input:
+                            self.params[param] = 
relax.const(param.data.cpu().numpy(), dtype)
                     else:
                         raise ValueError("Unsupported data type for model 
parameters: %s" % dtype)
                 # Translate the model.
@@ -740,7 +753,9 @@ class TorchFXImporter:
         return self.block_builder.get()
 
 
-def from_fx(model, input_info: List[Tuple[Tuple[int], str]]) -> tvm.IRModule:
+def from_fx(
+    model, input_info: List[Tuple[Tuple[int], str]], keep_params_as_input: 
bool = False
+) -> tvm.IRModule:
     """Convert a PyTorch FX GraphModule to a Relax program
 
     Parameters
@@ -751,6 +766,9 @@ def from_fx(model, input_info: List[Tuple[Tuple[int], 
str]]) -> tvm.IRModule:
     input_info : List[Tuple[Tuple[int], str]]
         A list of shapes and data types of input tensors.
 
+    keep_params_as_input : bool
+        Whether to keep model parameters as input variables.
+
     Returns
     -------
     module : tvm.IRModule
@@ -814,4 +832,4 @@ def from_fx(model, input_info: List[Tuple[Tuple[int], 
str]]) -> tvm.IRModule:
     to print out the tabular representation of the PyTorch module, and then
     check the placeholder rows in the beginning of the tabular.
     """
-    return TorchFXImporter().from_fx(model, input_info)
+    return TorchFXImporter().from_fx(model, input_info, keep_params_as_input)
diff --git a/tests/python/relax/test_frontend_from_fx.py 
b/tests/python/relax/test_frontend_from_fx.py
index 9b35d34bd3..24ed9946a3 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -22,12 +22,12 @@ import tvm.testing
 from tvm.script.parser import relax as R, tir as T
 
 
-def verify_model(torch_model, input_info, binding, expected):
+def verify_model(torch_model, input_info, binding, expected, 
keep_params_as_input=False):
     from torch import fx
     from tvm.relax.frontend.torch import from_fx
 
     graph_model = fx.symbolic_trace(torch_model)
-    mod = from_fx(graph_model, input_info)
+    mod = from_fx(graph_model, input_info, 
keep_params_as_input=keep_params_as_input)
     binding = {k: tvm.nd.array(v) for k, v in binding.items()}
     expected = relax.transform.BindParams("main", binding)(expected)
     tvm.ir.assert_structural_equal(mod, expected)
@@ -786,6 +786,7 @@ def test_binary():
 
     input_info1 = [([1, 3, 10, 10], "float32"), ([1, 3, 10, 10], "float32")]
     input_info2 = [([1, 3, 10, 10], "float32")]
+
     # Add
     class Add1(Module):
         def forward(self, lhs, rhs):
@@ -1725,5 +1726,51 @@ def test_view():
     verify_model(View(), input_info, {}, expected1)
 
 
[email protected]_gpu
+def test_keep_params():
+    import torch
+    from torch.nn import Module
+
+    class Conv2D1(Module):
+        def __init__(self):
+            super().__init__()
+            self.conv = torch.nn.Conv2d(3, 6, 7, bias=True)
+
+        def forward(self, input):
+            return self.conv(input)
+
+    @tvm.script.ir_module
+    class expected1:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+            w1: R.Tensor((6, 3, 7, 7), dtype="float32"),
+            w2: R.Tensor((6,), dtype="float32"),
+        ) -> R.Tensor((1, 6, 4, 4), dtype="float32"):
+            R.func_attr({"num_input": 1})
+            # block 0
+            with R.dataflow():
+                lv1: R.Tensor((1, 6, 4, 4), dtype="float32") = R.nn.conv2d(
+                    input_1,
+                    w1,
+                    strides=[1, 1],
+                    padding=[0, 0, 0, 0],
+                    dilation=[1, 1],
+                    data_layout="NCHW",
+                    kernel_layout="OIHW",
+                    out_layout="NCHW",
+                    out_dtype="float32",
+                )
+                lv2: R.Tensor((1, 6, 1, 1), dtype="float32") = R.reshape(w2, 
[1, 6, 1, 1])
+                lv3: R.Tensor((1, 6, 4, 4), dtype="float32") = R.add(lv1, lv2)
+                gv: R.Tensor((1, 6, 4, 4), dtype="float32") = lv3
+                R.output(gv)
+            return gv
+
+    model = Conv2D1()
+    input_info = [([1, 3, 10, 10], "float32")]
+    verify_model(model, input_info, {}, expected1, keep_params_as_input=True)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to