This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 64cea4a572 [Relax][PyTorch] Add MaxPool 1D and 3D Op Support for 
Exported Program and FX graph (#17919)
64cea4a572 is described below

commit 64cea4a5727eb4624bef48c67a64f37139a210a7
Author: Deivanayaki S <[email protected]>
AuthorDate: Thu May 8 18:10:09 2025 +0530

    [Relax][PyTorch] Add MaxPool 1D and 3D Op Support for Exported Program and 
FX graph (#17919)
    
    * add max_pool 1d and 3d op support and refactor 2d op
    
    * update the layout used in max pool 1d op
    
    * add mappings into fx translator and fix lint issues
    
    * fix missing incorrect mappings and add module func
    
    * update output tensor struct info for maxpool1d
    
    * add docs for handling edge cases
    
    ---------
    
    Co-authored-by: deivanayakisankaralingam <deiva@Deivanayaki>
---
 .../frontend/torch/base_fx_graph_translator.py     | 100 ++++-
 .../frontend/torch/exported_program_translator.py  |   2 +
 python/tvm/relax/frontend/torch/fx_translator.py   |  26 ++
 src/relax/op/nn/pooling.cc                         |   2 +-
 .../relax/test_frontend_from_exported_program.py   | 199 +++++++++
 tests/python/relax/test_frontend_from_fx.py        | 200 +++++++++
 tests/python/relax/test_op_nn_pooling.py           | 446 +++++++++++++++++++++
 7 files changed, 973 insertions(+), 2 deletions(-)

diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py 
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index f8634f5da7..f683e62d24 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -923,6 +923,50 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
         bias = args[2] if len(args) > 2 else None
         return self.block_builder.emit(relax.op.linear(x, weight, bias, 
"float32"))
 
+    def _max_pool1d_impl(
+        self,
+        x: relax.Expr,
+        kernel_size: Union[int, Tuple[int]] = 1,
+        stride: Optional[Union[int, Tuple[int]]] = None,
+        padding: Optional[int] = 0,
+        dilation: Optional[int] = 1,
+        ceil_mode: Optional[bool] = False,
+    ) -> relax.Var:
+        # Expand to 3D by adding batch dim if input is 2D
+        x_ndim = x.struct_info.ndim
+        if x_ndim == 2:
+            x = relax.op.expand_dims(x, axis=0)
+
+        stride = kernel_size if stride is None else stride
+
+        result = self.block_builder.emit(
+            relax.op.nn.max_pool1d(
+                x,
+                pool_size=kernel_size,
+                strides=stride,
+                padding=padding,
+                dilation=dilation,
+                ceil_mode=ceil_mode,
+                layout="NCW",
+            )
+        )
+
+        # Remove added batch dim from result
+        if x_ndim == 2:
+            result = relax.op.squeeze(result, axis=[0])
+        return result
+
+    def _max_pool1d(self, node: fx.Node) -> relax.Var:
+        args = self.retrieve_args(node)
+        x = args[0]
+        kernel_size = args[1]
+        stride = args[2] if len(args) > 2 else None
+        padding = args[3] if len(args) > 3 else 0
+        dilation = args[4] if len(args) > 4 else 1
+        ceil_mode = args[5] if len(args) > 5 else False
+
+        return self._max_pool1d_impl(x, kernel_size, stride, padding, 
dilation, ceil_mode)
+
     def _max_pool2d_impl(
         self,
         x: relax.Expr,
@@ -932,8 +976,14 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
         dilation: Optional[int] = 1,
         ceil_mode: Optional[bool] = False,
     ) -> relax.Var:
+        # Expand to 4D by adding batch dim if input is 3D
+        x_ndim = x.struct_info.ndim
+        if x_ndim == 3:
+            x = relax.op.expand_dims(x, axis=0)
+
         stride = kernel_size if stride is None else stride
-        return self.block_builder.emit(
+
+        result = self.block_builder.emit(
             relax.op.nn.max_pool2d(
                 x,
                 pool_size=kernel_size,
@@ -945,6 +995,11 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
             )
         )
 
+        # Remove added batch dim from result
+        if x_ndim == 3:
+            result = relax.op.squeeze(result, axis=[0])
+        return result
+
     def _max_pool2d(self, node: fx.Node) -> relax.Var:
         args = self.retrieve_args(node)
         x = args[0]
@@ -956,6 +1011,49 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
 
         return self._max_pool2d_impl(x, kernel_size, stride, padding, 
dilation, ceil_mode)
 
+    def _max_pool3d_impl(
+        self,
+        x: relax.Expr,
+        kernel_size: Union[int, Tuple[int, int, int]] = (1, 1, 1),
+        stride: Optional[Union[int, Tuple[int, int, int]]] = None,
+        padding: Optional[int] = 0,
+        dilation: Optional[int] = 1,
+        ceil_mode: Optional[bool] = False,
+    ) -> relax.Var:
+        # Expand to 5D by adding batch dim if input is 4D
+        x_ndim = x.struct_info.ndim
+        if x_ndim == 4:
+            x = relax.op.expand_dims(x, axis=0)
+
+        stride = kernel_size if stride is None else stride
+
+        result = self.block_builder.emit(
+            relax.op.nn.max_pool3d(
+                x,
+                pool_size=kernel_size,
+                strides=stride,
+                padding=padding,
+                dilation=dilation,
+                ceil_mode=ceil_mode,
+                layout="NCDHW",
+            )
+        )
+
+        # Remove added batch dim from result
+        if x_ndim == 4:
+            result = relax.op.squeeze(result, axis=[0])
+        return result
+
+    def _max_pool3d(self, node: fx.Node) -> relax.Var:
+        args = self.retrieve_args(node)
+        x = args[0]
+        kernel_size = args[1]
+        stride = args[2] if len(args) > 2 else None
+        padding = args[3] if len(args) > 3 else 0
+        dilation = args[4] if len(args) > 4 else 1
+        ceil_mode = args[5] if len(args) > 5 else False
+        return self._max_pool3d_impl(x, kernel_size, stride, padding, 
dilation, ceil_mode)
+
     def _pad(self, node: fx.Node) -> relax.Var:
         x = self.env[node.args[0]]
         pad = node.args[1]
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py 
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 42a57273af..0600dfa552 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -408,7 +408,9 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             "group_norm.default": self._group_norm,
             "layer_norm.default": self._layer_norm,
             "linear.default": self._linear,
+            "max_pool1d.default": self._max_pool1d,
             "max_pool2d.default": self._max_pool2d,
+            "max_pool3d.default": self._max_pool3d,
             "scaled_dot_product_attention.default": 
self._scaled_dot_product_attention,
             "unbind.int": self._unbind,
             "upsample_bilinear2d.vec": self._upsample_bilinear2d,
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py 
b/python/tvm/relax/frontend/torch/fx_translator.py
index 199e58cb1d..0a94c679f5 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -449,6 +449,17 @@ class TorchFXImporter(BaseFXGraphImporter):
         bias = self.params.get(module.bias, None)
         return self.block_builder.emit(relax.op.linear(x, weight, bias, 
"float32"))
 
+    def _max_pool1d_module(self, node: fx.Node) -> relax.Var:
+        x = self.env[node.args[0]]
+        module = self.named_modules[node.target]
+        kernel_size = module.kernel_size
+        stride = module.stride
+        padding = module.padding
+        dilation = module.dilation
+        ceil_mode = module.ceil_mode
+
+        return self._max_pool1d_impl(x, kernel_size, stride, padding, 
dilation, ceil_mode)
+
     def _max_pool2d_module(self, node: fx.Node) -> relax.Var:
         x = self.env[node.args[0]]
         module = self.named_modules[node.target]
@@ -460,6 +471,17 @@ class TorchFXImporter(BaseFXGraphImporter):
 
         return self._max_pool2d_impl(x, kernel_size, stride, padding, 
dilation, ceil_mode)
 
+    def _max_pool3d_module(self, node: fx.Node) -> relax.Var:
+        x = self.env[node.args[0]]
+        module = self.named_modules[node.target]
+        kernel_size = module.kernel_size
+        stride = module.stride
+        padding = module.padding
+        dilation = module.dilation
+        ceil_mode = module.ceil_mode
+
+        return self._max_pool3d_impl(x, kernel_size, stride, padding, 
dilation, ceil_mode)
+
     def _pixel_shuffle_module(self, node: fx.Node) -> relax.Var:
         x = self.env[node.args[0]]
         module = self.named_modules[node.target]
@@ -661,7 +683,9 @@ class TorchFXImporter(BaseFXGraphImporter):
             nn.GroupNorm: self._group_norm_module,
             nn.LayerNorm: self._layer_norm_module,
             nn.Linear: self._linear_module,
+            nn.MaxPool1d: self._max_pool1d_module,
             nn.MaxPool2d: self._max_pool2d_module,
+            nn.MaxPool3d: self._max_pool3d_module,
             nn.modules.sparse.Embedding: self._embedding_module,
             nn.PixelShuffle: self._pixel_shuffle_module,
             # tensor manipulation
@@ -774,7 +798,9 @@ class TorchFXImporter(BaseFXGraphImporter):
             "interpolate": self._interpolate,
             "layer_norm": self._layer_norm,
             "linear": self._linear,
+            "max_pool1d": self._max_pool1d,
             "max_pool2d": self._max_pool2d,
+            "max_pool3d": self._max_pool3d,
             "scaled_dot_product_attention": self._scaled_dot_product_attention,
             "stochastic_depth": lambda node: self.env[node.args[0]],
             "unbind": self._unbind,
diff --git a/src/relax/op/nn/pooling.cc b/src/relax/op/nn/pooling.cc
index 565e6a00c6..391edda9ef 100644
--- a/src/relax/op/nn/pooling.cc
+++ b/src/relax/op/nn/pooling.cc
@@ -95,7 +95,7 @@ StructInfo InferStructInfoPool1D(const Call& call, const 
BlockBuilder& ctx) {
 
   PrimExpr numerator_w = input_w + padding_w - attrs->dilation[0] * (kernel_w 
- 1) - 1;
   if (attrs->ceil_mode) {
-    numerator_w += attrs->strides[1] - 1;
+    numerator_w += attrs->strides[0] - 1;
   }
   out_NCW_shape[2] = analyzer->Simplify(floordiv(numerator_w, 
attrs->strides[0]) + 1);
 
diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index c375992dca..bcd96369d4 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -2364,6 +2364,101 @@ def test_linear():
     verify_model(model, example_args, binding, expected2)
 
 
+def test_maxpool1d():
+    class MaxPool1d(Module):
+        def __init__(self):
+            super().__init__()
+            self.pool = torch.nn.MaxPool1d(kernel_size=2)
+
+        def forward(self, input):
+            return self.pool(input)
+
+    class MaxPool1d_functional(Module):
+        def __init__(self):
+            super().__init__()
+
+        def forward(self, input):
+            return torch.nn.functional.max_pool1d(input, kernel_size=2)
+
+    class MaxPool1d2(Module):
+        def __init__(self):
+            super().__init__()
+            self.pool = torch.nn.MaxPool1d(kernel_size=3, stride=2)
+
+        def forward(self, input):
+            return self.pool(input)
+
+    @tvm.script.ir_module
+    class expected1:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 8), dtype="float32")
+        ) -> R.Tuple(R.Tensor((1, 3, 4), dtype="float32")):
+            with R.dataflow():
+                lv = R.nn.max_pool1d(
+                    input_1,
+                    pool_size=[2],
+                    strides=[2],
+                    dilation=[1],
+                    padding=[0, 0],
+                    layout="NCW",
+                    out_layout="NCW",
+                )
+                gv = (lv,)
+                R.output(gv)
+            return gv
+
+    @tvm.script.ir_module
+    class expected2:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 8), dtype="float32")
+        ) -> R.Tuple(R.Tensor((1, 3, 4), dtype="float32")):
+            with R.dataflow():
+                lv = R.nn.max_pool1d(
+                    input_1,
+                    pool_size=[2],
+                    strides=[2],
+                    dilation=[1],
+                    padding=[0, 0],
+                    layout="NCW",
+                    out_layout="NCW",
+                )
+                gv = (lv,)
+                R.output(gv)
+            return gv
+
+    @tvm.script.ir_module
+    class expected3:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10), dtype="float32")
+        ) -> R.Tuple(R.Tensor((1, 3, 4), dtype="float32")):
+            with R.dataflow():
+                lv = R.nn.max_pool1d(
+                    input_1,
+                    pool_size=[3],
+                    strides=[2],
+                    dilation=[1],
+                    padding=[0, 0],
+                    layout="NCW",
+                    out_layout="NCW",
+                )
+                gv = (lv,)
+                R.output(gv)
+            return gv
+
+    # Example inputs
+    example_args1 = (torch.randn(1, 3, 8, dtype=torch.float32),)
+    example_args2 = (torch.randn(1, 3, 8, dtype=torch.float32),)
+    example_args3 = (torch.randn(1, 3, 10, dtype=torch.float32),)
+
+    # Verify the models
+    verify_model(MaxPool1d(), example_args1, {}, expected1)
+    verify_model(MaxPool1d_functional(), example_args2, {}, expected2)
+    verify_model(MaxPool1d2(), example_args3, {}, expected3)
+
+
 def test_maxpool2d():
     class MaxPool2d(Module):
         def __init__(self):
