This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 3cb57d20a0 [TVMScript] Printer Frame (#12366)
3cb57d20a0 is described below
commit 3cb57d20a0ad3ceffdddc060d1188f4921eb8722
Author: Lite Ye <[email protected]>
AuthorDate: Thu Aug 11 23:22:37 2022 -0400
[TVMScript] Printer Frame (#12366)
This PR:
- Implement Frame for the TVMScript Unified Printer
Compared to the prototype version, this:
- Removes the dependency of VarTable (SymbolTable) from Frame
- Adds a callback array to the Frame base class so that VarTable can add
callback to clean variable when Frame goes out scope
Tracking issue: https://github.com/apache/tvm/issues/11912
---
include/tvm/script/printer/frame.h | 140 +++++++++++++++++++++
python/tvm/script/printer/frame.py | 81 ++++++++++++
src/script/printer/frame.cc | 50 ++++++++
.../unittest/test_tvmscript_printer_frame.py | 60 +++++++++
4 files changed, 331 insertions(+)
diff --git a/include/tvm/script/printer/frame.h
b/include/tvm/script/printer/frame.h
new file mode 100644
index 0000000000..407ad16007
--- /dev/null
+++ b/include/tvm/script/printer/frame.h
@@ -0,0 +1,140 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef TVM_SCRIPT_PRINTER_FRAME_H_
+#define TVM_SCRIPT_PRINTER_FRAME_H_
+
+#include <tvm/node/node.h>
+#include <tvm/script/printer/doc.h>
+
+#include <utility>
+#include <vector>
+
+namespace tvm {
+namespace script {
+namespace printer {
+
+/*!
+ * Frame is the core data structure for semantic information
+ * when printing IR graph into TVMScript code.
+ */
+class FrameNode : public Object {
+ public:
+ void VisitAttrs(tvm::AttrVisitor* v) {}
+
+ virtual ~FrameNode() = default;
+
+ /*!
+ * \brief Add a callback function to be called when this frame exits.
+ * \param cb The callback function. It should have signature void().
+ */
+ template <typename TCallback>
+ void AddExitCallback(TCallback&& cb) {
+ callbacks_.emplace_back(std::forward<TCallback>(cb));
+ }
+
+ /*!
+ * \brief Method that's called when Frame enters the scope.
+ */
+ virtual void EnterWithScope() {}
+
+ /*!
+ * \brief Method that's called when Frame exits the scope.
+ */
+ virtual void ExitWithScope() {
+ for (const std::function<void()>& callback : callbacks_) {
+ callback();
+ }
+ callbacks_.clear();
+ }
+
+ static constexpr const char* _type_key = "script.printer.Frame";
+ TVM_DECLARE_BASE_OBJECT_INFO(FrameNode, Object);
+
+ private:
+ std::vector<std::function<void()>> callbacks_;
+};
+
+/*!
+ * \brief Reference type of FrameNode
+ */
+class Frame : public ObjectRef {
+ protected:
+ Frame() = default;
+
+ public:
+ virtual ~Frame() = default;
+ TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Frame, ObjectRef,
FrameNode);
+};
+
+/*!
+ * \brief MetadataFrame contains information like contant parameter array.
+ */
+class MetadataFrameNode : public FrameNode {
+ public:
+ Array<ObjectRef> metadata;
+
+ void VisitAttrs(tvm::AttrVisitor* v) {
+ FrameNode::VisitAttrs(v);
+ v->Visit("metadata", &metadata);
+ }
+
+ static constexpr const char* _type_key = "script.printer.MetadataFrame";
+ TVM_DECLARE_FINAL_OBJECT_INFO(MetadataFrameNode, FrameNode);
+};
+
+/*!
+ * \brief Reference type of MetadataFrameNode
+ */
+class MetadataFrame : public Frame {
+ public:
+ MetadataFrame();
+ TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(MetadataFrame, Frame,
MetadataFrameNode);
+};
+
+/*!
+ * \brief VarDefFrame contains information about the free variables that needs
to be defined
+ * at the beginning of the printed snippet.
+ */
+class VarDefFrameNode : public FrameNode {
+ public:
+ Array<StmtDoc> stmts;
+
+ void VisitAttrs(tvm::AttrVisitor* v) {
+ FrameNode::VisitAttrs(v);
+ v->Visit("stmts", &stmts);
+ }
+
+ static constexpr const char* _type_key = "script.printer.VarDefFrame";
+ TVM_DECLARE_FINAL_OBJECT_INFO(VarDefFrameNode, FrameNode);
+};
+
+/*!
+ * \brief Reference type of VarDefFrameNode
+ */
+class VarDefFrame : public Frame {
+ public:
+ VarDefFrame();
+ TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(VarDefFrame, Frame,
VarDefFrameNode);
+};
+
+} // namespace printer
+} // namespace script
+} // namespace tvm
+
+#endif // TVM_SCRIPT_PRINTER_FRAME_H_
diff --git a/python/tvm/script/printer/frame.py
b/python/tvm/script/printer/frame.py
new file mode 100644
index 0000000000..c967382b8b
--- /dev/null
+++ b/python/tvm/script/printer/frame.py
@@ -0,0 +1,81 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Frame is the core data structure for semantic information when printing
+IR graph into TVMScript code.
+"""
+
+from typing import Callable, Sequence
+
+from tvm._ffi import register_object
+from tvm.runtime import Object
+from tvm.script.printer.doc import StmtDoc
+
+from . import _ffi_api
+
+
+class Frame(Object):
+ """
+ Frame is the core data structure for semantic information
+ when printing IR graph into TVMScript code.
+
+ Frame base class manages a list of callbacks to be executed
+ when frame goes out of scope.
+ """
+
+ def add_exit_callback(self, callback: Callable[[], None]) -> None:
+ """
+ Adds a callback function to be executed when frame goes out of scope.
+
+ Parameters
+ ----------
+ callback : Callable[[], None]
+ The callback function.
+ """
+ _ffi_api.FrameAddExitCallback(self, callback) # type: ignore #
pylint: disable=no-member
+
+ def __enter__(self):
+ _ffi_api.FrameEnterWithScope(self) # type: ignore # pylint:
disable=no-member
+ return self
+
+ def __exit__(self, *exception_info):
+ _ffi_api.FrameExitWithScope(self) # type: ignore # pylint:
disable=no-member
+
+
+@register_object("script.printer.MetadataFrame")
+class MetadataFrame(Frame):
+ """
+ MetadataFrame contains information like contant parameter array.
+ """
+
+ metadata: Sequence[Object]
+
+ def __init__(self):
+ self.__init_handle_by_constructor__(_ffi_api.MetadataFrame) # type:
ignore # pylint: disable=no-member
+
+
+@register_object("script.printer.VarDefFrame")
+class VarDefFrame(Frame):
+ """
+ VarDefFrame contains information about the free variables that needs to
+ be defined at the beginning of the printed snippet.
+ """
+
+ stmts: Sequence[StmtDoc]
+
+ def __init__(self):
+ self.__init_handle_by_constructor__(_ffi_api.VarDefFrame) # type:
ignore # pylint: disable=no-member
diff --git a/src/script/printer/frame.cc b/src/script/printer/frame.cc
new file mode 100644
index 0000000000..b342c7c886
--- /dev/null
+++ b/src/script/printer/frame.cc
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <tvm/runtime/registry.h>
+#include <tvm/script/printer/frame.h>
+
+namespace tvm {
+namespace script {
+namespace printer {
+
+MetadataFrame::MetadataFrame() :
MetadataFrame(make_object<MetadataFrameNode>()) {}
+
+VarDefFrame::VarDefFrame() : VarDefFrame(make_object<VarDefFrameNode>()) {}
+
+TVM_REGISTER_NODE_TYPE(FrameNode);
+TVM_REGISTER_GLOBAL("script.printer.FrameAddExitCallback")
+ .set_body_typed([](Frame frame, runtime::TypedPackedFunc<void()> callback)
{
+ frame->AddExitCallback(callback);
+ });
+TVM_REGISTER_GLOBAL("script.printer.FrameEnterWithScope")
+ .set_body_method<Frame>(&FrameNode::EnterWithScope);
+TVM_REGISTER_GLOBAL("script.printer.FrameExitWithScope")
+ .set_body_method<Frame>(&FrameNode::ExitWithScope);
+
+TVM_REGISTER_NODE_TYPE(MetadataFrameNode);
+TVM_REGISTER_GLOBAL("script.printer.MetadataFrame").set_body_typed([]() {
+ return MetadataFrame();
+});
+
+TVM_REGISTER_NODE_TYPE(VarDefFrameNode);
+TVM_REGISTER_GLOBAL("script.printer.VarDefFrame").set_body_typed([]() { return
VarDefFrame(); });
+
+} // namespace printer
+} // namespace script
+} // namespace tvm
diff --git a/tests/python/unittest/test_tvmscript_printer_frame.py
b/tests/python/unittest/test_tvmscript_printer_frame.py
new file mode 100644
index 0000000000..bd98d64456
--- /dev/null
+++ b/tests/python/unittest/test_tvmscript_printer_frame.py
@@ -0,0 +1,60 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from tvm.script.printer.frame import MetadataFrame
+
+
+def test_frame_add_callback():
+ frame = MetadataFrame()
+
+ flag = 0
+
+ def callback1():
+ nonlocal flag
+ flag += 1
+
+ def callback2():
+ nonlocal flag
+ flag += 5
+
+ frame.add_exit_callback(callback1)
+ with frame:
+ frame.add_exit_callback(callback2)
+ assert flag == 0
+
+ assert flag == 6
+
+
+def test_frame_clear_callbacks_after_exit():
+ frame = MetadataFrame()
+
+ flag = 0
+
+ def callback():
+ nonlocal flag
+ flag += 1
+
+ frame.add_exit_callback(callback)
+
+ with frame:
+ pass
+
+ assert flag == 1
+
+ with frame:
+ pass
+
+ assert flag == 1