Meteorix commented on a change in pull request #7147:
URL: https://github.com/apache/tvm/pull/7147#discussion_r564190200



##########
File path: python/tvm/topi/cuda/conv2d_alter_op.py
##########
@@ -345,4 +347,51 @@ def _conv2d_legalize(attrs, inputs, arg_types):
             else:
                 out = relay.nn.conv2d(data, kernel, **new_attrs)
             return out
+    elif data_dtype in ["float16"]:  # todo: support int8/int4
+        if data_layout == "NHWC" and kernel_layout == "HWIO":
+            batch = data_tensor.shape[0].value
+            in_channel = data_tensor.shape[3].value
+            out_channel = kernel_tensor.shape[3].value
+
+            if (
+                (batch % 8 == 0 and in_channel % 16 == 0 and out_channel % 32 
== 0)
+                or (batch % 16 == 0 and in_channel % 16 == 0 and out_channel % 
16 == 0)
+                or (batch % 32 == 0 and in_channel % 16 == 0 and out_channel % 
8 == 0)
+            ):
+                # no need to pad
+                return None
+
+            (db, di, do), extra_flops = pad_to_tensorcore(batch, in_channel, 
out_channel)
+
+            if extra_flops > 2:
+                logger.info("conv2d pad_to_tensorcore skipped, extra_flops 
%s", extra_flops)
+                return None
+
+            logger.info("conv2d pad_to_tensorcore, extra_flops %s", 
extra_flops)
+
+            # Pad batch size
+            if db != 0:
+                data = relay.nn.pad(data, pad_width=((0, db), (0, 0), (0, 0), 
(0, 0)))
+
+            # Pad input channel
+            if di != 0:
+                data = relay.nn.pad(data, pad_width=((0, 0), (0, 0), (0, 0), 
(0, di)))
+                kernel = relay.nn.pad(kernel, pad_width=((0, 0), (0, 0), (0, 
di), (0, 0)))
+
+            # Pad output channel
+            if do != 0:
+                kernel = relay.nn.pad(kernel, pad_width=((0, 0), (0, 0), (0, 
0), (0, do)))
+
+            if do != 0:
+                new_out_channel = out_channel + do
+                new_attrs["channels"] = new_out_channel
+                out = tvm.relay.nn.conv2d(data, kernel, **new_attrs)
+            else:
+                out = relay.nn.conv2d(data, kernel, **new_attrs)

Review comment:
       good point




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to