[GitHub] [incubator-tvm] mbrookhart commented on a change in pull request #5231: [POC] Pattern Language, Matcher, Rewriter, and Function Paritioner

2020-05-14 Thread GitBox


mbrookhart commented on a change in pull request #5231:
URL: https://github.com/apache/incubator-tvm/pull/5231#discussion_r425433993



##
File path: include/tvm/relay/dataflow_matcher.h
##
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/relay/dataflow_matcher.h
+ * \brief A pattern matcher for matching dataflow properties.
+ */
+#ifndef TVM_RELAY_DATAFLOW_MATCHER_H_
+#define TVM_RELAY_DATAFLOW_MATCHER_H_
+
+#include 
+#include 
+
+#include 
+#include 
+
+namespace tvm {
+namespace relay {
+
+class DFPatternCallback;
+/*!
+ * \brief Base type of all dataflow pattern callbacks.
+ * \sa DFPatternCallback
+ */
+class DFPatternCallbackNode : public Object {
+ public:
+  /*! \brief Pattern this callback matches */
+  DFPattern pattern_;
+  /*! \brief Function to call when finding a matched expression */
+  PackedFunc function_;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {}
+
+  static constexpr const char* _type_key = "DFPatternCallbackNode";
+  TVM_DECLARE_BASE_OBJECT_INFO(DFPatternCallbackNode, Object);
+};
+
+/*!
+ * \brief Managed reference to dataflow pattern callbacks.
+ * \sa DFPatternCallbackNode
+ */
+class DFPatternCallback : public ObjectRef {

Review comment:
   Something got lost in a refactor. I want to users to be able to write 
pattern-based passes in C++, which requires this in a header, but I don't seem 
to have the pass functions exposed. Will fix.





This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [incubator-tvm] mbrookhart commented on a change in pull request #5231: [POC] Pattern Language, Matcher, Rewriter, and Function Paritioner

2020-05-14 Thread GitBox


mbrookhart commented on a change in pull request #5231:
URL: https://github.com/apache/incubator-tvm/pull/5231#discussion_r42547



##
File path: src/relay/ir/dataflow_matcher.cc
##
@@ -0,0 +1,656 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/tvm/relay/dataflow_matcher.cc
+ * \brief The dataflow pattern matcher for Relay.
+ */
+
+#include 
+#include 
+#include 
+#include 
+
+#include 
+
+#include "indexed_graph.h"
+
+namespace tvm {
+namespace relay {
+
+// Pattern Matcher
+
+class DominatorMatcher;
+
+class DFPatternMatcher : public DFPatternFunctor {
+ public:
+  explicit DFPatternMatcher(const Expr& root_expr) : 
expr_graph_(CreateIndexedGraph(root_expr)) {}
+  bool Match(const DFPattern& pattern, const Expr& expr);
+  Map> GetMemo() { return Map>(memo_); }
+
+ protected:
+  bool VisitDFPattern(const DFPattern& pattern, const Expr& expr) override;
+  bool VisitDFPattern_(const AltPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const AttrPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const CallPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) 
override;
+  bool VisitDFPattern_(const ExprPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) 
override;
+  bool VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const TypePatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const VarPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr) 
override;
+
+  void ClearMap(size_t watermark);
+  bool MatchesPath(const DominatorPatternNode* op, const Expr& expr);
+  bool DominatesParent(const DominatorPatternNode* op, const Expr& expr);
+
+  std::unordered_map, ObjectHash, ObjectEqual> memo_;
+  std::vector matched_nodes_;
+  IndexedGraph expr_graph_;
+  IndexedGraph pattern_graph_;
+  bool memoize_ = true;
+};
+
+bool DFPatternMatcher::Match(const DFPattern& pattern, const Expr& expr) {
+  memo_.clear();
+  matched_nodes_.clear();
+  return VisitDFPattern(pattern, expr);
+}
+
+void DFPatternMatcher::ClearMap(size_t watermark) {
+  for (size_t i = watermark; i < matched_nodes_.size(); ++i) {
+memo_.erase(matched_nodes_[i]);
+  }
+  matched_nodes_.erase(matched_nodes_.begin() + watermark, 
matched_nodes_.end());
+}
+
+bool DFPatternMatcher::VisitDFPattern(const DFPattern& pattern, const Expr& 
expr) {
+  if (memoize_ && memo_.count(pattern)) {
+CHECK_EQ(memo_[pattern].size(), 1);
+return expr.same_as(memo_[pattern][0]);
+  } else {
+auto watermark = matched_nodes_.size();
+auto out = DFPatternFunctor::VisitDFPattern(pattern, expr);
+if (out) {
+  memo_[pattern].push_back(expr);
+  matched_nodes_.push_back(pattern);
+} else {
+  ClearMap(watermark);
+}
+return out;
+  }
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const AltPatternNode* op, const Expr& 
expr) {
+  return VisitDFPattern(op->left, expr) || VisitDFPattern(op->right, expr);
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const AttrPatternNode* attr_pattern, 
const Expr& expr) {
+  bool matches = false;
+  if (const auto* op_node = expr.as()) {
+Op op = GetRef(op_node);
+auto attributes = attr_pattern->attrs.as()->dict;
+for (auto kv : attributes) {
+  auto attr_name = kv.first;
+  auto attr_value = kv.second;
+  auto op_map = Op::GetAttr(attr_name);
+  if (op_map.count(op)) {
+switch (op_map[op].type_code()) {
+  case kDLInt:
+if (auto* val = kv.second.as()) {
+  matches = val->value == op_map[op].operator int64_t();
+}
+break;
+  case kDLFloat:
+if (auto* val = kv.second.as()) {
+  matches = val->value == op_map[op].operator double();
+}
+break;
+  case kTVMStr:
+if (auto* val = kv.second.as()) {
+  matches = val->value == op_map[op].operator std::string();
+}
+break;
+  default:
+

[GitHub] [incubator-tvm] mbrookhart commented on a change in pull request #5231: [POC] Pattern Language, Matcher, Rewriter, and Function Paritioner

2020-05-04 Thread GitBox


mbrookhart commented on a change in pull request #5231:
URL: https://github.com/apache/incubator-tvm/pull/5231#discussion_r419582381



##
File path: include/tvm/relay/dataflow_functor.h
##
@@ -0,0 +1,248 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/relay/dataflow_matcher.h
+ * \brief A pattern matcher for matching dataflow properties.
+ */
+#ifndef TVM_RELAY_DATAFLOW_FUNCTOR_H_
+#define TVM_RELAY_DATAFLOW_FUNCTOR_H_
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+namespace tvm {
+namespace relay {
+
+/*!
+ * \brief A dynamical functor that dispatches on in the first DFPattern 
argument.
+ *
+ * \tparam FType function signiture
+ *  This type is only defined for FType with function signature R(const 
DFPattern&,
+ * Args...)
+ */
+template 
+class DFPatternFunctor;
+
+// functions to be overriden.
+#define DFPATTERN_FUNCTOR_DEFAULT \
+  { return VisitDFPatternDefault_(op, std::forward(args)...); }
+
+#define RELAY_DFPATTERN_FUNCTOR_DISPATCH(OP)   
 \
+  vtable.template set_dispatch([](const ObjectRef& n, TSelf* self, Args... 
args) {  \
+return self->VisitDFPattern_(static_cast(n.get()), 
std::forward(args)...); \
+  });
+
+template 
+class DFPatternFunctor {
+ private:
+  using TSelf = DFPatternFunctor;
+  using FType = tvm::NodeFunctor;
+
+ public:
+  /*! \brief the result type of this functor */
+  using result_type = R;
+  /*! \brief virtual destructor */
+  virtual ~DFPatternFunctor() {}
+  /*!
+   * \brief Same as call.
+   * \param n The expression node.
+   * \param args Additional arguments.
+   * \return The result of the call
+   */
+  R operator()(const DFPattern& n, Args... args) {
+return VisitDFPattern(n, std::forward(args)...);
+  }
+  /*!
+   * \brief The functor call.
+   * \param n The expression node.
+   * \param args Additional arguments.
+   * \return The result of the call
+   */
+  virtual R VisitDFPattern(const DFPattern& n, Args... args) {
+CHECK(n.defined());
+static FType vtable = InitVTable();
+return vtable(n, this, std::forward(args)...);
+  }
+  // Functions that can be overriden by subclass
+  virtual R VisitDFPattern_(const AltPatternNode* op, Args... args) 
DFPATTERN_FUNCTOR_DEFAULT;
+  virtual R VisitDFPattern_(const AttrPatternNode* op, Args... args) 
DFPATTERN_FUNCTOR_DEFAULT;
+  virtual R VisitDFPattern_(const CallPatternNode* op, Args... args) 
DFPATTERN_FUNCTOR_DEFAULT;
+  virtual R VisitDFPattern_(const DominatorPatternNode* op, Args... args) 
DFPATTERN_FUNCTOR_DEFAULT;
+  virtual R VisitDFPattern_(const ExprPatternNode* op, Args... args) 
DFPATTERN_FUNCTOR_DEFAULT;
+  virtual R VisitDFPattern_(const TupleGetItemPatternNode* op,
+Args... args) DFPATTERN_FUNCTOR_DEFAULT;
+  virtual R VisitDFPattern_(const TuplePatternNode* op, Args... args) 
DFPATTERN_FUNCTOR_DEFAULT;
+  virtual R VisitDFPattern_(const TypePatternNode* op, Args... args) 
DFPATTERN_FUNCTOR_DEFAULT;
+  virtual R VisitDFPattern_(const VarPatternNode* op, Args... args) 
DFPATTERN_FUNCTOR_DEFAULT;
+  virtual R VisitDFPattern_(const WildcardPatternNode* op, Args... args) 
DFPATTERN_FUNCTOR_DEFAULT;
+  virtual R VisitDFPatternDefault_(const Object* op, Args...) {
+LOG(FATAL) << "Do not have a default for " << op->GetTypeKey();
+throw;
+  }
+
+ private:
+  // initialize the vtable.
+  static FType InitVTable() {
+FType vtable;
+// Set dispatch
+RELAY_DFPATTERN_FUNCTOR_DISPATCH(AltPatternNode);
+RELAY_DFPATTERN_FUNCTOR_DISPATCH(AttrPatternNode);
+RELAY_DFPATTERN_FUNCTOR_DISPATCH(CallPatternNode);
+RELAY_DFPATTERN_FUNCTOR_DISPATCH(DominatorPatternNode);
+RELAY_DFPATTERN_FUNCTOR_DISPATCH(ExprPatternNode);
+RELAY_DFPATTERN_FUNCTOR_DISPATCH(TupleGetItemPatternNode);
+RELAY_DFPATTERN_FUNCTOR_DISPATCH(TuplePatternNode);
+RELAY_DFPATTERN_FUNCTOR_DISPATCH(TypePatternNode);
+RELAY_DFPATTERN_FUNCTOR_DISPATCH(VarPatternNode);
+RELAY_DFPATTERN_FUNCTOR_DISPATCH(WildcardPatternNode);
+return vtable;
+  }
+};
+
+/*!
+ * \brief A simple visitor wrapper around DFPatternFunctor.
+ *  Recursively visit the content.
+ *
+ *  

[GitHub] [incubator-tvm] mbrookhart commented on a change in pull request #5231: [POC] Pattern Language, Matcher, Rewriter, and Function Paritioner

2020-04-24 Thread GitBox


mbrookhart commented on a change in pull request #5231:
URL: https://github.com/apache/incubator-tvm/pull/5231#discussion_r414721399



##
File path: python/tvm/relay/df_pattern/__init__.py
##
@@ -0,0 +1,488 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""The Relay Pattern Language and tooling."""
+from tvm.relay import Expr
+from ...ir.base import Node
+from ...ir import make_node
+from ...runtime import Object
+from ... import _ffi as tvm_ffi
+from ..op import get
+from . import _ffi as ffi
+
+
+def register_df_node(type_key=None):
+"""Register a Relay node type.
+
+Parameters
+--
+type_key : str or cls
+The type key of the node.
+"""
+if not isinstance(type_key, str):
+return tvm_ffi.register_object(
+"relay.df_pattern." + type_key.__name__)(type_key)
+return tvm_ffi.register_object(type_key)

Review comment:
   @jroesch can you comment on this? This was one of your contributions to 
the python API.





This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org