This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch unity-staging
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit 7d8095e8be62137305bafb29490e9190482b3675
Author: Ruihang Lai <[email protected]>
AuthorDate: Tue Feb 14 15:01:07 2023 -0500

    [Unity] Relax op: statistical (#13991)
    
    This PR is about the high-level tensor computation operators in Relax.
    
    This PR includes the statistical operators.
---
 include/tvm/relax/attrs/statistical.h              |  48 +++++
 python/tvm/relax/op/__init__.py                    |   1 +
 python/tvm/relax/op/op_attrs.py                    |   5 +
 python/tvm/relax/op/statistical.py                 | 218 +++++++++++++++++++++
 python/tvm/script/ir_builder/relax/ir.py           |  18 ++
 src/relax/op/tensor/statistical.cc                 |  96 +++++++++
 src/relax/op/tensor/statistical.h                  |  92 +++++++++
 tests/python/relax/test_op_statistical.py          | 204 +++++++++++++++++++
 .../relax/test_tvmscript_parser_op_statistical.py  | 174 ++++++++++++++++
 9 files changed, 856 insertions(+)

diff --git a/include/tvm/relax/attrs/statistical.h 
b/include/tvm/relax/attrs/statistical.h
new file mode 100644
index 0000000000..bb1ab2195d
--- /dev/null
+++ b/include/tvm/relax/attrs/statistical.h
@@ -0,0 +1,48 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/relax/attrs/statistical.h
+ * \brief Attributes for statistical operators.
+ */
+#ifndef TVM_RELAX_ATTRS_STATISTICAL_H_
+#define TVM_RELAX_ATTRS_STATISTICAL_H_
+
+#include <tvm/relax/expr.h>
+
+namespace tvm {
+namespace relax {
+
+/*! \brief Attributes for statistical operators */
+struct StatisticalAttrs : public tvm::AttrsNode<StatisticalAttrs> {
+  Optional<Array<Integer>> axis;
+  bool keepdims;
+
+  TVM_DECLARE_ATTRS(StatisticalAttrs, "relax.attrs.StatisticalAttrs") {
+    TVM_ATTR_FIELD(axis).describe("The axis or axes along which to perform the 
reduction.");
+    TVM_ATTR_FIELD(keepdims).describe(
+        "If this is set to `True`, the reduced axes are left in the result as 
dimension with size "
+        "one.");
+  }
+};  // struct StatisticalAttrs
+
+}  // namespace relax
+}  // namespace tvm
+
+#endif  // TVM_RELAX_ATTRS_STATISTICAL_H_
diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py
index 344576fe13..68152c2056 100644
--- a/python/tvm/relax/op/__init__.py
+++ b/python/tvm/relax/op/__init__.py
@@ -24,6 +24,7 @@ from .datatype import *
 from .index import *
 from .manipulate import *
 from .op_attrs import *
+from .statistical import *
 from .set import *
 from .ternary import *
 from .unary import *
diff --git a/python/tvm/relax/op/op_attrs.py b/python/tvm/relax/op/op_attrs.py
index fb64443b7e..1fb8853040 100644
--- a/python/tvm/relax/op/op_attrs.py
+++ b/python/tvm/relax/op/op_attrs.py
@@ -34,6 +34,11 @@ class StridedSliceAttrs(Attrs):
     """Attributes used in strided_slice operator"""
 
 
+@tvm._ffi.register_object("relax.attrs.StatisticalAttrs")
+class StatisticalAttrs(Attrs):
+    """Attributes used in statistical operator"""
+
+
 @tvm._ffi.register_object("relax.attrs.Resize2DAttrs")
 class Resize2DAttrs(Attrs):
     """Attributes used in image resize2d operator"""
diff --git a/python/tvm/relax/op/statistical.py 
b/python/tvm/relax/op/statistical.py
new file mode 100644
index 0000000000..4669c783ad
--- /dev/null
+++ b/python/tvm/relax/op/statistical.py
@@ -0,0 +1,218 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+"""Statistical operators."""
+from typing import List, Optional, Union
+
+from . import _ffi_api
+from ..expr import Expr
+
+
+def max(x: Expr, axis: Optional[Union[int, List[int]]] = None, keepdims: bool 
= False) -> Expr:
+    """Computes the max of tensor elements over given axes.
+
+    Parameters
+    ----------
+    x : relax.Expr
+        The input data tensor
+
+    axis : Optional[Union[int, List[int]]]
+        Axis or axes along which a max operation is performed.
+        The default, axis=None, will compute the max of all elements in the 
input tensor.
+        Negative indexing is supported.
+
+    keepdims : bool
+        If this is set to True, the axes which are reduced are left in the 
result as dimensions
+        with size one.
+        With this option, the result will broadcast correctly against the 
input tensor.
+
+    Returns
+    -------
+    result : relax.Expr
+        The computed result.
+    """
+    if isinstance(axis, int):
+        axis = [axis]
+    return _ffi_api.max(x, axis, keepdims)  # type: ignore
+
+
+def mean(x: Expr, axis: Optional[Union[int, List[int]]] = None, keepdims: bool 
= False) -> Expr:
+    """Computes the mean of tensor elements over given axes.
+
+    Parameters
+    ----------
+    x : relax.Expr
+        The input data tensor
+
+    axis : Optional[Union[int, List[int]]]
+        Axis or axes along which a mean operation is performed.
+        The default, axis=None, will compute the mean of all elements in the 
input tensor.
+        Negative indexing is supported.
+
+    keepdims : bool
+        If this is set to True, the axes which are reduced are left in the 
result as dimensions
+        with size one.
+        With this option, the result will broadcast correctly against the 
input tensor.
+
+    Returns
+    -------
+    result : relax.Expr
+        The computed result.
+    """
+    if isinstance(axis, int):
+        axis = [axis]
+    return _ffi_api.mean(x, axis, keepdims)  # type: ignore
+
+
+def min(x: Expr, axis: Optional[Union[int, List[int]]] = None, keepdims: bool 
= False) -> Expr:
+    """Computes the min of tensor elements over given axes.
+
+    Parameters
+    ----------
+    x : relax.Expr
+        The input data tensor
+
+    axis : Optional[Union[int, List[int]]]
+        Axis or axes along which a min operation is performed.
+        The default, axis=None, will compute the min of all elements in the 
input tensor.
+        Negative indexing is supported.
+
+    keepdims : bool
+        If this is set to True, the axes which are reduced are left in the 
result as dimensions
+        with size one.
+        With this option, the result will broadcast correctly against the 
input tensor.
+
+    Returns
+    -------
+    result : relax.Expr
+        The computed result.
+    """
+    if isinstance(axis, int):
+        axis = [axis]
+    return _ffi_api.min(x, axis, keepdims)  # type: ignore
+
+
+def prod(x: Expr, axis: Optional[Union[int, List[int]]] = None, keepdims: bool 
= False) -> Expr:
+    """Computes the product of tensor elements over given axes.
+
+    Parameters
+    ----------
+    x : relax.Expr
+        The input data tensor
+
+    axis : Optional[Union[int, List[int]]]
+        Axis or axes along which a product is performed.
+        The default, axis=None, will compute the product of all elements of 
the input tensor.
+        Negative indexing is supported.
+
+    keepdims : bool
+        If this is set to True, the axes which are reduced are left in the 
result as
+        dimensions with size one.
+        With this option, the result will broadcast correctly against the 
input tensor.
+
+    Returns
+    -------
+    result : relax.Expr
+        The computed result.
+    """
+    if isinstance(axis, int):
+        axis = [axis]
+    return _ffi_api.prod(x, axis, keepdims)  # type: ignore
+
+
+def std(x: Expr, axis: Optional[Union[int, List[int]]] = None, keepdims: bool 
= False) -> Expr:
+    """Computes the standard deviation of tensor elements over given axes.
+
+    Parameters
+    ----------
+    x : relax.Expr
+        The input data tensor
+
+    axis : Optional[Union[int, List[int]]]
+        Axis or axes along which a standard deviation is performed.
+        The default, axis=None, will compute the std of all elements of the 
input tensor.
+        Negative indexing is supported.
+
+    keepdims : bool
+        If this is set to True, the axes which are reduced are left in the 
result as
+        dimensions with size one.
+        With this option, the result will broadcast correctly against the 
input tensor.
+
+    Returns
+    -------
+    result : relax.Expr
+        The computed result.
+    """
+    if isinstance(axis, int):
+        axis = [axis]
+    return _ffi_api.std(x, axis, keepdims)  # type: ignore
+
+
+def sum(x: Expr, axis: Optional[Union[int, List[int]]] = None, keepdims: bool 
= False) -> Expr:
+    """Computes the sum of tensor elements over given axes.
+
+    Parameters
+    ----------
+    x : relax.Expr
+        The input data tensor
+
+    axis : Optional[Union[int, List[int]]]
+        Axis or axes along which a sum is performed.
+        The default, axis=None, will sum all of the elements of the input 
tensor.
+        Negative indexing is supported.
+
+    keepdims : bool
+        If this is set to True, the axes which are reduced are left in the 
result as
+        dimensions with size one.
+        With this option, the result will broadcast correctly against the 
input tensor.
+
+    Returns
+    -------
+    result : relax.Expr
+        The computed result.
+    """
+    if isinstance(axis, int):
+        axis = [axis]
+    return _ffi_api.sum(x, axis, keepdims)  # type: ignore
+
+
+def variance(x: Expr, axis: Optional[Union[int, List[int]]] = None, keepdims: 
bool = False) -> Expr:
+    """Computes the variance of tensor elements over given axes.
+
+    Parameters
+    ----------
+    x : relax.Expr
+        The input data tensor
+
+    axis : Optional[Union[int, List[int]]]
+        Axis or axes along which a variance operation is performed.
+        The default, axis=None, will compute the variance of all elements in 
the input tensor.
+        Negative indexing is supported.
+
+    keepdims : bool
+        If this is set to True, the axes which are reduced are left in the 
result as dimensions
+        with size one.
+        With this option, the result will broadcast correctly against the 
input tensor.
+
+    Returns
+    -------
+    result : relax.Expr
+        The computed result.
+    """
+    if isinstance(axis, int):
+        axis = [axis]
+    return _ffi_api.variance(x, axis, keepdims)  # type: ignore
diff --git a/python/tvm/script/ir_builder/relax/ir.py 
b/python/tvm/script/ir_builder/relax/ir.py
index a5cb574a06..47779a6024 100644
--- a/python/tvm/script/ir_builder/relax/ir.py
+++ b/python/tvm/script/ir_builder/relax/ir.py
@@ -63,15 +63,24 @@ from tvm.relax.op import (
     less_equal,
     log,
     make_closure,
+    max,
+    mean,
     memory,
+    min,
     multiply,
     negative,
     not_equal,
     null_value,
     print,
+    prod,
     reshape,
     round,
     shape_of,
+    std,
+    strided_slice,
+    sum,
+    take,
+    variance,
     sigmoid,
     sign,
     sin,
@@ -486,7 +495,10 @@ __all__ = [
     "less_equal",
     "log",
     "make_closure",
+    "max",
+    "mean",
     "memory",
+    "min",
     "multiply",
     "negative",
     "not_equal",
@@ -494,10 +506,15 @@ __all__ = [
     "output",
     "prim_value",
     "print",
+    "prod",
     "reshape",
     "round",
     "shape",
     "shape_of",
+    "std",
+    "str",
+    "strided_slice",
+    "sum",
     "sigmoid",
     "sign",
     "sin",
@@ -511,5 +528,6 @@ __all__ = [
     "tan",
     "tanh",
     "tuple",
+    "variance",
     "unique",
 ]
diff --git a/src/relax/op/tensor/statistical.cc 
b/src/relax/op/tensor/statistical.cc
new file mode 100644
index 0000000000..41b99fbe36
--- /dev/null
+++ b/src/relax/op/tensor/statistical.cc
@@ -0,0 +1,96 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file statistical.cc
+ * \brief Statistical operators.
+ */
+
+#include "statistical.h"
+
+#include <vector>
+
+namespace tvm {
+namespace relax {
+
+StructInfo InferStructInfoStatistical(const Call& call, const BlockBuilder& 
ctx) {
+  TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx);
+  const auto* attrs = call->attrs.as<StatisticalAttrs>();
+
+  std::vector<int> axes;
+  if (!data_sinfo->IsUnknownNdim() && attrs->axis.defined()) {
+    axes = NormalizeAxes(call, ctx, data_sinfo->ndim, attrs->axis.value());
+  }
+
+  int out_ndim;
+  if (attrs->keepdims) {
+    out_ndim = data_sinfo->ndim;
+  } else if (!attrs->axis.defined()) {
+    out_ndim = 0;
+  } else if (data_sinfo->IsUnknownNdim()) {
+    out_ndim = kUnknownNDim;
+  } else {
+    out_ndim = data_sinfo->ndim - axes.size();
+    ICHECK_GE(out_ndim, 0);
+  }
+
+  // The inference rule for reduction operator output shapes:
+  // - axes is None, keepdims is false -> return the zero-rank shape;
+  // - axes is None, keepdims is true -> return the shape whose ndim is the 
same as input and every
+  // value is 1.
+  // - axes is not None, keepdims is false -> the returned shape does not 
contain the input axes.
+  // - axes is not None, keepdims is true -> the returned shape has value 1 at 
the positions of the
+  // input axes
+  const auto* data_shape = data_sinfo->shape.as<ShapeExprNode>();
+  if (data_shape == nullptr) {
+    if (!attrs->axis.defined() && attrs->keepdims && out_ndim != kUnknownNDim) 
{
+      return TensorStructInfo(
+          ShapeExpr(Array<PrimExpr>(out_ndim, IntImm(DataType::Int(64), 
/*value=*/1))),
+          data_sinfo->dtype);
+    } else {
+      return out_ndim == 0 ? TensorStructInfo(ShapeExpr(Array<PrimExpr>()), 
data_sinfo->dtype)
+                           : TensorStructInfo(data_sinfo->dtype, out_ndim);
+    }
+  }
+
+  Array<PrimExpr> out_shape;
+  out_shape.reserve(out_ndim);
+  for (int i = 0; i < data_sinfo->ndim; ++i) {
+    if (attrs->axis.defined() && std::find(axes.begin(), axes.end(), i) == 
axes.end()) {
+      out_shape.push_back(data_shape->values[i]);
+    } else if (attrs->keepdims) {
+      out_shape.push_back(IntImm(DataType::Int(64), /*value=*/1));
+    }
+  }
+  ICHECK_EQ(static_cast<int>(out_shape.size()), out_ndim);
+  return TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype);
+}
+
+TVM_REGISTER_NODE_TYPE(StatisticalAttrs);
+
+RELAX_REGISTER_STATISTICAL_OP_INTERFACE(max);
+RELAX_REGISTER_STATISTICAL_OP_INTERFACE(mean);
+RELAX_REGISTER_STATISTICAL_OP_INTERFACE(min);
+RELAX_REGISTER_STATISTICAL_OP_INTERFACE(prod);
+RELAX_REGISTER_STATISTICAL_OP_INTERFACE(std);
+RELAX_REGISTER_STATISTICAL_OP_INTERFACE(sum);
+RELAX_REGISTER_STATISTICAL_OP_INTERFACE(variance);
+
+}  // namespace relax
+}  // namespace tvm
diff --git a/src/relax/op/tensor/statistical.h 
b/src/relax/op/tensor/statistical.h
new file mode 100644
index 0000000000..7d322d1129
--- /dev/null
+++ b/src/relax/op/tensor/statistical.h
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file statistical.h
+ * \brief The functions to make Relax statistical operator calls.
+ */
+#ifndef TVM_RELAX_OP_TENSOR_STATISTICAL_H_
+#define TVM_RELAX_OP_TENSOR_STATISTICAL_H_
+
+#include <tvm/relax/attrs/statistical.h>
+
+#include <algorithm>
+#include <utility>
+
+#include "../op_common.h"
+
+namespace tvm {
+namespace relax {
+
+/*!
+ * \brief Quick helper macro
+ * - Expose a make function to construct the node.
+ * - Register op to the registry.
+ * \param OpName The name of operator to register. The name passed in will
+ *  1. be prepended with a prefix "relax.op." as the FFI identifier string for 
the make function,
+ *  2. be prepended with a prefix "relax." as the identifier string in the 
operator registry.
+ */
+#define RELAX_REGISTER_STATISTICAL_OP_INTERFACE(OpName)                  \
+  Expr OpName(Expr x, Optional<Array<Integer>> axis, bool keepdims) {    \
+    ObjectPtr<StatisticalAttrs> attrs = make_object<StatisticalAttrs>(); \
+    attrs->axis = std::move(axis);                                       \
+    attrs->keepdims = keepdims;                                          \
+    static const Op& op = Op::Get("relax." #OpName);                     \
+    return Call(op, {std::move(x)}, Attrs{attrs}, {});                   \
+  }                                                                      \
+  TVM_REGISTER_GLOBAL("relax.op." #OpName).set_body_typed(OpName);       \
+  TVM_REGISTER_OP("relax." #OpName)                                      \
+      .set_num_inputs(1)                                                 \
+      .add_argument("x", "Tensor", "The input data tensor")              \
+      .set_attr<FInferStructInfo>("FInferStructInfo", 
InferStructInfoStatistical)
+
+/*!
+ * \brief Computes the maximum value of tensor elements over given axes.
+ * \param x The input data tensor
+ * \param axis Axis or axes along which a max is performed. Being `NullOpt` 
means to max all the
+ * elements of the input tensor
+ * \param keepdims If this is set to True, the axes which are reduced are left 
in the result as
+ * dimensions with size one. With this option, the result will broadcast 
correctly against the
+ * input tensor.
+ * \return The result after reduction.
+ */
+Expr max(Expr x, Optional<Array<Integer>> axis, bool keepdims);
+
+/*! \brief Computes the mean of tensor elements over given axes. */
+Expr mean(Expr x, Optional<Array<Integer>> axis, bool keepdims);
+
+/*! \brief Computes the min of tensor elements over given axes. */
+Expr min(Expr x, Optional<Array<Integer>> axis, bool keepdims);
+
+/*! \brief Computes the product of tensor elements over given axes. */
+Expr prod(Expr x, Optional<Array<Integer>> axis, bool keepdims);
+
+/*! \brief Computes the standard deviation of tensor elements over given axes. 
*/
+Expr std(Expr x, Optional<Array<Integer>> axis, bool keepdims);
+
+/*! \brief Computes the sum of tensor elements over given axes. */
+Expr sum(Expr x, Optional<Array<Integer>> axis, bool keepdims);
+
+/*! \brief Computes the variance of tensor elements over given axes. */
+Expr variance(Expr x, Optional<Array<Integer>> axis, bool keepdims);
+
+}  // namespace relax
+}  // namespace tvm
+
+#endif  // TVM_RELAX_OP_TENSOR_STATISTICAL_H_
diff --git a/tests/python/relax/test_op_statistical.py 
b/tests/python/relax/test_op_statistical.py
new file mode 100644
index 0000000000..b1bdd8e44d
--- /dev/null
+++ b/tests/python/relax/test_op_statistical.py
@@ -0,0 +1,204 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import pytest
+import tvm
+import tvm.testing
+from tvm import relax, tir
+from tvm import TVMError
+from tvm.ir import Op
+from tvm.script import relax as R
+
+
+def test_op_correctness():
+    x = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32"))
+    assert relax.op.max(x).op == Op.get("relax.max")
+    assert relax.op.mean(x).op == Op.get("relax.mean")
+    assert relax.op.min(x).op == Op.get("relax.min")
+    assert relax.op.prod(x).op == Op.get("relax.prod")
+    assert relax.op.std(x).op == Op.get("relax.std")
+    assert relax.op.sum(x).op == Op.get("relax.sum")
+    assert relax.op.variance(x).op == Op.get("relax.variance")
+
+
+def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: 
relax.StructInfo):
+    ret = bb.normalize(call)
+    tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo)
+
+
+def test_statistical_infer_struct_info():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32"))
+    x1 = relax.Var("x", R.Tensor("float32", ndim=4))
+    x2 = relax.Var("x", R.Tensor("float32"))
+    x3 = relax.Var("x", R.Tensor((2, 3, 4, 5)))
+
+    _check_inference(bb, relax.op.sum(x0, axis=[1, 2]), 
relax.TensorStructInfo((2, 5), "float32"))
+    _check_inference(
+        bb,
+        relax.op.sum(x0, axis=[1, 2], keepdims=True),
+        relax.TensorStructInfo((2, 1, 1, 5), "float32"),
+    )
+    _check_inference(bb, relax.op.sum(x0, axis=None), 
relax.TensorStructInfo((), "float32"))
+    _check_inference(
+        bb,
+        relax.op.sum(x0, axis=None, keepdims=True),
+        relax.TensorStructInfo((1, 1, 1, 1), "float32"),
+    )
+    _check_inference(
+        bb, relax.op.mean(x1, axis=[1, 2]), 
relax.TensorStructInfo(dtype="float32", ndim=2)
+    )
+    _check_inference(
+        bb,
+        relax.op.mean(x1, axis=[1, 2], keepdims=True),
+        relax.TensorStructInfo(dtype="float32", ndim=4),
+    )
+    _check_inference(bb, relax.op.mean(x1, axis=None), 
relax.TensorStructInfo((), "float32"))
+    _check_inference(
+        bb,
+        relax.op.mean(x1, axis=None, keepdims=True),
+        relax.TensorStructInfo((1, 1, 1, 1), "float32"),
+    )
+    _check_inference(
+        bb, relax.op.variance(x2, axis=[1, 2]), 
relax.TensorStructInfo(dtype="float32")
+    )
+    _check_inference(
+        bb,
+        relax.op.variance(x2, axis=[1, 2], keepdims=True),
+        relax.TensorStructInfo(dtype="float32"),
+    )
+    _check_inference(bb, relax.op.variance(x2, axis=None), 
relax.TensorStructInfo((), "float32"))
+    _check_inference(
+        bb,
+        relax.op.variance(x2, axis=None, keepdims=True),
+        relax.TensorStructInfo(dtype="float32"),
+    )
+    _check_inference(bb, relax.op.max(x3, axis=[1, 2]), 
relax.TensorStructInfo((2, 5), dtype=""))
+    _check_inference(
+        bb,
+        relax.op.max(x3, axis=[1, 2], keepdims=True),
+        relax.TensorStructInfo((2, 1, 1, 5), dtype=""),
+    )
+    _check_inference(bb, relax.op.max(x3, axis=None), 
relax.TensorStructInfo((), dtype=""))
+    _check_inference(
+        bb,
+        relax.op.max(x3, axis=None, keepdims=True),
+        relax.TensorStructInfo((1, 1, 1, 1), dtype=""),
+    )
+    _check_inference(bb, relax.op.prod(x0, axis=[1, 2]), 
relax.TensorStructInfo((2, 5), "float32"))
+    _check_inference(
+        bb,
+        relax.op.prod(x0, axis=[1, 2], keepdims=True),
+        relax.TensorStructInfo((2, 1, 1, 5), "float32"),
+    )
+    _check_inference(bb, relax.op.std(x0, axis=[1, 2]), 
relax.TensorStructInfo((2, 5), "float32"))
+    _check_inference(
+        bb,
+        relax.op.std(x0, axis=[1, 2], keepdims=True),
+        relax.TensorStructInfo((2, 1, 1, 5), "float32"),
+    )
+    _check_inference(bb, relax.op.sum(x0, axis=[-1, -4]), 
relax.TensorStructInfo((3, 4), "float32"))
+    _check_inference(bb, relax.op.sum(x0, axis=[]), relax.TensorStructInfo((2, 
3, 4, 5), "float32"))
+
+
+def test_statistical_infer_struct_info_shape_symbolic():
+    bb = relax.BlockBuilder()
+    a = tir.Var("a", "int64")
+    b = tir.Var("b", "int64")
+    c = tir.Var("c", "int64")
+    d = tir.Var("d", "int64")
+    x = relax.Var("x", R.Tensor((a, b, c, d), "float32"))
+
+    _check_inference(bb, relax.op.min(x, axis=[1, 2]), 
relax.TensorStructInfo((a, d), "float32"))
+    _check_inference(
+        bb,
+        relax.op.min(x, axis=[1, 2], keepdims=True),
+        relax.TensorStructInfo((a, 1, 1, d), "float32"),
+    )
+    _check_inference(bb, relax.op.min(x, axis=None), 
relax.TensorStructInfo((), "float32"))
+    _check_inference(
+        bb,
+        relax.op.min(x, axis=None, keepdims=True),
+        relax.TensorStructInfo((1, 1, 1, 1), "float32"),
+    )
+
+
+def test_statistical_infer_struct_info_shape_var():
+    bb = relax.BlockBuilder()
+    s0 = relax.Var("s", relax.ShapeStructInfo(ndim=4))
+    s1 = relax.Var("s", relax.ShapeStructInfo())
+    x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32"))
+    x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32"))
+
+    _check_inference(bb, relax.op.max(x0), relax.TensorStructInfo((), 
dtype="float32"))
+    _check_inference(
+        bb, relax.op.max(x0, keepdims=True), relax.TensorStructInfo((1, 1, 1, 
1), dtype="float32")
+    )
+    _check_inference(
+        bb, relax.op.max(x0, axis=[2, 3]), 
relax.TensorStructInfo(dtype="float32", ndim=2)
+    )
+    _check_inference(
+        bb,
+        relax.op.max(x0, axis=[2, 3], keepdims=True),
+        relax.TensorStructInfo(dtype="float32", ndim=4),
+    )
+    _check_inference(bb, relax.op.max(x1), relax.TensorStructInfo((), 
dtype="float32"))
+    _check_inference(bb, relax.op.max(x1, keepdims=True), 
relax.TensorStructInfo(dtype="float32"))
+    _check_inference(bb, relax.op.max(x1, axis=[2, 3]), 
relax.TensorStructInfo(dtype="float32"))
+    _check_inference(
+        bb, relax.op.max(x1, axis=[2, 3], keepdims=True), 
relax.TensorStructInfo(dtype="float32")
+    )
+
+
+def test_statistical_infer_struct_info_more_input_dtype():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float16"))
+    x1 = relax.Var("x", R.Tensor((2, 3, 4, 5), "int8"))
+
+    _check_inference(bb, relax.op.sum(x0), relax.TensorStructInfo((), 
"float16"))
+    _check_inference(bb, relax.op.sum(x1), relax.TensorStructInfo((), "int8"))
+
+
+def test_statistical_infer_struct_info_axis_out_of_range_repetitive():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32"))
+    x1 = relax.Var("x", R.Tensor("float32", ndim=4))
+
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.mean(x0, axis=[4]))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.mean(x1, axis=[3, 3]))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.mean(x0, axis=[-1, 3]))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.mean(x1, axis=[-4, -4]))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.mean(x0, axis=[-5]))
+
+
+def test_statistical_infer_struct_info_wrong_input_type():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 4, 5)))
+    x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 4, 5), 
"float32")))
+
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.variance(x0))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.variance(x1))
+
+
+if __name__ == "__main__":
+    tvm.testing.main()
diff --git a/tests/python/relax/test_tvmscript_parser_op_statistical.py 
b/tests/python/relax/test_tvmscript_parser_op_statistical.py
new file mode 100644
index 0000000000..221d2a17a8
--- /dev/null
+++ b/tests/python/relax/test_tvmscript_parser_op_statistical.py
@@ -0,0 +1,174 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from typing import Optional, Union
+
+import tvm
+import tvm.script
+import tvm.testing
+from tvm import IRModule, relax
+from tvm.script import relax as R
+
+
+def _check(
+    parsed: Union[relax.Function, IRModule],
+    expect: Optional[Union[relax.Function, IRModule]],
+):
+    test = parsed.script(show_meta=True)
+    roundtrip_mod = tvm.script.from_source(test)
+    tvm.ir.assert_structural_equal(parsed, roundtrip_mod)
+    if expect:
+        tvm.ir.assert_structural_equal(parsed, expect)
+
+
+def test_sum():
+    @R.function
+    def foo(x: R.Tensor((1, 2, 3, 4), "float32")) -> R.Tensor((1, 3), 
"float32"):
+        gv: R.Tensor((1, 3), "float32") = R.sum(x, axis=[1, 3])
+        return gv
+
+    x = relax.Var("x", R.Tensor((1, 2, 3, 4), "float32"))
+    bb = relax.BlockBuilder()
+    with bb.function("foo", [x]):
+        gv = bb.emit(relax.op.sum(x, axis=[1, 3]))
+        bb.emit_func_output(gv)
+
+    _check(foo, bb.get()["foo"])
+
+
+def test_sum_without_specified_axis():
+    @R.function
+    def foo(x: R.Tensor((1, 2, 3, 4), "float32")) -> R.Tensor((), "float32"):
+        gv: R.Tensor((), "float32") = R.sum(x)
+        return gv
+
+    x = relax.Var("x", R.Tensor((1, 2, 3, 4), "float32"))
+    bb = relax.BlockBuilder()
+    with bb.function("foo", [x]):
+        gv = bb.emit(relax.op.sum(x))
+        bb.emit_func_output(gv)
+
+    _check(foo, bb.get()["foo"])
+
+
+def test_sum_keep_dims():
+    @R.function
+    def foo(x: R.Tensor((1, 2, 3, 4), "float32")) -> R.Tensor((1, 1, 3, 1), 
"float32"):
+        gv: R.Tensor((1, 1, 3, 1), "float32") = R.sum(x, axis=[1, 3], 
keepdims=True)
+        return gv
+
+    x = relax.Var("x", R.Tensor((1, 2, 3, 4), "float32"))
+    bb = relax.BlockBuilder()
+    with bb.function("foo", [x]):
+        gv = bb.emit(relax.op.sum(x, axis=[1, 3], keepdims=True))
+        bb.emit_func_output(gv)
+
+    _check(foo, bb.get()["foo"])
+
+
+def test_mean():
+    @R.function
+    def foo(x: R.Tensor((1, 2, 3, 4), "float32")) -> R.Tensor((1, 3), 
"float32"):
+        gv: R.Tensor((1, 3), "float32") = R.mean(x, axis=[1, 3])
+        return gv
+
+    x = relax.Var("x", R.Tensor((1, 2, 3, 4), "float32"))
+    bb = relax.BlockBuilder()
+    with bb.function("foo", [x]):
+        gv = bb.emit(relax.op.mean(x, axis=[1, 3]))
+        bb.emit_func_output(gv)
+
+    _check(foo, bb.get()["foo"])
+
+
+def test_variance():
+    @R.function
+    def foo(x: R.Tensor((1, 2, 3, 4), "float32")) -> R.Tensor((1,), "float32"):
+        gv: R.Tensor((1,), "float32") = R.variance(x, axis=[-1, -2, -3])
+        return gv
+
+    x = relax.Var("x", R.Tensor((1, 2, 3, 4), "float32"))
+    bb = relax.BlockBuilder()
+    with bb.function("foo", [x]):
+        gv = bb.emit(relax.op.variance(x, axis=[-1, -2, -3]))
+        bb.emit_func_output(gv)
+
+    _check(foo, bb.get()["foo"])
+
+
+def test_max():
+    @R.function
+    def foo(x: R.Tensor((1, 2, 3, 4), "float32")) -> R.Tensor((1, 1, 1, 1), 
"float32"):
+        gv: R.Tensor((1, 1, 1, 1), "float32") = R.variance(x, axis=[-1, -2, 
-3], keepdims=True)
+        return gv
+
+    x = relax.Var("x", R.Tensor((1, 2, 3, 4), "float32"))
+    bb = relax.BlockBuilder()
+    with bb.function("foo", [x]):
+        gv = bb.emit(relax.op.variance(x, axis=[-1, -2, -3], keepdims=True))
+        bb.emit_func_output(gv)
+
+    _check(foo, bb.get()["foo"])
+
+
+def test_min():
+    @R.function
+    def foo(x: R.Tensor((1, 2, 3, 4), "float32")) -> R.Tensor((1, 3, 4), 
"float32"):
+        gv: R.Tensor((1, 3, 4), "float32") = R.min(x, axis=1)
+        return gv
+
+    x = relax.Var("x", R.Tensor((1, 2, 3, 4), "float32"))
+    bb = relax.BlockBuilder()
+    with bb.function("foo", [x]):
+        gv = bb.emit(relax.op.min(x, axis=1))
+        bb.emit_func_output(gv)
+
+    _check(foo, bb.get()["foo"])
+
+
+def test_prod():
+    @R.function
+    def foo(x: R.Tensor((1, 2, 3, 4), "float32")) -> R.Tensor((1, 3, 4), 
"float32"):
+        gv: R.Tensor((1, 3, 4), "float32") = R.prod(x, axis=1)
+        return gv
+
+    x = relax.Var("x", R.Tensor((1, 2, 3, 4), "float32"))
+    bb = relax.BlockBuilder()
+    with bb.function("foo", [x]):
+        gv = bb.emit(relax.op.prod(x, axis=1))
+        bb.emit_func_output(gv)
+
+    _check(foo, bb.get()["foo"])
+
+
+def test_std():
+    @R.function
+    def foo(x: R.Tensor((1, 2, 3, 4), "float32")) -> R.Tensor((1, 3, 4), 
"float32"):
+        gv: R.Tensor((1, 3, 4), "float32") = R.std(x, axis=1)
+        return gv
+
+    x = relax.Var("x", R.Tensor((1, 2, 3, 4), "float32"))
+    bb = relax.BlockBuilder()
+    with bb.function("foo", [x]):
+        gv = bb.emit(relax.op.std(x, axis=1))
+        bb.emit_func_output(gv)
+
+    _check(foo, bb.get()["foo"])
+
+
+if __name__ == "__main__":
+    tvm.testing.main()


Reply via email to