This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 899556d2da [Relax][Op][PyTorch] Supported Median operator (#18626)
899556d2da is described below
commit 899556d2da3f0bc191ec01cfb696f90f69f01b66
Author: Nguyen Duy Loc <[email protected]>
AuthorDate: Fri Jan 2 22:12:48 2026 +0700
[Relax][Op][PyTorch] Supported Median operator (#18626)
## Summary:
- Supported Median operator: Add relax.median & Apply median op into
exported_program_translator
- Input: Tensor, Axis, KeepDim
- Output: (Values, Indices)
## Expected:
### 1. Axis = None, KeepDim = False
```
class MedianWithoutDim(nn.Module):
def forward(self, x):
return torch.median(x)
```
```
class Module:
def main(x: R.Tensor((2, 3, 4), dtype="float32")) ->
R.Tuple(R.Tensor((), dtype="float32")):
with R.dataflow():
lv: R.Tensor((), dtype="float32") = R.median(x, axis=None,
keepdims=False)
gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv,)
R.output(gv)
return gv
```
### 2. Axis = 0, KeepDim = False
```
class MedianDim(nn.Module):
def forward(self, x):
return torch.median(x, dim=0)
```
```
class Module:
def main(x: R.Tensor((2, 3, 4), dtype="float32")) ->
R.Tuple(R.Tensor((3, 4), dtype="float32"), R.Tensor((3, 4), dtype="int64")):
with R.dataflow():
lv: R.Tuple(R.Tensor((3, 4), dtype="float32"), R.Tensor((3, 4),
dtype="int64")) = R.median(x, axis=[0], keepdims=False)
lv1: R.Tensor((3, 4), dtype="float32") = lv[0]
lv2: R.Tensor((3, 4), dtype="int64") = lv[1]
gv: R.Tuple(R.Tensor((3, 4), dtype="float32"), R.Tensor((3, 4),
dtype="int64")) = lv1, lv2
R.output(gv)
return gv
```
### 3. Axis = -1, KeepDim = True
```
class MedianKeepDim(nn.Module):
def forward(self, x):
return torch.median(x, dim=-1, keepdim=True)
```
```
class Module:
def main(x: R.Tensor((2, 3, 4), dtype="float32")) ->
R.Tuple(R.Tensor((2, 3, 1), dtype="float32"), R.Tensor((2, 3, 1),
dtype="int64")):
with R.dataflow():
lv: R.Tuple(R.Tensor((2, 3, 1), dtype="float32"), R.Tensor((2,
3, 1), dtype="int64")) = R.median(x, axis=[-1], keepdims=True)
lv1: R.Tensor((2, 3, 1), dtype="float32") = lv[0]
lv2: R.Tensor((2, 3, 1), dtype="int64") = lv[1]
gv: R.Tuple(R.Tensor((2, 3, 1), dtype="float32"), R.Tensor((2,
3, 1), dtype="int64")) = lv1, lv2
R.output(gv)
return gv
```
---
.../frontend/torch/base_fx_graph_translator.py | 7 +
.../frontend/torch/exported_program_translator.py | 2 +
python/tvm/relax/op/__init__.py | 2 +-
python/tvm/relax/op/statistical.py | 27 +++
.../relax/transform/legalize_ops/statistical.py | 47 ++++-
python/tvm/script/ir_builder/relax/ir.py | 2 +
src/relax/op/tensor/statistical.cc | 82 ++++++++
src/relax/op/tensor/statistical.h | 3 +
.../relax/test_frontend_from_exported_program.py | 68 +++++++
tests/python/relax/test_op_statistical.py | 226 +++++++++++++++++++++
...st_transform_legalize_ops_search_statistical.py | 78 +++++++
.../relax/test_tvmscript_parser_op_statistical.py | 19 ++
12 files changed, 561 insertions(+), 2 deletions(-)
diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index f7d54a6216..d04dfbb6c3 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -1572,6 +1572,13 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
keepdim = args[2] if len(node.args) > 2 else
node.kwargs.get("keepdim", False)
return self.block_builder.emit(relax.op.mean(x, dim, keepdims=keepdim))
+ def _median(self, node: fx.Node) -> relax.Var:
+ args = self.retrieve_args(node)
+ x = args[0]
+ dim = args[1] if len(node.args) > 1 else node.kwargs.get("dim", None)
+ keepdim = args[2] if len(node.args) > 2 else
node.kwargs.get("keepdim", False)
+ return self.block_builder.emit(relax.op.median(x, dim,
keepdims=keepdim))
+
def _norm(self, node: fx.Node) -> relax.Var:
data = self.env[node.args[0]]
dtype = data.struct_info.dtype
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index b6b9723c13..0a97614eb5 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -1384,6 +1384,8 @@ class ExportedProgramImporter(BaseFXGraphImporter):
"sum.dim_IntList": self._sum,
"var.correction": self._var,
"max.dim": self._max_dim,
+ "median.dim": self._median,
+ "median.default": self._median,
# search
"argmax.default": self._argmax_argmin(relax.op.argmax),
"argmin.default": self._argmax_argmin(relax.op.argmin),
diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py
index 19096decd9..c6504d79c9 100644
--- a/python/tvm/relax/op/__init__.py
+++ b/python/tvm/relax/op/__init__.py
@@ -119,7 +119,7 @@ from .sampling import multinomial_from_uniform
from .search import argmax, argmin, where, bucketize
from .set import nonzero, unique
from .sorting import argsort, sort, topk
-from .statistical import cumprod, cumsum, max, mean, min, prod, std, sum,
variance
+from .statistical import cumprod, cumsum, max, mean, min, prod, std, sum,
variance, median
from .ternary import ewise_fma
from .unary import (
abs,
diff --git a/python/tvm/relax/op/statistical.py
b/python/tvm/relax/op/statistical.py
index 502d058ffd..f11d31604a 100644
--- a/python/tvm/relax/op/statistical.py
+++ b/python/tvm/relax/op/statistical.py
@@ -341,3 +341,30 @@ def variance(x: Expr, axis: Optional[Union[int,
List[int]]] = None, keepdims: bo
if isinstance(axis, int):
axis = [axis]
return _ffi_api.variance(x, axis, keepdims) # type: ignore
+
+
+def median(x: Expr, axis: Optional[Union[int, List[int]]] = None, keepdims:
bool = False) -> Expr:
+ """Computes the median of tensor elements over given axes.
+
+ Parameters
+ ----------
+ x : relax.Expr
+ The input data tensor
+
+ axis : Optional[Union[int, List[int]]]
+ Axis along which the median is computed. The default (None) is to
compute
+ the median of the entire flattened tensor.
+
+ keepdims : bool
+ If this is set to True, the axes which are reduced are left in the
result as dimensions
+ with size one.
+ With this option, the result will broadcast correctly against the
input tensor.
+
+ Returns
+ -------
+ result : relax.Expr
+ The computed result.
+ """
+ if isinstance(axis, int):
+ axis = [axis]
+ return _ffi_api.median(x, axis, keepdims) # type: ignore
diff --git a/python/tvm/relax/transform/legalize_ops/statistical.py
b/python/tvm/relax/transform/legalize_ops/statistical.py
index bdb79126f0..0c140187db 100644
--- a/python/tvm/relax/transform/legalize_ops/statistical.py
+++ b/python/tvm/relax/transform/legalize_ops/statistical.py
@@ -16,7 +16,7 @@
# under the License.
# pylint: disable=invalid-name
"""Default legalization function for statistical operators."""
-from typing import List
+from typing import List, Union, Tuple
from tvm import topi, tir, te
from ...block_builder import BlockBuilder
from ...expr import Call, Expr
@@ -53,6 +53,40 @@ def _te_variance(x: te.Tensor, axis: List[tir.IntImm],
keepdims: bool) -> te.Ten
# return _te_mean(x * x, axis, keepdims) - mean * mean
+def _te_median(
+ x: te.Tensor, axis: List[tir.IntImm], keepdims: bool
+) -> Union[te.Tensor, Tuple[te.Tensor, te.Tensor]]:
+ # currently only supports one axis or no axis ~ same pytorch
+ # todo: support multiple axis ~ same numpy
+ shape_prod = _compute_shape_prod(x, axis)
+ mid_index = (shape_prod - 1) // 2
+
+ if axis is None or len(axis) == 0:
+ x = topi.reshape(x, [shape_prod.value])
+ ax = -1
+ else:
+ ax = axis[0].value
+ index_sorted = topi.argsort(x, axis=ax, is_ascend=True, dtype="int64")
+ x_sorted = topi.gather(x, axis=ax, indices=index_sorted)
+
+ new_shape = list(x.shape)
+ new_shape[ax] = 1
+ indices = topi.full(new_shape, fill_value=mid_index, dtype="int64")
+
+ median_val = topi.gather(x_sorted, axis=ax, indices=indices)
+ median_idx = topi.gather(index_sorted, axis=ax, indices=indices)
+
+ if axis is None or len(axis) == 0:
+ return median_val if keepdims else topi.squeeze(median_val, axis=axis)
+
+ val = median_val
+ idx = median_idx
+ if not keepdims:
+ val = topi.squeeze(val, axis=axis)
+ idx = topi.squeeze(idx, axis=axis)
+ return val, idx
+
+
@register_legalize("relax.mean")
def _mean(bb: BlockBuilder, call: Call) -> Expr:
return bb.call_te(
@@ -81,6 +115,17 @@ def _variance(bb: BlockBuilder, call: Call) -> Expr:
)
+@register_legalize("relax.median")
+def _median(bb: BlockBuilder, call: Call) -> Expr:
+ return bb.call_te(
+ _te_median,
+ call.args[0],
+ call.attrs.axis,
+ call.attrs.keepdims,
+ primfunc_name_hint="median",
+ )
+
+
register_legalize("relax.max", _statistical(topi.max))
register_legalize("relax.min", _statistical(topi.min))
register_legalize("relax.prod", _statistical(topi.prod))
diff --git a/python/tvm/script/ir_builder/relax/ir.py
b/python/tvm/script/ir_builder/relax/ir.py
index 141361a729..354a4d77ba 100644
--- a/python/tvm/script/ir_builder/relax/ir.py
+++ b/python/tvm/script/ir_builder/relax/ir.py
@@ -128,6 +128,7 @@ from tvm.relax.op import (
max,
maximum,
mean,
+ median,
memory,
meshgrid,
min,
@@ -874,6 +875,7 @@ __all__ = [
"max",
"maximum",
"mean",
+ "median",
"memory",
"meshgrid",
"metal",
diff --git a/src/relax/op/tensor/statistical.cc
b/src/relax/op/tensor/statistical.cc
index 621c23d363..771f6ffb13 100644
--- a/src/relax/op/tensor/statistical.cc
+++ b/src/relax/op/tensor/statistical.cc
@@ -180,6 +180,68 @@ StructInfo InferStructInfoScan(const Call& call, const
BlockBuilder& ctx) {
}
}
+StructInfo InferStructInfoStatisticalExtension(const Call& call, const
BlockBuilder& ctx) {
+ TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx);
+ const auto* attrs = call->attrs.as<StatisticalAttrs>();
+
+ std::vector<int> axes;
+ if (!data_sinfo->IsUnknownNdim() && attrs->axis.defined()) {
+ axes = NormalizeAxes(call, ctx, data_sinfo->ndim, attrs->axis.value());
+ }
+
+ int out_ndim;
+ if (attrs->keepdims) {
+ out_ndim = data_sinfo->ndim;
+ } else if (!attrs->axis.defined()) {
+ out_ndim = 0;
+ } else if (data_sinfo->IsUnknownNdim()) {
+ out_ndim = kUnknownNDim;
+ } else {
+ out_ndim = data_sinfo->ndim - axes.size();
+ ICHECK_GE(out_ndim, 0);
+ }
+
+ // The inference rule for median operator output shapes:
+ // - axes is None || len(axes) > 1, keepdims is false -> return the
zero-rank shape;
+ // - axes is None || len(axes) > 1, keepdims is true -> return the shape
whose ndim
+ // is the same as input and every value is 1.
+ // - len(axes) == 1, keepdims is false -> the returned shape does not
contain the input axis.
+ // - len(axes) == 1, keepdims is true -> the returned shape has value 1 at
the positions of the
+ // input axis
+ const auto* data_shape = data_sinfo->shape.as<ShapeExprNode>();
+ if (data_shape == nullptr) {
+ if (!attrs->axis.defined() && attrs->keepdims && out_ndim != kUnknownNDim)
{
+ return TensorStructInfo(
+ ShapeExpr(ffi::Array<PrimExpr>(out_ndim, IntImm(DataType::Int(64),
/*value=*/1))),
+ data_sinfo->dtype, data_sinfo->vdevice);
+ }
+ if (out_ndim == 0) {
+ return TensorStructInfo(ShapeExpr(ffi::Array<PrimExpr>()),
data_sinfo->dtype,
+ data_sinfo->vdevice);
+ }
+ return TupleStructInfo({TensorStructInfo(data_sinfo->dtype, out_ndim,
data_sinfo->vdevice),
+ TensorStructInfo(DataType::Int(64), out_ndim,
data_sinfo->vdevice)});
+ }
+
+ ffi::Array<PrimExpr> out_shape;
+ out_shape.reserve(out_ndim);
+ for (int i = 0; i < data_sinfo->ndim; ++i) {
+ if (attrs->axis.defined() && std::find(axes.begin(), axes.end(), i) ==
axes.end()) {
+ out_shape.push_back(data_shape->values[i]);
+ } else if (attrs->keepdims) {
+ out_shape.push_back(IntImm(DataType::Int(64), /*value=*/1));
+ }
+ }
+ ICHECK_EQ(static_cast<int>(out_shape.size()), out_ndim);
+
+ if (!attrs->axis.defined() || axes.size() > 1)
+ return TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype,
data_sinfo->vdevice);
+ else
+ return TupleStructInfo(
+ {TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype,
data_sinfo->vdevice),
+ TensorStructInfo(ShapeExpr(out_shape), DataType::Int(64),
data_sinfo->vdevice)});
+}
+
/* relax.cumprod */
Expr cumprod(Expr data, ffi::Optional<int64_t> axis, ffi::Optional<DataType>
dtype,
Bool exclusive) {
@@ -227,6 +289,26 @@ TVM_REGISTER_OP("relax.cumsum")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoScan)
.set_attr<Bool>("FPurity", Bool(true));
+/* relax.median */
+Expr median(Expr data, ffi::Optional<ffi::Array<Integer>> axis, bool keepdims)
{
+ ObjectPtr<StatisticalAttrs> attrs = ffi::make_object<StatisticalAttrs>();
+ attrs->axis = std::move(axis);
+ attrs->keepdims = keepdims;
+ static const Op& op = Op::Get("relax.median");
+ return Call(op, {std::move(data)}, Attrs{attrs}, {});
+}
+
+TVM_FFI_STATIC_INIT_BLOCK() {
+ namespace refl = tvm::ffi::reflection;
+ refl::GlobalDef().def("relax.op.median", median);
+}
+
+TVM_REGISTER_OP("relax.median")
+ .set_num_inputs(1)
+ .add_argument("data", "Tensor", "The input tensor.")
+ .set_attr<FInferStructInfo>("FInferStructInfo",
InferStructInfoStatisticalExtension)
+ .set_attr<Bool>("FPurity", Bool(true));
+
RELAX_REGISTER_STATISTICAL_OP_INTERFACE(max);
RELAX_REGISTER_STATISTICAL_OP_INTERFACE(mean);
RELAX_REGISTER_STATISTICAL_OP_INTERFACE(min);
diff --git a/src/relax/op/tensor/statistical.h
b/src/relax/op/tensor/statistical.h
index a80ef72868..0a4f83687d 100644
--- a/src/relax/op/tensor/statistical.h
+++ b/src/relax/op/tensor/statistical.h
@@ -119,6 +119,9 @@ Expr cumsum(Expr data, ffi::Optional<int64_t> axis =
std::nullopt,
/*! \brief Computes the variance of tensor elements over given axes. */
Expr variance(Expr x, ffi::Optional<ffi::Array<Integer>> axis, bool keepdims);
+/*! \brief Computes the median of tensor elements over given axes. */
+Expr median(Expr x, ffi::Optional<ffi::Array<Integer>> axis, bool keepdims);
+
} // namespace relax
} // namespace tvm
diff --git a/tests/python/relax/test_frontend_from_exported_program.py
b/tests/python/relax/test_frontend_from_exported_program.py
index 9f8842ddcb..01a24ada1f 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -4957,6 +4957,74 @@ def test_mean():
verify_model(MeanWithoutDim(), example_args, {}, Expected3)
+def test_median():
+ class Median(Module):
+ def forward(self, input):
+ return input.median(-1)
+
+ class MedianKeepDim(Module):
+ def forward(self, input):
+ return input.median(-1, keepdim=True)
+
+ class MedianWithoutDim(Module):
+ def forward(self, input):
+ return input.median()
+
+ @I.ir_module
+ class Expected1:
+ @R.function
+ def main(
+ inp_0: R.Tensor((256, 256), dtype="float32")
+ ) -> R.Tuple(R.Tensor((256,), dtype="float32"), R.Tensor((256,),
dtype="int64")):
+ with R.dataflow():
+ lv: R.Tuple(
+ R.Tensor((256,), dtype="float32"), R.Tensor((256,),
dtype="int64")
+ ) = R.median(inp_0, axis=[-1], keepdims=False)
+ lv1: R.Tensor((256,), dtype="float32") = lv[0]
+ lv2: R.Tensor((256,), dtype="int64") = lv[1]
+ gv: R.Tuple(R.Tensor((256,), dtype="float32"),
R.Tensor((256,), dtype="int64")) = (
+ lv1,
+ lv2,
+ )
+ R.output(gv)
+ return gv
+
+ @I.ir_module
+ class Expected2:
+ @R.function
+ def main(
+ inp_0: R.Tensor((256, 256), dtype="float32")
+ ) -> R.Tuple(R.Tensor((256, 1), dtype="float32"), R.Tensor((256, 1),
dtype="int64")):
+ with R.dataflow():
+ lv: R.Tuple(
+ R.Tensor((256, 1), dtype="float32"), R.Tensor((256, 1),
dtype="int64")
+ ) = R.median(inp_0, axis=[-1], keepdims=True)
+ lv1: R.Tensor((256, 1), dtype="float32") = lv[0]
+ lv2: R.Tensor((256, 1), dtype="int64") = lv[1]
+ gv: R.Tuple(
+ R.Tensor((256, 1), dtype="float32"), R.Tensor((256, 1),
dtype="int64")
+ ) = (lv1, lv2)
+ R.output(gv)
+ return gv
+
+ @I.ir_module
+ class Expected3:
+ @R.function
+ def main(
+ inp_0: R.Tensor((256, 256), dtype="float32")
+ ) -> R.Tuple(R.Tensor((), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((), dtype="float32") = R.median(inp_0, axis=None,
keepdims=False)
+ gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ example_args = (torch.randn(256, 256, dtype=torch.float32),)
+ verify_model(Median(), example_args, {}, Expected1)
+ verify_model(MedianKeepDim(), example_args, {}, Expected2)
+ verify_model(MedianWithoutDim(), example_args, {}, Expected3)
+
+
def test_sum():
class Sum(Module):
def forward(self, x):
diff --git a/tests/python/relax/test_op_statistical.py
b/tests/python/relax/test_op_statistical.py
index a0cfc81e55..5dccbb33cc 100644
--- a/tests/python/relax/test_op_statistical.py
+++ b/tests/python/relax/test_op_statistical.py
@@ -33,6 +33,7 @@ def test_op_correctness():
assert relax.op.std(x).op == Op.get("relax.std")
assert relax.op.sum(x).op == Op.get("relax.sum")
assert relax.op.variance(x).op == Op.get("relax.variance")
+ assert relax.op.median(x).op == Op.get("relax.median")
def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo:
relax.StructInfo):
@@ -275,5 +276,230 @@ def
test_scan_opinfer_struct_info_wrong_input_type(scan_op: Callable):
bb.normalize(scan_op(x1, axis=1))
+def test_statistical_ext_infer_struct_info():
+ bb = relax.BlockBuilder()
+ vdev0 = VDevice("llvm")
+ x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32"))
+ x1 = relax.Var("x", R.Tensor("float32", ndim=4))
+ x2 = relax.Var("x", R.Tensor("float32"))
+ x3 = relax.Var("x", R.Tensor((2, 3, 4, 5)))
+ x4 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32", vdev0))
+
+ _check_inference(
+ bb,
+ relax.op.median(x0, axis=[1]),
+ relax.TupleStructInfo(
+ [
+ relax.TensorStructInfo((2, 4, 5), "float32"),
+ relax.TensorStructInfo((2, 4, 5), "int64"),
+ ]
+ ),
+ )
+ _check_inference(
+ bb,
+ relax.op.median(x0, axis=[1], keepdims=True),
+ relax.TupleStructInfo(
+ [
+ relax.TensorStructInfo((2, 1, 4, 5), "float32"),
+ relax.TensorStructInfo((2, 1, 4, 5), "int64"),
+ ]
+ ),
+ )
+ _check_inference(
+ bb,
+ relax.op.median(x1, axis=[1]),
+ relax.TupleStructInfo(
+ [
+ relax.TensorStructInfo(dtype="float32", ndim=3),
+ relax.TensorStructInfo(dtype="int64", ndim=3),
+ ]
+ ),
+ )
+ _check_inference(
+ bb,
+ relax.op.median(x1, axis=[1], keepdims=True),
+ relax.TupleStructInfo(
+ [
+ relax.TensorStructInfo(dtype="float32", ndim=4),
+ relax.TensorStructInfo(dtype="int64", ndim=4),
+ ]
+ ),
+ )
+ _check_inference(
+ bb,
+ relax.op.median(x1, axis=None, keepdims=True),
+ relax.TensorStructInfo((1, 1, 1, 1), "float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.median(x2, axis=[1]),
+ relax.TupleStructInfo(
+ [
+ relax.TensorStructInfo(dtype="float32"),
+ relax.TensorStructInfo(dtype="int64"),
+ ]
+ ),
+ )
+ _check_inference(
+ bb,
+ relax.op.median(x2, axis=[1], keepdims=True),
+ relax.TupleStructInfo(
+ [
+ relax.TensorStructInfo(dtype="float32"),
+ relax.TensorStructInfo(dtype="int64"),
+ ]
+ ),
+ )
+ _check_inference(bb, relax.op.median(x2, axis=None),
relax.TensorStructInfo((), "float32"))
+ _check_inference(
+ bb,
+ relax.op.median(x3, axis=[1], keepdims=True),
+ relax.TupleStructInfo(
+ [
+ relax.TensorStructInfo((2, 1, 4, 5), dtype=""),
+ relax.TensorStructInfo((2, 1, 4, 5), dtype="int64"),
+ ]
+ ),
+ )
+ _check_inference(bb, relax.op.median(x3, axis=None),
relax.TensorStructInfo((), dtype=""))
+ _check_inference(
+ bb,
+ relax.op.median(x3, axis=None, keepdims=True),
+ relax.TensorStructInfo((1, 1, 1, 1), dtype=""),
+ )
+ _check_inference(
+ bb,
+ relax.op.median(x4, axis=[1]),
+ relax.TupleStructInfo(
+ [
+ relax.TensorStructInfo((2, 4, 5), "float32", vdev0),
+ relax.TensorStructInfo((2, 4, 5), "int64", vdev0),
+ ]
+ ),
+ )
+
+
+def test_statistical_ext_infer_struct_info_shape_symbolic():
+ bb = relax.BlockBuilder()
+ a = tir.Var("a", "int64")
+ b = tir.Var("b", "int64")
+ c = tir.Var("c", "int64")
+ d = tir.Var("d", "int64")
+ x = relax.Var("x", R.Tensor((a, b, c, d), "float32"))
+
+ _check_inference(
+ bb,
+ relax.op.median(x, axis=[1]),
+ relax.TupleStructInfo(
+ [
+ relax.TensorStructInfo((a, c, d), "float32"),
+ relax.TensorStructInfo((a, c, d), "int64"),
+ ]
+ ),
+ )
+ _check_inference(
+ bb,
+ relax.op.median(x, axis=[1], keepdims=True),
+ relax.TupleStructInfo(
+ [
+ relax.TensorStructInfo((a, 1, c, d), "float32"),
+ relax.TensorStructInfo((a, 1, c, d), "int64"),
+ ]
+ ),
+ )
+ _check_inference(bb, relax.op.median(x, axis=None),
relax.TensorStructInfo((), "float32"))
+ _check_inference(
+ bb,
+ relax.op.median(x, axis=None, keepdims=True),
+ relax.TensorStructInfo((1, 1, 1, 1), "float32"),
+ )
+
+
+def test_statistical_ext_infer_struct_info_shape_var():
+ bb = relax.BlockBuilder()
+ s0 = relax.Var("s", relax.ShapeStructInfo(ndim=4))
+ s1 = relax.Var("s", relax.ShapeStructInfo())
+ x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32"))
+ x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32"))
+
+ _check_inference(bb, relax.op.median(x0), relax.TensorStructInfo((),
dtype="float32"))
+ _check_inference(
+ bb,
+ relax.op.median(x0, keepdims=True),
+ relax.TensorStructInfo((1, 1, 1, 1), dtype="float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.median(x0, axis=[2]),
+ relax.TupleStructInfo(
+ [
+ relax.TensorStructInfo(dtype="float32", ndim=3),
+ relax.TensorStructInfo(dtype="int64", ndim=3),
+ ]
+ ),
+ )
+ _check_inference(
+ bb,
+ relax.op.median(x0, axis=[2], keepdims=True),
+ relax.TupleStructInfo(
+ [
+ relax.TensorStructInfo(dtype="float32", ndim=4),
+ relax.TensorStructInfo(dtype="int64", ndim=4),
+ ]
+ ),
+ )
+ _check_inference(bb, relax.op.median(x1), relax.TensorStructInfo((),
dtype="float32"))
+ _check_inference(
+ bb,
+ relax.op.median(x1, keepdims=True),
+ relax.TupleStructInfo(
+ [
+ relax.TensorStructInfo(dtype="float32"),
+ relax.TensorStructInfo(dtype="int64"),
+ ]
+ ),
+ )
+ _check_inference(
+ bb,
+ relax.op.median(x1, axis=[2]),
+ relax.TupleStructInfo(
+ [
+ relax.TensorStructInfo(dtype="float32"),
+ relax.TensorStructInfo(dtype="int64"),
+ ]
+ ),
+ )
+ _check_inference(
+ bb,
+ relax.op.median(x1, axis=[2], keepdims=True),
+ relax.TupleStructInfo(
+ [
+ relax.TensorStructInfo(dtype="float32"),
+ relax.TensorStructInfo(dtype="int64"),
+ ]
+ ),
+ )
+
+
+def test_statistical_ext_infer_struct_info_more_input_dtype():
+ bb = relax.BlockBuilder()
+ x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float16"))
+ x1 = relax.Var("x", R.Tensor((2, 3, 4, 5), "int8"))
+
+ _check_inference(bb, relax.op.median(x0), relax.TensorStructInfo((),
"float16"))
+ _check_inference(bb, relax.op.median(x1), relax.TensorStructInfo((),
"int8"))
+
+
+def test_statistical_ext_infer_struct_info_wrong_input_type():
+ bb = relax.BlockBuilder()
+ x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 4, 5)))
+ x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 4, 5),
"float32")))
+
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.median(x0))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.median(x1))
+
+
if __name__ == "__main__":
tvm.testing.main()
diff --git
a/tests/python/relax/test_transform_legalize_ops_search_statistical.py
b/tests/python/relax/test_transform_legalize_ops_search_statistical.py
index 7edfff3dfc..b28451da1b 100644
--- a/tests/python/relax/test_transform_legalize_ops_search_statistical.py
+++ b/tests/python/relax/test_transform_legalize_ops_search_statistical.py
@@ -684,6 +684,84 @@ def test_mean_symbolic():
tvm.ir.assert_structural_equal(mod, Expected)
+def test_median():
+ # fmt: off
+ @tvm.script.ir_module
+ class Median:
+ @R.function
+ def main(x: R.Tensor((2, 3, 4, 5), dtype="float32")) ->
R.Tuple(R.Tensor((3, 4, 5), dtype="float32"), R.Tensor((3, 4, 5),
dtype="int64")):
+ gv: R.Tuple(R.Tensor((3, 4, 5), dtype="float32"), R.Tensor((3, 4,
5), dtype="int64")) = R.median(x, axis=[0], keepdims=False)
+ return gv
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(x: R.Tensor((2, 3, 4, 5), dtype="float32")) ->
R.Tuple(R.Tensor((3, 4, 5), dtype="float32"), R.Tensor((3, 4, 5),
dtype="int64")):
+ gv = R.call_tir(Expected.median, (x,), out_sinfo=[R.Tensor((3, 4,
5), dtype="float32"), R.Tensor((3, 4, 5), dtype="int64")])
+ return gv
+
+ @T.prim_func(private=True)
+ def median(var_x: T.handle, T_squeeze: T.Buffer((T.int64(3),
T.int64(4), T.int64(5)), "float32"), T_squeeze_1: T.Buffer((T.int64(3),
T.int64(4), T.int64(5)), "int64")):
+ T.func_attr({"tir.noalias": True})
+ data_buf = T.match_buffer(var_x, (T.int64(2), T.int64(3),
T.int64(4), T.int64(5)), align=8)
+ # with T.block("root"):
+ T_full = T.alloc_buffer((T.int64(1), T.int64(3), T.int64(4),
T.int64(5)), "int64")
+ out_buf = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(4),
T.int64(5)), "int64", align=8)
+ T_gather = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(4),
T.int64(5)))
+ T_gather_1 = T.alloc_buffer((T.int64(1), T.int64(3), T.int64(4),
T.int64(5)))
+ T_gather_2 = T.alloc_buffer((T.int64(1), T.int64(3), T.int64(4),
T.int64(5)), "int64")
+ for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(3),
T.int64(4), T.int64(5)):
+ with T.block("T_full"):
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
+ T.reads()
+ T.writes(T_full[v_ax0, v_ax1, v_ax2, v_ax3])
+ T_full[v_ax0, v_ax1, v_ax2, v_ax3] = 0
+ with T.block("argsort_cpu"):
+ T.reads()
+ T.writes()
+ T.call_packed("tvm.contrib.sort.argsort",
T.tvm_stack_make_array(data_buf.data,
+
T.tvm_stack_make_shape(T.int64(2), T.int64(3), T.int64(4), T.int64(5)),
+
0, 4, T.float32(0.0), T.int64(0)),
+
T.tvm_stack_make_array(out_buf.data,
+
T.tvm_stack_make_shape(T.int64(2), T.int64(3), T.int64(4), T.int64(5)),
+
0, 4, T.int64(0), T.int64(0)),
+ 0, T.bool(True))
+ for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3),
T.int64(4), T.int64(5)):
+ with T.block("T_gather"):
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
+ T.reads(data_buf[out_buf[v_ax0, v_ax1, v_ax2, v_ax3],
v_ax1, v_ax2, v_ax3], out_buf[v_ax0, v_ax1, v_ax2, v_ax3])
+ T.writes(T_gather[v_ax0, v_ax1, v_ax2, v_ax3])
+ T_gather[v_ax0, v_ax1, v_ax2, v_ax3] =
data_buf[out_buf[v_ax0, v_ax1, v_ax2, v_ax3], v_ax1, v_ax2, v_ax3]
+ for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(3),
T.int64(4), T.int64(5)):
+ with T.block("T_gather_1"):
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
+ T.reads(T_gather[T_full[v_ax0, v_ax1, v_ax2, v_ax3],
v_ax1, v_ax2, v_ax3], T_full[v_ax0, v_ax1, v_ax2, v_ax3])
+ T.writes(T_gather_1[v_ax0, v_ax1, v_ax2, v_ax3])
+ T_gather_1[v_ax0, v_ax1, v_ax2, v_ax3] =
T_gather[T_full[v_ax0, v_ax1, v_ax2, v_ax3], v_ax1, v_ax2, v_ax3]
+ for ax0, ax1, ax2 in T.grid(T.int64(3), T.int64(4), T.int64(5)):
+ with T.block("T_squeeze"):
+ v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+ T.reads(T_gather_1[T.int64(0), v_ax0, v_ax1, v_ax2])
+ T.writes(T_squeeze[v_ax0, v_ax1, v_ax2])
+ T_squeeze[v_ax0, v_ax1, v_ax2] = T_gather_1[T.int64(0),
v_ax0, v_ax1, v_ax2]
+ for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(3),
T.int64(4), T.int64(5)):
+ with T.block("T_gather_2"):
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
+ T.reads(out_buf[T_full[v_ax0, v_ax1, v_ax2, v_ax3], v_ax1,
v_ax2, v_ax3], T_full[v_ax0, v_ax1, v_ax2, v_ax3])
+ T.writes(T_gather_2[v_ax0, v_ax1, v_ax2, v_ax3])
+ T_gather_2[v_ax0, v_ax1, v_ax2, v_ax3] =
out_buf[T_full[v_ax0, v_ax1, v_ax2, v_ax3], v_ax1, v_ax2, v_ax3]
+ for ax0, ax1, ax2 in T.grid(T.int64(3), T.int64(4), T.int64(5)):
+ with T.block("T_squeeze_1"):
+ v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+ T.reads(T_gather_2[T.int64(0), v_ax0, v_ax1, v_ax2])
+ T.writes(T_squeeze_1[v_ax0, v_ax1, v_ax2])
+ T_squeeze_1[v_ax0, v_ax1, v_ax2] = T_gather_2[T.int64(0),
v_ax0, v_ax1, v_ax2]
+ # fmt: on
+
+ mod = LegalizeOps()(Median)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
def test_std():
# fmt: off
@tvm.script.ir_module
diff --git a/tests/python/relax/test_tvmscript_parser_op_statistical.py
b/tests/python/relax/test_tvmscript_parser_op_statistical.py
index 910c08bf1e..6ba90c5651 100644
--- a/tests/python/relax/test_tvmscript_parser_op_statistical.py
+++ b/tests/python/relax/test_tvmscript_parser_op_statistical.py
@@ -95,6 +95,25 @@ def test_mean():
_check(foo, bb.get()["foo"])
+def test_median():
+ @R.function
+ def foo(
+ x: R.Tensor((1, 2, 3, 4), "float32")
+ ) -> R.Tuple(R.Tensor((1, 3, 4), "float32"), R.Tensor((1, 3, 4), "int64")):
+ gv: R.Tuple(R.Tensor((1, 3, 4), "float32"), R.Tensor((1, 3, 4),
"int64")) = R.median(
+ x, axis=[1]
+ )
+ return gv
+
+ x = relax.Var("x", R.Tensor((1, 2, 3, 4), "float32"))
+ bb = relax.BlockBuilder()
+ with bb.function("foo", [x]):
+ gv = bb.emit(relax.op.median(x, axis=[1]))
+ bb.emit_func_output(gv)
+
+ _check(foo, bb.get()["foo"])
+
+
def test_variance():
@R.function
def foo(x: R.Tensor((1, 2, 3, 4), "float32")) -> R.Tensor((1,), "float32"):