This is an automated email from the ASF dual-hosted git repository.

jcf94 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 89e3688  [CUDA]batch_matmul tensorcore schedule (#7146)
89e3688 is described below

commit 89e3688137b9d8dd0e431cbdadb84d42dae9eee3
Author: Meteorix <[email protected]>
AuthorDate: Mon Jan 11 09:51:31 2021 +0800

    [CUDA]batch_matmul tensorcore schedule (#7146)
    
    * add batch_matmul_tensorcore
    
    * add bmm cublas autotune
    
    * add bmm tests
    
    * out_shape for bmm_tensorcore
    
    * fix comments
    
    * code format
    
    * add todos for tensorcore datatype checking
    
    * fix lint
    
    * fix have_tensorcore
    
    * add dtype check for batch_matmul_tensorcore
---
 python/tvm/relay/op/strategy/cuda.py               |  16 ++
 python/tvm/topi/cuda/__init__.py                   |   1 +
 python/tvm/topi/cuda/batch_matmul.py               |  14 +-
 python/tvm/topi/cuda/batch_matmul_tensorcore.py    | 315 +++++++++++++++++++++
 python/tvm/topi/cuda/conv2d_nhwc_tensorcore.py     |   1 +
 python/tvm/topi/cuda/conv3d_ndhwc_tensorcore.py    |   1 +
 python/tvm/topi/cuda/dense_tensorcore.py           |   1 +
 .../python/test_topi_batch_matmul_tensorcore.py    |  75 +++++
 8 files changed, 422 insertions(+), 2 deletions(-)

diff --git a/python/tvm/relay/op/strategy/cuda.py 
b/python/tvm/relay/op/strategy/cuda.py
index 37946c0..04c16dd 100644
--- a/python/tvm/relay/op/strategy/cuda.py
+++ b/python/tvm/relay/op/strategy/cuda.py
@@ -732,6 +732,22 @@ def batch_matmul_strategy_cuda(attrs, inputs, out_type, 
target):
             name="batch_matmul_cublas.cuda",
             plevel=15,
         )
+    if target.kind.name == "cuda" and nvcc.have_tensorcore(target=target):
+        x, y = inputs
+        _, M, K = get_const_tuple(x.shape)
+        _, N, K = get_const_tuple(y.shape)
+        if x.dtype in ["float16", "int8", "uint8"] and (
+            (M % 8 == 0 and K % 16 == 0 and N % 32 == 0)
+            or (M % 16 == 0 and K % 16 == 0 and N % 16 == 0)
+            or (M % 32 == 0 and K % 16 == 0 and N % 8 == 0)
+        ):
+            strategy.add_implementation(
+                wrap_compute_batch_matmul(topi.cuda.batch_matmul_tensorcore),
+                wrap_topi_schedule(topi.cuda.schedule_batch_matmul_tensorcore),
+                name="batch_matmul_tensorcore.cuda",
+                plevel=20,
+            )
+
     return strategy
 
 
diff --git a/python/tvm/topi/cuda/__init__.py b/python/tvm/topi/cuda/__init__.py
index 23c625a..42bf980 100644
--- a/python/tvm/topi/cuda/__init__.py
+++ b/python/tvm/topi/cuda/__init__.py
@@ -42,6 +42,7 @@ from .dense import *
 from .pooling import *
 from .nn import schedule_lrn
 from .batch_matmul import *
+from .batch_matmul_tensorcore import *
 from .vision import *
 from .ssd import *
 from .nms import get_valid_counts, non_max_suppression
diff --git a/python/tvm/topi/cuda/batch_matmul.py 
b/python/tvm/topi/cuda/batch_matmul.py
index 8d34b29..006b866 100644
--- a/python/tvm/topi/cuda/batch_matmul.py
+++ b/python/tvm/topi/cuda/batch_matmul.py
@@ -21,7 +21,7 @@ from tvm import autotvm
 from tvm import te
 from tvm.contrib import cublas
 from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity
-from .. import nn
+from .. import nn, generic
 from ..utils import traverse_inline, get_const_tuple, get_max_power2_factor
 
 
@@ -138,7 +138,8 @@ def schedule_batch_matmul(cfg, outs):
     return s
 
 
-def batch_matmul_cublas(x, y, out_shape=None):
[email protected]_topi_compute("batch_matmul_cublas.cuda")
+def batch_matmul_cublas(cfg, x, y, out_shape=None):
     """Computes batch matrix multiplication of `x` and `y` when `x` and `y` are
     data in batch.
 
@@ -158,4 +159,13 @@ def batch_matmul_cublas(x, y, out_shape=None):
     output : tvm.te.Tensor
         3-D with shape [batch, M, N]
     """
+    b, m, k = x.shape
+    b, n, k = y.shape
+    cfg.add_flop(b * m * k * n * 2)
     return cublas.batch_matmul(x, y, False, True)
+
+
[email protected]_topi_schedule("batch_matmul_cublas.cuda")
+def schedule_batch_matmul_cublas(_, outs):
+    """Schedule batch_matmul operator using CUBLAS"""
+    return generic.schedule_extern(outs)
diff --git a/python/tvm/topi/cuda/batch_matmul_tensorcore.py 
b/python/tvm/topi/cuda/batch_matmul_tensorcore.py
new file mode 100644
index 0000000..59b92ec
--- /dev/null
+++ b/python/tvm/topi/cuda/batch_matmul_tensorcore.py
@@ -0,0 +1,315 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name,too-many-locals,unused-variable,unused-argument
+"""cuda batch_matmul operators"""
+import tvm
+from tvm import autotvm
+from tvm import te
+from ..utils import traverse_inline, get_const_tuple
+from .tensor_intrin import (
+    intrin_wmma_load_matrix_A,
+    intrin_wmma_load_matrix_W,
+    intrin_wmma_store_matrix,
+    intrin_wmma_gemm,
+)
+
+
[email protected]_topi_compute("batch_matmul_tensorcore.cuda")
+def batch_matmul_tensorcore(cfg, x, y, out_shape=None):
+    """batch matmul tensorcore operator on cuda"""
+    # todo: deal with out_shape for broadcast, liuxin.ai
+    return batch_matmul_tensorcore_cuda(x, y)
+
+
[email protected]_topi_schedule("batch_matmul_tensorcore.cuda")
+def schedule_batch_matmul_tensorcore(cfg, outs):
+    """Schedule for batch_matmul operator using Tensorcore
+
+    Parameters
+    ----------
+    outs: Array of Tensor
+          The computation graph description of batch_matmul
+          in the format of an array of tensors.
+
+    Returns
+    -------
+    s: Schedule
+        The computation schedule for the op.
+    """
+    outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
+    s = te.create_schedule([x.op for x in outs])
+
+    def _schedule(cfg, s, C):
+        A, B = s[C].op.input_tensors
+        batch, m_dim, k_dim = get_const_tuple(A.shape)
+        batch, n_dim, k_dim = get_const_tuple(B.shape)
+        out_dtype = C.dtype
+        # inline astype fp16
+        s[A].compute_inline()
+        s[B].compute_inline()
+
+        # Explicit memory access
+        AS = s.cache_read(A, "shared", [C])
+        BS = s.cache_read(B, "shared", [C])
+        AF = s.cache_read(AS, "wmma.matrix_a", [C])
+        BF = s.cache_read(BS, "wmma.matrix_b", [C])
+        CF = s.cache_write(C, "wmma.accumulator")
+        CS = s.cache_read(CF, "shared", [C])
+
+        # fallback support
+        target = tvm.target.Target.current()
+        if cfg.is_fallback:
+            ref_log = autotvm.tophub.load_reference_log(
+                target.kind.name, target.model, "batch_matmul_tensorcore.cuda"
+            )
+            cfg.fallback_with_reference_log(ref_log)
+
+        # Deal with op fusion, such as bias/relu and slice after padding
+        if C.op not in s.outputs and "injective" in s.outputs[0].tag:
+            s[C].compute_inline()
+            C = s.outputs[0].output(0)
+
+        # create tuning space
+        cfg.define_knob("block_row_warps", [1, 2, 4])
+        cfg.define_knob("block_col_warps", [1, 2, 4])
+        cfg.define_knob("warp_row_tiles", [1, 2, 4])
+        cfg.define_knob("warp_col_tiles", [1, 2, 4])
+        cfg.define_knob("chunk", [1, 2, 4, 8])
+        cfg.define_knob("offset", [0, 8])
+        cfg.define_knob("offsetCS", [0, 8])
+        cfg.define_knob("vec", [1, 2, 4, 8])
+
+        # Ensure that the default parameters are applicable when autotvm is 
not in use
+        if m_dim % 32 == 0 and n_dim % 8 == 0:
+            cfg.define_knob("wmma_m", [32, 16, 8])
+        elif m_dim % 16 == 0 and n_dim % 16 == 0:
+            cfg.define_knob("wmma_m", [16, 8, 32])
+        elif m_dim % 8 == 0 and n_dim % 32 == 0:
+            cfg.define_knob("wmma_m", [8, 16, 32])
+
+        warp_size = 32
+        wmma_k = 16
+        block_row_warps = cfg["block_row_warps"].val
+        block_col_warps = cfg["block_col_warps"].val
+        warp_row_tiles = cfg["warp_row_tiles"].val
+        warp_col_tiles = cfg["warp_col_tiles"].val
+        chunk = cfg["chunk"].val
+        offset = cfg["offset"].val
+        offsetCS = cfg["offsetCS"].val
+        wmma_m = cfg["wmma_m"].val
+        vec = cfg["vec"].val
+
+        if wmma_m == 16:
+            wmma_n = 16
+        elif wmma_m == 8:
+            wmma_n = 32
+        elif wmma_m == 32:
+            wmma_n = 8
+
+        # Define the stride of intrin functions
+        AS_align = chunk * wmma_k + offset
+        BS_align = chunk * wmma_k + offset
+        CS_align = warp_col_tiles * block_col_warps * wmma_n + offsetCS
+        AS_stride = [AS_align, 1]
+        BS_stride = [BS_align, 1]
+        AF_stride = [wmma_k, 1]
+        BF_stride = [wmma_k, 1]
+        CF_stride = [warp_col_tiles * wmma_n, 1]
+        CS_stride = [CS_align, 1]
+
+        block_x = te.thread_axis("blockIdx.x")
+        block_y = te.thread_axis("blockIdx.y")
+        block_z = te.thread_axis("blockIdx.z")
+        thread_x = te.thread_axis("threadIdx.x")
+        thread_y = te.thread_axis("threadIdx.y")
+        thread_z = te.thread_axis("threadIdx.z")
+
+        # Schedule for dense computation
+        block_factor_m = wmma_m * warp_row_tiles * block_row_warps
+        block_factor_n = wmma_n * warp_col_tiles * block_col_warps
+        b, m, n = C.op.axis
+        block_i, bc = s[C].split(m, factor=block_factor_m)
+        block_j, oc = s[C].split(n, factor=block_factor_n)
+        s[C].reorder(b, block_i, block_j, bc, oc)
+        t = s[C].fuse(bc, oc)
+        t, vi = s[C].split(t, factor=vec)
+        t, tx = s[C].split(t, factor=warp_size)
+        t, ty = s[C].split(t, factor=block_row_warps)
+        t, tz = s[C].split(t, factor=block_col_warps)
+        s[C].bind(block_i, block_x)
+        s[C].bind(block_j, block_y)
+        s[C].bind(b, block_z)
+        s[C].bind(tz, thread_z)
+        s[C].bind(ty, thread_y)
+        s[C].bind(tx, thread_x)
+        s[C].vectorize(vi)
+
+        # Schedule for wmma store
+        s[CS].compute_at(s[C], block_j)
+        bs, bb, oo = CS.op.axis
+        s[CS].storage_align(bb, CS_align - 1, CS_align)
+        bb, bbi = s[CS].split(bb, factor=wmma_m)
+        oo, ooi = s[CS].split(oo, factor=wmma_n)
+        bb, bbii = s[CS].split(bb, factor=warp_row_tiles)
+        oo, ooii = s[CS].split(oo, factor=warp_col_tiles)
+        s[CS].reorder(bs, bb, oo, bbii, ooii, bbi, ooi)
+
+        # Schedule for wmma computation
+        s[CF].compute_at(s[CS], oo)
+        bs, warp_i, warp_j = CF.op.axis
+        warp_i, _ii = s[CF].split(warp_i, factor=wmma_m)
+        warp_j, _jj = s[CF].split(warp_j, factor=wmma_n)
+        (k,) = CF.op.reduce_axis
+        k, _k = s[CF].split(k, factor=wmma_k)
+        ko, ki = s[CF].split(k, factor=chunk)
+        s[CF].reorder(bs, ko, ki, warp_i, warp_j, _ii, _jj, _k)
+
+        # Schedule for  wmma_matrix_a load
+        s[AF].compute_at(s[CF], ki)
+        bs, b, i = AF.op.axis
+        b, b_ii = s[AF].split(b, factor=wmma_m)
+        i, i_jj = s[AF].split(i, factor=wmma_k)
+        s[AF].reorder(bs, b, i, b_ii, i_jj)
+
+        # Schedule for  wmma_matrix_b load
+        s[BF].compute_at(s[CF], ki)
+        bs, o, i = BF.op.axis
+        o, o_ii = s[BF].split(o, factor=wmma_n)
+        i, i_ii = s[BF].split(i, factor=wmma_k)
+        s[BF].reorder(bs, o, i, o_ii, i_ii)
+
+        # Schedule for A's(B's) shared memory load
+        def shared_shedule(stage, strides):
+            s[stage].compute_at(s[CF], ko)
+            bs, xo, yo = stage.op.axis
+            s[stage].storage_align(xo, strides - 1, strides)
+            t = s[stage].fuse(xo, yo)
+            t, vi = s[stage].split(t, factor=vec)
+            t, tx = s[stage].split(t, factor=warp_size)
+            t, ty = s[stage].split(t, factor=block_row_warps)
+            _, tz = s[stage].split(t, factor=block_col_warps)
+            s[stage].bind(ty, thread_y)
+            s[stage].bind(tz, thread_z)
+            s[stage].bind(tx, thread_x)
+            s[stage].vectorize(vi)
+
+        shared_shedule(AS, AS_align)
+        shared_shedule(BS, BS_align)
+
+        shape = (wmma_m, wmma_n, wmma_k)
+        # TODO: add checking here, datatype casting may cause precision loss
+        in_dtype = "float16"
+        AL_gemm = te.placeholder((wmma_m, wmma_k), name="AL_gemm", 
dtype=in_dtype)
+        BL_gemm = te.placeholder((wmma_n, wmma_k), name="BL_gemm", 
dtype=in_dtype)
+        k_gemm = te.reduce_axis((0, wmma_k), name="k_gemm")
+        CL_compute = te.compute(
+            (wmma_m, wmma_n),
+            lambda ii, jj: te.sum(
+                AL_gemm[ii, k_gemm].astype(out_dtype) * BL_gemm[jj, 
k_gemm].astype(out_dtype),
+                axis=k_gemm,
+            ),
+            name="CL_compute",
+        )
+
+        # lower the computation loops down to TensorCore hardware intrinsics
+        # by mapping the dense tensorcore to tensor intrinsics
+        s[AF].tensorize(
+            b_ii,
+            intrin_wmma_load_matrix_A(
+                AF_stride,
+                AS_stride,
+                shape,
+                "row_major",
+                (wmma_m, wmma_k),
+                (wmma_m, wmma_k),
+                "float16",
+            ),
+        )
+        s[BF].tensorize(
+            o_ii,
+            intrin_wmma_load_matrix_W(
+                BF_stride,
+                BS_stride,
+                shape,
+                "col_major",
+                (wmma_n, wmma_k),
+                (wmma_n, wmma_k),
+                "float16",
+            ),
+        )
+        s[CF].tensorize(
+            _ii,
+            intrin_wmma_gemm(AL_gemm, BL_gemm, CL_compute, AF_stride, 
BF_stride, CF_stride, shape),
+        )
+        s[CS].tensorize(
+            bbi,
+            intrin_wmma_store_matrix(
+                CS_stride, CF_stride, shape, out_dtype, (wmma_m, wmma_n), 
(wmma_m, wmma_n)
+            ),
+        )
+
+    def _callback(op):
+        if "batch_matmul_tensorcore" in op.tag:
+            _schedule(cfg, s, op.output(0))
+
+    traverse_inline(s, outs[0].op, _callback)
+    return s
+
+
+def batch_matmul_tensorcore_cuda(x, y):
+    """Computes batch matrix multiplication of `x` and `y` when `x` and `y` are
+    data in batch.
+
+    Parameters
+    ----------
+    x : tvm.te.Tensor
+        3-D with shape [batch, M, K]
+
+    y : tvm.te.Tensor
+        3-D with shape [batch, N, K]
+
+    Returns
+    -------
+    output : tvm.te.Tensor
+        3-D with shape [batch, M, N]
+    """
+    assert len(x.shape) == 3 and len(y.shape) == 3, "only support 3-dim 
batch_matmul"
+    x_shape = get_const_tuple(x.shape)
+    y_shape = get_const_tuple(y.shape)
+    assert x_shape[0] == y_shape[0], "batch dimension doesn't match"
+    assert x_shape[2] == y_shape[2], "shapes of x and y is inconsistant"
+    batch, M, K = x.shape
+    N = y.shape[1]
+    out_dtype = x.dtype
+
+    assert (
+        (M % 8 == 0 and K % 16 == 0 and N % 32 == 0)
+        or (M % 16 == 0 and K % 16 == 0 and N % 16 == 0)
+        or (M % 32 == 0 and K % 16 == 0 and N % 8 == 0)
+    ), "The shape of (M, K, N) must be multiple of (16, 16, 16) or (32, 16, 8) 
or (8, 16, 32)"
+
+    x_16 = te.compute((batch, M, K), lambda b, i, k: x[b, i, 
k].astype("float16"))
+    y_16 = te.compute((batch, N, K), lambda b, j, k: y[b, j, 
k].astype("float16"))
+
+    k = te.reduce_axis((0, K), name="k")
+    return te.compute(
+        (batch, M, N),
+        lambda b, i, j: te.sum(
+            x_16[b, i, k].astype(out_dtype) * y_16[b, j, k].astype(out_dtype), 
axis=k
+        ),
+        tag="batch_matmul_tensorcore",
+    )
diff --git a/python/tvm/topi/cuda/conv2d_nhwc_tensorcore.py 
b/python/tvm/topi/cuda/conv2d_nhwc_tensorcore.py
index f665cc7..76f082f 100644
--- a/python/tvm/topi/cuda/conv2d_nhwc_tensorcore.py
+++ b/python/tvm/topi/cuda/conv2d_nhwc_tensorcore.py
@@ -72,6 +72,7 @@ def nhwc_tensorcore_cuda(cfg, Input, Filter, stride, padding, 
dilation, out_dtyp
     ry = te.reduce_axis((0, kernel_h), name="ry")
     rx = te.reduce_axis((0, kernel_w), name="rx")
     # convert data type of input feature maps and weights
+    # TODO: add checking here, datatype casting may cause precision loss
     TransPaddedInput = te.compute(
         PaddedInput.shape, lambda n, h, w, c: PaddedInput[n, h, w, 
c].astype("float16")
     )
diff --git a/python/tvm/topi/cuda/conv3d_ndhwc_tensorcore.py 
b/python/tvm/topi/cuda/conv3d_ndhwc_tensorcore.py
index a5c4e81..efb2574 100644
--- a/python/tvm/topi/cuda/conv3d_ndhwc_tensorcore.py
+++ b/python/tvm/topi/cuda/conv3d_ndhwc_tensorcore.py
@@ -75,6 +75,7 @@ def ndhwc_tensorcore_cuda(cfg, Input, Filter, stride, 
padding, dilation, out_dty
     ry = te.reduce_axis((0, kernel_h), name="ry")
     rx = te.reduce_axis((0, kernel_w), name="rx")
     # convert data type of input feature maps and weights
+    # TODO: add checking here, datatype casting may cause precision loss
     TransPaddedInput = te.compute(
         PaddedInput.shape, lambda n, d, h, w, c: PaddedInput[n, d, h, w, 
c].astype("float16")
     )
diff --git a/python/tvm/topi/cuda/dense_tensorcore.py 
b/python/tvm/topi/cuda/dense_tensorcore.py
index a59ebd73..430f804 100644
--- a/python/tvm/topi/cuda/dense_tensorcore.py
+++ b/python/tvm/topi/cuda/dense_tensorcore.py
@@ -245,6 +245,7 @@ def _schedule_dense_tensorcore(cfg, s, C):
     shared_shedule(BS, BS_align)
 
     shape = (wmma_m, wmma_n, wmma_k)
+    # TODO: add checking here, datatype casting may cause precision loss
     in_dtype = "float16"
     AL_gemm = te.placeholder((wmma_m, wmma_k), name="AL_gemm", dtype=in_dtype)
     BL_gemm = te.placeholder((wmma_n, wmma_k), name="BL_gemm", dtype=in_dtype)
diff --git a/tests/python/topi/python/test_topi_batch_matmul_tensorcore.py 
b/tests/python/topi/python/test_topi_batch_matmul_tensorcore.py
new file mode 100644
index 0000000..60f4bef3
--- /dev/null
+++ b/tests/python/topi/python/test_topi_batch_matmul_tensorcore.py
@@ -0,0 +1,75 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Test code for batch_matmul operator"""
+import numpy as np
+import tvm
+from tvm import te
+from tvm import topi
+import tvm.topi.testing
+from tvm.topi.utils import get_const_tuple
+from tvm.contrib.pickle_memoize import memoize
+
+import tvm.testing
+
+_batch_matmul_implement = {
+    "gpu": (topi.cuda.batch_matmul_tensorcore, 
topi.cuda.schedule_batch_matmul_tensorcore),
+}
+
+
+def verify_batch_matmul(x_batch, y_batch, M, N, K):
+    x = te.placeholder((x_batch, M, K), name="x")
+    y = te.placeholder((y_batch, N, K), name="y")
+    dtype = x.dtype
+
+    # use memoize to pickle the test data for next time use
+    @memoize("topi.tests.test_topi_batch_matmul_tensorcore")
+    def get_ref_data():
+        a_np = np.random.uniform(size=(x_batch, M, K)).astype(dtype)
+        b_np = np.random.uniform(size=(y_batch, N, K)).astype(dtype)
+        c_np = tvm.topi.testing.batch_matmul(a_np, b_np)
+        return (a_np, b_np, c_np)
+
+    # get the test data
+    a_np, b_np, c_np = get_ref_data()
+
+    def check_device(device):
+        ctx = tvm.context(device, 0)
+        print("Running on target: %s" % device)
+        with tvm.target.Target(device):
+            fcompute, fschedule = tvm.topi.testing.dispatch(device, 
_batch_matmul_implement)
+            out = fcompute(x, y)
+            s = fschedule([out])
+        a = tvm.nd.array(a_np, ctx)
+        b = tvm.nd.array(b_np, ctx)
+        c = tvm.nd.array(np.zeros(get_const_tuple(out.shape), dtype=dtype), 
ctx)
+        f = tvm.build(s, [x, y, out], device, name="dense")
+        f(a, b, c)
+        tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-3)
+
+    check_device("cuda")
+
+
[email protected]_gpu
+def test_batch_matmul():
+    verify_batch_matmul(1, 1, 16, 16, 32)
+    verify_batch_matmul(5, 5, 16, 16, 32)
+    verify_batch_matmul(5, 5, 16, 32, 32)
+    verify_batch_matmul(30, 30, 16, 32, 32)
+
+
+if __name__ == "__main__":
+    test_batch_matmul()

Reply via email to