This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new d86d0b201c [Unity][Frontend] Attach imported model weights, deprecate
ImporterOutput (#14211)
d86d0b201c is described below
commit d86d0b201ceb163a862aa0921feab77f7ebed73b
Author: Ruihang Lai <[email protected]>
AuthorDate: Mon Mar 6 13:48:17 2023 -0500
[Unity][Frontend] Attach imported model weights, deprecate ImporterOutput
(#14211)
The class ImporterOutput introduced by #14197 turns out introducing
inconvenience for use, as it is a wrapper of IRModule and the
parameters, and prevents good composability when we want to combine
multiple ImporterOutput together.
From the perspective of easy use, we would like to return IRModule only.
Therefore, as another approach, we can make the imported parameters as
one function attribute, and detach the parameters later.
This PR implements this and provides the param detachment function. With
only IRModule being manipulated, we can easily combine the results from
multiple imports.
---
python/tvm/relax/frontend/__init__.py | 3 +-
python/tvm/relax/frontend/common.py | 51 ++++++++++++----------
python/tvm/relax/frontend/torch/dynamo.py | 21 ++++-----
python/tvm/relax/frontend/torch/fx_translator.py | 28 ++++++------
.../python/relax/test_frontend_common.py | 30 ++++++++++---
tests/python/relax/test_frontend_dynamo.py | 4 +-
tests/python/relax/test_frontend_from_fx.py | 22 +++++-----
7 files changed, 93 insertions(+), 66 deletions(-)
diff --git a/python/tvm/relax/frontend/__init__.py
b/python/tvm/relax/frontend/__init__.py
index f3c0ed23eb..4baf3195f0 100644
--- a/python/tvm/relax/frontend/__init__.py
+++ b/python/tvm/relax/frontend/__init__.py
@@ -17,5 +17,4 @@
"""
Frontends for constructing Relax programs, with the model importers
"""
-from . import torch
-from .common import ImporterOutput
+from .common import detach_params
diff --git a/python/tvm/relax/frontend/common.py
b/python/tvm/relax/frontend/common.py
index cdb88cd12c..e4432d8c67 100644
--- a/python/tvm/relax/frontend/common.py
+++ b/python/tvm/relax/frontend/common.py
@@ -14,35 +14,42 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+# pylint: disable=invalid-name
"""Commons for Relax frontend."""
-from typing import Dict, List, Optional
+from typing import Dict, List, Tuple
import tvm
-class ImporterOutput:
- """The data structure representing the result of frontend imports.
+def detach_params(mod: tvm.IRModule) -> Tuple[tvm.IRModule, Dict[str,
List[tvm.nd.NDArray]]]:
+ """Detach the attribute "params" in the functions of the input IRModule as
+ separate dictionary of params.
- Attributes
+ Parameters
----------
mod : tvm.IRModule
- The IRModule imported from frontend.
+ The IRModule whose functions' "param" attribute is going to be
detached.
- params : Optional[Dict[str, List[tvm.nd.NDArray]]]
- The weights of the imported model, when the weights of the model are
- requested to be kept as parameters of functions in the IRModule. (e.g.,
- when the `keep_params_as_input` flag of `frontend.torch.from_fx` is
set to
- True.)
- - `params` is defined to be None when not requested.
- - The keys of `params` are the names of the Relax functions in the
IRModule.
- - Each weight tensor is in the form of TVM NDArray on device CPU.
- - The order of the returned weights is in accordance with the order of
- the kept Relax function input variables.
- """
-
- mod: tvm.IRModule
- params: Optional[Dict[str, List[tvm.nd.NDArray]]]
+ Returns
+ -------
+ detached_mod : tvm.IRModule
+ The IRModule after the detachment.
- def __init__(self, mod: tvm.IRModule, params: Optional[Dict[str,
List[tvm.nd.NDArray]]]):
- self.mod = mod
- self.params = params
+ params_dict : Dict[str, List[tvm.nd.NDArray]]
+ The detached params. The dict keys corresponds to the names of the
+ functions in the input IRModule that have attribute "params".
+ """
+ detached_mod = tvm.IRModule()
+ params_dict = dict()
+ for gv, func in mod.functions.items():
+ if "params" in func.attrs:
+ params = list(func.attrs["params"])
+ if not all([isinstance(param, tvm.nd.NDArray) for param in
params]):
+ raise ValueError(
+ 'The value "params" attribute is expected to be a list of
NDArray.'
+ )
+ params_dict[gv.name_hint] = params
+ detached_mod[gv] = func.without_attr("params")
+ else:
+ detached_mod[gv] = func
+ return detached_mod, params_dict
diff --git a/python/tvm/relax/frontend/torch/dynamo.py
b/python/tvm/relax/frontend/torch/dynamo.py
index 3f30044bb8..a7fb9bc015 100644
--- a/python/tvm/relax/frontend/torch/dynamo.py
+++ b/python/tvm/relax/frontend/torch/dynamo.py
@@ -26,7 +26,6 @@ import tvm
from tvm.relax import build as relax_build
from .fx_translator import from_fx
-from ..common import ImporterOutput
def device_from_inputs(example_inputs):
@@ -74,7 +73,7 @@ def relax_dynamo(pipeline: Optional[tvm.transform.Pass] =
None):
device = device_from_inputs(example_inputs)
input_info = [(tuple(tensor.shape), str(tensor.dtype)) for tensor in
example_inputs]
- mod = from_fx(graph_module, input_info).mod
+ mod = from_fx(graph_module, input_info)
if device.type == "cuda":
dev = tvm.cuda(device.index)
@@ -116,7 +115,7 @@ def relax_dynamo(pipeline: Optional[tvm.transform.Pass] =
None):
return _relax_backend
-def dynamo_capture_subgraphs(model, *params, **kwargs) -> ImporterOutput:
+def dynamo_capture_subgraphs(model, *params, **kwargs) -> tvm.IRModule:
"""Capture subgraphs of the PyTorch model using torch.compile into an
IRModule.
Parameters
@@ -133,8 +132,10 @@ def dynamo_capture_subgraphs(model, *params, **kwargs) ->
ImporterOutput:
Returns
-------
output : ImporterOutput
- The output of translation, including the translated IRModule, and
- the weights of the input model when `keep_params_as_input` is true.
+ The output of translation, including the translated IRModule.
+ If `keep_params_as_input` is true, the functions in the IRModule have
an
+ attribute "params" that contains the weights of the input model. The
+ weights can be detached by `relax.frontend.detach_params`.
"""
import torch # type: ignore[import]
from torch import fx # type: ignore[import]
@@ -143,22 +144,18 @@ def dynamo_capture_subgraphs(model, *params, **kwargs) ->
ImporterOutput:
keep_params_as_input = "keep_params_as_input" in kwargs and
kwargs["keep_params_as_input"]
mod = tvm.IRModule()
- params_ndarray = dict() if keep_params_as_input else None
def _capture(graph_module: fx.GraphModule, example_inputs):
assert isinstance(graph_module, torch.fx.GraphModule)
input_info = [(tuple(tensor.shape), str(tensor.dtype)) for tensor in
example_inputs]
- trace_output = from_fx(graph_module, input_info, keep_params_as_input)
- func_name = f"subgraph_{len(mod.get_global_vars())}"
- mod[func_name] = trace_output.mod["main"]
- if keep_params_as_input:
- params_ndarray[func_name] = trace_output.params["main"]
+ mod_ = from_fx(graph_module, input_info, keep_params_as_input)
+ mod[f"subgraph_{len(mod.get_global_vars())}"] = mod_["main"]
return graph_module.forward
dynamo.reset()
compiled_model = torch.compile(model, backend=_capture)
compiled_model(*params)
- return ImporterOutput(mod, params_ndarray)
+ return mod
@functools.lru_cache(None)
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py
b/python/tvm/relax/frontend/torch/fx_translator.py
index c89b15a7d5..b4a77ccb33 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -24,8 +24,6 @@ from functools import reduce
import tvm
from tvm import relax
-from ..common import ImporterOutput
-
class TorchFXImporter:
"""An importer from PyTorch FX to Relax."""
@@ -899,7 +897,7 @@ class TorchFXImporter:
def from_fx(
self, model, input_info: List[Tuple[Tuple[int], str]],
keep_params_as_input: bool
- ) -> ImporterOutput:
+ ) -> tvm.IRModule:
"""Convert a PyTorch FX GraphModule to a Relax program."""
from torch import fx
@@ -918,18 +916,16 @@ class TorchFXImporter:
# Initialize the block builder with a function and a dataflow block.
func_name = "main"
self.block_builder = relax.BlockBuilder()
+ params = []
if keep_params_as_input:
- params_ = []
func_attrs = {"num_input": len(inputs)}
for name, param in model.named_parameters():
shape = param.data.shape
dtype = self._convert_data_type(str(param.data.dtype))
inputs.append(relax.Var(name, relax.TensorStructInfo(shape,
dtype)))
self.params[param] = inputs[-1]
- params_.append(tvm.nd.array(param.data.cpu().numpy()))
- params = {func_name: params_}
+ params.append(tvm.nd.array(param.data.cpu().numpy()))
else:
- params = None
func_attrs = None
with self.block_builder.function(name=func_name, params=inputs.copy(),
attrs=func_attrs):
@@ -977,12 +973,15 @@ class TorchFXImporter:
assert output is not None
self.block_builder.emit_func_output(output)
- return ImporterOutput(self.block_builder.get(), params)
+ mod = self.block_builder.get()
+ if keep_params_as_input:
+ mod["main"] = mod["main"].with_attr("params", params)
+ return mod
def from_fx(
model, input_info: List[Tuple[Tuple[int], str]], keep_params_as_input:
bool = False
-) -> ImporterOutput:
+) -> tvm.IRModule:
"""Convert a PyTorch FX GraphModule to a Relax program
Parameters
@@ -998,9 +997,12 @@ def from_fx(
Returns
-------
- output : ImporterOutput
- The output of translation, including the translated IRModule, and
- the weights of the input model when `keep_params_as_input` is true.
+ output : tvm.IRModule
+ The import result IRModule, with the function "main" containing the
+ translated logic.
+ If `keep_params_as_input` is true, the "main" function have an
attribute
+ "params" that contains the weights of the input model. The weights
+ can be detached by `relax.frontend.detach_params`.
Examples
--------
@@ -1043,7 +1045,7 @@ def from_fx(
raise RuntimeError("Failed to export the PyTorch model to FX.")
# Use the importer to import the PyTorch model to Relax.
- mod: tvm.IRModule = from_fx(graph_module, input_info).mod
+ mod: tvm.IRModule = from_fx(graph_module, input_info)
# Print out the imported model.
print(mod.script())
diff --git a/python/tvm/relax/frontend/__init__.py
b/tests/python/relax/test_frontend_common.py
similarity index 50%
copy from python/tvm/relax/frontend/__init__.py
copy to tests/python/relax/test_frontend_common.py
index f3c0ed23eb..39f9af1031 100644
--- a/python/tvm/relax/frontend/__init__.py
+++ b/tests/python/relax/test_frontend_common.py
@@ -14,8 +14,28 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""
-Frontends for constructing Relax programs, with the model importers
-"""
-from . import torch
-from .common import ImporterOutput
+import tvm
+import tvm.testing
+from tvm.relax.frontend import detach_params
+from tvm.script.parser import relax as R
+
+
+def test_detach_params():
+ @R.function
+ def func(x: R.Tensor((2, 3), "float32")):
+ return x
+
+ param = tvm.nd.empty((3,), "float32")
+ mod = tvm.IRModule({"func": func.with_attr("params", [param])})
+ detached_mod, detached_params = detach_params(mod)
+
+ tvm.ir.assert_structural_equal(detached_mod, tvm.IRModule({"func": func}))
+ assert len(detached_params) == 1
+ assert "func" in detached_params
+ assert isinstance(detached_params["func"], list)
+ assert len(detached_params["func"]) == 1
+ tvm.testing.assert_allclose(detached_params["func"][0].numpy(),
param.numpy())
+
+
+if __name__ == "__main__":
+ tvm.testing.main()
diff --git a/tests/python/relax/test_frontend_dynamo.py
b/tests/python/relax/test_frontend_dynamo.py
index 14d1e48fb5..b47e3e22bd 100644
--- a/tests/python/relax/test_frontend_dynamo.py
+++ b/tests/python/relax/test_frontend_dynamo.py
@@ -147,7 +147,7 @@ def test_subgraph_capture():
return gv
model = Input1()
- mod = dynamo_capture_subgraphs(model, torch.randn(10, 100)).mod
+ mod = dynamo_capture_subgraphs(model, torch.randn(10, 100))
binding = {"w0": model.lin.weight.detach().numpy(), "w1":
model.lin.bias.detach().numpy()}
binding = {k: tvm.nd.array(v) for k, v in binding.items()}
expected = relax.transform.BindParams("subgraph_0", binding)(Expected1)
@@ -190,7 +190,7 @@ def test_subgraph_capture():
R.output(gv1)
return gv1
- mod = dynamo_capture_subgraphs(Input2, torch.randn(10), torch.ones(10)).mod
+ mod = dynamo_capture_subgraphs(Input2, torch.randn(10), torch.ones(10))
tvm.ir.assert_structural_equal(mod, Expected2)
diff --git a/tests/python/relax/test_frontend_from_fx.py
b/tests/python/relax/test_frontend_from_fx.py
index 8dfbc97d8b..e36be8c3c8 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -22,12 +22,12 @@ import tvm.testing
from tvm.script.parser import relax as R, tir as T
-def verify_model(torch_model, input_info, binding, expected,
keep_params_as_input=False):
+def verify_model(torch_model, input_info, binding, expected):
from torch import fx
from tvm.relax.frontend.torch import from_fx
graph_model = fx.symbolic_trace(torch_model)
- mod = from_fx(graph_model, input_info,
keep_params_as_input=keep_params_as_input).mod
+ mod = from_fx(graph_model, input_info)
binding = {k: tvm.nd.array(v) for k, v in binding.items()}
expected = relax.transform.BindParams("main", binding)(expected)
tvm.ir.assert_structural_equal(mod, expected)
@@ -1752,7 +1752,7 @@ def test_arange():
return torch.arange(0, 20, dtype=torch.int32)
graph_model = fx.symbolic_trace(Arange())
- mod = from_fx(graph_model, [([10, 10], "float32")]).mod
+ mod = from_fx(graph_model, [([10, 10], "float32")])
assert len(mod["main"].body.blocks) == 1
assert len(mod["main"].body.blocks[0].bindings) == 1
assert isinstance(mod["main"].body.blocks[0].bindings[0].value,
relax.Constant)
@@ -1776,7 +1776,7 @@ def test_empty():
return torch.empty((10, 10), dtype=torch.float32)
graph_model = fx.symbolic_trace(Empty())
- mod = from_fx(graph_model, [([10, 10], "float32")]).mod
+ mod = from_fx(graph_model, [([10, 10], "float32")])
assert len(mod["main"].body.blocks) == 1
assert len(mod["main"].body.blocks[0].bindings) == 1
assert isinstance(mod["main"].body.blocks[0].bindings[0].value,
relax.Constant)
@@ -1803,7 +1803,7 @@ def test_tensor():
return torch.tensor(3)
graph_model1 = fx.symbolic_trace(Empty1())
- mod1 = from_fx(graph_model1, [([10, 10], "float32")]).mod
+ mod1 = from_fx(graph_model1, [([10, 10], "float32")])
assert len(mod1["main"].body.blocks) == 1
assert len(mod1["main"].body.blocks[0].bindings) == 1
assert isinstance(mod1["main"].body.blocks[0].bindings[0].value,
relax.Constant)
@@ -1811,7 +1811,7 @@ def test_tensor():
assert mod1["main"].body.blocks[0].bindings[0].value.data.dtype ==
"float32"
graph_model2 = fx.symbolic_trace(Empty2())
- mod2 = from_fx(graph_model2, [([10, 10], "float32")]).mod
+ mod2 = from_fx(graph_model2, [([10, 10], "float32")])
assert len(mod2["main"].body.blocks) == 1
assert len(mod2["main"].body.blocks[0].bindings) == 1
assert isinstance(mod2["main"].body.blocks[0].bindings[0].value,
relax.Constant)
@@ -2173,6 +2173,7 @@ def test_keep_params():
import torch
from torch import fx
from torch.nn import Module
+ from tvm.relax.frontend import detach_params
from tvm.relax.frontend.torch import from_fx
class Conv2D1(Module):
@@ -2213,10 +2214,11 @@ def test_keep_params():
model = Conv2D1()
graph_model = fx.symbolic_trace(model)
- trace_output = from_fx(graph_model, [([1, 3, 10, 10], "float32")],
keep_params_as_input=True)
- tvm.ir.assert_structural_equal(trace_output.mod, expected1)
- func = trace_output.mod["main"]
- params = trace_output.params["main"]
+ mod = from_fx(graph_model, [([1, 3, 10, 10], "float32")],
keep_params_as_input=True)
+ mod, params = detach_params(mod)
+ tvm.ir.assert_structural_equal(mod, expected1)
+ func = mod["main"]
+ params = params["main"]
assert len(params) == len(func.params) - 1
for param_var, param_ndarray in zip(func.params[1:], params):