This is an automated email from the ASF dual-hosted git repository.
cbalint13 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 04f06b5ac3 [Relax] Add edge padding mode (#18558)
04f06b5ac3 is described below
commit 04f06b5ac3dbcb86f6596b87c303e3689f1d4c42
Author: Guan-Ming (Wesley) Chiu <[email protected]>
AuthorDate: Tue Dec 9 20:38:09 2025 +0800
[Relax] Add edge padding mode (#18558)
- Add edge padding mode
- Add auto pad test
---
python/tvm/relax/frontend/common.py | 4 +-
tests/python/relax/test_frontend_common.py | 174 +++++++++++++++++++++++++++++
2 files changed, 176 insertions(+), 2 deletions(-)
diff --git a/python/tvm/relax/frontend/common.py
b/python/tvm/relax/frontend/common.py
index c1e9296ca3..5b18d5e27d 100644
--- a/python/tvm/relax/frontend/common.py
+++ b/python/tvm/relax/frontend/common.py
@@ -123,5 +123,5 @@ def autopad(
topi.nn.mirror_pad, data, pad[:, 0].tolist(), pad[:, 1].tolist(),
"REFLECT"
)
else:
- # TODO(gigiblender) Support edge mode.
- raise NotImplementedError("Pad mode {} not
implemented".format(pad_type))
+ # edge mode - replicate border values
+ return bb.emit_te(topi.nn.replicate_pad, data, pad[:, 0].tolist(),
pad[:, 1].tolist())
diff --git a/tests/python/relax/test_frontend_common.py
b/tests/python/relax/test_frontend_common.py
index 21becb2c85..85424df2f6 100644
--- a/tests/python/relax/test_frontend_common.py
+++ b/tests/python/relax/test_frontend_common.py
@@ -16,7 +16,11 @@
# under the License.
import tvm
import tvm.testing
+from tvm import relax
from tvm.relax.frontend import detach_params
+from tvm.relax.frontend.common import autopad
+from tvm.script import ir as I
+from tvm.script import tir as T
from tvm.script.parser import relax as R
@@ -37,5 +41,175 @@ def test_detach_params():
tvm.testing.assert_allclose(detached_params["func"][0].numpy(),
param.numpy())
+class TestAutopad:
+ def _test_autopad(self, pad_type, expected):
+ bb = relax.BlockBuilder()
+ input_shape = (1, 1, 4, 4)
+ x = relax.Var("x", relax.TensorStructInfo(input_shape, "float32"))
+
+ with bb.function("main", [x]):
+ with bb.dataflow():
+ result = autopad(
+ bb,
+ x,
+ strides=[2, 2],
+ kernel_shape=[3, 3],
+ dilations=(1, 1),
+ pad_type=pad_type,
+ deconv=False,
+ mode="SAME_UPPER",
+ pad_value=0.0,
+ )
+ out = bb.emit_output(result)
+ bb.emit_func_output(out)
+
+ tvm.ir.assert_structural_equal(bb.get(), expected)
+
+ def test_constant(self):
+ @I.ir_module
+ class expected:
+ @T.prim_func(private=True)
+ def pad(
+ x: T.Buffer((T.int64(1), T.int64(1), T.int64(4), T.int64(4)),
"float32"),
+ PadInput: T.Buffer((T.int64(1), T.int64(1), T.int64(5),
T.int64(5)), "float32"),
+ ):
+ T.func_attr({"tir.noalias": True})
+ for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(1),
T.int64(5), T.int64(5)):
+ with T.block("PadInput"):
+ v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1,
i2, i3])
+ T.reads(x[v_i0, v_i1, v_i2, v_i3])
+ T.writes(PadInput[v_i0, v_i1, v_i2, v_i3])
+ PadInput[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(
+ T.int64(0) <= v_i2
+ and v_i2 < T.int64(4)
+ and T.int64(0) <= v_i3
+ and v_i3 < T.int64(4),
+ x[v_i0, v_i1, v_i2, v_i3],
+ T.float32(0.0),
+ )
+
+ @R.function
+ def main(
+ x: R.Tensor((1, 1, 4, 4), dtype="float32")
+ ) -> R.Tensor((1, 1, 5, 5), dtype="float32"):
+ cls = expected
+ with R.dataflow():
+ lv = R.call_tir(
+ cls.pad, (x,), out_sinfo=R.Tensor((1, 1, 5, 5),
dtype="float32")
+ )
+ gv: R.Tensor((1, 1, 5, 5), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ self._test_autopad("constant", expected)
+
+ def test_edge(self):
+ @I.ir_module
+ class expected:
+ @T.prim_func(private=True)
+ def replicate_pad(
+ x: T.Buffer((T.int64(1), T.int64(1), T.int64(4), T.int64(4)),
"float32"),
+ ReplicatePadInput: T.Buffer(
+ (T.int64(1), T.int64(1), T.int64(5), T.int64(5)), "float32"
+ ),
+ ):
+ T.func_attr({"tir.noalias": True})
+ for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(1),
T.int64(5), T.int64(5)):
+ with T.block("ReplicatePadInput"):
+ v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1,
i2, i3])
+ T.reads(
+ x[
+ T.int64(0),
+ T.int64(0),
+ T.int64(0) : T.int64(4),
+ T.int64(0) : T.int64(4),
+ ]
+ )
+ T.writes(ReplicatePadInput[v_i0, v_i1, v_i2, v_i3])
+ ReplicatePadInput[v_i0, v_i1, v_i2, v_i3] = x[
+ T.if_then_else(
+ v_i0 < T.int64(0),
+ T.int64(0),
+ T.if_then_else(T.int64(1) <= v_i0, T.int64(0),
v_i0),
+ ),
+ T.if_then_else(
+ v_i1 < T.int64(0),
+ T.int64(0),
+ T.if_then_else(T.int64(1) <= v_i1, T.int64(0),
v_i1),
+ ),
+ T.if_then_else(
+ v_i2 < T.int64(0),
+ T.int64(0),
+ T.if_then_else(T.int64(4) <= v_i2, T.int64(3),
v_i2),
+ ),
+ T.if_then_else(
+ v_i3 < T.int64(0),
+ T.int64(0),
+ T.if_then_else(T.int64(4) <= v_i3, T.int64(3),
v_i3),
+ ),
+ ]
+
+ @R.function
+ def main(
+ x: R.Tensor((1, 1, 4, 4), dtype="float32")
+ ) -> R.Tensor((1, 1, 5, 5), dtype="float32"):
+ cls = expected
+ with R.dataflow():
+ lv = R.call_tir(
+ cls.replicate_pad, (x,), out_sinfo=R.Tensor((1, 1, 5,
5), dtype="float32")
+ )
+ gv: R.Tensor((1, 1, 5, 5), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ self._test_autopad("edge", expected)
+
+ def test_reflect(self):
+ @I.ir_module
+ class expected:
+ @T.prim_func(private=True)
+ def mirror_pad(
+ x: T.Buffer((T.int64(1), T.int64(1), T.int64(4), T.int64(4)),
"float32"),
+ MirrorPadInput: T.Buffer(
+ (T.int64(1), T.int64(1), T.int64(5), T.int64(5)), "float32"
+ ),
+ ):
+ T.func_attr({"tir.noalias": True})
+ for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(1),
T.int64(5), T.int64(5)):
+ with T.block("MirrorPadInput"):
+ v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1,
i2, i3])
+ T.reads(x[v_i0, v_i1, T.int64(0) : T.int64(4),
T.int64(0) : T.int64(4)])
+ T.writes(MirrorPadInput[v_i0, v_i1, v_i2, v_i3])
+ MirrorPadInput[v_i0, v_i1, v_i2, v_i3] = x[
+ v_i0,
+ v_i1,
+ T.if_then_else(
+ T.int64(4) <= v_i2,
+ T.int64(6) - v_i2,
+ T.if_then_else(v_i2 < T.int64(0), v_i2 *
T.int64(-1), v_i2),
+ ),
+ T.if_then_else(
+ T.int64(4) <= v_i3,
+ T.int64(6) - v_i3,
+ T.if_then_else(v_i3 < T.int64(0), v_i3 *
T.int64(-1), v_i3),
+ ),
+ ]
+
+ @R.function
+ def main(
+ x: R.Tensor((1, 1, 4, 4), dtype="float32")
+ ) -> R.Tensor((1, 1, 5, 5), dtype="float32"):
+ cls = expected
+ with R.dataflow():
+ lv = R.call_tir(
+ cls.mirror_pad, (x,), out_sinfo=R.Tensor((1, 1, 5, 5),
dtype="float32")
+ )
+ gv: R.Tensor((1, 1, 5, 5), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ self._test_autopad("reflect", expected)
+
+
if __name__ == "__main__":
tvm.testing.main()