This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new fed71ef6a6 [Relax] Add native size operator (#18667)
fed71ef6a6 is described below
commit fed71ef6a69facc6031144959f191cf70e963a67
Author: Guan-Ming (Wesley) Chiu <[email protected]>
AuthorDate: Tue Jan 20 21:34:36 2026 +0800
[Relax] Add native size operator (#18667)
## Why
ONNX models use the Size operator to get total element count of a
tensor. Relax didn't have a native equivalent.
## How
- Adds R.size(tensor) operator that returns the total number of elements
in a tensor as a scalar int64
---
python/tvm/relax/frontend/onnx/onnx_frontend.py | 3 +-
python/tvm/relax/op/__init__.py | 1 +
python/tvm/relax/op/base.py | 28 ++++++++--
.../tvm/relax/transform/legalize_ops/inspect_op.py | 6 +++
python/tvm/script/ir_builder/relax/ir.py | 2 +
src/relax/op/op.cc | 26 +++++++++
tests/python/relax/test_op_size.py | 63 ++++++++++++++++++++++
7 files changed, 122 insertions(+), 7 deletions(-)
diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index e14e2ed956..9968eb5ed8 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -911,8 +911,7 @@ class Size(OnnxOpConverter):
@classmethod
def _impl_v1(cls, bb, inputs, attr, params):
- # TODO(tvm-team): add native support for size op
- return
relax.op.prod(relax.op.shape_to_tensor(relax.op.shape_of(inputs[0])))
+ return relax.op.size(inputs[0])
class EyeLike(OnnxOpConverter):
diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py
index c6504d79c9..2ebca3811f 100644
--- a/python/tvm/relax/op/__init__.py
+++ b/python/tvm/relax/op/__init__.py
@@ -40,6 +40,7 @@ from .base import (
register_gradient,
shape_of,
shape_to_tensor,
+ size,
tensor_to_shape,
to_vdevice,
)
diff --git a/python/tvm/relax/op/base.py b/python/tvm/relax/op/base.py
index ffa19fbaa0..d46aa883f0 100644
--- a/python/tvm/relax/op/base.py
+++ b/python/tvm/relax/op/base.py
@@ -634,6 +634,22 @@ def shape_of(expr: Expr) -> Expr:
return _ffi_api.shape_of(expr) # type: ignore # pylint: disable=no-member
+def size(expr: Expr) -> Expr:
+ """Get the total number of elements in a tensor.
+
+ Parameters
+ ----------
+ expr : Expr
+ The input tensor.
+
+ Returns
+ -------
+ result : Expr
+ A scalar tensor of dtype int64 containing the total number of elements.
+ """
+ return _ffi_api.size(expr) # type: ignore # pylint: disable=no-member
+
+
def tensor_to_shape(expr: Expr) -> Expr:
"""Convert tensor to shape expr.
Parameters
@@ -777,11 +793,13 @@ def call_pure_packed(
sinfo_args = [sinfo_args]
sinfo_args = [
- sinfo()
- if callable(sinfo)
- else sinfo.asobject()
- if isinstance(sinfo, ObjectConvertible)
- else sinfo
+ (
+ sinfo()
+ if callable(sinfo)
+ else sinfo.asobject()
+ if isinstance(sinfo, ObjectConvertible)
+ else sinfo
+ )
for sinfo in sinfo_args
]
diff --git a/python/tvm/relax/transform/legalize_ops/inspect_op.py
b/python/tvm/relax/transform/legalize_ops/inspect_op.py
index e031386e6e..a41c74cae0 100644
--- a/python/tvm/relax/transform/legalize_ops/inspect_op.py
+++ b/python/tvm/relax/transform/legalize_ops/inspect_op.py
@@ -23,6 +23,7 @@ from tvm.script import tir as T
from ...block_builder import BlockBuilder
from ...expr import Call, Expr
+from ... import op
from .common import register_legalize
@@ -126,3 +127,8 @@ def _tensor_elem_offset(bb: BlockBuilder, call: Call) ->
Expr:
gvar = bb.add_func(_get_tensor_elem_offset, "_get_tensor_elem_offset")
return Call(gvar, call.args)
+
+
+@register_legalize("relax.size")
+def _size(_bb: BlockBuilder, call: Call) -> Expr:
+ return op.prod(op.shape_to_tensor(op.shape_of(call.args[0])))
diff --git a/python/tvm/script/ir_builder/relax/ir.py
b/python/tvm/script/ir_builder/relax/ir.py
index e0a009a94e..5410c3c03a 100644
--- a/python/tvm/script/ir_builder/relax/ir.py
+++ b/python/tvm/script/ir_builder/relax/ir.py
@@ -163,6 +163,7 @@ from tvm.relax.op import (
sign,
sin,
sinh,
+ size,
slice_scatter,
sort,
split,
@@ -938,6 +939,7 @@ __all__ = [
"shape",
"shape_of",
"ShapeExpr",
+ "size",
"std",
"str",
"sum",
diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc
index 3acfb53b27..d7d68766dd 100644
--- a/src/relax/op/op.cc
+++ b/src/relax/op/op.cc
@@ -1125,6 +1125,32 @@ TVM_FFI_STATIC_INIT_BLOCK() {
refl::GlobalDef().def("relax.op.shape_of", MakeShapeOf);
}
+// size
+
+StructInfo InferStructInfoSize(const Call& call, const BlockBuilder& ctx) {
+ auto arg_sinfo = GetStructInfo(call->args[0]);
+ auto* tensor_sinfo = GetStructInfo(call->args[0]).as<TensorStructInfoNode>();
+ CHECK(tensor_sinfo) << "size expects a tensor input, but received " <<
arg_sinfo
+ << "; use MatchCast if necessary";
+ return TensorStructInfo(ShapeExpr(ffi::Array<PrimExpr>{}),
DataType::Int(64));
+}
+
+TVM_REGISTER_OP("relax.size")
+ .set_num_inputs(1)
+ .add_argument("input", "Expr", "The input tensor")
+ .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoSize)
+ .set_attr<Bool>("FPurity", Bool(true));
+
+Expr MakeSize(Expr expr) {
+ static const Op& op = Op::Get("relax.size");
+ return Call(op, {expr}, {}, {});
+}
+
+TVM_FFI_STATIC_INIT_BLOCK() {
+ namespace refl = tvm::ffi::reflection;
+ refl::GlobalDef().def("relax.op.size", MakeSize);
+}
+
// tensor_to_shape
StructInfo ReturnTensorToShapeStructInfo(const Call& call, const BlockBuilder&
ctx) {
diff --git a/tests/python/relax/test_op_size.py
b/tests/python/relax/test_op_size.py
new file mode 100644
index 0000000000..77c5ebef5a
--- /dev/null
+++ b/tests/python/relax/test_op_size.py
@@ -0,0 +1,63 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import numpy as np
+
+import tvm
+import tvm.testing
+from tvm import relax
+from tvm.script import relax as R
+
+
+def test_op_size():
+ @tvm.script.ir_module
+ class Module:
+ @R.function
+ def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((), "int64"):
+ return R.size(x)
+
+ x_np = np.random.rand(2, 3).astype("float32")
+ x = tvm.runtime.tensor(x_np)
+
+ target = tvm.target.Target("llvm")
+ ex = relax.build(Module, target)
+ vm = relax.VirtualMachine(ex, tvm.cpu())
+
+ res = vm["main"](x)
+ assert res.numpy() == 6
+
+
+def test_op_size_dynamic():
+ @tvm.script.ir_module
+ class Module:
+ @R.function
+ def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor((), "int64"):
+ return R.size(x)
+
+ x_np = np.random.rand(4, 5).astype("float32")
+ x = tvm.runtime.tensor(x_np)
+
+ target = tvm.target.Target("llvm")
+ ex = relax.build(Module, target)
+ vm = relax.VirtualMachine(ex, tvm.cpu())
+
+ res = vm["main"](x)
+ assert res.numpy() == 20
+
+
+if __name__ == "__main__":
+ tvm.testing.main()