This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new d33a332283 [TVMScript] Printer VarTable (#12336)
d33a332283 is described below

commit d33a332283e7349088ba7aa57e56534518bb088b
Author: Lite Ye <[email protected]>
AuthorDate: Sat Aug 13 16:04:39 2022 -0400

    [TVMScript] Printer VarTable (#12336)
    
    This PR:
    
    - Adds VarTable for the new TVMScript Printer
    
    Compared to the prototype version, this:
    
    - Removes unnecessary public methods.
      - GetObjectName
      - GetUniqueName
    - Add Frame parameter for `Define` methods. VarTable will add callback to 
Frame to remove variable when Frame exits.
    - Changes DocFactory from `ExprDoc(ObjectPath)` to `ExprDoc()` to simplify 
var definition.
    
    Tracking issue: https://github.com/apache/tvm/issues/11912
---
 include/tvm/script/printer/var_table.h             | 144 +++++++++++++++++++
 python/tvm/script/printer/var_table.py             | 118 +++++++++++++++
 src/script/printer/var_table.cc                    | 108 ++++++++++++++
 tests/cpp/tvmscript_printer_var_table_test.cc      | 158 +++++++++++++++++++++
 .../unittest/test_tvmscript_printer_var_table.py   |  89 ++++++++++++
 5 files changed, 617 insertions(+)

diff --git a/include/tvm/script/printer/var_table.h 
b/include/tvm/script/printer/var_table.h
new file mode 100644
index 0000000000..9300a976c5
--- /dev/null
+++ b/include/tvm/script/printer/var_table.h
@@ -0,0 +1,144 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef TVM_SCRIPT_PRINTER_VAR_TABLE_H_
+#define TVM_SCRIPT_PRINTER_VAR_TABLE_H_
+
+#include <tvm/node/node.h>
+#include <tvm/node/object_path.h>
+#include <tvm/script/printer/doc.h>
+#include <tvm/script/printer/frame.h>
+#include <tvm/script/printer/traced_object.h>
+
+#include <unordered_map>
+#include <unordered_set>
+
+namespace tvm {
+namespace script {
+namespace printer {
+
+/*!
+ * \brief Variable Table manages mapping from variable object to ExprDoc during
+ * the process of printing TVMScript.
+ *
+ * The value type of this map is ExprDoc rather than IdDoc or String. It's
+ * because variables can be implicitly defined. For example in TIR buffer 
(tir::Buffer),
+ * `buf->data` is a variable, while its representation in TVMScript should be 
an
+ * expression `x.data`, where `x` is the variable for the buffer itself.
+ */
+class VarTableNode : public Object {
+ public:
+  void VisitAttrs(AttrVisitor*) {}
+
+  /*!
+   * \brief Define variable by name.
+   * \param obj The variable object.
+   * \param name_hint The hint for variable name.
+   * \param object_path The object_path for the returned ExprDoc.
+   * \param frame The frame that this variable is defined in.
+   *
+   * \return The id doc for this variable.
+   *
+   * This function will rename the variable to avoid name conflict with other 
variables
+   * in the table.
+   */
+  IdDoc Define(const ObjectRef& obj, const String& name_hint, const 
ObjectPath& object_path,
+               const Frame& frame);
+
+  /*!
+   * \brief Define variable by name.
+   * \param obj The variable object.
+   * \param name_hint The hint for variable name.
+   * \param frame The frame that this variable is defined in.
+   *
+   * \return The id doc for this variable.
+   *
+   * This is a shortcut version of `Define` which accepts a traced string.
+   */
+  IdDoc Define(const ObjectRef& obj, const TracedObject<String>& name_hint, 
const Frame& frame) {
+    return Define(obj, name_hint.Get(), name_hint.GetPath(), frame);
+  }
+
+  using DocFactory = std::function<ExprDoc()>;
+
+  /*!
+   * \brief Define variable by doc factory.
+   * \param obj The variable object.
+   * \param doc_factory The function to return an ExprDoc object for this 
variable.
+   * \param frame The frame that this variable is defined in.
+   *
+   * This function is a special form of `Define`. Variable is mapped to 
ExprDoc rather
+   * than IdDoc. It's useful when a variable is implicitly defined without a 
name, like
+   * the buf->data in TIR, which should be mapped to 
`AttrDoc(IdDoc("<buffer_name>"), "data")`.
+   *
+   * This function takes a DocFactory instead of Doc. It's because GetVarDoc 
needs to
+   * return a new Doc object every time it's called, as the returned doc will 
have
+   * different `soruce_path`. Currently there isn't a good way to deep copy a 
TVMObject
+   * so VarTable needs to call a factory function to get a freshly-constructed 
Doc object
+   * every time GetVarDoc is called.
+   */
+  void DefineByDoc(const ObjectRef& obj, DocFactory doc_factory, const Frame& 
frame);
+
+  /*!
+   * \brief Get the doc for variable.
+   * \param obj The variable object.
+   * \param object_path The object path for the variable.
+   *
+   * \return The doc for variable, if it exists in the table. Otherwise it 
returns NullOpt.
+   */
+  Optional<ExprDoc> GetVarDoc(const ObjectRef& obj, const ObjectPath& 
object_path) const;
+
+  /*!
+   * \brief Check if a variable exists in the table.
+   * \param obj The variable object.
+   *
+   * \return a boolean for whether variable exists.
+   */
+  bool IsVarDefined(const ObjectRef& obj) const;
+
+  static constexpr const char* _type_key = "script.printer.VarTable";
+  TVM_DECLARE_FINAL_OBJECT_INFO(VarTableNode, Object);
+
+ private:
+  void RemoveVar(const ObjectRef& obj);
+
+  struct VariableInfo {
+    DocFactory doc_factory;
+    Optional<String> name;
+  };
+  std::unordered_map<ObjectRef, VariableInfo, ObjectPtrHash, ObjectPtrEqual> 
obj2info;
+  std::unordered_set<String> defined_names;
+};
+
+/*!
+ * \brief Reference type of VarTableNode.
+ */
+class VarTable : public ObjectRef {
+ public:
+  /*!
+   * \brief Create an empty VarTable.
+   */
+  VarTable();
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(VarTable, ObjectRef, 
VarTableNode);
+};
+
+}  // namespace printer
+}  // namespace script
+}  // namespace tvm
+
+#endif  // TVM_SCRIPT_PRINTER_VAR_TABLE_H_
diff --git a/python/tvm/script/printer/var_table.py 
b/python/tvm/script/printer/var_table.py
new file mode 100644
index 0000000000..ea1fa41b32
--- /dev/null
+++ b/python/tvm/script/printer/var_table.py
@@ -0,0 +1,118 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Functions to print doc into text format"""
+
+from typing import Callable, Optional
+
+from tvm._ffi import register_object
+from tvm.runtime import Object, ObjectPath
+
+from . import _ffi_api
+from .doc import ExprDoc, IdDoc
+from .frame import Frame
+
+
+@register_object("script.printer.VarTable")
+class VarTable(Object):
+    """
+    Variable Table manages mapping from variable object to ExprDoc during
+    the process of printing TVMScript.
+    """
+
+    def __init__(self):
+        """
+        Create an empty VarTable.
+        """
+        self.__init_handle_by_constructor__(_ffi_api.VarTable)  # type: ignore 
# pylint: disable=no-member
+
+    def define(self, obj: Object, name_hint: str, object_path: ObjectPath, 
frame: Frame) -> IdDoc:
+        """
+        Define a variable by name.
+
+        Parameters
+        ----------
+        obj : Object
+            The variable object.
+        name_hint : str
+            The hint for variable name.
+        object_path : ObjectPath
+            The object path to be associated with the returned ExprDoc.
+        frame : Frame
+            Then frame that this variable is defined in.
+
+        Returns
+        -------
+        doc : IdDoc
+            The doc for this variable.
+        """
+        return _ffi_api.VarTableDefine(self, obj, name_hint, object_path, 
frame)  # type: ignore # pylint: disable=no-member
+
+    def define_by_doc(self, obj: Object, doc_factory: Callable[[], ExprDoc], 
frame: Frame) -> None:
+        """
+        Define a variable by ExprDoc.
+
+        Parameters
+        ----------
+        obj : Object
+            The variable object.
+        doc_factory : Callable[[], ExprDoc]
+            The hint for variable name.
+        frame : Frame
+            Then frame that this variable is defined in.
+
+        Returns
+        -------
+        None
+        """
+        _ffi_api.VarTableDefineByDoc(self, obj, doc_factory, frame)  # type: 
ignore # pylint: disable=no-member
+
+    def get_var_doc(self, obj: Object, object_path: ObjectPath) -> 
Optional[ExprDoc]:
+        """
+        Get the doc for a variable.
+
+        Parameters
+        ----------
+        obj : Object
+            The variable object.
+        object_path : ObjectPath
+            The object path to be associated with the returned ExprDoc.
+
+        Returns
+        -------
+        doc : ExprDoc
+            The doc for this variable.
+        """
+        return _ffi_api.VarTableGetVarDoc(self, obj, object_path)  # type: 
ignore # pylint: disable=no-member
+
+    def is_var_defined(self, obj: Object) -> bool:
+        """
+        Check whether a variable is defined.
+
+        Parameters
+        ----------
+        obj : Object
+            The variable object.
+
+        Returns
+        -------
+        is_defined : bool
+            Whether the variable is defined.
+        """
+        return _ffi_api.VarTableIsVarDefined(self, obj)  # type: ignore # 
pylint: disable=no-member
+
+    def __contains__(self, obj: Object) -> bool:
+        return self.is_var_defined(obj)
diff --git a/src/script/printer/var_table.cc b/src/script/printer/var_table.cc
new file mode 100644
index 0000000000..49ba93f9bc
--- /dev/null
+++ b/src/script/printer/var_table.cc
@@ -0,0 +1,108 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <tvm/node/object_path.h>
+#include <tvm/runtime/container/optional.h>
+#include <tvm/runtime/logging.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/script/printer/var_table.h>
+
+namespace tvm {
+namespace script {
+namespace printer {
+
+String GenerateUniqueName(const String& name_hint, std::unordered_set<String>* 
defined_names) {
+  String name = name_hint;
+  for (int i = 1; !defined_names->insert(name).second; ++i) {
+    name = name_hint + "_" + std::to_string(i);
+  }
+  return name;
+}
+
+IdDoc VarTableNode::Define(const ObjectRef& obj, const String& name_hint,
+                           const ObjectPath& object_path, const Frame& frame) {
+  String name = GenerateUniqueName(name_hint, &this->defined_names);
+  DocFactory doc_factory = [name]() { return IdDoc(name); };
+
+  auto result = obj2info.insert({obj, VariableInfo{std::move(doc_factory), 
name}});
+  ICHECK(result.second) << "Duplicated object: " << obj;
+
+  IdDoc def_doc(name);
+  def_doc->source_paths.push_back(object_path);
+
+  frame->AddExitCallback([this, obj]() { this->RemoveVar(obj); });
+
+  return def_doc;
+}
+
+void VarTableNode::DefineByDoc(const ObjectRef& obj, DocFactory doc_factory, 
const Frame& frame) {
+  ICHECK(obj2info.find(obj) == obj2info.end()) << "Duplicated object: " << obj;
+
+  ICHECK(!doc_factory()->IsInstance<IdDocNode>())
+      << "VarTableNode::Define cannot be used for variable that's mapped to 
IdDoc.";
+
+  obj2info.insert({obj, VariableInfo{std::move(doc_factory), NullOpt}});
+
+  frame->AddExitCallback([this, obj]() { this->RemoveVar(obj); });
+}
+
+Optional<ExprDoc> VarTableNode::GetVarDoc(const ObjectRef& obj,
+                                          const ObjectPath& object_path) const 
{
+  auto it = obj2info.find(obj);
+  if (it == obj2info.end()) {
+    return NullOpt;
+  }
+  ExprDoc doc = it->second.doc_factory();
+  doc->source_paths.push_back(object_path);
+  return doc;
+}
+
+bool VarTableNode::IsVarDefined(const ObjectRef& obj) const { return 
obj2info.count(obj); }
+
+void VarTableNode::RemoveVar(const ObjectRef& obj) {
+  auto it = obj2info.find(obj);
+  ICHECK(it != obj2info.end()) << "No such object: " << obj;
+
+  if (it->second.name.defined()) {
+    defined_names.erase(it->second.name.value());
+  }
+  obj2info.erase(it);
+}
+
+VarTable::VarTable() { data_ = make_object<VarTableNode>(); }
+
+TVM_REGISTER_NODE_TYPE(VarTableNode);
+TVM_REGISTER_GLOBAL("script.printer.VarTable").set_body_typed([]() { return 
VarTable(); });
+TVM_REGISTER_GLOBAL("script.printer.VarTableDefine")
+    .set_body_method<VarTable, VarTableNode, IdDoc, const ObjectRef&, const 
String&,
+                     const ObjectPath&, const Frame&>(&VarTableNode::Define);
+TVM_REGISTER_GLOBAL("script.printer.VarTableDefineByDoc")
+    .set_body_typed([](VarTable var_table, const ObjectRef& obj, 
runtime::PackedFunc factory,
+                       Frame frame) {
+      var_table->DefineByDoc(
+          obj, [f = std::move(factory)]() { return f(); }, frame);
+    });
+TVM_REGISTER_GLOBAL("script.printer.VarTableGetVarDoc")
+    .set_body_method<VarTable>(&VarTableNode::GetVarDoc);
+TVM_REGISTER_GLOBAL("script.printer.VarTableIsVarDefined")
+    .set_body_method<VarTable>(&VarTableNode::IsVarDefined);
+
+}  // namespace printer
+}  // namespace script
+}  // namespace tvm
diff --git a/tests/cpp/tvmscript_printer_var_table_test.cc 
b/tests/cpp/tvmscript_printer_var_table_test.cc
new file mode 100644
index 0000000000..b447c81ac0
--- /dev/null
+++ b/tests/cpp/tvmscript_printer_var_table_test.cc
@@ -0,0 +1,158 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <dmlc/logging.h>
+#include <gtest/gtest.h>
+#include <tvm/node/object_path.h>
+#include <tvm/runtime/logging.h>
+#include <tvm/script/printer/frame.h>
+#include <tvm/script/printer/var_table.h>
+#include <tvm/tir/var.h>
+
+using namespace tvm;
+using namespace tvm::script::printer;
+
+TEST(PrinterVarTableTest, Define) {
+  VarTable vars;
+  MetadataFrame frame;
+  tir::Var x("x");
+  ObjectPath object_path = ObjectPath::Root();
+
+  IdDoc doc = vars->Define(x, "x", object_path, frame);
+
+  ICHECK_EQ(doc->name, "x");
+
+  IdDoc second_doc = Downcast<IdDoc>(vars->GetVarDoc(x, object_path).value());
+
+  ICHECK_EQ(second_doc->name, "x");
+}
+
+TEST(PrinterVarTableTest, DefineByDoc) {
+  VarTable vars;
+  MetadataFrame frame;
+  tir::Var x("x");
+  ObjectPath object_path = ObjectPath::Root();
+
+  auto doc_factory = []() { return LiteralDoc::Str("x"); };
+
+  vars->DefineByDoc(x, doc_factory, frame);
+
+  ExprDoc doc = vars->GetVarDoc(x, object_path).value();
+
+  ICHECK_EQ(Downcast<String>(Downcast<LiteralDoc>(doc)->value), "x");
+}
+
+TEST(PrinterVarTableTest, GetVarDocWithUnknownVariable) {
+  VarTable vars;
+  MetadataFrame frame;
+  tir::Var x("x");
+  tir::Var y("y");
+  ObjectPath object_path = ObjectPath::Root();
+
+  Doc doc = vars->Define(x, "x", object_path, frame);
+  ICHECK(!vars->GetVarDoc(y, object_path).defined());
+}
+
+TEST(PrinterVarTableTest, GetVarDocWithObjectPath) {
+  VarTable vars;
+  MetadataFrame frame;
+  tir::Var x("x");
+  ObjectPath object_path = ObjectPath::Root();
+  ObjectPath second_object_path = ObjectPath::Root()->Attr("x");
+
+  IdDoc doc = vars->Define(x, "x", object_path, frame);
+  ICHECK_EQ(doc->source_paths[0], object_path);
+  ICHECK_EQ(doc->source_paths.size(), 1);
+
+  Doc second_doc = vars->GetVarDoc(x, second_object_path).value();
+  ICHECK_EQ(second_doc->source_paths[0], second_object_path);
+  ICHECK_EQ(second_doc->source_paths.size(), 1);
+}
+
+TEST(PrinterVarTableTest, IsVarDefined) {
+  VarTable vars;
+  MetadataFrame frame;
+  tir::Var x("x");
+  tir::Var y("y");
+  ObjectPath object_path = ObjectPath::Root();
+
+  vars->Define(x, "x", object_path, frame);
+  ICHECK(vars->IsVarDefined(x));
+  ICHECK(!vars->IsVarDefined(y));
+}
+
+TEST(PrinterVarTableTest, VarRemovedAfterFrameOutOfScope) {
+  VarTable vars;
+  MetadataFrame frame;
+  tir::Var x("x");
+  ObjectPath object_path = ObjectPath::Root();
+
+  vars->Define(x, "x", object_path, frame);
+  ICHECK(vars->IsVarDefined(x));
+
+  frame->ExitWithScope();
+  ICHECK(!vars->IsVarDefined(x));
+}
+
+TEST(PrinterVarTableTest, DefineDuplicateName) {
+  VarTable vars;
+  MetadataFrame frame;
+  tir::Var x("x");
+  tir::Var y("y");
+  ObjectPath object_path = ObjectPath::Root();
+
+  IdDoc x_doc = vars->Define(x, "x", object_path, frame);
+  IdDoc y_doc = vars->Define(y, "x", object_path, frame);
+
+  ICHECK_NE(x_doc->name, y_doc->name);
+}
+
+TEST(PrinterVarTableTest, DefineDuplicateVariable) {
+  VarTable vars;
+  MetadataFrame frame;
+  tir::Var x("x");
+  ObjectPath object_path = ObjectPath::Root();
+
+  vars->Define(x, "x", object_path, frame);
+
+  bool failed = false;
+  try {
+    vars->Define(x, "x", object_path, frame);
+  } catch (...) {
+    failed = true;
+  }
+  ASSERT_EQ(failed, true);
+}
+
+TEST(PrinterVarTableTest, DefineByDocWithIdDoc) {
+  VarTable vars;
+  MetadataFrame frame;
+  tir::Var x("x");
+  ObjectPath object_path = ObjectPath::Root();
+
+  bool failed = false;
+  try {
+    // User has to use `Define` if variable needs to be mapped to IdDoc
+    vars->DefineByDoc(
+        x, []() { return IdDoc("x"); }, frame);
+  } catch (...) {
+    failed = true;
+  }
+  ASSERT_EQ(failed, true);
+}
diff --git a/tests/python/unittest/test_tvmscript_printer_var_table.py 
b/tests/python/unittest/test_tvmscript_printer_var_table.py
new file mode 100644
index 0000000000..eab63a08dd
--- /dev/null
+++ b/tests/python/unittest/test_tvmscript_printer_var_table.py
@@ -0,0 +1,89 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+This file tests the FFI binding of script.printer.VarTable.
+These only make sure parameter can be passed to the C++ functions
+correctly. The test for the functionality of VarTable is in C++.
+"""
+
+from tvm.runtime import ObjectPath
+from tvm.script.printer.doc import LiteralDoc
+from tvm.script.printer.frame import VarDefFrame
+from tvm.script.printer.var_table import VarTable
+from tvm.tir import Var
+
+
+def test_define():
+    var_table = VarTable()
+    var_name = "a"
+    var_obj = Var(var_name, dtype="int32")
+    object_path = ObjectPath.root().attr("a")
+    frame = VarDefFrame()
+
+    id_doc = var_table.define(var_obj, var_name, object_path, frame)
+
+    assert id_doc.name == "a"
+    assert list(id_doc.source_paths) == [object_path]
+
+    id_doc = var_table.get_var_doc(var_obj, object_path)
+
+    assert id_doc.name == "a"
+    assert list(id_doc.source_paths) == [object_path]
+
+
+def test_define_by_doc():
+    var_table = VarTable()
+    var_name = "a"
+    var_obj = Var(var_name, dtype="int32")
+    object_path = ObjectPath.root().attr("a")
+    frame = VarDefFrame()
+
+    var_table.define_by_doc(var_obj, lambda: LiteralDoc(var_name), frame)
+
+    var_doc = var_table.get_var_doc(var_obj, object_path)
+
+    assert isinstance(var_doc, LiteralDoc)
+    assert var_doc.value == var_name
+    assert list(var_doc.source_paths) == [object_path]
+
+
+def test_is_var_defined():
+    var_table = VarTable()
+    a = Var("a", dtype="int32")
+    object_path = ObjectPath.root().attr("a")
+    frame = VarDefFrame()
+
+    var_table.define(a, "a", object_path, frame)
+
+    assert var_table.is_var_defined(a)
+    assert a in var_table
+
+
+def test_var_out_of_scope():
+    var_table = VarTable()
+    var_name = "a"
+    var_obj = Var(var_name, dtype="int32")
+    object_path = ObjectPath.root().attr("a")
+    frame = VarDefFrame()
+
+    var_table.define(var_obj, var_name, object_path, frame)
+
+    with frame:
+        assert var_obj in var_table
+
+    assert var_obj not in var_table
+    assert var_table.get_var_doc(var_obj, object_path) is None

Reply via email to