@@ -2466,6 +2561,110 @@ def test_maxpool2d():
     verify_model(MaxPool2d3(), example_args, {}, expected3)
 
 
+def test_maxpool3d():
+    class MaxPool3d(Module):
+        def __init__(self):
+            super().__init__()
+            self.pool = torch.nn.MaxPool3d(kernel_size=[1, 1, 1])
+
+        def forward(self, input):
+            return self.pool(input)
+
+    class MaxPool3d_functional(Module):
+        def __init__(self):
+            super().__init__()
+
+        def forward(self, input):
+            return torch.nn.functional.max_pool3d(input, kernel_size=[1, 1, 1])
+
+    @tvm.script.ir_module
+    class expected1:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 4, 4, 4), dtype="float32")
+        ) -> R.Tuple(R.Tensor((1, 3, 4, 4, 4), dtype="float32")):
+            with R.dataflow():
+                lv = R.nn.max_pool3d(
+                    input_1,
+                    pool_size=[1, 1, 1],
+                    strides=[1, 1, 1],
+                    dilation=[1, 1, 1],
+                    padding=[0, 0, 0, 0, 0, 0],
+                    layout="NCDHW",
+                    out_layout="NCDHW",
+                )
+                gv = (lv,)
+                R.output(gv)
+            return gv
+
+    class MaxPool3d2(Module):
+        def __init__(self):
+            super().__init__()
+            self.pool = torch.nn.MaxPool3d(kernel_size=[2, 2, 2], dilation=[2, 
2, 2])
+
+        def forward(self, input):
+            return self.pool(input)
+
+    @tvm.script.ir_module
+    class expected2:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 8, 8, 8), dtype="float32")
+        ) -> R.Tuple(R.Tensor((1, 3, 3, 3, 3), dtype="float32")):
+            with R.dataflow():
+                lv = R.nn.max_pool3d(
+                    input_1,
+                    pool_size=[2, 2, 2],
+                    strides=[2, 2, 2],
+                    dilation=[2, 2, 2],
+                    padding=[0, 0, 0, 0, 0, 0],
+                    layout="NCDHW",
+                    out_layout="NCDHW",
+                )
+                gv = (lv,)
+                R.output(gv)
+            return gv
+
+    class MaxPool3d3(Module):
+        def __init__(self):
+            super().__init__()
+            self.pool = torch.nn.MaxPool3d(kernel_size=[3, 3, 3], padding=1, 
stride=2)
+
+        def forward(self, input):
+            return self.pool(input)
+
+    @tvm.script.ir_module
+    class expected3:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10, 10), dtype="float32")
+        ) -> R.Tuple(R.Tensor((1, 3, 5, 5, 5), dtype="float32")):
+            with R.dataflow():
+                lv = R.nn.max_pool3d(
+                    input_1,
+                    pool_size=[3, 3, 3],
+                    strides=[2, 2, 2],
+                    dilation=[1, 1, 1],
+                    padding=[1, 1, 1, 1, 1, 1],
+                    layout="NCDHW",
+                    out_layout="NCDHW",
+                )
+                gv = (lv,)
+                R.output(gv)
+            return gv
+
+    # Example input tensors
+    example_args1 = (torch.randn(1, 3, 4, 4, 4, dtype=torch.float32),)
+    example_args2 = (torch.randn(1, 3, 8, 8, 8, dtype=torch.float32),)
+    example_args3 = (torch.randn(1, 3, 10, 10, 10, dtype=torch.float32),)
+
+    # Verify the models with expected IR modules
+    verify_model(MaxPool3d(), example_args1, {}, expected1)
+    verify_model(MaxPool3d_functional(), example_args1, {}, expected1)
+    verify_model(MaxPool3d2(), example_args2, {}, expected2)
+    verify_model(MaxPool3d3(), example_args3, {}, expected3)
+
+
 def test_scaled_dot_product_attention():
     class Attention1(Module):
         def forward(self, q, k, v):
