This is an automated email from the ASF dual-hosted git repository.
csullivan pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 5e02385c0c [Unity] Implemented SameShapeConstraint for dataflow
pattern matches (#15694)
5e02385c0c is described below
commit 5e02385c0c2886cdf94af5127b77d77017f0a9a4
Author: Eric Lunderberg <[email protected]>
AuthorDate: Wed Sep 13 16:57:36 2023 -0500
[Unity] Implemented SameShapeConstraint for dataflow pattern matches
(#15694)
* [Unity] Implemented SameShapeConstraint for dataflow pattern matches
Prior to this commit, a shape could be explicitly specified using the
`ShapePattern`, but could not be specified relative to the shape of
another expression. As a result, patterns with restricted shapes
became very difficult to
This commit implements `SameShapeConstraint`, which can be applied
between any patterns that participate in the match. For example,
matching against `R.add(lhs,rhs)` where `lhs` and `rhs` have the same
non-broadcasted shape. Because these constraints operate between
patterns that do not necessarily share a consumer/producer
relationship, they could not previously be expressed using the
existing `PairCons` functionality, and were instead implemented in
terms of a new `DFConstraint` base class.
* Removed empty line
* Update API based on PR discussion
Provide a `AsPrimExpr` instead of `IsConstraintSatisfied` function for
constraints. A constraint can return a necessary-and-sufficient
condition, or a necessary-but-not-sufficient condition, and the
calling scope can decide how to interpret those results.
* Lint fix
* lint fixes
---
include/tvm/relax/dataflow_pattern.h | 112 ++++++++++++++++++-
python/tvm/relax/dpl/pattern.py | 34 ++++++
src/relax/ir/dataflow_matcher.cc | 165 +++++++++++++++++++++++++---
src/relax/ir/dataflow_pattern.cc | 26 ++++-
tests/python/relax/test_dataflow_pattern.py | 83 ++++++++++++++
5 files changed, 400 insertions(+), 20 deletions(-)
diff --git a/include/tvm/relax/dataflow_pattern.h
b/include/tvm/relax/dataflow_pattern.h
index 68cfdd83ad..933429cb9b 100644
--- a/include/tvm/relax/dataflow_pattern.h
+++ b/include/tvm/relax/dataflow_pattern.h
@@ -32,13 +32,20 @@
#include <tvm/support/with.h>
#include <cstdint>
+#include <functional>
#include <map>
#include <memory>
#include <string>
+#include <tuple>
#include <utility>
#include <vector>
namespace tvm {
+
+namespace arith {
+class Analyzer;
+}
+
namespace relax {
class PatternSeq;
@@ -50,6 +57,7 @@ class ShapePattern;
class TypePattern;
class DataTypePattern;
class AttrPattern;
+class SameShapeConstraint;
/*!
* \brief Create used-by relationship between lhs[-1] and rhs[0], with [*lhs,
*rhs] returned.
@@ -112,6 +120,8 @@ class DFPattern : public ObjectRef {
TVM_DLL DataTypePattern HasDtype(const std::string& dtype) const;
/*! \brief Syntatic Sugar for creating a ShapePattern */
TVM_DLL ShapePattern HasShape(const Array<PrimExpr>& shape) const;
+ /*! \brief Syntatic Sugar for creating a ShapePattern */
+ TVM_DLL SameShapeConstraint HasSameShapeAs(const DFPattern& other) const;
/*! \brief Syntatic Sugar for duplicating the current pattern */
TVM_DLL DFPattern dup() const;
@@ -143,6 +153,58 @@ struct PairCons {
}
};
+/*! \brief Additional constraints on the graph
+ *
+ * Unlike PairCons, these may relate nodes that are not directly
+ * connected by a DFPattern edge from producer to consumer. For
+ * example, constraining the two branches of an elementwise operation
+ * to have the same shape.
+ */
+class DFConstraintNode : public Object {
+ public:
+ /*! \brief Return the patterns on which the constraint depends */
+ virtual Array<DFPattern> GetDependentPatterns() const = 0;
+
+ /*! \brief Convert the constraint to a PrimExpr
+ *
+ * If the returned boolean parameter is true, then the returned
+ * expression is a necessary-and-sufficient condition for evaluating
+ * the constraint. In this case, the matcher may either mark the
+ * constraint as satisfied (no need to re-check later), or as failed
+ * (need to back-track).
+ *
+ * If the returned boolean parameter is false, then the returned
+ * expression is a necessary-but-not-sufficient condition for
+ * evaluating the constraint. In this case, the matcher may start
+ * backtracking as a result of a failed condition, but may not mark
+ * the constraint as satisfied. This typically occurs when the
+ * constraint involves a parameter that the matcher has not yet
+ * filled.
+ *
+ * \param match_state A function that can be called to check the
+ * current state of the match. The function takes as argument a
+ * pattern on which the constraint depends, and returns the relax
+ * variable matched by that pattern, or NullOpt if the pattern
+ * has not yet been matched.
+ *
+ * \return A tuple of `PrimExpr` and `bool`. The first element is a
+ * necessary condition for the constraint to be satisfied. The
+ * second tuple element indicates whether the condition is also
+ * sufficient for the constraint to be satisfied.
+ */
+ virtual std::tuple<PrimExpr, bool> AsPrimExpr(
+ std::function<Optional<Var>(const DFPatternNode*)> match_state) const =
0;
+
+ static constexpr const char* _type_key = "DFConstraintNode";
+ static constexpr const uint32_t _type_child_slots = 1;
+ TVM_DECLARE_BASE_OBJECT_INFO(DFConstraintNode, Object);
+};
+
+class DFConstraint : public ObjectRef {
+ public:
+ TVM_DEFINE_OBJECT_REF_METHODS(DFConstraint, ObjectRef, DFConstraintNode);
+};
+
/*!
* \brief A sequence of DFPatterns that the previous DFPattern is connected to
the next one.
* \sa PatternSeq
@@ -190,12 +252,18 @@ class PatternContextNode : public Object {
kMay, /*!< No constraints */
kMustNot, /*!< All nodes except outputs only have internal depedencies in
the matched graph. */
} allow_extern_use = kMay;
+
// src node -> <dst node, constraint type> constraints.
// Dst nodes are kept in a vector to keep them ordered.
- std::map<DFPattern, std::vector<std::pair<DFPattern,
std::vector<PairCons>>>> constraints;
- // Keep a separate vector of patterns to process constraints in a fixed
order.
+ std::map<DFPattern, std::vector<std::pair<DFPattern,
std::vector<PairCons>>>> edge_constraints;
+
+ // Underlying DFPattern nodes which the edge constraints may reference
+ // Kept as a separate vector of patterns to process constraints in a fixed
order.
std::vector<DFPattern> src_ordered;
+ // Non-edge constraints
+ std::vector<DFConstraint> validation_constraints;
+
static constexpr const char* _type_key = "relax.dpl.PatternContext";
TVM_DECLARE_FINAL_OBJECT_INFO(PatternContextNode, Object);
};
@@ -227,7 +295,7 @@ class PatternContext : public ObjectRef {
* \param cons The constraint type. \sa PairCons
*/
void add_constraint(DFPattern producer, DFPattern consumer, PairCons cons) {
- auto& pairs = (*this)->constraints[producer];
+ auto& pairs = (*this)->edge_constraints[producer];
auto it = std::find_if(pairs.begin(), pairs.end(),
[consumer](auto p) { return p.first == consumer; });
if (it == pairs.end()) {
@@ -245,6 +313,15 @@ class PatternContext : public ObjectRef {
}
}
+ /*!
+ * \brief Add a validation constraint
+ *
+ * \param constraint The new constraint
+ */
+ void add_constraint(DFConstraint constraint) {
+ (*this)->validation_constraints.push_back(constraint);
+ }
+
/*! \brief Get the constraint context object on the top of the stack */
TVM_DLL static Optional<PatternContext> Current();
@@ -709,6 +786,35 @@ class ShapePattern : public DFPattern {
TVM_DEFINE_OBJECT_REF_METHODS(ShapePattern, DFPattern, ShapePatternNode);
};
+/*!
+ * \brief A pattern that asserting multiple root patterns have the same shape
+ * \sa SameShapePattern
+ */
+class SameShapeConstraintNode : public DFConstraintNode {
+ public:
+ Array<DFPattern> args; /*!< The patterns with matching shapes */
+
+ Array<DFPattern> GetDependentPatterns() const override { return args; }
+
+ std::tuple<PrimExpr, bool> AsPrimExpr(
+ std::function<Optional<Var>(const DFPatternNode*)> match_state) const
override;
+
+ void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("args", &args); }
+
+ static constexpr const char* _type_key = "relax.dpl.SameShapeConstraint";
+ TVM_DECLARE_FINAL_OBJECT_INFO(SameShapeConstraintNode, DFConstraintNode);
+};
+
+/*!
+ * \brief Managed reference to SameShapePatternNode.
+ * \sa SameShapePatternNode
+ */
+class SameShapeConstraint : public DFConstraint {
+ public:
+ TVM_DLL SameShapeConstraint(Array<DFPattern> args);
+ TVM_DEFINE_OBJECT_REF_METHODS(SameShapeConstraint, DFConstraint,
SameShapeConstraintNode);
+};
+
/*!
* \brief A pattern that asserting a root pattern has a certain data type.
* \sa DataTypePattern
diff --git a/python/tvm/relax/dpl/pattern.py b/python/tvm/relax/dpl/pattern.py
index 79883b9161..b72cb73b5f 100644
--- a/python/tvm/relax/dpl/pattern.py
+++ b/python/tvm/relax/dpl/pattern.py
@@ -289,6 +289,26 @@ class DFPattern(Node):
for v in args:
self ^ v
+ def same_shape_as(self, *args: List["DFPattern"]) -> "SameShapeConstraint":
+ """
+ The current pattern with the same shape as another pattern (sequence)
+
+ Parameters
+ ----------
+ other : List[DFPattern]
+ The other pattern (sequence)
+
+ Returns
+ -------
+ result: PatternSeq
+ A chained pattern sequence
+ """
+ return SameShapeConstraint(self, *args)
+
+
+class DFConstraint(Node):
+ """Base class of all constraints."""
+
@register_df_node
class ExprPattern(DFPattern):
@@ -606,6 +626,20 @@ class ShapePattern(DFPattern):
self.__init_handle_by_constructor__(ffi.ShapePattern, pattern, shape)
# type: ignore
+@register_df_node
+class SameShapeConstraint(DFConstraint):
+ """A pattern that requires a set of patterns to have the same shape
+
+ Parameters
+ ----------
+ args: List[DFPattern]
+ A set of patterns which must all provide the same shape.
+ """
+
+ def __init__(self, *args: List[DFPattern]):
+ self.__init_handle_by_constructor__(ffi.SameShapeConstraint, args) #
type: ignore
+
+
@register_df_node
class PrimArrPattern(DFPattern):
"""
diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc
index 41647e261d..ab2ad4fa36 100644
--- a/src/relax/ir/dataflow_matcher.cc
+++ b/src/relax/ir/dataflow_matcher.cc
@@ -22,6 +22,7 @@
* \brief The dataflow pattern matcher for Relax.
*/
+#include <tvm/arith/analyzer.h>
#include <tvm/node/structural_equal.h>
#include <tvm/relax/analysis.h>
#include <tvm/relax/dataflow_matcher.h>
@@ -443,6 +444,69 @@ bool DFPatternMatcher::VisitDFPattern_(const
ShapePatternNode* op, const Expr& e
return false;
}
+std::tuple<PrimExpr, bool> SameShapeConstraintNode::AsPrimExpr(
+ std::function<Optional<Var>(const DFPatternNode*)> match_state) const {
+ Optional<Array<PrimExpr>> expected_shape;
+ bool all_shapes_defined = true;
+
+ // The expression that must be true in order
+ PrimExpr all_dimensions_equal = Bool(true);
+
+ for (const auto& arg : args) {
+ if (auto opt_var = match_state(arg.get())) {
+ auto var = opt_var.value();
+ auto opt_var_shape = [&]() -> Optional<Array<PrimExpr>> {
+ auto sinfo = GetStructInfo(var);
+ if (auto tensor = sinfo.as<TensorStructInfoNode>()) {
+ return tensor->GetShape();
+ } else if (auto shape_expr = sinfo.as<ShapeStructInfoNode>()) {
+ return shape_expr->values;
+ } else {
+ return NullOpt;
+ }
+ }();
+
+ if (!opt_var_shape.defined()) {
+ // The pattern has matched to something without a shape.
+ // Therefore, it cannot have the same shape as something else.
+ return {PrimExpr(Bool(false)), true};
+ }
+ auto var_shape = opt_var_shape.value();
+
+ if (expected_shape.defined()) {
+ auto prev_shape = expected_shape.value();
+ if (prev_shape.size() == var_shape.size()) {
+ // The dimensionalities match, so build up the expression
+ // that must be true for the shapes to be equivalent.
+ for (size_t i = 0; i < prev_shape.size(); i++) {
+ all_dimensions_equal = all_dimensions_equal && (var_shape[i] ==
prev_shape[i]);
+ }
+
+ } else {
+ // The shapes have different dimensionality. No need to
+ // perform potentially-expensive simplifications, because
+ // the dimensions do not match.
+ return {PrimExpr(Bool(false)), true};
+ }
+
+ } else {
+ // This is the first pattern with a known match. Store the
+ // shape so it can be compared against later shapes.
+ expected_shape = var_shape;
+ }
+
+ } else {
+ // Missing an argument, so the constraint will either return
+ // NullOpt or false at this point. However, delay the return of
+ // NullOpt until the end of the function, because we'd rather
+ // return "false" if it possible to do so.
+ all_shapes_defined = false;
+ }
+ }
+
+ return {all_dimensions_equal, all_shapes_defined};
+}
+
bool DFPatternMatcher::VisitDFPattern_(const PrimArrPatternNode* op, const
Expr& expr0) {
auto expr = TryGetValOfVar(expr0, var2val_);
if (const ShapeExprNode* shape_expr = expr.as<ShapeExprNode>())
@@ -579,9 +643,12 @@ struct MatchState {
match_r_p[r] = p;
}
+ void add(const DFConstraintNode* constraint) {
validated_constraints_.insert(constraint); }
+
void add(MatchState&& other) {
match_p_r.merge(std::move(other.match_p_r));
match_r_p.merge(std::move(other.match_r_p));
+ validated_constraints_.merge(other.validated_constraints_);
}
const VarNode* matched(const PNode* p) const {
@@ -601,9 +668,14 @@ struct MatchState {
const VarNode* matched(const PNode& p) const { return matched(&p); }
const DFPatternNode* matched(const RNode& r) const { return matched(&r); }
+ bool is_validated(const DFConstraintNode* constraint) const {
+ return validated_constraints_.count(constraint);
+ }
+
private:
std::unordered_map<const PNode*, const RNode*> match_p_r;
std::unordered_map<const RNode*, const PNode*> match_r_p;
+ std::unordered_set<const DFConstraintNode*> validated_constraints_;
};
/**
@@ -663,11 +735,64 @@ static std::optional<MatchState> TryMatch(const PNode& p,
const RNode& r,
return new_match;
}
+static std::optional<MatchState> TryValidate(
+ const MatchState& current_match,
+ const std::unordered_map<const DFPatternNode*, PNode>& pattern2node,
+ const std::vector<DFConstraint>& validation_constraints, arith::Analyzer*
analyzer) {
+ MatchState new_match;
+
+ std::function<Optional<Var>(const DFPatternNode*)> query_match_state =
+ [&pattern2node, ¤t_match](const DFPatternNode* pattern) ->
Optional<Var> {
+ auto it = pattern2node.find(pattern);
+ ICHECK(it != pattern2node.end())
+ << "DFConstraint attempted to access DFPattern " <<
GetRef<DFPattern>(pattern)
+ << ", which does not appear in the PatternContext";
+ const auto& p_node = it->second;
+ if (auto ptr = current_match.matched(p_node)) {
+ return GetRef<Var>(ptr);
+ } else {
+ return NullOpt;
+ }
+ };
+
+ for (const auto& constraint : validation_constraints) {
+ if (!current_match.is_validated(constraint.get())) {
+ auto [necessary_condition, is_sufficient] =
constraint->AsPrimExpr(query_match_state);
+
+ necessary_condition = analyzer->Simplify(necessary_condition);
+ const auto* known = tir::as_const_int(necessary_condition);
+
+ if (known && *known && is_sufficient) {
+ // The condition passes, and the expression provided is both
+ // necessary and sufficient for the constraint to pass. Mark
+ // the constraint as passing, to avoid re-checking it unless
+ // we backtrack.
+ new_match.add(constraint.get());
+ } else if (known && !*known) {
+ // The condition fails. Even if additional information would
+ // be required to pass a constraint, it may bail out early as
+ // a failure (e.g. shape mismatch in the first two items out
+ // of N shapes that must all match).
+ return std::nullopt;
+ } else if (is_sufficient) {
+ // The condition depends on dynamic parameters. In the
+ // future, this may be exposed to the user as a condition for
+ // optimization, or can be combined with the conditions
+ // provided from other constraints.
+ return std::nullopt;
+ }
+ }
+ }
+
+ return new_match;
+}
+
static std::optional<MatchState> MatchTree(
const MatchState& current_match, size_t current_root_idx,
const std::unordered_map<const DFPatternNode*, PNode>& pattern2node,
const std::unordered_map<const VarNode*, RNode>& var2node,
DFPatternMatcher* matcher,
- const std::vector<DFPattern>& roots, const MatcherUseDefAnalysis&
ud_analysis) {
+ const std::vector<DFPattern>& roots, const std::vector<DFConstraint>&
validation_constraints,
+ const MatcherUseDefAnalysis& ud_analysis, arith::Analyzer* analyzer) {
auto get_next_root = [&](size_t root_idx) -> const PNode* {
// Look for the next unmatched root node.
for (; root_idx < roots.size(); ++root_idx) {
@@ -692,12 +817,17 @@ static std::optional<MatchState> MatchTree(
const RNode& r_node = var2node.at(var);
if (new_match.matched(r_node)) continue;
if (auto match = TryMatch(*root, r_node, new_match, matcher, ud_analysis))
{
- // Recursivly try to match the next subtree.
+ // Recursively try to match the next subtree.
new_match.add(std::move(*match));
- if (auto match_rec = MatchTree(new_match, current_root_idx + 1,
pattern2node, var2node,
- matcher, roots, ud_analysis)) {
- new_match.add(std::move(*match_rec));
- return new_match;
+ if (auto validation =
+ TryValidate(new_match, pattern2node, validation_constraints,
analyzer)) {
+ new_match.add(std::move(*validation));
+ if (auto match_rec =
+ MatchTree(new_match, current_root_idx + 1, pattern2node,
var2node, matcher, roots,
+ validation_constraints, ud_analysis, analyzer)) {
+ new_match.add(std::move(*match_rec));
+ return new_match;
+ }
}
// Recursive matching has failed, backtrack.
new_match = current_match;
@@ -734,11 +864,11 @@ Optional<Map<DFPattern, Var>> MatchGraph(const
PatternContext& ctx, const Datafl
}
std::unordered_map<const DFPatternNode*, PNode> pattern2node;
- pattern2node.reserve(ctx->constraints.size());
+ pattern2node.reserve(ctx->edge_constraints.size());
for (const auto& def_pattern : ctx->src_ordered) {
PNode& def_node = pattern2node[def_pattern.get()];
- const auto& uses = ctx->constraints.at(def_pattern);
+ const auto& uses = ctx->edge_constraints.at(def_pattern);
def_node.ptr = def_pattern.get();
def_node.children.reserve(uses.size());
for (const auto& [use_pattern, cons] : uses) {
@@ -760,16 +890,19 @@ Optional<Map<DFPattern, Var>> MatchGraph(const
PatternContext& ctx, const Datafl
return NullOpt;
}
- if (auto match = MatchTree({}, 0, pattern2node, var2node, &matcher, roots,
ud_analysis)) {
- Map<DFPattern, Var> ret;
- for (const auto& [pat, p_node] : pattern2node) {
- ICHECK(match->matched(p_node));
- ret.Set(GetRef<DFPattern>(pat), GetRef<Var>(match->matched(p_node)));
- }
- return ret;
+ arith::Analyzer analyzer;
+ auto match = MatchTree({}, 0, pattern2node, var2node, &matcher, roots,
+ ctx->validation_constraints, ud_analysis, &analyzer);
+ if (!match) {
+ return NullOpt;
}
- return NullOpt;
+ Map<DFPattern, Var> ret;
+ for (const auto& [pat, p_node] : pattern2node) {
+ ICHECK(match->matched(p_node));
+ ret.Set(GetRef<DFPattern>(pat), GetRef<Var>(match->matched(p_node)));
+ }
+ return ret;
}
Optional<Map<DFPattern, Var>> MatchGraph(const PatternContext& ctx, const
DataflowBlock& dfb) {
diff --git a/src/relax/ir/dataflow_pattern.cc b/src/relax/ir/dataflow_pattern.cc
index cd1376303c..faa890a12c 100644
--- a/src/relax/ir/dataflow_pattern.cc
+++ b/src/relax/ir/dataflow_pattern.cc
@@ -276,6 +276,30 @@ RELAX_PATTERN_PRINTER_DEF(ShapePatternNode, [](auto p,
auto node) {
p->stream << "ShapePattern(" << node->pattern << " has shape " <<
node->shape << ")";
});
+TVM_REGISTER_NODE_TYPE(SameShapeConstraintNode);
+SameShapeConstraint::SameShapeConstraint(Array<DFPattern> args) {
+ ObjectPtr<SameShapeConstraintNode> n =
make_object<SameShapeConstraintNode>();
+ n->args = std::move(args);
+ data_ = std::move(n);
+
+ if (auto ctx = PatternContext::Current()) {
+ ctx.value().add_constraint(*this);
+ }
+}
+TVM_REGISTER_GLOBAL("relax.dpl.SameShapeConstraint").set_body_typed([](Array<DFPattern>
args) {
+ return SameShapeConstraint(args);
+});
+RELAX_PATTERN_PRINTER_DEF(SameShapeConstraintNode, [](auto p, auto node) {
+ p->stream << "SameShapeConstraint(";
+ for (size_t i = 0; i < node->args.size(); i++) {
+ if (i) {
+ p->stream << ", ";
+ }
+ p->stream << node->args;
+ }
+ p->stream << ")";
+});
+
TVM_REGISTER_NODE_TYPE(DataTypePatternNode);
DataTypePattern::DataTypePattern(DFPattern pattern, DataType dtype) {
ObjectPtr<DataTypePatternNode> n = make_object<DataTypePatternNode>();
@@ -405,7 +429,7 @@ PatternContext::PatternContext(bool incremental) {
ICHECK(!pattern_ctx_stack().empty())
<< "Incremental context needs to be built inside a existing context.";
n->allow_extern_use = pattern_ctx_stack().top()->allow_extern_use;
- n->constraints = pattern_ctx_stack().top()->constraints;
+ n->edge_constraints = pattern_ctx_stack().top()->edge_constraints;
n->src_ordered = pattern_ctx_stack().top()->src_ordered;
}
diff --git a/tests/python/relax/test_dataflow_pattern.py
b/tests/python/relax/test_dataflow_pattern.py
index 3444eff79b..7b68655a2f 100644
--- a/tests/python/relax/test_dataflow_pattern.py
+++ b/tests/python/relax/test_dataflow_pattern.py
@@ -1407,5 +1407,88 @@ def test_rewrite_without_trivial_binding():
tvm.ir.assert_structural_equal(after, expected)
+same_shape_func_type = tvm.testing.parameter(
+ "same_static_shape",
+ "same_dynamic_shape",
+ "different_static_shape",
+ "different_dynamic_shape",
+)
+
+
+def test_same_shape_pattern(same_shape_func_type):
+ if same_shape_func_type == "same_static_shape":
+
+ @R.function(private=True)
+ def func(
+ a: R.Tensor((1024, 128), "float32"),
+ b: R.Tensor((1024, 128), "float32"),
+ ) -> R.Tensor:
+ with R.dataflow():
+ c = R.multiply(a, R.const(2.0))
+ d = R.add(b, c)
+ out = d
+ R.output(out)
+ return out
+
+ elif same_shape_func_type == "same_dynamic_shape":
+
+ @R.function(private=True)
+ def func(
+ a: R.Tensor(("n", 128), "float32"),
+ b: R.Tensor(("n", 128), "float32"),
+ ) -> R.Tensor:
+ with R.dataflow():
+ c = R.multiply(a, R.const(2.0))
+ d = R.add(b, c)
+ out = d
+ R.output(out)
+ return out
+
+ elif same_shape_func_type == "different_static_shape":
+
+ @R.function(private=True)
+ def func(
+ a: R.Tensor((1024, 128), "float32"),
+ b: R.Tensor((1, 128), "float32"),
+ ) -> R.Tensor:
+ with R.dataflow():
+ c = R.multiply(a, R.const(2.0))
+ d = R.add(b, c)
+ out = d
+ R.output(out)
+ return out
+
+ elif same_shape_func_type == "different_dynamic_shape":
+
+ @R.function(private=True)
+ def func(
+ a: R.Tensor(("n", 128), "float32"),
+ b: R.Tensor(("m", 128), "float32"),
+ ) -> R.Tensor:
+ with R.dataflow():
+ c = R.multiply(a, R.const(2.0))
+ d = R.add(b, c)
+ out = d
+ R.output(out)
+ return out
+
+ else:
+ raise ValueError(f"Unknown value of
same_shape_func_type={same_shape_func_type}")
+
+ with PatternContext() as ctx:
+ pat_lhs = wildcard()
+ pat_rhs = wildcard()
+ pat_sum = is_op("relax.add")(pat_lhs, pat_rhs)
+ pat_lhs.same_shape_as(pat_rhs)
+
+ block = func.body.blocks[0]
+ match = ctx.match_dfb(block)
+
+ if "same" in same_shape_func_type:
+ assert match
+ else:
+ assert match is None
+
+
if __name__ == "__main__":
tvm.testing.main()