This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 001d5ec90c [Relax][PyTorch][Docs] Use `torch.export` insteamd of 
`fx.symbolic_trace` for tutorial (#17436)
001d5ec90c is described below

commit 001d5ec90c2821b16f9d4edd913dfeff03c027a3
Author: Masahiro Hiramori <[email protected]>
AuthorDate: Tue Oct 8 09:57:27 2024 +0900

    [Relax][PyTorch][Docs] Use `torch.export` insteamd of `fx.symbolic_trace` 
for tutorial (#17436)
    
    * use torch.export
    
    * in order to make interface consistent, user inputs should be placed first
    
    * chore
---
 docs/get_started/tutorials/ir_module.py            | 15 +++--
 docs/how_to/tutorials/e2e_opt_model.py             | 18 +++---
 .../frontend/torch/exported_program_translator.py  | 71 +++++++++++-----------
 .../relax/test_frontend_from_exported_program.py   |  4 +-
 4 files changed, 56 insertions(+), 52 deletions(-)

diff --git a/docs/get_started/tutorials/ir_module.py 
b/docs/get_started/tutorials/ir_module.py
index f813333baf..0a825c3da7 100644
--- a/docs/get_started/tutorials/ir_module.py
+++ b/docs/get_started/tutorials/ir_module.py
@@ -40,8 +40,9 @@ from tvm import relax
 # below.
 
 import torch
-from torch import fx, nn
-from tvm.relax.frontend.torch import from_fx
+from torch import nn
+from torch.export import export
+from tvm.relax.frontend.torch import from_exported_program
 
 ######################################################################
 # Import from existing models
@@ -67,13 +68,15 @@ class TorchModel(nn.Module):
         return x
 
 
-# Give the input shape and data type
-input_info = [((1, 784), "float32")]
+# Give an example argument to torch.export
+example_args = (torch.randn(1, 784, dtype=torch.float32),)
 
 # Convert the model to IRModule
 with torch.no_grad():
-    torch_fx_model = fx.symbolic_trace(TorchModel())
-    mod_from_torch = from_fx(torch_fx_model, input_info, 
keep_params_as_input=True)
+    exported_program = export(TorchModel().eval(), example_args)
+    mod_from_torch = from_exported_program(
+        exported_program, keep_params_as_input=True, 
unwrap_unit_return_tuple=True
+    )
 
 mod_from_torch, params_from_torch = 
relax.frontend.detach_params(mod_from_torch)
 # Print the IRModule
diff --git a/docs/how_to/tutorials/e2e_opt_model.py 
b/docs/how_to/tutorials/e2e_opt_model.py
index 5c11439e16..532fb89fd3 100644
--- a/docs/how_to/tutorials/e2e_opt_model.py
+++ b/docs/how_to/tutorials/e2e_opt_model.py
@@ -34,10 +34,10 @@ Please note that default end-to-end optimization may not 
suit complex models.
 import os
 import numpy as np
 import torch
-from torch import fx
+from torch.export import export
 from torchvision.models.resnet import ResNet18_Weights, resnet18
 
-torch_model = resnet18(weights=ResNet18_Weights.DEFAULT)
+torch_model = resnet18(weights=ResNet18_Weights.DEFAULT).eval()
 
 ######################################################################
 # Review Overall Flow
@@ -63,21 +63,19 @@ torch_model = resnet18(weights=ResNet18_Weights.DEFAULT)
 # Convert the model to IRModule
 # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 # Next step, we convert the model to an IRModule using the Relax frontend for 
PyTorch for further
-# optimization. Besides the model, we also need to provide the input shape and 
data type.
+# optimization.
 
 import tvm
 from tvm import relax
-from tvm.relax.frontend.torch import from_fx
+from tvm.relax.frontend.torch import from_exported_program
 
-torch_model = resnet18(weights=ResNet18_Weights.DEFAULT)
-
-# Give the input shape and data type
-input_info = [((1, 3, 224, 224), "float32")]
+# Give an example argument to torch.export
+example_args = (torch.randn(1, 3, 224, 224, dtype=torch.float32),)
 
 # Convert the model to IRModule
 with torch.no_grad():
-    torch_fx_model = fx.symbolic_trace(torch_model)
-    mod = from_fx(torch_fx_model, input_info, keep_params_as_input=True)
+    exported_program = export(torch_model, example_args)
+    mod = from_exported_program(exported_program, keep_params_as_input=True)
 
 mod, params = relax.frontend.detach_params(mod)
 mod.show()
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py 
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 1401a0bcef..7bcd20c462 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -34,37 +34,6 @@ class ExportedProgramImporter(BaseFXGraphImporter):
 
     from torch import fx
 
-    def create_input_vars(
-        self, exported_program: torch.export.ExportedProgram
-    ) -> Tuple[List[relax.Var], List[relax.Var]]:
-        """Create relax input vars."""
-        parameters_buffers_constants = []
-        user_inputs = []
-        for spec in exported_program.graph_signature.input_specs:
-            name_hint = spec.arg.name
-            if spec.kind is 
torch.export.graph_signature.InputKind.CONSTANT_TENSOR:
-                shape = exported_program.tensor_constants[spec.target].shape
-                torch_dtype = 
exported_program.tensor_constants[spec.target].dtype
-            elif spec.kind is 
torch.export.graph_signature.InputKind.USER_INPUT:
-                for node in 
exported_program.graph.find_nodes(op="placeholder", target=spec.target):
-                    if node.name == name_hint:
-                        shape = node.meta["tensor_meta"].shape
-                        torch_dtype = node.meta["tensor_meta"].dtype
-                        break
-            else:
-                # PARAMETER or BUFFER
-                shape = exported_program.state_dict[spec.target].shape
-                torch_dtype = exported_program.state_dict[spec.target].dtype
-
-            dtype = self._convert_data_type(torch_dtype)
-            relax_var = relax.Var(name_hint, relax.TensorStructInfo(shape, 
dtype))
-            if spec.kind is torch.export.graph_signature.InputKind.USER_INPUT:
-                user_inputs.append(relax_var)
-            else:
-                parameters_buffers_constants.append(relax_var)
-
-        return parameters_buffers_constants, user_inputs
-
     ########## Unary Ops ##########
 
     def _hardtanh(self, node: fx.Node) -> relax.Expr:
@@ -178,6 +147,8 @@ class ExportedProgramImporter(BaseFXGraphImporter):
         stride = [node.args[4] if len(node.args) > 4 else 1]
         return self.block_builder.emit(relax.op.strided_slice(x, axes, begin, 
end, stride))
 
+    ########## Others ##########
+
     def create_convert_map(
         self,
     ) -> Dict[str, Callable[[fx.Node], relax.Var]]:
@@ -293,6 +264,37 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             "getitem": self._getitem,
         }
 
+    def create_input_vars(
+        self, exported_program: torch.export.ExportedProgram
+    ) -> Tuple[Dict[str, relax.Var], Dict[str, relax.Var]]:
+        """Create relax input vars."""
+        parameters_buffers_constants = OrderedDict()
+        user_inputs = OrderedDict()
+        for spec in exported_program.graph_signature.input_specs:
+            name_hint = spec.arg.name
+            if spec.kind is 
torch.export.graph_signature.InputKind.CONSTANT_TENSOR:
+                shape = exported_program.tensor_constants[spec.target].shape
+                torch_dtype = 
exported_program.tensor_constants[spec.target].dtype
+            elif spec.kind is 
torch.export.graph_signature.InputKind.USER_INPUT:
+                for node in 
exported_program.graph.find_nodes(op="placeholder", target=spec.target):
+                    if node.name == name_hint:
+                        shape = node.meta["tensor_meta"].shape
+                        torch_dtype = node.meta["tensor_meta"].dtype
+                        break
+            else:
+                # PARAMETER or BUFFER
+                shape = exported_program.state_dict[spec.target].shape
+                torch_dtype = exported_program.state_dict[spec.target].dtype
+
+            dtype = self._convert_data_type(torch_dtype)
+            relax_var = relax.Var(name_hint, relax.TensorStructInfo(shape, 
dtype))
+            if spec.kind is torch.export.graph_signature.InputKind.USER_INPUT:
+                user_inputs[name_hint] = relax_var
+            else:
+                parameters_buffers_constants[name_hint] = relax_var
+
+        return parameters_buffers_constants, user_inputs
+
     def from_exported_program(
         self,
         exported_program: torch.export.ExportedProgram,
@@ -305,7 +307,8 @@ class ExportedProgramImporter(BaseFXGraphImporter):
 
         # Create input variables.
         parameter_buffer_constant_vars, user_input_vars = 
self.create_input_vars(exported_program)
-        inputs_vars = parameter_buffer_constant_vars + user_input_vars
+        inputs_vars = user_input_vars.copy()
+        inputs_vars.update(parameter_buffer_constant_vars)
 
         # Initialize the block builder with a function and a dataflow block.
         self.block_builder = relax.BlockBuilder()
@@ -314,7 +317,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
 
         nodes: List[fx.Node] = exported_program.graph.nodes
         with self.block_builder.function(
-            name=func_name, params=inputs_vars.copy(), attrs=func_attrs
+            name=func_name, params=list(inputs_vars.values()).copy(), 
attrs=func_attrs
         ):
             output = None
             with self.block_builder.dataflow():
@@ -325,7 +328,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
                             # Ignore sym input
                             continue
 
-                        self.env[node] = inputs_vars.pop(0)
+                        self.env[node] = inputs_vars[node.name]
                     elif node.op == "output":
                         args = self.retrieve_args(node)
                         assert len(args) == 1
diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index 65890ff697..0d8425fc7f 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -3550,9 +3550,9 @@ def test_keep_params():
     class expected1:
         @R.function
         def main(
+            input_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
             conv_weight: R.Tensor((6, 3, 7, 7), dtype="float32"),
             conv_bias: R.Tensor((6,), dtype="float32"),
-            input_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
         ) -> R.Tuple(R.Tensor((1, 6, 4, 4), dtype="float32")):
             R.func_attr({"num_input": 1})
             # block 0
@@ -3586,7 +3586,7 @@ def test_keep_params():
     params = params["main"]
 
     assert len(params) == len(func.params) - 1
-    for param_var, param_ndarray in zip(func.params[:-1], params):
+    for param_var, param_ndarray in zip(func.params[1:], params):
         assert tuple(x.value for x in param_var.struct_info.shape.values) == 
param_ndarray.shape
         assert param_var.struct_info.dtype == param_ndarray.dtype
 

Reply via email to