This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 6bcd375369 [Unity] Torch-like NN module enhancement (#14499)
6bcd375369 is described below

commit 6bcd375369a714565b965af93d58878e3e84e5cf
Author: Ruihang Lai <[email protected]>
AuthorDate: Tue Apr 4 16:16:22 2023 -0400

    [Unity] Torch-like NN module enhancement (#14499)
    
    This PR
    * supports `emit` to the Relax nn module so that we can emit arbitrary
    Expr instead of only having `emit_te` available,
    * fixes module parameter fetching to ignore non-parameter fields of an
    module (the previous behavior is to throw error for unrecognized
    fields).
---
 python/tvm/relax/testing/nn.py        |  8 +++--
 tests/python/relax/test_testing_nn.py | 60 +++++++++++++++++++++++++++++++++++
 2 files changed, 65 insertions(+), 3 deletions(-)

diff --git a/python/tvm/relax/testing/nn.py b/python/tvm/relax/testing/nn.py
index 830ddd779f..c2acc5a229 100644
--- a/python/tvm/relax/testing/nn.py
+++ b/python/tvm/relax/testing/nn.py
@@ -26,6 +26,10 @@ import tvm
 from tvm import relax, topi, tir
 
 
+def emit(expr: relax.Expr) -> relax.Var:
+    return relax.BlockBuilder.current().emit(expr)
+
+
 def emit_te(func: Callable, *args: Any, **kwargs: Any) -> relax.Var:
     return relax.BlockBuilder.current().emit_te(func, *args, **kwargs)
 
@@ -112,9 +116,7 @@ def _unpack_params(value: object) -> List[relax.Var]:
         for v in value:
             params += _unpack_params(v)
         return params
-    if value is None or isinstance(value, (int, float, str)):
-        return []
-    raise TypeError("not supported type when unpacking parameters: 
{}".format(type(value)))
+    return []
 
 
 def init_params(mod: tvm.IRModule) -> List[tvm.nd.array]:
diff --git a/tests/python/relax/test_testing_nn.py 
b/tests/python/relax/test_testing_nn.py
new file mode 100644
index 0000000000..65d531a4da
--- /dev/null
+++ b/tests/python/relax/test_testing_nn.py
@@ -0,0 +1,60 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import tvm
+import tvm.testing
+from tvm import relax
+from tvm.relax.testing import nn
+from tvm.script import ir as I, relax as R
+
+
+def test_emit():
+    class ReLU(nn.Module):
+        def forward(self, input: relax.Expr) -> relax.Var:
+            return nn.emit(relax.op.nn.relu(input))
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(x: R.Tensor((32, 32), dtype="float32")) -> R.Tensor((32, 32), 
dtype="float32"):
+            gv: R.Tensor((32, 32), dtype="float32") = R.nn.relu(x)
+            return gv
+
+    bb = relax.BlockBuilder()
+    with bb.function("main"):
+        model = ReLU()
+        x = nn.Placeholder((32, 32), dtype="float32", name="x")
+        output = model(x)
+        params = [x] + model.parameters()
+        bb.emit_func_output(output, params)
+
+    tvm.ir.assert_structural_equal(bb.get(), Expected)
+
+
+def test_get_param():
+    class Plus1(nn.Module):
+        def __init__(self):
+            self.const_1 = relax.const(1, "float32")
+
+        def forward(self, input: relax.Expr) -> relax.Var:
+            return nn.emit(relax.op.add(input, self.const_1))
+
+    model = Plus1()
+    assert model.parameters() == []
+
+
+if __name__ == "__main__":
+    tvm.testing.main()

Reply via email to