masahi commented on code in PR #15011:
URL: https://github.com/apache/tvm/pull/15011#discussion_r1257571717


##########
python/tvm/topi/cuda/conv2d.py:
##########
@@ -148,3 +154,352 @@ def conv2d_backward_weight_cudnn(
         conv_dtype=conv_dtype,
         groups=groups,
     )
+
+
[email protected]_topi_compute("conv2d_nchw_mma.cuda")
+def conv2d_nchw_mma(cfg, data, kernel, strides, padding, dilation, 
out_dtype="float32"):
+    """Compute conv2d nchw using im2col"""
+    assert data.dtype == "float16"
+    out_channels, in_channels, kernel_h, kernel_w = 
get_const_tuple(kernel.shape)
+
+    if isinstance(dilation, int):
+        dilation_h = dilation_w = dilation
+    else:
+        dilation_h, dilation_w = dilation
+    assert dilation_h == 1 and dilation_w == 1
+
+    if isinstance(strides, int):
+        stride_h = stride_w = strides
+    else:
+        stride_h, stride_w = strides
+
+    batch_size, _, P, Q = get_output_shape(
+        data, kernel, stride_h, stride_w, dilation_h, dilation_w, padding
+    )
+    assert batch_size == 1
+
+    pad_top, pad_left, pad_down, pad_right = get_pad_tuple(padding, (kernel_h, 
kernel_w))
+    pad_before = [0, 0, pad_top, pad_left]
+    pad_after = [0, 0, pad_down, pad_right]
+
+    if all([v == 0 for v in pad_before]) and all([v == 0 for v in pad_after]):
+        pad_data = data
+    else:
+        pad_data = pad(data, pad_before, pad_after, name="pad_data")
+
+    M = out_channels
+    K = in_channels * kernel_h * kernel_w
+    N = batch_size * P * Q
+
+    if kernel_h * kernel_w == 1:
+        ck = te.reduce_axis((0, K), name="k")
+        A = reshape(kernel, (M, K))
+        B = reshape(pad_data, (K, N))
+        C = te.compute(
+            (batch_size, out_channels, P, Q),
+            lambda b, o, h, w: te.sum(
+                A[o, ck].astype(out_dtype) * B[ck, h * Q + 
w].astype(out_dtype),
+                axis=[ck],
+            ),
+            name="conv2d_nchw_mma",
+            attrs={
+                "schedule_rule": "conv2d_nchw_mma",
+            },
+        )
+    else:
+        # Convert the kernel of (O,I,H,W) to (N,K) format i.e (OC,IC*KH*KW)
+        A = te.compute(
+            (M, K),
+            lambda x, y: kernel[
+                x, (y // (kernel_h * kernel_w)), (y // kernel_w) % kernel_h, y 
% kernel_w
+            ],
+            name="T_reshape",
+        )
+
+        # Convert the data of (N,C,H,W) to (K,M) format i.e (IC*KH*KW,OH*OW)
+        B = te.compute(
+            (K, N),
+            lambda y, x: pad_data[
+                0,
+                y // (kernel_h * kernel_w),
+                stride_h * (x // Q) + ((y // kernel_w) % kernel_h),
+                stride_w * (x % Q) + y % kernel_w,
+            ],
+            name="T_reshape_1",
+        )
+
+        # Apply GEMM operation. The result will be of (N,O,H,W) format
+        ck = te.reduce_axis((0, K), name="k")
+        C = te.compute(
+            (batch_size, out_channels, P, Q),
+            lambda b, o, h, w: te.sum(
+                A[o, ck].astype(out_dtype) * B[ck, h * Q + 
w].astype(out_dtype),
+                axis=[ck],
+            ),
+            name="conv2d_nchw_mma",
+            attrs={
+                "schedule_rule": "conv2d_nchw_mma",
+            },
+        )
+    return C
+
+
+def schedule_rule_conv2d_nchw_mma(sch: Schedule, block: BlockRV):

Review Comment:
   Rather than implementing a manual schedule like this, I recommend trying to 
get auto tensorization for tensor core working for vulkan. I expect the 
resulting schedule to be very similar.   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to