This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 4d57272e7f [Unity] Support model kwargs in dynamo_capture_subgraph
(#14349)
4d57272e7f is described below
commit 4d57272e7f9beba8599ac0d604a5fe83cb846214
Author: Hongyi Jin <[email protected]>
AuthorDate: Tue Mar 21 01:25:44 2023 -0400
[Unity] Support model kwargs in dynamo_capture_subgraph (#14349)
This PR enables user to pass the torch model's kwargs of forward function
to dynamo_capture_subgraph, which makes it more flexible.
---
python/tvm/relax/frontend/torch/dynamo.py | 4 ++--
tests/python/relax/test_frontend_dynamo.py | 36 ++++++++++++++++++++++++++++++
2 files changed, 38 insertions(+), 2 deletions(-)
diff --git a/python/tvm/relax/frontend/torch/dynamo.py
b/python/tvm/relax/frontend/torch/dynamo.py
index c71c1fbc84..f48a2cde3c 100644
--- a/python/tvm/relax/frontend/torch/dynamo.py
+++ b/python/tvm/relax/frontend/torch/dynamo.py
@@ -142,7 +142,7 @@ def dynamo_capture_subgraphs(model, *params, **kwargs) ->
tvm.IRModule:
from torch import _dynamo as dynamo # type: ignore[import]
keep_params_as_input = "keep_params_as_input" in kwargs and
kwargs["keep_params_as_input"]
-
+ kwargs.pop("keep_params_as_input", None)
mod = tvm.IRModule()
def _capture(graph_module: fx.GraphModule, example_inputs):
@@ -159,7 +159,7 @@ def dynamo_capture_subgraphs(model, *params, **kwargs) ->
tvm.IRModule:
dynamo.reset()
compiled_model = torch.compile(model, backend=_capture)
- compiled_model(*params)
+ compiled_model(*params, **kwargs)
return mod
diff --git a/tests/python/relax/test_frontend_dynamo.py
b/tests/python/relax/test_frontend_dynamo.py
index 192f8e8b10..765ca9b6f0 100644
--- a/tests/python/relax/test_frontend_dynamo.py
+++ b/tests/python/relax/test_frontend_dynamo.py
@@ -193,6 +193,42 @@ def test_subgraph_capture():
mod = dynamo_capture_subgraphs(Input2, torch.randn(10), torch.ones(10))
tvm.ir.assert_structural_equal(mod, Expected2)
+ class Input3(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.lin = torch.nn.Linear(100, 10)
+
+ def forward(self, x, add_one=False):
+ if add_one:
+ x = x + 1
+ return torch.nn.functional.relu(self.lin(x))
+
+ @tvm.script.ir_module
+ class Expected3:
+ @R.function
+ def subgraph_0(
+ inp_0: R.Tensor((10, 100), dtype="float32"),
+ w0: R.Tensor((10, 100), dtype="float32"),
+ w1: R.Tensor((10,), dtype="float32"),
+ ) -> R.Tensor((10, 10), dtype="float32"):
+ # block 0
+ with R.dataflow():
+ lv0 = R.add(inp_0, R.const(1, "float32"))
+ lv: R.Tensor((100, 10), dtype="float32") = R.permute_dims(w0,
axes=None)
+ lv1: R.Tensor((10, 10), dtype="float32") = R.matmul(lv0, lv,
out_dtype="float32")
+ lv2: R.Tensor((10, 10), dtype="float32") = R.add(lv1, w1)
+ lv3: R.Tensor((10, 10), dtype="float32") = R.nn.relu(lv2)
+ gv: R.Tensor((10, 10), dtype="float32") = lv3
+ R.output(gv)
+ return gv
+
+ model = Input3()
+ mod = dynamo_capture_subgraphs(model, torch.randn(10, 100), add_one=True)
+ binding = {"w0": model.lin.weight.detach().numpy(), "w1":
model.lin.bias.detach().numpy()}
+ binding = {k: tvm.nd.array(v) for k, v in binding.items()}
+ expected = relax.transform.BindParams("subgraph_0", binding)(Expected3)
+ tvm.ir.assert_structural_equal(mod, expected)
+
if __name__ == "__main__":
tvm.testing.main()