This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new e86a470ce0 [Relay] Enhance type infer for dynamic shape (#14601)
e86a470ce0 is described below
commit e86a470ce091aeca2908d354363f650766a5c0f6
Author: shengxinhu <[email protected]>
AuthorDate: Mon Apr 17 10:44:04 2023 +0800
[Relay] Enhance type infer for dynamic shape (#14601)
* [Relay] Enhance type infer for dynamic shape
Support type_infer to enable unify PrimExpr such as
tir.IndexMod(relay.Any(), 5)
* fix a bug
* fix lint
---
src/relay/analysis/type_solver.cc | 21 ++++++++++++++++++++-
src/relay/analysis/type_solver.h | 1 +
tests/python/relay/test_type_infer.py | 8 ++++++++
3 files changed, 29 insertions(+), 1 deletion(-)
diff --git a/src/relay/analysis/type_solver.cc
b/src/relay/analysis/type_solver.cc
index b6af977071..79b340390b 100644
--- a/src/relay/analysis/type_solver.cc
+++ b/src/relay/analysis/type_solver.cc
@@ -25,6 +25,7 @@
#include <tvm/ir/type_functor.h>
#include <tvm/node/structural_equal.h>
+#include <tvm/tir/expr_functor.h>
#include <tvm/tir/op.h>
#include <memory>
@@ -76,6 +77,19 @@ class TypeSolver::Reporter : public TypeReporterNode {
TypeSolver* solver_;
};
+class TypeSolver::AnyChecker : public tir::ExprVisitor {
+ public:
+ void VisitExpr_(const AnyNode* op) final { found_ = true; }
+
+ bool Check(const PrimExpr& expr) {
+ tir::ExprVisitor::VisitExpr(expr);
+ return found_;
+ }
+
+ private:
+ bool found_{false};
+};
+
class TypeSolver::OccursChecker : public TypeVisitor {
public:
explicit OccursChecker(TypeSolver* solver, TypeNode* var)
@@ -146,6 +160,11 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const
Type&, const Type&)> {
}
}
+ bool HasAny(const PrimExpr& expr) {
+ AnyChecker ac;
+ return ac.Check(expr);
+ }
+
// Checks whether lhs (taken to be a type var) occurs in t, meaning
// there is a recursive equality constraint, which should be rejected.
// N.b.: A tautology like ?a = ?a is okay and should be checked for
@@ -186,7 +205,7 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const
Type&, const Type&)> {
if (ulhs.same_as(urhs)) {
return ulhs;
}
- if (ulhs.as<AnyNode>() || urhs.as<AnyNode>()) {
+ if (HasAny(ulhs) || HasAny(urhs)) {
return Any();
}
diff --git a/src/relay/analysis/type_solver.h b/src/relay/analysis/type_solver.h
index 7940e347b3..5d32afab64 100644
--- a/src/relay/analysis/type_solver.h
+++ b/src/relay/analysis/type_solver.h
@@ -97,6 +97,7 @@ class TypeSolver {
void Emit(const Diagnostic& diag) { diag_ctx_.Emit(diag); }
private:
+ class AnyChecker;
class OccursChecker;
class Unifier;
class Resolver;
diff --git a/tests/python/relay/test_type_infer.py
b/tests/python/relay/test_type_infer.py
index 1874555702..7fbb656b36 100644
--- a/tests/python/relay/test_type_infer.py
+++ b/tests/python/relay/test_type_infer.py
@@ -417,6 +417,14 @@ def test_dynamic_function():
mod = transform.InferType()(mod)
assert mod["main"].params[0].checked_type == s_tt
+ data = relay.var(
+ "data", shape=(relay.Any(), relay.Any(), relay.Any(), relay.Any()),
dtype="float32"
+ )
+ weigth = relay.const(np.full((16, 16, 3, 3), 0.25), dtype="float32")
+ x = relay.nn.conv2d(data, weigth, kernel_size=(3, 3), channels=16,
groups=2)
+ mod = tvm.IRModule.from_expr(x)
+ mod = transform.InferType()(mod)
+
def test_custom_op_infer():
"""Tests infer type for custom_op"""