This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 3cb57d20a0 [TVMScript] Printer Frame (#12366)
3cb57d20a0 is described below

commit 3cb57d20a0ad3ceffdddc060d1188f4921eb8722
Author: Lite Ye <[email protected]>
AuthorDate: Thu Aug 11 23:22:37 2022 -0400

    [TVMScript] Printer Frame (#12366)
    
    This PR:
    
    - Implement Frame for the TVMScript Unified Printer
    
    Compared to the prototype version, this:
    
    - Removes the dependency of VarTable (SymbolTable) from Frame
    - Adds a callback array to the Frame base class so that VarTable can add 
callback to clean variable when Frame goes out scope
    
    Tracking issue: https://github.com/apache/tvm/issues/11912
---
 include/tvm/script/printer/frame.h                 | 140 +++++++++++++++++++++
 python/tvm/script/printer/frame.py                 |  81 ++++++++++++
 src/script/printer/frame.cc                        |  50 ++++++++
 .../unittest/test_tvmscript_printer_frame.py       |  60 +++++++++
 4 files changed, 331 insertions(+)

diff --git a/include/tvm/script/printer/frame.h 
b/include/tvm/script/printer/frame.h
new file mode 100644
index 0000000000..407ad16007
--- /dev/null
+++ b/include/tvm/script/printer/frame.h
@@ -0,0 +1,140 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef TVM_SCRIPT_PRINTER_FRAME_H_
+#define TVM_SCRIPT_PRINTER_FRAME_H_
+
+#include <tvm/node/node.h>
+#include <tvm/script/printer/doc.h>
+
+#include <utility>
+#include <vector>
+
+namespace tvm {
+namespace script {
+namespace printer {
+
+/*!
+ * Frame is the core data structure for semantic information
+ * when printing IR graph into TVMScript code.
+ */
+class FrameNode : public Object {
+ public:
+  void VisitAttrs(tvm::AttrVisitor* v) {}
+
+  virtual ~FrameNode() = default;
+
+  /*!
+   * \brief Add a callback function to be called when this frame exits.
+   * \param cb The callback function. It should have signature void().
+   */
+  template <typename TCallback>
+  void AddExitCallback(TCallback&& cb) {
+    callbacks_.emplace_back(std::forward<TCallback>(cb));
+  }
+
+  /*!
+   * \brief Method that's called when Frame enters the scope.
+   */
+  virtual void EnterWithScope() {}
+
+  /*!
+   * \brief Method that's called when Frame exits the scope.
+   */
+  virtual void ExitWithScope() {
+    for (const std::function<void()>& callback : callbacks_) {
+      callback();
+    }
+    callbacks_.clear();
+  }
+
+  static constexpr const char* _type_key = "script.printer.Frame";
+  TVM_DECLARE_BASE_OBJECT_INFO(FrameNode, Object);
+
+ private:
+  std::vector<std::function<void()>> callbacks_;
+};
+
+/*!
+ * \brief Reference type of FrameNode
+ */
+class Frame : public ObjectRef {
+ protected:
+  Frame() = default;
+
+ public:
+  virtual ~Frame() = default;
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Frame, ObjectRef, 
FrameNode);
+};
+
+/*!
+ * \brief MetadataFrame contains information like contant parameter array.
+ */
+class MetadataFrameNode : public FrameNode {
+ public:
+  Array<ObjectRef> metadata;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    FrameNode::VisitAttrs(v);
+    v->Visit("metadata", &metadata);
+  }
+
+  static constexpr const char* _type_key = "script.printer.MetadataFrame";
+  TVM_DECLARE_FINAL_OBJECT_INFO(MetadataFrameNode, FrameNode);
+};
+
+/*!
+ * \brief Reference type of MetadataFrameNode
+ */
+class MetadataFrame : public Frame {
+ public:
+  MetadataFrame();
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(MetadataFrame, Frame, 
MetadataFrameNode);
+};
+
+/*!
+ * \brief VarDefFrame contains information about the free variables that needs 
to be defined
+ * at the beginning of the printed snippet.
+ */
+class VarDefFrameNode : public FrameNode {
+ public:
+  Array<StmtDoc> stmts;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    FrameNode::VisitAttrs(v);
+    v->Visit("stmts", &stmts);
+  }
+
+  static constexpr const char* _type_key = "script.printer.VarDefFrame";
+  TVM_DECLARE_FINAL_OBJECT_INFO(VarDefFrameNode, FrameNode);
+};
+
+/*!
+ * \brief Reference type of VarDefFrameNode
+ */
+class VarDefFrame : public Frame {
+ public:
+  VarDefFrame();
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(VarDefFrame, Frame, 
VarDefFrameNode);
+};
+
+}  // namespace printer
+}  // namespace script
+}  // namespace tvm
+
+#endif  // TVM_SCRIPT_PRINTER_FRAME_H_
diff --git a/python/tvm/script/printer/frame.py 
b/python/tvm/script/printer/frame.py
new file mode 100644
index 0000000000..c967382b8b
--- /dev/null
+++ b/python/tvm/script/printer/frame.py
@@ -0,0 +1,81 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Frame is the core data structure for semantic information when printing
+IR graph into TVMScript code.
+"""
+
+from typing import Callable, Sequence
+
+from tvm._ffi import register_object
+from tvm.runtime import Object
+from tvm.script.printer.doc import StmtDoc
+
+from . import _ffi_api
+
+
+class Frame(Object):
+    """
+    Frame is the core data structure for semantic information
+    when printing IR graph into TVMScript code.
+
+    Frame base class manages a list of callbacks to be executed
+    when frame goes out of scope.
+    """
+
+    def add_exit_callback(self, callback: Callable[[], None]) -> None:
+        """
+        Adds a callback function to be executed when frame goes out of scope.
+
+        Parameters
+        ----------
+        callback : Callable[[], None]
+            The callback function.
+        """
+        _ffi_api.FrameAddExitCallback(self, callback)  # type: ignore # 
pylint: disable=no-member
+
+    def __enter__(self):
+        _ffi_api.FrameEnterWithScope(self)  # type: ignore # pylint: 
disable=no-member
+        return self
+
+    def __exit__(self, *exception_info):
+        _ffi_api.FrameExitWithScope(self)  # type: ignore # pylint: 
disable=no-member
+
+
+@register_object("script.printer.MetadataFrame")
+class MetadataFrame(Frame):
+    """
+    MetadataFrame contains information like contant parameter array.
+    """
+
+    metadata: Sequence[Object]
+
+    def __init__(self):
+        self.__init_handle_by_constructor__(_ffi_api.MetadataFrame)  # type: 
ignore # pylint: disable=no-member
+
+
+@register_object("script.printer.VarDefFrame")
+class VarDefFrame(Frame):
+    """
+    VarDefFrame contains information about the free variables that needs to
+    be defined at the beginning of the printed snippet.
+    """
+
+    stmts: Sequence[StmtDoc]
+
+    def __init__(self):
+        self.__init_handle_by_constructor__(_ffi_api.VarDefFrame)  # type: 
ignore # pylint: disable=no-member
diff --git a/src/script/printer/frame.cc b/src/script/printer/frame.cc
new file mode 100644
index 0000000000..b342c7c886
--- /dev/null
+++ b/src/script/printer/frame.cc
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <tvm/runtime/registry.h>
+#include <tvm/script/printer/frame.h>
+
+namespace tvm {
+namespace script {
+namespace printer {
+
+MetadataFrame::MetadataFrame() : 
MetadataFrame(make_object<MetadataFrameNode>()) {}
+
+VarDefFrame::VarDefFrame() : VarDefFrame(make_object<VarDefFrameNode>()) {}
+
+TVM_REGISTER_NODE_TYPE(FrameNode);
+TVM_REGISTER_GLOBAL("script.printer.FrameAddExitCallback")
+    .set_body_typed([](Frame frame, runtime::TypedPackedFunc<void()> callback) 
{
+      frame->AddExitCallback(callback);
+    });
+TVM_REGISTER_GLOBAL("script.printer.FrameEnterWithScope")
+    .set_body_method<Frame>(&FrameNode::EnterWithScope);
+TVM_REGISTER_GLOBAL("script.printer.FrameExitWithScope")
+    .set_body_method<Frame>(&FrameNode::ExitWithScope);
+
+TVM_REGISTER_NODE_TYPE(MetadataFrameNode);
+TVM_REGISTER_GLOBAL("script.printer.MetadataFrame").set_body_typed([]() {
+  return MetadataFrame();
+});
+
+TVM_REGISTER_NODE_TYPE(VarDefFrameNode);
+TVM_REGISTER_GLOBAL("script.printer.VarDefFrame").set_body_typed([]() { return 
VarDefFrame(); });
+
+}  // namespace printer
+}  // namespace script
+}  // namespace tvm
diff --git a/tests/python/unittest/test_tvmscript_printer_frame.py 
b/tests/python/unittest/test_tvmscript_printer_frame.py
new file mode 100644
index 0000000000..bd98d64456
--- /dev/null
+++ b/tests/python/unittest/test_tvmscript_printer_frame.py
@@ -0,0 +1,60 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from tvm.script.printer.frame import MetadataFrame
+
+
+def test_frame_add_callback():
+    frame = MetadataFrame()
+
+    flag = 0
+
+    def callback1():
+        nonlocal flag
+        flag += 1
+
+    def callback2():
+        nonlocal flag
+        flag += 5
+
+    frame.add_exit_callback(callback1)
+    with frame:
+        frame.add_exit_callback(callback2)
+        assert flag == 0
+
+    assert flag == 6
+
+
+def test_frame_clear_callbacks_after_exit():
+    frame = MetadataFrame()
+
+    flag = 0
+
+    def callback():
+        nonlocal flag
+        flag += 1
+
+    frame.add_exit_callback(callback)
+
+    with frame:
+        pass
+
+    assert flag == 1
+
+    with frame:
+        pass
+
+    assert flag == 1

Reply via email to