This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new ab01abc  [BugFix][Relay] Fix type relation for batch_matmul (#8376)
ab01abc is described below

commit ab01abc22460c6d8dcd6ed75e77b7224364014d8
Author: Tianqi Zhang (张天启) <[email protected]>
AuthorDate: Fri Jul 2 14:33:06 2021 +0800

    [BugFix][Relay] Fix type relation for batch_matmul (#8376)
    
    * fix type relation for batch_matmul
    
    * fix lint
---
 src/relay/op/nn/nn.cc                 |  3 +--
 tests/python/relay/test_op_level10.py | 37 ++++++++++++++++++++++++++++-------
 2 files changed, 31 insertions(+), 9 deletions(-)

diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc
index 4eaa12b..d09a849 100644
--- a/src/relay/op/nn/nn.cc
+++ b/src/relay/op/nn/nn.cc
@@ -974,9 +974,8 @@ bool BatchMatmulRel(const Array<Type>& types, int 
num_inputs, const Attrs& attrs
     ICHECK(reporter->AssertEQ(x->shape[2], y_shape[2]))
         << "BatchDot: shapes of x and y is inconsistent, "
         << " x shape=" << x->shape << ", y shape=" << y_shape;
-
-    oshape.Set(2, y_shape[1]);
   }
+  oshape.Set(2, y_shape[1]);
 
   DataType out_dtype = param->out_dtype;
   if (out_dtype.bits() == 0) {
diff --git a/tests/python/relay/test_op_level10.py 
b/tests/python/relay/test_op_level10.py
index 24f0ed6..71598e6 100644
--- a/tests/python/relay/test_op_level10.py
+++ b/tests/python/relay/test_op_level10.py
@@ -359,9 +359,11 @@ def test_batch_matmul():
     verify_batch_matmul((30, 16, 32), (30, 20, 32), (30, 16, 20))
 
 
-def verify_dynamic_batch_matmul(x_shape, y_shape, out_shape, dtype="float32"):
-    x = relay.var("x", relay.TensorType(x_shape, dtype))
-    y = relay.var("y", relay.TensorType((relay.Any(),) * len(y_shape), dtype))
+def verify_dynamic_batch_matmul(
+    x_shape, y_shape, out_shape, x_var_shape, y_var_shape, dtype="float32"
+):
+    x = relay.var("x", relay.TensorType(x_var_shape, dtype))
+    y = relay.var("y", relay.TensorType(y_var_shape, dtype))
     z = relay.nn.batch_matmul(x, y)
 
     func = relay.Function([x, y], z)
@@ -380,10 +382,31 @@ def verify_dynamic_batch_matmul(x_shape, y_shape, 
out_shape, dtype="float32"):
 # TODO(mbrookhart): enable once VM supports heterogenous execution
 # @tvm.testing.uses_gpu
 def test_dynamic_batch_matmul():
-    verify_dynamic_batch_matmul((1, 16, 32), (1, 16, 32), (1, 16, 16))
-    verify_dynamic_batch_matmul((5, 16, 32), (5, 16, 32), (5, 16, 16))
-    verify_dynamic_batch_matmul((5, 16, 32), (5, 20, 32), (5, 16, 20))
-    verify_dynamic_batch_matmul((30, 16, 32), (30, 20, 32), (30, 16, 20))
+    verify_dynamic_batch_matmul(
+        (1, 16, 32), (1, 16, 32), (1, 16, 16), (1, 16, 32), (relay.Any(),) * 3
+    )
+    verify_dynamic_batch_matmul(
+        (5, 16, 32), (5, 16, 32), (5, 16, 16), (5, 16, 32), (relay.Any(),) * 3
+    )
+    verify_dynamic_batch_matmul(
+        (5, 16, 32), (5, 20, 32), (5, 16, 20), (5, 16, 32), (relay.Any(),) * 3
+    )
+    verify_dynamic_batch_matmul(
+        (30, 16, 32), (30, 20, 32), (30, 16, 20), (30, 16, 32), (relay.Any(),) 
* 3
+    )
+
+    verify_dynamic_batch_matmul(
+        (1, 16, 32), (1, 16, 32), (1, 16, 16), (relay.Any(), 16, 32), 
(relay.Any(), 16, 32)
+    )
+    verify_dynamic_batch_matmul(
+        (5, 16, 32), (5, 16, 32), (5, 16, 16), (relay.Any(), 16, 32), 
(relay.Any(), 16, 32)
+    )
+    verify_dynamic_batch_matmul(
+        (5, 16, 32), (5, 20, 32), (5, 16, 20), (relay.Any(), 16, 32), 
(relay.Any(), 20, 32)
+    )
+    verify_dynamic_batch_matmul(
+        (30, 16, 32), (30, 20, 32), (30, 16, 20), (relay.Any(), 16, 32), 
(relay.Any(), 20, 32)
+    )
 
 
 @tvm.testing.uses_gpu

Reply via email to