This is an automated email from the ASF dual-hosted git repository.
jcf94 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 89e3688 [CUDA]batch_matmul tensorcore schedule (#7146)
89e3688 is described below
commit 89e3688137b9d8dd0e431cbdadb84d42dae9eee3
Author: Meteorix <[email protected]>
AuthorDate: Mon Jan 11 09:51:31 2021 +0800
[CUDA]batch_matmul tensorcore schedule (#7146)
* add batch_matmul_tensorcore
* add bmm cublas autotune
* add bmm tests
* out_shape for bmm_tensorcore
* fix comments
* code format
* add todos for tensorcore datatype checking
* fix lint
* fix have_tensorcore
* add dtype check for batch_matmul_tensorcore
---
python/tvm/relay/op/strategy/cuda.py | 16 ++
python/tvm/topi/cuda/__init__.py | 1 +
python/tvm/topi/cuda/batch_matmul.py | 14 +-
python/tvm/topi/cuda/batch_matmul_tensorcore.py | 315 +++++++++++++++++++++
python/tvm/topi/cuda/conv2d_nhwc_tensorcore.py | 1 +
python/tvm/topi/cuda/conv3d_ndhwc_tensorcore.py | 1 +
python/tvm/topi/cuda/dense_tensorcore.py | 1 +
.../python/test_topi_batch_matmul_tensorcore.py | 75 +++++
8 files changed, 422 insertions(+), 2 deletions(-)
diff --git a/python/tvm/relay/op/strategy/cuda.py
b/python/tvm/relay/op/strategy/cuda.py
index 37946c0..04c16dd 100644
--- a/python/tvm/relay/op/strategy/cuda.py
+++ b/python/tvm/relay/op/strategy/cuda.py
@@ -732,6 +732,22 @@ def batch_matmul_strategy_cuda(attrs, inputs, out_type,
target):
name="batch_matmul_cublas.cuda",
plevel=15,
)
+ if target.kind.name == "cuda" and nvcc.have_tensorcore(target=target):
+ x, y = inputs
+ _, M, K = get_const_tuple(x.shape)
+ _, N, K = get_const_tuple(y.shape)
+ if x.dtype in ["float16", "int8", "uint8"] and (
+ (M % 8 == 0 and K % 16 == 0 and N % 32 == 0)
+ or (M % 16 == 0 and K % 16 == 0 and N % 16 == 0)
+ or (M % 32 == 0 and K % 16 == 0 and N % 8 == 0)
+ ):
+ strategy.add_implementation(
+ wrap_compute_batch_matmul(topi.cuda.batch_matmul_tensorcore),
+ wrap_topi_schedule(topi.cuda.schedule_batch_matmul_tensorcore),
+ name="batch_matmul_tensorcore.cuda",
+ plevel=20,
+ )
+
return strategy
diff --git a/python/tvm/topi/cuda/__init__.py b/python/tvm/topi/cuda/__init__.py
index 23c625a..42bf980 100644
--- a/python/tvm/topi/cuda/__init__.py
+++ b/python/tvm/topi/cuda/__init__.py
@@ -42,6 +42,7 @@ from .dense import *
from .pooling import *
from .nn import schedule_lrn
from .batch_matmul import *
+from .batch_matmul_tensorcore import *
from .vision import *
from .ssd import *
from .nms import get_valid_counts, non_max_suppression
diff --git a/python/tvm/topi/cuda/batch_matmul.py
b/python/tvm/topi/cuda/batch_matmul.py
index 8d34b29..006b866 100644
--- a/python/tvm/topi/cuda/batch_matmul.py
+++ b/python/tvm/topi/cuda/batch_matmul.py
@@ -21,7 +21,7 @@ from tvm import autotvm
from tvm import te
from tvm.contrib import cublas
from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity
-from .. import nn
+from .. import nn, generic
from ..utils import traverse_inline, get_const_tuple, get_max_power2_factor
@@ -138,7 +138,8 @@ def schedule_batch_matmul(cfg, outs):
return s
-def batch_matmul_cublas(x, y, out_shape=None):
[email protected]_topi_compute("batch_matmul_cublas.cuda")
+def batch_matmul_cublas(cfg, x, y, out_shape=None):
"""Computes batch matrix multiplication of `x` and `y` when `x` and `y` are
data in batch.
@@ -158,4 +159,13 @@ def batch_matmul_cublas(x, y, out_shape=None):
output : tvm.te.Tensor
3-D with shape [batch, M, N]
"""
+ b, m, k = x.shape
+ b, n, k = y.shape
+ cfg.add_flop(b * m * k * n * 2)
return cublas.batch_matmul(x, y, False, True)
+
+
[email protected]_topi_schedule("batch_matmul_cublas.cuda")
+def schedule_batch_matmul_cublas(_, outs):
+ """Schedule batch_matmul operator using CUBLAS"""
+ return generic.schedule_extern(outs)
diff --git a/python/tvm/topi/cuda/batch_matmul_tensorcore.py
b/python/tvm/topi/cuda/batch_matmul_tensorcore.py
new file mode 100644
index 0000000..59b92ec
--- /dev/null
+++ b/python/tvm/topi/cuda/batch_matmul_tensorcore.py
@@ -0,0 +1,315 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name,too-many-locals,unused-variable,unused-argument
+"""cuda batch_matmul operators"""
+import tvm
+from tvm import autotvm
+from tvm import te
+from ..utils import traverse_inline, get_const_tuple
+from .tensor_intrin import (
+ intrin_wmma_load_matrix_A,
+ intrin_wmma_load_matrix_W,
+ intrin_wmma_store_matrix,
+ intrin_wmma_gemm,
+)
+
+
[email protected]_topi_compute("batch_matmul_tensorcore.cuda")
+def batch_matmul_tensorcore(cfg, x, y, out_shape=None):
+ """batch matmul tensorcore operator on cuda"""
+ # todo: deal with out_shape for broadcast, liuxin.ai
+ return batch_matmul_tensorcore_cuda(x, y)
+
+
[email protected]_topi_schedule("batch_matmul_tensorcore.cuda")
+def schedule_batch_matmul_tensorcore(cfg, outs):
+ """Schedule for batch_matmul operator using Tensorcore
+
+ Parameters
+ ----------
+ outs: Array of Tensor
+ The computation graph description of batch_matmul
+ in the format of an array of tensors.
+
+ Returns
+ -------
+ s: Schedule
+ The computation schedule for the op.
+ """
+ outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
+ s = te.create_schedule([x.op for x in outs])
+
+ def _schedule(cfg, s, C):
+ A, B = s[C].op.input_tensors
+ batch, m_dim, k_dim = get_const_tuple(A.shape)
+ batch, n_dim, k_dim = get_const_tuple(B.shape)
+ out_dtype = C.dtype
+ # inline astype fp16
+ s[A].compute_inline()
+ s[B].compute_inline()
+
+ # Explicit memory access
+ AS = s.cache_read(A, "shared", [C])
+ BS = s.cache_read(B, "shared", [C])
+ AF = s.cache_read(AS, "wmma.matrix_a", [C])
+ BF = s.cache_read(BS, "wmma.matrix_b", [C])
+ CF = s.cache_write(C, "wmma.accumulator")
+ CS = s.cache_read(CF, "shared", [C])
+
+ # fallback support
+ target = tvm.target.Target.current()
+ if cfg.is_fallback:
+ ref_log = autotvm.tophub.load_reference_log(
+ target.kind.name, target.model, "batch_matmul_tensorcore.cuda"
+ )
+ cfg.fallback_with_reference_log(ref_log)
+
+ # Deal with op fusion, such as bias/relu and slice after padding
+ if C.op not in s.outputs and "injective" in s.outputs[0].tag:
+ s[C].compute_inline()
+ C = s.outputs[0].output(0)
+
+ # create tuning space
+ cfg.define_knob("block_row_warps", [1, 2, 4])
+ cfg.define_knob("block_col_warps", [1, 2, 4])
+ cfg.define_knob("warp_row_tiles", [1, 2, 4])
+ cfg.define_knob("warp_col_tiles", [1, 2, 4])
+ cfg.define_knob("chunk", [1, 2, 4, 8])
+ cfg.define_knob("offset", [0, 8])
+ cfg.define_knob("offsetCS", [0, 8])
+ cfg.define_knob("vec", [1, 2, 4, 8])
+
+ # Ensure that the default parameters are applicable when autotvm is
not in use
+ if m_dim % 32 == 0 and n_dim % 8 == 0:
+ cfg.define_knob("wmma_m", [32, 16, 8])
+ elif m_dim % 16 == 0 and n_dim % 16 == 0:
+ cfg.define_knob("wmma_m", [16, 8, 32])
+ elif m_dim % 8 == 0 and n_dim % 32 == 0:
+ cfg.define_knob("wmma_m", [8, 16, 32])
+
+ warp_size = 32
+ wmma_k = 16
+ block_row_warps = cfg["block_row_warps"].val
+ block_col_warps = cfg["block_col_warps"].val
+ warp_row_tiles = cfg["warp_row_tiles"].val
+ warp_col_tiles = cfg["warp_col_tiles"].val
+ chunk = cfg["chunk"].val
+ offset = cfg["offset"].val
+ offsetCS = cfg["offsetCS"].val
+ wmma_m = cfg["wmma_m"].val
+ vec = cfg["vec"].val
+
+ if wmma_m == 16:
+ wmma_n = 16
+ elif wmma_m == 8:
+ wmma_n = 32
+ elif wmma_m == 32:
+ wmma_n = 8
+
+ # Define the stride of intrin functions
+ AS_align = chunk * wmma_k + offset
+ BS_align = chunk * wmma_k + offset
+ CS_align = warp_col_tiles * block_col_warps * wmma_n + offsetCS
+ AS_stride = [AS_align, 1]
+ BS_stride = [BS_align, 1]
+ AF_stride = [wmma_k, 1]
+ BF_stride = [wmma_k, 1]
+ CF_stride = [warp_col_tiles * wmma_n, 1]
+ CS_stride = [CS_align, 1]
+
+ block_x = te.thread_axis("blockIdx.x")
+ block_y = te.thread_axis("blockIdx.y")
+ block_z = te.thread_axis("blockIdx.z")
+ thread_x = te.thread_axis("threadIdx.x")
+ thread_y = te.thread_axis("threadIdx.y")
+ thread_z = te.thread_axis("threadIdx.z")
+
+ # Schedule for dense computation
+ block_factor_m = wmma_m * warp_row_tiles * block_row_warps
+ block_factor_n = wmma_n * warp_col_tiles * block_col_warps
+ b, m, n = C.op.axis
+ block_i, bc = s[C].split(m, factor=block_factor_m)
+ block_j, oc = s[C].split(n, factor=block_factor_n)
+ s[C].reorder(b, block_i, block_j, bc, oc)
+ t = s[C].fuse(bc, oc)
+ t, vi = s[C].split(t, factor=vec)
+ t, tx = s[C].split(t, factor=warp_size)
+ t, ty = s[C].split(t, factor=block_row_warps)
+ t, tz = s[C].split(t, factor=block_col_warps)
+ s[C].bind(block_i, block_x)
+ s[C].bind(block_j, block_y)
+ s[C].bind(b, block_z)
+ s[C].bind(tz, thread_z)
+ s[C].bind(ty, thread_y)
+ s[C].bind(tx, thread_x)
+ s[C].vectorize(vi)
+
+ # Schedule for wmma store
+ s[CS].compute_at(s[C], block_j)
+ bs, bb, oo = CS.op.axis
+ s[CS].storage_align(bb, CS_align - 1, CS_align)
+ bb, bbi = s[CS].split(bb, factor=wmma_m)
+ oo, ooi = s[CS].split(oo, factor=wmma_n)
+ bb, bbii = s[CS].split(bb, factor=warp_row_tiles)
+ oo, ooii = s[CS].split(oo, factor=warp_col_tiles)
+ s[CS].reorder(bs, bb, oo, bbii, ooii, bbi, ooi)
+
+ # Schedule for wmma computation
+ s[CF].compute_at(s[CS], oo)
+ bs, warp_i, warp_j = CF.op.axis
+ warp_i, _ii = s[CF].split(warp_i, factor=wmma_m)
+ warp_j, _jj = s[CF].split(warp_j, factor=wmma_n)
+ (k,) = CF.op.reduce_axis
+ k, _k = s[CF].split(k, factor=wmma_k)
+ ko, ki = s[CF].split(k, factor=chunk)
+ s[CF].reorder(bs, ko, ki, warp_i, warp_j, _ii, _jj, _k)
+
+ # Schedule for wmma_matrix_a load
+ s[AF].compute_at(s[CF], ki)
+ bs, b, i = AF.op.axis
+ b, b_ii = s[AF].split(b, factor=wmma_m)
+ i, i_jj = s[AF].split(i, factor=wmma_k)
+ s[AF].reorder(bs, b, i, b_ii, i_jj)
+
+ # Schedule for wmma_matrix_b load
+ s[BF].compute_at(s[CF], ki)
+ bs, o, i = BF.op.axis
+ o, o_ii = s[BF].split(o, factor=wmma_n)
+ i, i_ii = s[BF].split(i, factor=wmma_k)
+ s[BF].reorder(bs, o, i, o_ii, i_ii)
+
+ # Schedule for A's(B's) shared memory load
+ def shared_shedule(stage, strides):
+ s[stage].compute_at(s[CF], ko)
+ bs, xo, yo = stage.op.axis
+ s[stage].storage_align(xo, strides - 1, strides)
+ t = s[stage].fuse(xo, yo)
+ t, vi = s[stage].split(t, factor=vec)
+ t, tx = s[stage].split(t, factor=warp_size)
+ t, ty = s[stage].split(t, factor=block_row_warps)
+ _, tz = s[stage].split(t, factor=block_col_warps)
+ s[stage].bind(ty, thread_y)
+ s[stage].bind(tz, thread_z)
+ s[stage].bind(tx, thread_x)
+ s[stage].vectorize(vi)
+
+ shared_shedule(AS, AS_align)
+ shared_shedule(BS, BS_align)
+
+ shape = (wmma_m, wmma_n, wmma_k)
+ # TODO: add checking here, datatype casting may cause precision loss
+ in_dtype = "float16"
+ AL_gemm = te.placeholder((wmma_m, wmma_k), name="AL_gemm",
dtype=in_dtype)
+ BL_gemm = te.placeholder((wmma_n, wmma_k), name="BL_gemm",
dtype=in_dtype)
+ k_gemm = te.reduce_axis((0, wmma_k), name="k_gemm")
+ CL_compute = te.compute(
+ (wmma_m, wmma_n),
+ lambda ii, jj: te.sum(
+ AL_gemm[ii, k_gemm].astype(out_dtype) * BL_gemm[jj,
k_gemm].astype(out_dtype),
+ axis=k_gemm,
+ ),
+ name="CL_compute",
+ )
+
+ # lower the computation loops down to TensorCore hardware intrinsics
+ # by mapping the dense tensorcore to tensor intrinsics
+ s[AF].tensorize(
+ b_ii,
+ intrin_wmma_load_matrix_A(
+ AF_stride,
+ AS_stride,
+ shape,
+ "row_major",
+ (wmma_m, wmma_k),
+ (wmma_m, wmma_k),
+ "float16",
+ ),
+ )
+ s[BF].tensorize(
+ o_ii,
+ intrin_wmma_load_matrix_W(
+ BF_stride,
+ BS_stride,
+ shape,
+ "col_major",
+ (wmma_n, wmma_k),
+ (wmma_n, wmma_k),
+ "float16",
+ ),
+ )
+ s[CF].tensorize(
+ _ii,
+ intrin_wmma_gemm(AL_gemm, BL_gemm, CL_compute, AF_stride,
BF_stride, CF_stride, shape),
+ )
+ s[CS].tensorize(
+ bbi,
+ intrin_wmma_store_matrix(
+ CS_stride, CF_stride, shape, out_dtype, (wmma_m, wmma_n),
(wmma_m, wmma_n)
+ ),
+ )
+
+ def _callback(op):
+ if "batch_matmul_tensorcore" in op.tag:
+ _schedule(cfg, s, op.output(0))
+
+ traverse_inline(s, outs[0].op, _callback)
+ return s
+
+
+def batch_matmul_tensorcore_cuda(x, y):
+ """Computes batch matrix multiplication of `x` and `y` when `x` and `y` are
+ data in batch.
+
+ Parameters
+ ----------
+ x : tvm.te.Tensor
+ 3-D with shape [batch, M, K]
+
+ y : tvm.te.Tensor
+ 3-D with shape [batch, N, K]
+
+ Returns
+ -------
+ output : tvm.te.Tensor
+ 3-D with shape [batch, M, N]
+ """
+ assert len(x.shape) == 3 and len(y.shape) == 3, "only support 3-dim
batch_matmul"
+ x_shape = get_const_tuple(x.shape)
+ y_shape = get_const_tuple(y.shape)
+ assert x_shape[0] == y_shape[0], "batch dimension doesn't match"
+ assert x_shape[2] == y_shape[2], "shapes of x and y is inconsistant"
+ batch, M, K = x.shape
+ N = y.shape[1]
+ out_dtype = x.dtype
+
+ assert (
+ (M % 8 == 0 and K % 16 == 0 and N % 32 == 0)
+ or (M % 16 == 0 and K % 16 == 0 and N % 16 == 0)
+ or (M % 32 == 0 and K % 16 == 0 and N % 8 == 0)
+ ), "The shape of (M, K, N) must be multiple of (16, 16, 16) or (32, 16, 8)
or (8, 16, 32)"
+
+ x_16 = te.compute((batch, M, K), lambda b, i, k: x[b, i,
k].astype("float16"))
+ y_16 = te.compute((batch, N, K), lambda b, j, k: y[b, j,
k].astype("float16"))
+
+ k = te.reduce_axis((0, K), name="k")
+ return te.compute(
+ (batch, M, N),
+ lambda b, i, j: te.sum(
+ x_16[b, i, k].astype(out_dtype) * y_16[b, j, k].astype(out_dtype),
axis=k
+ ),
+ tag="batch_matmul_tensorcore",
+ )
diff --git a/python/tvm/topi/cuda/conv2d_nhwc_tensorcore.py
b/python/tvm/topi/cuda/conv2d_nhwc_tensorcore.py
index f665cc7..76f082f 100644
--- a/python/tvm/topi/cuda/conv2d_nhwc_tensorcore.py
+++ b/python/tvm/topi/cuda/conv2d_nhwc_tensorcore.py
@@ -72,6 +72,7 @@ def nhwc_tensorcore_cuda(cfg, Input, Filter, stride, padding,
dilation, out_dtyp
ry = te.reduce_axis((0, kernel_h), name="ry")
rx = te.reduce_axis((0, kernel_w), name="rx")
# convert data type of input feature maps and weights
+ # TODO: add checking here, datatype casting may cause precision loss
TransPaddedInput = te.compute(
PaddedInput.shape, lambda n, h, w, c: PaddedInput[n, h, w,
c].astype("float16")
)
diff --git a/python/tvm/topi/cuda/conv3d_ndhwc_tensorcore.py
b/python/tvm/topi/cuda/conv3d_ndhwc_tensorcore.py
index a5c4e81..efb2574 100644
--- a/python/tvm/topi/cuda/conv3d_ndhwc_tensorcore.py
+++ b/python/tvm/topi/cuda/conv3d_ndhwc_tensorcore.py
@@ -75,6 +75,7 @@ def ndhwc_tensorcore_cuda(cfg, Input, Filter, stride,
padding, dilation, out_dty
ry = te.reduce_axis((0, kernel_h), name="ry")
rx = te.reduce_axis((0, kernel_w), name="rx")
# convert data type of input feature maps and weights
+ # TODO: add checking here, datatype casting may cause precision loss
TransPaddedInput = te.compute(
PaddedInput.shape, lambda n, d, h, w, c: PaddedInput[n, d, h, w,
c].astype("float16")
)
diff --git a/python/tvm/topi/cuda/dense_tensorcore.py
b/python/tvm/topi/cuda/dense_tensorcore.py
index a59ebd73..430f804 100644
--- a/python/tvm/topi/cuda/dense_tensorcore.py
+++ b/python/tvm/topi/cuda/dense_tensorcore.py
@@ -245,6 +245,7 @@ def _schedule_dense_tensorcore(cfg, s, C):
shared_shedule(BS, BS_align)
shape = (wmma_m, wmma_n, wmma_k)
+ # TODO: add checking here, datatype casting may cause precision loss
in_dtype = "float16"
AL_gemm = te.placeholder((wmma_m, wmma_k), name="AL_gemm", dtype=in_dtype)
BL_gemm = te.placeholder((wmma_n, wmma_k), name="BL_gemm", dtype=in_dtype)
diff --git a/tests/python/topi/python/test_topi_batch_matmul_tensorcore.py
b/tests/python/topi/python/test_topi_batch_matmul_tensorcore.py
new file mode 100644
index 0000000..60f4bef3
--- /dev/null
+++ b/tests/python/topi/python/test_topi_batch_matmul_tensorcore.py
@@ -0,0 +1,75 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Test code for batch_matmul operator"""
+import numpy as np
+import tvm
+from tvm import te
+from tvm import topi
+import tvm.topi.testing
+from tvm.topi.utils import get_const_tuple
+from tvm.contrib.pickle_memoize import memoize
+
+import tvm.testing
+
+_batch_matmul_implement = {
+ "gpu": (topi.cuda.batch_matmul_tensorcore,
topi.cuda.schedule_batch_matmul_tensorcore),
+}
+
+
+def verify_batch_matmul(x_batch, y_batch, M, N, K):
+ x = te.placeholder((x_batch, M, K), name="x")
+ y = te.placeholder((y_batch, N, K), name="y")
+ dtype = x.dtype
+
+ # use memoize to pickle the test data for next time use
+ @memoize("topi.tests.test_topi_batch_matmul_tensorcore")
+ def get_ref_data():
+ a_np = np.random.uniform(size=(x_batch, M, K)).astype(dtype)
+ b_np = np.random.uniform(size=(y_batch, N, K)).astype(dtype)
+ c_np = tvm.topi.testing.batch_matmul(a_np, b_np)
+ return (a_np, b_np, c_np)
+
+ # get the test data
+ a_np, b_np, c_np = get_ref_data()
+
+ def check_device(device):
+ ctx = tvm.context(device, 0)
+ print("Running on target: %s" % device)
+ with tvm.target.Target(device):
+ fcompute, fschedule = tvm.topi.testing.dispatch(device,
_batch_matmul_implement)
+ out = fcompute(x, y)
+ s = fschedule([out])
+ a = tvm.nd.array(a_np, ctx)
+ b = tvm.nd.array(b_np, ctx)
+ c = tvm.nd.array(np.zeros(get_const_tuple(out.shape), dtype=dtype),
ctx)
+ f = tvm.build(s, [x, y, out], device, name="dense")
+ f(a, b, c)
+ tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-3)
+
+ check_device("cuda")
+
+
[email protected]_gpu
+def test_batch_matmul():
+ verify_batch_matmul(1, 1, 16, 16, 32)
+ verify_batch_matmul(5, 5, 16, 16, 32)
+ verify_batch_matmul(5, 5, 16, 32, 32)
+ verify_batch_matmul(30, 30, 16, 32, 32)
+
+
+if __name__ == "__main__":
+ test_batch_matmul()