This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 4df6b1750b [Relax][ONNX] Add roi_pool op and MaxRoiPool frontend 
support (#18952)
4df6b1750b is described below

commit 4df6b1750b790cb413c833f15f9741904871df4b
Author: YinHanke <[email protected]>
AuthorDate: Mon Mar 30 00:29:43 2026 +0800

    [Relax][ONNX] Add roi_pool op and MaxRoiPool frontend support (#18952)
    
    ## Summary
    
    Add Relax `roi_pool` support and wire it through the ONNX frontend for
    `MaxRoiPool`.
    
    ## Changes
    
    - add `relax.vision.roi_pool`, including attrs, Python wrapper, struct
    info inference, and legalization
    - add TOPI `roi_pool` compute for NCHW layout
    - support ONNX `MaxRoiPool` in the Relax ONNX frontend
    - handle empty / out-of-bound pooled bins according to ONNX/reference
    semantics, returning `0` instead of propagating invalid reductions
    - add regression tests for Relax op inference, legalization, and ONNX
    frontend import
    - add out-of-bound ROI coverage to make sure fully invalid pooled bins
    still match ONNX Runtime
    
    ## Validation
    
    - `pytest tests/python/relax/test_op_vision.py -k roi_pool`
    - `pytest tests/python/relax/test_frontend_onnx.py -k 'max_roi_pool'`
    
    
    This PR completes the `MaxRoiPool` portion of the Relax ONNX frontend
    operator work tracked in #18945.
---
 include/tvm/relax/attrs/vision.h                  |  18 ++-
 python/tvm/relax/frontend/onnx/onnx_frontend.py   |  24 +++-
 python/tvm/relax/op/__init__.py                   |   1 +
 python/tvm/relax/op/op_attrs.py                   |   5 +
 python/tvm/relax/op/vision/__init__.py            |   1 +
 python/tvm/relax/op/vision/roi_pool.py            |  57 ++++++++++
 python/tvm/relax/transform/legalize_ops/vision.py |  12 ++
 python/tvm/runtime/support.py                     |  11 +-
 python/tvm/s_tir/meta_schedule/utils.py           |  11 +-
 python/tvm/topi/vision/__init__.py                |   1 +
 python/tvm/topi/vision/roi_pool.py                |  94 ++++++++++++++++
 src/relax/op/vision/roi_pool.cc                   | 128 ++++++++++++++++++++++
 src/relax/op/vision/roi_pool.h                    |  42 +++++++
 tests/python/relax/test_frontend_onnx.py          |  46 ++++++++
 tests/python/relax/test_op_vision.py              |  99 ++++++++++++++++-
 15 files changed, 538 insertions(+), 12 deletions(-)

diff --git a/include/tvm/relax/attrs/vision.h b/include/tvm/relax/attrs/vision.h
index 69ce458e7e..8971127d76 100644
--- a/include/tvm/relax/attrs/vision.h
+++ b/include/tvm/relax/attrs/vision.h
@@ -73,6 +73,23 @@ struct ROIAlignAttrs : public 
AttrsNodeReflAdapter<ROIAlignAttrs> {
   TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.ROIAlignAttrs", 
ROIAlignAttrs, BaseAttrsNode);
 };  // struct ROIAlignAttrs
 
+/*! \brief Attributes used in ROIPool operator */
+struct ROIPoolAttrs : public AttrsNodeReflAdapter<ROIPoolAttrs> {
+  ffi::Array<int64_t> pooled_size;
+  double spatial_scale;
+  ffi::String layout;
+
+  static void RegisterReflection() {
+    namespace refl = tvm::ffi::reflection;
+    refl::ObjectDef<ROIPoolAttrs>()
+        .def_ro("pooled_size", &ROIPoolAttrs::pooled_size, "Output size of roi 
pool.")
+        .def_ro("spatial_scale", &ROIPoolAttrs::spatial_scale,
+                "Ratio of input feature map height (or width) to raw image 
height (or width).")
+        .def_ro("layout", &ROIPoolAttrs::layout, "Dimension ordering of the 
input data.");
+  }
+  TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.ROIPoolAttrs", ROIPoolAttrs, 
BaseAttrsNode);
+};  // struct ROIPoolAttrs
+
 /*! \brief Attributes used in GetValidCounts operator */
 struct GetValidCountsAttrs : public AttrsNodeReflAdapter<GetValidCountsAttrs> {
   double score_threshold;
@@ -132,7 +149,6 @@ struct NonMaximumSuppressionAttrs
                                     NonMaximumSuppressionAttrs, BaseAttrsNode);
 };  // struct NonMaximumSuppressionAttrs
 
-
 /*! \brief Attributes for multibox_transform_loc (SSD / TFLite-style box 
decode). */
 struct MultiboxTransformLocAttrs : public 
AttrsNodeReflAdapter<MultiboxTransformLocAttrs> {
   bool clip;
diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py 
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index 74c8bfe690..fa9d6eb05d 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -2519,6 +2519,28 @@ class RoiAlign(OnnxOpConverter):
         return cls._impl(bb, inputs, attr, params, b"half_pixel")
 
 
+class MaxRoiPool(OnnxOpConverter):
+    """Converts an onnx MaxRoiPool node into an equivalent Relax expression."""
+
+    @classmethod
+    def _impl_v1(cls, bb, inputs, attr, params):
+        if len(inputs) != 2:
+            raise ValueError("MaxRoiPool expects exactly 2 inputs")
+
+        pooled_shape = attr.get("pooled_shape")
+        if pooled_shape is None:
+            raise ValueError("MaxRoiPool requires pooled_shape attribute")
+
+        spatial_scale = attr.get("spatial_scale", 1.0)
+        return relax.op.vision.roi_pool(
+            inputs[0],
+            inputs[1],
+            pooled_size=tuple(pooled_shape),
+            spatial_scale=spatial_scale,
+            layout="NCHW",
+        )
+
+
 class Range(OnnxOpConverter):
     """Converts an onnx Range node into an equivalent Relax expression."""
 
@@ -4179,7 +4201,7 @@ def _get_convert_map():
         "OneHot": OneHot,
         "Unique": Unique,
         "NonZero": NonZero,
-        # "MaxRoiPool": MaxRoiPool,
+        "MaxRoiPool": MaxRoiPool,
         "RoiAlign": RoiAlign,
         "NonMaxSuppression": NonMaxSuppression,
         "AllClassNMS": AllClassNMS,
diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py
index 0b8dc4e7de..6f985ef36c 100644
--- a/python/tvm/relax/op/__init__.py
+++ b/python/tvm/relax/op/__init__.py
@@ -163,6 +163,7 @@ from .vision import (
     multibox_transform_loc,
     non_max_suppression,
     roi_align,
+    roi_pool,
 )
 
 
diff --git a/python/tvm/relax/op/op_attrs.py b/python/tvm/relax/op/op_attrs.py
index 7602af7e58..b4c3260bb4 100644
--- a/python/tvm/relax/op/op_attrs.py
+++ b/python/tvm/relax/op/op_attrs.py
@@ -266,6 +266,11 @@ class ROIAlignAttrs(Attrs):
     """Attributes for vision.roi_align"""
 
 
+@tvm_ffi.register_object("relax.attrs.ROIPoolAttrs")
+class ROIPoolAttrs(Attrs):
+    """Attributes for vision.roi_pool"""
+
+
 @tvm_ffi.register_object("relax.attrs.MultiboxTransformLocAttrs")
 class MultiboxTransformLocAttrs(Attrs):
     """Attributes for vision.multibox_transform_loc"""
diff --git a/python/tvm/relax/op/vision/__init__.py 
b/python/tvm/relax/op/vision/__init__.py
index 58266c5b2a..f99bbc95dd 100644
--- a/python/tvm/relax/op/vision/__init__.py
+++ b/python/tvm/relax/op/vision/__init__.py
@@ -20,3 +20,4 @@
 from .multibox_transform_loc import *
 from .nms import *
 from .roi_align import *
+from .roi_pool import *
diff --git a/python/tvm/relax/op/vision/roi_pool.py 
b/python/tvm/relax/op/vision/roi_pool.py
new file mode 100644
index 0000000000..f8b7f11463
--- /dev/null
+++ b/python/tvm/relax/op/vision/roi_pool.py
@@ -0,0 +1,57 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""ROI Pool operator"""
+
+from ..base import Expr
+from . import _ffi_api
+
+
+def roi_pool(
+    data: Expr,
+    rois: Expr,
+    pooled_size: int | tuple[int, int] | list[int],
+    spatial_scale: float,
+    layout: str = "NCHW",
+):
+    """ROI Pool operator.
+
+    Parameters
+    ----------
+    data : relax.Expr
+        4-D input tensor.
+
+    rois : relax.Expr
+        2-D input tensor with shape `(num_roi, 5)` in
+        `[batch_idx, x1, y1, x2, y2]` format.
+
+    pooled_size : Union[int, Tuple[int, int], List[int]]
+        Output pooled size.
+
+    spatial_scale : float
+        Ratio of input feature map height (or width) to raw image height (or 
width).
+
+    layout : str, optional
+        Layout of the input data. Currently only `NCHW` is supported.
+
+    Returns
+    -------
+    result : relax.Expr
+        The computed result.
+    """
+    if isinstance(pooled_size, int):
+        pooled_size = (pooled_size, pooled_size)
+    return _ffi_api.roi_pool(data, rois, pooled_size, spatial_scale, layout)
diff --git a/python/tvm/relax/transform/legalize_ops/vision.py 
b/python/tvm/relax/transform/legalize_ops/vision.py
index ea0458bfce..7d8586ab52 100644
--- a/python/tvm/relax/transform/legalize_ops/vision.py
+++ b/python/tvm/relax/transform/legalize_ops/vision.py
@@ -150,6 +150,18 @@ def _non_max_suppression(block_builder: BlockBuilder, 
call: Call) -> Expr:
     )
 
 
+@register_legalize("relax.vision.roi_pool")
+def _roi_pool(bb: BlockBuilder, call: Call) -> Expr:
+    return bb.call_te(
+        topi.vision.roi_pool,
+        call.args[0],
+        call.args[1],
+        pooled_size=call.attrs.pooled_size,
+        spatial_scale=call.attrs.spatial_scale,
+        layout=call.attrs.layout,
+    )
+
+
 @register_legalize("relax.vision.multibox_transform_loc")
 def _multibox_transform_loc(bb: BlockBuilder, call: Call) -> Expr:
     variances = tuple(float(x) for x in call.attrs.variances)
diff --git a/python/tvm/runtime/support.py b/python/tvm/runtime/support.py
index b0ac671763..f6591b2871 100644
--- a/python/tvm/runtime/support.py
+++ b/python/tvm/runtime/support.py
@@ -146,10 +146,17 @@ def derived_object(cls: type[T]) -> type[T]:
     fields = metadata.get("fields", [])
     methods = metadata.get("methods", [])
 
-    class TVMDerivedObject(metadata["cls"]):  # type: ignore
+    base_cls = metadata["cls"]
+    slots = []
+    if getattr(base_cls, "__dictoffset__", 0) == 0:
+        slots.append("__dict__")
+    if getattr(base_cls, "__weakrefoffset__", 0) == 0:
+        slots.append("__weakref__")
+
+    class TVMDerivedObject(base_cls):  # type: ignore
         """The derived object to avoid cyclic dependency."""
 
-        __slots__ = ("__dict__", "__weakref__",)
+        __slots__ = tuple(slots)
 
         _cls = cls
         _type = "TVMDerivedObject"
diff --git a/python/tvm/s_tir/meta_schedule/utils.py 
b/python/tvm/s_tir/meta_schedule/utils.py
index 2460a6cc26..1344211711 100644
--- a/python/tvm/s_tir/meta_schedule/utils.py
+++ b/python/tvm/s_tir/meta_schedule/utils.py
@@ -106,10 +106,17 @@ def derived_object(cls: type) -> type:
     fields = metadata.get("fields", [])
     methods = metadata.get("methods", [])
 
-    class TVMDerivedObject(metadata["cls"]):  # type: ignore
+    base_cls = metadata["cls"]
+    slots = []
+    if getattr(base_cls, "__dictoffset__", 0) == 0:
+        slots.append("__dict__")
+    if getattr(base_cls, "__weakrefoffset__", 0) == 0:
+        slots.append("__weakref__")
+
+    class TVMDerivedObject(base_cls):  # type: ignore
         """The derived object to avoid cyclic dependency."""
 
-        __slots__ = ("__dict__", "__weakref__",)
+        __slots__ = tuple(slots)
 
         _cls = cls
         _type = "TVMDerivedObject"
diff --git a/python/tvm/topi/vision/__init__.py 
b/python/tvm/topi/vision/__init__.py
index cb0467c98c..93074201f5 100644
--- a/python/tvm/topi/vision/__init__.py
+++ b/python/tvm/topi/vision/__init__.py
@@ -20,3 +20,4 @@
 from .multibox_transform_loc import *
 from .nms import *
 from .roi_align import *
+from .roi_pool import *
diff --git a/python/tvm/topi/vision/roi_pool.py 
b/python/tvm/topi/vision/roi_pool.py
new file mode 100644
index 0000000000..54a4aeba50
--- /dev/null
+++ b/python/tvm/topi/vision/roi_pool.py
@@ -0,0 +1,94 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name
+"""ROI Pool operator"""
+
+import tvm
+from tvm import te
+
+
+def roi_pool_nchw(data, rois, pooled_size, spatial_scale):
+    """ROI pool operator in NCHW layout."""
+    _, channel, height, width = data.shape
+    num_roi, _ = rois.shape
+
+    if isinstance(pooled_size, int):
+        pooled_size_h = pooled_size_w = pooled_size
+    else:
+        pooled_size_h, pooled_size_w = pooled_size
+
+    zero = tvm.tirx.const(0.0, data.dtype)
+    roi_dtype = rois.dtype
+
+    neg_inf = tvm.tirx.const(float("-inf"), data.dtype)
+
+    def _bin_bounds(i, ph, pw):
+        roi = rois[i]
+        roi_start_w = te.round(roi[1] * spatial_scale).astype("int32")
+        roi_start_h = te.round(roi[2] * spatial_scale).astype("int32")
+        roi_end_w = te.round(roi[3] * spatial_scale).astype("int32")
+        roi_end_h = te.round(roi[4] * spatial_scale).astype("int32")
+
+        roi_h = te.max(roi_end_h - roi_start_h + 1, tvm.tirx.const(1, "int32"))
+        roi_w = te.max(roi_end_w - roi_start_w + 1, tvm.tirx.const(1, "int32"))
+
+        bin_h = tvm.tirx.Cast(roi_dtype, roi_h) / 
tvm.tirx.const(float(pooled_size_h), roi_dtype)
+        bin_w = tvm.tirx.Cast(roi_dtype, roi_w) / 
tvm.tirx.const(float(pooled_size_w), roi_dtype)
+
+        hstart = te.floor(tvm.tirx.Cast(roi_dtype, ph) * bin_h).astype("int32")
+        wstart = te.floor(tvm.tirx.Cast(roi_dtype, pw) * bin_w).astype("int32")
+        hend = te.ceil(tvm.tirx.Cast(roi_dtype, ph + 1) * 
bin_h).astype("int32")
+        wend = te.ceil(tvm.tirx.Cast(roi_dtype, pw + 1) * 
bin_w).astype("int32")
+
+        hstart = te.min(te.max(hstart + roi_start_h, 0), height)
+        hend = te.min(te.max(hend + roi_start_h, 0), height)
+        wstart = te.min(te.max(wstart + roi_start_w, 0), width)
+        wend = te.min(te.max(wend + roi_start_w, 0), width)
+        return hstart, hend, wstart, wend
+
+    def _sample(i, c, ph, pw):
+        roi = rois[i]
+        batch_index = roi[0].astype("int32")
+        hstart, hend, wstart, wend = _bin_bounds(i, ph, pw)
+        valid = tvm.tirx.all(hstart <= rh, rh < hend, wstart <= rw, rw < wend)
+        return tvm.tirx.if_then_else(valid, data[batch_index, c, rh, rw], 
neg_inf)
+
+    def _is_empty(i, ph, pw):
+        hstart, hend, wstart, wend = _bin_bounds(i, ph, pw)
+        return tvm.tirx.any(hend <= hstart, wend <= wstart)
+
+    rh = te.reduce_axis((0, height), name="rh")
+    rw = te.reduce_axis((0, width), name="rw")
+    pooled = te.compute(
+        (num_roi, channel, pooled_size_h, pooled_size_w),
+        lambda i, c, ph, pw: te.max(_sample(i, c, ph, pw), axis=[rh, rw]),
+        tag="pool,roi_pool_nchw",
+    )
+
+    return te.compute(
+        (num_roi, channel, pooled_size_h, pooled_size_w),
+        lambda i, c, ph, pw: tvm.tirx.if_then_else(
+            _is_empty(i, ph, pw), zero, pooled[i, c, ph, pw]
+        ),
+    )
+
+
+def roi_pool(data, rois, pooled_size, spatial_scale, layout="NCHW"):
+    """ROI pool operator."""
+    if layout == "NCHW":
+        return roi_pool_nchw(data, rois, pooled_size, spatial_scale)
+    raise ValueError(f"Unsupported layout for roi_pool: {layout}")
diff --git a/src/relax/op/vision/roi_pool.cc b/src/relax/op/vision/roi_pool.cc
new file mode 100644
index 0000000000..93eddb04cb
--- /dev/null
+++ b/src/relax/op/vision/roi_pool.cc
@@ -0,0 +1,128 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file roi_pool.cc
+ * \brief ROI Pool operators.
+ */
+
+#include "roi_pool.h"
+
+#include <tvm/ffi/reflection/registry.h>
+
+#include <utility>
+
+namespace tvm {
+namespace relax {
+
+TVM_FFI_STATIC_INIT_BLOCK() { ROIPoolAttrs::RegisterReflection(); }
+
+Expr roi_pool(Expr data, Expr rois, ffi::Array<int64_t> pooled_size, double 
spatial_scale,
+              ffi::String layout) {
+  if (pooled_size.size() == 1) {
+    pooled_size.push_back(pooled_size[0]);
+  }
+  TVM_FFI_ICHECK_EQ(pooled_size.size(), 2)
+      << "The input pooled_size length is expected to be 2. However, the given 
pooled_size is "
+      << pooled_size;
+
+  auto attrs = ffi::make_object<ROIPoolAttrs>();
+  attrs->pooled_size = std::move(pooled_size);
+  attrs->spatial_scale = spatial_scale;
+  attrs->layout = layout;
+
+  static const Op& op = Op::Get("relax.vision.roi_pool");
+  return Call(op, {std::move(data), std::move(rois)}, Attrs(attrs), {});
+}
+
+TVM_FFI_STATIC_INIT_BLOCK() {
+  namespace refl = tvm::ffi::reflection;
+  refl::GlobalDef().def("relax.op.vision.roi_pool", roi_pool);
+}
+
+StructInfo InferStructInfoROIPool(const Call& call, const BlockBuilder& ctx) {
+  if (call->args.size() != 2) {
+    ctx->ReportFatal(Diagnostic::Error(call)
+                     << "ROIPool expects two arguments, while the given number 
of arguments is "
+                     << call->args.size());
+  }
+
+  const auto* data_sinfo = 
GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
+  const auto* rois_sinfo = 
GetStructInfoAs<TensorStructInfoNode>(call->args[1]);
+  if (data_sinfo == nullptr) {
+    ctx->ReportFatal(Diagnostic::Error(call)
+                     << "ROIPool expects the input data to be a Tensor, while 
the given data is "
+                     << call->args[0]->GetTypeKey());
+  }
+  if (rois_sinfo == nullptr) {
+    ctx->ReportFatal(Diagnostic::Error(call)
+                     << "ROIPool expects the rois to be a Tensor, while the 
given rois is "
+                     << call->args[1]->GetTypeKey());
+  }
+  if (!data_sinfo->IsUnknownNdim() && data_sinfo->ndim != 4) {
+    ctx->ReportFatal(Diagnostic::Error(call)
+                     << "ROIPool expects the input data to be 4-D, while the 
given data has ndim "
+                     << data_sinfo->ndim);
+  }
+  if (!rois_sinfo->IsUnknownNdim() && rois_sinfo->ndim != 2) {
+    ctx->ReportFatal(Diagnostic::Error(call)
+                     << "ROIPool expects the rois tensor to be 2-D, while the 
given rois has ndim "
+                     << rois_sinfo->ndim);
+  }
+
+  const auto* attrs = call->attrs.as<ROIPoolAttrs>();
+  TVM_FFI_ICHECK(attrs != nullptr) << "Invalid ROIPool attrs";
+  if (attrs->layout != "NCHW") {
+    ctx->ReportFatal(Diagnostic::Error(call)
+                     << "ROIPool only supports NCHW layout, but got " << 
attrs->layout);
+  }
+
+  const auto* rois_shape = rois_sinfo->shape.as<ShapeExprNode>();
+  if (rois_shape != nullptr) {
+    const auto* last_dim = rois_shape->values[1].as<IntImmNode>();
+    if (last_dim != nullptr && last_dim->value != 5) {
+      ctx->ReportFatal(Diagnostic::Error(call)
+                       << "ROIPool expects rois to have shape (num_roi, 5), 
but got last "
+                          "dimension "
+                       << last_dim->value);
+    }
+  }
+
+  if (data_sinfo->shape.as<ShapeExprNode>() == nullptr || rois_shape == 
nullptr) {
+    return TensorStructInfo(data_sinfo->dtype, 4, data_sinfo->vdevice);
+  }
+
+  ffi::Array<PrimExpr> data_shape = 
data_sinfo->shape.as<ShapeExprNode>()->values;
+  ffi::Array<PrimExpr> out_shape = {rois_shape->values[0], data_shape[1],
+                                    Integer(attrs->pooled_size[0]), 
Integer(attrs->pooled_size[1])};
+  return TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype, 
data_sinfo->vdevice);
+}
+
+TVM_REGISTER_OP("relax.vision.roi_pool")
+    .set_attrs_type<ROIPoolAttrs>()
+    .set_num_inputs(2)
+    .add_argument("data", "Tensor", "The input tensor.")
+    .add_argument("rois", "Tensor",
+                  "The input rois with shape (num_roi, 5) in [batch_idx, x1, 
y1, x2, y2] format.")
+    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoROIPool)
+    .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy", 
MixedPrecisionPolicyKind::kFollow)
+    .set_attr<Bool>("FPurity", Bool(true));
+
+}  // namespace relax
+}  // namespace tvm
diff --git a/src/relax/op/vision/roi_pool.h b/src/relax/op/vision/roi_pool.h
new file mode 100644
index 0000000000..738dbee0d8
--- /dev/null
+++ b/src/relax/op/vision/roi_pool.h
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file roi_pool.h
+ * \brief The functions to make Relax ROI Pool operator calls.
+ */
+
+#ifndef TVM_RELAX_OP_VISION_ROI_POOL_H_
+#define TVM_RELAX_OP_VISION_ROI_POOL_H_
+
+#include <tvm/relax/attrs/vision.h>
+
+#include "../op_common.h"
+
+namespace tvm {
+namespace relax {
+
+/*! \brief ROI Pool operator. */
+Expr roi_pool(Expr data, Expr rois, ffi::Array<int64_t> pooled_size, double 
spatial_scale,
+              ffi::String layout);
+
+}  // namespace relax
+}  // namespace tvm
+
+#endif  // TVM_RELAX_OP_VISION_ROI_POOL_H_
diff --git a/tests/python/relax/test_frontend_onnx.py 
b/tests/python/relax/test_frontend_onnx.py
index c6b4df6aaa..0fb5f3003a 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -4423,5 +4423,51 @@ def test_if_nested():
     )
 
 
[email protected](
+    ("pooled_shape", "rois"),
+    [
+        ((1, 1), np.array([[0.0, 1.0, 1.0, 6.0, 6.0], [0.0, 0.0, 0.0, 7.0, 
7.0]], dtype="float32")),
+        (
+            (2, 3),
+            np.array([[0.0, 1.2, 0.5, 6.8, 7.0], [0.0, -1.0, 2.0, 3.5, 5.2]], 
dtype="float32"),
+        ),
+        (
+            (2, 2),
+            np.array(
+                [[0.0, 100.0, 100.0, 110.0, 110.0], [0.0, 1.0, 1.0, 6.0, 
6.0]], dtype="float32"
+            ),
+        ),
+    ],
+)
+def test_max_roi_pool(pooled_shape, rois):
+    x_shape = [1, 4, 8, 8]
+    out_shape = [2, 4, pooled_shape[0], pooled_shape[1]]
+
+    node = helper.make_node(
+        "MaxRoiPool",
+        inputs=["X", "rois"],
+        outputs=["Y"],
+        pooled_shape=pooled_shape,
+        spatial_scale=1.0,
+    )
+
+    graph = helper.make_graph(
+        [node],
+        "max_roi_pool_test",
+        inputs=[
+            helper.make_tensor_value_info("X", TensorProto.FLOAT, x_shape),
+            helper.make_tensor_value_info("rois", TensorProto.FLOAT, [2, 5]),
+        ],
+        outputs=[helper.make_tensor_value_info("Y", TensorProto.FLOAT, 
out_shape)],
+    )
+
+    model = helper.make_model(graph, producer_name="max_roi_pool_test")
+    inputs = {
+        "X": rg.standard_normal(size=x_shape).astype("float32"),
+        "rois": rois,
+    }
+    check_correctness(model, inputs=inputs, opset=16, rtol=1e-5, atol=1e-5)
+
+
 if __name__ == "__main__":
     tvm.testing.main()
diff --git a/tests/python/relax/test_op_vision.py 
b/tests/python/relax/test_op_vision.py
index 6d04a796ca..b597b325f4 100644
--- a/tests/python/relax/test_op_vision.py
+++ b/tests/python/relax/test_op_vision.py
@@ -1050,6 +1050,96 @@ def test_nms_e2e_index_remap():
     np.testing.assert_array_equal(ref_valid_box_count, np.array([[3]], 
dtype="int32"))
 
 
+def test_roi_pool_op_correctness():
+    x = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32"))
+    rois = relax.Var("rois", R.Tensor((4, 5), "float32"))
+    assert relax.op.vision.roi_pool(x, rois, (7, 7), 1.0).op == 
Op.get("relax.vision.roi_pool")
+
+
+def test_roi_pool_infer_struct_info():
+    bb = relax.BlockBuilder()
+    x = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32"))
+    rois = relax.Var("rois", R.Tensor((5, 5), "float32"))
+
+    _check_inference(
+        bb,
+        relax.op.vision.roi_pool(x, rois, (7, 5), 0.25),
+        relax.TensorStructInfo((5, 3, 7, 5), "float32"),
+    )
+
+
+def test_roi_pool_infer_struct_info_shape_var():
+    bb = relax.BlockBuilder()
+    n = tirx.Var("n", "int64")
+    c = tirx.Var("c", "int64")
+    h = tirx.Var("h", "int64")
+    w = tirx.Var("w", "int64")
+    num_roi = tirx.Var("num_roi", "int64")
+
+    x = relax.Var("x", R.Tensor((n, c, h, w), "float32"))
+    rois = relax.Var("rois", R.Tensor((num_roi, 5), "float32"))
+
+    _check_inference(
+        bb,
+        relax.op.vision.roi_pool(x, rois, (7, 7), 0.5),
+        relax.TensorStructInfo((num_roi, c, 7, 7), "float32"),
+    )
+
+
+def test_roi_pool_wrong_input_ndim():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((2, 3, 32), "float32"))
+    x1 = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32"))
+    rois0 = relax.Var("rois", R.Tensor((4,), "float32"))
+    rois1 = relax.Var("rois", R.Tensor((4, 5), "float32"))
+
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.vision.roi_pool(x0, rois1, (7, 7), 1.0))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.vision.roi_pool(x1, rois0, (7, 7), 1.0))
+
+
+def test_roi_pool_wrong_rois_last_dim():
+    bb = relax.BlockBuilder()
+    x = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32"))
+    rois = relax.Var("rois", R.Tensor((4, 4), "float32"))
+
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.vision.roi_pool(x, rois, (7, 7), 1.0))
+
+
+def test_roi_pool_wrong_layout():
+    bb = relax.BlockBuilder()
+    x = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32"))
+    rois = relax.Var("rois", R.Tensor((4, 5), "float32"))
+
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.vision.roi_pool(x, rois, (7, 7), 1.0, 
layout="NHWC"))
+
+
+def test_roi_pool_legalize():
+    @tvm.script.ir_module
+    class ROIPool:
+        @R.function
+        def main(
+            x: R.Tensor((1, 2, 8, 8), "float32"),
+            rois: R.Tensor((2, 5), "float32"),
+        ) -> R.Tensor((2, 2, 3, 2), "float32"):
+            gv: R.Tensor((2, 2, 3, 2), "float32") = R.vision.roi_pool(
+                x,
+                rois,
+                pooled_size=(3, 2),
+                spatial_scale=1.0,
+                layout="NCHW",
+            )
+            return gv
+
+    mod = LegalizeOps()(ROIPool)
+    assert "call_tir" in str(mod)
+    tvm.ir.assert_structural_equal(
+        mod["main"].ret_struct_info,
+        relax.TensorStructInfo((2, 2, 3, 2), "float32"),
+    )
 def test_all_class_non_max_suppression_infer_struct_info():
     bb = relax.BlockBuilder()
     batch_size, num_classes, num_boxes = 10, 8, 5
@@ -1201,12 +1291,9 @@ def test_multibox_transform_loc_op_correctness():
     cls = relax.Var("cls", R.Tensor((1, 5, 10), "float32"))
     loc = relax.Var("loc", R.Tensor((1, 40), "float32"))
     anc = relax.Var("anc", R.Tensor((1, 10, 4), "float32"))
-    assert (
-        relax.op.vision.multibox_transform_loc(
-            cls, loc, anc, False, 0.0, (1.0, 1.0, 1.0, 1.0), True
-        ).op
-        == Op.get("relax.vision.multibox_transform_loc")
-    )
+    assert relax.op.vision.multibox_transform_loc(
+        cls, loc, anc, False, 0.0, (1.0, 1.0, 1.0, 1.0), True
+    ).op == Op.get("relax.vision.multibox_transform_loc")
 
 
 def test_multibox_transform_loc_infer_struct_info():

Reply via email to