This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new fcb0127925 [Unity][Fix] FX translating dtype (#14201)
fcb0127925 is described below

commit fcb01279252b7af3bc385a9f2ede15771ec6321a
Author: Ruihang Lai <[email protected]>
AuthorDate: Sun Mar 5 07:51:22 2023 -0500

    [Unity][Fix] FX translating dtype (#14201)
    
    This PR fixes a bug of the current FX translator when dealing with
    dtype.
    
    Previously, the translator does not take the cases
    ```python
    dtype = x.getattr("dtype")
    ```
    into consideration. In this case, the dtype will be a fx.Node object,
    while the translator assumes that the dtype is either a string or
    a torch native datatype (e.g., torch.float32).
    
    This PR fixes this by doing an environment table lookup before for all
    dtypes.
---
 python/tvm/relax/frontend/torch/fx_translator.py | 15 +++++++++------
 tests/python/relax/test_frontend_from_fx.py      | 14 ++++++++++----
 2 files changed, 19 insertions(+), 10 deletions(-)

diff --git a/python/tvm/relax/frontend/torch/fx_translator.py 
b/python/tvm/relax/frontend/torch/fx_translator.py
index a73bc9d0db..fa68b2eee3 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -18,7 +18,7 @@
 # pylint: disable=invalid-name, inconsistent-return-statements, 
unidiomatic-typecheck
 # pylint: disable=import-outside-toplevel
 """PyTorch FX frontend of Relax."""
-from typing import Callable, Dict, List, Tuple, Union
+from typing import Callable, Dict, List, Optional, Tuple, Union
 from functools import reduce
 
 import tvm
@@ -61,10 +61,13 @@ class TorchFXImporter:
         return attr_itr
 
     @staticmethod
-    def _convert_data_type(input_type):
+    def _convert_data_type(input_type, env: Optional[Dict] = None):
         """converts the PyTorch scalar type input_type to a TVM dtype."""
         import torch  # type: ignore
 
+        if env is not None and input_type in env:
+            input_type = env[input_type]
+
         input_type = input_type.lower() if isinstance(input_type, str) else 
input_type
         if input_type in ["float", "float32", "torch.float32", torch.float32]:
             return "float32"
@@ -247,7 +250,7 @@ class TorchFXImporter:
             start_end_step[2] = 1
 
         if "dtype" in node.kwargs:
-            dtype = 
TorchFXImporter._convert_data_type(str(node.kwargs["dtype"]))
+            dtype = 
TorchFXImporter._convert_data_type(str(node.kwargs["dtype"]), self.env)
         elif any([isinstance(x, float) for x in start_end_step]):
             dtype = 
TorchFXImporter._convert_data_type(torch.get_default_dtype())
         else:
@@ -256,7 +259,7 @@ class TorchFXImporter:
         return relax.const(np.arange(*start_end_step, dtype=dtype))
 
     def _empty(self, node: fx.node.Node) -> relax.Var:
-        dtype = TorchFXImporter._convert_data_type(str(node.kwargs["dtype"]))
+        dtype = TorchFXImporter._convert_data_type(str(node.kwargs["dtype"]), 
self.env)
         return self.block_builder.emit(relax.op.zeros(node.args, dtype))
 
     def _inplace_fill(self, node: fx.node.Node) -> relax.Var:
@@ -334,7 +337,7 @@ class TorchFXImporter:
 
     def _type(self, node: fx.node.Node) -> relax.Var:
         x = self.env[node.args[0]]
-        dtype = self._convert_data_type(node.args[1])
+        dtype = TorchFXImporter._convert_data_type(node.args[1], self.env)
         return self.block_builder.emit(relax.op.astype(x, dtype))
 
     ########## Linear Algebra ##########
@@ -565,7 +568,7 @@ class TorchFXImporter:
         module = self.named_modules[node.target]
         weight = self.params[module.weight]
         bias = self.params[module.bias]
-        dtype = self._convert_data_type(str(module.running_mean.dtype))
+        dtype = 
TorchFXImporter._convert_data_type(str(module.running_mean.dtype))
         running_mean = relax.const(module.running_mean.cpu().detach().numpy(), 
dtype)
         running_var = relax.const(module.running_var.cpu().detach().numpy(), 
dtype)
         eps = module.eps
diff --git a/tests/python/relax/test_frontend_from_fx.py 
b/tests/python/relax/test_frontend_from_fx.py
index e28483dc2f..4fd7cee812 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -1683,7 +1683,7 @@ def test_arange():
             return torch.arange(0, 20, dtype=torch.int32)
 
     graph_model = fx.symbolic_trace(Arange())
-    mod = from_fx(graph_model, [([10, 10], "float32")])
+    mod = from_fx(graph_model, [([10, 10], "float32")]).mod
     assert len(mod["main"].body.blocks) == 1
     assert len(mod["main"].body.blocks[0].bindings) == 1
     assert isinstance(mod["main"].body.blocks[0].bindings[0].value, 
relax.Constant)
@@ -1707,7 +1707,7 @@ def test_empty():
             return torch.empty((10, 10), dtype=torch.float32)
 
     graph_model = fx.symbolic_trace(Empty())
-    mod = from_fx(graph_model, [([10, 10], "float32")])
+    mod = from_fx(graph_model, [([10, 10], "float32")]).mod
     assert len(mod["main"].body.blocks) == 1
     assert len(mod["main"].body.blocks[0].bindings) == 1
     assert isinstance(mod["main"].body.blocks[0].bindings[0].value, 
relax.Constant)
@@ -1734,7 +1734,7 @@ def test_tensor():
             return torch.tensor(3)
 
     graph_model1 = fx.symbolic_trace(Empty1())
-    mod1 = from_fx(graph_model1, [([10, 10], "float32")])
+    mod1 = from_fx(graph_model1, [([10, 10], "float32")]).mod
     assert len(mod1["main"].body.blocks) == 1
     assert len(mod1["main"].body.blocks[0].bindings) == 1
     assert isinstance(mod1["main"].body.blocks[0].bindings[0].value, 
relax.Constant)
@@ -1742,7 +1742,7 @@ def test_tensor():
     assert mod1["main"].body.blocks[0].bindings[0].value.data.dtype == 
"float32"
 
     graph_model2 = fx.symbolic_trace(Empty2())
-    mod2 = from_fx(graph_model2, [([10, 10], "float32")])
+    mod2 = from_fx(graph_model2, [([10, 10], "float32")]).mod
     assert len(mod2["main"].body.blocks) == 1
     assert len(mod2["main"].body.blocks[0].bindings) == 1
     assert isinstance(mod2["main"].body.blocks[0].bindings[0].value, 
relax.Constant)
@@ -1968,12 +1968,18 @@ def test_datatype():
         def forward(self, x):
             return x.type(torch.float32)
 
+    # type
+    class TypeFromAttr(Module):
+        def forward(self, x):
+            return x.type(x.getattr("dtype"))
+
     # astype
     class AsType(Module):
         def forward(self, x):
             return x.astype(torch.float32)
 
     verify_model(Type(), input_info, {}, expected1)
+    verify_model(TypeFromAttr(), input_info, {}, expected1)
     verify_model(AsType(), input_info, {}, expected1)
 
 

Reply via email to