diff --git a/tests/python/relax/test_frontend_from_fx.py 
b/tests/python/relax/test_frontend_from_fx.py
index 643372750b..56f76bd3e9 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -984,6 +984,106 @@ def test_prelu():
     verify_model(Prelu2(), input_info, {}, expected)
 
 
+def test_maxpool1d():
+    input_info = [([1, 3, 10], "float32")]
+
+    class MaxPool1d(Module):
+        def __init__(self):
+            super().__init__()
+            self.pool = torch.nn.MaxPool1d(kernel_size=2)
+
+        def forward(self, input):
+            return self.pool(input)
+
+    class MaxPool1d_functional(Module):
+        def __init__(self):
+            super().__init__()
+
+        def forward(self, input):
+            return torch.nn.functional.max_pool1d(input, kernel_size=2)
+
+    @tvm.script.ir_module
+    class expected1:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10), dtype="float32")
+        ) -> R.Tensor((1, 3, 5), dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 5), dtype="float32") = R.nn.max_pool1d(
+                    input_1,
+                    pool_size=[2],
+                    strides=[2],
+                    dilation=[1],
+                    padding=[0, 0],
+                    layout="NCW",
+                    out_layout="NCW",
+                )
+                gv: R.Tensor((1, 3, 5), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    class MaxPool1d2(Module):
+        def __init__(self):
+            super().__init__()
+            self.pool = torch.nn.MaxPool1d(kernel_size=3, stride=1, padding=1)
+
+        def forward(self, input):
+            return self.pool(input)
+
+    @tvm.script.ir_module
+    class expected2:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10), dtype="float32")
+        ) -> R.Tensor((1, 3, 10), dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10), dtype="float32") = R.nn.max_pool1d(
+                    input_1,
+                    pool_size=[3],
+                    strides=[1],
+                    dilation=[1],
+                    padding=[1, 1],
+                    layout="NCW",
+                    out_layout="NCW",
+                )
+                gv: R.Tensor((1, 3, 10), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    class MaxPool1d3(Module):
+        def __init__(self):
+            super().__init__()
+            self.pool = torch.nn.MaxPool1d(kernel_size=3, stride=2, dilation=2)
+
+        def forward(self, input):
+            return self.pool(input)
+
+    @tvm.script.ir_module
+    class expected3:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10), dtype="float32")
+        ) -> R.Tensor((1, 3, 3), dtype="float32"):  # Corrected here
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 3), dtype="float32") = R.nn.max_pool1d(
+                    input_1,
+                    pool_size=[3],
+                    strides=[2],
+                    dilation=[2],
+                    padding=[0, 0],
+                    layout="NCW",
+                    out_layout="NCW",
+                )
+                gv: R.Tensor((1, 3, 3), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    verify_model(MaxPool1d(), input_info, {}, expected1)
+    verify_model(MaxPool1d_functional(), input_info, {}, expected1)
+    verify_model(MaxPool1d2(), input_info, {}, expected2)
+    verify_model(MaxPool1d3(), input_info, {}, expected3)
+
+
 def test_maxpool2d():
     input_info = [([1, 3, 10, 10], "float32")]
 
@@ -1087,6 +1187,106 @@ def test_maxpool2d():
     verify_model(MaxPool2d3(), input_info, {}, expected3)
 
 
+def test_maxpool3d():
+    input_info = [([1, 3, 10, 10, 10], "float32")]
+
+    class MaxPool3d(Module):
+        def __init__(self):
+            super().__init__()
+            self.pool = torch.nn.MaxPool3d(kernel_size=[1, 1, 1])
+
+        def forward(self, input):
+            return self.pool(input)
+
+    class MaxPool3d_functional(Module):
+        def __init__(self):
+            super().__init__()
+
+        def forward(self, input):
+            return torch.nn.functional.max_pool3d(input, kernel_size=[1, 1, 1])
+
+    @tvm.script.ir_module
+    class expected1:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10, 10), dtype="float32")
+        ) -> R.Tensor((1, 3, 10, 10, 10), dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10, 10), dtype="float32") = 
R.nn.max_pool3d(
+                    input_1,
+                    pool_size=[1, 1, 1],
+                    strides=[1, 1, 1],
+                    dilation=[1, 1, 1],
+                    padding=[0, 0, 0, 0, 0, 0],
+                    layout="NCDHW",
+                    out_layout="NCDHW",
+                )
+                gv: R.Tensor((1, 3, 10, 10, 10), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    class MaxPool3d2(Module):
+        def __init__(self):
+            super().__init__()
+            self.pool = torch.nn.MaxPool3d(kernel_size=[2, 2, 2], dilation=[1, 
2, 2])
+
+        def forward(self, input):
+            return self.pool(input)
+
+    @tvm.script.ir_module
+    class expected2:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10, 10), dtype="float32")
+        ) -> R.Tensor((1, 3, 5, 4, 4), dtype="float32"):  # Fixed here
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 5, 4, 4), dtype="float32") = 
R.nn.max_pool3d(
+                    input_1,
+                    pool_size=[2, 2, 2],
+                    strides=[2, 2, 2],
+                    dilation=[1, 2, 2],
+                    padding=[0, 0, 0, 0, 0, 0],
+                    layout="NCDHW",
+                    out_layout="NCDHW",
+                )
+                gv: R.Tensor((1, 3, 5, 4, 4), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    class MaxPool3d3(Module):
+        def __init__(self):
+            super().__init__()
+            self.pool = torch.nn.MaxPool3d(kernel_size=[3, 3, 3], padding=1, 
stride=2)
+
+        def forward(self, input):
+            return self.pool(input)
+
+    @tvm.script.ir_module
+    class expected3:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10, 10), dtype="float32")
+        ) -> R.Tensor((1, 3, 5, 5, 5), dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 5, 5, 5), dtype="float32") = 
R.nn.max_pool3d(
+                    input_1,
+                    pool_size=[3, 3, 3],
+                    strides=[2, 2, 2],
+                    dilation=[1, 1, 1],
+                    padding=[1, 1, 1, 1, 1, 1],
+                    layout="NCDHW",
+                    out_layout="NCDHW",
+                )
+                gv: R.Tensor((1, 3, 5, 5, 5), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    verify_model(MaxPool3d(), input_info, {}, expected1)
+    verify_model(MaxPool3d_functional(), input_info, {}, expected1)
+    verify_model(MaxPool3d2(), input_info, {}, expected2)
+    verify_model(MaxPool3d3(), input_info, {}, expected3)
+
+
 def test_avgpool2d():
     input_info = [([1, 3, 10, 10], "float32")]
 
diff --git a/tests/python/relax/test_op_nn_pooling.py 
b/tests/python/relax/test_op_nn_pooling.py
index 2533a2fcad..0d58af1cbe 100644
--- a/tests/python/relax/test_op_nn_pooling.py
+++ b/tests/python/relax/test_op_nn_pooling.py
@@ -25,7 +25,11 @@ from tvm.script import relax as R
 
 def test_op_correctness():
     x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32"))
+    x1 = relax.Var("x1", R.Tensor((2, 3, 64), "float32"))
+    x2 = relax.Var("x2", R.Tensor((2, 3, 8, 28, 28), "float32"))
+    assert relax.op.nn.max_pool1d(x1).op == Op.get("relax.nn.max_pool1d")
     assert relax.op.nn.max_pool2d(x).op == Op.get("relax.nn.max_pool2d")
+    assert relax.op.nn.max_pool3d(x2).op == Op.get("relax.nn.max_pool3d")
     assert relax.op.nn.avg_pool2d(x).op == Op.get("relax.nn.avg_pool2d")
     assert relax.op.nn.adaptive_avg_pool2d(x).op == 
Op.get("relax.nn.adaptive_avg_pool2d")
 
@@ -35,6 +39,197 @@ def _check_inference(bb: relax.BlockBuilder, call: 
relax.Call, expected_sinfo: r
     tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo)
 
 
+def test_max_pool1d_infer_struct_info():
+    bb = relax.BlockBuilder()
+    vdev0 = VDevice("llvm")
+    x0 = relax.Var("x", R.Tensor((2, 3, 32), "float32"))
+    x1 = relax.Var("x", R.Tensor("float32", ndim=3))
+    x2 = relax.Var("x", R.Tensor(ndim=3))
+    x3 = relax.Var("x", R.Tensor("float32"))
+    x4 = relax.Var("x", R.Tensor())
+    x5 = relax.Var("x", R.Tensor((2, 3, 32), "float32", vdev0))
+
+    _check_inference(bb, relax.op.nn.max_pool1d(x0), 
relax.TensorStructInfo((2, 3, 32), "float32"))
+    _check_inference(
+        bb, relax.op.nn.max_pool1d(x5), relax.TensorStructInfo((2, 3, 32), 
"float32", vdev0)
+    )
+    _check_inference(
+        bb, relax.op.nn.max_pool1d(x0, pool_size=3), 
relax.TensorStructInfo((2, 3, 30), "float32")
+    )
+    _check_inference(
+        bb, relax.op.nn.max_pool1d(x0, strides=2), relax.TensorStructInfo((2, 
3, 16), "float32")
+    )
+    _check_inference(
+        bb, relax.op.nn.max_pool1d(x0, padding=1), relax.TensorStructInfo((2, 
3, 34), "float32")
+    )
+    _check_inference(
+        bb, relax.op.nn.max_pool1d(x0, dilation=2), relax.TensorStructInfo((2, 
3, 32), "float32")
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.max_pool1d(x0, layout="NCW", out_layout="NWC"),
+        relax.TensorStructInfo((2, 32, 3), "float32"),
+    )
+    _check_inference(
+        bb, relax.op.nn.max_pool1d(x1), 
relax.TensorStructInfo(dtype="float32", ndim=3)
+    )
+    _check_inference(bb, relax.op.nn.max_pool1d(x2), 
relax.TensorStructInfo(dtype="", ndim=3))
+    _check_inference(
+        bb, relax.op.nn.max_pool1d(x3), 
relax.TensorStructInfo(dtype="float32", ndim=3)
+    )
+    _check_inference(bb, relax.op.nn.max_pool1d(x4), 
relax.TensorStructInfo(dtype="", ndim=3))
+
+
+def test_max_pool1d_infer_struct_info_shape_symbolic():
+    bb = relax.BlockBuilder()
+    n = tir.Var("n", "int64")
+    c = tir.Var("c", "int64")
+    w = tir.Var("w", "int64")
+    c16 = tir.Var("c16", "int64")
+
+    x0 = relax.Var("x", R.Tensor((n, c, w), "float32"))
+    x1 = relax.Var("x", R.Tensor((n, c, w, c16), "float32"))
+
+    _check_inference(
+        bb,
+        relax.op.nn.max_pool1d(x0, pool_size=3, strides=3, padding=2, 
dilation=2),
+        relax.TensorStructInfo(
+            (
+                n,
+                c,
+                tvm.tir.floordiv(w - 1, 3) + 1,
+            ),
+            "float32",
+        ),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.max_pool1d(x1, layout="NCW16c", out_layout="NWC"),
+        relax.TensorStructInfo((n, w, c * 16), "float32"),
+    )
+
+
+def test_max_pool1d_infer_struct_info_shape_var():
+    bb = relax.BlockBuilder()
+    s0 = relax.Var("s", relax.ShapeStructInfo(ndim=3))
+    s1 = relax.Var("s", relax.ShapeStructInfo(ndim=4))
+    s2 = relax.Var("s", relax.ShapeStructInfo())
+
+    x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32"))
+    x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32"))
+    x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32"))
+
+    _check_inference(
+        bb, relax.op.nn.max_pool1d(x0), 
relax.TensorStructInfo(dtype="float32", ndim=3)
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.max_pool1d(x1, layout="NCW16c"),
+        relax.TensorStructInfo(dtype="float32", ndim=4),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.max_pool1d(x2),
+        relax.TensorStructInfo(dtype="float32", ndim=3),
+    )
+
+
+def test_max_pool1d_infer_struct_info_ceil_mode():
+    bb = relax.BlockBuilder()
+    x = relax.Var("x", R.Tensor((2, 3, 32), "float32"))
+
+    _check_inference(
+        bb,
+        relax.op.nn.max_pool1d(x, pool_size=3, strides=2, ceil_mode=True),
+        relax.TensorStructInfo((2, 3, 16), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.max_pool1d(x, pool_size=5, strides=2, ceil_mode=True),
+        relax.TensorStructInfo((2, 3, 15), "float32"),
+    )
+
+
+def test_max_pool1d_infer_struct_info_ceil_mode_symbolic():
+    bb = relax.BlockBuilder()
+    n = tir.Var("n", "int64")
+    c = tir.Var("c", "int64")
+    w = tir.Var("w", "int64")
+    x = relax.Var("x", R.Tensor((n, c, w), "float32"))
+
+    _check_inference(
+        bb,
+        relax.op.nn.max_pool1d(x, pool_size=3, strides=2, padding=1, 
dilation=2, ceil_mode=True),
+        relax.TensorStructInfo((n, c, tvm.tir.floordiv(w, 2)), "float32"),
+    )
+
+
+def test_max_pool1d_infer_struct_info_more_input_dtype():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((2, 3, 32), "float16"))
+    x1 = relax.Var("x", R.Tensor((2, 3, 32), "int8"))
+    x2 = relax.Var("x", R.Tensor((2, 3, 32), "int64"))
+
+    _check_inference(bb, relax.op.nn.max_pool1d(x0), 
relax.TensorStructInfo((2, 3, 32), "float16"))
+    _check_inference(bb, relax.op.nn.max_pool1d(x1), 
relax.TensorStructInfo((2, 3, 32), "int8"))
+    _check_inference(bb, relax.op.nn.max_pool1d(x2), 
relax.TensorStructInfo((2, 3, 32), "int64"))
+
+
+def test_max_pool1d_stride_padding_dilation_int64():
+    x = relax.Var("x", R.Tensor((2, 3, 28), "float32"))
+    max_pool1d = relax.op.nn.max_pool1d(x, pool_size=3, strides=1, padding=1, 
dilation=1)
+
+    assert max_pool1d.attrs.strides[0].dtype == "int64"
+    assert max_pool1d.attrs.padding[0].dtype == "int64"
+    assert max_pool1d.attrs.padding[1].dtype == "int64"
+    assert max_pool1d.attrs.dilation[0].dtype == "int64"
+
+
+def test_max_pool1d_wrong_pool_size_strides_padding_dilation_length():
+    x = relax.Var("x", R.Tensor((2, 3, 28), "float32"))
+    with pytest.raises(TVMError):
+        relax.op.nn.max_pool1d(x, pool_size=(1, 2))
+    with pytest.raises(TVMError):
+        relax.op.nn.max_pool1d(x, strides=(1, 2))
+    with pytest.raises(TVMError):
+        relax.op.nn.max_pool1d(x, padding=(1, 2, 3))
+    with pytest.raises(TVMError):
+        relax.op.nn.max_pool1d(x, dilation=(1, 2))
+
+
+def test_max_pool1d_infer_struct_info_wrong_layout_string():
+    bb = relax.BlockBuilder()
+    x = relax.Var("x", R.Tensor((2, 3, 28), "float32"))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.max_pool1d(x, layout="OIW"))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.max_pool1d(x, out_layout="OWI"))
+
+
+def test_max_pool1d_wrong_input_ndim():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32"))
+    x1 = relax.Var("x", R.Tensor("float32", ndim=5))
+
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.max_pool1d(x0))
+
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.max_pool1d(x1))
+
+
+def test_max_pool1d_infer_struct_info_wrong_input_type():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 28)))
+    x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 28), 
"float32")))
+
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.max_pool1d(x0))
+
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.max_pool1d(x1))
+
+
 def test_max_pool2d_infer_struct_info():
     bb = relax.BlockBuilder()
     vdev0 = VDevice("llvm")
