This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 406427ce3c [Unity][Frontend] Annotate number of non-static input of FX
function (#14067)
406427ce3c is described below
commit 406427ce3c1328024122de0f0a9385a51afe2c82
Author: Wuwei Lin <[email protected]>
AuthorDate: Tue Feb 21 10:18:09 2023 -0800
[Unity][Frontend] Annotate number of non-static input of FX function
(#14067)
---
python/tvm/relax/frontend/torch/fx_translator.py | 30 +++++++++++---
tests/python/relax/test_frontend_from_fx.py | 51 +++++++++++++++++++++++-
2 files changed, 73 insertions(+), 8 deletions(-)
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py
b/python/tvm/relax/frontend/torch/fx_translator.py
index a762b0a0fb..4acad61855 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -36,7 +36,7 @@ class TorchFXImporter:
from torch import fx
self.env: Dict[fx.node.Node, relax.Expr] = {}
- self.params: Dict[torch.Tensor, relax.Constant] = {}
+ self.params: Dict[torch.Tensor, relax.Expr] = {}
self.named_modules: Dict[str, torch.Module] = None
self.block_builder: relax.BlockBuilder = None
self.create_convert_map()
@@ -675,7 +675,9 @@ class TorchFXImporter:
"adaptive_avg_pool2d": self._adaptive_avg_pool2d(is_module=False),
}
- def from_fx(self, model, input_info: List[Tuple[Tuple[int], str]]) ->
tvm.IRModule:
+ def from_fx(
+ self, model, input_info: List[Tuple[Tuple[int], str]],
keep_params_as_input: bool
+ ) -> tvm.IRModule:
"""Convert a PyTorch FX GraphModule to a Relax program."""
from torch import fx
@@ -693,7 +695,17 @@ class TorchFXImporter:
# Initialize the block builder with a function and a dataflow block.
self.block_builder = relax.BlockBuilder()
- with self.block_builder.function(name="main", params=inputs.copy()):
+ if keep_params_as_input:
+ func_attrs = {"num_input": len(inputs)}
+ for name, param in model.named_parameters():
+ shape = param.data.shape
+ dtype = self._convert_data_type(str(param.data.dtype))
+ inputs.append(relax.Var(name, relax.TensorStructInfo(shape,
dtype)))
+ self.params[param] = inputs[-1]
+ else:
+ func_attrs = None
+
+ with self.block_builder.function(name="main", params=inputs.copy(),
attrs=func_attrs):
output = None
with self.block_builder.dataflow():
# Translate model parameters.
@@ -701,7 +713,8 @@ class TorchFXImporter:
shape = param.data.shape
dtype = self._convert_data_type(str(param.data.dtype))
if dtype in ("float32", "float16"):
- self.params[param] =
relax.const(param.data.cpu().numpy(), dtype)
+ if not keep_params_as_input:
+ self.params[param] =
relax.const(param.data.cpu().numpy(), dtype)
else:
raise ValueError("Unsupported data type for model
parameters: %s" % dtype)
# Translate the model.
@@ -740,7 +753,9 @@ class TorchFXImporter:
return self.block_builder.get()
-def from_fx(model, input_info: List[Tuple[Tuple[int], str]]) -> tvm.IRModule:
+def from_fx(
+ model, input_info: List[Tuple[Tuple[int], str]], keep_params_as_input:
bool = False
+) -> tvm.IRModule:
"""Convert a PyTorch FX GraphModule to a Relax program
Parameters
@@ -751,6 +766,9 @@ def from_fx(model, input_info: List[Tuple[Tuple[int],
str]]) -> tvm.IRModule:
input_info : List[Tuple[Tuple[int], str]]
A list of shapes and data types of input tensors.
+ keep_params_as_input : bool
+ Whether to keep model parameters as input variables.
+
Returns
-------
module : tvm.IRModule
@@ -814,4 +832,4 @@ def from_fx(model, input_info: List[Tuple[Tuple[int],
str]]) -> tvm.IRModule:
to print out the tabular representation of the PyTorch module, and then
check the placeholder rows in the beginning of the tabular.
"""
- return TorchFXImporter().from_fx(model, input_info)
+ return TorchFXImporter().from_fx(model, input_info, keep_params_as_input)
diff --git a/tests/python/relax/test_frontend_from_fx.py
b/tests/python/relax/test_frontend_from_fx.py
index 9b35d34bd3..24ed9946a3 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -22,12 +22,12 @@ import tvm.testing
from tvm.script.parser import relax as R, tir as T
-def verify_model(torch_model, input_info, binding, expected):
+def verify_model(torch_model, input_info, binding, expected,
keep_params_as_input=False):
from torch import fx
from tvm.relax.frontend.torch import from_fx
graph_model = fx.symbolic_trace(torch_model)
- mod = from_fx(graph_model, input_info)
+ mod = from_fx(graph_model, input_info,
keep_params_as_input=keep_params_as_input)
binding = {k: tvm.nd.array(v) for k, v in binding.items()}
expected = relax.transform.BindParams("main", binding)(expected)
tvm.ir.assert_structural_equal(mod, expected)
@@ -786,6 +786,7 @@ def test_binary():
input_info1 = [([1, 3, 10, 10], "float32"), ([1, 3, 10, 10], "float32")]
input_info2 = [([1, 3, 10, 10], "float32")]
+
# Add
class Add1(Module):
def forward(self, lhs, rhs):
@@ -1725,5 +1726,51 @@ def test_view():
verify_model(View(), input_info, {}, expected1)
[email protected]_gpu
+def test_keep_params():
+ import torch
+ from torch.nn import Module
+
+ class Conv2D1(Module):
+ def __init__(self):
+ super().__init__()
+ self.conv = torch.nn.Conv2d(3, 6, 7, bias=True)
+
+ def forward(self, input):
+ return self.conv(input)
+
+ @tvm.script.ir_module
+ class expected1:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+ w1: R.Tensor((6, 3, 7, 7), dtype="float32"),
+ w2: R.Tensor((6,), dtype="float32"),
+ ) -> R.Tensor((1, 6, 4, 4), dtype="float32"):
+ R.func_attr({"num_input": 1})
+ # block 0
+ with R.dataflow():
+ lv1: R.Tensor((1, 6, 4, 4), dtype="float32") = R.nn.conv2d(
+ input_1,
+ w1,
+ strides=[1, 1],
+ padding=[0, 0, 0, 0],
+ dilation=[1, 1],
+ data_layout="NCHW",
+ kernel_layout="OIHW",
+ out_layout="NCHW",
+ out_dtype="float32",
+ )
+ lv2: R.Tensor((1, 6, 1, 1), dtype="float32") = R.reshape(w2,
[1, 6, 1, 1])
+ lv3: R.Tensor((1, 6, 4, 4), dtype="float32") = R.add(lv1, lv2)
+ gv: R.Tensor((1, 6, 4, 4), dtype="float32") = lv3
+ R.output(gv)
+ return gv
+
+ model = Conv2D1()
+ input_info = [([1, 3, 10, 10], "float32")]
+ verify_model(model, input_info, {}, expected1, keep_params_as_input=True)
+
+
if __name__ == "__main__":
tvm.testing.main()