This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 4d57272e7f [Unity] Support model kwargs in dynamo_capture_subgraph 
(#14349)
4d57272e7f is described below

commit 4d57272e7f9beba8599ac0d604a5fe83cb846214
Author: Hongyi Jin <[email protected]>
AuthorDate: Tue Mar 21 01:25:44 2023 -0400

    [Unity] Support model kwargs in dynamo_capture_subgraph (#14349)
    
    This PR enables user to pass the torch model's kwargs of forward function 
to dynamo_capture_subgraph, which makes it more flexible.
---
 python/tvm/relax/frontend/torch/dynamo.py  |  4 ++--
 tests/python/relax/test_frontend_dynamo.py | 36 ++++++++++++++++++++++++++++++
 2 files changed, 38 insertions(+), 2 deletions(-)

diff --git a/python/tvm/relax/frontend/torch/dynamo.py 
b/python/tvm/relax/frontend/torch/dynamo.py
index c71c1fbc84..f48a2cde3c 100644
--- a/python/tvm/relax/frontend/torch/dynamo.py
+++ b/python/tvm/relax/frontend/torch/dynamo.py
@@ -142,7 +142,7 @@ def dynamo_capture_subgraphs(model, *params, **kwargs) -> 
tvm.IRModule:
     from torch import _dynamo as dynamo  # type: ignore[import]
 
     keep_params_as_input = "keep_params_as_input" in kwargs and 
kwargs["keep_params_as_input"]
-
+    kwargs.pop("keep_params_as_input", None)
     mod = tvm.IRModule()
 
     def _capture(graph_module: fx.GraphModule, example_inputs):
@@ -159,7 +159,7 @@ def dynamo_capture_subgraphs(model, *params, **kwargs) -> 
tvm.IRModule:
 
     dynamo.reset()
     compiled_model = torch.compile(model, backend=_capture)
-    compiled_model(*params)
+    compiled_model(*params, **kwargs)
     return mod
 
 
diff --git a/tests/python/relax/test_frontend_dynamo.py 
b/tests/python/relax/test_frontend_dynamo.py
index 192f8e8b10..765ca9b6f0 100644
--- a/tests/python/relax/test_frontend_dynamo.py
+++ b/tests/python/relax/test_frontend_dynamo.py
@@ -193,6 +193,42 @@ def test_subgraph_capture():
     mod = dynamo_capture_subgraphs(Input2, torch.randn(10), torch.ones(10))
     tvm.ir.assert_structural_equal(mod, Expected2)
 
+    class Input3(torch.nn.Module):
+        def __init__(self):
+            super().__init__()
+            self.lin = torch.nn.Linear(100, 10)
+
+        def forward(self, x, add_one=False):
+            if add_one:
+                x = x + 1
+            return torch.nn.functional.relu(self.lin(x))
+
+    @tvm.script.ir_module
+    class Expected3:
+        @R.function
+        def subgraph_0(
+            inp_0: R.Tensor((10, 100), dtype="float32"),
+            w0: R.Tensor((10, 100), dtype="float32"),
+            w1: R.Tensor((10,), dtype="float32"),
+        ) -> R.Tensor((10, 10), dtype="float32"):
+            # block 0
+            with R.dataflow():
+                lv0 = R.add(inp_0, R.const(1, "float32"))
+                lv: R.Tensor((100, 10), dtype="float32") = R.permute_dims(w0, 
axes=None)
+                lv1: R.Tensor((10, 10), dtype="float32") = R.matmul(lv0, lv, 
out_dtype="float32")
+                lv2: R.Tensor((10, 10), dtype="float32") = R.add(lv1, w1)
+                lv3: R.Tensor((10, 10), dtype="float32") = R.nn.relu(lv2)
+                gv: R.Tensor((10, 10), dtype="float32") = lv3
+                R.output(gv)
+            return gv
+
+    model = Input3()
+    mod = dynamo_capture_subgraphs(model, torch.randn(10, 100), add_one=True)
+    binding = {"w0": model.lin.weight.detach().numpy(), "w1": 
model.lin.bias.detach().numpy()}
+    binding = {k: tvm.nd.array(v) for k, v in binding.items()}
+    expected = relax.transform.BindParams("subgraph_0", binding)(Expected3)
+    tvm.ir.assert_structural_equal(mod, expected)
+
 
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to