This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 85976ea1a9 [Unity][Frontend] FX translator returning weights with
`keep_params_as_input` (#14197)
85976ea1a9 is described below
commit 85976ea1a9b5b237215be9f26cc143038778d095
Author: Ruihang Lai <[email protected]>
AuthorDate: Sat Mar 4 22:59:53 2023 -0500
[Unity][Frontend] FX translator returning weights with
`keep_params_as_input` (#14197)
PR #14067 introduces the flag `keep_params_as_input` to the FX
translator, in the purpose to handle to model weights outside of the
translated Relax function.
This PR takes a further step, by returning the model weights as
NDArrays when the flag `keep_params_as_input` is true. With this PR, the
translator now can return back the weights upon requested. Otherwise,
after the import we will lose the model weights in the given PyTorch
model.
---
python/tvm/relax/frontend/__init__.py | 2 +
python/tvm/relax/frontend/common.py | 48 ++++++++++++++++++++++++
python/tvm/relax/frontend/torch/dynamo.py | 28 ++++++++++----
python/tvm/relax/frontend/torch/fx_translator.py | 22 +++++++----
tests/python/relax/test_frontend_dynamo.py | 4 +-
tests/python/relax/test_frontend_from_fx.py | 19 ++++++++--
6 files changed, 103 insertions(+), 20 deletions(-)
diff --git a/python/tvm/relax/frontend/__init__.py
b/python/tvm/relax/frontend/__init__.py
index 6c9c188aaa..f3c0ed23eb 100644
--- a/python/tvm/relax/frontend/__init__.py
+++ b/python/tvm/relax/frontend/__init__.py
@@ -17,3 +17,5 @@
"""
Frontends for constructing Relax programs, with the model importers
"""
+from . import torch
+from .common import ImporterOutput
diff --git a/python/tvm/relax/frontend/common.py
b/python/tvm/relax/frontend/common.py
new file mode 100644
index 0000000000..cdb88cd12c
--- /dev/null
+++ b/python/tvm/relax/frontend/common.py
@@ -0,0 +1,48 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Commons for Relax frontend."""
+from typing import Dict, List, Optional
+
+import tvm
+
+
+class ImporterOutput:
+ """The data structure representing the result of frontend imports.
+
+ Attributes
+ ----------
+ mod : tvm.IRModule
+ The IRModule imported from frontend.
+
+ params : Optional[Dict[str, List[tvm.nd.NDArray]]]
+ The weights of the imported model, when the weights of the model are
+ requested to be kept as parameters of functions in the IRModule. (e.g.,
+ when the `keep_params_as_input` flag of `frontend.torch.from_fx` is
set to
+ True.)
+ - `params` is defined to be None when not requested.
+ - The keys of `params` are the names of the Relax functions in the
IRModule.
+ - Each weight tensor is in the form of TVM NDArray on device CPU.
+ - The order of the returned weights is in accordance with the order of
+ the kept Relax function input variables.
+ """
+
+ mod: tvm.IRModule
+ params: Optional[Dict[str, List[tvm.nd.NDArray]]]
+
+ def __init__(self, mod: tvm.IRModule, params: Optional[Dict[str,
List[tvm.nd.NDArray]]]):
+ self.mod = mod
+ self.params = params
diff --git a/python/tvm/relax/frontend/torch/dynamo.py
b/python/tvm/relax/frontend/torch/dynamo.py
index 589c6be3b5..3f30044bb8 100644
--- a/python/tvm/relax/frontend/torch/dynamo.py
+++ b/python/tvm/relax/frontend/torch/dynamo.py
@@ -24,7 +24,9 @@ from typing import Optional
import tvm
from tvm.relax import build as relax_build
-from tvm.relax.frontend.torch.fx_translator import from_fx
+
+from .fx_translator import from_fx
+from ..common import ImporterOutput
def device_from_inputs(example_inputs):
@@ -72,7 +74,7 @@ def relax_dynamo(pipeline: Optional[tvm.transform.Pass] =
None):
device = device_from_inputs(example_inputs)
input_info = [(tuple(tensor.shape), str(tensor.dtype)) for tensor in
example_inputs]
- mod = from_fx(graph_module, input_info)
+ mod = from_fx(graph_module, input_info).mod
if device.type == "cuda":
dev = tvm.cuda(device.index)
@@ -114,7 +116,7 @@ def relax_dynamo(pipeline: Optional[tvm.transform.Pass] =
None):
return _relax_backend
-def dynamo_capture_subgraphs(model, *params) -> tvm.ir.IRModule:
+def dynamo_capture_subgraphs(model, *params, **kwargs) -> ImporterOutput:
"""Capture subgraphs of the PyTorch model using torch.compile into an
IRModule.
Parameters
@@ -125,28 +127,38 @@ def dynamo_capture_subgraphs(model, *params) ->
tvm.ir.IRModule:
params : List[torch.Tensor]
The parameters of the PyTorch model.
+ keep_params_as_input : bool
+ Whether to keep model parameters as input variables of the captured
Relax functions.
+
Returns
-------
- mod : tvm.ir.IRModule
- The IRModule that contains captured subgraphs.
+ output : ImporterOutput
+ The output of translation, including the translated IRModule, and
+ the weights of the input model when `keep_params_as_input` is true.
"""
import torch # type: ignore[import]
from torch import fx # type: ignore[import]
from torch import _dynamo as dynamo # type: ignore[import]
+ keep_params_as_input = "keep_params_as_input" in kwargs and
kwargs["keep_params_as_input"]
+
mod = tvm.IRModule()
+ params_ndarray = dict() if keep_params_as_input else None
def _capture(graph_module: fx.GraphModule, example_inputs):
assert isinstance(graph_module, torch.fx.GraphModule)
input_info = [(tuple(tensor.shape), str(tensor.dtype)) for tensor in
example_inputs]
- subgraph = from_fx(graph_module, input_info)
- mod["subgraph_" + str(len(mod.get_global_vars()))] = subgraph["main"]
+ trace_output = from_fx(graph_module, input_info, keep_params_as_input)
+ func_name = f"subgraph_{len(mod.get_global_vars())}"
+ mod[func_name] = trace_output.mod["main"]
+ if keep_params_as_input:
+ params_ndarray[func_name] = trace_output.params["main"]
return graph_module.forward
dynamo.reset()
compiled_model = torch.compile(model, backend=_capture)
compiled_model(*params)
- return mod
+ return ImporterOutput(mod, params_ndarray)
@functools.lru_cache(None)
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py
b/python/tvm/relax/frontend/torch/fx_translator.py
index b580e1679b..a73bc9d0db 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -24,6 +24,8 @@ from functools import reduce
import tvm
from tvm import relax
+from ..common import ImporterOutput
+
class TorchFXImporter:
"""An importer from PyTorch FX to Relax."""
@@ -843,7 +845,7 @@ class TorchFXImporter:
def from_fx(
self, model, input_info: List[Tuple[Tuple[int], str]],
keep_params_as_input: bool
- ) -> tvm.IRModule:
+ ) -> ImporterOutput:
"""Convert a PyTorch FX GraphModule to a Relax program."""
from torch import fx
@@ -860,18 +862,23 @@ class TorchFXImporter:
)
# Initialize the block builder with a function and a dataflow block.
+ func_name = "main"
self.block_builder = relax.BlockBuilder()
if keep_params_as_input:
+ params_ = []
func_attrs = {"num_input": len(inputs)}
for name, param in model.named_parameters():
shape = param.data.shape
dtype = self._convert_data_type(str(param.data.dtype))
inputs.append(relax.Var(name, relax.TensorStructInfo(shape,
dtype)))
self.params[param] = inputs[-1]
+ params_.append(tvm.nd.array(param.data.cpu().numpy()))
+ params = {func_name: params_}
else:
+ params = None
func_attrs = None
- with self.block_builder.function(name="main", params=inputs.copy(),
attrs=func_attrs):
+ with self.block_builder.function(name=func_name, params=inputs.copy(),
attrs=func_attrs):
output = None
with self.block_builder.dataflow():
# Translate model parameters.
@@ -916,12 +923,12 @@ class TorchFXImporter:
assert output is not None
self.block_builder.emit_func_output(output)
- return self.block_builder.get()
+ return ImporterOutput(self.block_builder.get(), params)
def from_fx(
model, input_info: List[Tuple[Tuple[int], str]], keep_params_as_input:
bool = False
-) -> tvm.IRModule:
+) -> ImporterOutput:
"""Convert a PyTorch FX GraphModule to a Relax program
Parameters
@@ -937,8 +944,9 @@ def from_fx(
Returns
-------
- module : tvm.IRModule
- The converted Relax program.
+ output : ImporterOutput
+ The output of translation, including the translated IRModule, and
+ the weights of the input model when `keep_params_as_input` is true.
Examples
--------
@@ -981,7 +989,7 @@ def from_fx(
raise RuntimeError("Failed to export the PyTorch model to FX.")
# Use the importer to import the PyTorch model to Relax.
- mod: tvm.IRModule = from_fx(graph_module, input_info)
+ mod: tvm.IRModule = from_fx(graph_module, input_info).mod
# Print out the imported model.
print(mod.script())
diff --git a/tests/python/relax/test_frontend_dynamo.py
b/tests/python/relax/test_frontend_dynamo.py
index b47e3e22bd..14d1e48fb5 100644
--- a/tests/python/relax/test_frontend_dynamo.py
+++ b/tests/python/relax/test_frontend_dynamo.py
@@ -147,7 +147,7 @@ def test_subgraph_capture():
return gv
model = Input1()
- mod = dynamo_capture_subgraphs(model, torch.randn(10, 100))
+ mod = dynamo_capture_subgraphs(model, torch.randn(10, 100)).mod
binding = {"w0": model.lin.weight.detach().numpy(), "w1":
model.lin.bias.detach().numpy()}
binding = {k: tvm.nd.array(v) for k, v in binding.items()}
expected = relax.transform.BindParams("subgraph_0", binding)(Expected1)
@@ -190,7 +190,7 @@ def test_subgraph_capture():
R.output(gv1)
return gv1
- mod = dynamo_capture_subgraphs(Input2, torch.randn(10), torch.ones(10))
+ mod = dynamo_capture_subgraphs(Input2, torch.randn(10), torch.ones(10)).mod
tvm.ir.assert_structural_equal(mod, Expected2)
diff --git a/tests/python/relax/test_frontend_from_fx.py
b/tests/python/relax/test_frontend_from_fx.py
index 9ab0b3304c..e28483dc2f 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -27,7 +27,7 @@ def verify_model(torch_model, input_info, binding, expected,
keep_params_as_inpu
from tvm.relax.frontend.torch import from_fx
graph_model = fx.symbolic_trace(torch_model)
- mod = from_fx(graph_model, input_info,
keep_params_as_input=keep_params_as_input)
+ mod = from_fx(graph_model, input_info,
keep_params_as_input=keep_params_as_input).mod
binding = {k: tvm.nd.array(v) for k, v in binding.items()}
expected = relax.transform.BindParams("main", binding)(expected)
tvm.ir.assert_structural_equal(mod, expected)
@@ -2096,7 +2096,9 @@ def test_view():
@tvm.testing.requires_gpu
def test_keep_params():
import torch
+ from torch import fx
from torch.nn import Module
+ from tvm.relax.frontend.torch import from_fx
class Conv2D1(Module):
def __init__(self):
@@ -2135,8 +2137,19 @@ def test_keep_params():
return gv
model = Conv2D1()
- input_info = [([1, 3, 10, 10], "float32")]
- verify_model(model, input_info, {}, expected1, keep_params_as_input=True)
+ graph_model = fx.symbolic_trace(model)
+ trace_output = from_fx(graph_model, [([1, 3, 10, 10], "float32")],
keep_params_as_input=True)
+ tvm.ir.assert_structural_equal(trace_output.mod, expected1)
+ func = trace_output.mod["main"]
+ params = trace_output.params["main"]
+
+ assert len(params) == len(func.params) - 1
+ for param_var, param_ndarray in zip(func.params[1:], params):
+ assert tuple(x.value for x in param_var.struct_info.shape.values) ==
param_ndarray.shape
+ assert param_var.struct_info.dtype == param_ndarray.dtype
+
+ tvm.testing.assert_allclose(params[0].numpy(),
model.conv.weight.detach().numpy())
+ tvm.testing.assert_allclose(params[1].numpy(),
model.conv.bias.detach().numpy())
@tvm.testing.requires_gpu