comaniac commented on a change in pull request #8480:
URL: https://github.com/apache/tvm/pull/8480#discussion_r670674834



##########
File path: python/tvm/topi/arm_cpu/conv2d_spatial_pack.py
##########
@@ -273,7 +273,8 @@ def conv2d_spatial_pack_nhwc(cfg, data, kernel, strides, 
padding, dilation, out_
     data_pad = nn.pad(data, [0, pad_top, pad_left, 0], [0, pad_down, 
pad_right, 0])
 
     # ==================== define configuration space ====================
-    n, oc, oh, ow = cfg.axis(N), cfg.axis(OC), cfg.axis(OH), cfg.axis(OW)
+    n = cfg.axis(N) if isinstance(N, int) else cfg.axis(1)

Review comment:
       Does it mean to use `1` as the batch size when it is dynamic? Better to 
add a comment if so.

##########
File path: python/tvm/relay/op/nn/_nn.py
##########
@@ -975,23 +976,72 @@ def _conv_shape_func(dshape, kshape, strides, padding, 
dilation):
     return out
 
 
+@script
+def _conv_shape_func_nhwc_hwio(dshape, kshape, strides, padding, dilation):
+    """Shape function for conv*d op with nhwc & hwio layout."""
+    out = output_tensor((dshape.shape[0],), "int64")
+    out[0] = dshape[0]
+    out[dshape.shape[0] - 1] = kshape[kshape.shape[0] - 1]
+
+    for i in const_range(dshape.shape[0] - 2):
+        dilated_k = (kshape[i] - 1) * dilation[i] + 1
+        out[i + 1] = (dshape[i + 1] + 2 * padding[i] - dilated_k) // 
strides[i] + 1
+    return out
+
+
+@script
+def _conv_shape_func_nhwc_hwoi(dshape, kshape, strides, padding, dilation):
+    """Shape function for conv*d op with nhwc & hwoi layout."""
+    out = output_tensor((dshape.shape[0],), "int64")
+    out[0] = dshape[0]
+    out[dshape.shape[0] - 1] = kshape[kshape.shape[0] - 2]
+
+    for i in const_range(dshape.shape[0] - 2):
+        dilated_k = (kshape[i] - 1) * dilation[i] + 1
+        out[i + 1] = (dshape[i + 1] + 2 * padding[i] - dilated_k) // 
strides[i] + 1
+    return out
+
+
 def conv_shape_func(attrs, inputs, _):
-    """
-    Shape function for contrib_conv2d_NCHWc op.
-    """
+    """Shape function for conv*d op."""
     strides = get_const_tuple(attrs.strides)
     padding = get_const_tuple(attrs.padding)
     dilation = get_const_tuple(attrs.dilation)
 
-    return [
-        _conv_shape_func(
-            inputs[0],
-            inputs[1],
-            convert(strides),
-            convert(padding),
-            convert(dilation),
-        )
-    ]
+    if attrs["data_layout"] == "NCHW" and attrs["kernel_layout"] == "OIHW":
+        return [
+            _conv_shape_func_nchw(
+                inputs[0],
+                inputs[1],
+                convert(strides),
+                convert(padding),
+                convert(dilation),
+            )
+        ]
+    if attrs["data_layout"] == "NHWC":
+        if attrs["kernel_layout"] == "HWIO":
+            return [
+                _conv_shape_func_nhwc_hwio(
+                    inputs[0],
+                    inputs[1],
+                    convert(strides),
+                    convert(padding),
+                    convert(dilation),
+                )
+            ]
+        if attrs["kernel_layout"] == "HWOI":
+            return [
+                _conv_shape_func_nhwc_hwoi(
+                    inputs[0],
+                    inputs[1],
+                    convert(strides),
+                    convert(padding),
+                    convert(dilation),
+                )
+            ]
+    raise ValueError(
+        "Unsupported data/kernel layout: %s, %s" % (attrs["data_layout"], 
attrs["kernel_layout"])
+    )

Review comment:
       ```suggestion
       shape_func = None
       if attrs["data_layout"] == "NCHW" and attrs["kernel_layout"] == "OIHW":
           shape_func = conv_shape_func_nchw
       elif attrs["data_layout"] == "NHWC" and attrs["kernel_layout"] == "HWIO":
           shape_func = conv_shape_func_nhwc_hwio
       elif attrs["data_layout"] == "NHWC" and attrs["kernel_layout"] == "HWOI":
           shape_func = conv_shape_func_nhwc_hwoi
       else:
           raise ValueError(
               "Unsupported data/kernel layout: %s, %s" % 
(attrs["data_layout"], attrs["kernel_layout"])
           )
           
       return [shape_func(inputs[0], inputs[1], convert(strides), 
convert(padding), convert(dilation))]
   ```




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to