This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 38c3ad4cad [TVMScript] Use default fallback for un-registered type 
(#14874)
38c3ad4cad is described below

commit 38c3ad4cad5dc7268aadc278ad2c695347d154f3
Author: Eric Lunderberg <[email protected]>
AuthorDate: Wed May 17 23:03:22 2023 -0500

    [TVMScript] Use default fallback for un-registered type (#14874)
    
    Previously, even if an object is supported by TVMScript, printing that
    object could fail if the object contains references to un-registered
    objects.  In particular, the attributes in the `PrimFunc` and
    `AttrStmt` may contain any `ObjectRef`, and are frequently used for
    external bookkeeping.  This commit adds a default printer for
    unregistered types in TVMScript.
---
 include/tvm/script/printer/ir_docsifier_functor.h | 33 ++++++++++++++++
 src/script/printer/ir/relay.cc                    | 48 -----------------------
 src/script/printer/ir_docsifier.cc                |  5 +++
 3 files changed, 38 insertions(+), 48 deletions(-)

diff --git a/include/tvm/script/printer/ir_docsifier_functor.h 
b/include/tvm/script/printer/ir_docsifier_functor.h
index 54810fd55a..e63c00f68a 100644
--- a/include/tvm/script/printer/ir_docsifier_functor.h
+++ b/include/tvm/script/printer/ir_docsifier_functor.h
@@ -23,6 +23,7 @@
 #include <tvm/runtime/logging.h>
 #include <tvm/runtime/packed_func.h>
 
+#include <optional>
 #include <string>
 #include <type_traits>
 #include <unordered_map>
@@ -69,6 +70,10 @@ class IRDocsifierFunctor {
     if ((pf = LookupDispatchTable("", type_index)) != nullptr) {
       return (*pf)(obj, args...);
     }
+    if ((pf = LookupFallback()) != nullptr) {
+      return (*pf)(obj, args...);
+    }
+
     LOG(WARNING) << "ObjectFunctor calls un-registered function on type: "
                  << runtime::Object::TypeIndex2Key(type_index) << " (token: " 
<< token << ")"
                  << ". ObjectType: " << obj->GetTypeKey() << ". Object: " << 
obj;
@@ -100,6 +105,14 @@ class IRDocsifierFunctor {
     return *this;
   }
 
+  TSelf& set_fallback(runtime::PackedFunc f) {
+    ICHECK(!dispatch_fallback_.has_value()) << "Fallback is already defined";
+    dispatch_fallback_ = f;
+    return *this;
+  }
+
+  void remove_fallback() { dispatch_fallback_ = std::nullopt; }
+
   /*!
    * \brief Set the dispatch function
    * \param token The dispatch token.
@@ -112,6 +125,13 @@ class IRDocsifierFunctor {
                         runtime::TypedPackedFunc<R(TObjectRef, Args...)>(f));
   }
 
+  template <typename TCallable,
+            typename = std::enable_if_t<IsDispatchFunction<ObjectRef, 
TCallable>::value>>
+  TSelf& set_fallback(TCallable f) {
+    runtime::PackedFunc func = runtime::TypedPackedFunc<R(ObjectRef, 
Args...)>(f);
+    return set_fallback(func);
+  }
+
   /*!
    * \brief Remove dispatch function
    * \param token The dispatch token.
@@ -151,6 +171,18 @@ class IRDocsifierFunctor {
       return nullptr;
     }
   }
+
+  /*!
+   * \brief Look up the fallback to be used if no handler is registered
+   */
+  const runtime::PackedFunc* LookupFallback() const {
+    if (dispatch_fallback_.has_value()) {
+      return &*dispatch_fallback_;
+    } else {
+      return nullptr;
+    }
+  }
+
   /*
    * This type alias and the following free functions are created to reduce 
the binary bloat
    * from template and also hide implementation details from this header
@@ -158,6 +190,7 @@ class IRDocsifierFunctor {
   using DispatchTable = std::unordered_map<std::string, 
std::vector<runtime::PackedFunc>>;
   /*! \brief The dispatch table. */
   DispatchTable dispatch_table_;
+  std::optional<runtime::PackedFunc> dispatch_fallback_;
 };
 
 }  // namespace printer
diff --git a/src/script/printer/ir/relay.cc b/src/script/printer/ir/relay.cc
deleted file mode 100644
index 574c07e32a..0000000000
--- a/src/script/printer/ir/relay.cc
+++ /dev/null
@@ -1,48 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-#include <tvm/relay/executor.h>
-#include <tvm/relay/runtime.h>
-
-#include "../../../relay/backend/utils.h"
-#include "./utils.h"
-
-namespace tvm {
-namespace script {
-namespace printer {
-
-TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
-    .set_dispatch<relay::Executor>("", [](relay::Executor ty, ObjectPath p, 
IRDocsifier d) -> Doc {
-      return d->AddMetadata(ty);
-    });
-
-TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
-    .set_dispatch<relay::Runtime>("", [](relay::Runtime ty, ObjectPath p, 
IRDocsifier d) -> Doc {
-      return d->AddMetadata(ty);
-    });
-
-TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
-    .set_dispatch<relay::backend::FunctionInfo>("",
-                                                
[](relay::backend::FunctionInfo ty, ObjectPath p,
-                                                   IRDocsifier d) -> Doc {
-                                                  return d->AddMetadata(ty);
-                                                });
-
-}  // namespace printer
-}  // namespace script
-}  // namespace tvm
diff --git a/src/script/printer/ir_docsifier.cc 
b/src/script/printer/ir_docsifier.cc
index fd5003073a..62084d17be 100644
--- a/src/script/printer/ir_docsifier.cc
+++ b/src/script/printer/ir_docsifier.cc
@@ -160,6 +160,11 @@ IRDocsifier::FType& IRDocsifier::vtable() {
 TVM_REGISTER_NODE_TYPE(FrameNode);
 TVM_REGISTER_NODE_TYPE(IRDocsifierNode);
 
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
+    .set_fallback([](ObjectRef obj, ObjectPath p, IRDocsifier d) -> Doc {
+      return d->AddMetadata(obj);
+    });
+
 }  // namespace printer
 }  // namespace script
 }  // namespace tvm

Reply via email to