@@ -265,6 +460,257 @@ def test_max_pool2d_infer_struct_info_wrong_input_type():
         bb.normalize(relax.op.nn.max_pool2d(x1))
 
 
+def test_max_pool3d_infer_struct_info():
+    bb = relax.BlockBuilder()
+    vdev0 = VDevice("llvm")
+    x0 = relax.Var("x", R.Tensor((2, 3, 16, 32, 32), "float32"))
+    x1 = relax.Var("x", R.Tensor((2, 16, 32, 32, 3), "float32"))
+    x2 = relax.Var("x", R.Tensor("float32", ndim=5))
+    x3 = relax.Var("x", R.Tensor("float32"))
+    x4 = relax.Var("x", R.Tensor(ndim=5))
+    x5 = relax.Var("x", R.Tensor())
+    x6 = relax.Var("x", R.Tensor((2, 4, 16, 32, 32, 16), "float32"))
+    x7 = relax.Var("x", R.Tensor((2, 3, 16, 32, 32), "float32", vdev0))
+
+    _check_inference(
+        bb, relax.op.nn.max_pool3d(x0), relax.TensorStructInfo((2, 3, 16, 32, 
32), "float32")
+    )
+    _check_inference(
+        bb, relax.op.nn.max_pool3d(x7), relax.TensorStructInfo((2, 3, 16, 32, 
32), "float32", vdev0)
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.max_pool3d(x0, pool_size=3),
+        relax.TensorStructInfo((2, 3, 14, 30, 30), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.max_pool3d(x0, pool_size=(3, 5, 3)),
+        relax.TensorStructInfo((2, 3, 14, 28, 30), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.max_pool3d(x0, padding=1),
+        relax.TensorStructInfo((2, 3, 18, 34, 34), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.max_pool3d(x0, padding=[1, 2, 3]),
+        relax.TensorStructInfo((2, 3, 18, 36, 38), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.max_pool3d(x0, strides=2),
+        relax.TensorStructInfo((2, 3, 8, 16, 16), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.max_pool3d(x0, dilation=2),
+        relax.TensorStructInfo((2, 3, 16, 32, 32), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.max_pool3d(x1, layout="NDHWC"),
+        relax.TensorStructInfo((2, 16, 32, 32, 3), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.max_pool3d(x0, out_layout="NDHWC"),
+        relax.TensorStructInfo((2, 16, 32, 32, 3), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.max_pool3d(x6, layout="NCDHW16c", out_layout="NDHWC16c"),
+        relax.TensorStructInfo((2, 16, 32, 32, 4, 16), "float32"),
+    )
+    _check_inference(
+        bb, relax.op.nn.max_pool3d(x2), 
relax.TensorStructInfo(dtype="float32", ndim=5)
+    )
+    _check_inference(
+        bb, relax.op.nn.max_pool3d(x3), 
relax.TensorStructInfo(dtype="float32", ndim=5)
+    )
+    _check_inference(bb, relax.op.nn.max_pool3d(x4), 
relax.TensorStructInfo(dtype="", ndim=5))
+    _check_inference(bb, relax.op.nn.max_pool3d(x5), 
relax.TensorStructInfo(dtype="", ndim=5))
+
+
+def test_max_pool3d_infer_struct_info_shape_symbolic():
+    bb = relax.BlockBuilder()
+    n = tir.Var("n", "int64")
+    c = tir.Var("c", "int64")
+    c16 = tir.Var("c16", "int64")
+    id = tir.Var("id", "int64")
+    ih = tir.Var("ih", "int64")
+    iw = tir.Var("iw", "int64")
+    x0 = relax.Var("x", R.Tensor((n, c, id, ih, iw), "float32"))
+    x1 = relax.Var("x", R.Tensor((n, c, id, ih, iw, c16), "float32"))
+
+    _check_inference(
+        bb,
+        relax.op.nn.max_pool3d(
+            x0, pool_size=(3, 3, 3), strides=(3, 3, 3), padding=(2, 2, 2), 
dilation=(2, 2, 2)
+        ),
+        relax.TensorStructInfo(
+            (
+                n,
+                c,
+                tvm.tir.floordiv(id - 1, 3) + 1,
+                tvm.tir.floordiv(ih - 1, 3) + 1,
+                tvm.tir.floordiv(iw - 1, 3) + 1,
+            ),
+            "float32",
+        ),
+    )
+
+    _check_inference(
+        bb,
+        relax.op.nn.max_pool3d(x1, layout="NCDHW16c", out_layout="NDHWC"),
+        relax.TensorStructInfo((n, id, ih, iw, c * 16), "float32"),
+    )
+
+
+def test_max_pool3d_infer_struct_info_shape_var():
+    bb = relax.BlockBuilder()
+    s0 = relax.Var("s", relax.ShapeStructInfo(ndim=5))
+    s1 = relax.Var("s", relax.ShapeStructInfo(ndim=6))
+    s2 = relax.Var("s", relax.ShapeStructInfo())
+    x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32"))
+    x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32"))
+    x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32"))
+
+    _check_inference(
+        bb, relax.op.nn.max_pool3d(x0), 
relax.TensorStructInfo(dtype="float32", ndim=5)
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.max_pool3d(x1, layout="NCDHW16c"),
+        relax.TensorStructInfo(dtype="float32", ndim=6),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.max_pool3d(x2),
+        relax.TensorStructInfo(dtype="float32", ndim=5),
+    )
+
+
+def test_max_pool3d_infer_struct_info_ceil_mode():
+    bb = relax.BlockBuilder()
+    x = relax.Var("x", R.Tensor((2, 3, 32, 32, 32), "float32"))
+
+    _check_inference(
+        bb,
+        relax.op.nn.max_pool3d(x, pool_size=3, strides=2, ceil_mode=True),
+        relax.TensorStructInfo((2, 3, 16, 16, 16), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.max_pool3d(x, pool_size=(5, 3, 3), strides=2, 
ceil_mode=True),
+        relax.TensorStructInfo((2, 3, 15, 16, 16), "float32"),
+    )
+
+
+def test_max_pool3d_infer_struct_info_ceil_mode_symbolic():
+    bb = relax.BlockBuilder()
+    n = tir.Var("n", "int64")
+    c = tir.Var("c", "int64")
+    id_ = tir.Var("id", "int64")
+    ih = tir.Var("ih", "int64")
+    iw = tir.Var("iw", "int64")
+    x = relax.Var("x", R.Tensor((n, c, id_, ih, iw), "float32"))
+
+    _check_inference(
+        bb,
+        relax.op.nn.max_pool3d(
+            x,
+            pool_size=(3, 3, 3),
+            strides=(2, 2, 2),
+            padding=(1, 1, 1),
+            dilation=(2, 2, 2),
+            ceil_mode=True,
+        ),
+        relax.TensorStructInfo(
+            (n, c, tvm.tir.floordiv(id_, 2), tvm.tir.floordiv(ih, 2), 
tvm.tir.floordiv(iw, 2)),
+            "float32",
+        ),
+    )
+
+
+def test_max_pool3d_infer_struct_info_more_input_dtype():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((2, 3, 32, 32, 32), "float16"))
+    x1 = relax.Var("x", R.Tensor((2, 3, 32, 32, 32), "int8"))
+    x2 = relax.Var("x", R.Tensor((2, 3, 32, 32, 32), "int64"))
+    _check_inference(
+        bb, relax.op.nn.max_pool3d(x0), relax.TensorStructInfo((2, 3, 32, 32, 
32), "float16")
+    )
+    _check_inference(
+        bb, relax.op.nn.max_pool3d(x1), relax.TensorStructInfo((2, 3, 32, 32, 
32), "int8")
+    )
+    _check_inference(
+        bb, relax.op.nn.max_pool3d(x2), relax.TensorStructInfo((2, 3, 32, 32, 
32), "int64")
+    )
+
+
+def test_max_pool3d_stride_padding_dilation_int64():
+    x = relax.Var("x", R.Tensor((2, 3, 28, 28, 28), "float32"))
+    max_pool3d = relax.op.nn.max_pool3d(
+        x, (3, 3, 3), strides=(1, 1, 1), padding=(1, 1, 1), dilation=(1, 1, 1)
+    )
+
+    assert max_pool3d.attrs.strides[0].dtype == "int64"
+    assert max_pool3d.attrs.strides[1].dtype == "int64"
+    assert max_pool3d.attrs.strides[2].dtype == "int64"
+    assert max_pool3d.attrs.padding[0].dtype == "int64"
+    assert max_pool3d.attrs.padding[1].dtype == "int64"
+    assert max_pool3d.attrs.padding[2].dtype == "int64"
+    assert max_pool3d.attrs.padding[3].dtype == "int64"
+    assert max_pool3d.attrs.padding[4].dtype == "int64"
+    assert max_pool3d.attrs.dilation[0].dtype == "int64"
+    assert max_pool3d.attrs.dilation[1].dtype == "int64"
+    assert max_pool3d.attrs.dilation[2].dtype == "int64"
+
+
+def test_max_pool3d_wrong_pool_size_strides_padding_dilation_length():
+    x = relax.Var("x", R.Tensor((2, 3, 28, 28, 28), "float32"))
+    with pytest.raises(TVMError):
+        relax.op.nn.max_pool3d(x, pool_size=(1, 2, 3, 4))
+    with pytest.raises(TVMError):
+        relax.op.nn.max_pool3d(x, strides=(1, 2, 3, 4))
+    with pytest.raises(TVMError):
+        relax.op.nn.max_pool3d(x, padding=(1, 2, 3, 4))
+    with pytest.raises(TVMError):
+        relax.op.nn.max_pool3d(x, dilation=(1, 2, 3, 4))
+
+
+def test_max_pool3d_infer_struct_info_wrong_layout_string():
+    bb = relax.BlockBuilder()
+    x = relax.Var("x", R.Tensor((2, 3, 28, 28, 28), "float32"))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.max_pool3d(x, layout="OIHW"))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.max_pool3d(x, out_layout="OHWI"))
+
+
+def test_max_pool3d_wrong_input_ndim():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((2, 3, 28, 28, 28, 3), "float32"))
+    x1 = relax.Var("x", R.Tensor("float32", ndim=4))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.max_pool3d(x0))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.max_pool3d(x1))
+
+
+def test_max_pool3d_infer_struct_info_wrong_input_type():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 28, 28, 28)))
+    x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 28, 28, 28), 
"float32")))
+
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.max_pool3d(x0))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.max_pool3d(x1))
+
+
 def test_avg_pool2d_infer_struct_info():
     bb = relax.BlockBuilder()
     vdev0 = VDevice("llvm")


Reply via email to