This is an automated email from the ASF dual-hosted git repository.
lunderberg pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new c43fad1d60 [Relax] Implement StructInfoPattern for dataflow pattern
matching (#16685)
c43fad1d60 is described below
commit c43fad1d603434d2316f3a2268e978dd06335c9a
Author: Eric Lunderberg <[email protected]>
AuthorDate: Sun Mar 10 12:23:22 2024 -0500
[Relax] Implement StructInfoPattern for dataflow pattern matching (#16685)
This commit implements `StructInfoPattern`, which can be applied to
any existing `DFPattern`, and requires the expression to have a
specific struct info. Any symbolic variables that occur in the struct
info are treated as free parameters, to be defined by the match.
---
include/tvm/relax/analysis.h | 22 +++
include/tvm/relax/dataflow_pattern.h | 27 ++++
include/tvm/relax/dataflow_pattern_functor.h | 4 +
python/tvm/relax/dpl/pattern.py | 24 ++++
python/tvm/relax/frontend/nn/core.py | 12 +-
src/relax/analysis/struct_info_analysis.cc | 200 +++++++++++++++++++++++++++
src/relax/ir/dataflow_matcher.cc | 55 ++++++++
src/relax/ir/dataflow_matcher_impl.h | 15 ++
src/relax/ir/dataflow_pattern.cc | 22 +++
src/relax/ir/dataflow_pattern_functor.cc | 4 +
tests/python/relax/test_dataflow_pattern.py | 169 ++++++++++++++++++++++
11 files changed, 553 insertions(+), 1 deletion(-)
diff --git a/include/tvm/relax/analysis.h b/include/tvm/relax/analysis.h
index 76da778ce0..0c43732813 100644
--- a/include/tvm/relax/analysis.h
+++ b/include/tvm/relax/analysis.h
@@ -249,6 +249,28 @@ TVM_DLL BaseCheckResult StructInfoBaseCheck(const
StructInfo& base, const Struct
TVM_DLL bool IsBaseOf(const StructInfo& base, const StructInfo& derived,
arith::Analyzer* ana = nullptr);
+/*!
+ * \brief Return the condition for which base is a superset of derived
+ *
+ * This function returns finer-grained conditions for kFailL2 than
StructInfoBaseCheck
+ *
+ * If the returned expression is true, or simplifies to true, then
+ * base is a superset of derived. If the returned expression is
+ * false, or simplifies to false, then base is not a superset of
+ * derived.
+ *
+ * If the returned expression is neither true nor false, it is an
+ * expression in terms of the symbolic variables available in `base`
+ * and `derived`.
+ *
+ * \param base The base struct info.
+ * \param derived The derived struct info.
+ * \return Whether base is a base of derived.
+ *
+ * \sa BaseCheckResult
+ */
+TVM_DLL PrimExpr StructInfoBaseCheckPrecondition(const StructInfo& base, const
StructInfo& derived);
+
/*!
* \brief Unify the two struct info to their least common ancestor.
*
diff --git a/include/tvm/relax/dataflow_pattern.h
b/include/tvm/relax/dataflow_pattern.h
index b634b315d9..0d8e7678c2 100644
--- a/include/tvm/relax/dataflow_pattern.h
+++ b/include/tvm/relax/dataflow_pattern.h
@@ -54,6 +54,7 @@ class OrPattern;
class AndPattern;
class NotPattern;
class ShapePattern;
+class StructInfoPattern;
class TypePattern;
class DataTypePattern;
class AttrPattern;
@@ -112,6 +113,8 @@ class DFPattern : public ObjectRef {
TVM_DLL NotPattern operator~() const;
/*! \brief Syntatic Sugar for creating an AttrPattern */
TVM_DLL AttrPattern HasAttr(const Map<String, ObjectRef>& attrs) const;
+ /*! \brief Syntatic Sugar for creating a StructInfoPattern */
+ TVM_DLL StructInfoPattern HasStructInfo(const StructInfo& struct_info) const;
/*! \brief Syntatic Sugar for creating a TypePattern */
TVM_DLL TypePattern HasType(const Type& type) const;
/*! \brief Syntatic Sugar for creating a DataTypePattern with a DataType */
@@ -765,6 +768,30 @@ class TypePattern : public DFPattern {
TVM_DEFINE_OBJECT_REF_METHODS(TypePattern, DFPattern, TypePatternNode);
};
+/*!
+ * \brief Pattern for matching a certain struct info.
+ * \sa StructInfoPattern
+ */
+class StructInfoPatternNode : public DFPatternNode {
+ public:
+ DFPattern pattern; /*!< The pattern to match */
+ StructInfo struct_info; /*!< The type to match */
+
+ void VisitAttrs(tvm::AttrVisitor* v) {
+ v->Visit("pattern", &pattern);
+ v->Visit("struct_info", &struct_info);
+ }
+
+ static constexpr const char* _type_key = "relax.dpl.StructInfoPattern";
+ TVM_DECLARE_FINAL_OBJECT_INFO(StructInfoPatternNode, DFPatternNode);
+};
+
+class StructInfoPattern : public DFPattern {
+ public:
+ TVM_DLL StructInfoPattern(DFPattern pattern, StructInfo struct_info);
+ TVM_DEFINE_OBJECT_REF_METHODS(StructInfoPattern, DFPattern,
StructInfoPatternNode);
+};
+
/*!
* \brief A pattern that asserting a root pattern has a certain shape.
* \sa ShapePattern
diff --git a/include/tvm/relax/dataflow_pattern_functor.h
b/include/tvm/relax/dataflow_pattern_functor.h
index 983881ddc9..bbdda44213 100644
--- a/include/tvm/relax/dataflow_pattern_functor.h
+++ b/include/tvm/relax/dataflow_pattern_functor.h
@@ -94,6 +94,8 @@ class DFPatternFunctor<R(const DFPattern& n, Args...)> {
virtual R VisitDFPattern_(const TupleGetItemPatternNode* op,
Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const TuplePatternNode* op, Args... args)
DFPATTERN_FUNCTOR_DEFAULT;
+ virtual R VisitDFPattern_(const StructInfoPatternNode* op,
+ Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const TypePatternNode* op, Args... args)
DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const WildcardPatternNode* op, Args... args)
DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const VarPatternNode* op, Args... args)
DFPATTERN_FUNCTOR_DEFAULT;
@@ -129,6 +131,7 @@ class DFPatternFunctor<R(const DFPattern& n, Args...)> {
RELAX_DFPATTERN_FUNCTOR_DISPATCH(ShapePatternNode);
RELAX_DFPATTERN_FUNCTOR_DISPATCH(TupleGetItemPatternNode);
RELAX_DFPATTERN_FUNCTOR_DISPATCH(TuplePatternNode);
+ RELAX_DFPATTERN_FUNCTOR_DISPATCH(StructInfoPatternNode);
RELAX_DFPATTERN_FUNCTOR_DISPATCH(TypePatternNode);
RELAX_DFPATTERN_FUNCTOR_DISPATCH(WildcardPatternNode);
RELAX_DFPATTERN_FUNCTOR_DISPATCH(VarPatternNode);
@@ -163,6 +166,7 @@ class DFPatternVisitor : public DFPatternFunctor<void(const
DFPattern&)> {
void VisitDFPattern_(const ShapePatternNode* op) override;
void VisitDFPattern_(const TupleGetItemPatternNode* op) override;
void VisitDFPattern_(const TuplePatternNode* op) override;
+ void VisitDFPattern_(const StructInfoPatternNode* op) override;
void VisitDFPattern_(const TypePatternNode* op) override;
void VisitDFPattern_(const WildcardPatternNode* op) override;
void VisitDFPattern_(const VarPatternNode* op) override;
diff --git a/python/tvm/relax/dpl/pattern.py b/python/tvm/relax/dpl/pattern.py
index 5594dea3ad..0d38b6bc87 100644
--- a/python/tvm/relax/dpl/pattern.py
+++ b/python/tvm/relax/dpl/pattern.py
@@ -122,6 +122,9 @@ class DFPattern(Node):
attrs = make_node("DictAttrs", **attrs)
return AttrPattern(self, attrs)
+ def has_struct_info(self, struct_info: "StructInfo") ->
"StructInfoPattern":
+ return StructInfoPattern(self, struct_info)
+
def has_type(self, ttype: tvm.ir.type.Type) -> "TypePattern":
"""
Add a type constraint to this pattern
@@ -575,6 +578,27 @@ class WildcardPattern(DFPattern):
self.__init_handle_by_constructor__(ffi.WildcardPattern) # type:
ignore
+@register_df_node
+class StructInfoPattern(DFPattern):
+ """A pattern that matches another pattern with a certain StructInfo
+
+ Parameters
+ ----------
+ pattern: tvm.relax.dpl.DFPattern
+ The input pattern that needs type annotation.
+
+ struct_info: tvm.relax.StructInfo
+ The struct info to match against
+ """
+
+ def __init__(self, pattern: "DFPattern", struct_info: "StructInfo"):
+ self.__init_handle_by_constructor__(
+ ffi.StructInfoPattern,
+ pattern,
+ struct_info,
+ ) # type: ignore
+
+
@register_df_node
class TypePattern(DFPattern):
"""A pattern that matches another pattern with a certain type annotation.
diff --git a/python/tvm/relax/frontend/nn/core.py
b/python/tvm/relax/frontend/nn/core.py
index 8eeffd8758..b7b3f411ed 100644
--- a/python/tvm/relax/frontend/nn/core.py
+++ b/python/tvm/relax/frontend/nn/core.py
@@ -48,7 +48,7 @@ from tvm.runtime import ndarray
from tvm.runtime.relax_vm import VirtualMachine
from tvm.target import Target
-from ... import expr as rx
+from .... import relax as rx
from ...block_builder import BlockBuilder
from ...struct_info import (
ObjectStructInfo,
@@ -126,6 +126,16 @@ class Tensor(_TensorOp):
"""Construct a tensor from a scalar with dtype specified."""
return Tensor(_expr=rx.const(data, dtype=dtype))
+ @staticmethod
+ def from_struct_info(struct_info: rx.TensorStructInfo, name: str =
"tensor") -> "Tensor":
+ """Construct a nn.Tensor from relax TensorStructInfo"""
+ return Tensor(
+ _expr=rx.Var(
+ name_hint=name,
+ struct_info=struct_info,
+ )
+ )
+
@staticmethod
def placeholder(
shape: Sequence[Union[int, str, tir.PrimExpr]],
diff --git a/src/relax/analysis/struct_info_analysis.cc
b/src/relax/analysis/struct_info_analysis.cc
index b939ea712c..b1932f9b5d 100644
--- a/src/relax/analysis/struct_info_analysis.cc
+++ b/src/relax/analysis/struct_info_analysis.cc
@@ -609,6 +609,206 @@ TVM_REGISTER_GLOBAL("relax.StructInfoIsBaseOf")
return IsBaseOf(base, derived);
});
+class StructInfoBasePreconditionCollector
+ : public StructInfoFunctor<PrimExpr(const StructInfo&, const StructInfo&)>
{
+ public:
+ explicit StructInfoBasePreconditionCollector() {}
+
+ PrimExpr VisitStructInfo(const StructInfo& lhs, const StructInfo& other)
override {
+ if (lhs.same_as(other)) {
+ // Early bail-out if the StructInfo has reference equality.
+ return Bool(true);
+ } else {
+ return StructInfoFunctor::VisitStructInfo(lhs, other);
+ }
+ }
+
+ PrimExpr VisitStructInfo_(const ObjectStructInfoNode* lhs, const StructInfo&
other) final {
+ return Bool(true);
+ }
+
+ PrimExpr VisitStructInfo_(const PrimStructInfoNode* lhs, const StructInfo&
other) final {
+ auto* rhs = other.as<PrimStructInfoNode>();
+ if (rhs == nullptr) {
+ return Bool(false);
+ }
+
+ if (lhs->dtype != rhs->dtype) {
+ return Bool(false);
+ }
+
+ if (lhs->value.defined() && rhs->value.defined()) {
+ return lhs->value.value() == rhs->value.value();
+ } else if (lhs->value.defined() && !rhs->value.defined()) {
+ return Bool(false);
+ } else {
+ return Bool(true);
+ }
+ }
+
+ PrimExpr VisitStructInfo_(const ShapeStructInfoNode* lhs, const StructInfo&
other) final {
+ auto* rhs = other.as<ShapeStructInfoNode>();
+ if (rhs == nullptr) {
+ return Bool(false);
+ }
+ // lhs have unknown ndim
+ if (lhs->IsUnknownNdim()) {
+ return Bool(true);
+ }
+
+ // ndim must match
+ if (lhs->ndim != rhs->ndim) {
+ return Bool(false);
+ }
+
+ if (lhs->values.defined() && rhs->values.defined()) {
+ return ArrayCheck(lhs->values.value(), rhs->values.value());
+ } else if (lhs->values.defined() && !rhs->values.defined()) {
+ return Bool(false);
+ } else {
+ return Bool(true);
+ }
+ }
+
+ PrimExpr VisitStructInfo_(const TensorStructInfoNode* lhs, const StructInfo&
other) final {
+ auto* rhs = other.as<TensorStructInfoNode>();
+ if (rhs == nullptr) {
+ return Bool(false);
+ }
+ // dtype mismatch
+ if (!lhs->IsUnknownDtype() && lhs->dtype != rhs->dtype) {
+ return Bool(false);
+ }
+
+ // ndim mismatch
+ if (!lhs->IsUnknownNdim() && lhs->ndim != rhs->ndim) {
+ return Bool(false);
+ }
+
+ // vdevice mismatch
+ if (lhs->vdevice.defined() && !rhs->vdevice.defined()) {
+ return Bool(false);
+ }
+ if (lhs->vdevice.defined() && rhs->vdevice.defined()) {
+ VDevice lhs_vdevice = lhs->vdevice.value();
+ VDevice rhs_vdevice = rhs->vdevice.value();
+ if (lhs_vdevice->target.defined() && !rhs_vdevice->target.defined()) {
+ return Bool(false);
+ }
+ // mismatch in either the target, vdevice_id, or memory_scope
+ if ((lhs_vdevice->target.defined() && rhs_vdevice->target.defined()) &&
+ (lhs_vdevice->target != rhs_vdevice->target ||
+ lhs_vdevice->vdevice_id != rhs_vdevice->vdevice_id ||
+ lhs_vdevice->memory_scope != rhs_vdevice->memory_scope)) {
+ return Bool(false);
+ }
+ }
+
+ if (lhs->shape.same_as(rhs->shape)) {
+ return Bool(true);
+ } else if (lhs->shape.defined() && !rhs->shape.defined()) {
+ return Bool(false);
+ }
+
+ auto* lhs_shape = lhs->shape.as<ShapeExprNode>();
+ auto* rhs_shape = rhs->shape.as<ShapeExprNode>();
+ if (lhs_shape && rhs_shape) {
+ return ArrayCheck(lhs_shape->values, rhs_shape->values);
+ } else if (lhs_shape && !rhs_shape) {
+ return Bool(false);
+ }
+
+ return Bool(true);
+ }
+
+ PrimExpr VisitStructInfo_(const distributed::DTensorStructInfoNode* lhs,
+ const StructInfo& other) final {
+ auto* rhs = other.as<distributed::DTensorStructInfoNode>();
+ if (rhs == nullptr) {
+ return Bool(false);
+ }
+
+ StructuralEqual struct_equal;
+ if (!struct_equal(lhs->device_mesh, rhs->device_mesh) ||
+ !struct_equal(lhs->placement, rhs->placement)) {
+ return Bool(false);
+ }
+
+ return this->VisitStructInfo(lhs->tensor_sinfo, rhs->tensor_sinfo);
+ }
+
+ PrimExpr VisitStructInfo_(const TupleStructInfoNode* lhs, const StructInfo&
other) final {
+ auto* rhs = other.as<TupleStructInfoNode>();
+ if (rhs == nullptr) {
+ return Bool(false);
+ }
+ return ArrayCheck(lhs->fields, rhs->fields);
+ }
+
+ PrimExpr VisitStructInfo_(const FuncStructInfoNode* lhs, const StructInfo&
other) override {
+ auto* rhs = other.as<FuncStructInfoNode>();
+ if (rhs == nullptr) {
+ return Bool(false);
+ }
+
+ // Check purity: Pure functions are a subtype of impure functions
+ if (lhs->purity && !rhs->purity) {
+ return Bool(false);
+ }
+
+ if (lhs->derive_func.defined() &&
!lhs->derive_func.same_as(rhs->derive_func)) {
+ return Bool(false);
+ }
+ if (lhs->params.defined() && !rhs->params.defined()) {
+ return Bool(false);
+ }
+
+ PrimExpr all_match = VisitStructInfo(lhs->ret, rhs->ret);
+
+ PrimExpr param_check;
+ if (lhs->params.defined()) {
+ param_check = ArrayCheck(lhs->params.value(), rhs->params.value());
+ } else {
+ param_check = Bool(true);
+ }
+
+ PrimExpr ret_check = VisitStructInfo(lhs->ret, rhs->ret);
+
+ return param_check && ret_check;
+ }
+
+ private:
+ PrimExpr ArrayCheck(const Array<PrimExpr>& lhs, const Array<PrimExpr>& rhs) {
+ if (lhs.size() != rhs.size()) {
+ return Bool(false);
+ }
+
+ PrimExpr all_equal = Bool(true);
+ for (size_t i = 0; i < lhs.size(); i++) {
+ all_equal = all_equal && (lhs[i] == rhs[i]);
+ }
+ return all_equal;
+ }
+
+ PrimExpr ArrayCheck(const Array<StructInfo>& lhs, const Array<StructInfo>&
rhs) {
+ if (lhs.size() != rhs.size()) {
+ return Bool(false);
+ }
+
+ PrimExpr all_pass = Bool(true);
+
+ for (size_t i = 0; i < lhs.size(); ++i) {
+ all_pass = all_pass && VisitStructInfo(lhs[i], rhs[i]);
+ }
+ return all_pass;
+ }
+};
+
+PrimExpr StructInfoBaseCheckPrecondition(const StructInfo& base, const
StructInfo& derived) {
+ StructInfoBasePreconditionCollector visitor;
+ return visitor(base, derived);
+}
+
//--------------------------
// DeriveStructInfo
//--------------------------
diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc
index c2515067ed..a14d43f6d3 100644
--- a/src/relax/ir/dataflow_matcher.cc
+++ b/src/relax/ir/dataflow_matcher.cc
@@ -43,6 +43,7 @@
#include <utility>
#include <vector>
+#include "../../arith/constraint_extract.h"
#include "../transform/utils.h"
#include "dataflow_matcher_impl.h"
@@ -85,6 +86,7 @@ bool DFPatternMatcher::VisitDFPattern(const DFPattern&
pattern, const Expr& expr
ICHECK_EQ(memo_[pattern].size(), 1);
return expr.same_as(memo_[pattern][0]);
} else {
+ PrimExpr cached_condition = symbolic_expr_condition_;
size_t watermark = matched_nodes_.size();
bool out = DFPatternFunctor::VisitDFPattern(pattern, expr);
if (out) {
@@ -92,6 +94,7 @@ bool DFPatternMatcher::VisitDFPattern(const DFPattern&
pattern, const Expr& expr
matched_nodes_.push_back(pattern);
} else {
ClearMap(watermark);
+ symbolic_expr_condition_ = cached_condition;
}
return out;
}
@@ -424,6 +427,58 @@ bool DFPatternMatcher::VisitDFPattern_(const
UnorderedTuplePatternNode* op, cons
return false;
}
+bool DFPatternMatcher::VisitDFPattern_(const StructInfoPatternNode* op, const
Expr& expr0) {
+ if (!VisitDFPattern(op->pattern, expr0)) {
+ return false;
+ }
+
+ auto expr = TryGetValOfVar(expr0, var2val_);
+ auto expr_struct_info = GetStructInfo(expr);
+
+ PrimExpr new_constraint = StructInfoBaseCheckPrecondition(op->struct_info,
expr_struct_info);
+ if (auto* as_int = new_constraint.as<IntImmNode>()) {
+ return as_int->value;
+ }
+
+ symbolic_expr_condition_ = SimplifyCondition(symbolic_expr_condition_ &&
new_constraint);
+
+ if (auto* as_int = symbolic_expr_condition_.as<IntImmNode>()) {
+ return as_int->value;
+ } else {
+ return true;
+ }
+}
+
+PrimExpr DFPatternMatcher::SimplifyCondition(PrimExpr condition) {
+ if (condition->IsInstance<IntImmNode>()) {
+ return condition;
+ }
+
+ std::vector<PrimExpr> constraints = arith::ExtractConstraints(condition,
false);
+ if (constraints.size() == 1) {
+ return condition;
+ }
+
+ auto sort_key = [](PrimExpr expr) -> String {
+ if (const auto* equal = expr.as<tir::EQNode>()) {
+ if (const auto* var = equal->a.as<tir::VarNode>()) {
+ return var->name_hint;
+ }
+ }
+ return "";
+ };
+ std::stable_sort(
+ constraints.begin(), constraints.end(),
+ [&sort_key](const PrimExpr& a, const PrimExpr& b) { return sort_key(a) <
sort_key(b); });
+
+ PrimExpr sorted_condition = Bool(true);
+ for (const PrimExpr& constraint : constraints) {
+ sorted_condition = sorted_condition && constraint;
+ }
+
+ return analyzer_.Simplify(sorted_condition);
+}
+
bool DFPatternMatcher::VisitDFPattern_(const TypePatternNode* op, const Expr&
expr0) {
auto expr = TryGetValOfVar(expr0, var2val_);
auto expr_type = expr.as<ExprNode>()->checked_type();
diff --git a/src/relax/ir/dataflow_matcher_impl.h
b/src/relax/ir/dataflow_matcher_impl.h
index 89f3d114c1..a0c35ac0de 100644
--- a/src/relax/ir/dataflow_matcher_impl.h
+++ b/src/relax/ir/dataflow_matcher_impl.h
@@ -59,6 +59,7 @@ class DFPatternMatcher : public DFPatternFunctor<bool(const
DFPattern&, const Ex
bool VisitDFPattern_(const ShapePatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr)
override;
bool VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) override;
+ bool VisitDFPattern_(const StructInfoPatternNode* op, const Expr& expr)
override;
bool VisitDFPattern_(const TypePatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr)
override;
bool VisitDFPattern_(const VarPatternNode* op, const Expr& expr) override;
@@ -74,9 +75,23 @@ class DFPatternMatcher : public DFPatternFunctor<bool(const
DFPattern&, const Ex
const tvm::Array<Expr> fields, std::vector<int8_t>&
match_cache,
std::vector<bool>& matched);
+ /* \brief Simplify a boolean condition using the analyzer
+ *
+ * Matching struct info can often produce conditions that do not
+ * simplify cleanly. For example, while the rewrite simplifier can
+ * recognize that `m==0 && m==1` can be simplifies to `false`, it
+ * cannot recognize that `m==0 && n==0 && m==1` can be simplified to
+ * false.
+ *
+ * This function applies additional simplification steps to handle
+ * these cases, before delgating to `analyzer_.Simplify`.
+ */
+ PrimExpr SimplifyCondition(PrimExpr condition);
+
std::unordered_map<DFPattern, Array<Expr>, ObjectPtrHash, ObjectPtrEqual>
memo_;
var2val_t var2val_;
std::vector<DFPattern> matched_nodes_;
+ PrimExpr symbolic_expr_condition_{Bool(true)};
arith::Analyzer analyzer_;
bool memoize_ = true;
};
diff --git a/src/relax/ir/dataflow_pattern.cc b/src/relax/ir/dataflow_pattern.cc
index ca81b91012..220f4e0ab5 100644
--- a/src/relax/ir/dataflow_pattern.cc
+++ b/src/relax/ir/dataflow_pattern.cc
@@ -259,6 +259,22 @@ RELAX_PATTERN_PRINTER_DEF(TypePatternNode, [](auto p, auto
node) {
p->stream << "TypePattern(" << node->pattern << " has type " << node->type
<< ")";
});
+TVM_REGISTER_NODE_TYPE(StructInfoPatternNode);
+StructInfoPattern::StructInfoPattern(DFPattern pattern, StructInfo
struct_info) {
+ ObjectPtr<StructInfoPatternNode> n = make_object<StructInfoPatternNode>();
+ n->pattern = std::move(pattern);
+ n->struct_info = std::move(struct_info);
+ data_ = std::move(n);
+}
+TVM_REGISTER_GLOBAL("relax.dpl.StructInfoPattern")
+ .set_body_typed([](DFPattern pattern, StructInfo struct_info) {
+ return StructInfoPattern(pattern, struct_info);
+ });
+RELAX_PATTERN_PRINTER_DEF(StructInfoPatternNode, [](auto p, auto node) {
+ p->stream << "StructInfoPattern(" << node->pattern << " has relax StructInfo
"
+ << node->struct_info << ")";
+});
+
TVM_REGISTER_NODE_TYPE(ShapePatternNode);
ShapePattern::ShapePattern(DFPattern pattern, Array<PrimExpr> shape) {
ObjectPtr<ShapePatternNode> n = make_object<ShapePatternNode>();
@@ -371,6 +387,9 @@ class DFPatternDuplicator : public
DFPatternFunctor<DFPattern(const DFPattern&)>
DFPattern VisitDFPattern_(const ShapePatternNode* op) override {
return ShapePattern(op->pattern, op->shape);
}
+ DFPattern VisitDFPattern_(const StructInfoPatternNode* op) override {
+ return StructInfoPattern(op->pattern, op->struct_info);
+ }
DFPattern VisitDFPattern_(const TypePatternNode* op) override {
return TypePattern(op->pattern, op->type);
}
@@ -398,6 +417,9 @@ NotPattern DFPattern::operator~() const { return
NotPattern(*this); }
AttrPattern DFPattern::HasAttr(const Map<String, ObjectRef>& attrs) const {
return AttrPattern(*this, DictAttrs(attrs));
}
+StructInfoPattern DFPattern::HasStructInfo(const StructInfo& struct_info)
const {
+ return StructInfoPattern(*this, struct_info);
+}
TypePattern DFPattern::HasType(const Type& type) const { return
TypePattern(*this, type); }
DataTypePattern DFPattern::HasDtype(const DataType& dtype) const {
return DataTypePattern(*this, dtype);
diff --git a/src/relax/ir/dataflow_pattern_functor.cc
b/src/relax/ir/dataflow_pattern_functor.cc
index 37a98f28be..655fa2eea1 100644
--- a/src/relax/ir/dataflow_pattern_functor.cc
+++ b/src/relax/ir/dataflow_pattern_functor.cc
@@ -98,6 +98,10 @@ void DFPatternVisitor::VisitDFPattern_(const
UnorderedTuplePatternNode* op) {
void DFPatternVisitor::VisitDFPattern_(const TypePatternNode* op) {
VisitDFPattern(op->pattern); }
+void DFPatternVisitor::VisitDFPattern_(const StructInfoPatternNode* op) {
+ VisitDFPattern(op->pattern);
+}
+
// leaf nodes.
void DFPatternVisitor::VisitDFPattern_(const PrimArrPatternNode* op) {}
void DFPatternVisitor::VisitDFPattern_(const VarPatternNode* op) {}
diff --git a/tests/python/relax/test_dataflow_pattern.py
b/tests/python/relax/test_dataflow_pattern.py
index cf2a0cde84..a717e3da04 100644
--- a/tests/python/relax/test_dataflow_pattern.py
+++ b/tests/python/relax/test_dataflow_pattern.py
@@ -1720,5 +1720,174 @@ def test_iterative_rewrite_with_removed_intermediates():
tvm.ir.assert_structural_equal(expected, after)
+def test_wildcard_with_struct_info_updates_when_matching():
+ """A DFPattern may be restricted to a specific StructInfo"""
+
+ pat_lhs = wildcard().has_struct_info(R.Tensor([2, 3]))
+ pat_rhs = wildcard().has_struct_info(R.Tensor([2, 3]))
+ pat = is_op("relax.add")(pat_lhs, pat_rhs)
+
+ def rewriter(expr, matches):
+ lhs = matches[pat_lhs]
+ rhs = matches[pat_rhs]
+ return rx.op.multiply(lhs, rhs)
+
+ @R.function(private=True)
+ def before():
+ with R.dataflow():
+ A = R.zeros([2, 3], "int32")
+ B = R.ones([2, 3], "int32")
+ C = R.add(A, B)
+
+ R.output(C)
+ return C
+
+ @R.function(private=True)
+ def expected():
+ with R.dataflow():
+ A = R.zeros([2, 3], "int32")
+ B = R.ones([2, 3], "int32")
+ C = R.multiply(A, B)
+
+ R.output(C)
+ return C
+
+ after = rewrite_call(pat, rewriter, before)
+ tvm.ir.assert_structural_equal(expected, after)
+
+
+def test_wildcard_with_struct_info_is_no_op_when_not_matching():
+ """StructInfoPattern requires the StructInfo provided
+
+ Here, the pattern would match, expect that the function has
+ `R.Tensor([16,32])`, and the pattern requires `R.Tensor([2,3])`.
+ """
+
+ pat_lhs = wildcard().has_struct_info(R.Tensor([2, 3]))
+ pat_rhs = wildcard().has_struct_info(R.Tensor([2, 3]))
+ pat = is_op("relax.add")(pat_lhs, pat_rhs)
+
+ def rewriter(expr, matches):
+ lhs = matches[pat_lhs]
+ rhs = matches[pat_rhs]
+ return rx.op.multiply(lhs, rhs)
+
+ @R.function(private=True)
+ def before():
+ with R.dataflow():
+ # This R.add has the same shape as the pattern, and will
+ # be updated.
+ A = R.zeros([16, 32], "int32")
+ B = R.ones([16, 32], "int32")
+ C = R.add(A, B)
+
+ R.output(C)
+ return C
+
+ expected = before
+
+ after = rewrite_call(pat, rewriter, before)
+ tvm.ir.assert_structural_equal(expected, after)
+
+
+def test_wildcard_struct_info_for_unknown_dtype():
+ """TensorStructInfo with unknown dtype allows any dtype"""
+
+ pat_lhs = wildcard().has_struct_info(R.Tensor([2, 3]))
+ pat_rhs = wildcard().has_struct_info(R.Tensor([2, 3]))
+ pat = is_op("relax.add")(pat_lhs, pat_rhs)
+
+ def rewriter(expr, matches):
+ lhs = matches[pat_lhs]
+ rhs = matches[pat_rhs]
+ return rx.op.multiply(lhs, rhs)
+
+ @R.function(private=True)
+ def before():
+ with R.dataflow():
+ A = R.zeros([2, 3], "int32")
+ B = R.ones([2, 3], "int32")
+ C = R.add(A, B)
+
+ D = R.zeros([2, 3], "float32")
+ E = R.ones([2, 3], "float32")
+ F = R.add(D, E)
+
+ output = (C, F)
+ R.output(output)
+ return output
+
+ @R.function(private=True)
+ def expected():
+ with R.dataflow():
+ A = R.zeros([2, 3], "int32")
+ B = R.ones([2, 3], "int32")
+ C = R.multiply(A, B)
+
+ D = R.zeros([2, 3], "float32")
+ E = R.ones([2, 3], "float32")
+ F = R.multiply(D, E)
+
+ output = (C, F)
+ R.output(output)
+ return output
+
+ after = rewrite_call(pat, rewriter, before)
+ tvm.ir.assert_structural_equal(expected, after)
+
+
+def test_wildcard_struct_info_with_symbolic_vars():
+ """StructInfoPattern may define symbolic vars
+
+ This test finds an elementwise `R.add`, while ignoring a
+ broadcasted `R.add`.
+ """
+
+ m = tir.Var("m", "int64")
+ n = tir.Var("n", "int64")
+
+ pat_lhs = wildcard().has_struct_info(R.Tensor([m, n]))
+ pat_rhs = wildcard().has_struct_info(R.Tensor([m, n]))
+ pat = is_op("relax.add")(pat_lhs, pat_rhs)
+
+ def rewriter(expr, matches):
+ lhs = matches[pat_lhs]
+ rhs = matches[pat_rhs]
+ return rx.op.multiply(lhs, rhs)
+
+ @R.function(private=True)
+ def before():
+ with R.dataflow():
+ A = R.zeros([64, 128], "int32")
+ B = R.ones([64, 128], "int32")
+ C = R.add(A, B)
+
+ D = R.zeros([64, 128], "float32")
+ E = R.ones([1, 128], "float32")
+ F = R.add(D, E)
+
+ output = (C, F)
+ R.output(output)
+ return output
+
+ @R.function(private=True)
+ def expected():
+ with R.dataflow():
+ A = R.zeros([64, 128], "int32")
+ B = R.ones([64, 128], "int32")
+ C = R.multiply(A, B)
+
+ D = R.zeros([64, 128], "float32")
+ E = R.ones([1, 128], "float32")
+ F = R.add(D, E)
+
+ output = (C, F)
+ R.output(output)
+ return output
+
+ after = rewrite_call(pat, rewriter, before)
+ tvm.ir.assert_structural_equal(expected, after)
+
+
if __name__ == "__main__":
tvm.testing.main()