This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new f3501d85b7 [Unity][Fix][Op] Add groups to conv1d (#15457)
f3501d85b7 is described below
commit f3501d85b7ccae9e1ee3d390a40c667220dafae6
Author: Lesheng Jin <[email protected]>
AuthorDate: Wed Aug 2 16:44:26 2023 -0700
[Unity][Fix][Op] Add groups to conv1d (#15457)
upd
---
python/tvm/relax/transform/legalize_ops/nn.py | 1 +
python/tvm/topi/nn/conv1d.py | 5 ++++-
python/tvm/topi/nn/conv2d.py | 2 +-
.../python/relax/test_transform_legalize_ops_nn.py | 26 +++++++++++-----------
4 files changed, 19 insertions(+), 15 deletions(-)
diff --git a/python/tvm/relax/transform/legalize_ops/nn.py
b/python/tvm/relax/transform/legalize_ops/nn.py
index e4e608e769..562b497cb2 100644
--- a/python/tvm/relax/transform/legalize_ops/nn.py
+++ b/python/tvm/relax/transform/legalize_ops/nn.py
@@ -60,6 +60,7 @@ def _nn_conv1d(bb: BlockBuilder, call: Call) -> Expr:
strides=call.attrs.strides,
padding=call.attrs.padding,
dilation=call.attrs.dilation,
+ groups=call.attrs.groups,
data_layout=call.attrs.data_layout,
kernel_layout=call.attrs.kernel_layout,
out_dtype=call.attrs.out_dtype if call.attrs.out_dtype != "" else None,
diff --git a/python/tvm/topi/nn/conv1d.py b/python/tvm/topi/nn/conv1d.py
index ee388b4297..af856c2ed5 100644
--- a/python/tvm/topi/nn/conv1d.py
+++ b/python/tvm/topi/nn/conv1d.py
@@ -25,6 +25,7 @@ def conv1d(
strides=1,
padding="VALID",
dilation=1,
+ groups=1,
data_layout="NCW",
kernel_layout="",
out_dtype=None,
@@ -60,7 +61,9 @@ def conv1d(
out_dtype : str
The output data type. If None then output is same type as input.
"""
- return conv(data, kernel, strides, padding, dilation, 1, data_layout,
kernel_layout, out_dtype)
+ return conv(
+ data, kernel, strides, padding, dilation, groups, data_layout,
kernel_layout, out_dtype
+ )
def conv1d_nwc(data, kernel, strides=1, padding="VALID", dilation=1,
out_dtype=None):
diff --git a/python/tvm/topi/nn/conv2d.py b/python/tvm/topi/nn/conv2d.py
index f70d749e0f..0bb298ecce 100644
--- a/python/tvm/topi/nn/conv2d.py
+++ b/python/tvm/topi/nn/conv2d.py
@@ -885,7 +885,7 @@ def conv(
# compute the output shape
out_channel = num_filter
out_dimensions = [
- simplify(d - (k - 1) * dil - 1 + pb + pe) // stride + 1
+ simplify((d - (k - 1) * dil - 1 + pb + pe) // stride + 1)
for d, k, dil, pb, pe, stride in zip(
dimensions, kernel_dimensions, dilations, pad_begin, pad_end,
strides
)
diff --git a/tests/python/relax/test_transform_legalize_ops_nn.py
b/tests/python/relax/test_transform_legalize_ops_nn.py
index 0737b2784c..c2acd52105 100644
--- a/tests/python/relax/test_transform_legalize_ops_nn.py
+++ b/tests/python/relax/test_transform_legalize_ops_nn.py
@@ -44,23 +44,23 @@ def test_conv1d():
return gv
@T.prim_func(private=True)
- def conv1d(rxplaceholder: T.Buffer((T.int64(2), T.int64(128),
T.int64(28)), "float32"), rxplaceholder_1: T.Buffer((T.int64(64), T.int64(16),
T.int64(3)), "float32"), conv1d_ncw: T.Buffer((T.int64(2), T.int64(64),
T.int64(13)), "float32")):
+ def conv1d(A: T.Buffer((T.int64(2), T.int64(128), T.int64(28)),
"float32"), B: T.Buffer((T.int64(64), T.int64(16), T.int64(3)), "float32"),
group_conv1d_ncw: T.Buffer((T.int64(2), T.int64(64), T.int64(13)), "float32")):
T.func_attr({"tir.noalias": True})
pad_temp = T.alloc_buffer((T.int64(2), T.int64(128), T.int64(30)))
for i0, i1, i2 in T.grid(T.int64(2), T.int64(128), T.int64(30)):
with T.block("pad_temp"):
v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
- T.reads(rxplaceholder[v_i0, v_i1, v_i2 - T.int64(1)])
+ T.reads(A[v_i0, v_i1, v_i2 - T.int64(1)])
T.writes(pad_temp[v_i0, v_i1, v_i2])
- pad_temp[v_i0, v_i1, v_i2] = T.if_then_else(T.int64(1) <=
v_i2 and v_i2 < T.int64(29), rxplaceholder[v_i0, v_i1, v_i2 - T.int64(1)],
T.float32(0))
- for nn, ff, yy, rc, ry in T.grid(T.int64(2), T.int64(64),
T.int64(13), T.int64(128), T.int64(3)):
- with T.block("conv1d_ncw"):
+ pad_temp[v_i0, v_i1, v_i2] = T.if_then_else(T.int64(1) <=
v_i2 and v_i2 < T.int64(29), A[v_i0, v_i1, v_i2 - T.int64(1)], T.float32(0))
+ for nn, ff, yy, rc, ry in T.grid(T.int64(2), T.int64(64),
T.int64(13), T.int64(16), T.int64(3)):
+ with T.block("group_conv1d_ncw"):
v_nn, v_ff, v_yy, v_rc, v_ry = T.axis.remap("SSSRR", [nn,
ff, yy, rc, ry])
- T.reads(pad_temp[v_nn, v_rc, v_yy * T.int64(2) + v_ry *
T.int64(2)], rxplaceholder_1[v_ff, v_rc, v_ry])
- T.writes(conv1d_ncw[v_nn, v_ff, v_yy])
+ T.reads(pad_temp[v_nn, v_ff // T.int64(8) * T.int64(16) +
v_rc, v_yy * T.int64(2) + v_ry * T.int64(2)], B[v_ff, v_rc, v_ry])
+ T.writes(group_conv1d_ncw[v_nn, v_ff, v_yy])
with T.init():
- conv1d_ncw[v_nn, v_ff, v_yy] = T.float32(0)
- conv1d_ncw[v_nn, v_ff, v_yy] = conv1d_ncw[v_nn, v_ff,
v_yy] + pad_temp[v_nn, v_rc, v_yy * T.int64(2) + v_ry * T.int64(2)] *
rxplaceholder_1[v_ff, v_rc, v_ry]
+ group_conv1d_ncw[v_nn, v_ff, v_yy] = T.float32(0)
+ group_conv1d_ncw[v_nn, v_ff, v_yy] =
group_conv1d_ncw[v_nn, v_ff, v_yy] + pad_temp[v_nn, v_ff // T.int64(8) *
T.int64(16) + v_rc, v_yy * T.int64(2) + v_ry * T.int64(2)] * B[v_ff, v_rc, v_ry]
# fmt: on
mod = LegalizeOps()(Conv1d)
@@ -171,7 +171,7 @@ def test_conv1d_symbolic():
w = T.int64()
kw = T.int64()
c = T.int64()
- gv = R.call_tir(Expected.conv1d, (x, kernel),
out_sinfo=R.Tensor((n, f, w - kw + 1), dtype="float32"))
+ gv = R.call_tir(Expected.conv1d, (x, kernel),
out_sinfo=R.Tensor((n, f, w + 1 - kw), dtype="float32"))
return gv
@T.prim_func(private=True)
@@ -181,7 +181,7 @@ def test_conv1d_symbolic():
rxplaceholder = T.match_buffer(var_rxplaceholder, (n, c, w))
f, kw = T.int64(), T.int64()
rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, (f, c, kw))
- conv1d_ncw = T.match_buffer(var_conv1d_ncw, (n, f, w - kw +
T.int64(1)))
+ conv1d_ncw = T.match_buffer(var_conv1d_ncw, (n, f, w + T.int64(1)
- kw))
# with T.block("root"):
pad_temp = T.alloc_buffer((n, c, w))
for i0, i1, i2 in T.grid(n, c, w):
@@ -349,7 +349,7 @@ def test_conv2d_symbolic():
kh = T.int64()
w = T.int64()
kw = T.int64()
- gv = R.call_tir(Expected.conv2d, (x, kernel), R.Tensor((n, f, ((h
- kh) + 1), ((w - kw) + 1)), dtype="float32"))
+ gv = R.call_tir(Expected.conv2d, (x, kernel), R.Tensor((n, f, h +
1 - kh, w + 1 - kw), dtype="float32"))
return gv
@T.prim_func(private=True)
@@ -364,7 +364,7 @@ def test_conv2d_symbolic():
w = T.int64()
rxplaceholder = T.match_buffer(var_rxplaceholder, [n, c, h, w],
dtype="float32")
rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [f, c, kh,
kw], dtype="float32")
- conv2d_nchw = T.match_buffer(var_conv2d_nchw, [n, f, h - kh +
T.int64(1), w - kw + T.int64(1)], dtype="float32")
+ conv2d_nchw = T.match_buffer(var_conv2d_nchw, [n, f, h +
T.int64(1) - kh, w + T.int64(1) - kw], dtype="float32")
pad_temp = T.alloc_buffer([n, c, h, w], dtype="float32")
for i0, i1, i2, i3 in T.grid(n, c, h, w):
with T.block("pad_temp"):