This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 1ad1994f5f [Fix][TVMScript] Fix index of metadata in printed script
(#14130)
1ad1994f5f is described below
commit 1ad1994f5f33d0420df61eb80520cc8a2730e74d
Author: Yixin Dong <[email protected]>
AuthorDate: Sun Feb 26 07:24:30 2023 +0800
[Fix][TVMScript] Fix index of metadata in printed script (#14130)
Currently, if the same metadata object (e.g. a multi-line `tir.StringImm`)
is referenced for more than one times in an IRModule, each reference will have
different indices of the metadata array. For example, this code
```
str_imm = T.StringImm("aaa\nbbb\n")
@I.ir_module
class Module:
@T.prim_func
def foo() -> None:
A = str_imm
B = str_imm
@T.prim_func
def foo1() -> None:
A = str_imm
Module.show()
```
where `str_imm` is referenced three times, will generate such output:
```
@I.ir_module
class Module:
@T.prim_func
def foo():
A: T.handle = metadata["tir.StringImm"][0]
B: T.handle = metadata["tir.StringImm"][1]
T.evaluate(0)
@T.prim_func
def foo1():
A: T.handle = metadata["tir.StringImm"][2]
T.evaluate(0)
```
Each time has a different metadata index.
This PR fixes this problem by detecting duplicate item in
`IRDocsifierNode::AddMetadata`.
---
src/script/printer/ir_docsifier.cc | 10 ++---
.../unittest/test_tvmscript_printer_metadata.py | 47 ++++++++++++++++++++++
2 files changed, 52 insertions(+), 5 deletions(-)
diff --git a/src/script/printer/ir_docsifier.cc
b/src/script/printer/ir_docsifier.cc
index 936534480f..fd5003073a 100644
--- a/src/script/printer/ir_docsifier.cc
+++ b/src/script/printer/ir_docsifier.cc
@@ -56,11 +56,11 @@ ExprDoc IRDocsifierNode::AddMetadata(const ObjectRef& obj) {
ICHECK(obj.defined()) << "TypeError: Cannot add nullptr to metadata";
String key = obj->GetTypeKey();
Array<ObjectRef>& array = metadata[key];
- int index = array.size();
- array.push_back(obj);
- return IdDoc("metadata") //
- [{LiteralDoc::Str(key, NullOpt)}] //
- [{LiteralDoc::Int(index, NullOpt)}];
+ int index = std::find(array.begin(), array.end(), obj) - array.begin();
+ if (index == static_cast<int>(array.size())) {
+ array.push_back(obj);
+ }
+ return IdDoc("metadata")[{LiteralDoc::Str(key,
NullOpt)}][{LiteralDoc::Int(index, NullOpt)}];
}
bool IRDocsifierNode::IsVarDefined(const ObjectRef& obj) const { return
obj2info.count(obj); }
diff --git a/tests/python/unittest/test_tvmscript_printer_metadata.py
b/tests/python/unittest/test_tvmscript_printer_metadata.py
new file mode 100644
index 0000000000..a57f4c71f7
--- /dev/null
+++ b/tests/python/unittest/test_tvmscript_printer_metadata.py
@@ -0,0 +1,47 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+import tvm.testing
+from tvm.script.parser import ir as I
+from tvm.script.parser import tir as T
+
+
+def test_str_metadata():
+ # This test is to check we reuse the existing metadata element for the
same tir.StringImm
+ # So metadata["tir.StringImm"][0] will occur in the printed script for
three times
+ str_imm = T.StringImm("aaa\nbbb\n")
+
+ @I.ir_module
+ class Module:
+ @T.prim_func
+ def foo() -> None:
+ A = str_imm
+ B = str_imm
+
+ @T.prim_func
+ def foo1() -> None:
+ A = str_imm
+
+ printed_str = Module.script(verbose_expr=True)
+ assert (
+ printed_str.count('metadata["tir.StringImm"][0]') == 3
+ and printed_str.count('metadata["tir.StringImm"][1]') == 0
+ )
+
+
+if __name__ == "__main__":
+ tvm.testing.main()