This is an automated email from the ASF dual-hosted git repository.

csullivan pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 5e02385c0c [Unity] Implemented SameShapeConstraint for dataflow 
pattern matches (#15694)
5e02385c0c is described below

commit 5e02385c0c2886cdf94af5127b77d77017f0a9a4
Author: Eric Lunderberg <[email protected]>
AuthorDate: Wed Sep 13 16:57:36 2023 -0500

    [Unity] Implemented SameShapeConstraint for dataflow pattern matches 
(#15694)
    
    * [Unity] Implemented SameShapeConstraint for dataflow pattern matches
    
    Prior to this commit, a shape could be explicitly specified using the
    `ShapePattern`, but could not be specified relative to the shape of
    another expression.  As a result, patterns with restricted shapes
    became very difficult to
    
    This commit implements `SameShapeConstraint`, which can be applied
    between any patterns that participate in the match.  For example,
    matching against `R.add(lhs,rhs)` where `lhs` and `rhs` have the same
    non-broadcasted shape.  Because these constraints operate between
    patterns that do not necessarily share a consumer/producer
    relationship, they could not previously be expressed using the
    existing `PairCons` functionality, and were instead implemented in
    terms of a new `DFConstraint` base class.
    
    * Removed empty line
    
    * Update API based on PR discussion
    
    Provide a `AsPrimExpr` instead of `IsConstraintSatisfied` function for
    constraints.  A constraint can return a necessary-and-sufficient
    condition, or a necessary-but-not-sufficient condition, and the
    calling scope can decide how to interpret those results.
    
    * Lint fix
    
    * lint fixes
---
 include/tvm/relax/dataflow_pattern.h        | 112 ++++++++++++++++++-
 python/tvm/relax/dpl/pattern.py             |  34 ++++++
 src/relax/ir/dataflow_matcher.cc            | 165 +++++++++++++++++++++++++---
 src/relax/ir/dataflow_pattern.cc            |  26 ++++-
 tests/python/relax/test_dataflow_pattern.py |  83 ++++++++++++++
 5 files changed, 400 insertions(+), 20 deletions(-)

diff --git a/include/tvm/relax/dataflow_pattern.h 
b/include/tvm/relax/dataflow_pattern.h
index 68cfdd83ad..933429cb9b 100644
--- a/include/tvm/relax/dataflow_pattern.h
+++ b/include/tvm/relax/dataflow_pattern.h
@@ -32,13 +32,20 @@
 #include <tvm/support/with.h>
 
 #include <cstdint>
+#include <functional>
 #include <map>
 #include <memory>
 #include <string>
+#include <tuple>
 #include <utility>
 #include <vector>
 
 namespace tvm {
+
+namespace arith {
+class Analyzer;
+}
+
 namespace relax {
 
 class PatternSeq;
@@ -50,6 +57,7 @@ class ShapePattern;
 class TypePattern;
 class DataTypePattern;
 class AttrPattern;
+class SameShapeConstraint;
 
 /*!
  * \brief Create used-by relationship between lhs[-1] and rhs[0], with [*lhs, 
*rhs] returned.
@@ -112,6 +120,8 @@ class DFPattern : public ObjectRef {
   TVM_DLL DataTypePattern HasDtype(const std::string& dtype) const;
   /*! \brief Syntatic Sugar for creating a ShapePattern */
   TVM_DLL ShapePattern HasShape(const Array<PrimExpr>& shape) const;
+  /*! \brief Syntatic Sugar for creating a ShapePattern */
+  TVM_DLL SameShapeConstraint HasSameShapeAs(const DFPattern& other) const;
   /*! \brief Syntatic Sugar for duplicating the current pattern */
   TVM_DLL DFPattern dup() const;
 
@@ -143,6 +153,58 @@ struct PairCons {
   }
 };
 
+/*! \brief Additional constraints on the graph
+ *
+ * Unlike PairCons, these may relate nodes that are not directly
+ * connected by a DFPattern edge from producer to consumer.  For
+ * example, constraining the two branches of an elementwise operation
+ * to have the same shape.
+ */
+class DFConstraintNode : public Object {
+ public:
+  /*! \brief Return the patterns on which the constraint depends */
+  virtual Array<DFPattern> GetDependentPatterns() const = 0;
+
+  /*! \brief Convert the constraint to a PrimExpr
+   *
+   * If the returned boolean parameter is true, then the returned
+   * expression is a necessary-and-sufficient condition for evaluating
+   * the constraint.  In this case, the matcher may either mark the
+   * constraint as satisfied (no need to re-check later), or as failed
+   * (need to back-track).
+   *
+   * If the returned boolean parameter is false, then the returned
+   * expression is a necessary-but-not-sufficient condition for
+   * evaluating the constraint.  In this case, the matcher may start
+   * backtracking as a result of a failed condition, but may not mark
+   * the constraint as satisfied.  This typically occurs when the
+   * constraint involves a parameter that the matcher has not yet
+   * filled.
+   *
+   * \param match_state A function that can be called to check the
+   *    current state of the match.  The function takes as argument a
+   *    pattern on which the constraint depends, and returns the relax
+   *    variable matched by that pattern, or NullOpt if the pattern
+   *    has not yet been matched.
+   *
+   * \return A tuple of `PrimExpr` and `bool`.  The first element is a
+   *    necessary condition for the constraint to be satisfied.  The
+   *    second tuple element indicates whether the condition is also
+   *    sufficient for the constraint to be satisfied.
+   */
+  virtual std::tuple<PrimExpr, bool> AsPrimExpr(
+      std::function<Optional<Var>(const DFPatternNode*)> match_state) const = 
0;
+
+  static constexpr const char* _type_key = "DFConstraintNode";
+  static constexpr const uint32_t _type_child_slots = 1;
+  TVM_DECLARE_BASE_OBJECT_INFO(DFConstraintNode, Object);
+};
+
+class DFConstraint : public ObjectRef {
+ public:
+  TVM_DEFINE_OBJECT_REF_METHODS(DFConstraint, ObjectRef, DFConstraintNode);
+};
+
 /*!
  * \brief A sequence of DFPatterns that the previous DFPattern is connected to 
the next one.
  * \sa PatternSeq
@@ -190,12 +252,18 @@ class PatternContextNode : public Object {
     kMay,     /*!< No constraints */
     kMustNot, /*!< All nodes except outputs only have internal depedencies in 
the matched graph. */
   } allow_extern_use = kMay;
+
   // src node -> <dst node, constraint type> constraints.
   // Dst nodes are kept in a vector to keep them ordered.
-  std::map<DFPattern, std::vector<std::pair<DFPattern, 
std::vector<PairCons>>>> constraints;
-  // Keep a separate vector of patterns to process constraints in a fixed 
order.
+  std::map<DFPattern, std::vector<std::pair<DFPattern, 
std::vector<PairCons>>>> edge_constraints;
+
+  // Underlying DFPattern nodes which the edge constraints may reference
+  // Kept as a separate vector of patterns to process constraints in a fixed 
order.
   std::vector<DFPattern> src_ordered;
 
+  // Non-edge constraints
+  std::vector<DFConstraint> validation_constraints;
+
   static constexpr const char* _type_key = "relax.dpl.PatternContext";
   TVM_DECLARE_FINAL_OBJECT_INFO(PatternContextNode, Object);
 };
@@ -227,7 +295,7 @@ class PatternContext : public ObjectRef {
    * \param cons The constraint type. \sa PairCons
    */
   void add_constraint(DFPattern producer, DFPattern consumer, PairCons cons) {
-    auto& pairs = (*this)->constraints[producer];
+    auto& pairs = (*this)->edge_constraints[producer];
     auto it = std::find_if(pairs.begin(), pairs.end(),
                            [consumer](auto p) { return p.first == consumer; });
     if (it == pairs.end()) {
@@ -245,6 +313,15 @@ class PatternContext : public ObjectRef {
     }
   }
 
+  /*!
+   * \brief Add a validation constraint
+   *
+   * \param constraint The new constraint
+   */
+  void add_constraint(DFConstraint constraint) {
+    (*this)->validation_constraints.push_back(constraint);
+  }
+
   /*! \brief Get the constraint context object on the top of the stack */
   TVM_DLL static Optional<PatternContext> Current();
 
@@ -709,6 +786,35 @@ class ShapePattern : public DFPattern {
   TVM_DEFINE_OBJECT_REF_METHODS(ShapePattern, DFPattern, ShapePatternNode);
 };
 
+/*!
+ * \brief A pattern that asserting multiple root patterns have the same shape
+ * \sa SameShapePattern
+ */
+class SameShapeConstraintNode : public DFConstraintNode {
+ public:
+  Array<DFPattern> args; /*!< The patterns with matching shapes */
+
+  Array<DFPattern> GetDependentPatterns() const override { return args; }
+
+  std::tuple<PrimExpr, bool> AsPrimExpr(
+      std::function<Optional<Var>(const DFPatternNode*)> match_state) const 
override;
+
+  void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("args", &args); }
+
+  static constexpr const char* _type_key = "relax.dpl.SameShapeConstraint";
+  TVM_DECLARE_FINAL_OBJECT_INFO(SameShapeConstraintNode, DFConstraintNode);
+};
+
+/*!
+ * \brief Managed reference to SameShapePatternNode.
+ * \sa SameShapePatternNode
+ */
+class SameShapeConstraint : public DFConstraint {
+ public:
+  TVM_DLL SameShapeConstraint(Array<DFPattern> args);
+  TVM_DEFINE_OBJECT_REF_METHODS(SameShapeConstraint, DFConstraint, 
SameShapeConstraintNode);
+};
+
 /*!
  * \brief A pattern that asserting a root pattern has a certain data type.
  * \sa DataTypePattern
diff --git a/python/tvm/relax/dpl/pattern.py b/python/tvm/relax/dpl/pattern.py
index 79883b9161..b72cb73b5f 100644
--- a/python/tvm/relax/dpl/pattern.py
+++ b/python/tvm/relax/dpl/pattern.py
@@ -289,6 +289,26 @@ class DFPattern(Node):
         for v in args:
             self ^ v
 
+    def same_shape_as(self, *args: List["DFPattern"]) -> "SameShapeConstraint":
+        """
+        The current pattern with the same shape as another pattern (sequence)
+
+        Parameters
+        ----------
+        other : List[DFPattern]
+            The other pattern (sequence)
+
+        Returns
+        -------
+        result: PatternSeq
+            A chained pattern sequence
+        """
+        return SameShapeConstraint(self, *args)
+
+
+class DFConstraint(Node):
+    """Base class of all constraints."""
+
 
 @register_df_node
 class ExprPattern(DFPattern):
@@ -606,6 +626,20 @@ class ShapePattern(DFPattern):
         self.__init_handle_by_constructor__(ffi.ShapePattern, pattern, shape)  
# type: ignore
 
 
+@register_df_node
+class SameShapeConstraint(DFConstraint):
+    """A pattern that requires a set of patterns to have the same shape
+
+    Parameters
+    ----------
+    args: List[DFPattern]
+        A set of patterns which must all provide the same shape.
+    """
+
+    def __init__(self, *args: List[DFPattern]):
+        self.__init_handle_by_constructor__(ffi.SameShapeConstraint, args)  # 
type: ignore
+
+
 @register_df_node
 class PrimArrPattern(DFPattern):
     """
diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc
index 41647e261d..ab2ad4fa36 100644
--- a/src/relax/ir/dataflow_matcher.cc
+++ b/src/relax/ir/dataflow_matcher.cc
@@ -22,6 +22,7 @@
  * \brief The dataflow pattern matcher for Relax.
  */
 
+#include <tvm/arith/analyzer.h>
 #include <tvm/node/structural_equal.h>
 #include <tvm/relax/analysis.h>
 #include <tvm/relax/dataflow_matcher.h>
@@ -443,6 +444,69 @@ bool DFPatternMatcher::VisitDFPattern_(const 
ShapePatternNode* op, const Expr& e
   return false;
 }
 
+std::tuple<PrimExpr, bool> SameShapeConstraintNode::AsPrimExpr(
+    std::function<Optional<Var>(const DFPatternNode*)> match_state) const {
+  Optional<Array<PrimExpr>> expected_shape;
+  bool all_shapes_defined = true;
+
+  // The expression that must be true in order
+  PrimExpr all_dimensions_equal = Bool(true);
+
+  for (const auto& arg : args) {
+    if (auto opt_var = match_state(arg.get())) {
+      auto var = opt_var.value();
+      auto opt_var_shape = [&]() -> Optional<Array<PrimExpr>> {
+        auto sinfo = GetStructInfo(var);
+        if (auto tensor = sinfo.as<TensorStructInfoNode>()) {
+          return tensor->GetShape();
+        } else if (auto shape_expr = sinfo.as<ShapeStructInfoNode>()) {
+          return shape_expr->values;
+        } else {
+          return NullOpt;
+        }
+      }();
+
+      if (!opt_var_shape.defined()) {
+        // The pattern has matched to something without a shape.
+        // Therefore, it cannot have the same shape as something else.
+        return {PrimExpr(Bool(false)), true};
+      }
+      auto var_shape = opt_var_shape.value();
+
+      if (expected_shape.defined()) {
+        auto prev_shape = expected_shape.value();
+        if (prev_shape.size() == var_shape.size()) {
+          // The dimensionalities match, so build up the expression
+          // that must be true for the shapes to be equivalent.
+          for (size_t i = 0; i < prev_shape.size(); i++) {
+            all_dimensions_equal = all_dimensions_equal && (var_shape[i] == 
prev_shape[i]);
+          }
+
+        } else {
+          // The shapes have different dimensionality.  No need to
+          // perform potentially-expensive simplifications, because
+          // the dimensions do not match.
+          return {PrimExpr(Bool(false)), true};
+        }
+
+      } else {
+        // This is the first pattern with a known match.  Store the
+        // shape so it can be compared against later shapes.
+        expected_shape = var_shape;
+      }
+
+    } else {
+      // Missing an argument, so the constraint will either return
+      // NullOpt or false at this point.  However, delay the return of
+      // NullOpt until the end of the function, because we'd rather
+      // return "false" if it possible to do so.
+      all_shapes_defined = false;
+    }
+  }
+
+  return {all_dimensions_equal, all_shapes_defined};
+}
+
 bool DFPatternMatcher::VisitDFPattern_(const PrimArrPatternNode* op, const 
Expr& expr0) {
   auto expr = TryGetValOfVar(expr0, var2val_);
   if (const ShapeExprNode* shape_expr = expr.as<ShapeExprNode>())
@@ -579,9 +643,12 @@ struct MatchState {
     match_r_p[r] = p;
   }
 
+  void add(const DFConstraintNode* constraint) { 
validated_constraints_.insert(constraint); }
+
   void add(MatchState&& other) {
     match_p_r.merge(std::move(other.match_p_r));
     match_r_p.merge(std::move(other.match_r_p));
+    validated_constraints_.merge(other.validated_constraints_);
   }
 
   const VarNode* matched(const PNode* p) const {
@@ -601,9 +668,14 @@ struct MatchState {
   const VarNode* matched(const PNode& p) const { return matched(&p); }
   const DFPatternNode* matched(const RNode& r) const { return matched(&r); }
 
+  bool is_validated(const DFConstraintNode* constraint) const {
+    return validated_constraints_.count(constraint);
+  }
+
  private:
   std::unordered_map<const PNode*, const RNode*> match_p_r;
   std::unordered_map<const RNode*, const PNode*> match_r_p;
+  std::unordered_set<const DFConstraintNode*> validated_constraints_;
 };
 
 /**
@@ -663,11 +735,64 @@ static std::optional<MatchState> TryMatch(const PNode& p, 
const RNode& r,
   return new_match;
 }
 
+static std::optional<MatchState> TryValidate(
+    const MatchState& current_match,
+    const std::unordered_map<const DFPatternNode*, PNode>& pattern2node,
+    const std::vector<DFConstraint>& validation_constraints, arith::Analyzer* 
analyzer) {
+  MatchState new_match;
+
+  std::function<Optional<Var>(const DFPatternNode*)> query_match_state =
+      [&pattern2node, &current_match](const DFPatternNode* pattern) -> 
Optional<Var> {
+    auto it = pattern2node.find(pattern);
+    ICHECK(it != pattern2node.end())
+        << "DFConstraint attempted to access DFPattern " << 
GetRef<DFPattern>(pattern)
+        << ", which does not appear in the PatternContext";
+    const auto& p_node = it->second;
+    if (auto ptr = current_match.matched(p_node)) {
+      return GetRef<Var>(ptr);
+    } else {
+      return NullOpt;
+    }
+  };
+
+  for (const auto& constraint : validation_constraints) {
+    if (!current_match.is_validated(constraint.get())) {
+      auto [necessary_condition, is_sufficient] = 
constraint->AsPrimExpr(query_match_state);
+
+      necessary_condition = analyzer->Simplify(necessary_condition);
+      const auto* known = tir::as_const_int(necessary_condition);
+
+      if (known && *known && is_sufficient) {
+        // The condition passes, and the expression provided is both
+        // necessary and sufficient for the constraint to pass.  Mark
+        // the constraint as passing, to avoid re-checking it unless
+        // we backtrack.
+        new_match.add(constraint.get());
+      } else if (known && !*known) {
+        // The condition fails.  Even if additional information would
+        // be required to pass a constraint, it may bail out early as
+        // a failure (e.g. shape mismatch in the first two items out
+        // of N shapes that must all match).
+        return std::nullopt;
+      } else if (is_sufficient) {
+        // The condition depends on dynamic parameters.  In the
+        // future, this may be exposed to the user as a condition for
+        // optimization, or can be combined with the conditions
+        // provided from other constraints.
+        return std::nullopt;
+      }
+    }
+  }
+
+  return new_match;
+}
+
 static std::optional<MatchState> MatchTree(
     const MatchState& current_match, size_t current_root_idx,
     const std::unordered_map<const DFPatternNode*, PNode>& pattern2node,
     const std::unordered_map<const VarNode*, RNode>& var2node, 
DFPatternMatcher* matcher,
-    const std::vector<DFPattern>& roots, const MatcherUseDefAnalysis& 
ud_analysis) {
+    const std::vector<DFPattern>& roots, const std::vector<DFConstraint>& 
validation_constraints,
+    const MatcherUseDefAnalysis& ud_analysis, arith::Analyzer* analyzer) {
   auto get_next_root = [&](size_t root_idx) -> const PNode* {
     // Look for the next unmatched root node.
     for (; root_idx < roots.size(); ++root_idx) {
@@ -692,12 +817,17 @@ static std::optional<MatchState> MatchTree(
     const RNode& r_node = var2node.at(var);
     if (new_match.matched(r_node)) continue;
     if (auto match = TryMatch(*root, r_node, new_match, matcher, ud_analysis)) 
{
-      // Recursivly try to match the next subtree.
+      // Recursively try to match the next subtree.
       new_match.add(std::move(*match));
-      if (auto match_rec = MatchTree(new_match, current_root_idx + 1, 
pattern2node, var2node,
-                                     matcher, roots, ud_analysis)) {
-        new_match.add(std::move(*match_rec));
-        return new_match;
+      if (auto validation =
+              TryValidate(new_match, pattern2node, validation_constraints, 
analyzer)) {
+        new_match.add(std::move(*validation));
+        if (auto match_rec =
+                MatchTree(new_match, current_root_idx + 1, pattern2node, 
var2node, matcher, roots,
+                          validation_constraints, ud_analysis, analyzer)) {
+          new_match.add(std::move(*match_rec));
+          return new_match;
+        }
       }
       // Recursive matching has failed, backtrack.
       new_match = current_match;
@@ -734,11 +864,11 @@ Optional<Map<DFPattern, Var>> MatchGraph(const 
PatternContext& ctx, const Datafl
   }
 
   std::unordered_map<const DFPatternNode*, PNode> pattern2node;
-  pattern2node.reserve(ctx->constraints.size());
+  pattern2node.reserve(ctx->edge_constraints.size());
 
   for (const auto& def_pattern : ctx->src_ordered) {
     PNode& def_node = pattern2node[def_pattern.get()];
-    const auto& uses = ctx->constraints.at(def_pattern);
+    const auto& uses = ctx->edge_constraints.at(def_pattern);
     def_node.ptr = def_pattern.get();
     def_node.children.reserve(uses.size());
     for (const auto& [use_pattern, cons] : uses) {
@@ -760,16 +890,19 @@ Optional<Map<DFPattern, Var>> MatchGraph(const 
PatternContext& ctx, const Datafl
     return NullOpt;
   }
 
-  if (auto match = MatchTree({}, 0, pattern2node, var2node, &matcher, roots, 
ud_analysis)) {
-    Map<DFPattern, Var> ret;
-    for (const auto& [pat, p_node] : pattern2node) {
-      ICHECK(match->matched(p_node));
-      ret.Set(GetRef<DFPattern>(pat), GetRef<Var>(match->matched(p_node)));
-    }
-    return ret;
+  arith::Analyzer analyzer;
+  auto match = MatchTree({}, 0, pattern2node, var2node, &matcher, roots,
+                         ctx->validation_constraints, ud_analysis, &analyzer);
+  if (!match) {
+    return NullOpt;
   }
 
-  return NullOpt;
+  Map<DFPattern, Var> ret;
+  for (const auto& [pat, p_node] : pattern2node) {
+    ICHECK(match->matched(p_node));
+    ret.Set(GetRef<DFPattern>(pat), GetRef<Var>(match->matched(p_node)));
+  }
+  return ret;
 }
 
 Optional<Map<DFPattern, Var>> MatchGraph(const PatternContext& ctx, const 
DataflowBlock& dfb) {
diff --git a/src/relax/ir/dataflow_pattern.cc b/src/relax/ir/dataflow_pattern.cc
index cd1376303c..faa890a12c 100644
--- a/src/relax/ir/dataflow_pattern.cc
+++ b/src/relax/ir/dataflow_pattern.cc
@@ -276,6 +276,30 @@ RELAX_PATTERN_PRINTER_DEF(ShapePatternNode, [](auto p, 
auto node) {
   p->stream << "ShapePattern(" << node->pattern << " has shape " << 
node->shape << ")";
 });
 
+TVM_REGISTER_NODE_TYPE(SameShapeConstraintNode);
+SameShapeConstraint::SameShapeConstraint(Array<DFPattern> args) {
+  ObjectPtr<SameShapeConstraintNode> n = 
make_object<SameShapeConstraintNode>();
+  n->args = std::move(args);
+  data_ = std::move(n);
+
+  if (auto ctx = PatternContext::Current()) {
+    ctx.value().add_constraint(*this);
+  }
+}
+TVM_REGISTER_GLOBAL("relax.dpl.SameShapeConstraint").set_body_typed([](Array<DFPattern>
 args) {
+  return SameShapeConstraint(args);
+});
+RELAX_PATTERN_PRINTER_DEF(SameShapeConstraintNode, [](auto p, auto node) {
+  p->stream << "SameShapeConstraint(";
+  for (size_t i = 0; i < node->args.size(); i++) {
+    if (i) {
+      p->stream << ", ";
+    }
+    p->stream << node->args;
+  }
+  p->stream << ")";
+});
+
 TVM_REGISTER_NODE_TYPE(DataTypePatternNode);
 DataTypePattern::DataTypePattern(DFPattern pattern, DataType dtype) {
   ObjectPtr<DataTypePatternNode> n = make_object<DataTypePatternNode>();
@@ -405,7 +429,7 @@ PatternContext::PatternContext(bool incremental) {
     ICHECK(!pattern_ctx_stack().empty())
         << "Incremental context needs to be built inside a existing context.";
     n->allow_extern_use = pattern_ctx_stack().top()->allow_extern_use;
-    n->constraints = pattern_ctx_stack().top()->constraints;
+    n->edge_constraints = pattern_ctx_stack().top()->edge_constraints;
     n->src_ordered = pattern_ctx_stack().top()->src_ordered;
   }
 
diff --git a/tests/python/relax/test_dataflow_pattern.py 
b/tests/python/relax/test_dataflow_pattern.py
index 3444eff79b..7b68655a2f 100644
--- a/tests/python/relax/test_dataflow_pattern.py
+++ b/tests/python/relax/test_dataflow_pattern.py
@@ -1407,5 +1407,88 @@ def test_rewrite_without_trivial_binding():
     tvm.ir.assert_structural_equal(after, expected)
 
 
+same_shape_func_type = tvm.testing.parameter(
+    "same_static_shape",
+    "same_dynamic_shape",
+    "different_static_shape",
+    "different_dynamic_shape",
+)
+
+
+def test_same_shape_pattern(same_shape_func_type):
+    if same_shape_func_type == "same_static_shape":
+
+        @R.function(private=True)
+        def func(
+            a: R.Tensor((1024, 128), "float32"),
+            b: R.Tensor((1024, 128), "float32"),
+        ) -> R.Tensor:
+            with R.dataflow():
+                c = R.multiply(a, R.const(2.0))
+                d = R.add(b, c)
+                out = d
+                R.output(out)
+            return out
+
+    elif same_shape_func_type == "same_dynamic_shape":
+
+        @R.function(private=True)
+        def func(
+            a: R.Tensor(("n", 128), "float32"),
+            b: R.Tensor(("n", 128), "float32"),
+        ) -> R.Tensor:
+            with R.dataflow():
+                c = R.multiply(a, R.const(2.0))
+                d = R.add(b, c)
+                out = d
+                R.output(out)
+            return out
+
+    elif same_shape_func_type == "different_static_shape":
+
+        @R.function(private=True)
+        def func(
+            a: R.Tensor((1024, 128), "float32"),
+            b: R.Tensor((1, 128), "float32"),
+        ) -> R.Tensor:
+            with R.dataflow():
+                c = R.multiply(a, R.const(2.0))
+                d = R.add(b, c)
+                out = d
+                R.output(out)
+            return out
+
+    elif same_shape_func_type == "different_dynamic_shape":
+
+        @R.function(private=True)
+        def func(
+            a: R.Tensor(("n", 128), "float32"),
+            b: R.Tensor(("m", 128), "float32"),
+        ) -> R.Tensor:
+            with R.dataflow():
+                c = R.multiply(a, R.const(2.0))
+                d = R.add(b, c)
+                out = d
+                R.output(out)
+            return out
+
+    else:
+        raise ValueError(f"Unknown value of 
same_shape_func_type={same_shape_func_type}")
+
+    with PatternContext() as ctx:
+        pat_lhs = wildcard()
+        pat_rhs = wildcard()
+        pat_sum = is_op("relax.add")(pat_lhs, pat_rhs)
+        pat_lhs.same_shape_as(pat_rhs)
+
+    block = func.body.blocks[0]
+    match = ctx.match_dfb(block)
+
+    if "same" in same_shape_func_type:
+        assert match
+    else:
+        assert match is None
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to