This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 6bcd375369 [Unity] Torch-like NN module enhancement (#14499)
6bcd375369 is described below
commit 6bcd375369a714565b965af93d58878e3e84e5cf
Author: Ruihang Lai <[email protected]>
AuthorDate: Tue Apr 4 16:16:22 2023 -0400
[Unity] Torch-like NN module enhancement (#14499)
This PR
* supports `emit` to the Relax nn module so that we can emit arbitrary
Expr instead of only having `emit_te` available,
* fixes module parameter fetching to ignore non-parameter fields of an
module (the previous behavior is to throw error for unrecognized
fields).
---
python/tvm/relax/testing/nn.py | 8 +++--
tests/python/relax/test_testing_nn.py | 60 +++++++++++++++++++++++++++++++++++
2 files changed, 65 insertions(+), 3 deletions(-)
diff --git a/python/tvm/relax/testing/nn.py b/python/tvm/relax/testing/nn.py
index 830ddd779f..c2acc5a229 100644
--- a/python/tvm/relax/testing/nn.py
+++ b/python/tvm/relax/testing/nn.py
@@ -26,6 +26,10 @@ import tvm
from tvm import relax, topi, tir
+def emit(expr: relax.Expr) -> relax.Var:
+ return relax.BlockBuilder.current().emit(expr)
+
+
def emit_te(func: Callable, *args: Any, **kwargs: Any) -> relax.Var:
return relax.BlockBuilder.current().emit_te(func, *args, **kwargs)
@@ -112,9 +116,7 @@ def _unpack_params(value: object) -> List[relax.Var]:
for v in value:
params += _unpack_params(v)
return params
- if value is None or isinstance(value, (int, float, str)):
- return []
- raise TypeError("not supported type when unpacking parameters:
{}".format(type(value)))
+ return []
def init_params(mod: tvm.IRModule) -> List[tvm.nd.array]:
diff --git a/tests/python/relax/test_testing_nn.py
b/tests/python/relax/test_testing_nn.py
new file mode 100644
index 0000000000..65d531a4da
--- /dev/null
+++ b/tests/python/relax/test_testing_nn.py
@@ -0,0 +1,60 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import tvm
+import tvm.testing
+from tvm import relax
+from tvm.relax.testing import nn
+from tvm.script import ir as I, relax as R
+
+
+def test_emit():
+ class ReLU(nn.Module):
+ def forward(self, input: relax.Expr) -> relax.Var:
+ return nn.emit(relax.op.nn.relu(input))
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(x: R.Tensor((32, 32), dtype="float32")) -> R.Tensor((32, 32),
dtype="float32"):
+ gv: R.Tensor((32, 32), dtype="float32") = R.nn.relu(x)
+ return gv
+
+ bb = relax.BlockBuilder()
+ with bb.function("main"):
+ model = ReLU()
+ x = nn.Placeholder((32, 32), dtype="float32", name="x")
+ output = model(x)
+ params = [x] + model.parameters()
+ bb.emit_func_output(output, params)
+
+ tvm.ir.assert_structural_equal(bb.get(), Expected)
+
+
+def test_get_param():
+ class Plus1(nn.Module):
+ def __init__(self):
+ self.const_1 = relax.const(1, "float32")
+
+ def forward(self, input: relax.Expr) -> relax.Var:
+ return nn.emit(relax.op.add(input, self.const_1))
+
+ model = Plus1()
+ assert model.parameters() == []
+
+
+if __name__ == "__main__":
+ tvm.testing.main()