This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 684689e924 [Bugfix] [Relay] fix a bug of  printing dataflow pattern 
(#15350)
684689e924 is described below

commit 684689e9247630c5cce66363b9f877284d94523a
Author: 电线杆 <[email protected]>
AuthorDate: Mon Jul 24 09:29:18 2023 +0800

    [Bugfix] [Relay] fix a bug of  printing dataflow pattern (#15350)
    
    * fix bug of dataflow pattern print
    
    * fix linting problems
    
    * Improve documentation. Minor modification of the printed text.
---
 include/tvm/relay/dataflow_pattern.h |  41 +++++
 src/relay/ir/dataflow_pattern.cc     | 304 +++++++++++++++++++++++++++--------
 2 files changed, 280 insertions(+), 65 deletions(-)

diff --git a/include/tvm/relay/dataflow_pattern.h 
b/include/tvm/relay/dataflow_pattern.h
index 46abee5d44..8c30a0df9f 100644
--- a/include/tvm/relay/dataflow_pattern.h
+++ b/include/tvm/relay/dataflow_pattern.h
@@ -28,6 +28,8 @@
 #include <tvm/relay/type.h>
 
 #include <string>
+#include <unordered_map>
+#include <utility>
 #include <vector>
 
 namespace tvm {
@@ -537,6 +539,45 @@ DFPattern IsTuple(const Array<DFPattern>& fields);
 /*! \brief Syntatic Sugar for creating a TupleGetItemPattern*/
 DFPattern IsTupleGetItem(const DFPattern tuple, int index = -1);
 
+/*! \brief A printer class to print pattern. */
+class DFPatternPrinter : public ReprPrinter {
+ public:
+  std::stringstream string_stream{};
+
+  std::unordered_map<DFPattern, std::pair<size_t, std::string>, ObjectPtrHash, 
ObjectPtrEqual>
+      memo_{};
+  /*! \brief Subpatterns that are encountered more than once during printing. 
If a subpattern has
+   * already printed, only the pattern ID will be printed in the next 
encounter of the same pattern.
+   * This avoids printing a subpattern infinitely many times is the considered 
pattern involves
+   * recursion.*/
+  std::vector<DFPattern> auxiliary_patterns{};
+
+  DFPatternPrinter(std::ostream& stream)  // NOLINT(*)
+      : ReprPrinter(stream) {}
+  TVM_DLL void Print(const ObjectRef& node);
+  using FType = NodeFunctor<void(const ObjectRef&, DFPatternPrinter*)>;
+  TVM_DLL static FType& vtable();
+};
+
+inline std::ostream& operator<<(std::ostream& os,
+                                const DFPattern& n) {  // NOLINT(*)
+  std::stringstream string_stream{}, tmp_stream{};
+  DFPatternPrinter printer{tmp_stream};
+  printer.Print(n);
+  string_stream << "Main pattern:" << std::endl;
+  string_stream << printer.string_stream.str();
+  string_stream << std::endl;
+  string_stream << "Auxiliary patterns:";
+  for (const DFPattern& pat : printer.auxiliary_patterns) {
+    string_stream << std::endl;
+    string_stream << printer.memo_[pat].second;
+  }
+  os << string_stream.str();
+  return os;
+}
+
+String PrettyPrint(const DFPattern& pattern);
+
 }  // namespace relay
 }  // namespace tvm
 #endif  // TVM_RELAY_DATAFLOW_PATTERN_H_
diff --git a/src/relay/ir/dataflow_pattern.cc b/src/relay/ir/dataflow_pattern.cc
index 1f5dba6aca..c141ca51ef 100644
--- a/src/relay/ir/dataflow_pattern.cc
+++ b/src/relay/ir/dataflow_pattern.cc
@@ -27,6 +27,39 @@
 namespace tvm {
 namespace relay {
 
+DFPatternPrinter::FType& DFPatternPrinter::vtable() {
+  static FType inst;
+  return inst;
+}
+
+String PrettyPrint(const DFPattern& pattern) {
+  std::stringstream string_stream{};
+  string_stream << pattern;
+  return string_stream.str();
+}
+
+void DFPatternPrinter::Print(const ObjectRef& node) {
+  ICHECK(node.as<DFPatternNode>());
+  DFPattern pat = Downcast<DFPattern>(node);
+  static const FType& f = vtable();
+  string_stream.str("");
+  if (!node.defined()) {
+    string_stream << "(nullptr)";
+  } else if (memo_.find(pat) != memo_.end()) {
+    string_stream << "(invoke pattern id " << memo_[pat].first << ")";
+    auxiliary_patterns.push_back(pat);
+  } else {
+    if (f.can_dispatch(node)) {
+      memo_.insert({pat, {memo_.size(), ""}});
+      f(node, this);
+      memo_[pat].second = string_stream.str();
+    } else {
+      // default value, output type key and addr.
+      string_stream << node->GetTypeKey() << "(" << node.get() << ")";
+    }
+  }
+}
+
 ExprPattern::ExprPattern(Expr expr) {
   ObjectPtr<ExprPatternNode> n = make_object<ExprPatternNode>();
   n->expr = std::move(expr);
@@ -39,10 +72,11 @@ 
TVM_REGISTER_GLOBAL("relay.dataflow_pattern.ExprPattern").set_body_typed([](Expr
   return ExprPattern(e);
 });
 
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-    .set_dispatch<ExprPatternNode>([](const ObjectRef& ref, ReprPrinter* p) {
-      auto* node = static_cast<const ExprPatternNode*>(ref.get());
-      p->Print(node->expr);
+TVM_STATIC_IR_FUNCTOR(DFPatternPrinter, vtable)
+    .set_dispatch<ExprPatternNode>([](const ObjectRef& ref, DFPatternPrinter* 
p) {
+      ExprPattern pattern = Downcast<ExprPattern>(ref);
+      p->string_stream.str("");
+      p->string_stream << pattern->expr;
     });
 
 VarPattern::VarPattern(String name_hint) {
@@ -57,10 +91,11 @@ 
TVM_REGISTER_GLOBAL("relay.dataflow_pattern.VarPattern").set_body_typed([](Strin
   return VarPattern(name_hint);
 });
 
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-    .set_dispatch<VarPatternNode>([](const ObjectRef& ref, ReprPrinter* p) {
-      auto* node = static_cast<const VarPatternNode*>(ref.get());
-      p->stream << "VarPattern(" << node->name_hint() << ")";
+TVM_STATIC_IR_FUNCTOR(DFPatternPrinter, vtable)
+    .set_dispatch<VarPatternNode>([](const ObjectRef& ref, DFPatternPrinter* 
p) {
+      VarPattern pattern = Downcast<VarPattern>(ref);
+      p->string_stream.str("");
+      p->string_stream << "VarPattern(" << pattern->name_hint() << ")";
     });
 
 TVM_REGISTER_NODE_TYPE(ConstantPatternNode);
@@ -70,9 +105,10 @@ 
TVM_REGISTER_GLOBAL("relay.dataflow_pattern.ConstantPattern").set_body_typed([](
   return c;
 });
 
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-    .set_dispatch<ConstantPatternNode>([](const ObjectRef& ref, ReprPrinter* 
p) {
-      p->stream << "ConstantPattern()";
+TVM_STATIC_IR_FUNCTOR(DFPatternPrinter, vtable)
+    .set_dispatch<ConstantPatternNode>([](const ObjectRef& ref, 
DFPatternPrinter* p) {
+      p->string_stream.str("");
+      p->string_stream << "ConstantPattern()";
     });
 
 CallPattern::CallPattern(DFPattern op, Array<DFPattern> args) {
@@ -87,10 +123,29 @@ TVM_REGISTER_NODE_TYPE(CallPatternNode);
 TVM_REGISTER_GLOBAL("relay.dataflow_pattern.CallPattern")
     .set_body_typed([](DFPattern op, Array<DFPattern> args) { return 
CallPattern(op, args); });
 
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-    .set_dispatch<CallPatternNode>([](const ObjectRef& ref, ReprPrinter* p) {
-      auto* node = static_cast<const CallPatternNode*>(ref.get());
-      p->stream << "CallPatternNode(" << node->op << ", " << node->args << ")";
+TVM_STATIC_IR_FUNCTOR(DFPatternPrinter, vtable)
+    .set_dispatch<CallPatternNode>([](const ObjectRef& ref, DFPatternPrinter* 
p) {
+      CallPattern pattern = Downcast<CallPattern>(ref);
+
+      p->Print(pattern->op);
+      std::string op_pattern_string{p->string_stream.str()};
+
+      std::vector<std::string> args_pattern_string{};
+      for (const DFPattern& arg : pattern->args) {
+        p->Print(arg);
+        args_pattern_string.push_back(p->string_stream.str());
+      }
+
+      p->string_stream.str("");
+      p->string_stream << "(id " << p->memo_[pattern].first << "): ";
+      p->string_stream << "CallPatternNode(" << op_pattern_string << ", [";
+      for (size_t i = 0; i < args_pattern_string.size(); ++i) {
+        if (i != 0) {
+          p->string_stream << ", ";
+        }
+        p->string_stream << args_pattern_string[i];
+      }
+      p->string_stream << "])";
     });
 
 FunctionPattern::FunctionPattern(Array<DFPattern> params, DFPattern body) {
@@ -106,10 +161,31 @@ 
TVM_REGISTER_GLOBAL("relay.dataflow_pattern.FunctionPattern")
       return FunctionPattern(params, body);
     });
 
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-    .set_dispatch<FunctionPatternNode>([](const ObjectRef& ref, ReprPrinter* 
p) {
-      auto* node = static_cast<const FunctionPatternNode*>(ref.get());
-      p->stream << "FunctionPatternNode(" << node->params << ", " << 
node->body << ")";
+TVM_STATIC_IR_FUNCTOR(DFPatternPrinter, vtable)
+    .set_dispatch<FunctionPatternNode>([](const ObjectRef& ref, 
DFPatternPrinter* p) {
+      FunctionPattern pattern = Downcast<FunctionPattern>(ref);
+
+      std::vector<std::string> params_pattern_string{};
+      for (const DFPattern& param : pattern->params) {
+        p->Print(param);
+        params_pattern_string.push_back(p->string_stream.str());
+      }
+
+      p->Print(pattern->body);
+      std::string body_pattern_string{p->string_stream.str()};
+
+      p->string_stream.str("");
+      p->string_stream << "(id " << p->memo_[pattern].first << "): ";
+
+      p->string_stream << "FunctionPatternNode([";
+      for (size_t i = 0; i < params_pattern_string.size(); ++i) {
+        if (i != 0) {
+          p->string_stream << ", ";
+        }
+        p->string_stream << params_pattern_string[i];
+      }
+      p->string_stream << "]";
+      p->string_stream << ", " << body_pattern_string << ")";
     });
 
 LetPattern::LetPattern(DFPattern var, DFPattern value, DFPattern body) {
@@ -127,11 +203,23 @@ TVM_REGISTER_GLOBAL("relay.dataflow_pattern.LetPattern")
       return LetPattern(var, value, body);
     });
 
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-    .set_dispatch<LetPatternNode>([](const ObjectRef& ref, ReprPrinter* p) {
-      auto* node = static_cast<const LetPatternNode*>(ref.get());
-      p->stream << "LetPatternNode(" << node->var << ", " << node->value << ", 
" << node->body
-                << ")";
+TVM_STATIC_IR_FUNCTOR(DFPatternPrinter, vtable)
+    .set_dispatch<LetPatternNode>([](const ObjectRef& ref, DFPatternPrinter* 
p) {
+      LetPattern pattern = Downcast<LetPattern>(ref);
+
+      p->Print(pattern->var);
+      std::string var_pattern_string{p->string_stream.str()};
+
+      p->Print(pattern->value);
+      std::string value_pattern_string{p->string_stream.str()};
+
+      p->Print(pattern->body);
+      std::string body_pattern_string{p->string_stream.str()};
+
+      p->string_stream.str("");
+      p->string_stream << "(id " << p->memo_[pattern].first << "): ";
+      p->string_stream << "LetPatternNode(" << var_pattern_string << ", " << 
value_pattern_string
+                       << ", " << body_pattern_string << ")";
     });
 
 IfPattern::IfPattern(DFPattern cond, DFPattern true_branch, DFPattern 
false_branch) {
@@ -149,11 +237,23 @@ TVM_REGISTER_GLOBAL("relay.dataflow_pattern.IfPattern")
       return IfPattern(cond, true_branch, false_branch);
     });
 
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-    .set_dispatch<IfPatternNode>([](const ObjectRef& ref, ReprPrinter* p) {
-      auto* node = static_cast<const IfPatternNode*>(ref.get());
-      p->stream << "IfPattern(" << node->cond << ", " << node->true_branch << 
", "
-                << node->false_branch << ")";
+TVM_STATIC_IR_FUNCTOR(DFPatternPrinter, vtable)
+    .set_dispatch<IfPatternNode>([](const ObjectRef& ref, DFPatternPrinter* p) 
{
+      IfPattern pattern = Downcast<IfPattern>(ref);
+
+      p->Print(pattern->cond);
+      std::string cond_pattern_string{p->string_stream.str()};
+
+      p->Print(pattern->true_branch);
+      std::string true_branch_pattern_string{p->string_stream.str()};
+
+      p->Print(pattern->false_branch);
+      std::string false_branch_pattern_string{p->string_stream.str()};
+
+      p->string_stream.str("");
+      p->string_stream << "(id " << p->memo_[pattern].first << "): ";
+      p->string_stream << "IfPattern(" << cond_pattern_string << ", " << 
true_branch_pattern_string
+                       << ", " << false_branch_pattern_string << ")";
     });
 
 TuplePattern::TuplePattern(tvm::Array<DFPattern> fields) {
@@ -167,10 +267,28 @@ TVM_REGISTER_NODE_TYPE(TuplePatternNode);
 TVM_REGISTER_GLOBAL("relay.dataflow_pattern.TuplePattern")
     .set_body_typed([](tvm::Array<DFPattern> fields) { return 
TuplePattern(fields); });
 
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-    .set_dispatch<TuplePatternNode>([](const ObjectRef& ref, ReprPrinter* p) {
-      auto* node = static_cast<const TuplePatternNode*>(ref.get());
-      p->stream << "TuplePattern(" << node->fields << ")";
+TVM_STATIC_IR_FUNCTOR(DFPatternPrinter, vtable)
+    .set_dispatch<TuplePatternNode>([](const ObjectRef& ref, DFPatternPrinter* 
p) {
+      TuplePattern pattern = Downcast<TuplePattern>(ref);
+
+      std::vector<std::string> fields_pattern_string{};
+      for (const DFPattern& field : pattern->fields) {
+        p->Print(field);
+        fields_pattern_string.push_back(p->string_stream.str());
+      }
+
+      p->string_stream.str("");
+      p->string_stream << "(id " << p->memo_[pattern].first << "): ";
+      p->string_stream << "TuplePattern(";
+      p->string_stream << "[";
+      for (size_t i = 0; i < fields_pattern_string.size(); ++i) {
+        if (i != 0) {
+          p->string_stream << ", ";
+        }
+        p->string_stream << fields_pattern_string[i];
+      }
+      p->string_stream << "]";
+      p->string_stream << ")";
     });
 
 TupleGetItemPattern::TupleGetItemPattern(DFPattern tuple, int index) {
@@ -185,10 +303,17 @@ TVM_REGISTER_NODE_TYPE(TupleGetItemPatternNode);
 TVM_REGISTER_GLOBAL("relay.dataflow_pattern.TupleGetItemPattern")
     .set_body_typed([](DFPattern tuple, int index) { return 
TupleGetItemPattern(tuple, index); });
 
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-    .set_dispatch<TupleGetItemPatternNode>([](const ObjectRef& ref, 
ReprPrinter* p) {
-      auto* node = static_cast<const TupleGetItemPatternNode*>(ref.get());
-      p->stream << "TupleGetItemPatternNode(" << node->tuple << ", " << 
node->index << ")";
+TVM_STATIC_IR_FUNCTOR(DFPatternPrinter, vtable)
+    .set_dispatch<TupleGetItemPatternNode>([](const ObjectRef& ref, 
DFPatternPrinter* p) {
+      TupleGetItemPattern pattern = Downcast<TupleGetItemPattern>(ref);
+
+      p->Print(pattern->tuple);
+      std::string tuple_pattern_string{p->string_stream.str()};
+
+      p->string_stream.str("");
+      p->string_stream << "(id " << p->memo_[pattern].first << "): ";
+      p->string_stream << "TupleGetItemPatternNode(";
+      p->string_stream << tuple_pattern_string << ", " << pattern->index << 
")";
     });
 
 AltPattern::AltPattern(DFPattern left, DFPattern right) {
@@ -203,10 +328,20 @@ TVM_REGISTER_NODE_TYPE(AltPatternNode);
 TVM_REGISTER_GLOBAL("relay.dataflow_pattern.AltPattern")
     .set_body_typed([](DFPattern left, DFPattern right) { return 
AltPattern(left, right); });
 
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-    .set_dispatch<AltPatternNode>([](const ObjectRef& ref, ReprPrinter* p) {
-      auto* node = static_cast<const AltPatternNode*>(ref.get());
-      p->stream << "AltPattern(" << node->left << " | " << node->right << ")";
+TVM_STATIC_IR_FUNCTOR(DFPatternPrinter, vtable)
+    .set_dispatch<AltPatternNode>([](const ObjectRef& ref, DFPatternPrinter* 
p) {
+      AltPattern pattern = Downcast<AltPattern>(ref);
+
+      p->Print(pattern->left);
+      std::string left_pattern_string{p->string_stream.str()};
+
+      p->Print(pattern->right);
+      std::string right_pattern_string{p->string_stream.str()};
+
+      p->string_stream.str("");
+      p->string_stream << "(id " << p->memo_[pattern].first << "): ";
+      p->string_stream << "AltPattern(" << left_pattern_string << " | " << 
right_pattern_string
+                       << ")";
     });
 
 TVM_REGISTER_NODE_TYPE(WildcardPatternNode);
@@ -216,9 +351,10 @@ 
TVM_REGISTER_GLOBAL("relay.dataflow_pattern.WildcardPattern").set_body_typed([](
   return w;
 });
 
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-    .set_dispatch<WildcardPatternNode>([](const ObjectRef& ref, ReprPrinter* 
p) {
-      p->stream << "*";
+TVM_STATIC_IR_FUNCTOR(DFPatternPrinter, vtable)
+    .set_dispatch<WildcardPatternNode>([](const ObjectRef& ref, 
DFPatternPrinter* p) {
+      p->string_stream.str("");
+      p->string_stream << "*";
     });
 
 TypePattern::TypePattern(DFPattern pattern, Type type) {
@@ -233,10 +369,17 @@ TVM_REGISTER_NODE_TYPE(TypePatternNode);
 TVM_REGISTER_GLOBAL("relay.dataflow_pattern.TypePattern")
     .set_body_typed([](DFPattern pattern, Type type) { return 
TypePattern(pattern, type); });
 
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-    .set_dispatch<TypePatternNode>([](const ObjectRef& ref, ReprPrinter* p) {
-      auto* node = static_cast<const TypePatternNode*>(ref.get());
-      p->stream << "TypePattern(" << node->pattern << " has type " << 
node->type << ")";
+TVM_STATIC_IR_FUNCTOR(DFPatternPrinter, vtable)
+    .set_dispatch<TypePatternNode>([](const ObjectRef& ref, DFPatternPrinter* 
p) {
+      TypePattern pattern = Downcast<TypePattern>(ref);
+
+      p->Print(pattern->pattern);
+      std::string pattern_pattern_string{p->string_stream.str()};
+
+      p->string_stream.str("");
+      p->string_stream << "(id " << p->memo_[pattern].first << "): ";
+      p->string_stream << "TypePattern(" << pattern_pattern_string << " has 
type " << pattern->type
+                       << ")";
     });
 
 ShapePattern::ShapePattern(DFPattern pattern, Array<PrimExpr> shape) {
@@ -253,10 +396,15 @@ TVM_REGISTER_GLOBAL("relay.dataflow_pattern.ShapePattern")
       return ShapePattern(pattern, shape);
     });
 
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-    .set_dispatch<ShapePatternNode>([](const ObjectRef& ref, ReprPrinter* p) {
-      auto* node = static_cast<const ShapePatternNode*>(ref.get());
-      p->stream << "ShapePattern(" << node->pattern << " has shape " << 
node->shape << ")";
+TVM_STATIC_IR_FUNCTOR(DFPatternPrinter, vtable)
+    .set_dispatch<ShapePatternNode>([](const ObjectRef& ref, DFPatternPrinter* 
p) {
+      ShapePattern pattern = Downcast<ShapePattern>(ref);
+
+      p->Print(pattern->pattern);
+      std::string pattern_pattern_string{p->string_stream.str()};
+
+      p->string_stream.str("");
+      p->string_stream << "(id " << p->memo_[pattern].first << "): ";
     });
 
 DataTypePattern::DataTypePattern(DFPattern pattern, DataType dtype) {
@@ -273,10 +421,17 @@ 
TVM_REGISTER_GLOBAL("relay.dataflow_pattern.DataTypePattern")
       return DataTypePattern(pattern, dtype);
     });
 
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-    .set_dispatch<DataTypePatternNode>([](const ObjectRef& ref, ReprPrinter* 
p) {
-      auto* node = static_cast<const DataTypePatternNode*>(ref.get());
-      p->stream << "TypePattern(" << node->pattern << " has dtype " << 
node->dtype << ")";
+TVM_STATIC_IR_FUNCTOR(DFPatternPrinter, vtable)
+    .set_dispatch<DataTypePatternNode>([](const ObjectRef& ref, 
DFPatternPrinter* p) {
+      DataTypePattern pattern = Downcast<DataTypePattern>(ref);
+
+      p->Print(pattern->pattern);
+      std::string pattern_pattern_string{p->string_stream.str()};
+
+      p->string_stream.str("");
+      p->string_stream << "(id " << p->memo_[pattern].first << "): ";
+      p->string_stream << "DataTypePattern(" << pattern_pattern_string << " 
has dtype "
+                       << pattern->dtype << ")";
     });
 
 AttrPattern::AttrPattern(DFPattern pattern, DictAttrs attrs) {
@@ -291,10 +446,17 @@ TVM_REGISTER_NODE_TYPE(AttrPatternNode);
 TVM_REGISTER_GLOBAL("relay.dataflow_pattern.AttrPattern")
     .set_body_typed([](DFPattern pattern, DictAttrs attrs) { return 
AttrPattern(pattern, attrs); });
 
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-    .set_dispatch<AttrPatternNode>([](const ObjectRef& ref, ReprPrinter* p) {
-      auto* node = static_cast<const AttrPatternNode*>(ref.get());
-      p->stream << "AttrPattern(" << node->pattern << " has attributes " << 
node->attrs << ")";
+TVM_STATIC_IR_FUNCTOR(DFPatternPrinter, vtable)
+    .set_dispatch<AttrPatternNode>([](const ObjectRef& ref, DFPatternPrinter* 
p) {
+      AttrPattern pattern = Downcast<AttrPattern>(ref);
+
+      p->Print(pattern->pattern);
+      std::string pattern_pattern_string{p->string_stream.str()};
+
+      p->string_stream.str("");
+      p->string_stream << "(id " << p->memo_[pattern].first << "): ";
+      p->string_stream << "AttrPattern(" << pattern_pattern_string << " has 
attributes "
+                       << pattern->attrs << ")";
     });
 
 DominatorPattern::DominatorPattern(DFPattern parent, DFPattern path, DFPattern 
child) {
@@ -313,11 +475,23 @@ 
TVM_REGISTER_GLOBAL("relay.dataflow_pattern.DominatorPattern")
       return DominatorPattern(parent, path, child);
     });
 
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-    .set_dispatch<DominatorPatternNode>([](const ObjectRef& ref, ReprPrinter* 
p) {
-      auto* node = static_cast<const DominatorPatternNode*>(ref.get());
-      p->stream << "DominatorPattern(" << node->parent << ", " << node->path 
<< ", " << node->child
-                << ")";
+TVM_STATIC_IR_FUNCTOR(DFPatternPrinter, vtable)
+    .set_dispatch<DominatorPatternNode>([](const ObjectRef& ref, 
DFPatternPrinter* p) {
+      DominatorPattern pattern = Downcast<DominatorPattern>(ref);
+
+      p->Print(pattern->parent);
+      std::string parent_pattern_string{p->string_stream.str()};
+
+      p->Print(pattern->path);
+      std::string path_pattern_string{p->string_stream.str()};
+
+      p->Print(pattern->child);
+      std::string child_pattern_string{p->string_stream.str()};
+
+      p->string_stream.str("");
+      p->string_stream << "(id " << p->memo_[pattern].first << "): ";
+      p->string_stream << "DominatorPattern(" << parent_pattern_string << ", "
+                       << path_pattern_string << ", " << child_pattern_string 
<< ")";
     });
 
 // Syntatic Sugar

Reply via email to