This is an automated email from the ASF dual-hosted git repository.

lunderberg pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new c43fad1d60 [Relax] Implement StructInfoPattern for dataflow pattern 
matching (#16685)
c43fad1d60 is described below

commit c43fad1d603434d2316f3a2268e978dd06335c9a
Author: Eric Lunderberg <[email protected]>
AuthorDate: Sun Mar 10 12:23:22 2024 -0500

    [Relax] Implement StructInfoPattern for dataflow pattern matching (#16685)
    
    This commit implements `StructInfoPattern`, which can be applied to
    any existing `DFPattern`, and requires the expression to have a
    specific struct info.  Any symbolic variables that occur in the struct
    info are treated as free parameters, to be defined by the match.
---
 include/tvm/relax/analysis.h                 |  22 +++
 include/tvm/relax/dataflow_pattern.h         |  27 ++++
 include/tvm/relax/dataflow_pattern_functor.h |   4 +
 python/tvm/relax/dpl/pattern.py              |  24 ++++
 python/tvm/relax/frontend/nn/core.py         |  12 +-
 src/relax/analysis/struct_info_analysis.cc   | 200 +++++++++++++++++++++++++++
 src/relax/ir/dataflow_matcher.cc             |  55 ++++++++
 src/relax/ir/dataflow_matcher_impl.h         |  15 ++
 src/relax/ir/dataflow_pattern.cc             |  22 +++
 src/relax/ir/dataflow_pattern_functor.cc     |   4 +
 tests/python/relax/test_dataflow_pattern.py  | 169 ++++++++++++++++++++++
 11 files changed, 553 insertions(+), 1 deletion(-)

diff --git a/include/tvm/relax/analysis.h b/include/tvm/relax/analysis.h
index 76da778ce0..0c43732813 100644
--- a/include/tvm/relax/analysis.h
+++ b/include/tvm/relax/analysis.h
@@ -249,6 +249,28 @@ TVM_DLL BaseCheckResult StructInfoBaseCheck(const 
StructInfo& base, const Struct
 TVM_DLL bool IsBaseOf(const StructInfo& base, const StructInfo& derived,
                       arith::Analyzer* ana = nullptr);
 
+/*!
+ * \brief Return the condition for which base is a superset of derived
+ *
+ * This function returns finer-grained conditions for kFailL2 than 
StructInfoBaseCheck
+ *
+ * If the returned expression is true, or simplifies to true, then
+ * base is a superset of derived.  If the returned expression is
+ * false, or simplifies to false, then base is not a superset of
+ * derived.
+ *
+ * If the returned expression is neither true nor false, it is an
+ * expression in terms of the symbolic variables available in `base`
+ * and `derived`.
+ *
+ * \param base The base struct info.
+ * \param derived The derived struct info.
+ * \return Whether base is a base of derived.
+ *
+ * \sa BaseCheckResult
+ */
+TVM_DLL PrimExpr StructInfoBaseCheckPrecondition(const StructInfo& base, const 
StructInfo& derived);
+
 /*!
  * \brief Unify the two struct info to their least common ancestor.
  *
diff --git a/include/tvm/relax/dataflow_pattern.h 
b/include/tvm/relax/dataflow_pattern.h
index b634b315d9..0d8e7678c2 100644
--- a/include/tvm/relax/dataflow_pattern.h
+++ b/include/tvm/relax/dataflow_pattern.h
@@ -54,6 +54,7 @@ class OrPattern;
 class AndPattern;
 class NotPattern;
 class ShapePattern;
+class StructInfoPattern;
 class TypePattern;
 class DataTypePattern;
 class AttrPattern;
@@ -112,6 +113,8 @@ class DFPattern : public ObjectRef {
   TVM_DLL NotPattern operator~() const;
   /*! \brief Syntatic Sugar for creating an AttrPattern */
   TVM_DLL AttrPattern HasAttr(const Map<String, ObjectRef>& attrs) const;
+  /*! \brief Syntatic Sugar for creating a StructInfoPattern */
+  TVM_DLL StructInfoPattern HasStructInfo(const StructInfo& struct_info) const;
   /*! \brief Syntatic Sugar for creating a TypePattern */
   TVM_DLL TypePattern HasType(const Type& type) const;
   /*! \brief Syntatic Sugar for creating a DataTypePattern with a DataType */
@@ -765,6 +768,30 @@ class TypePattern : public DFPattern {
   TVM_DEFINE_OBJECT_REF_METHODS(TypePattern, DFPattern, TypePatternNode);
 };
 
+/*!
+ * \brief Pattern for matching a certain struct info.
+ * \sa StructInfoPattern
+ */
+class StructInfoPatternNode : public DFPatternNode {
+ public:
+  DFPattern pattern;      /*!< The pattern to match */
+  StructInfo struct_info; /*!< The type to match */
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    v->Visit("pattern", &pattern);
+    v->Visit("struct_info", &struct_info);
+  }
+
+  static constexpr const char* _type_key = "relax.dpl.StructInfoPattern";
+  TVM_DECLARE_FINAL_OBJECT_INFO(StructInfoPatternNode, DFPatternNode);
+};
+
+class StructInfoPattern : public DFPattern {
+ public:
+  TVM_DLL StructInfoPattern(DFPattern pattern, StructInfo struct_info);
+  TVM_DEFINE_OBJECT_REF_METHODS(StructInfoPattern, DFPattern, 
StructInfoPatternNode);
+};
+
 /*!
  * \brief A pattern that asserting a root pattern has a certain shape.
  * \sa ShapePattern
diff --git a/include/tvm/relax/dataflow_pattern_functor.h 
b/include/tvm/relax/dataflow_pattern_functor.h
index 983881ddc9..bbdda44213 100644
--- a/include/tvm/relax/dataflow_pattern_functor.h
+++ b/include/tvm/relax/dataflow_pattern_functor.h
@@ -94,6 +94,8 @@ class DFPatternFunctor<R(const DFPattern& n, Args...)> {
   virtual R VisitDFPattern_(const TupleGetItemPatternNode* op,
                             Args... args) DFPATTERN_FUNCTOR_DEFAULT;
   virtual R VisitDFPattern_(const TuplePatternNode* op, Args... args) 
DFPATTERN_FUNCTOR_DEFAULT;
+  virtual R VisitDFPattern_(const StructInfoPatternNode* op,
+                            Args... args) DFPATTERN_FUNCTOR_DEFAULT;
   virtual R VisitDFPattern_(const TypePatternNode* op, Args... args) 
DFPATTERN_FUNCTOR_DEFAULT;
   virtual R VisitDFPattern_(const WildcardPatternNode* op, Args... args) 
DFPATTERN_FUNCTOR_DEFAULT;
   virtual R VisitDFPattern_(const VarPatternNode* op, Args... args) 
DFPATTERN_FUNCTOR_DEFAULT;
@@ -129,6 +131,7 @@ class DFPatternFunctor<R(const DFPattern& n, Args...)> {
     RELAX_DFPATTERN_FUNCTOR_DISPATCH(ShapePatternNode);
     RELAX_DFPATTERN_FUNCTOR_DISPATCH(TupleGetItemPatternNode);
     RELAX_DFPATTERN_FUNCTOR_DISPATCH(TuplePatternNode);
+    RELAX_DFPATTERN_FUNCTOR_DISPATCH(StructInfoPatternNode);
     RELAX_DFPATTERN_FUNCTOR_DISPATCH(TypePatternNode);
     RELAX_DFPATTERN_FUNCTOR_DISPATCH(WildcardPatternNode);
     RELAX_DFPATTERN_FUNCTOR_DISPATCH(VarPatternNode);
@@ -163,6 +166,7 @@ class DFPatternVisitor : public DFPatternFunctor<void(const 
DFPattern&)> {
   void VisitDFPattern_(const ShapePatternNode* op) override;
   void VisitDFPattern_(const TupleGetItemPatternNode* op) override;
   void VisitDFPattern_(const TuplePatternNode* op) override;
+  void VisitDFPattern_(const StructInfoPatternNode* op) override;
   void VisitDFPattern_(const TypePatternNode* op) override;
   void VisitDFPattern_(const WildcardPatternNode* op) override;
   void VisitDFPattern_(const VarPatternNode* op) override;
diff --git a/python/tvm/relax/dpl/pattern.py b/python/tvm/relax/dpl/pattern.py
index 5594dea3ad..0d38b6bc87 100644
--- a/python/tvm/relax/dpl/pattern.py
+++ b/python/tvm/relax/dpl/pattern.py
@@ -122,6 +122,9 @@ class DFPattern(Node):
         attrs = make_node("DictAttrs", **attrs)
         return AttrPattern(self, attrs)
 
+    def has_struct_info(self, struct_info: "StructInfo") -> 
"StructInfoPattern":
+        return StructInfoPattern(self, struct_info)
+
     def has_type(self, ttype: tvm.ir.type.Type) -> "TypePattern":
         """
         Add a type constraint to this pattern
@@ -575,6 +578,27 @@ class WildcardPattern(DFPattern):
         self.__init_handle_by_constructor__(ffi.WildcardPattern)  # type: 
ignore
 
 
+@register_df_node
+class StructInfoPattern(DFPattern):
+    """A pattern that matches another pattern with a certain StructInfo
+
+    Parameters
+    ----------
+    pattern: tvm.relax.dpl.DFPattern
+        The input pattern that needs type annotation.
+
+    struct_info: tvm.relax.StructInfo
+        The struct info to match against
+    """
+
+    def __init__(self, pattern: "DFPattern", struct_info: "StructInfo"):
+        self.__init_handle_by_constructor__(
+            ffi.StructInfoPattern,
+            pattern,
+            struct_info,
+        )  # type: ignore
+
+
 @register_df_node
 class TypePattern(DFPattern):
     """A pattern that matches another pattern with a certain type annotation.
diff --git a/python/tvm/relax/frontend/nn/core.py 
b/python/tvm/relax/frontend/nn/core.py
index 8eeffd8758..b7b3f411ed 100644
--- a/python/tvm/relax/frontend/nn/core.py
+++ b/python/tvm/relax/frontend/nn/core.py
@@ -48,7 +48,7 @@ from tvm.runtime import ndarray
 from tvm.runtime.relax_vm import VirtualMachine
 from tvm.target import Target
 
-from ... import expr as rx
+from .... import relax as rx
 from ...block_builder import BlockBuilder
 from ...struct_info import (
     ObjectStructInfo,
@@ -126,6 +126,16 @@ class Tensor(_TensorOp):
         """Construct a tensor from a scalar with dtype specified."""
         return Tensor(_expr=rx.const(data, dtype=dtype))
 
+    @staticmethod
+    def from_struct_info(struct_info: rx.TensorStructInfo, name: str = 
"tensor") -> "Tensor":
+        """Construct a nn.Tensor from relax TensorStructInfo"""
+        return Tensor(
+            _expr=rx.Var(
+                name_hint=name,
+                struct_info=struct_info,
+            )
+        )
+
     @staticmethod
     def placeholder(
         shape: Sequence[Union[int, str, tir.PrimExpr]],
diff --git a/src/relax/analysis/struct_info_analysis.cc 
b/src/relax/analysis/struct_info_analysis.cc
index b939ea712c..b1932f9b5d 100644
--- a/src/relax/analysis/struct_info_analysis.cc
+++ b/src/relax/analysis/struct_info_analysis.cc
@@ -609,6 +609,206 @@ TVM_REGISTER_GLOBAL("relax.StructInfoIsBaseOf")
       return IsBaseOf(base, derived);
     });
 
+class StructInfoBasePreconditionCollector
+    : public StructInfoFunctor<PrimExpr(const StructInfo&, const StructInfo&)> 
{
+ public:
+  explicit StructInfoBasePreconditionCollector() {}
+
+  PrimExpr VisitStructInfo(const StructInfo& lhs, const StructInfo& other) 
override {
+    if (lhs.same_as(other)) {
+      // Early bail-out if the StructInfo has reference equality.
+      return Bool(true);
+    } else {
+      return StructInfoFunctor::VisitStructInfo(lhs, other);
+    }
+  }
+
+  PrimExpr VisitStructInfo_(const ObjectStructInfoNode* lhs, const StructInfo& 
other) final {
+    return Bool(true);
+  }
+
+  PrimExpr VisitStructInfo_(const PrimStructInfoNode* lhs, const StructInfo& 
other) final {
+    auto* rhs = other.as<PrimStructInfoNode>();
+    if (rhs == nullptr) {
+      return Bool(false);
+    }
+
+    if (lhs->dtype != rhs->dtype) {
+      return Bool(false);
+    }
+
+    if (lhs->value.defined() && rhs->value.defined()) {
+      return lhs->value.value() == rhs->value.value();
+    } else if (lhs->value.defined() && !rhs->value.defined()) {
+      return Bool(false);
+    } else {
+      return Bool(true);
+    }
+  }
+
+  PrimExpr VisitStructInfo_(const ShapeStructInfoNode* lhs, const StructInfo& 
other) final {
+    auto* rhs = other.as<ShapeStructInfoNode>();
+    if (rhs == nullptr) {
+      return Bool(false);
+    }
+    // lhs have unknown ndim
+    if (lhs->IsUnknownNdim()) {
+      return Bool(true);
+    }
+
+    // ndim must match
+    if (lhs->ndim != rhs->ndim) {
+      return Bool(false);
+    }
+
+    if (lhs->values.defined() && rhs->values.defined()) {
+      return ArrayCheck(lhs->values.value(), rhs->values.value());
+    } else if (lhs->values.defined() && !rhs->values.defined()) {
+      return Bool(false);
+    } else {
+      return Bool(true);
+    }
+  }
+
+  PrimExpr VisitStructInfo_(const TensorStructInfoNode* lhs, const StructInfo& 
other) final {
+    auto* rhs = other.as<TensorStructInfoNode>();
+    if (rhs == nullptr) {
+      return Bool(false);
+    }
+    // dtype mismatch
+    if (!lhs->IsUnknownDtype() && lhs->dtype != rhs->dtype) {
+      return Bool(false);
+    }
+
+    // ndim mismatch
+    if (!lhs->IsUnknownNdim() && lhs->ndim != rhs->ndim) {
+      return Bool(false);
+    }
+
+    // vdevice mismatch
+    if (lhs->vdevice.defined() && !rhs->vdevice.defined()) {
+      return Bool(false);
+    }
+    if (lhs->vdevice.defined() && rhs->vdevice.defined()) {
+      VDevice lhs_vdevice = lhs->vdevice.value();
+      VDevice rhs_vdevice = rhs->vdevice.value();
+      if (lhs_vdevice->target.defined() && !rhs_vdevice->target.defined()) {
+        return Bool(false);
+      }
+      // mismatch in either the target, vdevice_id, or memory_scope
+      if ((lhs_vdevice->target.defined() && rhs_vdevice->target.defined()) &&
+          (lhs_vdevice->target != rhs_vdevice->target ||
+           lhs_vdevice->vdevice_id != rhs_vdevice->vdevice_id ||
+           lhs_vdevice->memory_scope != rhs_vdevice->memory_scope)) {
+        return Bool(false);
+      }
+    }
+
+    if (lhs->shape.same_as(rhs->shape)) {
+      return Bool(true);
+    } else if (lhs->shape.defined() && !rhs->shape.defined()) {
+      return Bool(false);
+    }
+
+    auto* lhs_shape = lhs->shape.as<ShapeExprNode>();
+    auto* rhs_shape = rhs->shape.as<ShapeExprNode>();
+    if (lhs_shape && rhs_shape) {
+      return ArrayCheck(lhs_shape->values, rhs_shape->values);
+    } else if (lhs_shape && !rhs_shape) {
+      return Bool(false);
+    }
+
+    return Bool(true);
+  }
+
+  PrimExpr VisitStructInfo_(const distributed::DTensorStructInfoNode* lhs,
+                            const StructInfo& other) final {
+    auto* rhs = other.as<distributed::DTensorStructInfoNode>();
+    if (rhs == nullptr) {
+      return Bool(false);
+    }
+
+    StructuralEqual struct_equal;
+    if (!struct_equal(lhs->device_mesh, rhs->device_mesh) ||
+        !struct_equal(lhs->placement, rhs->placement)) {
+      return Bool(false);
+    }
+
+    return this->VisitStructInfo(lhs->tensor_sinfo, rhs->tensor_sinfo);
+  }
+
+  PrimExpr VisitStructInfo_(const TupleStructInfoNode* lhs, const StructInfo& 
other) final {
+    auto* rhs = other.as<TupleStructInfoNode>();
+    if (rhs == nullptr) {
+      return Bool(false);
+    }
+    return ArrayCheck(lhs->fields, rhs->fields);
+  }
+
+  PrimExpr VisitStructInfo_(const FuncStructInfoNode* lhs, const StructInfo& 
other) override {
+    auto* rhs = other.as<FuncStructInfoNode>();
+    if (rhs == nullptr) {
+      return Bool(false);
+    }
+
+    // Check purity: Pure functions are a subtype of impure functions
+    if (lhs->purity && !rhs->purity) {
+      return Bool(false);
+    }
+
+    if (lhs->derive_func.defined() && 
!lhs->derive_func.same_as(rhs->derive_func)) {
+      return Bool(false);
+    }
+    if (lhs->params.defined() && !rhs->params.defined()) {
+      return Bool(false);
+    }
+
+    PrimExpr all_match = VisitStructInfo(lhs->ret, rhs->ret);
+
+    PrimExpr param_check;
+    if (lhs->params.defined()) {
+      param_check = ArrayCheck(lhs->params.value(), rhs->params.value());
+    } else {
+      param_check = Bool(true);
+    }
+
+    PrimExpr ret_check = VisitStructInfo(lhs->ret, rhs->ret);
+
+    return param_check && ret_check;
+  }
+
+ private:
+  PrimExpr ArrayCheck(const Array<PrimExpr>& lhs, const Array<PrimExpr>& rhs) {
+    if (lhs.size() != rhs.size()) {
+      return Bool(false);
+    }
+
+    PrimExpr all_equal = Bool(true);
+    for (size_t i = 0; i < lhs.size(); i++) {
+      all_equal = all_equal && (lhs[i] == rhs[i]);
+    }
+    return all_equal;
+  }
+
+  PrimExpr ArrayCheck(const Array<StructInfo>& lhs, const Array<StructInfo>& 
rhs) {
+    if (lhs.size() != rhs.size()) {
+      return Bool(false);
+    }
+
+    PrimExpr all_pass = Bool(true);
+
+    for (size_t i = 0; i < lhs.size(); ++i) {
+      all_pass = all_pass && VisitStructInfo(lhs[i], rhs[i]);
+    }
+    return all_pass;
+  }
+};
+
+PrimExpr StructInfoBaseCheckPrecondition(const StructInfo& base, const 
StructInfo& derived) {
+  StructInfoBasePreconditionCollector visitor;
+  return visitor(base, derived);
+}
+
 //--------------------------
 // DeriveStructInfo
 //--------------------------
diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc
index c2515067ed..a14d43f6d3 100644
--- a/src/relax/ir/dataflow_matcher.cc
+++ b/src/relax/ir/dataflow_matcher.cc
@@ -43,6 +43,7 @@
 #include <utility>
 #include <vector>
 
+#include "../../arith/constraint_extract.h"
 #include "../transform/utils.h"
 #include "dataflow_matcher_impl.h"
 
@@ -85,6 +86,7 @@ bool DFPatternMatcher::VisitDFPattern(const DFPattern& 
pattern, const Expr& expr
     ICHECK_EQ(memo_[pattern].size(), 1);
     return expr.same_as(memo_[pattern][0]);
   } else {
+    PrimExpr cached_condition = symbolic_expr_condition_;
     size_t watermark = matched_nodes_.size();
     bool out = DFPatternFunctor::VisitDFPattern(pattern, expr);
     if (out) {
@@ -92,6 +94,7 @@ bool DFPatternMatcher::VisitDFPattern(const DFPattern& 
pattern, const Expr& expr
       matched_nodes_.push_back(pattern);
     } else {
       ClearMap(watermark);
+      symbolic_expr_condition_ = cached_condition;
     }
     return out;
   }
@@ -424,6 +427,58 @@ bool DFPatternMatcher::VisitDFPattern_(const 
UnorderedTuplePatternNode* op, cons
   return false;
 }
 
+bool DFPatternMatcher::VisitDFPattern_(const StructInfoPatternNode* op, const 
Expr& expr0) {
+  if (!VisitDFPattern(op->pattern, expr0)) {
+    return false;
+  }
+
+  auto expr = TryGetValOfVar(expr0, var2val_);
+  auto expr_struct_info = GetStructInfo(expr);
+
+  PrimExpr new_constraint = StructInfoBaseCheckPrecondition(op->struct_info, 
expr_struct_info);
+  if (auto* as_int = new_constraint.as<IntImmNode>()) {
+    return as_int->value;
+  }
+
+  symbolic_expr_condition_ = SimplifyCondition(symbolic_expr_condition_ && 
new_constraint);
+
+  if (auto* as_int = symbolic_expr_condition_.as<IntImmNode>()) {
+    return as_int->value;
+  } else {
+    return true;
+  }
+}
+
+PrimExpr DFPatternMatcher::SimplifyCondition(PrimExpr condition) {
+  if (condition->IsInstance<IntImmNode>()) {
+    return condition;
+  }
+
+  std::vector<PrimExpr> constraints = arith::ExtractConstraints(condition, 
false);
+  if (constraints.size() == 1) {
+    return condition;
+  }
+
+  auto sort_key = [](PrimExpr expr) -> String {
+    if (const auto* equal = expr.as<tir::EQNode>()) {
+      if (const auto* var = equal->a.as<tir::VarNode>()) {
+        return var->name_hint;
+      }
+    }
+    return "";
+  };
+  std::stable_sort(
+      constraints.begin(), constraints.end(),
+      [&sort_key](const PrimExpr& a, const PrimExpr& b) { return sort_key(a) < 
sort_key(b); });
+
+  PrimExpr sorted_condition = Bool(true);
+  for (const PrimExpr& constraint : constraints) {
+    sorted_condition = sorted_condition && constraint;
+  }
+
+  return analyzer_.Simplify(sorted_condition);
+}
+
 bool DFPatternMatcher::VisitDFPattern_(const TypePatternNode* op, const Expr& 
expr0) {
   auto expr = TryGetValOfVar(expr0, var2val_);
   auto expr_type = expr.as<ExprNode>()->checked_type();
diff --git a/src/relax/ir/dataflow_matcher_impl.h 
b/src/relax/ir/dataflow_matcher_impl.h
index 89f3d114c1..a0c35ac0de 100644
--- a/src/relax/ir/dataflow_matcher_impl.h
+++ b/src/relax/ir/dataflow_matcher_impl.h
@@ -59,6 +59,7 @@ class DFPatternMatcher : public DFPatternFunctor<bool(const 
DFPattern&, const Ex
   bool VisitDFPattern_(const ShapePatternNode* op, const Expr& expr) override;
   bool VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) 
override;
   bool VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const StructInfoPatternNode* op, const Expr& expr) 
override;
   bool VisitDFPattern_(const TypePatternNode* op, const Expr& expr) override;
   bool VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr) 
override;
   bool VisitDFPattern_(const VarPatternNode* op, const Expr& expr) override;
@@ -74,9 +75,23 @@ class DFPatternMatcher : public DFPatternFunctor<bool(const 
DFPattern&, const Ex
                          const tvm::Array<Expr> fields, std::vector<int8_t>& 
match_cache,
                          std::vector<bool>& matched);
 
+  /* \brief Simplify a boolean condition using the analyzer
+   *
+   * Matching struct info can often produce conditions that do not
+   * simplify cleanly.  For example, while the rewrite simplifier can
+   * recognize that `m==0 && m==1` can be simplifies to `false`, it
+   * cannot recognize that `m==0 && n==0 && m==1` can be simplified to
+   * false.
+   *
+   * This function applies additional simplification steps to handle
+   * these cases, before delgating to `analyzer_.Simplify`.
+   */
+  PrimExpr SimplifyCondition(PrimExpr condition);
+
   std::unordered_map<DFPattern, Array<Expr>, ObjectPtrHash, ObjectPtrEqual> 
memo_;
   var2val_t var2val_;
   std::vector<DFPattern> matched_nodes_;
+  PrimExpr symbolic_expr_condition_{Bool(true)};
   arith::Analyzer analyzer_;
   bool memoize_ = true;
 };
diff --git a/src/relax/ir/dataflow_pattern.cc b/src/relax/ir/dataflow_pattern.cc
index ca81b91012..220f4e0ab5 100644
--- a/src/relax/ir/dataflow_pattern.cc
+++ b/src/relax/ir/dataflow_pattern.cc
@@ -259,6 +259,22 @@ RELAX_PATTERN_PRINTER_DEF(TypePatternNode, [](auto p, auto 
node) {
   p->stream << "TypePattern(" << node->pattern << " has type " << node->type 
<< ")";
 });
 
+TVM_REGISTER_NODE_TYPE(StructInfoPatternNode);
+StructInfoPattern::StructInfoPattern(DFPattern pattern, StructInfo 
struct_info) {
+  ObjectPtr<StructInfoPatternNode> n = make_object<StructInfoPatternNode>();
+  n->pattern = std::move(pattern);
+  n->struct_info = std::move(struct_info);
+  data_ = std::move(n);
+}
+TVM_REGISTER_GLOBAL("relax.dpl.StructInfoPattern")
+    .set_body_typed([](DFPattern pattern, StructInfo struct_info) {
+      return StructInfoPattern(pattern, struct_info);
+    });
+RELAX_PATTERN_PRINTER_DEF(StructInfoPatternNode, [](auto p, auto node) {
+  p->stream << "StructInfoPattern(" << node->pattern << " has relax StructInfo 
"
+            << node->struct_info << ")";
+});
+
 TVM_REGISTER_NODE_TYPE(ShapePatternNode);
 ShapePattern::ShapePattern(DFPattern pattern, Array<PrimExpr> shape) {
   ObjectPtr<ShapePatternNode> n = make_object<ShapePatternNode>();
@@ -371,6 +387,9 @@ class DFPatternDuplicator : public 
DFPatternFunctor<DFPattern(const DFPattern&)>
   DFPattern VisitDFPattern_(const ShapePatternNode* op) override {
     return ShapePattern(op->pattern, op->shape);
   }
+  DFPattern VisitDFPattern_(const StructInfoPatternNode* op) override {
+    return StructInfoPattern(op->pattern, op->struct_info);
+  }
   DFPattern VisitDFPattern_(const TypePatternNode* op) override {
     return TypePattern(op->pattern, op->type);
   }
@@ -398,6 +417,9 @@ NotPattern DFPattern::operator~() const { return 
NotPattern(*this); }
 AttrPattern DFPattern::HasAttr(const Map<String, ObjectRef>& attrs) const {
   return AttrPattern(*this, DictAttrs(attrs));
 }
+StructInfoPattern DFPattern::HasStructInfo(const StructInfo& struct_info) 
const {
+  return StructInfoPattern(*this, struct_info);
+}
 TypePattern DFPattern::HasType(const Type& type) const { return 
TypePattern(*this, type); }
 DataTypePattern DFPattern::HasDtype(const DataType& dtype) const {
   return DataTypePattern(*this, dtype);
diff --git a/src/relax/ir/dataflow_pattern_functor.cc 
b/src/relax/ir/dataflow_pattern_functor.cc
index 37a98f28be..655fa2eea1 100644
--- a/src/relax/ir/dataflow_pattern_functor.cc
+++ b/src/relax/ir/dataflow_pattern_functor.cc
@@ -98,6 +98,10 @@ void DFPatternVisitor::VisitDFPattern_(const 
UnorderedTuplePatternNode* op) {
 
 void DFPatternVisitor::VisitDFPattern_(const TypePatternNode* op) { 
VisitDFPattern(op->pattern); }
 
+void DFPatternVisitor::VisitDFPattern_(const StructInfoPatternNode* op) {
+  VisitDFPattern(op->pattern);
+}
+
 // leaf nodes.
 void DFPatternVisitor::VisitDFPattern_(const PrimArrPatternNode* op) {}
 void DFPatternVisitor::VisitDFPattern_(const VarPatternNode* op) {}
diff --git a/tests/python/relax/test_dataflow_pattern.py 
b/tests/python/relax/test_dataflow_pattern.py
index cf2a0cde84..a717e3da04 100644
--- a/tests/python/relax/test_dataflow_pattern.py
+++ b/tests/python/relax/test_dataflow_pattern.py
@@ -1720,5 +1720,174 @@ def test_iterative_rewrite_with_removed_intermediates():
     tvm.ir.assert_structural_equal(expected, after)
 
 
+def test_wildcard_with_struct_info_updates_when_matching():
+    """A DFPattern may be restricted to a specific StructInfo"""
+
+    pat_lhs = wildcard().has_struct_info(R.Tensor([2, 3]))
+    pat_rhs = wildcard().has_struct_info(R.Tensor([2, 3]))
+    pat = is_op("relax.add")(pat_lhs, pat_rhs)
+
+    def rewriter(expr, matches):
+        lhs = matches[pat_lhs]
+        rhs = matches[pat_rhs]
+        return rx.op.multiply(lhs, rhs)
+
+    @R.function(private=True)
+    def before():
+        with R.dataflow():
+            A = R.zeros([2, 3], "int32")
+            B = R.ones([2, 3], "int32")
+            C = R.add(A, B)
+
+            R.output(C)
+        return C
+
+    @R.function(private=True)
+    def expected():
+        with R.dataflow():
+            A = R.zeros([2, 3], "int32")
+            B = R.ones([2, 3], "int32")
+            C = R.multiply(A, B)
+
+            R.output(C)
+        return C
+
+    after = rewrite_call(pat, rewriter, before)
+    tvm.ir.assert_structural_equal(expected, after)
+
+
+def test_wildcard_with_struct_info_is_no_op_when_not_matching():
+    """StructInfoPattern requires the StructInfo provided
+
+    Here, the pattern would match, expect that the function has
+    `R.Tensor([16,32])`, and the pattern requires `R.Tensor([2,3])`.
+    """
+
+    pat_lhs = wildcard().has_struct_info(R.Tensor([2, 3]))
+    pat_rhs = wildcard().has_struct_info(R.Tensor([2, 3]))
+    pat = is_op("relax.add")(pat_lhs, pat_rhs)
+
+    def rewriter(expr, matches):
+        lhs = matches[pat_lhs]
+        rhs = matches[pat_rhs]
+        return rx.op.multiply(lhs, rhs)
+
+    @R.function(private=True)
+    def before():
+        with R.dataflow():
+            # This R.add has the same shape as the pattern, and will
+            # be updated.
+            A = R.zeros([16, 32], "int32")
+            B = R.ones([16, 32], "int32")
+            C = R.add(A, B)
+
+            R.output(C)
+        return C
+
+    expected = before
+
+    after = rewrite_call(pat, rewriter, before)
+    tvm.ir.assert_structural_equal(expected, after)
+
+
+def test_wildcard_struct_info_for_unknown_dtype():
+    """TensorStructInfo with unknown dtype allows any dtype"""
+
+    pat_lhs = wildcard().has_struct_info(R.Tensor([2, 3]))
+    pat_rhs = wildcard().has_struct_info(R.Tensor([2, 3]))
+    pat = is_op("relax.add")(pat_lhs, pat_rhs)
+
+    def rewriter(expr, matches):
+        lhs = matches[pat_lhs]
+        rhs = matches[pat_rhs]
+        return rx.op.multiply(lhs, rhs)
+
+    @R.function(private=True)
+    def before():
+        with R.dataflow():
+            A = R.zeros([2, 3], "int32")
+            B = R.ones([2, 3], "int32")
+            C = R.add(A, B)
+
+            D = R.zeros([2, 3], "float32")
+            E = R.ones([2, 3], "float32")
+            F = R.add(D, E)
+
+            output = (C, F)
+            R.output(output)
+        return output
+
+    @R.function(private=True)
+    def expected():
+        with R.dataflow():
+            A = R.zeros([2, 3], "int32")
+            B = R.ones([2, 3], "int32")
+            C = R.multiply(A, B)
+
+            D = R.zeros([2, 3], "float32")
+            E = R.ones([2, 3], "float32")
+            F = R.multiply(D, E)
+
+            output = (C, F)
+            R.output(output)
+        return output
+
+    after = rewrite_call(pat, rewriter, before)
+    tvm.ir.assert_structural_equal(expected, after)
+
+
+def test_wildcard_struct_info_with_symbolic_vars():
+    """StructInfoPattern may define symbolic vars
+
+    This test finds an elementwise `R.add`, while ignoring a
+    broadcasted `R.add`.
+    """
+
+    m = tir.Var("m", "int64")
+    n = tir.Var("n", "int64")
+
+    pat_lhs = wildcard().has_struct_info(R.Tensor([m, n]))
+    pat_rhs = wildcard().has_struct_info(R.Tensor([m, n]))
+    pat = is_op("relax.add")(pat_lhs, pat_rhs)
+
+    def rewriter(expr, matches):
+        lhs = matches[pat_lhs]
+        rhs = matches[pat_rhs]
+        return rx.op.multiply(lhs, rhs)
+
+    @R.function(private=True)
+    def before():
+        with R.dataflow():
+            A = R.zeros([64, 128], "int32")
+            B = R.ones([64, 128], "int32")
+            C = R.add(A, B)
+
+            D = R.zeros([64, 128], "float32")
+            E = R.ones([1, 128], "float32")
+            F = R.add(D, E)
+
+            output = (C, F)
+            R.output(output)
+        return output
+
+    @R.function(private=True)
+    def expected():
+        with R.dataflow():
+            A = R.zeros([64, 128], "int32")
+            B = R.ones([64, 128], "int32")
+            C = R.multiply(A, B)
+
+            D = R.zeros([64, 128], "float32")
+            E = R.ones([1, 128], "float32")
+            F = R.add(D, E)
+
+            output = (C, F)
+            R.output(output)
+        return output
+
+    after = rewrite_call(pat, rewriter, before)
+    tvm.ir.assert_structural_equal(expected, after)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to