This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new e86a470ce0 [Relay] Enhance type infer for dynamic shape (#14601)
e86a470ce0 is described below

commit e86a470ce091aeca2908d354363f650766a5c0f6
Author: shengxinhu <[email protected]>
AuthorDate: Mon Apr 17 10:44:04 2023 +0800

    [Relay] Enhance type infer for dynamic shape (#14601)
    
    * [Relay] Enhance type infer for dynamic shape
    
    Support type_infer to enable unify PrimExpr such as
    tir.IndexMod(relay.Any(), 5)
    
    * fix a bug
    
    * fix lint
---
 src/relay/analysis/type_solver.cc     | 21 ++++++++++++++++++++-
 src/relay/analysis/type_solver.h      |  1 +
 tests/python/relay/test_type_infer.py |  8 ++++++++
 3 files changed, 29 insertions(+), 1 deletion(-)

diff --git a/src/relay/analysis/type_solver.cc 
b/src/relay/analysis/type_solver.cc
index b6af977071..79b340390b 100644
--- a/src/relay/analysis/type_solver.cc
+++ b/src/relay/analysis/type_solver.cc
@@ -25,6 +25,7 @@
 
 #include <tvm/ir/type_functor.h>
 #include <tvm/node/structural_equal.h>
+#include <tvm/tir/expr_functor.h>
 #include <tvm/tir/op.h>
 
 #include <memory>
@@ -76,6 +77,19 @@ class TypeSolver::Reporter : public TypeReporterNode {
   TypeSolver* solver_;
 };
 
+class TypeSolver::AnyChecker : public tir::ExprVisitor {
+ public:
+  void VisitExpr_(const AnyNode* op) final { found_ = true; }
+
+  bool Check(const PrimExpr& expr) {
+    tir::ExprVisitor::VisitExpr(expr);
+    return found_;
+  }
+
+ private:
+  bool found_{false};
+};
+
 class TypeSolver::OccursChecker : public TypeVisitor {
  public:
   explicit OccursChecker(TypeSolver* solver, TypeNode* var)
@@ -146,6 +160,11 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const 
Type&, const Type&)> {
     }
   }
 
+  bool HasAny(const PrimExpr& expr) {
+    AnyChecker ac;
+    return ac.Check(expr);
+  }
+
   // Checks whether lhs (taken to be a type var) occurs in t, meaning
   // there is a recursive equality constraint, which should be rejected.
   // N.b.: A tautology like ?a = ?a is okay and should be checked for
@@ -186,7 +205,7 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const 
Type&, const Type&)> {
     if (ulhs.same_as(urhs)) {
       return ulhs;
     }
-    if (ulhs.as<AnyNode>() || urhs.as<AnyNode>()) {
+    if (HasAny(ulhs) || HasAny(urhs)) {
       return Any();
     }
 
diff --git a/src/relay/analysis/type_solver.h b/src/relay/analysis/type_solver.h
index 7940e347b3..5d32afab64 100644
--- a/src/relay/analysis/type_solver.h
+++ b/src/relay/analysis/type_solver.h
@@ -97,6 +97,7 @@ class TypeSolver {
   void Emit(const Diagnostic& diag) { diag_ctx_.Emit(diag); }
 
  private:
+  class AnyChecker;
   class OccursChecker;
   class Unifier;
   class Resolver;
diff --git a/tests/python/relay/test_type_infer.py 
b/tests/python/relay/test_type_infer.py
index 1874555702..7fbb656b36 100644
--- a/tests/python/relay/test_type_infer.py
+++ b/tests/python/relay/test_type_infer.py
@@ -417,6 +417,14 @@ def test_dynamic_function():
     mod = transform.InferType()(mod)
     assert mod["main"].params[0].checked_type == s_tt
 
+    data = relay.var(
+        "data", shape=(relay.Any(), relay.Any(), relay.Any(), relay.Any()), 
dtype="float32"
+    )
+    weigth = relay.const(np.full((16, 16, 3, 3), 0.25), dtype="float32")
+    x = relay.nn.conv2d(data, weigth, kernel_size=(3, 3), channels=16, 
groups=2)
+    mod = tvm.IRModule.from_expr(x)
+    mod = transform.InferType()(mod)
+
 
 def test_custom_op_infer():
     """Tests infer type for custom_op"""

Reply via email to