This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 64cea4a572 [Relax][PyTorch] Add MaxPool 1D and 3D Op Support for
Exported Program and FX graph (#17919)
64cea4a572 is described below
commit 64cea4a5727eb4624bef48c67a64f37139a210a7
Author: Deivanayaki S <[email protected]>
AuthorDate: Thu May 8 18:10:09 2025 +0530
[Relax][PyTorch] Add MaxPool 1D and 3D Op Support for Exported Program and
FX graph (#17919)
* add max_pool 1d and 3d op support and refactor 2d op
* update the layout used in max pool 1d op
* add mappings into fx translator and fix lint issues
* fix missing incorrect mappings and add module func
* update output tensor struct info for maxpool1d
* add docs for handling edge cases
---------
Co-authored-by: deivanayakisankaralingam <deiva@Deivanayaki>
---
.../frontend/torch/base_fx_graph_translator.py | 100 ++++-
.../frontend/torch/exported_program_translator.py | 2 +
python/tvm/relax/frontend/torch/fx_translator.py | 26 ++
src/relax/op/nn/pooling.cc | 2 +-
.../relax/test_frontend_from_exported_program.py | 199 +++++++++
tests/python/relax/test_frontend_from_fx.py | 200 +++++++++
tests/python/relax/test_op_nn_pooling.py | 446 +++++++++++++++++++++
7 files changed, 973 insertions(+), 2 deletions(-)
diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index f8634f5da7..f683e62d24 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -923,6 +923,50 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
bias = args[2] if len(args) > 2 else None
return self.block_builder.emit(relax.op.linear(x, weight, bias,
"float32"))
+ def _max_pool1d_impl(
+ self,
+ x: relax.Expr,
+ kernel_size: Union[int, Tuple[int]] = 1,
+ stride: Optional[Union[int, Tuple[int]]] = None,
+ padding: Optional[int] = 0,
+ dilation: Optional[int] = 1,
+ ceil_mode: Optional[bool] = False,
+ ) -> relax.Var:
+ # Expand to 3D by adding batch dim if input is 2D
+ x_ndim = x.struct_info.ndim
+ if x_ndim == 2:
+ x = relax.op.expand_dims(x, axis=0)
+
+ stride = kernel_size if stride is None else stride
+
+ result = self.block_builder.emit(
+ relax.op.nn.max_pool1d(
+ x,
+ pool_size=kernel_size,
+ strides=stride,
+ padding=padding,
+ dilation=dilation,
+ ceil_mode=ceil_mode,
+ layout="NCW",
+ )
+ )
+
+ # Remove added batch dim from result
+ if x_ndim == 2:
+ result = relax.op.squeeze(result, axis=[0])
+ return result
+
+ def _max_pool1d(self, node: fx.Node) -> relax.Var:
+ args = self.retrieve_args(node)
+ x = args[0]
+ kernel_size = args[1]
+ stride = args[2] if len(args) > 2 else None
+ padding = args[3] if len(args) > 3 else 0
+ dilation = args[4] if len(args) > 4 else 1
+ ceil_mode = args[5] if len(args) > 5 else False
+
+ return self._max_pool1d_impl(x, kernel_size, stride, padding,
dilation, ceil_mode)
+
def _max_pool2d_impl(
self,
x: relax.Expr,
@@ -932,8 +976,14 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
dilation: Optional[int] = 1,
ceil_mode: Optional[bool] = False,
) -> relax.Var:
+ # Expand to 4D by adding batch dim if input is 3D
+ x_ndim = x.struct_info.ndim
+ if x_ndim == 3:
+ x = relax.op.expand_dims(x, axis=0)
+
stride = kernel_size if stride is None else stride
- return self.block_builder.emit(
+
+ result = self.block_builder.emit(
relax.op.nn.max_pool2d(
x,
pool_size=kernel_size,
@@ -945,6 +995,11 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
)
)
+ # Remove added batch dim from result
+ if x_ndim == 3:
+ result = relax.op.squeeze(result, axis=[0])
+ return result
+
def _max_pool2d(self, node: fx.Node) -> relax.Var:
args = self.retrieve_args(node)
x = args[0]
@@ -956,6 +1011,49 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
return self._max_pool2d_impl(x, kernel_size, stride, padding,
dilation, ceil_mode)
+ def _max_pool3d_impl(
+ self,
+ x: relax.Expr,
+ kernel_size: Union[int, Tuple[int, int, int]] = (1, 1, 1),
+ stride: Optional[Union[int, Tuple[int, int, int]]] = None,
+ padding: Optional[int] = 0,
+ dilation: Optional[int] = 1,
+ ceil_mode: Optional[bool] = False,
+ ) -> relax.Var:
+ # Expand to 5D by adding batch dim if input is 4D
+ x_ndim = x.struct_info.ndim
+ if x_ndim == 4:
+ x = relax.op.expand_dims(x, axis=0)
+
+ stride = kernel_size if stride is None else stride
+
+ result = self.block_builder.emit(
+ relax.op.nn.max_pool3d(
+ x,
+ pool_size=kernel_size,
+ strides=stride,
+ padding=padding,
+ dilation=dilation,
+ ceil_mode=ceil_mode,
+ layout="NCDHW",
+ )
+ )
+
+ # Remove added batch dim from result
+ if x_ndim == 4:
+ result = relax.op.squeeze(result, axis=[0])
+ return result
+
+ def _max_pool3d(self, node: fx.Node) -> relax.Var:
+ args = self.retrieve_args(node)
+ x = args[0]
+ kernel_size = args[1]
+ stride = args[2] if len(args) > 2 else None
+ padding = args[3] if len(args) > 3 else 0
+ dilation = args[4] if len(args) > 4 else 1
+ ceil_mode = args[5] if len(args) > 5 else False
+ return self._max_pool3d_impl(x, kernel_size, stride, padding,
dilation, ceil_mode)
+
def _pad(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
pad = node.args[1]
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 42a57273af..0600dfa552 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -408,7 +408,9 @@ class ExportedProgramImporter(BaseFXGraphImporter):
"group_norm.default": self._group_norm,
"layer_norm.default": self._layer_norm,
"linear.default": self._linear,
+ "max_pool1d.default": self._max_pool1d,
"max_pool2d.default": self._max_pool2d,
+ "max_pool3d.default": self._max_pool3d,
"scaled_dot_product_attention.default":
self._scaled_dot_product_attention,
"unbind.int": self._unbind,
"upsample_bilinear2d.vec": self._upsample_bilinear2d,
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py
b/python/tvm/relax/frontend/torch/fx_translator.py
index 199e58cb1d..0a94c679f5 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -449,6 +449,17 @@ class TorchFXImporter(BaseFXGraphImporter):
bias = self.params.get(module.bias, None)
return self.block_builder.emit(relax.op.linear(x, weight, bias,
"float32"))
+ def _max_pool1d_module(self, node: fx.Node) -> relax.Var:
+ x = self.env[node.args[0]]
+ module = self.named_modules[node.target]
+ kernel_size = module.kernel_size
+ stride = module.stride
+ padding = module.padding
+ dilation = module.dilation
+ ceil_mode = module.ceil_mode
+
+ return self._max_pool1d_impl(x, kernel_size, stride, padding,
dilation, ceil_mode)
+
def _max_pool2d_module(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
module = self.named_modules[node.target]
@@ -460,6 +471,17 @@ class TorchFXImporter(BaseFXGraphImporter):
return self._max_pool2d_impl(x, kernel_size, stride, padding,
dilation, ceil_mode)
+ def _max_pool3d_module(self, node: fx.Node) -> relax.Var:
+ x = self.env[node.args[0]]
+ module = self.named_modules[node.target]
+ kernel_size = module.kernel_size
+ stride = module.stride
+ padding = module.padding
+ dilation = module.dilation
+ ceil_mode = module.ceil_mode
+
+ return self._max_pool3d_impl(x, kernel_size, stride, padding,
dilation, ceil_mode)
+
def _pixel_shuffle_module(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
module = self.named_modules[node.target]
@@ -661,7 +683,9 @@ class TorchFXImporter(BaseFXGraphImporter):
nn.GroupNorm: self._group_norm_module,
nn.LayerNorm: self._layer_norm_module,
nn.Linear: self._linear_module,
+ nn.MaxPool1d: self._max_pool1d_module,
nn.MaxPool2d: self._max_pool2d_module,
+ nn.MaxPool3d: self._max_pool3d_module,
nn.modules.sparse.Embedding: self._embedding_module,
nn.PixelShuffle: self._pixel_shuffle_module,
# tensor manipulation
@@ -774,7 +798,9 @@ class TorchFXImporter(BaseFXGraphImporter):
"interpolate": self._interpolate,
"layer_norm": self._layer_norm,
"linear": self._linear,
+ "max_pool1d": self._max_pool1d,
"max_pool2d": self._max_pool2d,
+ "max_pool3d": self._max_pool3d,
"scaled_dot_product_attention": self._scaled_dot_product_attention,
"stochastic_depth": lambda node: self.env[node.args[0]],
"unbind": self._unbind,
diff --git a/src/relax/op/nn/pooling.cc b/src/relax/op/nn/pooling.cc
index 565e6a00c6..391edda9ef 100644
--- a/src/relax/op/nn/pooling.cc
+++ b/src/relax/op/nn/pooling.cc
@@ -95,7 +95,7 @@ StructInfo InferStructInfoPool1D(const Call& call, const
BlockBuilder& ctx) {
PrimExpr numerator_w = input_w + padding_w - attrs->dilation[0] * (kernel_w
- 1) - 1;
if (attrs->ceil_mode) {
- numerator_w += attrs->strides[1] - 1;
+ numerator_w += attrs->strides[0] - 1;
}
out_NCW_shape[2] = analyzer->Simplify(floordiv(numerator_w,
attrs->strides[0]) + 1);
diff --git a/tests/python/relax/test_frontend_from_exported_program.py
b/tests/python/relax/test_frontend_from_exported_program.py
index c375992dca..bcd96369d4 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -2364,6 +2364,101 @@ def test_linear():
verify_model(model, example_args, binding, expected2)
+def test_maxpool1d():
+ class MaxPool1d(Module):
+ def __init__(self):
+ super().__init__()
+ self.pool = torch.nn.MaxPool1d(kernel_size=2)
+
+ def forward(self, input):
+ return self.pool(input)
+
+ class MaxPool1d_functional(Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, input):
+ return torch.nn.functional.max_pool1d(input, kernel_size=2)
+
+ class MaxPool1d2(Module):
+ def __init__(self):
+ super().__init__()
+ self.pool = torch.nn.MaxPool1d(kernel_size=3, stride=2)
+
+ def forward(self, input):
+ return self.pool(input)
+
+ @tvm.script.ir_module
+ class expected1:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 8), dtype="float32")
+ ) -> R.Tuple(R.Tensor((1, 3, 4), dtype="float32")):
+ with R.dataflow():
+ lv = R.nn.max_pool1d(
+ input_1,
+ pool_size=[2],
+ strides=[2],
+ dilation=[1],
+ padding=[0, 0],
+ layout="NCW",
+ out_layout="NCW",
+ )
+ gv = (lv,)
+ R.output(gv)
+ return gv
+
+ @tvm.script.ir_module
+ class expected2:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 8), dtype="float32")
+ ) -> R.Tuple(R.Tensor((1, 3, 4), dtype="float32")):
+ with R.dataflow():
+ lv = R.nn.max_pool1d(
+ input_1,
+ pool_size=[2],
+ strides=[2],
+ dilation=[1],
+ padding=[0, 0],
+ layout="NCW",
+ out_layout="NCW",
+ )
+ gv = (lv,)
+ R.output(gv)
+ return gv
+
+ @tvm.script.ir_module
+ class expected3:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10), dtype="float32")
+ ) -> R.Tuple(R.Tensor((1, 3, 4), dtype="float32")):
+ with R.dataflow():
+ lv = R.nn.max_pool1d(
+ input_1,
+ pool_size=[3],
+ strides=[2],
+ dilation=[1],
+ padding=[0, 0],
+ layout="NCW",
+ out_layout="NCW",
+ )
+ gv = (lv,)
+ R.output(gv)
+ return gv
+
+ # Example inputs
+ example_args1 = (torch.randn(1, 3, 8, dtype=torch.float32),)
+ example_args2 = (torch.randn(1, 3, 8, dtype=torch.float32),)
+ example_args3 = (torch.randn(1, 3, 10, dtype=torch.float32),)
+
+ # Verify the models
+ verify_model(MaxPool1d(), example_args1, {}, expected1)
+ verify_model(MaxPool1d_functional(), example_args2, {}, expected2)
+ verify_model(MaxPool1d2(), example_args3, {}, expected3)
+
+
def test_maxpool2d():
class MaxPool2d(Module):
def __init__(self):
@@ -2466,6 +2561,110 @@ def test_maxpool2d():
verify_model(MaxPool2d3(), example_args, {}, expected3)
+def test_maxpool3d():
+ class MaxPool3d(Module):
+ def __init__(self):
+ super().__init__()
+ self.pool = torch.nn.MaxPool3d(kernel_size=[1, 1, 1])
+
+ def forward(self, input):
+ return self.pool(input)
+
+ class MaxPool3d_functional(Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, input):
+ return torch.nn.functional.max_pool3d(input, kernel_size=[1, 1, 1])
+
+ @tvm.script.ir_module
+ class expected1:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 4, 4, 4), dtype="float32")
+ ) -> R.Tuple(R.Tensor((1, 3, 4, 4, 4), dtype="float32")):
+ with R.dataflow():
+ lv = R.nn.max_pool3d(
+ input_1,
+ pool_size=[1, 1, 1],
+ strides=[1, 1, 1],
+ dilation=[1, 1, 1],
+ padding=[0, 0, 0, 0, 0, 0],
+ layout="NCDHW",
+ out_layout="NCDHW",
+ )
+ gv = (lv,)
+ R.output(gv)
+ return gv
+
+ class MaxPool3d2(Module):
+ def __init__(self):
+ super().__init__()
+ self.pool = torch.nn.MaxPool3d(kernel_size=[2, 2, 2], dilation=[2,
2, 2])
+
+ def forward(self, input):
+ return self.pool(input)
+
+ @tvm.script.ir_module
+ class expected2:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 8, 8, 8), dtype="float32")
+ ) -> R.Tuple(R.Tensor((1, 3, 3, 3, 3), dtype="float32")):
+ with R.dataflow():
+ lv = R.nn.max_pool3d(
+ input_1,
+ pool_size=[2, 2, 2],
+ strides=[2, 2, 2],
+ dilation=[2, 2, 2],
+ padding=[0, 0, 0, 0, 0, 0],
+ layout="NCDHW",
+ out_layout="NCDHW",
+ )
+ gv = (lv,)
+ R.output(gv)
+ return gv
+
+ class MaxPool3d3(Module):
+ def __init__(self):
+ super().__init__()
+ self.pool = torch.nn.MaxPool3d(kernel_size=[3, 3, 3], padding=1,
stride=2)
+
+ def forward(self, input):
+ return self.pool(input)
+
+ @tvm.script.ir_module
+ class expected3:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10, 10), dtype="float32")
+ ) -> R.Tuple(R.Tensor((1, 3, 5, 5, 5), dtype="float32")):
+ with R.dataflow():
+ lv = R.nn.max_pool3d(
+ input_1,
+ pool_size=[3, 3, 3],
+ strides=[2, 2, 2],
+ dilation=[1, 1, 1],
+ padding=[1, 1, 1, 1, 1, 1],
+ layout="NCDHW",
+ out_layout="NCDHW",
+ )
+ gv = (lv,)
+ R.output(gv)
+ return gv
+
+ # Example input tensors
+ example_args1 = (torch.randn(1, 3, 4, 4, 4, dtype=torch.float32),)
+ example_args2 = (torch.randn(1, 3, 8, 8, 8, dtype=torch.float32),)
+ example_args3 = (torch.randn(1, 3, 10, 10, 10, dtype=torch.float32),)
+
+ # Verify the models with expected IR modules
+ verify_model(MaxPool3d(), example_args1, {}, expected1)
+ verify_model(MaxPool3d_functional(), example_args1, {}, expected1)
+ verify_model(MaxPool3d2(), example_args2, {}, expected2)
+ verify_model(MaxPool3d3(), example_args3, {}, expected3)
+
+
def test_scaled_dot_product_attention():
class Attention1(Module):
def forward(self, q, k, v):
diff --git a/tests/python/relax/test_frontend_from_fx.py
b/tests/python/relax/test_frontend_from_fx.py
index 643372750b..56f76bd3e9 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -984,6 +984,106 @@ def test_prelu():
verify_model(Prelu2(), input_info, {}, expected)
+def test_maxpool1d():
+ input_info = [([1, 3, 10], "float32")]
+
+ class MaxPool1d(Module):
+ def __init__(self):
+ super().__init__()
+ self.pool = torch.nn.MaxPool1d(kernel_size=2)
+
+ def forward(self, input):
+ return self.pool(input)
+
+ class MaxPool1d_functional(Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, input):
+ return torch.nn.functional.max_pool1d(input, kernel_size=2)
+
+ @tvm.script.ir_module
+ class expected1:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10), dtype="float32")
+ ) -> R.Tensor((1, 3, 5), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 5), dtype="float32") = R.nn.max_pool1d(
+ input_1,
+ pool_size=[2],
+ strides=[2],
+ dilation=[1],
+ padding=[0, 0],
+ layout="NCW",
+ out_layout="NCW",
+ )
+ gv: R.Tensor((1, 3, 5), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ class MaxPool1d2(Module):
+ def __init__(self):
+ super().__init__()
+ self.pool = torch.nn.MaxPool1d(kernel_size=3, stride=1, padding=1)
+
+ def forward(self, input):
+ return self.pool(input)
+
+ @tvm.script.ir_module
+ class expected2:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10), dtype="float32")
+ ) -> R.Tensor((1, 3, 10), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10), dtype="float32") = R.nn.max_pool1d(
+ input_1,
+ pool_size=[3],
+ strides=[1],
+ dilation=[1],
+ padding=[1, 1],
+ layout="NCW",
+ out_layout="NCW",
+ )
+ gv: R.Tensor((1, 3, 10), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ class MaxPool1d3(Module):
+ def __init__(self):
+ super().__init__()
+ self.pool = torch.nn.MaxPool1d(kernel_size=3, stride=2, dilation=2)
+
+ def forward(self, input):
+ return self.pool(input)
+
+ @tvm.script.ir_module
+ class expected3:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10), dtype="float32")
+ ) -> R.Tensor((1, 3, 3), dtype="float32"): # Corrected here
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 3), dtype="float32") = R.nn.max_pool1d(
+ input_1,
+ pool_size=[3],
+ strides=[2],
+ dilation=[2],
+ padding=[0, 0],
+ layout="NCW",
+ out_layout="NCW",
+ )
+ gv: R.Tensor((1, 3, 3), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ verify_model(MaxPool1d(), input_info, {}, expected1)
+ verify_model(MaxPool1d_functional(), input_info, {}, expected1)
+ verify_model(MaxPool1d2(), input_info, {}, expected2)
+ verify_model(MaxPool1d3(), input_info, {}, expected3)
+
+
def test_maxpool2d():
input_info = [([1, 3, 10, 10], "float32")]
@@ -1087,6 +1187,106 @@ def test_maxpool2d():
verify_model(MaxPool2d3(), input_info, {}, expected3)
+def test_maxpool3d():
+ input_info = [([1, 3, 10, 10, 10], "float32")]
+
+ class MaxPool3d(Module):
+ def __init__(self):
+ super().__init__()
+ self.pool = torch.nn.MaxPool3d(kernel_size=[1, 1, 1])
+
+ def forward(self, input):
+ return self.pool(input)
+
+ class MaxPool3d_functional(Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, input):
+ return torch.nn.functional.max_pool3d(input, kernel_size=[1, 1, 1])
+
+ @tvm.script.ir_module
+ class expected1:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10, 10), dtype="float32")
+ ) -> R.Tensor((1, 3, 10, 10, 10), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10, 10), dtype="float32") =
R.nn.max_pool3d(
+ input_1,
+ pool_size=[1, 1, 1],
+ strides=[1, 1, 1],
+ dilation=[1, 1, 1],
+ padding=[0, 0, 0, 0, 0, 0],
+ layout="NCDHW",
+ out_layout="NCDHW",
+ )
+ gv: R.Tensor((1, 3, 10, 10, 10), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ class MaxPool3d2(Module):
+ def __init__(self):
+ super().__init__()
+ self.pool = torch.nn.MaxPool3d(kernel_size=[2, 2, 2], dilation=[1,
2, 2])
+
+ def forward(self, input):
+ return self.pool(input)
+
+ @tvm.script.ir_module
+ class expected2:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10, 10), dtype="float32")
+ ) -> R.Tensor((1, 3, 5, 4, 4), dtype="float32"): # Fixed here
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 5, 4, 4), dtype="float32") =
R.nn.max_pool3d(
+ input_1,
+ pool_size=[2, 2, 2],
+ strides=[2, 2, 2],
+ dilation=[1, 2, 2],
+ padding=[0, 0, 0, 0, 0, 0],
+ layout="NCDHW",
+ out_layout="NCDHW",
+ )
+ gv: R.Tensor((1, 3, 5, 4, 4), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ class MaxPool3d3(Module):
+ def __init__(self):
+ super().__init__()
+ self.pool = torch.nn.MaxPool3d(kernel_size=[3, 3, 3], padding=1,
stride=2)
+
+ def forward(self, input):
+ return self.pool(input)
+
+ @tvm.script.ir_module
+ class expected3:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10, 10), dtype="float32")
+ ) -> R.Tensor((1, 3, 5, 5, 5), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 5, 5, 5), dtype="float32") =
R.nn.max_pool3d(
+ input_1,
+ pool_size=[3, 3, 3],
+ strides=[2, 2, 2],
+ dilation=[1, 1, 1],
+ padding=[1, 1, 1, 1, 1, 1],
+ layout="NCDHW",
+ out_layout="NCDHW",
+ )
+ gv: R.Tensor((1, 3, 5, 5, 5), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ verify_model(MaxPool3d(), input_info, {}, expected1)
+ verify_model(MaxPool3d_functional(), input_info, {}, expected1)
+ verify_model(MaxPool3d2(), input_info, {}, expected2)
+ verify_model(MaxPool3d3(), input_info, {}, expected3)
+
+
def test_avgpool2d():
input_info = [([1, 3, 10, 10], "float32")]
diff --git a/tests/python/relax/test_op_nn_pooling.py
b/tests/python/relax/test_op_nn_pooling.py
index 2533a2fcad..0d58af1cbe 100644
--- a/tests/python/relax/test_op_nn_pooling.py
+++ b/tests/python/relax/test_op_nn_pooling.py
@@ -25,7 +25,11 @@ from tvm.script import relax as R
def test_op_correctness():
x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32"))
+ x1 = relax.Var("x1", R.Tensor((2, 3, 64), "float32"))
+ x2 = relax.Var("x2", R.Tensor((2, 3, 8, 28, 28), "float32"))
+ assert relax.op.nn.max_pool1d(x1).op == Op.get("relax.nn.max_pool1d")
assert relax.op.nn.max_pool2d(x).op == Op.get("relax.nn.max_pool2d")
+ assert relax.op.nn.max_pool3d(x2).op == Op.get("relax.nn.max_pool3d")
assert relax.op.nn.avg_pool2d(x).op == Op.get("relax.nn.avg_pool2d")
assert relax.op.nn.adaptive_avg_pool2d(x).op ==
Op.get("relax.nn.adaptive_avg_pool2d")
@@ -35,6 +39,197 @@ def _check_inference(bb: relax.BlockBuilder, call:
relax.Call, expected_sinfo: r
tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo)
+def test_max_pool1d_infer_struct_info():
+ bb = relax.BlockBuilder()
+ vdev0 = VDevice("llvm")
+ x0 = relax.Var("x", R.Tensor((2, 3, 32), "float32"))
+ x1 = relax.Var("x", R.Tensor("float32", ndim=3))
+ x2 = relax.Var("x", R.Tensor(ndim=3))
+ x3 = relax.Var("x", R.Tensor("float32"))
+ x4 = relax.Var("x", R.Tensor())
+ x5 = relax.Var("x", R.Tensor((2, 3, 32), "float32", vdev0))
+
+ _check_inference(bb, relax.op.nn.max_pool1d(x0),
relax.TensorStructInfo((2, 3, 32), "float32"))
+ _check_inference(
+ bb, relax.op.nn.max_pool1d(x5), relax.TensorStructInfo((2, 3, 32),
"float32", vdev0)
+ )
+ _check_inference(
+ bb, relax.op.nn.max_pool1d(x0, pool_size=3),
relax.TensorStructInfo((2, 3, 30), "float32")
+ )
+ _check_inference(
+ bb, relax.op.nn.max_pool1d(x0, strides=2), relax.TensorStructInfo((2,
3, 16), "float32")
+ )
+ _check_inference(
+ bb, relax.op.nn.max_pool1d(x0, padding=1), relax.TensorStructInfo((2,
3, 34), "float32")
+ )
+ _check_inference(
+ bb, relax.op.nn.max_pool1d(x0, dilation=2), relax.TensorStructInfo((2,
3, 32), "float32")
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.max_pool1d(x0, layout="NCW", out_layout="NWC"),
+ relax.TensorStructInfo((2, 32, 3), "float32"),
+ )
+ _check_inference(
+ bb, relax.op.nn.max_pool1d(x1),
relax.TensorStructInfo(dtype="float32", ndim=3)
+ )
+ _check_inference(bb, relax.op.nn.max_pool1d(x2),
relax.TensorStructInfo(dtype="", ndim=3))
+ _check_inference(
+ bb, relax.op.nn.max_pool1d(x3),
relax.TensorStructInfo(dtype="float32", ndim=3)
+ )
+ _check_inference(bb, relax.op.nn.max_pool1d(x4),
relax.TensorStructInfo(dtype="", ndim=3))
+
+
+def test_max_pool1d_infer_struct_info_shape_symbolic():
+ bb = relax.BlockBuilder()
+ n = tir.Var("n", "int64")
+ c = tir.Var("c", "int64")
+ w = tir.Var("w", "int64")
+ c16 = tir.Var("c16", "int64")
+
+ x0 = relax.Var("x", R.Tensor((n, c, w), "float32"))
+ x1 = relax.Var("x", R.Tensor((n, c, w, c16), "float32"))
+
+ _check_inference(
+ bb,
+ relax.op.nn.max_pool1d(x0, pool_size=3, strides=3, padding=2,
dilation=2),
+ relax.TensorStructInfo(
+ (
+ n,
+ c,
+ tvm.tir.floordiv(w - 1, 3) + 1,
+ ),
+ "float32",
+ ),
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.max_pool1d(x1, layout="NCW16c", out_layout="NWC"),
+ relax.TensorStructInfo((n, w, c * 16), "float32"),
+ )
+
+
+def test_max_pool1d_infer_struct_info_shape_var():
+ bb = relax.BlockBuilder()
+ s0 = relax.Var("s", relax.ShapeStructInfo(ndim=3))
+ s1 = relax.Var("s", relax.ShapeStructInfo(ndim=4))
+ s2 = relax.Var("s", relax.ShapeStructInfo())
+
+ x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32"))
+ x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32"))
+ x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32"))
+
+ _check_inference(
+ bb, relax.op.nn.max_pool1d(x0),
relax.TensorStructInfo(dtype="float32", ndim=3)
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.max_pool1d(x1, layout="NCW16c"),
+ relax.TensorStructInfo(dtype="float32", ndim=4),
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.max_pool1d(x2),
+ relax.TensorStructInfo(dtype="float32", ndim=3),
+ )
+
+
+def test_max_pool1d_infer_struct_info_ceil_mode():
+ bb = relax.BlockBuilder()
+ x = relax.Var("x", R.Tensor((2, 3, 32), "float32"))
+
+ _check_inference(
+ bb,
+ relax.op.nn.max_pool1d(x, pool_size=3, strides=2, ceil_mode=True),
+ relax.TensorStructInfo((2, 3, 16), "float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.max_pool1d(x, pool_size=5, strides=2, ceil_mode=True),
+ relax.TensorStructInfo((2, 3, 15), "float32"),
+ )
+
+
+def test_max_pool1d_infer_struct_info_ceil_mode_symbolic():
+ bb = relax.BlockBuilder()
+ n = tir.Var("n", "int64")
+ c = tir.Var("c", "int64")
+ w = tir.Var("w", "int64")
+ x = relax.Var("x", R.Tensor((n, c, w), "float32"))
+
+ _check_inference(
+ bb,
+ relax.op.nn.max_pool1d(x, pool_size=3, strides=2, padding=1,
dilation=2, ceil_mode=True),
+ relax.TensorStructInfo((n, c, tvm.tir.floordiv(w, 2)), "float32"),
+ )
+
+
+def test_max_pool1d_infer_struct_info_more_input_dtype():
+ bb = relax.BlockBuilder()
+ x0 = relax.Var("x", R.Tensor((2, 3, 32), "float16"))
+ x1 = relax.Var("x", R.Tensor((2, 3, 32), "int8"))
+ x2 = relax.Var("x", R.Tensor((2, 3, 32), "int64"))
+
+ _check_inference(bb, relax.op.nn.max_pool1d(x0),
relax.TensorStructInfo((2, 3, 32), "float16"))
+ _check_inference(bb, relax.op.nn.max_pool1d(x1),
relax.TensorStructInfo((2, 3, 32), "int8"))
+ _check_inference(bb, relax.op.nn.max_pool1d(x2),
relax.TensorStructInfo((2, 3, 32), "int64"))
+
+
+def test_max_pool1d_stride_padding_dilation_int64():
+ x = relax.Var("x", R.Tensor((2, 3, 28), "float32"))
+ max_pool1d = relax.op.nn.max_pool1d(x, pool_size=3, strides=1, padding=1,
dilation=1)
+
+ assert max_pool1d.attrs.strides[0].dtype == "int64"
+ assert max_pool1d.attrs.padding[0].dtype == "int64"
+ assert max_pool1d.attrs.padding[1].dtype == "int64"
+ assert max_pool1d.attrs.dilation[0].dtype == "int64"
+
+
+def test_max_pool1d_wrong_pool_size_strides_padding_dilation_length():
+ x = relax.Var("x", R.Tensor((2, 3, 28), "float32"))
+ with pytest.raises(TVMError):
+ relax.op.nn.max_pool1d(x, pool_size=(1, 2))
+ with pytest.raises(TVMError):
+ relax.op.nn.max_pool1d(x, strides=(1, 2))
+ with pytest.raises(TVMError):
+ relax.op.nn.max_pool1d(x, padding=(1, 2, 3))
+ with pytest.raises(TVMError):
+ relax.op.nn.max_pool1d(x, dilation=(1, 2))
+
+
+def test_max_pool1d_infer_struct_info_wrong_layout_string():
+ bb = relax.BlockBuilder()
+ x = relax.Var("x", R.Tensor((2, 3, 28), "float32"))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.nn.max_pool1d(x, layout="OIW"))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.nn.max_pool1d(x, out_layout="OWI"))
+
+
+def test_max_pool1d_wrong_input_ndim():
+ bb = relax.BlockBuilder()
+ x0 = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32"))
+ x1 = relax.Var("x", R.Tensor("float32", ndim=5))
+
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.nn.max_pool1d(x0))
+
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.nn.max_pool1d(x1))
+
+
+def test_max_pool1d_infer_struct_info_wrong_input_type():
+ bb = relax.BlockBuilder()
+ x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 28)))
+ x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 28),
"float32")))
+
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.nn.max_pool1d(x0))
+
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.nn.max_pool1d(x1))
+
+
def test_max_pool2d_infer_struct_info():
bb = relax.BlockBuilder()
vdev0 = VDevice("llvm")
@@ -265,6 +460,257 @@ def test_max_pool2d_infer_struct_info_wrong_input_type():
bb.normalize(relax.op.nn.max_pool2d(x1))
+def test_max_pool3d_infer_struct_info():
+ bb = relax.BlockBuilder()
+ vdev0 = VDevice("llvm")
+ x0 = relax.Var("x", R.Tensor((2, 3, 16, 32, 32), "float32"))
+ x1 = relax.Var("x", R.Tensor((2, 16, 32, 32, 3), "float32"))
+ x2 = relax.Var("x", R.Tensor("float32", ndim=5))
+ x3 = relax.Var("x", R.Tensor("float32"))
+ x4 = relax.Var("x", R.Tensor(ndim=5))
+ x5 = relax.Var("x", R.Tensor())
+ x6 = relax.Var("x", R.Tensor((2, 4, 16, 32, 32, 16), "float32"))
+ x7 = relax.Var("x", R.Tensor((2, 3, 16, 32, 32), "float32", vdev0))
+
+ _check_inference(
+ bb, relax.op.nn.max_pool3d(x0), relax.TensorStructInfo((2, 3, 16, 32,
32), "float32")
+ )
+ _check_inference(
+ bb, relax.op.nn.max_pool3d(x7), relax.TensorStructInfo((2, 3, 16, 32,
32), "float32", vdev0)
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.max_pool3d(x0, pool_size=3),
+ relax.TensorStructInfo((2, 3, 14, 30, 30), "float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.max_pool3d(x0, pool_size=(3, 5, 3)),
+ relax.TensorStructInfo((2, 3, 14, 28, 30), "float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.max_pool3d(x0, padding=1),
+ relax.TensorStructInfo((2, 3, 18, 34, 34), "float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.max_pool3d(x0, padding=[1, 2, 3]),
+ relax.TensorStructInfo((2, 3, 18, 36, 38), "float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.max_pool3d(x0, strides=2),
+ relax.TensorStructInfo((2, 3, 8, 16, 16), "float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.max_pool3d(x0, dilation=2),
+ relax.TensorStructInfo((2, 3, 16, 32, 32), "float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.max_pool3d(x1, layout="NDHWC"),
+ relax.TensorStructInfo((2, 16, 32, 32, 3), "float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.max_pool3d(x0, out_layout="NDHWC"),
+ relax.TensorStructInfo((2, 16, 32, 32, 3), "float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.max_pool3d(x6, layout="NCDHW16c", out_layout="NDHWC16c"),
+ relax.TensorStructInfo((2, 16, 32, 32, 4, 16), "float32"),
+ )
+ _check_inference(
+ bb, relax.op.nn.max_pool3d(x2),
relax.TensorStructInfo(dtype="float32", ndim=5)
+ )
+ _check_inference(
+ bb, relax.op.nn.max_pool3d(x3),
relax.TensorStructInfo(dtype="float32", ndim=5)
+ )
+ _check_inference(bb, relax.op.nn.max_pool3d(x4),
relax.TensorStructInfo(dtype="", ndim=5))
+ _check_inference(bb, relax.op.nn.max_pool3d(x5),
relax.TensorStructInfo(dtype="", ndim=5))
+
+
+def test_max_pool3d_infer_struct_info_shape_symbolic():
+ bb = relax.BlockBuilder()
+ n = tir.Var("n", "int64")
+ c = tir.Var("c", "int64")
+ c16 = tir.Var("c16", "int64")
+ id = tir.Var("id", "int64")
+ ih = tir.Var("ih", "int64")
+ iw = tir.Var("iw", "int64")
+ x0 = relax.Var("x", R.Tensor((n, c, id, ih, iw), "float32"))
+ x1 = relax.Var("x", R.Tensor((n, c, id, ih, iw, c16), "float32"))
+
+ _check_inference(
+ bb,
+ relax.op.nn.max_pool3d(
+ x0, pool_size=(3, 3, 3), strides=(3, 3, 3), padding=(2, 2, 2),
dilation=(2, 2, 2)
+ ),
+ relax.TensorStructInfo(
+ (
+ n,
+ c,
+ tvm.tir.floordiv(id - 1, 3) + 1,
+ tvm.tir.floordiv(ih - 1, 3) + 1,
+ tvm.tir.floordiv(iw - 1, 3) + 1,
+ ),
+ "float32",
+ ),
+ )
+
+ _check_inference(
+ bb,
+ relax.op.nn.max_pool3d(x1, layout="NCDHW16c", out_layout="NDHWC"),
+ relax.TensorStructInfo((n, id, ih, iw, c * 16), "float32"),
+ )
+
+
+def test_max_pool3d_infer_struct_info_shape_var():
+ bb = relax.BlockBuilder()
+ s0 = relax.Var("s", relax.ShapeStructInfo(ndim=5))
+ s1 = relax.Var("s", relax.ShapeStructInfo(ndim=6))
+ s2 = relax.Var("s", relax.ShapeStructInfo())
+ x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32"))
+ x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32"))
+ x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32"))
+
+ _check_inference(
+ bb, relax.op.nn.max_pool3d(x0),
relax.TensorStructInfo(dtype="float32", ndim=5)
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.max_pool3d(x1, layout="NCDHW16c"),
+ relax.TensorStructInfo(dtype="float32", ndim=6),
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.max_pool3d(x2),
+ relax.TensorStructInfo(dtype="float32", ndim=5),
+ )
+
+
+def test_max_pool3d_infer_struct_info_ceil_mode():
+ bb = relax.BlockBuilder()
+ x = relax.Var("x", R.Tensor((2, 3, 32, 32, 32), "float32"))
+
+ _check_inference(
+ bb,
+ relax.op.nn.max_pool3d(x, pool_size=3, strides=2, ceil_mode=True),
+ relax.TensorStructInfo((2, 3, 16, 16, 16), "float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.max_pool3d(x, pool_size=(5, 3, 3), strides=2,
ceil_mode=True),
+ relax.TensorStructInfo((2, 3, 15, 16, 16), "float32"),
+ )
+
+
+def test_max_pool3d_infer_struct_info_ceil_mode_symbolic():
+ bb = relax.BlockBuilder()
+ n = tir.Var("n", "int64")
+ c = tir.Var("c", "int64")
+ id_ = tir.Var("id", "int64")
+ ih = tir.Var("ih", "int64")
+ iw = tir.Var("iw", "int64")
+ x = relax.Var("x", R.Tensor((n, c, id_, ih, iw), "float32"))
+
+ _check_inference(
+ bb,
+ relax.op.nn.max_pool3d(
+ x,
+ pool_size=(3, 3, 3),
+ strides=(2, 2, 2),
+ padding=(1, 1, 1),
+ dilation=(2, 2, 2),
+ ceil_mode=True,
+ ),
+ relax.TensorStructInfo(
+ (n, c, tvm.tir.floordiv(id_, 2), tvm.tir.floordiv(ih, 2),
tvm.tir.floordiv(iw, 2)),
+ "float32",
+ ),
+ )
+
+
+def test_max_pool3d_infer_struct_info_more_input_dtype():
+ bb = relax.BlockBuilder()
+ x0 = relax.Var("x", R.Tensor((2, 3, 32, 32, 32), "float16"))
+ x1 = relax.Var("x", R.Tensor((2, 3, 32, 32, 32), "int8"))
+ x2 = relax.Var("x", R.Tensor((2, 3, 32, 32, 32), "int64"))
+ _check_inference(
+ bb, relax.op.nn.max_pool3d(x0), relax.TensorStructInfo((2, 3, 32, 32,
32), "float16")
+ )
+ _check_inference(
+ bb, relax.op.nn.max_pool3d(x1), relax.TensorStructInfo((2, 3, 32, 32,
32), "int8")
+ )
+ _check_inference(
+ bb, relax.op.nn.max_pool3d(x2), relax.TensorStructInfo((2, 3, 32, 32,
32), "int64")
+ )
+
+
+def test_max_pool3d_stride_padding_dilation_int64():
+ x = relax.Var("x", R.Tensor((2, 3, 28, 28, 28), "float32"))
+ max_pool3d = relax.op.nn.max_pool3d(
+ x, (3, 3, 3), strides=(1, 1, 1), padding=(1, 1, 1), dilation=(1, 1, 1)
+ )
+
+ assert max_pool3d.attrs.strides[0].dtype == "int64"
+ assert max_pool3d.attrs.strides[1].dtype == "int64"
+ assert max_pool3d.attrs.strides[2].dtype == "int64"
+ assert max_pool3d.attrs.padding[0].dtype == "int64"
+ assert max_pool3d.attrs.padding[1].dtype == "int64"
+ assert max_pool3d.attrs.padding[2].dtype == "int64"
+ assert max_pool3d.attrs.padding[3].dtype == "int64"
+ assert max_pool3d.attrs.padding[4].dtype == "int64"
+ assert max_pool3d.attrs.dilation[0].dtype == "int64"
+ assert max_pool3d.attrs.dilation[1].dtype == "int64"
+ assert max_pool3d.attrs.dilation[2].dtype == "int64"
+
+
+def test_max_pool3d_wrong_pool_size_strides_padding_dilation_length():
+ x = relax.Var("x", R.Tensor((2, 3, 28, 28, 28), "float32"))
+ with pytest.raises(TVMError):
+ relax.op.nn.max_pool3d(x, pool_size=(1, 2, 3, 4))
+ with pytest.raises(TVMError):
+ relax.op.nn.max_pool3d(x, strides=(1, 2, 3, 4))
+ with pytest.raises(TVMError):
+ relax.op.nn.max_pool3d(x, padding=(1, 2, 3, 4))
+ with pytest.raises(TVMError):
+ relax.op.nn.max_pool3d(x, dilation=(1, 2, 3, 4))
+
+
+def test_max_pool3d_infer_struct_info_wrong_layout_string():
+ bb = relax.BlockBuilder()
+ x = relax.Var("x", R.Tensor((2, 3, 28, 28, 28), "float32"))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.nn.max_pool3d(x, layout="OIHW"))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.nn.max_pool3d(x, out_layout="OHWI"))
+
+
+def test_max_pool3d_wrong_input_ndim():
+ bb = relax.BlockBuilder()
+ x0 = relax.Var("x", R.Tensor((2, 3, 28, 28, 28, 3), "float32"))
+ x1 = relax.Var("x", R.Tensor("float32", ndim=4))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.nn.max_pool3d(x0))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.nn.max_pool3d(x1))
+
+
+def test_max_pool3d_infer_struct_info_wrong_input_type():
+ bb = relax.BlockBuilder()
+ x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 28, 28, 28)))
+ x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 28, 28, 28),
"float32")))
+
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.nn.max_pool3d(x0))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.nn.max_pool3d(x1))
+
+
def test_avg_pool2d_infer_struct_info():
bb = relax.BlockBuilder()
vdev0 = VDevice("llvm")