Shawn-Inspur commented on a change in pull request #6121:
URL: https://github.com/apache/incubator-tvm/pull/6121#discussion_r467373678
##########
File path: topi/python/topi/cuda/conv2d_hwnc_tensorcore.py
##########
@@ -0,0 +1,402 @@
+import numpy as np
+import tvm
+from tvm import te
+from tvm import autotvm
+from ..util import get_const_tuple, traverse_inline, simplify, tag
+from ..nn.pad import pad
+from ..nn.util import get_pad_tuple
+from topi.cuda.injective import schedule_injective_from_existing
+from .tensor_intrin import intrin_wmma_load_matrix_A,
intrin_wmma_load_matrix_W, intrin_wmma_store_matrix, intrin_wmma_gemm
+
+def unpack_HWNCnc_to_hwnc(packed_out, out_dtype):
+ """Unpack conv2d_hwnc output from layout hwncnc to hwnc
+
+ Parameters
+ -----------
+ packed_out : tvm.te.Tensor
+ The output tensor of conv2d_hwnc.
+
+ out_dtype : str
+ The output dtype.
+
+ Returns
+ -------
+ unpacked_out : tvm.te.Tensor
+ The unpacked output tensor in hwnc layout.
+ """
+ H, W, N, O, wmma_m, wmma_n = get_const_tuple(packed_out.shape)
+
+ idxmod = tvm.tir.indexmod
+ idxdiv = tvm.tir.indexdiv
+
+ oshape = (H, W, N * wmma_m, O * wmma_n)
+ unpacked_out = \
+ te.compute(oshape,
+ lambda h, w, n, o:
+ packed_out[h, w, idxdiv(n, wmma_m), idxdiv(o, wmma_n),
idxmod(n, wmma_m), idxmod(o, wmma_n)]
+ .astype(out_dtype),
+ name='output_unpack',
+ tag=tag.INJECTIVE+",unpack_hwncc")
+ return unpacked_out
+
+def conv2d_hwnc_tensorcore(data, kernel, strides, padding, dilation, in_dtype,
out_dtype='int32'):
+ """Compute conv2d internally using conv2d_nchwc layout for int8 dtype"""
+ assert data.dtype in ('int4', 'uint4', 'int8', 'uint8')
+ assert kernel.dtype in ('int4', 'uint4', 'int8', 'uint8')
+ # assert data.dtype == kernel.dtype
+ packed_out = hwnc_tensorcore_cuda(data, kernel, strides, padding,
dilation, out_dtype)
+ return unpack_HWNCnc_to_hwnc(packed_out, out_dtype)
+
[email protected]_topi_compute("conv2d_HWNCnc_tensorcore.cuda")
+def hwnc_tensorcore_cuda(cfg, Input, Filter, stride, padding, dilation,
out_dtype='int32'):
+ """Compute declaration for tensorcore"""
+ assert isinstance(stride, int) or len(stride) == 2
+ assert isinstance(dilation, int) or len(dilation) == 2
+
+ if isinstance(stride, int):
+ stride_h = stride_w = stride
+ else:
+ stride_h, stride_w = stride
+
+ if isinstance(dilation, int):
+ dilation_h = dilation_w = dilation
+ else:
+ dilation_h, dilation_w = dilation
+
+ in_dtype = Input.dtype
+
+ if in_dtype in ['int4', 'uint4']:
+ wmma_n = wmma_m = 8
+ wmma_k = 32
+ else:
+ wmma_m = 8
+ wmma_n = 32
+ wmma_k = 16
+
+ pre_computed = len(Filter.shape) == 6
+ in_height, in_width, batch, in_channels = get_const_tuple(Input.shape)
+ if pre_computed:
+ kernel_h, kernel_w, oc_chunk, ic_chunk, oc_block_factor,
ic_block_factor = get_const_tuple(Filter.shape)
+ num_filter = oc_block_factor * oc_chunk
+ else:
+ kernel_h, kernel_w, num_filter, _ = get_const_tuple(Filter.shape)
+
+ if in_dtype in ['int4', 'uint4']:
+ assert (batch % 8 == 0 and in_channels % 32 == 0 and num_filter % 8 ==
0)
+ else:
+ assert (batch % 16 == 0 and in_channels % 16 == 0 and num_filter % 16
== 0) or \
+ (batch % 8 == 0 and in_channels % 16 == 0 and num_filter % 32
== 0) or \
+ (batch % 32 == 0 and in_channels % 16 == 0 and num_filter % 8
== 0), \
Review comment:
As indicated by lines 72-74, there is only one shape that can be support
by non-int4 case. However the assertion here including three shapes regarding
m,n,k, which is confused.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]