This is an automated email from the ASF dual-hosted git repository.
laurawly pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new 9cc15a4 [TOPI] Support int4/int8 conv2d tensor core with HWNC layout
(#6121)
9cc15a4 is described below
commit 9cc15a40a37e940f0f48c9af42fcedaa4274c9d9
Author: GaryYuyjl <[email protected]>
AuthorDate: Tue Aug 18 13:32:30 2020 +0800
[TOPI] Support int4/int8 conv2d tensor core with HWNC layout (#6121)
* int4 tensorcore
* a draft for new int4 schedule
* update layout
* add inline option
* clean code
* increase search space
* fix kernel shape
* update intrinsic
* update intrinsic
* support int4/int8 hwnc layout
* remove useless code
* remove useless code
* remove useless code
* remove useless code
* fix int8 transpose
* fix assert
* add asf header
* CI
* CI
* CI
* fix bug
fix bug
Co-authored-by: Leyuan Wang <[email protected]>
---
python/tvm/relay/op/strategy/cuda.py | 19 +
python/tvm/topi/cuda/__init__.py | 1 +
python/tvm/topi/cuda/conv2d_alter_op.py | 30 ++
python/tvm/topi/cuda/conv2d_hwnc_tensorcore.py | 440 +++++++++++++++++++++
src/target/source/codegen_c.cc | 8 +-
.../python/test_topi_conv2d_hwnc_tensorcore.py | 133 +++++++
6 files changed, 629 insertions(+), 2 deletions(-)
diff --git a/python/tvm/relay/op/strategy/cuda.py
b/python/tvm/relay/op/strategy/cuda.py
index b2a0ff4..4b50937 100644
--- a/python/tvm/relay/op/strategy/cuda.py
+++ b/python/tvm/relay/op/strategy/cuda.py
@@ -172,6 +172,25 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
wrap_topi_schedule(topi.cuda.schedule_conv2d_nhwc_tensorcore),
name="conv2d_nhwc_tensorcore.cuda",
plevel=20)
+ elif layout == "HWNC":
+ assert kernel_layout in ["HWOI", "HWOI16o16i", "HWOI8o32i",
"HWOI32o16i"]
+ _, _, N, in_channels = get_const_tuple(data.shape)
+ pre_computed = len(kernel.shape) == 6
+ if pre_computed:
+ _, _, oc_chunk, _, oc_block_factor, _ =
get_const_tuple(kernel.shape)
+ out_channels = oc_chunk * oc_block_factor
+ else:
+ _, _, out_channels, _ = get_const_tuple(kernel.shape)
+ if topi.cuda.is_shape_tensorcore_direct_qualified(
+ batch=N, in_channels=in_channels, num_filter=out_channels,
in_dtype=data.dtype):
+ strategy.add_implementation(
+ wrap_compute_conv2d(topi.cuda.conv2d_hwnc_tensorcore),
+
wrap_topi_schedule(topi.cuda.schedule_conv2d_hwnc_tensorcore),
+ name="conv2d_hwnc_tensorcore_direct.cuda",
+ plevel=20)
+ else:
+ raise RuntimeError("Unsupported shape for conv2d HWNC.\
+ Need to satisfy tensor core schedule.")
elif layout == "NCHW4c" and data.dtype in ["int8", "uint8"]:
assert kernel_layout == "OIHW4o4i"
strategy.add_implementation(
diff --git a/python/tvm/topi/cuda/__init__.py b/python/tvm/topi/cuda/__init__.py
index 90f4e60..ed80370 100644
--- a/python/tvm/topi/cuda/__init__.py
+++ b/python/tvm/topi/cuda/__init__.py
@@ -50,5 +50,6 @@ from .sort import *
from .conv2d_nhwc_tensorcore import *
from .conv3d_ndhwc_tensorcore import *
from .dense_tensorcore import *
+from .conv2d_hwnc_tensorcore import *
from .correlation import *
from .sparse import *
diff --git a/python/tvm/topi/cuda/conv2d_alter_op.py
b/python/tvm/topi/cuda/conv2d_alter_op.py
index c2a1905..f07ef98 100644
--- a/python/tvm/topi/cuda/conv2d_alter_op.py
+++ b/python/tvm/topi/cuda/conv2d_alter_op.py
@@ -171,6 +171,36 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
dispatch_ctx.update(target, new_workload, cfg)
return relay.nn.conv2d(*inputs, **new_attrs)
+ if topi_tmpl == "conv2d_HWNCnc_tensorcore.cuda":
+ assert data_layout == "HWNC" and kernel_layout == "HWOI"
+ assert float(tvm.gpu(0).compute_version) >= 7.5
+ H, W, N, CI = get_const_tuple(data.shape)
+ KH, KW, CO, _ = get_const_tuple(kernel.shape)
+
+ if kernel.dtype in ['int4', 'uint4'] and (CI % 32 != 0 or CO % 8 != 0)
or \
+ kernel.dtype in ['int8', 'uint8'] and (CI % 16 != 0 or CO % 32 !=
0):
+ return relay.nn.conv2d(*inputs, **new_attrs)
+
+ new_attrs["channels"] = CO
+ if kernel.dtype in ['int4', 'uint4']:
+ new_attrs['kernel_layout'] = 'HWOI8o32i'
+ ic_block_factor = 32
+ oc_block_factor = 8
+ else:
+ new_attrs['kernel_layout'] = 'HWOI32o16i'
+ ic_block_factor = 16
+ oc_block_factor = 32
+
+ new_kernel = te.placeholder((KH, KW, CO // oc_block_factor, CI //
ic_block_factor,
+ oc_block_factor, ic_block_factor),
dtype=kernel.dtype)
+
+ new_workload = autotvm.task.args_to_workload(
+ [data, new_kernel, strides, padding, dilation, out_dtype],
+ "conv2d_HWNCnc_tensorcore.cuda")
+
+ dispatch_ctx.update(target, new_workload, cfg)
+ return relay.nn.conv2d(*inputs, **new_attrs)
+
return None
@conv2d_legalize.register("cuda")
diff --git a/python/tvm/topi/cuda/conv2d_hwnc_tensorcore.py
b/python/tvm/topi/cuda/conv2d_hwnc_tensorcore.py
new file mode 100644
index 0000000..592613f
--- /dev/null
+++ b/python/tvm/topi/cuda/conv2d_hwnc_tensorcore.py
@@ -0,0 +1,440 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, too-many-locals, too-many-function-args
+# pylint: disable=too-many-statements, unused-argument, too-many-arguments
+"""Tensorcore template for cuda backend"""
+import tvm
+from tvm import te
+from tvm import autotvm
+from tvm.topi.cuda.injective import schedule_injective_from_existing
+from ..util import get_const_tuple, traverse_inline, simplify, tag
+from ..nn.pad import pad
+from ..nn.util import get_pad_tuple
+from .tensor_intrin import intrin_wmma_load_matrix_A
+from .tensor_intrin import intrin_wmma_load_matrix_W
+from .tensor_intrin import intrin_wmma_store_matrix
+from .tensor_intrin import intrin_wmma_gemm
+
+
+def unpack_HWNCnc_to_hwnc(packed_out, out_dtype):
+ """Unpack conv2d_hwnc output from layout hwncnc to hwnc
+
+ Parameters
+ -----------
+ packed_out : tvm.te.Tensor
+ The output tensor of conv2d_hwnc.
+
+ out_dtype : str
+ The output dtype.
+
+ Returns
+ -------
+ unpacked_out : tvm.te.Tensor
+ The unpacked output tensor in hwnc layout.
+ """
+ H, W, N, O, wmma_m, wmma_n = get_const_tuple(packed_out.shape)
+
+ idxmod = tvm.tir.indexmod
+ idxdiv = tvm.tir.indexdiv
+
+ oshape = (H, W, N * wmma_m, O * wmma_n)
+ unpacked_out = \
+ te.compute(oshape,
+ lambda h, w, n, o:
+ packed_out[h, w, idxdiv(n, wmma_m), idxdiv(o, wmma_n),
+ idxmod(n, wmma_m), idxmod(o, wmma_n)]
+ .astype(out_dtype),
+ name='output_unpack',
+ tag=tag.INJECTIVE + ",unpack_hwncc")
+ return unpacked_out
+
+
+def conv2d_hwnc_tensorcore(data, kernel, strides, padding, dilation, in_dtype,
out_dtype='int32'):
+ """"Compute conv2d with tensorcore for HWNC layout with int8/int4"""
+ assert data.dtype in ('int4', 'uint4', 'int8', 'uint8')
+ assert kernel.dtype in ('int4', 'uint4', 'int8', 'uint8')
+ packed_out = hwnc_tensorcore_cuda(
+ data, kernel, strides, padding, dilation, out_dtype)
+ return unpack_HWNCnc_to_hwnc(packed_out, out_dtype)
+
+
[email protected]_topi_compute("conv2d_HWNCnc_tensorcore.cuda")
+def hwnc_tensorcore_cuda(cfg, Input, Filter, stride, padding, dilation,
out_dtype='int32'):
+ """Compute declaration for tensorcore"""
+ assert isinstance(stride, int) or len(stride) == 2
+ assert isinstance(dilation, int) or len(dilation) == 2
+
+ if isinstance(stride, int):
+ stride_h = stride_w = stride
+ else:
+ stride_h, stride_w = stride
+
+ if isinstance(dilation, int):
+ dilation_h = dilation_w = dilation
+ else:
+ dilation_h, dilation_w = dilation
+
+ in_dtype = Input.dtype
+
+ if in_dtype in ['int4', 'uint4']:
+ wmma_n = wmma_m = 8
+ wmma_k = 32
+ else:
+ wmma_m = 8
+ wmma_n = 32
+ wmma_k = 16
+
+ pre_computed = len(Filter.shape) == 6
+ in_height, in_width, batch, in_channels = get_const_tuple(Input.shape)
+ if pre_computed:
+ kernel_h, kernel_w, oc_chunk, _, oc_block_factor, _\
+ = get_const_tuple(Filter.shape)
+ num_filter = oc_block_factor * oc_chunk
+ else:
+ kernel_h, kernel_w, num_filter, _ = get_const_tuple(Filter.shape)
+
+ if in_dtype in ['int4', 'uint4']:
+ assert (batch % 8 == 0 and in_channels %
+ 32 == 0 and num_filter % 8 == 0)
+ else:
+ assert (batch % 8 == 0 and in_channels % 16 == 0 and num_filter % 32
== 0), \
+ "The shape of (batch, in_channels, num_filter) "\
+ "must be multiple of (8, 16, 32) for int8, "\
+ "and (8, 32, 8) for int4"
+
+ # compute the output shape
+ dilated_kernel_h = (kernel_h - 1) * dilation_h + 1
+ dilated_kernel_w = (kernel_w - 1) * dilation_w + 1
+
+ pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
+ padding, (dilated_kernel_h, dilated_kernel_w))
+
+ out_channels = num_filter
+ out_height = simplify(
+ (in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1)
+ out_width = simplify((in_width - dilated_kernel_w +
+ pad_left + pad_right) // stride_w + 1)
+
+ cfg.add_flop(2 * batch * out_height * out_width *
+ out_channels * in_channels * kernel_h * kernel_w)
+
+ # Input feature map: (H, W, N, IC, n, ic)
+ data_shape = (in_height,
+ in_width,
+ batch // wmma_m,
+ in_channels // wmma_k,
+ wmma_m,
+ wmma_k)
+
+ # Kernel: (H, W, OC, IC, oc, ic)
+ kernel_shape = (kernel_h,
+ kernel_w,
+ out_channels // wmma_n,
+ in_channels // wmma_k,
+ wmma_n,
+ wmma_k)
+
+ # Reduction axes
+ kh = te.reduce_axis((0, kernel_h), name='kh')
+ kw = te.reduce_axis((0, kernel_w), name='kw')
+ ic = te.reduce_axis((0, in_channels // wmma_k), name='ic')
+ ii = te.reduce_axis((0, wmma_k), name='ii')
+
+ if pre_computed:
+ packed_kernel = Filter
+ else:
+ packed_kernel = te.compute(kernel_shape, lambda kh, kw, o, i, oo, ii:
+ Filter[kh, kw, o * wmma_n + oo, i * wmma_k
+ ii],
+ name="packed_kernel"
+ )
+
+ packed_data = te.compute(data_shape,
+ lambda h, w, n, i, nn, ii: Input[h,
+ w, n * wmma_m +
nn, i * wmma_k + ii]
+ )
+
+ pad_before = [pad_top, pad_left, 0, 0, 0, 0]
+ pad_after = [pad_down, pad_right, 0, 0, 0, 0]
+ pad_data = pad(packed_data, pad_before, pad_after, name="pad_data")
+
+ Conv = te.compute((out_height, out_width, batch // wmma_m,
+ out_channels // wmma_n, wmma_m, wmma_n),
+ lambda h, w, n, o, nn, oo: te.sum(
+ (pad_data[h * stride_h + kh, w * stride_w + kw,
+ n, ic, nn, ii].astype('int32') *
+ packed_kernel[kh, kw, o, ic, oo,
ii].astype('int32')),
+ axis=[ic, kh, kw, ii]),
+ name="Conv", tag="conv2d_HWNCnc_tensorcore")
+ return Conv
+
+
+def schedule_hwnc_tensorcore_cuda(cfg, s, Conv):
+ """Schedule tensorcore template"""
+ packed_data, packed_kernel = s[Conv].op.input_tensors
+ ic, kh, kw, ii = s[Conv].op.reduce_axis
+ pad_data = s[packed_data].op.input_tensors[0]
+
+ block_x = te.thread_axis('blockIdx.x')
+ block_y = te.thread_axis('blockIdx.y')
+ block_z = te.thread_axis('blockIdx.z')
+ thread_x = te.thread_axis('threadIdx.x')
+ thread_y = te.thread_axis('threadIdx.y')
+ thread_z = te.thread_axis('threadIdx.z')
+
+ # Designate the memory hierarchy
+ AS = s.cache_read(packed_data, 'shared', [Conv])
+ WS = s.cache_read(packed_kernel, 'shared', [Conv])
+ AF = s.cache_read(AS, 'wmma.matrix_a', [Conv])
+ WF = s.cache_read(WS, 'wmma.matrix_b', [Conv])
+ ConvF = s.cache_write(Conv, 'wmma.accumulator')
+
+ if Conv.op in s.outputs:
+ output = Conv
+ ConvS = s.cache_read(ConvF, 'shared', [Conv])
+ OL = ConvS
+ else:
+ output = s.outputs[0].output(0)
+ s[Conv].set_scope('shared')
+ OL = Conv
+
+ out_dtype = Conv.dtype
+
+ if isinstance(packed_kernel.op, te.tensor.ComputeOp) and
packed_kernel.name == "packed_kernel":
+ if autotvm.GLOBAL_SCOPE.in_tuning:
+ s[packed_kernel].pragma(
+ s[packed_kernel].op.axis[0], "debug_skip_region")
+ else:
+ with tvm.target.create('cuda'):
+ schedule_injective_from_existing(s, packed_kernel)
+
+ if isinstance(pad_data.op, te.tensor.ComputeOp) and "pad" in
pad_data.op.tag:
+ s[pad_data].compute_inline()
+ data = pad_data.op.input_tensors[0]
+
+ if autotvm.GLOBAL_SCOPE.in_tuning:
+ # skip this part during tuning to make recrods accurate
+ # this part will be pre-computed during NNVM's pre-compute
optimization pass
+ s[pad_data].pragma(s[pad_data].op.axis[0], "debug_skip_region")
+ else:
+ data = pad_data
+ s[data].compute_inline()
+
+ data_dtype = data.dtype
+ kernel_dtype = packed_kernel.dtype
+
+ # Schedule for autotvm
+ cfg.define_knob("block_row_warps", [1, 2, 4])
+ cfg.define_knob("block_col_warps", [1, 2, 4])
+ cfg.define_knob("warp_row_tiles", [1, 2, 4, 8, 16])
+ cfg.define_knob("warp_col_tiles", [1, 2, 4, 8, 16])
+ cfg.define_knob("chunk", [1, 2, 4, 8])
+ cfg.define_knob("fuse_pack", [0, 1])
+ cfg.define_knob("split_block_k_nums", [1, 2, 4, 8, 16, 32])
+ cfg.define_knob("vector_ws", [1, 8])
+ cfg.define_knob("vector_as", [1, 8, 16])
+
+ block_row_warps = cfg["block_row_warps"].val
+ block_col_warps = cfg["block_col_warps"].val
+ warp_row_tiles = cfg["warp_row_tiles"].val
+ warp_col_tiles = cfg["warp_col_tiles"].val
+ chunk = cfg["chunk"].val
+ vector_as = cfg["vector_as"].val
+ vector_ws = cfg["vector_ws"].val
+ split_block_k_nums = cfg["split_block_k_nums"].val
+ fuse_pack = cfg["fuse_pack"].val
+
+ if not fuse_pack:
+ s[packed_data].compute_inline()
+ else:
+ with tvm.target.create('cuda'):
+ schedule_injective_from_existing(s, packed_data)
+
+ if data_dtype in ['int4', 'uint4']:
+ wmma_m = wmma_n = 8
+ wmma_k = 32
+ else:
+ wmma_m = 8
+ wmma_n = 32
+ wmma_k = 16
+
+ warp_size = 32
+
+ # Schedule for output
+ if len(s[output].op.axis) == 4:
+ hc, wc, nc, oc, = output.op.axis
+ nc, nnc = s[output].split(nc, factor=wmma_m)
+ oc, ooc = s[output].split(oc, factor=wmma_n)
+ else:
+ hc, wc, nc, oc, nnc, ooc = output.op.axis
+
+ kernel_scope, hc = s[output].split(hc, nparts=1)
+
+ block_k = s[output].fuse(hc, wc)
+ block_k, split_block_k = s[output].split(
+ block_k, factor=split_block_k_nums)
+ nc, nci = s[output].split(nc, factor=warp_row_tiles)
+ block_i, nc = s[output].split(nc, factor=block_row_warps)
+ oc, oci = s[output].split(oc, factor=warp_col_tiles)
+ block_j, oc = s[output].split(oc, factor=block_col_warps)
+ s[output].reorder(block_k, split_block_k, block_i,
+ block_j, nc, oc, nci, oci, nnc, ooc)
+ t = s[output].fuse(nnc, ooc)
+ _, tx = s[output].split(t, factor=warp_size)
+ s[output].bind(block_k, block_z)
+ s[output].bind(block_i, block_x)
+ s[output].bind(block_j, block_y)
+ s[output].bind(tx, thread_x)
+ s[output].bind(nc, thread_y)
+ s[output].bind(oc, thread_z)
+
+ # Schedule wmma store
+ s[OL].compute_at(s[output], block_j)
+ hc, wc, nc, oc, nnc, ooc = OL.op.axis
+ oc, oci = s[OL].split(oc, factor=warp_col_tiles)
+ _, oc = s[OL].split(oc, factor=block_col_warps)
+ nc, nci = s[OL].split(nc, factor=warp_row_tiles)
+ _, nc = s[OL].split(nc, factor=block_row_warps)
+ s[OL].reorder(nc, oc, nci, oci, nnc, ooc)
+ s[OL].bind(nc, thread_y)
+ s[OL].bind(oc, thread_z)
+
+ # Schedule local computation
+ s[ConvF].compute_at(s[OL], oc)
+ _, _, n, o, nnf, oof = ConvF.op.axis
+ ko, ki = s[ConvF].split(ic, factor=chunk)
+ s[ConvF].reorder(ko, kh, ki, kw, n, o, nnf, oof, ii)
+
+ cfg.define_reorder("reorder_inner", [ko, kh], policy="all")
+ cfg["reorder_inner"].apply(s, ConvF, [ko, kh])
+ cfg["reorder_inner"].apply(s, ConvF, [ki, kw])
+
+ cfg.define_knob("compute_at_AS", [0, 1, 2, 3])
+ cfg.define_knob("compute_at_WS", [0, 1, 2, 3])
+ compute_at_AS = cfg["compute_at_AS"].val
+ compute_at_WS = cfg["compute_at_WS"].val
+
+ # Move intermediate computation into each output compute tile
+ s[AF].compute_at(s[ConvF], kw)
+ s[WF].compute_at(s[ConvF], kw)
+
+ # Schedule for A's share memory
+ if compute_at_AS == 0:
+ s[AS].compute_at(s[ConvF], ki)
+ elif compute_at_AS == 1:
+ s[AS].compute_at(s[ConvF], kw)
+ elif compute_at_AS == 2:
+ s[AS].compute_at(s[ConvF], ko)
+ else:
+ s[AS].compute_at(s[ConvF], kh)
+ _, _, n, _, nn, ii = AS.op.axis
+ tx, xo = s[AS].split(n, nparts=block_row_warps)
+ ty, _ = s[AS].split(xo, nparts=block_col_warps)
+ t = s[AS].fuse(nn, ii)
+ to, ti = s[AS].split(t, nparts=warp_size)
+ ti, _t = s[AS].split(ti, factor=vector_as)
+ s[AS].bind(tx, thread_y)
+ s[AS].bind(ty, thread_z)
+ s[AS].bind(to, thread_x)
+ s[AS].vectorize(_t)
+
+ # Schedule for W's share memory
+ if compute_at_WS == 0:
+ s[WS].compute_at(s[ConvF], ki)
+ elif compute_at_WS == 1:
+ s[WS].compute_at(s[ConvF], kw)
+ elif compute_at_WS == 2:
+ s[WS].compute_at(s[ConvF], ko)
+ else:
+ s[WS].compute_at(s[ConvF], kh)
+ s[WS].compute_at(s[ConvF], kw)
+ kh, kw, ic, o, ii, oo = WS.op.axis
+ tx, xo = s[WS].split(o, nparts=block_row_warps)
+ ty, _ = s[WS].split(xo, nparts=block_col_warps)
+ t = s[WS].fuse(ii, oo)
+ to, ti = s[WS].split(t, nparts=warp_size)
+ ti, _t = s[WS].split(ti, factor=vector_ws)
+ s[WS].bind(tx, thread_y)
+ s[WS].bind(ty, thread_z)
+ s[WS].bind(to, thread_x)
+ s[WS].vectorize(ti)
+
+ # double buffer
+ cfg.define_knob('AS_double_buffer', [0, 1])
+ cfg.define_knob('WS_double_buffer', [0, 1])
+ if cfg['AS_double_buffer'].val:
+ s[AS].double_buffer()
+ if cfg['WS_double_buffer'].val:
+ s[WS].double_buffer()
+
+ # unroll
+ cfg.define_knob("auto_unroll_max_step", [0, 512, 1500])
+ s[output].pragma(kernel_scope, 'auto_unroll_max_step',
+ cfg['auto_unroll_max_step'].val)
+ s[output].pragma(kernel_scope, 'unroll_explicit', False)
+
+ shape = (wmma_m, wmma_n, wmma_k)
+
+ AS_shape = (wmma_m, wmma_k)
+ AL_shape = (wmma_m, wmma_k)
+ WS_shape = (wmma_n, wmma_k)
+ WL_shape = (wmma_n, wmma_k)
+ CL_shape = (wmma_m, wmma_n)
+ CS_shape = (wmma_m, wmma_n)
+
+ AL_gemm = te.placeholder(AL_shape, name='A', dtype=data_dtype)
+ WL_gemm = te.placeholder(WL_shape, name='B', dtype=kernel_dtype)
+ k_gemm = te.reduce_axis((0, wmma_k), name="k")
+ CL_compute = te.compute(CL_shape, lambda ii, jj:
+ te.sum((AL_gemm[ii, k_gemm].astype(
+ 'int32') * WL_gemm[jj,
k_gemm].astype('int32')), axis=k_gemm),
+ name='C')
+
+ AL_strides = [wmma_k, 1]
+ AS_strides = [wmma_k, 1]
+ WL_strides = [wmma_k, 1]
+ WS_strides = [wmma_k, 1]
+ CL_strides = [wmma_n, 1]
+ CS_strides = [wmma_n, 1]
+
+ s[AF].tensorize(AF.op.axis[-2],
+ intrin_wmma_load_matrix_A(AL_strides, AS_strides, shape,
+ "row_major", AS_shape, AL_shape,
data_dtype))
+
+ s[WF].tensorize(WF.op.axis[-2],
+ intrin_wmma_load_matrix_W(WL_strides, WS_strides, shape,
+ "col_major", WS_shape, WL_shape,
kernel_dtype))
+
+ s[OL].tensorize(nnc, intrin_wmma_store_matrix(CS_strides, CL_strides,
+ shape, out_dtype, CL_shape,
CS_shape))
+
+ s[ConvF].tensorize(nnf, intrin_wmma_gemm(AL_gemm, WL_gemm, CL_compute,
AL_strides,
+ WL_strides, CL_strides, shape))
+
+ return s
+
+
[email protected]_topi_schedule("conv2d_HWNCnc_tensorcore.cuda")
+def schedule_conv2d_hwnc_tensorcore(cfg, outs):
+ """TOPI schedule callback"""
+ s = te.create_schedule([x.op for x in outs])
+
+ def _callback(op):
+ if 'conv2d_HWNCnc_tensorcore' in op.tag:
+ schedule_hwnc_tensorcore_cuda(cfg, s, op.output(0))
+
+ traverse_inline(s, outs[0].op, _callback)
+ return s
diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc
index 3e6838c..2f19d6e 100644
--- a/src/target/source/codegen_c.cc
+++ b/src/target/source/codegen_c.cc
@@ -617,9 +617,13 @@ void CodeGenC::VisitExpr_(const CallNode* op,
std::ostream& os) { // NOLINT(*)
CHECK(op->args.size() == 1 && l);
os << "((";
this->PrintType(l->dtype.element_of(), os);
- os << " *)" << this->GetVarID(l->buffer_var.get()) << " + ";
+ os << " *)" << this->GetVarID(l->buffer_var.get()) << " + "
+ << "(";
this->PrintExpr(l->index, os);
- os << ')';
+ if (l->dtype.bits() == 4 || (l->dtype.bits() == 1 && l->dtype.is_int()))
{
+ os << " / " << (32 / l->dtype.bits());
+ }
+ os << "))";
} else if (op->op.same_as(builtin::tvm_struct_get())) {
CHECK_EQ(op->args.size(), 3U);
os << GetStructRef(op->dtype, op->args[0], op->args[1],
op->args[2].as<IntImmNode>()->value);
diff --git a/tests/python/topi/python/test_topi_conv2d_hwnc_tensorcore.py
b/tests/python/topi/python/test_topi_conv2d_hwnc_tensorcore.py
new file mode 100644
index 0000000..2c071c9
--- /dev/null
+++ b/tests/python/topi/python/test_topi_conv2d_hwnc_tensorcore.py
@@ -0,0 +1,133 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, too-many-locals, too-many-arguments
+"""Example code to do convolution."""
+
+import numpy as np
+import tvm
+import os
+import tvm.topi.testing
+from tvm import te, autotvm, topi
+from tvm.contrib.pickle_memoize import memoize
+from tvm.contrib import nvcc
+from tvm.topi.nn.util import get_pad_tuple
+from tvm.topi.util import get_const_tuple
+
+_conv2d_hwnc_tensorcore_implement = {
+ "cuda": (topi.cuda.conv2d_hwnc_tensorcore,
topi.cuda.schedule_conv2d_hwnc_tensorcore)
+}
+
+def verify_conv2d_hwnc(batch, in_channel, in_size, num_filter, kernel, stride,
+ padding, dilation=1, devices='cuda', dtype='int4'):
+ """Test the conv2d with tensorcore for hwnc layout"""
+ pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel,
kernel))
+ padding_sum = pad_top + pad_left + pad_bottom + pad_right
+ print("Workload: (%d, %d, %d, %d, %d, %d, %d, %d)" % (
+ batch, in_channel, in_size, num_filter, kernel, stride, padding_sum,
dilation))
+ # choose dtype from int4, int8
+ assert dtype in ['int4', 'int8']
+
+ in_height = in_width = in_size
+
+ A = te.placeholder((in_height, in_width, batch, in_channel), name='A',
dtype=dtype)
+ W = te.placeholder((kernel, kernel, num_filter, in_channel), name='W',
dtype=dtype)
+
+ a_shape = get_const_tuple(A.shape)
+ w_shape = get_const_tuple(W.shape)
+ @memoize("topi.tests.test_topi_conv2d_hwnc.verify_conv2d_hwnc")
+ def get_ref_data():
+ if dtype == 'int4':
+ a_np = np.random.randint(low=-8, high=7,
size=a_shape).transpose((2, 0, 1, 3))
+ w_np = np.random.randint(low=-8, high=7, size=w_shape)
+ dw_np = topi.testing.dilate_python(w_np.transpose((0, 1, 3, 2)),
(1, 1, dilation, dilation))
+ elif dtype == 'int8':
+ a_np = np.random.randint(low=-128, high=127,
size=a_shape).transpose((2, 0, 1, 3)).astype(dtype)
+ w_np = np.random.randint(low=-128, high=127,
size=w_shape).astype(dtype)
+ dw_np = topi.testing.dilate_python(w_np.transpose((0, 1, 3, 2)),
(1, 1, dilation, dilation))
+
+ c_np = topi.testing.conv2d_nhwc_python(a_np, dw_np, stride, padding)
+ return a_np, w_np, c_np
+
+ def convert_int32_into_int4(a_int32):
+ """ convert int32 values into int4
+ Parameters
+ ----------
+ a_int32 : int
+
+ Return
+ ------
+ a_int4 : int
+ """
+ I, J, K, L = a_int32.shape
+ a_int4 = np.zeros(shape=(I, J, K, L // 8), dtype=np.int32)
+ for i in range(I):
+ for j in range(J):
+ for k in range(K):
+ for l in range(L // 8):
+ for m in range(min(8, L-l*8)):
+ a_int4[i, j, k, l] = a_int4[i, j, k, l] |
((a_int32[i, j, k, l * 8 + m] & 0xf) << ((7 - m) * 4))
+ return a_int4
+
+ a_np, w_np, c_np = get_ref_data()
+ if dtype == 'int4':
+ a_np = convert_int32_into_int4(a_np)
+ w_np = convert_int32_into_int4(w_np)
+
+ def check_device(device):
+ ctx = tvm.context(device, 0)
+ if not ctx.exist:
+ print("Skip because %s is not enabled" % device)
+ return
+ if not nvcc.have_tensorcore(ctx.compute_version):
+ print("skip because gpu does not support Tensor Cores")
+ return
+ print("Running on target: %s" % device)
+ with tvm.target.create(device):
+ fcompute, fschedule = topi.testing.dispatch(device,
_conv2d_hwnc_tensorcore_implement)
+ C = fcompute(A, W, stride, padding, dilation, dtype, 'int32')
+ s = fschedule([C])
+
+ a = tvm.nd.array(a_np.transpose((1, 2, 0, 3)), ctx)
+ w = tvm.nd.array(w_np, ctx)
+ c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype),
ctx)
+
+ func = tvm.build(s, [A, W, C], device,
name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (
+ batch, in_channel, in_size, num_filter, kernel, stride,
padding_sum, dilation))
+ func(a, w, c)
+
+ rtol = 1e-3
+ tvm.testing.assert_allclose(c.asnumpy().transpose((2, 0, 1, 3)), c_np,
rtol=rtol)
+
+ check_device(devices)
+
+
+def test_conv2d_hwnc_tensorcore():
+ """Test the conv2d with tensorcore for hwnc layout"""
+ verify_conv2d_hwnc(8, 64, 56, 64, 3, 1, 1, dtype='int8')
+ verify_conv2d_hwnc(8, 64, 56, 64, 1, 1, 0, dtype='int4')
+ verify_conv2d_hwnc(8, 64, 56, 128, 3, 2, 1)
+ verify_conv2d_hwnc(8, 64, 56, 64, 1, 2, 0)
+ verify_conv2d_hwnc(8, 128, 28, 128, 3, 1, 1)
+ verify_conv2d_hwnc(8, 128, 28, 256, 3, 2, 1)
+ verify_conv2d_hwnc(8, 128, 28, 256, 1, 2, 0)
+ verify_conv2d_hwnc(8, 256, 14, 256, 3, 1, 1)
+ verify_conv2d_hwnc(8, 256, 14, 512, 3, 2, 1)
+ verify_conv2d_hwnc(8, 256, 14, 512, 1, 2, 0)
+ verify_conv2d_hwnc(8, 512, 9, 512, 3, 1, 1)
+
+if __name__ == "__main__":
+ test_conv2d_hwnc_tensorcore()