This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new f3501d85b7 [Unity][Fix][Op] Add groups to conv1d (#15457)
f3501d85b7 is described below

commit f3501d85b7ccae9e1ee3d390a40c667220dafae6
Author: Lesheng Jin <[email protected]>
AuthorDate: Wed Aug 2 16:44:26 2023 -0700

    [Unity][Fix][Op] Add groups to conv1d (#15457)
    
    upd
---
 python/tvm/relax/transform/legalize_ops/nn.py      |  1 +
 python/tvm/topi/nn/conv1d.py                       |  5 ++++-
 python/tvm/topi/nn/conv2d.py                       |  2 +-
 .../python/relax/test_transform_legalize_ops_nn.py | 26 +++++++++++-----------
 4 files changed, 19 insertions(+), 15 deletions(-)

diff --git a/python/tvm/relax/transform/legalize_ops/nn.py 
b/python/tvm/relax/transform/legalize_ops/nn.py
index e4e608e769..562b497cb2 100644
--- a/python/tvm/relax/transform/legalize_ops/nn.py
+++ b/python/tvm/relax/transform/legalize_ops/nn.py
@@ -60,6 +60,7 @@ def _nn_conv1d(bb: BlockBuilder, call: Call) -> Expr:
         strides=call.attrs.strides,
         padding=call.attrs.padding,
         dilation=call.attrs.dilation,
+        groups=call.attrs.groups,
         data_layout=call.attrs.data_layout,
         kernel_layout=call.attrs.kernel_layout,
         out_dtype=call.attrs.out_dtype if call.attrs.out_dtype != "" else None,
diff --git a/python/tvm/topi/nn/conv1d.py b/python/tvm/topi/nn/conv1d.py
index ee388b4297..af856c2ed5 100644
--- a/python/tvm/topi/nn/conv1d.py
+++ b/python/tvm/topi/nn/conv1d.py
@@ -25,6 +25,7 @@ def conv1d(
     strides=1,
     padding="VALID",
     dilation=1,
+    groups=1,
     data_layout="NCW",
     kernel_layout="",
     out_dtype=None,
@@ -60,7 +61,9 @@ def conv1d(
     out_dtype : str
         The output data type. If None then output is same type as input.
     """
-    return conv(data, kernel, strides, padding, dilation, 1, data_layout, 
kernel_layout, out_dtype)
+    return conv(
+        data, kernel, strides, padding, dilation, groups, data_layout, 
kernel_layout, out_dtype
+    )
 
 
 def conv1d_nwc(data, kernel, strides=1, padding="VALID", dilation=1, 
out_dtype=None):
diff --git a/python/tvm/topi/nn/conv2d.py b/python/tvm/topi/nn/conv2d.py
index f70d749e0f..0bb298ecce 100644
--- a/python/tvm/topi/nn/conv2d.py
+++ b/python/tvm/topi/nn/conv2d.py
@@ -885,7 +885,7 @@ def conv(
     # compute the output shape
     out_channel = num_filter
     out_dimensions = [
-        simplify(d - (k - 1) * dil - 1 + pb + pe) // stride + 1
+        simplify((d - (k - 1) * dil - 1 + pb + pe) // stride + 1)
         for d, k, dil, pb, pe, stride in zip(
             dimensions, kernel_dimensions, dilations, pad_begin, pad_end, 
strides
         )
diff --git a/tests/python/relax/test_transform_legalize_ops_nn.py 
b/tests/python/relax/test_transform_legalize_ops_nn.py
index 0737b2784c..c2acd52105 100644
--- a/tests/python/relax/test_transform_legalize_ops_nn.py
+++ b/tests/python/relax/test_transform_legalize_ops_nn.py
@@ -44,23 +44,23 @@ def test_conv1d():
             return gv
 
         @T.prim_func(private=True)
-        def conv1d(rxplaceholder: T.Buffer((T.int64(2), T.int64(128), 
T.int64(28)), "float32"), rxplaceholder_1: T.Buffer((T.int64(64), T.int64(16), 
T.int64(3)), "float32"), conv1d_ncw: T.Buffer((T.int64(2), T.int64(64), 
T.int64(13)), "float32")):
+        def conv1d(A: T.Buffer((T.int64(2), T.int64(128), T.int64(28)), 
"float32"), B: T.Buffer((T.int64(64), T.int64(16), T.int64(3)), "float32"), 
group_conv1d_ncw: T.Buffer((T.int64(2), T.int64(64), T.int64(13)), "float32")):
             T.func_attr({"tir.noalias": True})
             pad_temp = T.alloc_buffer((T.int64(2), T.int64(128), T.int64(30)))
             for i0, i1, i2 in T.grid(T.int64(2), T.int64(128), T.int64(30)):
                 with T.block("pad_temp"):
                     v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
-                    T.reads(rxplaceholder[v_i0, v_i1, v_i2 - T.int64(1)])
+                    T.reads(A[v_i0, v_i1, v_i2 - T.int64(1)])
                     T.writes(pad_temp[v_i0, v_i1, v_i2])
-                    pad_temp[v_i0, v_i1, v_i2] = T.if_then_else(T.int64(1) <= 
v_i2 and v_i2 < T.int64(29), rxplaceholder[v_i0, v_i1, v_i2 - T.int64(1)], 
T.float32(0))
-            for nn, ff, yy, rc, ry in T.grid(T.int64(2), T.int64(64), 
T.int64(13), T.int64(128), T.int64(3)):
-                with T.block("conv1d_ncw"):
+                    pad_temp[v_i0, v_i1, v_i2] = T.if_then_else(T.int64(1) <= 
v_i2 and v_i2 < T.int64(29), A[v_i0, v_i1, v_i2 - T.int64(1)], T.float32(0))
+            for nn, ff, yy, rc, ry in T.grid(T.int64(2), T.int64(64), 
T.int64(13), T.int64(16), T.int64(3)):
+                with T.block("group_conv1d_ncw"):
                     v_nn, v_ff, v_yy, v_rc, v_ry = T.axis.remap("SSSRR", [nn, 
ff, yy, rc, ry])
-                    T.reads(pad_temp[v_nn, v_rc, v_yy * T.int64(2) + v_ry * 
T.int64(2)], rxplaceholder_1[v_ff, v_rc, v_ry])
-                    T.writes(conv1d_ncw[v_nn, v_ff, v_yy])
+                    T.reads(pad_temp[v_nn, v_ff // T.int64(8) * T.int64(16) + 
v_rc, v_yy * T.int64(2) + v_ry * T.int64(2)], B[v_ff, v_rc, v_ry])
+                    T.writes(group_conv1d_ncw[v_nn, v_ff, v_yy])
                     with T.init():
-                        conv1d_ncw[v_nn, v_ff, v_yy] = T.float32(0)
-                    conv1d_ncw[v_nn, v_ff, v_yy] = conv1d_ncw[v_nn, v_ff, 
v_yy] + pad_temp[v_nn, v_rc, v_yy * T.int64(2) + v_ry * T.int64(2)] * 
rxplaceholder_1[v_ff, v_rc, v_ry]
+                        group_conv1d_ncw[v_nn, v_ff, v_yy] = T.float32(0)
+                    group_conv1d_ncw[v_nn, v_ff, v_yy] = 
group_conv1d_ncw[v_nn, v_ff, v_yy] + pad_temp[v_nn, v_ff // T.int64(8) * 
T.int64(16) + v_rc, v_yy * T.int64(2) + v_ry * T.int64(2)] * B[v_ff, v_rc, v_ry]
     # fmt: on
 
     mod = LegalizeOps()(Conv1d)
@@ -171,7 +171,7 @@ def test_conv1d_symbolic():
             w = T.int64()
             kw = T.int64()
             c = T.int64()
-            gv = R.call_tir(Expected.conv1d, (x, kernel), 
out_sinfo=R.Tensor((n, f, w - kw + 1), dtype="float32"))
+            gv = R.call_tir(Expected.conv1d, (x, kernel), 
out_sinfo=R.Tensor((n, f, w + 1 - kw), dtype="float32"))
             return gv
 
         @T.prim_func(private=True)
@@ -181,7 +181,7 @@ def test_conv1d_symbolic():
             rxplaceholder = T.match_buffer(var_rxplaceholder, (n, c, w))
             f, kw = T.int64(), T.int64()
             rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, (f, c, kw))
-            conv1d_ncw = T.match_buffer(var_conv1d_ncw, (n, f, w - kw + 
T.int64(1)))
+            conv1d_ncw = T.match_buffer(var_conv1d_ncw, (n, f, w + T.int64(1) 
- kw))
             # with T.block("root"):
             pad_temp = T.alloc_buffer((n, c, w))
             for i0, i1, i2 in T.grid(n, c, w):
@@ -349,7 +349,7 @@ def test_conv2d_symbolic():
             kh = T.int64()
             w = T.int64()
             kw = T.int64()
-            gv = R.call_tir(Expected.conv2d, (x, kernel), R.Tensor((n, f, ((h 
- kh) + 1), ((w - kw) + 1)), dtype="float32"))
+            gv = R.call_tir(Expected.conv2d, (x, kernel), R.Tensor((n, f, h + 
1 - kh, w + 1 - kw), dtype="float32"))
             return gv
 
         @T.prim_func(private=True)
@@ -364,7 +364,7 @@ def test_conv2d_symbolic():
             w = T.int64()
             rxplaceholder = T.match_buffer(var_rxplaceholder, [n, c, h, w], 
dtype="float32")
             rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [f, c, kh, 
kw], dtype="float32")
-            conv2d_nchw = T.match_buffer(var_conv2d_nchw, [n, f, h - kh + 
T.int64(1), w - kw + T.int64(1)], dtype="float32")
+            conv2d_nchw = T.match_buffer(var_conv2d_nchw, [n, f, h + 
T.int64(1) - kh, w + T.int64(1) - kw], dtype="float32")
             pad_temp = T.alloc_buffer([n, c, h, w], dtype="float32")
             for i0, i1, i2, i3 in T.grid(n, c, h, w):
                 with T.block("pad_temp"):

Reply via email to