This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new ab01abc [BugFix][Relay] Fix type relation for batch_matmul (#8376)
ab01abc is described below
commit ab01abc22460c6d8dcd6ed75e77b7224364014d8
Author: Tianqi Zhang (张天启) <[email protected]>
AuthorDate: Fri Jul 2 14:33:06 2021 +0800
[BugFix][Relay] Fix type relation for batch_matmul (#8376)
* fix type relation for batch_matmul
* fix lint
---
src/relay/op/nn/nn.cc | 3 +--
tests/python/relay/test_op_level10.py | 37 ++++++++++++++++++++++++++++-------
2 files changed, 31 insertions(+), 9 deletions(-)
diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc
index 4eaa12b..d09a849 100644
--- a/src/relay/op/nn/nn.cc
+++ b/src/relay/op/nn/nn.cc
@@ -974,9 +974,8 @@ bool BatchMatmulRel(const Array<Type>& types, int
num_inputs, const Attrs& attrs
ICHECK(reporter->AssertEQ(x->shape[2], y_shape[2]))
<< "BatchDot: shapes of x and y is inconsistent, "
<< " x shape=" << x->shape << ", y shape=" << y_shape;
-
- oshape.Set(2, y_shape[1]);
}
+ oshape.Set(2, y_shape[1]);
DataType out_dtype = param->out_dtype;
if (out_dtype.bits() == 0) {
diff --git a/tests/python/relay/test_op_level10.py
b/tests/python/relay/test_op_level10.py
index 24f0ed6..71598e6 100644
--- a/tests/python/relay/test_op_level10.py
+++ b/tests/python/relay/test_op_level10.py
@@ -359,9 +359,11 @@ def test_batch_matmul():
verify_batch_matmul((30, 16, 32), (30, 20, 32), (30, 16, 20))
-def verify_dynamic_batch_matmul(x_shape, y_shape, out_shape, dtype="float32"):
- x = relay.var("x", relay.TensorType(x_shape, dtype))
- y = relay.var("y", relay.TensorType((relay.Any(),) * len(y_shape), dtype))
+def verify_dynamic_batch_matmul(
+ x_shape, y_shape, out_shape, x_var_shape, y_var_shape, dtype="float32"
+):
+ x = relay.var("x", relay.TensorType(x_var_shape, dtype))
+ y = relay.var("y", relay.TensorType(y_var_shape, dtype))
z = relay.nn.batch_matmul(x, y)
func = relay.Function([x, y], z)
@@ -380,10 +382,31 @@ def verify_dynamic_batch_matmul(x_shape, y_shape,
out_shape, dtype="float32"):
# TODO(mbrookhart): enable once VM supports heterogenous execution
# @tvm.testing.uses_gpu
def test_dynamic_batch_matmul():
- verify_dynamic_batch_matmul((1, 16, 32), (1, 16, 32), (1, 16, 16))
- verify_dynamic_batch_matmul((5, 16, 32), (5, 16, 32), (5, 16, 16))
- verify_dynamic_batch_matmul((5, 16, 32), (5, 20, 32), (5, 16, 20))
- verify_dynamic_batch_matmul((30, 16, 32), (30, 20, 32), (30, 16, 20))
+ verify_dynamic_batch_matmul(
+ (1, 16, 32), (1, 16, 32), (1, 16, 16), (1, 16, 32), (relay.Any(),) * 3
+ )
+ verify_dynamic_batch_matmul(
+ (5, 16, 32), (5, 16, 32), (5, 16, 16), (5, 16, 32), (relay.Any(),) * 3
+ )
+ verify_dynamic_batch_matmul(
+ (5, 16, 32), (5, 20, 32), (5, 16, 20), (5, 16, 32), (relay.Any(),) * 3
+ )
+ verify_dynamic_batch_matmul(
+ (30, 16, 32), (30, 20, 32), (30, 16, 20), (30, 16, 32), (relay.Any(),)
* 3
+ )
+
+ verify_dynamic_batch_matmul(
+ (1, 16, 32), (1, 16, 32), (1, 16, 16), (relay.Any(), 16, 32),
(relay.Any(), 16, 32)
+ )
+ verify_dynamic_batch_matmul(
+ (5, 16, 32), (5, 16, 32), (5, 16, 16), (relay.Any(), 16, 32),
(relay.Any(), 16, 32)
+ )
+ verify_dynamic_batch_matmul(
+ (5, 16, 32), (5, 20, 32), (5, 16, 20), (relay.Any(), 16, 32),
(relay.Any(), 20, 32)
+ )
+ verify_dynamic_batch_matmul(
+ (30, 16, 32), (30, 20, 32), (30, 16, 20), (relay.Any(), 16, 32),
(relay.Any(), 20, 32)
+ )
@tvm.testing.uses_gpu