This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new fcb0127925 [Unity][Fix] FX translating dtype (#14201)
fcb0127925 is described below
commit fcb01279252b7af3bc385a9f2ede15771ec6321a
Author: Ruihang Lai <[email protected]>
AuthorDate: Sun Mar 5 07:51:22 2023 -0500
[Unity][Fix] FX translating dtype (#14201)
This PR fixes a bug of the current FX translator when dealing with
dtype.
Previously, the translator does not take the cases
```python
dtype = x.getattr("dtype")
```
into consideration. In this case, the dtype will be a fx.Node object,
while the translator assumes that the dtype is either a string or
a torch native datatype (e.g., torch.float32).
This PR fixes this by doing an environment table lookup before for all
dtypes.
---
python/tvm/relax/frontend/torch/fx_translator.py | 15 +++++++++------
tests/python/relax/test_frontend_from_fx.py | 14 ++++++++++----
2 files changed, 19 insertions(+), 10 deletions(-)
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py
b/python/tvm/relax/frontend/torch/fx_translator.py
index a73bc9d0db..fa68b2eee3 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -18,7 +18,7 @@
# pylint: disable=invalid-name, inconsistent-return-statements,
unidiomatic-typecheck
# pylint: disable=import-outside-toplevel
"""PyTorch FX frontend of Relax."""
-from typing import Callable, Dict, List, Tuple, Union
+from typing import Callable, Dict, List, Optional, Tuple, Union
from functools import reduce
import tvm
@@ -61,10 +61,13 @@ class TorchFXImporter:
return attr_itr
@staticmethod
- def _convert_data_type(input_type):
+ def _convert_data_type(input_type, env: Optional[Dict] = None):
"""converts the PyTorch scalar type input_type to a TVM dtype."""
import torch # type: ignore
+ if env is not None and input_type in env:
+ input_type = env[input_type]
+
input_type = input_type.lower() if isinstance(input_type, str) else
input_type
if input_type in ["float", "float32", "torch.float32", torch.float32]:
return "float32"
@@ -247,7 +250,7 @@ class TorchFXImporter:
start_end_step[2] = 1
if "dtype" in node.kwargs:
- dtype =
TorchFXImporter._convert_data_type(str(node.kwargs["dtype"]))
+ dtype =
TorchFXImporter._convert_data_type(str(node.kwargs["dtype"]), self.env)
elif any([isinstance(x, float) for x in start_end_step]):
dtype =
TorchFXImporter._convert_data_type(torch.get_default_dtype())
else:
@@ -256,7 +259,7 @@ class TorchFXImporter:
return relax.const(np.arange(*start_end_step, dtype=dtype))
def _empty(self, node: fx.node.Node) -> relax.Var:
- dtype = TorchFXImporter._convert_data_type(str(node.kwargs["dtype"]))
+ dtype = TorchFXImporter._convert_data_type(str(node.kwargs["dtype"]),
self.env)
return self.block_builder.emit(relax.op.zeros(node.args, dtype))
def _inplace_fill(self, node: fx.node.Node) -> relax.Var:
@@ -334,7 +337,7 @@ class TorchFXImporter:
def _type(self, node: fx.node.Node) -> relax.Var:
x = self.env[node.args[0]]
- dtype = self._convert_data_type(node.args[1])
+ dtype = TorchFXImporter._convert_data_type(node.args[1], self.env)
return self.block_builder.emit(relax.op.astype(x, dtype))
########## Linear Algebra ##########
@@ -565,7 +568,7 @@ class TorchFXImporter:
module = self.named_modules[node.target]
weight = self.params[module.weight]
bias = self.params[module.bias]
- dtype = self._convert_data_type(str(module.running_mean.dtype))
+ dtype =
TorchFXImporter._convert_data_type(str(module.running_mean.dtype))
running_mean = relax.const(module.running_mean.cpu().detach().numpy(),
dtype)
running_var = relax.const(module.running_var.cpu().detach().numpy(),
dtype)
eps = module.eps
diff --git a/tests/python/relax/test_frontend_from_fx.py
b/tests/python/relax/test_frontend_from_fx.py
index e28483dc2f..4fd7cee812 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -1683,7 +1683,7 @@ def test_arange():
return torch.arange(0, 20, dtype=torch.int32)
graph_model = fx.symbolic_trace(Arange())
- mod = from_fx(graph_model, [([10, 10], "float32")])
+ mod = from_fx(graph_model, [([10, 10], "float32")]).mod
assert len(mod["main"].body.blocks) == 1
assert len(mod["main"].body.blocks[0].bindings) == 1
assert isinstance(mod["main"].body.blocks[0].bindings[0].value,
relax.Constant)
@@ -1707,7 +1707,7 @@ def test_empty():
return torch.empty((10, 10), dtype=torch.float32)
graph_model = fx.symbolic_trace(Empty())
- mod = from_fx(graph_model, [([10, 10], "float32")])
+ mod = from_fx(graph_model, [([10, 10], "float32")]).mod
assert len(mod["main"].body.blocks) == 1
assert len(mod["main"].body.blocks[0].bindings) == 1
assert isinstance(mod["main"].body.blocks[0].bindings[0].value,
relax.Constant)
@@ -1734,7 +1734,7 @@ def test_tensor():
return torch.tensor(3)
graph_model1 = fx.symbolic_trace(Empty1())
- mod1 = from_fx(graph_model1, [([10, 10], "float32")])
+ mod1 = from_fx(graph_model1, [([10, 10], "float32")]).mod
assert len(mod1["main"].body.blocks) == 1
assert len(mod1["main"].body.blocks[0].bindings) == 1
assert isinstance(mod1["main"].body.blocks[0].bindings[0].value,
relax.Constant)
@@ -1742,7 +1742,7 @@ def test_tensor():
assert mod1["main"].body.blocks[0].bindings[0].value.data.dtype ==
"float32"
graph_model2 = fx.symbolic_trace(Empty2())
- mod2 = from_fx(graph_model2, [([10, 10], "float32")])
+ mod2 = from_fx(graph_model2, [([10, 10], "float32")]).mod
assert len(mod2["main"].body.blocks) == 1
assert len(mod2["main"].body.blocks[0].bindings) == 1
assert isinstance(mod2["main"].body.blocks[0].bindings[0].value,
relax.Constant)
@@ -1968,12 +1968,18 @@ def test_datatype():
def forward(self, x):
return x.type(torch.float32)
+ # type
+ class TypeFromAttr(Module):
+ def forward(self, x):
+ return x.type(x.getattr("dtype"))
+
# astype
class AsType(Module):
def forward(self, x):
return x.astype(torch.float32)
verify_model(Type(), input_info, {}, expected1)
+ verify_model(TypeFromAttr(), input_info, {}, expected1)
verify_model(AsType(), input_info, {}, expected1)