This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 4df6b1750b [Relax][ONNX] Add roi_pool op and MaxRoiPool frontend
support (#18952)
4df6b1750b is described below
commit 4df6b1750b790cb413c833f15f9741904871df4b
Author: YinHanke <[email protected]>
AuthorDate: Mon Mar 30 00:29:43 2026 +0800
[Relax][ONNX] Add roi_pool op and MaxRoiPool frontend support (#18952)
## Summary
Add Relax `roi_pool` support and wire it through the ONNX frontend for
`MaxRoiPool`.
## Changes
- add `relax.vision.roi_pool`, including attrs, Python wrapper, struct
info inference, and legalization
- add TOPI `roi_pool` compute for NCHW layout
- support ONNX `MaxRoiPool` in the Relax ONNX frontend
- handle empty / out-of-bound pooled bins according to ONNX/reference
semantics, returning `0` instead of propagating invalid reductions
- add regression tests for Relax op inference, legalization, and ONNX
frontend import
- add out-of-bound ROI coverage to make sure fully invalid pooled bins
still match ONNX Runtime
## Validation
- `pytest tests/python/relax/test_op_vision.py -k roi_pool`
- `pytest tests/python/relax/test_frontend_onnx.py -k 'max_roi_pool'`
This PR completes the `MaxRoiPool` portion of the Relax ONNX frontend
operator work tracked in #18945.
---
include/tvm/relax/attrs/vision.h | 18 ++-
python/tvm/relax/frontend/onnx/onnx_frontend.py | 24 +++-
python/tvm/relax/op/__init__.py | 1 +
python/tvm/relax/op/op_attrs.py | 5 +
python/tvm/relax/op/vision/__init__.py | 1 +
python/tvm/relax/op/vision/roi_pool.py | 57 ++++++++++
python/tvm/relax/transform/legalize_ops/vision.py | 12 ++
python/tvm/runtime/support.py | 11 +-
python/tvm/s_tir/meta_schedule/utils.py | 11 +-
python/tvm/topi/vision/__init__.py | 1 +
python/tvm/topi/vision/roi_pool.py | 94 ++++++++++++++++
src/relax/op/vision/roi_pool.cc | 128 ++++++++++++++++++++++
src/relax/op/vision/roi_pool.h | 42 +++++++
tests/python/relax/test_frontend_onnx.py | 46 ++++++++
tests/python/relax/test_op_vision.py | 99 ++++++++++++++++-
15 files changed, 538 insertions(+), 12 deletions(-)
diff --git a/include/tvm/relax/attrs/vision.h b/include/tvm/relax/attrs/vision.h
index 69ce458e7e..8971127d76 100644
--- a/include/tvm/relax/attrs/vision.h
+++ b/include/tvm/relax/attrs/vision.h
@@ -73,6 +73,23 @@ struct ROIAlignAttrs : public
AttrsNodeReflAdapter<ROIAlignAttrs> {
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.ROIAlignAttrs",
ROIAlignAttrs, BaseAttrsNode);
}; // struct ROIAlignAttrs
+/*! \brief Attributes used in ROIPool operator */
+struct ROIPoolAttrs : public AttrsNodeReflAdapter<ROIPoolAttrs> {
+ ffi::Array<int64_t> pooled_size;
+ double spatial_scale;
+ ffi::String layout;
+
+ static void RegisterReflection() {
+ namespace refl = tvm::ffi::reflection;
+ refl::ObjectDef<ROIPoolAttrs>()
+ .def_ro("pooled_size", &ROIPoolAttrs::pooled_size, "Output size of roi
pool.")
+ .def_ro("spatial_scale", &ROIPoolAttrs::spatial_scale,
+ "Ratio of input feature map height (or width) to raw image
height (or width).")
+ .def_ro("layout", &ROIPoolAttrs::layout, "Dimension ordering of the
input data.");
+ }
+ TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.ROIPoolAttrs", ROIPoolAttrs,
BaseAttrsNode);
+}; // struct ROIPoolAttrs
+
/*! \brief Attributes used in GetValidCounts operator */
struct GetValidCountsAttrs : public AttrsNodeReflAdapter<GetValidCountsAttrs> {
double score_threshold;
@@ -132,7 +149,6 @@ struct NonMaximumSuppressionAttrs
NonMaximumSuppressionAttrs, BaseAttrsNode);
}; // struct NonMaximumSuppressionAttrs
-
/*! \brief Attributes for multibox_transform_loc (SSD / TFLite-style box
decode). */
struct MultiboxTransformLocAttrs : public
AttrsNodeReflAdapter<MultiboxTransformLocAttrs> {
bool clip;
diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index 74c8bfe690..fa9d6eb05d 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -2519,6 +2519,28 @@ class RoiAlign(OnnxOpConverter):
return cls._impl(bb, inputs, attr, params, b"half_pixel")
+class MaxRoiPool(OnnxOpConverter):
+ """Converts an onnx MaxRoiPool node into an equivalent Relax expression."""
+
+ @classmethod
+ def _impl_v1(cls, bb, inputs, attr, params):
+ if len(inputs) != 2:
+ raise ValueError("MaxRoiPool expects exactly 2 inputs")
+
+ pooled_shape = attr.get("pooled_shape")
+ if pooled_shape is None:
+ raise ValueError("MaxRoiPool requires pooled_shape attribute")
+
+ spatial_scale = attr.get("spatial_scale", 1.0)
+ return relax.op.vision.roi_pool(
+ inputs[0],
+ inputs[1],
+ pooled_size=tuple(pooled_shape),
+ spatial_scale=spatial_scale,
+ layout="NCHW",
+ )
+
+
class Range(OnnxOpConverter):
"""Converts an onnx Range node into an equivalent Relax expression."""
@@ -4179,7 +4201,7 @@ def _get_convert_map():
"OneHot": OneHot,
"Unique": Unique,
"NonZero": NonZero,
- # "MaxRoiPool": MaxRoiPool,
+ "MaxRoiPool": MaxRoiPool,
"RoiAlign": RoiAlign,
"NonMaxSuppression": NonMaxSuppression,
"AllClassNMS": AllClassNMS,
diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py
index 0b8dc4e7de..6f985ef36c 100644
--- a/python/tvm/relax/op/__init__.py
+++ b/python/tvm/relax/op/__init__.py
@@ -163,6 +163,7 @@ from .vision import (
multibox_transform_loc,
non_max_suppression,
roi_align,
+ roi_pool,
)
diff --git a/python/tvm/relax/op/op_attrs.py b/python/tvm/relax/op/op_attrs.py
index 7602af7e58..b4c3260bb4 100644
--- a/python/tvm/relax/op/op_attrs.py
+++ b/python/tvm/relax/op/op_attrs.py
@@ -266,6 +266,11 @@ class ROIAlignAttrs(Attrs):
"""Attributes for vision.roi_align"""
+@tvm_ffi.register_object("relax.attrs.ROIPoolAttrs")
+class ROIPoolAttrs(Attrs):
+ """Attributes for vision.roi_pool"""
+
+
@tvm_ffi.register_object("relax.attrs.MultiboxTransformLocAttrs")
class MultiboxTransformLocAttrs(Attrs):
"""Attributes for vision.multibox_transform_loc"""
diff --git a/python/tvm/relax/op/vision/__init__.py
b/python/tvm/relax/op/vision/__init__.py
index 58266c5b2a..f99bbc95dd 100644
--- a/python/tvm/relax/op/vision/__init__.py
+++ b/python/tvm/relax/op/vision/__init__.py
@@ -20,3 +20,4 @@
from .multibox_transform_loc import *
from .nms import *
from .roi_align import *
+from .roi_pool import *
diff --git a/python/tvm/relax/op/vision/roi_pool.py
b/python/tvm/relax/op/vision/roi_pool.py
new file mode 100644
index 0000000000..f8b7f11463
--- /dev/null
+++ b/python/tvm/relax/op/vision/roi_pool.py
@@ -0,0 +1,57 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""ROI Pool operator"""
+
+from ..base import Expr
+from . import _ffi_api
+
+
+def roi_pool(
+ data: Expr,
+ rois: Expr,
+ pooled_size: int | tuple[int, int] | list[int],
+ spatial_scale: float,
+ layout: str = "NCHW",
+):
+ """ROI Pool operator.
+
+ Parameters
+ ----------
+ data : relax.Expr
+ 4-D input tensor.
+
+ rois : relax.Expr
+ 2-D input tensor with shape `(num_roi, 5)` in
+ `[batch_idx, x1, y1, x2, y2]` format.
+
+ pooled_size : Union[int, Tuple[int, int], List[int]]
+ Output pooled size.
+
+ spatial_scale : float
+ Ratio of input feature map height (or width) to raw image height (or
width).
+
+ layout : str, optional
+ Layout of the input data. Currently only `NCHW` is supported.
+
+ Returns
+ -------
+ result : relax.Expr
+ The computed result.
+ """
+ if isinstance(pooled_size, int):
+ pooled_size = (pooled_size, pooled_size)
+ return _ffi_api.roi_pool(data, rois, pooled_size, spatial_scale, layout)
diff --git a/python/tvm/relax/transform/legalize_ops/vision.py
b/python/tvm/relax/transform/legalize_ops/vision.py
index ea0458bfce..7d8586ab52 100644
--- a/python/tvm/relax/transform/legalize_ops/vision.py
+++ b/python/tvm/relax/transform/legalize_ops/vision.py
@@ -150,6 +150,18 @@ def _non_max_suppression(block_builder: BlockBuilder,
call: Call) -> Expr:
)
+@register_legalize("relax.vision.roi_pool")
+def _roi_pool(bb: BlockBuilder, call: Call) -> Expr:
+ return bb.call_te(
+ topi.vision.roi_pool,
+ call.args[0],
+ call.args[1],
+ pooled_size=call.attrs.pooled_size,
+ spatial_scale=call.attrs.spatial_scale,
+ layout=call.attrs.layout,
+ )
+
+
@register_legalize("relax.vision.multibox_transform_loc")
def _multibox_transform_loc(bb: BlockBuilder, call: Call) -> Expr:
variances = tuple(float(x) for x in call.attrs.variances)
diff --git a/python/tvm/runtime/support.py b/python/tvm/runtime/support.py
index b0ac671763..f6591b2871 100644
--- a/python/tvm/runtime/support.py
+++ b/python/tvm/runtime/support.py
@@ -146,10 +146,17 @@ def derived_object(cls: type[T]) -> type[T]:
fields = metadata.get("fields", [])
methods = metadata.get("methods", [])
- class TVMDerivedObject(metadata["cls"]): # type: ignore
+ base_cls = metadata["cls"]
+ slots = []
+ if getattr(base_cls, "__dictoffset__", 0) == 0:
+ slots.append("__dict__")
+ if getattr(base_cls, "__weakrefoffset__", 0) == 0:
+ slots.append("__weakref__")
+
+ class TVMDerivedObject(base_cls): # type: ignore
"""The derived object to avoid cyclic dependency."""
- __slots__ = ("__dict__", "__weakref__",)
+ __slots__ = tuple(slots)
_cls = cls
_type = "TVMDerivedObject"
diff --git a/python/tvm/s_tir/meta_schedule/utils.py
b/python/tvm/s_tir/meta_schedule/utils.py
index 2460a6cc26..1344211711 100644
--- a/python/tvm/s_tir/meta_schedule/utils.py
+++ b/python/tvm/s_tir/meta_schedule/utils.py
@@ -106,10 +106,17 @@ def derived_object(cls: type) -> type:
fields = metadata.get("fields", [])
methods = metadata.get("methods", [])
- class TVMDerivedObject(metadata["cls"]): # type: ignore
+ base_cls = metadata["cls"]
+ slots = []
+ if getattr(base_cls, "__dictoffset__", 0) == 0:
+ slots.append("__dict__")
+ if getattr(base_cls, "__weakrefoffset__", 0) == 0:
+ slots.append("__weakref__")
+
+ class TVMDerivedObject(base_cls): # type: ignore
"""The derived object to avoid cyclic dependency."""
- __slots__ = ("__dict__", "__weakref__",)
+ __slots__ = tuple(slots)
_cls = cls
_type = "TVMDerivedObject"
diff --git a/python/tvm/topi/vision/__init__.py
b/python/tvm/topi/vision/__init__.py
index cb0467c98c..93074201f5 100644
--- a/python/tvm/topi/vision/__init__.py
+++ b/python/tvm/topi/vision/__init__.py
@@ -20,3 +20,4 @@
from .multibox_transform_loc import *
from .nms import *
from .roi_align import *
+from .roi_pool import *
diff --git a/python/tvm/topi/vision/roi_pool.py
b/python/tvm/topi/vision/roi_pool.py
new file mode 100644
index 0000000000..54a4aeba50
--- /dev/null
+++ b/python/tvm/topi/vision/roi_pool.py
@@ -0,0 +1,94 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name
+"""ROI Pool operator"""
+
+import tvm
+from tvm import te
+
+
+def roi_pool_nchw(data, rois, pooled_size, spatial_scale):
+ """ROI pool operator in NCHW layout."""
+ _, channel, height, width = data.shape
+ num_roi, _ = rois.shape
+
+ if isinstance(pooled_size, int):
+ pooled_size_h = pooled_size_w = pooled_size
+ else:
+ pooled_size_h, pooled_size_w = pooled_size
+
+ zero = tvm.tirx.const(0.0, data.dtype)
+ roi_dtype = rois.dtype
+
+ neg_inf = tvm.tirx.const(float("-inf"), data.dtype)
+
+ def _bin_bounds(i, ph, pw):
+ roi = rois[i]
+ roi_start_w = te.round(roi[1] * spatial_scale).astype("int32")
+ roi_start_h = te.round(roi[2] * spatial_scale).astype("int32")
+ roi_end_w = te.round(roi[3] * spatial_scale).astype("int32")
+ roi_end_h = te.round(roi[4] * spatial_scale).astype("int32")
+
+ roi_h = te.max(roi_end_h - roi_start_h + 1, tvm.tirx.const(1, "int32"))
+ roi_w = te.max(roi_end_w - roi_start_w + 1, tvm.tirx.const(1, "int32"))
+
+ bin_h = tvm.tirx.Cast(roi_dtype, roi_h) /
tvm.tirx.const(float(pooled_size_h), roi_dtype)
+ bin_w = tvm.tirx.Cast(roi_dtype, roi_w) /
tvm.tirx.const(float(pooled_size_w), roi_dtype)
+
+ hstart = te.floor(tvm.tirx.Cast(roi_dtype, ph) * bin_h).astype("int32")
+ wstart = te.floor(tvm.tirx.Cast(roi_dtype, pw) * bin_w).astype("int32")
+ hend = te.ceil(tvm.tirx.Cast(roi_dtype, ph + 1) *
bin_h).astype("int32")
+ wend = te.ceil(tvm.tirx.Cast(roi_dtype, pw + 1) *
bin_w).astype("int32")
+
+ hstart = te.min(te.max(hstart + roi_start_h, 0), height)
+ hend = te.min(te.max(hend + roi_start_h, 0), height)
+ wstart = te.min(te.max(wstart + roi_start_w, 0), width)
+ wend = te.min(te.max(wend + roi_start_w, 0), width)
+ return hstart, hend, wstart, wend
+
+ def _sample(i, c, ph, pw):
+ roi = rois[i]
+ batch_index = roi[0].astype("int32")
+ hstart, hend, wstart, wend = _bin_bounds(i, ph, pw)
+ valid = tvm.tirx.all(hstart <= rh, rh < hend, wstart <= rw, rw < wend)
+ return tvm.tirx.if_then_else(valid, data[batch_index, c, rh, rw],
neg_inf)
+
+ def _is_empty(i, ph, pw):
+ hstart, hend, wstart, wend = _bin_bounds(i, ph, pw)
+ return tvm.tirx.any(hend <= hstart, wend <= wstart)
+
+ rh = te.reduce_axis((0, height), name="rh")
+ rw = te.reduce_axis((0, width), name="rw")
+ pooled = te.compute(
+ (num_roi, channel, pooled_size_h, pooled_size_w),
+ lambda i, c, ph, pw: te.max(_sample(i, c, ph, pw), axis=[rh, rw]),
+ tag="pool,roi_pool_nchw",
+ )
+
+ return te.compute(
+ (num_roi, channel, pooled_size_h, pooled_size_w),
+ lambda i, c, ph, pw: tvm.tirx.if_then_else(
+ _is_empty(i, ph, pw), zero, pooled[i, c, ph, pw]
+ ),
+ )
+
+
+def roi_pool(data, rois, pooled_size, spatial_scale, layout="NCHW"):
+ """ROI pool operator."""
+ if layout == "NCHW":
+ return roi_pool_nchw(data, rois, pooled_size, spatial_scale)
+ raise ValueError(f"Unsupported layout for roi_pool: {layout}")
diff --git a/src/relax/op/vision/roi_pool.cc b/src/relax/op/vision/roi_pool.cc
new file mode 100644
index 0000000000..93eddb04cb
--- /dev/null
+++ b/src/relax/op/vision/roi_pool.cc
@@ -0,0 +1,128 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file roi_pool.cc
+ * \brief ROI Pool operators.
+ */
+
+#include "roi_pool.h"
+
+#include <tvm/ffi/reflection/registry.h>
+
+#include <utility>
+
+namespace tvm {
+namespace relax {
+
+TVM_FFI_STATIC_INIT_BLOCK() { ROIPoolAttrs::RegisterReflection(); }
+
+Expr roi_pool(Expr data, Expr rois, ffi::Array<int64_t> pooled_size, double
spatial_scale,
+ ffi::String layout) {
+ if (pooled_size.size() == 1) {
+ pooled_size.push_back(pooled_size[0]);
+ }
+ TVM_FFI_ICHECK_EQ(pooled_size.size(), 2)
+ << "The input pooled_size length is expected to be 2. However, the given
pooled_size is "
+ << pooled_size;
+
+ auto attrs = ffi::make_object<ROIPoolAttrs>();
+ attrs->pooled_size = std::move(pooled_size);
+ attrs->spatial_scale = spatial_scale;
+ attrs->layout = layout;
+
+ static const Op& op = Op::Get("relax.vision.roi_pool");
+ return Call(op, {std::move(data), std::move(rois)}, Attrs(attrs), {});
+}
+
+TVM_FFI_STATIC_INIT_BLOCK() {
+ namespace refl = tvm::ffi::reflection;
+ refl::GlobalDef().def("relax.op.vision.roi_pool", roi_pool);
+}
+
+StructInfo InferStructInfoROIPool(const Call& call, const BlockBuilder& ctx) {
+ if (call->args.size() != 2) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "ROIPool expects two arguments, while the given number
of arguments is "
+ << call->args.size());
+ }
+
+ const auto* data_sinfo =
GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
+ const auto* rois_sinfo =
GetStructInfoAs<TensorStructInfoNode>(call->args[1]);
+ if (data_sinfo == nullptr) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "ROIPool expects the input data to be a Tensor, while
the given data is "
+ << call->args[0]->GetTypeKey());
+ }
+ if (rois_sinfo == nullptr) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "ROIPool expects the rois to be a Tensor, while the
given rois is "
+ << call->args[1]->GetTypeKey());
+ }
+ if (!data_sinfo->IsUnknownNdim() && data_sinfo->ndim != 4) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "ROIPool expects the input data to be 4-D, while the
given data has ndim "
+ << data_sinfo->ndim);
+ }
+ if (!rois_sinfo->IsUnknownNdim() && rois_sinfo->ndim != 2) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "ROIPool expects the rois tensor to be 2-D, while the
given rois has ndim "
+ << rois_sinfo->ndim);
+ }
+
+ const auto* attrs = call->attrs.as<ROIPoolAttrs>();
+ TVM_FFI_ICHECK(attrs != nullptr) << "Invalid ROIPool attrs";
+ if (attrs->layout != "NCHW") {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "ROIPool only supports NCHW layout, but got " <<
attrs->layout);
+ }
+
+ const auto* rois_shape = rois_sinfo->shape.as<ShapeExprNode>();
+ if (rois_shape != nullptr) {
+ const auto* last_dim = rois_shape->values[1].as<IntImmNode>();
+ if (last_dim != nullptr && last_dim->value != 5) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "ROIPool expects rois to have shape (num_roi, 5),
but got last "
+ "dimension "
+ << last_dim->value);
+ }
+ }
+
+ if (data_sinfo->shape.as<ShapeExprNode>() == nullptr || rois_shape ==
nullptr) {
+ return TensorStructInfo(data_sinfo->dtype, 4, data_sinfo->vdevice);
+ }
+
+ ffi::Array<PrimExpr> data_shape =
data_sinfo->shape.as<ShapeExprNode>()->values;
+ ffi::Array<PrimExpr> out_shape = {rois_shape->values[0], data_shape[1],
+ Integer(attrs->pooled_size[0]),
Integer(attrs->pooled_size[1])};
+ return TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype,
data_sinfo->vdevice);
+}
+
+TVM_REGISTER_OP("relax.vision.roi_pool")
+ .set_attrs_type<ROIPoolAttrs>()
+ .set_num_inputs(2)
+ .add_argument("data", "Tensor", "The input tensor.")
+ .add_argument("rois", "Tensor",
+ "The input rois with shape (num_roi, 5) in [batch_idx, x1,
y1, x2, y2] format.")
+ .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoROIPool)
+ .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy",
MixedPrecisionPolicyKind::kFollow)
+ .set_attr<Bool>("FPurity", Bool(true));
+
+} // namespace relax
+} // namespace tvm
diff --git a/src/relax/op/vision/roi_pool.h b/src/relax/op/vision/roi_pool.h
new file mode 100644
index 0000000000..738dbee0d8
--- /dev/null
+++ b/src/relax/op/vision/roi_pool.h
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file roi_pool.h
+ * \brief The functions to make Relax ROI Pool operator calls.
+ */
+
+#ifndef TVM_RELAX_OP_VISION_ROI_POOL_H_
+#define TVM_RELAX_OP_VISION_ROI_POOL_H_
+
+#include <tvm/relax/attrs/vision.h>
+
+#include "../op_common.h"
+
+namespace tvm {
+namespace relax {
+
+/*! \brief ROI Pool operator. */
+Expr roi_pool(Expr data, Expr rois, ffi::Array<int64_t> pooled_size, double
spatial_scale,
+ ffi::String layout);
+
+} // namespace relax
+} // namespace tvm
+
+#endif // TVM_RELAX_OP_VISION_ROI_POOL_H_
diff --git a/tests/python/relax/test_frontend_onnx.py
b/tests/python/relax/test_frontend_onnx.py
index c6b4df6aaa..0fb5f3003a 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -4423,5 +4423,51 @@ def test_if_nested():
)
[email protected](
+ ("pooled_shape", "rois"),
+ [
+ ((1, 1), np.array([[0.0, 1.0, 1.0, 6.0, 6.0], [0.0, 0.0, 0.0, 7.0,
7.0]], dtype="float32")),
+ (
+ (2, 3),
+ np.array([[0.0, 1.2, 0.5, 6.8, 7.0], [0.0, -1.0, 2.0, 3.5, 5.2]],
dtype="float32"),
+ ),
+ (
+ (2, 2),
+ np.array(
+ [[0.0, 100.0, 100.0, 110.0, 110.0], [0.0, 1.0, 1.0, 6.0,
6.0]], dtype="float32"
+ ),
+ ),
+ ],
+)
+def test_max_roi_pool(pooled_shape, rois):
+ x_shape = [1, 4, 8, 8]
+ out_shape = [2, 4, pooled_shape[0], pooled_shape[1]]
+
+ node = helper.make_node(
+ "MaxRoiPool",
+ inputs=["X", "rois"],
+ outputs=["Y"],
+ pooled_shape=pooled_shape,
+ spatial_scale=1.0,
+ )
+
+ graph = helper.make_graph(
+ [node],
+ "max_roi_pool_test",
+ inputs=[
+ helper.make_tensor_value_info("X", TensorProto.FLOAT, x_shape),
+ helper.make_tensor_value_info("rois", TensorProto.FLOAT, [2, 5]),
+ ],
+ outputs=[helper.make_tensor_value_info("Y", TensorProto.FLOAT,
out_shape)],
+ )
+
+ model = helper.make_model(graph, producer_name="max_roi_pool_test")
+ inputs = {
+ "X": rg.standard_normal(size=x_shape).astype("float32"),
+ "rois": rois,
+ }
+ check_correctness(model, inputs=inputs, opset=16, rtol=1e-5, atol=1e-5)
+
+
if __name__ == "__main__":
tvm.testing.main()
diff --git a/tests/python/relax/test_op_vision.py
b/tests/python/relax/test_op_vision.py
index 6d04a796ca..b597b325f4 100644
--- a/tests/python/relax/test_op_vision.py
+++ b/tests/python/relax/test_op_vision.py
@@ -1050,6 +1050,96 @@ def test_nms_e2e_index_remap():
np.testing.assert_array_equal(ref_valid_box_count, np.array([[3]],
dtype="int32"))
+def test_roi_pool_op_correctness():
+ x = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32"))
+ rois = relax.Var("rois", R.Tensor((4, 5), "float32"))
+ assert relax.op.vision.roi_pool(x, rois, (7, 7), 1.0).op ==
Op.get("relax.vision.roi_pool")
+
+
+def test_roi_pool_infer_struct_info():
+ bb = relax.BlockBuilder()
+ x = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32"))
+ rois = relax.Var("rois", R.Tensor((5, 5), "float32"))
+
+ _check_inference(
+ bb,
+ relax.op.vision.roi_pool(x, rois, (7, 5), 0.25),
+ relax.TensorStructInfo((5, 3, 7, 5), "float32"),
+ )
+
+
+def test_roi_pool_infer_struct_info_shape_var():
+ bb = relax.BlockBuilder()
+ n = tirx.Var("n", "int64")
+ c = tirx.Var("c", "int64")
+ h = tirx.Var("h", "int64")
+ w = tirx.Var("w", "int64")
+ num_roi = tirx.Var("num_roi", "int64")
+
+ x = relax.Var("x", R.Tensor((n, c, h, w), "float32"))
+ rois = relax.Var("rois", R.Tensor((num_roi, 5), "float32"))
+
+ _check_inference(
+ bb,
+ relax.op.vision.roi_pool(x, rois, (7, 7), 0.5),
+ relax.TensorStructInfo((num_roi, c, 7, 7), "float32"),
+ )
+
+
+def test_roi_pool_wrong_input_ndim():
+ bb = relax.BlockBuilder()
+ x0 = relax.Var("x", R.Tensor((2, 3, 32), "float32"))
+ x1 = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32"))
+ rois0 = relax.Var("rois", R.Tensor((4,), "float32"))
+ rois1 = relax.Var("rois", R.Tensor((4, 5), "float32"))
+
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.vision.roi_pool(x0, rois1, (7, 7), 1.0))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.vision.roi_pool(x1, rois0, (7, 7), 1.0))
+
+
+def test_roi_pool_wrong_rois_last_dim():
+ bb = relax.BlockBuilder()
+ x = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32"))
+ rois = relax.Var("rois", R.Tensor((4, 4), "float32"))
+
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.vision.roi_pool(x, rois, (7, 7), 1.0))
+
+
+def test_roi_pool_wrong_layout():
+ bb = relax.BlockBuilder()
+ x = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32"))
+ rois = relax.Var("rois", R.Tensor((4, 5), "float32"))
+
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.vision.roi_pool(x, rois, (7, 7), 1.0,
layout="NHWC"))
+
+
+def test_roi_pool_legalize():
+ @tvm.script.ir_module
+ class ROIPool:
+ @R.function
+ def main(
+ x: R.Tensor((1, 2, 8, 8), "float32"),
+ rois: R.Tensor((2, 5), "float32"),
+ ) -> R.Tensor((2, 2, 3, 2), "float32"):
+ gv: R.Tensor((2, 2, 3, 2), "float32") = R.vision.roi_pool(
+ x,
+ rois,
+ pooled_size=(3, 2),
+ spatial_scale=1.0,
+ layout="NCHW",
+ )
+ return gv
+
+ mod = LegalizeOps()(ROIPool)
+ assert "call_tir" in str(mod)
+ tvm.ir.assert_structural_equal(
+ mod["main"].ret_struct_info,
+ relax.TensorStructInfo((2, 2, 3, 2), "float32"),
+ )
def test_all_class_non_max_suppression_infer_struct_info():
bb = relax.BlockBuilder()
batch_size, num_classes, num_boxes = 10, 8, 5
@@ -1201,12 +1291,9 @@ def test_multibox_transform_loc_op_correctness():
cls = relax.Var("cls", R.Tensor((1, 5, 10), "float32"))
loc = relax.Var("loc", R.Tensor((1, 40), "float32"))
anc = relax.Var("anc", R.Tensor((1, 10, 4), "float32"))
- assert (
- relax.op.vision.multibox_transform_loc(
- cls, loc, anc, False, 0.0, (1.0, 1.0, 1.0, 1.0), True
- ).op
- == Op.get("relax.vision.multibox_transform_loc")
- )
+ assert relax.op.vision.multibox_transform_loc(
+ cls, loc, anc, False, 0.0, (1.0, 1.0, 1.0, 1.0), True
+ ).op == Op.get("relax.vision.multibox_transform_loc")
def test_multibox_transform_loc_infer_struct_info():