This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 1ad1994f5f [Fix][TVMScript] Fix index of metadata in printed script 
(#14130)
1ad1994f5f is described below

commit 1ad1994f5f33d0420df61eb80520cc8a2730e74d
Author: Yixin Dong <[email protected]>
AuthorDate: Sun Feb 26 07:24:30 2023 +0800

    [Fix][TVMScript] Fix index of metadata in printed script (#14130)
    
    Currently, if the same metadata object (e.g. a multi-line `tir.StringImm`) 
is referenced for more than one times in an IRModule, each reference will have 
different indices of the metadata array. For example, this code
    
    ```
    str_imm = T.StringImm("aaa\nbbb\n")
    @I.ir_module
    class Module:
        @T.prim_func
        def foo() -> None:
            A = str_imm
            B = str_imm
    
        @T.prim_func
        def foo1() -> None:
            A = str_imm
    Module.show()
    ```
    
    where `str_imm` is referenced three times, will generate such output:
    
    ```
    @I.ir_module
    class Module:
        @T.prim_func
        def foo():
            A: T.handle = metadata["tir.StringImm"][0]
            B: T.handle = metadata["tir.StringImm"][1]
            T.evaluate(0)
    
        @T.prim_func
        def foo1():
            A: T.handle = metadata["tir.StringImm"][2]
            T.evaluate(0)
    ```
    
    Each time has a different metadata index.
    
    This PR fixes this problem by detecting duplicate item in 
`IRDocsifierNode::AddMetadata`.
---
 src/script/printer/ir_docsifier.cc                 | 10 ++---
 .../unittest/test_tvmscript_printer_metadata.py    | 47 ++++++++++++++++++++++
 2 files changed, 52 insertions(+), 5 deletions(-)

diff --git a/src/script/printer/ir_docsifier.cc 
b/src/script/printer/ir_docsifier.cc
index 936534480f..fd5003073a 100644
--- a/src/script/printer/ir_docsifier.cc
+++ b/src/script/printer/ir_docsifier.cc
@@ -56,11 +56,11 @@ ExprDoc IRDocsifierNode::AddMetadata(const ObjectRef& obj) {
   ICHECK(obj.defined()) << "TypeError: Cannot add nullptr to metadata";
   String key = obj->GetTypeKey();
   Array<ObjectRef>& array = metadata[key];
-  int index = array.size();
-  array.push_back(obj);
-  return IdDoc("metadata")               //
-      [{LiteralDoc::Str(key, NullOpt)}]  //
-      [{LiteralDoc::Int(index, NullOpt)}];
+  int index = std::find(array.begin(), array.end(), obj) - array.begin();
+  if (index == static_cast<int>(array.size())) {
+    array.push_back(obj);
+  }
+  return IdDoc("metadata")[{LiteralDoc::Str(key, 
NullOpt)}][{LiteralDoc::Int(index, NullOpt)}];
 }
 
 bool IRDocsifierNode::IsVarDefined(const ObjectRef& obj) const { return 
obj2info.count(obj); }
diff --git a/tests/python/unittest/test_tvmscript_printer_metadata.py 
b/tests/python/unittest/test_tvmscript_printer_metadata.py
new file mode 100644
index 0000000000..a57f4c71f7
--- /dev/null
+++ b/tests/python/unittest/test_tvmscript_printer_metadata.py
@@ -0,0 +1,47 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+import tvm.testing
+from tvm.script.parser import ir as I
+from tvm.script.parser import tir as T
+
+
+def test_str_metadata():
+    # This test is to check we reuse the existing metadata element for the 
same tir.StringImm
+    # So metadata["tir.StringImm"][0] will occur in the printed script for 
three times
+    str_imm = T.StringImm("aaa\nbbb\n")
+
+    @I.ir_module
+    class Module:
+        @T.prim_func
+        def foo() -> None:
+            A = str_imm
+            B = str_imm
+
+        @T.prim_func
+        def foo1() -> None:
+            A = str_imm
+
+    printed_str = Module.script(verbose_expr=True)
+    assert (
+        printed_str.count('metadata["tir.StringImm"][0]') == 3
+        and printed_str.count('metadata["tir.StringImm"][1]') == 0
+    )
+
+
+if __name__ == "__main__":
+    tvm.testing.main()

Reply via email to