This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new 22ad0dd Add rocblas_sgemm_strided_batched impl. (#6579)
22ad0dd is described below
commit 22ad0dd4cc21f44a605ec2552945bd5f068a6959
Author: Chris Sullivan <[email protected]>
AuthorDate: Tue Sep 29 02:59:54 2020 -0700
Add rocblas_sgemm_strided_batched impl. (#6579)
---
include/tvm/topi/contrib/rocblas.h | 23 ++++++++++++
python/tvm/contrib/rocblas.py | 33 +++++++++++++++++
python/tvm/relay/op/strategy/rocm.py | 21 +++++++++++
python/tvm/topi/rocm/__init__.py | 1 +
python/tvm/topi/rocm/batch_matmul.py | 56 +++++++++++++++++++++++++++++
src/runtime/contrib/rocblas/rocblas.cc | 59 +++++++++++++++++++++++++-----
tests/python/contrib/test_rocblas.py | 66 ++++++++++++++++++++++++++++++++--
7 files changed, 249 insertions(+), 10 deletions(-)
diff --git a/include/tvm/topi/contrib/rocblas.h
b/include/tvm/topi/contrib/rocblas.h
index a4fa26f..4f0b887 100644
--- a/include/tvm/topi/contrib/rocblas.h
+++ b/include/tvm/topi/contrib/rocblas.h
@@ -54,6 +54,29 @@ inline Tensor rocblas_matmul(const Tensor& lhs, const
Tensor& rhs, bool transa,
},
"C", "", {})[0];
}
+/*!
+ * \brief Create an op that batch multiplies lhs and rhs with rocBLAS
+ *
+ * \param lhs The left matrix operand e.g. (batch_size, M, K)
+ * \param rhs The right matrix operand e.g. (batch_size, K, N)
+ * \param transa Whether to transpose lhs
+ * \param transb Whether to transpose rhs
+ *
+ * \return The output tensor
+ */
+inline Tensor rocblas_batch_matmul(const Tensor& lhs, const Tensor& rhs, bool
transa, bool transb) {
+ auto batch_size = lhs->shape[0];
+ auto n = transa ? lhs->shape[2] : lhs->shape[1];
+ auto m = transb ? rhs->shape[1] : rhs->shape[2];
+
+ return make_extern(
+ {{batch_size, n, m}}, {lhs->dtype}, {lhs, rhs},
+ [&](Array<Buffer> ins, Array<Buffer> outs) {
+ return call_packed({StringImm("tvm.contrib.rocblas.batch_matmul"),
pack_buffer(ins[0]),
+ pack_buffer(ins[1]), pack_buffer(outs[0]), transa,
transb});
+ },
+ "C", "", {})[0];
+}
} // namespace contrib
} // namespace topi
diff --git a/python/tvm/contrib/rocblas.py b/python/tvm/contrib/rocblas.py
index 03ea2b5..70791dc 100644
--- a/python/tvm/contrib/rocblas.py
+++ b/python/tvm/contrib/rocblas.py
@@ -48,3 +48,36 @@ def matmul(lhs, rhs, transa=False, transb=False):
),
name="C",
)
+
+
+def batch_matmul(lhs, rhs, transa=False, transb=False):
+ """Create an extern op that compute matrix mult of A and rhs with rocBLAS
+
+ Parameters
+ ----------
+ lhs : Tensor
+ The left batched matrix operand
+ rhs : Tensor
+ The right batched matrix operand
+ transa : bool
+ Whether transpose lhs
+ transb : bool
+ Whether transpose rhs
+
+ Returns
+ -------
+ C : Tensor
+ The result tensor.
+ """
+ batch_size = lhs.shape[0]
+ assert batch_size == rhs.shape[0]
+ n = lhs.shape[2] if transa else lhs.shape[1]
+ m = rhs.shape[1] if transb else rhs.shape[2]
+ return te.extern(
+ (batch_size, n, m),
+ [lhs, rhs],
+ lambda ins, outs: tvm.tir.call_packed(
+ "tvm.contrib.rocblas.batch_matmul", ins[0], ins[1], outs[0],
transa, transb
+ ),
+ name="C",
+ )
diff --git a/python/tvm/relay/op/strategy/rocm.py
b/python/tvm/relay/op/strategy/rocm.py
index 2410260..f52bbc3 100644
--- a/python/tvm/relay/op/strategy/rocm.py
+++ b/python/tvm/relay/op/strategy/rocm.py
@@ -160,3 +160,24 @@ def dense_strategy_rocm(attrs, inputs, out_type, target):
plevel=15,
)
return strategy
+
+
+@batch_matmul_strategy.register("rocm")
+def batch_matmul_strategy_rocm(attrs, inputs, out_type, target):
+ """Batch matmul strategy for ROCM"""
+ strategy = _op.OpStrategy()
+ strategy.add_implementation(
+ wrap_compute_batch_matmul(topi.cuda.batch_matmul),
+ wrap_topi_schedule(topi.cuda.schedule_batch_matmul),
+ name="batch_matmul.cuda",
+ plevel=10,
+ )
+ if target.kind.name == "rocm" and "rocblas" in target.libs:
+ assert out_type.dtype == inputs[0].dtype, "Mixed precision not
supported."
+ strategy.add_implementation(
+ wrap_compute_batch_matmul(topi.rocm.batch_matmul_rocblas),
+ wrap_topi_schedule(topi.rocm.schedule_batch_matmul_rocblas),
+ name="batch_matmul_rocblas.rocm",
+ plevel=12,
+ )
+ return strategy
diff --git a/python/tvm/topi/rocm/__init__.py b/python/tvm/topi/rocm/__init__.py
index 4efdab4..1ea4c79 100644
--- a/python/tvm/topi/rocm/__init__.py
+++ b/python/tvm/topi/rocm/__init__.py
@@ -19,6 +19,7 @@
"""rocm specific declaration and schedules."""
from __future__ import absolute_import as _abs
+from .batch_matmul import *
from .conv2d import *
from .dense import *
from .nn import *
diff --git a/python/tvm/topi/rocm/batch_matmul.py
b/python/tvm/topi/rocm/batch_matmul.py
new file mode 100644
index 0000000..fa4dd45
--- /dev/null
+++ b/python/tvm/topi/rocm/batch_matmul.py
@@ -0,0 +1,56 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, unused-variable, unused-argument
+"""Schedule for batch_matmul operator"""
+from tvm import autotvm
+from tvm.contrib import rocblas
+from .. import generic
+from ..util import get_const_tuple
+
+
[email protected]_topi_compute("batch_matmul_rocblas.rocm")
+def batch_matmul_rocblas(cfg, x, y, out_shape=None):
+ """Computes matrix multiplication of `x` and `y` via rocblas when
+ `x` and `y` are batched matrices.
+
+ Parameters
+ ----------
+ cfg : ConfigSpace
+ Autotvm tuning space config file
+ x : tvm.te.Tensor
+ 3-D with shape [batch, M, K]
+ y : tvm.te.Tensor
+ 3-D with shape [batch, N, K]
+ Returns
+ -------
+ output : tvm.te.Tensor
+ 3-D with shape [batch, M, N]
+ """
+ batch, M, K = get_const_tuple(x.shape)
+ _, N, _ = get_const_tuple(y.shape)
+ if out_shape is not None:
+ assert out_shape[0] == batch, "Input and output batch sizes must match"
+ assert out_shape[1] == M and out_shape[2] == N, "Invalid output shape"
+ result = rocblas.batch_matmul(x, y, False, True)
+ cfg.add_flop(batch * M * N * K * 2)
+ return result
+
+
[email protected]_topi_schedule("batch_matmul_rocblas.rocm")
+def schedule_batch_matmul_rocblas(_, outs):
+ """Schedule for batch_matmul operator with rocm cblas"""
+ return generic.schedule_extern(outs)
diff --git a/src/runtime/contrib/rocblas/rocblas.cc
b/src/runtime/contrib/rocblas/rocblas.cc
index 0e6f4bd..bca00a5 100644
--- a/src/runtime/contrib/rocblas/rocblas.cc
+++ b/src/runtime/contrib/rocblas/rocblas.cc
@@ -70,18 +70,61 @@
TVM_REGISTER_GLOBAL("tvm.contrib.rocblas.matmul").set_body([](TVMArgs args, TVMR
CHECK_ROCBLAS_ERROR(rocblas_create_handle(&handle));
float alpha = 1.0;
float beta = 0.0;
- float* A_ptr = reinterpret_cast<float*>(static_cast<char*>(B->data) +
B->byte_offset);
- float* B_ptr = reinterpret_cast<float*>(static_cast<char*>(A->data) +
A->byte_offset);
+ float* A_ptr = reinterpret_cast<float*>(static_cast<char*>(A->data) +
A->byte_offset);
+ float* B_ptr = reinterpret_cast<float*>(static_cast<char*>(B->data) +
B->byte_offset);
float* C_ptr = reinterpret_cast<float*>(static_cast<char*>(C->data) +
C->byte_offset);
- CHECK_ROCBLAS_ERROR(
- rocblas_sgemm(handle, transb ? rocblas_operation_transpose :
rocblas_operation_none,
- transa ? rocblas_operation_transpose :
rocblas_operation_none,
- transb ? B->shape[0] : B->shape[1], transa ? A->shape[1] :
A->shape[0],
- transb ? B->shape[1] : B->shape[0], &alpha, A_ptr,
B->shape[1], B_ptr,
- A->shape[1], &beta, C_ptr, C->shape[1]));
+ rocblas_operation roc_trans_A = transa ? rocblas_operation_transpose :
rocblas_operation_none;
+ rocblas_operation roc_trans_B = transb ? rocblas_operation_transpose :
rocblas_operation_none;
+ size_t N = transb ? B->shape[0] : B->shape[1];
+ size_t M = transa ? A->shape[1] : A->shape[0];
+ size_t K = transb ? B->shape[1] : B->shape[0];
+ size_t lda = transa ? M : K;
+ size_t ldb = transb ? K : N;
+ size_t ldc = N;
+
+ CHECK_ROCBLAS_ERROR(rocblas_sgemm(handle, roc_trans_B, roc_trans_A, N, M, K,
&alpha, B_ptr, ldb,
+ A_ptr, lda, &beta, C_ptr, ldc));
CHECK_ROCBLAS_ERROR(rocblas_destroy_handle(handle));
});
+
+TVM_REGISTER_GLOBAL("tvm.contrib.rocblas.batch_matmul")
+ .set_body([](TVMArgs args, TVMRetValue* ret) {
+ DLTensor* A = args[0];
+ DLTensor* B = args[1];
+ DLTensor* C = args[2];
+ bool transa = args[3];
+ bool transb = args[4];
+ // call gemm for simple compact code.
+ CHECK_EQ(A->ndim, 3);
+ CHECK_EQ(B->ndim, 3);
+ CHECK_EQ(C->ndim, 3);
+ CHECK(TypeMatch(A->dtype, kDLFloat, 32));
+ CHECK(TypeMatch(B->dtype, kDLFloat, 32));
+ CHECK(TypeMatch(C->dtype, kDLFloat, 32));
+
+ rocblas_handle handle;
+ CHECK_ROCBLAS_ERROR(rocblas_create_handle(&handle));
+ float alpha = 1.0;
+ float beta = 0.0;
+ float* A_ptr = reinterpret_cast<float*>(static_cast<char*>(A->data) +
A->byte_offset);
+ float* B_ptr = reinterpret_cast<float*>(static_cast<char*>(B->data) +
B->byte_offset);
+ float* C_ptr = reinterpret_cast<float*>(static_cast<char*>(C->data) +
C->byte_offset);
+
+ rocblas_operation roc_trans_A = transa ? rocblas_operation_transpose :
rocblas_operation_none;
+ rocblas_operation roc_trans_B = transb ? rocblas_operation_transpose :
rocblas_operation_none;
+ size_t batch_size = C->shape[0];
+ size_t N = transb ? B->shape[1] : B->shape[2];
+ size_t M = transa ? A->shape[2] : A->shape[1];
+ size_t K = transb ? B->shape[2] : B->shape[1];
+ size_t lda = transa ? M : K;
+ size_t ldb = transb ? K : N;
+ size_t ldc = N;
+
+ CHECK_ROCBLAS_ERROR(rocblas_sgemm_strided_batched(
+ handle, roc_trans_B, roc_trans_A, N, M, K, &alpha, B_ptr, ldb, K *
N, A_ptr, lda, M * K,
+ &beta, C_ptr, ldc, M * N, batch_size));
+ });
} // namespace contrib
} // namespace tvm
diff --git a/tests/python/contrib/test_rocblas.py
b/tests/python/contrib/test_rocblas.py
index 9b8bacb..6f1783d 100644
--- a/tests/python/contrib/test_rocblas.py
+++ b/tests/python/contrib/test_rocblas.py
@@ -18,11 +18,13 @@ import tvm
import tvm.testing
from tvm import te
import numpy as np
+import tvm.topi.testing
+import tvm.testing
from tvm.contrib import rocblas
@tvm.testing.requires_rocm
-def test_matmul_add():
+def test_matmul():
n = 1024
l = 128
m = 235
@@ -46,5 +48,65 @@ def test_matmul_add():
verify()
+def verify_batch_matmul(batch, m, k, n, lib, transa=False, transb=False,
dtype="float32"):
+ ashape = (batch, k, m) if transa else (batch, m, k)
+ bshape = (batch, n, k) if transb else (batch, k, n)
+ A = te.placeholder(ashape, name="A", dtype=dtype)
+ B = te.placeholder(bshape, name="B", dtype=dtype)
+ C = lib.batch_matmul(A, B, transa, transb)
+ s = te.create_schedule(C.op)
+
+ def get_numpy(a, b, transa, transb):
+ if transa:
+ a = a.transpose(0, 2, 1)
+ if not transb:
+ b = b.transpose(0, 2, 1)
+ return tvm.topi.testing.batch_matmul(a, b)
+
+ def verify(target="rocm"):
+ if not tvm.testing.device_enabled(target):
+ print("skip because %s is not enabled..." % target)
+ return
+ if not tvm.get_global_func(lib.__name__ + ".batch_matmul", True):
+ print("skip because extern function is not available")
+ return
+ ctx = tvm.rocm(0)
+ f = tvm.build(s, [A, B, C], target)
+ a = tvm.nd.array(np.random.uniform(size=ashape).astype(A.dtype), ctx)
+ b = tvm.nd.array(np.random.uniform(size=bshape).astype(B.dtype), ctx)
+ c = tvm.nd.array(np.zeros((batch, m, n), dtype=C.dtype), ctx)
+ f(a, b, c)
+ tvm.testing.assert_allclose(
+ c.asnumpy(), get_numpy(a.asnumpy(), b.asnumpy(), transa, transb),
rtol=1e-5
+ )
+
+ verify()
+
+
[email protected]_rocm
+def test_batch_matmul():
+ verify_batch_matmul(128, 64, 512, 512, rocblas, transa=False, transb=False)
+ verify_batch_matmul(128, 64, 512, 512, rocblas, transa=False, transb=True)
+ verify_batch_matmul(128, 64, 512, 512, rocblas, transa=True, transb=False)
+ verify_batch_matmul(128, 64, 512, 512, rocblas, transa=True, transb=True)
+ verify_batch_matmul(128, 512, 512, 64, rocblas, transa=False, transb=False)
+ verify_batch_matmul(128, 512, 512, 64, rocblas, transa=False, transb=True)
+ verify_batch_matmul(128, 512, 512, 64, rocblas, transa=True, transb=False)
+ verify_batch_matmul(128, 512, 512, 64, rocblas, transa=True, transb=True)
+ verify_batch_matmul(128, 512, 64, 512, rocblas, transa=False, transb=False)
+ verify_batch_matmul(128, 512, 64, 512, rocblas, transa=False, transb=True)
+ verify_batch_matmul(128, 512, 64, 512, rocblas, transa=True, transb=False)
+ verify_batch_matmul(128, 512, 64, 512, rocblas, transa=True, transb=True)
+ verify_batch_matmul(128, 64, 128, 128, rocblas, transa=False, transb=False)
+ verify_batch_matmul(128, 64, 128, 128, rocblas, transa=False, transb=True)
+ verify_batch_matmul(128, 64, 128, 128, rocblas, transa=True, transb=False)
+ verify_batch_matmul(128, 64, 128, 128, rocblas, transa=True, transb=True)
+ verify_batch_matmul(128, 128, 128, 64, rocblas, transa=False, transb=False)
+ verify_batch_matmul(128, 128, 128, 64, rocblas, transa=False, transb=True)
+ verify_batch_matmul(128, 128, 128, 64, rocblas, transa=True, transb=False)
+ verify_batch_matmul(128, 128, 128, 64, rocblas, transa=True, transb=True)
+
+
if __name__ == "__main__":
- test_matmul_add()
+ test_matmul()
+ test_batch_matmul()