This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 3eb86f78ed [Relax][TOPI] Add relax.vision.multibox_transform_loc for
SSD/TFLite box decode (#18942)
3eb86f78ed is described below
commit 3eb86f78ed1bdb2111118924d16e92bf2d1b054d
Author: Dayuxiaoshui <[email protected]>
AuthorDate: Sat Mar 28 12:20:22 2026 +0800
[Relax][TOPI] Add relax.vision.multibox_transform_loc for SSD/TFLite box
decode (#18942)
Introduce relax.vision.multibox_transform_loc with
MultiboxTransformLocAttrs: decode center-size offsets against ltrb
priors, softmax on class logits, and optional clip, threshold masking,
and background score zeroing. Register the C++ op with FInferStructInfo
checks for shapes and dtypes (including batch and 4*N consistency).
Legalize to topi.vision.multibox_transform_loc.
Add tests for struct inference, invalid inputs, Legalize+e2e on LLVM,
attribute branches, and TVMScript roundtrip. Add a standalone numpy
reference under topi/testing (not exported from tvm.topi.testing to
avoid pulling scipy).
Update TFLite frontend NotImplementedError text for
DETECTION_POSTPROCESS and NON_MAX_SUPPRESSION_V5 to note multibox is
available and link tracking issue #18928.
---
include/tvm/relax/attrs/vision.h | 24 ++
.../tvm/relax/frontend/tflite/tflite_frontend.py | 12 +-
python/tvm/relax/op/__init__.py | 2 +-
python/tvm/relax/op/op_attrs.py | 5 +
python/tvm/relax/op/vision/__init__.py | 1 +
.../tvm/relax/op/vision/multibox_transform_loc.py | 85 +++++++
python/tvm/relax/transform/legalize_ops/vision.py | 24 ++
python/tvm/topi/vision/__init__.py | 1 +
python/tvm/topi/vision/multibox_transform_loc.py | 121 +++++++++
src/relax/op/vision/multibox_transform_loc.cc | 204 +++++++++++++++
src/relax/op/vision/multibox_transform_loc.h | 42 +++
tests/python/relax/test_op_vision.py | 283 +++++++++++++++++++++
.../relax/test_tvmscript_parser_op_vision.py | 42 +++
13 files changed, 839 insertions(+), 7 deletions(-)
diff --git a/include/tvm/relax/attrs/vision.h b/include/tvm/relax/attrs/vision.h
index 59a1dd7314..4e3351bb90 100644
--- a/include/tvm/relax/attrs/vision.h
+++ b/include/tvm/relax/attrs/vision.h
@@ -73,6 +73,30 @@ struct ROIAlignAttrs : public
AttrsNodeReflAdapter<ROIAlignAttrs> {
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.ROIAlignAttrs",
ROIAlignAttrs, BaseAttrsNode);
}; // struct ROIAlignAttrs
+/*! \brief Attributes for multibox_transform_loc (SSD / TFLite-style box
decode). */
+struct MultiboxTransformLocAttrs : public
AttrsNodeReflAdapter<MultiboxTransformLocAttrs> {
+ bool clip;
+ double threshold;
+ ffi::Array<double> variances;
+ bool keep_background;
+
+ static void RegisterReflection() {
+ namespace refl = tvm::ffi::reflection;
+ refl::ObjectDef<MultiboxTransformLocAttrs>()
+ .def_ro("clip", &MultiboxTransformLocAttrs::clip,
+ "Clip decoded ymin,xmin,ymax,xmax to [0,1].")
+ .def_ro("threshold", &MultiboxTransformLocAttrs::threshold,
+ "After softmax, zero scores strictly below this value.")
+ .def_ro("variances", &MultiboxTransformLocAttrs::variances,
+ "(x,y,w,h) scales = TFLite
1/x_scale,1/y_scale,1/w_scale,1/h_scale on "
+ "encodings. Very large w/h scales can overflow exp in decode.")
+ .def_ro("keep_background", &MultiboxTransformLocAttrs::keep_background,
+ "If false, force output scores[:,0,:] to 0 (background
class).");
+ }
+ TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.MultiboxTransformLocAttrs",
+ MultiboxTransformLocAttrs, BaseAttrsNode);
+}; // struct MultiboxTransformLocAttrs
+
} // namespace relax
} // namespace tvm
diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py
b/python/tvm/relax/frontend/tflite/tflite_frontend.py
index 5c73af18ad..435180dfee 100644
--- a/python/tvm/relax/frontend/tflite/tflite_frontend.py
+++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py
@@ -3205,9 +3205,10 @@ class OperatorConverter:
def convert_detection_postprocess(self, op):
"""Convert TFLite_Detection_PostProcess"""
raise NotImplementedError(
- "DETECTION_POSTPROCESS requires vision ops
(multibox_transform_loc, "
- "non_max_suppression, get_valid_counts) not yet available in
Relax. "
- "See https://github.com/apache/tvm/issues/XXXX"
+ "DETECTION_POSTPROCESS is not wired in this frontend yet: it still
needs "
+ "Relax NMS / get_valid_counts / related vision helpers (see dead
code below). "
+ "relax.vision.multibox_transform_loc exists; tracking: "
+ "https://github.com/apache/tvm/issues/18928"
)
flexbuffer = op.CustomOptionsAsNumpy().tobytes()
custom_options = FlexBufferDecoder(flexbuffer).decode()
@@ -3340,9 +3341,8 @@ class OperatorConverter:
"""Convert TFLite NonMaxSuppressionV5"""
#
https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/non-max-suppression-v5
raise NotImplementedError(
- "NON_MAX_SUPPRESSION_V5 requires vision ops (get_valid_counts, "
- "non_max_suppression) not yet available in Relax. "
- "See https://github.com/apache/tvm/issues/XXXX"
+ "NON_MAX_SUPPRESSION_V5 is not wired in this frontend yet (needs
get_valid_counts, "
+ "non_max_suppression, etc.). Tracking:
https://github.com/apache/tvm/issues/18928"
)
input_tensors = self.get_input_tensors(op)
diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py
index 0bc3f65784..ee1a2c2420 100644
--- a/python/tvm/relax/op/__init__.py
+++ b/python/tvm/relax/op/__init__.py
@@ -157,7 +157,7 @@ from .unary import (
tanh,
trunc,
)
-from .vision import all_class_non_max_suppression, roi_align
+from .vision import all_class_non_max_suppression, multibox_transform_loc,
roi_align
def _register_op_make():
diff --git a/python/tvm/relax/op/op_attrs.py b/python/tvm/relax/op/op_attrs.py
index a3b6544dcc..e8c91f04b4 100644
--- a/python/tvm/relax/op/op_attrs.py
+++ b/python/tvm/relax/op/op_attrs.py
@@ -251,6 +251,11 @@ class ROIAlignAttrs(Attrs):
"""Attributes for vision.roi_align"""
+@tvm_ffi.register_object("relax.attrs.MultiboxTransformLocAttrs")
+class MultiboxTransformLocAttrs(Attrs):
+ """Attributes for vision.multibox_transform_loc"""
+
+
@tvm_ffi.register_object("relax.attrs.Conv1DAttrs")
class Conv1DAttrs(Attrs):
"""Attributes for nn.conv1d"""
diff --git a/python/tvm/relax/op/vision/__init__.py
b/python/tvm/relax/op/vision/__init__.py
index 76d9ea35a1..58266c5b2a 100644
--- a/python/tvm/relax/op/vision/__init__.py
+++ b/python/tvm/relax/op/vision/__init__.py
@@ -17,5 +17,6 @@
# under the License.
"""VISION operators."""
+from .multibox_transform_loc import *
from .nms import *
from .roi_align import *
diff --git a/python/tvm/relax/op/vision/multibox_transform_loc.py
b/python/tvm/relax/op/vision/multibox_transform_loc.py
new file mode 100644
index 0000000000..6830b1dc63
--- /dev/null
+++ b/python/tvm/relax/op/vision/multibox_transform_loc.py
@@ -0,0 +1,85 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Multibox location transform for object detection."""
+
+from . import _ffi_api
+
+
+def multibox_transform_loc(
+ cls_pred,
+ loc_pred,
+ anchor,
+ clip=False,
+ threshold=0.0,
+ variances=(1.0, 1.0, 1.0, 1.0),
+ keep_background=True,
+):
+ """SSD / TFLite-style decode: priors + offsets → boxes; logits → softmax
scores.
+
+ Box decode follows TFLite ``DecodeCenterSizeBoxes``; expected tensor
layout matches
+ ``tflite_frontend.convert_detection_postprocess`` (loc reorder yxhw→xywh,
anchor ltrb).
+
+ Parameters
+ ----------
+ cls_pred : relax.Expr
+ ``[B, C, N]`` class logits (pre-softmax).
+ loc_pred : relax.Expr
+ ``[B, 4*N]`` per-anchor encodings as ``(x,y,w,h)`` after reorder (see
above).
+ anchor : relax.Expr
+ ``[1, N, 4]`` priors: ``(left, top, right, bottom)``.
+ clip : bool
+ If True, clip ``ymin,xmin,ymax,xmax`` to ``[0, 1]``.
+ threshold : float
+ After softmax, multiply scores by mask ``(score >= threshold)``.
+ variances : tuple of 4 floats
+ ``(x,y,w,h)`` = TFLite ``1/x_scale, 1/y_scale, 1/w_scale, 1/h_scale``.
+ Use magnitudes consistent with the model: very large ``w``/``h``
entries scale the
+ encoded height/width terms inside ``exp(...)`` and can overflow in
float32/float16.
+ keep_background : bool
+ If False, set output scores at class index 0 to zero.
+
+ Returns
+ -------
+ result : relax.Expr
+ Tuple ``(boxes, scores)``: ``boxes`` is ``[B, N, 4]`` as
``(ymin,xmin,ymax,xmax)``;
+ ``scores`` is ``[B, C, N]`` softmax, post-processed like the
implementation.
+
+ Notes
+ -----
+ **Shape/dtype (checked in ``FInferStructInfo`` when static):**
+
+ - ``cls_pred``: 3-D; ``loc_pred``: 2-D; ``anchor``: 3-D.
+ - ``cls_pred``, ``loc_pred``, ``anchor`` dtypes must match.
+ - ``N = cls_pred.shape[2]``; ``loc_pred.shape[1] == 4*N``; ``anchor.shape
== [1,N,4]``.
+ - ``loc_pred.shape[1]`` must be divisible by 4.
+ - ``cls_pred.shape[0]`` must equal ``loc_pred.shape[0]`` (batch).
+
+ If ``cls_pred`` has **unknown** shape, inference only returns generic
rank-3 tensor
+ struct info for the two outputs; it does **not** verify ``4*N`` vs
``loc_pred`` or
+ ``anchor.shape[1]`` vs ``N``, because ``N`` is not available statically.
Other checks
+ (ranks, dtypes, ``loc_pred.shape[1] % 4 == 0`` when known, batch match
when both batch
+ axes are known, etc.) still run where applicable.
+ """
+ return _ffi_api.multibox_transform_loc(
+ cls_pred,
+ loc_pred,
+ anchor,
+ clip,
+ threshold,
+ variances,
+ keep_background,
+ )
diff --git a/python/tvm/relax/transform/legalize_ops/vision.py
b/python/tvm/relax/transform/legalize_ops/vision.py
index 7a1e305f39..28367a67a3 100644
--- a/python/tvm/relax/transform/legalize_ops/vision.py
+++ b/python/tvm/relax/transform/legalize_ops/vision.py
@@ -118,3 +118,27 @@ def _roi_align(bb: BlockBuilder, call: Call) -> Expr:
aligned=call.attrs.aligned,
layout=call.attrs.layout,
)
+
+
+@register_legalize("relax.vision.multibox_transform_loc")
+def _multibox_transform_loc(bb: BlockBuilder, call: Call) -> Expr:
+ variances = tuple(float(x) for x in call.attrs.variances)
+
+ def _te(cls_pred, loc_pred, anchor):
+ return topi.vision.multibox_transform_loc(
+ cls_pred,
+ loc_pred,
+ anchor,
+ variances,
+ clip=call.attrs.clip,
+ threshold=call.attrs.threshold,
+ keep_background=call.attrs.keep_background,
+ )
+
+ return bb.call_te(
+ _te,
+ call.args[0],
+ call.args[1],
+ call.args[2],
+ primfunc_name_hint="multibox_transform_loc",
+ )
diff --git a/python/tvm/topi/vision/__init__.py
b/python/tvm/topi/vision/__init__.py
index 75725a8a4b..cb0467c98c 100644
--- a/python/tvm/topi/vision/__init__.py
+++ b/python/tvm/topi/vision/__init__.py
@@ -17,5 +17,6 @@
# under the License.
"""Vision operators."""
+from .multibox_transform_loc import *
from .nms import *
from .roi_align import *
diff --git a/python/tvm/topi/vision/multibox_transform_loc.py
b/python/tvm/topi/vision/multibox_transform_loc.py
new file mode 100644
index 0000000000..ab965e7981
--- /dev/null
+++ b/python/tvm/topi/vision/multibox_transform_loc.py
@@ -0,0 +1,121 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name
+"""Multibox location transform (SSD / TFLite DetectionPostProcess decode)."""
+
+import tvm
+from tvm import te, topi
+
+
+def multibox_transform_loc(
+ cls_pred,
+ loc_pred,
+ anchor,
+ variances,
+ clip=False,
+ threshold=0.0,
+ keep_background=True,
+):
+ """TFLite ``DecodeCenterSizeBoxes``-style decode + softmax score
post-process.
+
+ Inputs must match Relax op contracts: ``cls_pred [B,C,N]``, ``loc_pred
[B,4*N]``,
+ ``anchor [1,N,4]`` ltrb; per-anchor loc order ``(x,y,w,h)`` after
yxhw→xywh reorder.
+
+ Parameters
+ ----------
+ cls_pred : te.Tensor
+ ``[B, C, N]`` logits.
+ loc_pred : te.Tensor
+ ``[B, 4*N]`` encodings ``(x,y,w,h)`` per anchor.
+ anchor : te.Tensor
+ ``[1, N, 4]`` ``(left, top, right, bottom)``.
+ variances : tuple of 4 float
+ ``(x,y,w,h)`` = ``1/x_scale, 1/y_scale, 1/w_scale, 1/h_scale``
(TFLite).
+ clip : bool
+ Clip ``ymin,xmin,ymax,xmax`` to ``[0,1]``.
+ threshold : float
+ After softmax: ``scores *= (scores >= threshold)``.
+ keep_background : bool
+ If False: ``scores[:,0,:] = 0``.
+
+ Returns
+ -------
+ boxes : te.Tensor
+ ``[B, N, 4]`` as ``(ymin,xmin,ymax,xmax)``.
+ scores : te.Tensor
+ ``[B, C, N]`` softmax, then threshold mask and optional background
zero.
+ """
+ dtype = cls_pred.dtype
+ B = cls_pred.shape[0]
+ num_anchors = cls_pred.shape[2]
+ loc_reshaped = topi.reshape(loc_pred, [B, num_anchors, 4])
+
+ vx = tvm.tirx.const(float(variances[0]), dtype)
+ vy = tvm.tirx.const(float(variances[1]), dtype)
+ vw = tvm.tirx.const(float(variances[2]), dtype)
+ vh = tvm.tirx.const(float(variances[3]), dtype)
+ half = tvm.tirx.const(0.5, dtype)
+ zero = tvm.tirx.const(0.0, dtype)
+ one = tvm.tirx.const(1.0, dtype)
+ th = tvm.tirx.const(float(threshold), dtype)
+
+ def decode_bbox(b, a, k):
+ l = anchor[0, a, 0]
+ t = anchor[0, a, 1]
+ r = anchor[0, a, 2]
+ br = anchor[0, a, 3]
+ ay = (t + br) * half
+ ax = (l + r) * half
+ ah = br - t
+ aw = r - l
+ ex = loc_reshaped[b, a, 0]
+ ey = loc_reshaped[b, a, 1]
+ ew = loc_reshaped[b, a, 2]
+ eh = loc_reshaped[b, a, 3]
+ ycenter = ey * vy * ah + ay
+ xcenter = ex * vx * aw + ax
+ half_h = half * te.exp(eh * vh) * ah
+ half_w = half * te.exp(ew * vw) * aw
+ ymin = ycenter - half_h
+ xmin = xcenter - half_w
+ ymax = ycenter + half_h
+ xmax = xcenter + half_w
+ if clip:
+ ymin = te.max(zero, te.min(one, ymin))
+ xmin = te.max(zero, te.min(one, xmin))
+ ymax = te.max(zero, te.min(one, ymax))
+ xmax = te.max(zero, te.min(one, xmax))
+ return tvm.tirx.Select(
+ k == 0,
+ ymin,
+ tvm.tirx.Select(k == 1, xmin, tvm.tirx.Select(k == 2, ymax, xmax)),
+ )
+
+ boxes = te.compute((B, num_anchors, 4), decode_bbox, name="multibox_boxes")
+
+ scores = topi.nn.softmax(cls_pred, axis=1)
+ mask = topi.cast(topi.greater_equal(scores, th), dtype)
+ scores = scores * mask
+ if not keep_background:
+
+ def zero_bg(b, c, n):
+ s = scores[b, c, n]
+ return te.if_then_else(c == 0, zero, s)
+
+ scores = te.compute(scores.shape, zero_bg, name="multibox_scores_bg")
+
+ return [boxes, scores]
diff --git a/src/relax/op/vision/multibox_transform_loc.cc
b/src/relax/op/vision/multibox_transform_loc.cc
new file mode 100644
index 0000000000..e01e569b78
--- /dev/null
+++ b/src/relax/op/vision/multibox_transform_loc.cc
@@ -0,0 +1,204 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file multibox_transform_loc.cc
+ * \brief Multibox transform (location decode) for object detection.
+ */
+
+#include "multibox_transform_loc.h"
+
+#include <tvm/ffi/reflection/registry.h>
+#include <tvm/relax/struct_info.h>
+
+#include <utility>
+
+namespace tvm {
+namespace relax {
+
+TVM_FFI_STATIC_INIT_BLOCK() { MultiboxTransformLocAttrs::RegisterReflection();
}
+
+Expr multibox_transform_loc(Expr cls_pred, Expr loc_pred, Expr anchor, bool
clip, double threshold,
+ ffi::Array<double> variances, bool
keep_background) {
+ TVM_FFI_ICHECK_EQ(variances.size(), 4)
+ << "multibox_transform_loc: variances must be length 4 (x,y,w,h), got "
<< variances.size();
+
+ auto attrs = ffi::make_object<MultiboxTransformLocAttrs>();
+ attrs->clip = clip;
+ attrs->threshold = threshold;
+ attrs->variances = std::move(variances);
+ attrs->keep_background = keep_background;
+
+ static const Op& op = Op::Get("relax.vision.multibox_transform_loc");
+ return Call(op, {std::move(cls_pred), std::move(loc_pred),
std::move(anchor)}, Attrs(attrs), {});
+}
+
+TVM_FFI_STATIC_INIT_BLOCK() {
+ namespace refl = tvm::ffi::reflection;
+ refl::GlobalDef().def("relax.op.vision.multibox_transform_loc",
multibox_transform_loc);
+}
+
+/*!
+ * \brief Infer struct info for relax.vision.multibox_transform_loc.
+ *
+ * \note Shape cross-checks that need the anchor count N (e.g.
loc_pred.shape[1] == 4*N,
+ * anchor.shape[1] == N with N = cls_pred.shape[2]) run only when cls_pred has
a known
+ * static shape. If cls_pred shape is unknown, inference returns generic
rank-3 outputs and
+ * skips those N-based relations; other checks (ndim, dtype, loc dim divisible
by 4, etc.)
+ * still apply when their inputs are known.
+ */
+StructInfo InferStructInfoMultiboxTransformLoc(const Call& call, const
BlockBuilder& ctx) {
+ if (call->args.size() != 3) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "multibox_transform_loc: expected 3 inputs (cls_pred,
loc_pred, anchor), "
+ "got "
+ << call->args.size());
+ }
+
+ ffi::Array<TensorStructInfo> input_sinfo = GetInputTensorStructInfo(call,
ctx);
+ const auto cls_sinfo = input_sinfo[0];
+ const auto loc_sinfo = input_sinfo[1];
+ const auto anchor_sinfo = input_sinfo[2];
+
+ if (!cls_sinfo->IsUnknownNdim() && cls_sinfo->ndim != 3) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "multibox_transform_loc: cls_pred must be 3-D [B,
num_classes, N], got "
+ "ndim "
+ << cls_sinfo->ndim);
+ }
+ if (!loc_sinfo->IsUnknownNdim() && loc_sinfo->ndim != 2) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "multibox_transform_loc: loc_pred must be 2-D [B,
4*N], got ndim "
+ << loc_sinfo->ndim);
+ }
+ if (!anchor_sinfo->IsUnknownNdim() && anchor_sinfo->ndim != 3) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "multibox_transform_loc: anchor must be 3-D [1, N, 4]
ltrb, got ndim "
+ << anchor_sinfo->ndim);
+ }
+
+ if (!cls_sinfo->IsUnknownDtype() && !loc_sinfo->IsUnknownDtype() &&
+ cls_sinfo->dtype != loc_sinfo->dtype) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "multibox_transform_loc: cls_pred and loc_pred dtype
must match, got "
+ << cls_sinfo->dtype << " vs " << loc_sinfo->dtype);
+ }
+ if (!cls_sinfo->IsUnknownDtype() && !anchor_sinfo->IsUnknownDtype() &&
+ cls_sinfo->dtype != anchor_sinfo->dtype) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "multibox_transform_loc: cls_pred and anchor dtype
must match, got "
+ << cls_sinfo->dtype << " vs " << anchor_sinfo->dtype);
+ }
+
+ auto vdev = cls_sinfo->vdevice;
+ const auto* cls_shape = cls_sinfo->shape.as<ShapeExprNode>();
+ const auto* loc_shape = loc_sinfo->shape.as<ShapeExprNode>();
+ const auto* anchor_shape = anchor_sinfo->shape.as<ShapeExprNode>();
+
+ if (loc_shape != nullptr) {
+ const auto* loc_dim1 = loc_shape->values[1].as<IntImmNode>();
+ if (loc_dim1 != nullptr && loc_dim1->value % 4 != 0) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "multibox_transform_loc: loc_pred.shape[1] must be
divisible by 4, got "
+ << loc_dim1->value);
+ }
+ }
+
+ if (cls_shape != nullptr && loc_shape != nullptr) {
+ const auto* cls_b = cls_shape->values[0].as<IntImmNode>();
+ const auto* loc_b = loc_shape->values[0].as<IntImmNode>();
+ if (cls_b != nullptr && loc_b != nullptr && cls_b->value != loc_b->value) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "multibox_transform_loc: cls_pred.shape[0] must
match loc_pred.shape[0], "
+ "got B="
+ << cls_b->value << " vs " << loc_b->value);
+ }
+ }
+
+ if (anchor_shape != nullptr) {
+ const auto* anchor_batch = anchor_shape->values[0].as<IntImmNode>();
+ if (anchor_batch != nullptr && anchor_batch->value != 1) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "multibox_transform_loc: anchor.shape[0] must be 1,
got "
+ << anchor_batch->value);
+ }
+ const auto* anchor_last = anchor_shape->values[2].as<IntImmNode>();
+ if (anchor_last != nullptr && anchor_last->value != 4) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "multibox_transform_loc: anchor.shape[2] must be 4
(ltrb), got "
+ << anchor_last->value);
+ }
+ }
+
+ if (cls_shape == nullptr) {
+ ffi::Array<StructInfo> fields = {TensorStructInfo(cls_sinfo->dtype, 3,
vdev),
+ TensorStructInfo(cls_sinfo->dtype, 3,
vdev)};
+ return TupleStructInfo(fields);
+ }
+
+ const auto& batch = cls_shape->values[0];
+ const auto& num_classes = cls_shape->values[1];
+ const auto& num_anchors = cls_shape->values[2];
+
+ if (loc_shape != nullptr) {
+ const auto* num_anchors_imm = num_anchors.as<IntImmNode>();
+ const auto* loc_dim1 = loc_shape->values[1].as<IntImmNode>();
+ if (num_anchors_imm != nullptr && loc_dim1 != nullptr &&
+ loc_dim1->value != num_anchors_imm->value * 4) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "multibox_transform_loc: loc_pred.shape[1] must
equal 4*N with "
+ "N=cls_pred.shape[2], got loc_dim="
+ << loc_dim1->value << ", N=" << num_anchors_imm->value);
+ }
+ }
+ if (anchor_shape != nullptr) {
+ const auto* num_anchors_imm = num_anchors.as<IntImmNode>();
+ const auto* anchor_num_anchors = anchor_shape->values[1].as<IntImmNode>();
+ if (num_anchors_imm != nullptr && anchor_num_anchors != nullptr &&
+ anchor_num_anchors->value != num_anchors_imm->value) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "multibox_transform_loc: anchor.shape[1] must equal
N=cls_pred.shape[2], "
+ "got anchor_N="
+ << anchor_num_anchors->value << ", N=" <<
num_anchors_imm->value);
+ }
+ }
+
+ ffi::Array<PrimExpr> boxes_shape = {batch, num_anchors, Integer(4)};
+ ffi::Array<PrimExpr> scores_shape = {batch, num_classes, num_anchors};
+ ffi::Array<StructInfo> fields = {
+ TensorStructInfo(ShapeExpr(boxes_shape), cls_sinfo->dtype, vdev),
+ TensorStructInfo(ShapeExpr(scores_shape), cls_sinfo->dtype, vdev)};
+ return TupleStructInfo(fields);
+}
+
+TVM_REGISTER_OP("relax.vision.multibox_transform_loc")
+ .describe("Decode SSD/TFLite-style priors and offsets into boxes and
softmax scores. If "
+ "cls_pred shape is unknown, N-based loc/anchor shape checks are
skipped in "
+ "inference. Very large variances (w,h) can overflow exp in half
box sizes.")
+ .set_attrs_type<MultiboxTransformLocAttrs>()
+ .set_num_inputs(3)
+ .add_argument("cls_pred", "Tensor", "[B,C,N] class logits (pre-softmax).")
+ .add_argument("loc_pred", "Tensor",
+ "[B,4*N] box encodings (x,y,w,h); TFLite yxhw order remapped
to xywh.")
+ .add_argument("anchor", "Tensor", "[1,N,4] priors as ltrb
(left,top,right,bottom).")
+ .set_attr<FInferStructInfo>("FInferStructInfo",
InferStructInfoMultiboxTransformLoc)
+ .set_attr<Bool>("FPurity", Bool(true));
+
+} // namespace relax
+} // namespace tvm
diff --git a/src/relax/op/vision/multibox_transform_loc.h
b/src/relax/op/vision/multibox_transform_loc.h
new file mode 100644
index 0000000000..726bc4c0e5
--- /dev/null
+++ b/src/relax/op/vision/multibox_transform_loc.h
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file multibox_transform_loc.h
+ * \brief The functions to make Relax multibox_transform_loc operator calls.
+ */
+
+#ifndef TVM_RELAX_OP_VISION_MULTIBOX_TRANSFORM_LOC_H_
+#define TVM_RELAX_OP_VISION_MULTIBOX_TRANSFORM_LOC_H_
+
+#include <tvm/relax/attrs/vision.h>
+
+#include "../op_common.h"
+
+namespace tvm {
+namespace relax {
+
+/*! \brief Decode SSD box encodings and prepare class scores
(TFLite-compatible). */
+Expr multibox_transform_loc(Expr cls_pred, Expr loc_pred, Expr anchor, bool
clip, double threshold,
+ ffi::Array<double> variances, bool
keep_background);
+
+} // namespace relax
+} // namespace tvm
+
+#endif // TVM_RELAX_OP_VISION_MULTIBOX_TRANSFORM_LOC_H_
diff --git a/tests/python/relax/test_op_vision.py
b/tests/python/relax/test_op_vision.py
index b902518b49..cded9f5f29 100644
--- a/tests/python/relax/test_op_vision.py
+++ b/tests/python/relax/test_op_vision.py
@@ -286,6 +286,7 @@ def
test_all_class_non_max_suppression_legalize_dynamic_trim():
)
[email protected]_llvm
def test_all_class_non_max_suppression_legalize_e2e():
@tvm.script.ir_module
class NMSModule:
@@ -344,5 +345,287 @@ def test_all_class_non_max_suppression_legalize_e2e():
tvm.testing.assert_allclose(selected_indices.shape, (num_total_detections,
3))
+def test_multibox_transform_loc_op_correctness():
+ cls = relax.Var("cls", R.Tensor((1, 5, 10), "float32"))
+ loc = relax.Var("loc", R.Tensor((1, 40), "float32"))
+ anc = relax.Var("anc", R.Tensor((1, 10, 4), "float32"))
+ assert (
+ relax.op.vision.multibox_transform_loc(
+ cls, loc, anc, False, 0.0, (1.0, 1.0, 1.0, 1.0), True
+ ).op
+ == Op.get("relax.vision.multibox_transform_loc")
+ )
+
+
+def test_multibox_transform_loc_infer_struct_info():
+ bb = relax.BlockBuilder()
+ cls = relax.Var("cls", R.Tensor((2, 3, 5), "float32"))
+ loc = relax.Var("loc", R.Tensor((2, 20), "float32"))
+ anc = relax.Var("anc", R.Tensor((1, 5, 4), "float32"))
+ _check_inference(
+ bb,
+ relax.op.vision.multibox_transform_loc(
+ cls, loc, anc, False, 0.0, (0.1, 0.1, 0.2, 0.2), True
+ ),
+ relax.TupleStructInfo(
+ [
+ relax.TensorStructInfo((2, 5, 4), "float32"),
+ relax.TensorStructInfo((2, 3, 5), "float32"),
+ ]
+ ),
+ )
+
+
+def test_multibox_transform_loc_wrong_cls_ndim():
+ bb = relax.BlockBuilder()
+ cls = relax.Var("cls", R.Tensor((2, 3), "float32"))
+ loc = relax.Var("loc", R.Tensor((2, 20), "float32"))
+ anc = relax.Var("anc", R.Tensor((1, 5, 4), "float32"))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.vision.multibox_transform_loc(cls, loc, anc))
+
+
+def test_multibox_transform_loc_wrong_shape_relation():
+ bb = relax.BlockBuilder()
+ cls = relax.Var("cls", R.Tensor((2, 3, 5), "float32"))
+ anc = relax.Var("anc", R.Tensor((1, 5, 4), "float32"))
+ loc_bad_div = relax.Var("loc_bad_div", R.Tensor((2, 19), "float32"))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.vision.multibox_transform_loc(cls, loc_bad_div,
anc))
+ # Divisible by 4 but loc_dim != 4*N (N=5 -> expect 20, not 24)
+ loc_bad_n = relax.Var("loc_bad_n", R.Tensor((2, 24), "float32"))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.vision.multibox_transform_loc(cls, loc_bad_n,
anc))
+
+
+def test_multibox_transform_loc_wrong_anchor_shape():
+ bb = relax.BlockBuilder()
+ cls = relax.Var("cls", R.Tensor((2, 3, 5), "float32"))
+ loc = relax.Var("loc", R.Tensor((2, 20), "float32"))
+ anc_bad_batch = relax.Var("anc_bad_batch", R.Tensor((2, 5, 4), "float32"))
+ anc_bad_last = relax.Var("anc_bad_last", R.Tensor((1, 5, 5), "float32"))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.vision.multibox_transform_loc(cls, loc,
anc_bad_batch))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.vision.multibox_transform_loc(cls, loc,
anc_bad_last))
+
+
+def test_multibox_transform_loc_wrong_dtype():
+ bb = relax.BlockBuilder()
+ cls = relax.Var("cls", R.Tensor((2, 3, 5), "float32"))
+ loc = relax.Var("loc", R.Tensor((2, 20), "float16"))
+ anc = relax.Var("anc", R.Tensor((1, 5, 4), "float32"))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.vision.multibox_transform_loc(cls, loc, anc))
+
+
+def test_multibox_transform_loc_wrong_batch():
+ bb = relax.BlockBuilder()
+ cls = relax.Var("cls", R.Tensor((2, 3, 5), "float32"))
+ loc = relax.Var("loc", R.Tensor((1, 20), "float32"))
+ anc = relax.Var("anc", R.Tensor((1, 5, 4), "float32"))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.vision.multibox_transform_loc(cls, loc, anc))
+
+
+def _multibox_ref_numpy(
+ cls_pred, loc_pred, anchor, variances, clip=False, threshold=0.0,
keep_background=True
+):
+ """Numpy reference aligned with ``topi.vision.multibox_transform_loc``."""
+
+ def _softmax(x, axis):
+ x_max = np.max(x, axis=axis, keepdims=True)
+ exp = np.exp(x - x_max)
+ return exp / np.sum(exp, axis=axis, keepdims=True)
+
+ B, C, N = cls_pred.shape
+ loc = loc_pred.reshape(B, N, 4)
+ scores = _softmax(cls_pred.astype("float64"), axis=1).astype(np.float32)
+ if threshold > 0.0:
+ scores = np.where(scores >= threshold, scores, 0.0).astype(np.float32)
+ if not keep_background:
+ scores = scores.copy()
+ scores[:, 0, :] = 0.0
+ vx, vy, vw, vh = variances
+ boxes = np.zeros((B, N, 4), dtype=np.float32)
+ for b in range(B):
+ for a in range(N):
+ l, t, r, br = anchor[0, a, :]
+ ay = (t + br) * 0.5
+ ax = (l + r) * 0.5
+ ah = br - t
+ aw = r - l
+ ex, ey, ew, eh = loc[b, a, :]
+ ycenter = ey * vy * ah + ay
+ xcenter = ex * vx * aw + ax
+ half_h = 0.5 * np.exp(eh * vh) * ah
+ half_w = 0.5 * np.exp(ew * vw) * aw
+ ymin = ycenter - half_h
+ xmin = xcenter - half_w
+ ymax = ycenter + half_h
+ xmax = xcenter + half_w
+ if clip:
+ ymin = np.clip(ymin, 0.0, 1.0)
+ xmin = np.clip(xmin, 0.0, 1.0)
+ ymax = np.clip(ymax, 0.0, 1.0)
+ xmax = np.clip(xmax, 0.0, 1.0)
+ boxes[b, a, :] = (ymin, xmin, ymax, xmax)
+ return boxes, scores
+
+
[email protected]_llvm
+def test_multibox_transform_loc_legalize_e2e():
+ @tvm.script.ir_module
+ class Mod:
+ @R.function
+ def main(
+ cls: R.Tensor((1, 3, 5), "float32"),
+ loc: R.Tensor((1, 20), "float32"),
+ anc: R.Tensor((1, 5, 4), "float32"),
+ ) -> R.Tuple(R.Tensor((1, 5, 4), "float32"), R.Tensor((1, 3, 5),
"float32")):
+ return R.vision.multibox_transform_loc(
+ cls,
+ loc,
+ anc,
+ clip=False,
+ threshold=0.0,
+ variances=(1.0, 1.0, 1.0, 1.0),
+ keep_background=True,
+ )
+
+ cls_data = np.random.randn(1, 3, 5).astype(np.float32)
+ loc_data = np.random.randn(1, 20).astype(np.float32) * 0.05
+ anc_data = np.array(
+ [
+ [
+ [0.1, 0.1, 0.5, 0.5],
+ [0.2, 0.2, 0.6, 0.6],
+ [0.0, 0.0, 1.0, 1.0],
+ [0.3, 0.3, 0.7, 0.7],
+ [0.05, 0.05, 0.45, 0.45],
+ ]
+ ],
+ dtype=np.float32,
+ )
+
+ mod = LegalizeOps()(Mod)
+ exe = tvm.compile(mod, target="llvm")
+ vm = relax.VirtualMachine(exe, tvm.cpu())
+ ref_b, ref_s = _multibox_ref_numpy(cls_data, loc_data, anc_data, (1.0,
1.0, 1.0, 1.0))
+ out = vm["main"](
+ tvm.runtime.tensor(cls_data, tvm.cpu()),
+ tvm.runtime.tensor(loc_data, tvm.cpu()),
+ tvm.runtime.tensor(anc_data, tvm.cpu()),
+ )
+ tvm.testing.assert_allclose(out[0].numpy(), ref_b, rtol=1e-4, atol=1e-5)
+ tvm.testing.assert_allclose(out[1].numpy(), ref_s, rtol=1e-4, atol=1e-5)
+
+
[email protected]_llvm
+def test_multibox_transform_loc_legalize_e2e_nonunity_variances():
+ @tvm.script.ir_module
+ class Mod:
+ @R.function
+ def main(
+ cls: R.Tensor((1, 3, 5), "float32"),
+ loc: R.Tensor((1, 20), "float32"),
+ anc: R.Tensor((1, 5, 4), "float32"),
+ ) -> R.Tuple(R.Tensor((1, 5, 4), "float32"), R.Tensor((1, 3, 5),
"float32")):
+ return R.vision.multibox_transform_loc(
+ cls,
+ loc,
+ anc,
+ clip=False,
+ threshold=0.0,
+ variances=(0.1, 0.1, 0.2, 0.2),
+ keep_background=True,
+ )
+
+ cls_data = np.random.randn(1, 3, 5).astype(np.float32)
+ loc_data = np.random.randn(1, 20).astype(np.float32) * 0.05
+ anc_data = np.array(
+ [
+ [
+ [0.1, 0.1, 0.5, 0.5],
+ [0.2, 0.2, 0.6, 0.6],
+ [0.0, 0.0, 1.0, 1.0],
+ [0.3, 0.3, 0.7, 0.7],
+ [0.05, 0.05, 0.45, 0.45],
+ ]
+ ],
+ dtype=np.float32,
+ )
+
+ mod = LegalizeOps()(Mod)
+ exe = tvm.compile(mod, target="llvm")
+ vm = relax.VirtualMachine(exe, tvm.cpu())
+ ref_b, ref_s = _multibox_ref_numpy(cls_data, loc_data, anc_data, (0.1,
0.1, 0.2, 0.2))
+ out = vm["main"](
+ tvm.runtime.tensor(cls_data, tvm.cpu()),
+ tvm.runtime.tensor(loc_data, tvm.cpu()),
+ tvm.runtime.tensor(anc_data, tvm.cpu()),
+ )
+ tvm.testing.assert_allclose(out[0].numpy(), ref_b, rtol=1e-4, atol=1e-5)
+ tvm.testing.assert_allclose(out[1].numpy(), ref_s, rtol=1e-4, atol=1e-5)
+
+
[email protected]_llvm
+def test_multibox_transform_loc_legalize_attr_branches():
+ @tvm.script.ir_module
+ class Mod:
+ @R.function
+ def main(
+ cls: R.Tensor((1, 3, 4), "float32"),
+ loc: R.Tensor((1, 16), "float32"),
+ anc: R.Tensor((1, 4, 4), "float32"),
+ ) -> R.Tuple(R.Tensor((1, 4, 4), "float32"), R.Tensor((1, 3, 4),
"float32")):
+ return R.vision.multibox_transform_loc(
+ cls,
+ loc,
+ anc,
+ clip=True,
+ threshold=0.4,
+ variances=(1.0, 1.0, 1.0, 1.0),
+ keep_background=False,
+ )
+
+ cls_data = np.array(
+ [[[2.0, 0.1, -0.5, 0.0], [0.2, 2.2, 0.3, -1.0], [0.1, 0.4, 2.0, 0.5]]],
+ dtype=np.float32,
+ )
+ loc_data = np.array(
+ [[0.1, -0.2, 0.0, 0.0, -0.2, 0.1, 0.3, -0.1, 0.0, 0.0, 0.8, 0.8, 0.2,
0.2, -0.6, -0.6]],
+ dtype=np.float32,
+ )
+ anc_data = np.array(
+ [[[0.1, 0.1, 0.5, 0.5], [0.2, 0.2, 0.6, 0.6], [0.0, 0.0, 1.0, 1.0],
[0.4, 0.4, 1.2, 1.2]]],
+ dtype=np.float32,
+ )
+
+ mod = LegalizeOps()(Mod)
+ exe = tvm.compile(mod, target="llvm")
+ vm = relax.VirtualMachine(exe, tvm.cpu())
+ ref_b, ref_s = _multibox_ref_numpy(
+ cls_data,
+ loc_data,
+ anc_data,
+ (1.0, 1.0, 1.0, 1.0),
+ clip=True,
+ threshold=0.4,
+ keep_background=False,
+ )
+ out = vm["main"](
+ tvm.runtime.tensor(cls_data, tvm.cpu()),
+ tvm.runtime.tensor(loc_data, tvm.cpu()),
+ tvm.runtime.tensor(anc_data, tvm.cpu()),
+ )
+ boxes = out[0].numpy()
+ scores = out[1].numpy()
+ tvm.testing.assert_allclose(boxes, ref_b, rtol=1e-4, atol=1e-5)
+ tvm.testing.assert_allclose(scores, ref_s, rtol=1e-4, atol=1e-5)
+ assert np.all(boxes >= 0.0) and np.all(boxes <= 1.0)
+ tvm.testing.assert_allclose(scores[:, 0, :], np.zeros_like(scores[:, 0,
:]))
+
+
if __name__ == "__main__":
tvm.testing.main()
diff --git a/tests/python/relax/test_tvmscript_parser_op_vision.py
b/tests/python/relax/test_tvmscript_parser_op_vision.py
index c4e8ff0c9d..f053e36744 100644
--- a/tests/python/relax/test_tvmscript_parser_op_vision.py
+++ b/tests/python/relax/test_tvmscript_parser_op_vision.py
@@ -75,6 +75,48 @@ def test_all_class_non_max_suppression():
_check(foo, bb.get()["foo"])
+def test_multibox_transform_loc():
+ @R.function
+ def foo(
+ cls: R.Tensor((1, 3, 5), "float32"),
+ loc: R.Tensor((1, 20), "float32"),
+ anc: R.Tensor((1, 5, 4), "float32"),
+ ) -> R.Tuple(R.Tensor((1, 5, 4), "float32"), R.Tensor((1, 3, 5),
"float32")):
+ gv: R.Tuple(R.Tensor((1, 5, 4), "float32"), R.Tensor((1, 3, 5),
"float32")) = (
+ R.vision.multibox_transform_loc(
+ cls,
+ loc,
+ anc,
+ clip=False,
+ threshold=0.0,
+ variances=(1.0, 1.0, 1.0, 1.0),
+ keep_background=True,
+ )
+ )
+ return gv
+
+ cls = relax.Var("cls", R.Tensor((1, 3, 5), "float32"))
+ loc = relax.Var("loc", R.Tensor((1, 20), "float32"))
+ anc = relax.Var("anc", R.Tensor((1, 5, 4), "float32"))
+
+ bb = relax.BlockBuilder()
+ with bb.function("foo", [cls, loc, anc]):
+ gv = bb.emit(
+ relax.op.vision.multibox_transform_loc(
+ cls,
+ loc,
+ anc,
+ clip=False,
+ threshold=0.0,
+ variances=(1.0, 1.0, 1.0, 1.0),
+ keep_background=True,
+ )
+ )
+ bb.emit_func_output(gv)
+
+ _check(foo, bb.get()["foo"])
+
+
def test_roi_align():
@R.function
def foo(