This is an automated email from the ASF dual-hosted git repository.

comaniac pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 41879b2  [FIX] Bug fix for batch_matmul parameters mismatch (#8785)
41879b2 is described below

commit 41879b2552364f094492470a77a3ec0866b30eae
Author: Chenfan <[email protected]>
AuthorDate: Thu Aug 19 11:15:34 2021 +0800

    [FIX] Bug fix for batch_matmul parameters mismatch (#8785)
---
 python/tvm/topi/cuda/batch_matmul.py            | 13 ++++++++++++-
 python/tvm/topi/cuda/batch_matmul_tensorcore.py |  9 +++++++--
 python/tvm/topi/rocm/batch_matmul.py            |  7 +++++--
 3 files changed, 24 insertions(+), 5 deletions(-)

diff --git a/python/tvm/topi/cuda/batch_matmul.py 
b/python/tvm/topi/cuda/batch_matmul.py
index 3fc8a58..bd556d2 100644
--- a/python/tvm/topi/cuda/batch_matmul.py
+++ b/python/tvm/topi/cuda/batch_matmul.py
@@ -237,7 +237,9 @@ def schedule_batch_matmul_cublas(_, outs):
 
 
 @autotvm.register_topi_compute("batch_matmul_int8.cuda")
-def batch_matmul_int8(cfg, x, y, out_shape=None, out_dtype=None):
+def batch_matmul_int8(
+    cfg, x, y, out_shape=None, out_dtype=None, transpose_a=False, 
transpose_b=True
+):
     """Batch Matmul operator for int8 on CUDA.
 
     Parameters
@@ -258,11 +260,20 @@ def batch_matmul_int8(cfg, x, y, out_shape=None, 
out_dtype=None):
     out_dtype : Optional[str]
         Specifies the output data type for mixed precision batch matmul.
 
+    transpose_a : Optional[bool] = False
+        Whether the first tensor is in transposed format.
+
+    transpose_b : Optional[bool] = True
+        Whether the second tensor is in transposed format.
+
     Returns
     -------
     output : tvm.te.Tensor
         3-D with shape [batch, M, N]
     """
+    del out_shape
+    # TODO(jcf94): Deal with different transpose combinations
+    assert not transpose_a and transpose_b
     if out_dtype is None:
         out_dtype = x.dtype
 
diff --git a/python/tvm/topi/cuda/batch_matmul_tensorcore.py 
b/python/tvm/topi/cuda/batch_matmul_tensorcore.py
index a56d3c3..5324302 100644
--- a/python/tvm/topi/cuda/batch_matmul_tensorcore.py
+++ b/python/tvm/topi/cuda/batch_matmul_tensorcore.py
@@ -29,9 +29,14 @@ from .tensor_intrin import (
 
 
 @autotvm.register_topi_compute("batch_matmul_tensorcore.cuda")
-def batch_matmul_tensorcore(cfg, x, y, out_shape=None, out_dtype=None):
+def batch_matmul_tensorcore(
+    cfg, x, y, out_shape=None, out_dtype=None, transpose_a=False, 
transpose_b=True
+):
     """batch matmul tensorcore operator on cuda"""
-    # todo: deal with out_shape for broadcast, liuxin.ai
+    # TODO(jcf94): Deal with different transpose combinations
+    assert not transpose_a and transpose_b
+    # TODO(liuxin.ai): Deal with out_shape for broadcast
+    del out_shape
     return batch_matmul_tensorcore_cuda(x, y, out_dtype)
 
 
diff --git a/python/tvm/topi/rocm/batch_matmul.py 
b/python/tvm/topi/rocm/batch_matmul.py
index 7f35f4b..53b51ee 100644
--- a/python/tvm/topi/rocm/batch_matmul.py
+++ b/python/tvm/topi/rocm/batch_matmul.py
@@ -23,7 +23,9 @@ from ..utils import get_const_tuple
 
 
 @autotvm.register_topi_compute("batch_matmul_rocblas.rocm")
-def batch_matmul_rocblas(cfg, x, y, out_shape=None):
+def batch_matmul_rocblas(
+    cfg, x, y, out_shape=None, out_dtype=None, transpose_a=False, 
transpose_b=True
+):
     """Computes matrix multiplication of `x` and `y` via rocblas when
     `x` and `y` are batched matrices.
 
@@ -40,12 +42,13 @@ def batch_matmul_rocblas(cfg, x, y, out_shape=None):
     output : tvm.te.Tensor
         3-D with shape [batch, M, N]
     """
+    del out_dtype
     batch, M, K = get_const_tuple(x.shape)
     _, N, _ = get_const_tuple(y.shape)
     if out_shape is not None:
         assert out_shape[0] == batch, "Input and output batch sizes must match"
         assert out_shape[1] == M and out_shape[2] == N, "Invalid output shape"
-    result = rocblas.batch_matmul(x, y, False, True)
+    result = rocblas.batch_matmul(x, y, transpose_a, transpose_b)
     cfg.add_flop(batch * M * N * K * 2)
     return result
 

Reply via email to