This is an automated email from the ASF dual-hosted git repository.

andrewzhaoluo pushed a commit to branch aluo/play-with-layer-norm
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit ce5e9ae637a44bcea108edb90c03501d1a550238
Author: Andrew Zhao Luo <[email protected]>
AuthorDate: Wed Aug 24 13:28:20 2022 -0700

    stash work
---
 src/relay/op/nn/nn.cc                      |  21 ++
 src/relay/op/tensor/reduce.cc              | 285 +--------------------------
 src/relay/op/tensor/reduce.h               | 303 +++++++++++++++++++++++++++++
 src/relay/transforms/simplify_inference.cc |   9 +-
 4 files changed, 329 insertions(+), 289 deletions(-)

diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc
index 9e73c64564..4e5aaf7d67 100644
--- a/src/relay/op/nn/nn.cc
+++ b/src/relay/op/nn/nn.cc
@@ -35,12 +35,15 @@
 #include <tvm/topi/nn/softmax.h>
 
 #include <algorithm>
+#include <limits>
+#include <numeric>
 #include <string>
 #include <vector>
 
 #include "../../transforms/infer_layout_utils.h"
 #include "../make_op.h"
 #include "../op_common.h"
+#include "../tensor/reduce.h"
 #include "../type_relations.h"
 
 namespace tvm {
@@ -976,6 +979,22 @@ Expr MakeLayerNorm(Expr data, Expr gamma, Expr beta, int 
axis, double epsilon, b
 
 
TVM_REGISTER_GLOBAL("relay.op.nn._make.layer_norm").set_body_typed(MakeLayerNorm);
 
+Array<te::Tensor> LayerNormCompute(const Attrs& attrs, const 
Array<te::Tensor>& inputs,
+                                   const Type& out_type) {
+  IndexExpr count = tir::make_const(inputs[0]->dtype, 1);
+  const ReduceAttrs* param = attrs.as<ReduceAttrs>();
+  ICHECK(param != nullptr);
+  auto axes = param->axis;
+  for (int64_t i : GetReduceAxes(inputs[0]->shape.size(), param->axis, 
param->exclude)) {
+    count *= inputs[0]->shape[i];
+  }
+  // Although count is created as inputs[0]->dtype,
+  // its type may be changed (promoted) during multiplication
+  count = cast(inputs[0]->dtype, count);
+  auto res = ReduceCompute(attrs, inputs, out_type, topi::sum);
+  return {topi::divide(res[0], count)};
+}
+
 RELAY_REGISTER_OP("nn.layer_norm")
     .describe(R"code(
 )code" TVM_ADD_FILELINE)
@@ -986,6 +1005,8 @@ RELAY_REGISTER_OP("nn.layer_norm")
     .add_argument("beta", "Tensor", "The beta offset factor.")
     .set_attr<FInferCorrectLayout>("FInferCorrectLayout",
                                    
NormalizationInferCorrectLayout<LayerNormAttrs>)
+    .set_attr<TOpPattern>("TOpPattern", kElemWise)
+    .set_attr<FTVMCompute>("FTVMCompute", LayerNormCompute)
     .set_support_level(1)
     .add_type_rel("LayerNorm", LayerNormRel);
 
diff --git a/src/relay/op/tensor/reduce.cc b/src/relay/op/tensor/reduce.cc
index 2b1afc6e55..37e884e5f5 100644
--- a/src/relay/op/tensor/reduce.cc
+++ b/src/relay/op/tensor/reduce.cc
@@ -21,6 +21,8 @@
  * \file reduce.cc
  * \brief Reduction operators.
  */
+#include "reduce.h"
+
 #include <tvm/relay/attrs/reduce.h>
 #include <tvm/relay/expr.h>
 #include <tvm/relay/op.h>
@@ -41,289 +43,6 @@ TVM_REGISTER_NODE_TYPE(ReduceAttrs);
 TVM_REGISTER_NODE_TYPE(ArgReduceAttrs);
 TVM_REGISTER_NODE_TYPE(VarianceAttrs);
 
-/*!
- * \brief GetReduceAxes, get the new axis from indim and other arguments
- * \param indim Number of dimensions of input data.
- * \param axis The input axis vector.
- * \param exclude Whether 'axis' input given is the excluded axis.
- * \return r_axes The new reduced axes of the output.
- */
-inline std::vector<int64_t> GetReduceAxes(const uint32_t indim, const 
Array<Integer>& inaxis,
-                                          bool exclude) {
-  if (!inaxis.defined() || inaxis.empty()) {
-    std::vector<int64_t> r_axes(indim);
-    std::iota(r_axes.begin(), r_axes.end(), 0);
-    return r_axes;
-  }
-
-  std::vector<int64_t> in_axes;
-  for (auto i : inaxis) {
-    int64_t axis = i->value;
-    if (axis < 0) {
-      axis = axis + indim;
-    }
-
-    // Check out of bounds error
-    ICHECK(axis >= 0) << "Axis out of bounds in reduce operator.";
-    ICHECK(axis < indim) << "Axis out of bounds in reduce operator.";
-    in_axes.push_back(axis);
-  }
-
-  ICHECK(in_axes[in_axes.size() - 1] < indim)
-      << "Reduction axis " << in_axes[in_axes.size() - 1] << " exceeds input 
dimensions " << indim;
-
-  std::sort(in_axes.begin(), in_axes.end());
-
-  if (!exclude) {
-    return in_axes;
-  }
-
-  auto r_size = indim - in_axes.size();
-  std::vector<int64_t> r_axes(r_size);
-  for (uint32_t i = 0, j = 0, k = 0; i < indim; ++i) {
-    if (j < in_axes.size() && in_axes[j] == i) {
-      ++j;
-      continue;
-    }
-    r_axes[k++] = i;
-  }
-  return r_axes;
-}
-
-// Get axis under exclude condition.
-Array<Integer> GetExcludeAxes(size_t indim, const Array<Integer>& inaxis) {
-  ICHECK(inaxis.defined()) << "Cannot set exclude when axis=None";
-  std::vector<bool> axis_flag(indim, true);
-  for (auto i : inaxis) {
-    int64_t axis = i->value;
-    if (axis < 0) {
-      axis = axis + static_cast<int64_t>(indim);
-    }
-    // Check out of bounds error
-    ICHECK_GE(axis, 0) << "Axis out of bounds in reduce operator.";
-    ICHECK_LT(axis, static_cast<int64_t>(indim)) << "Axis out of bounds in 
reduce operator.";
-    axis_flag[axis] = false;
-  }
-
-  Array<Integer> r_axes;
-
-  for (size_t i = 0; i < axis_flag.size(); ++i) {
-    if (axis_flag[i]) {
-      r_axes.push_back(static_cast<int>(i));
-    }
-  }
-  return r_axes;
-}
-
-// Return the modified layout for AlterOpLayout pass.
-template <typename T>
-InferCorrectLayoutOutput ReduceInferCorrectLayout(const Attrs& attrs,
-                                                  const Array<Layout>& 
new_in_layouts,
-                                                  const Array<Layout>& 
old_in_layouts,
-                                                  const 
Array<tvm::relay::Type>& old_in_types) {
-  const auto* attrs_ptr = attrs.as<T>();
-  ICHECK(attrs_ptr);
-  ObjectPtr<T> params = make_object<T>(*attrs_ptr);
-
-  // Get the reduce axes.
-  Array<Array<IndexExpr>> old_in_shapes;
-  for (auto old_in_t : old_in_types) {
-    ICHECK(old_in_t.as<TensorTypeNode>());
-    old_in_shapes.push_back(old_in_t.as<TensorTypeNode>()->shape);
-  }
-  uint32_t indim = old_in_shapes[0].size();
-  auto r_axes = GetReduceAxes(indim, params->axis, params->exclude);
-
-  Layout inferred_in = Layout::Undef();
-  Layout inferred_out = Layout::Undef();
-
-  // Infer [in_layout, out_layout, new_r_axes] from old_in_layout or 
new_in_layout
-  auto infer = [&](const Layout& layout) {
-    // 1) Collect the original axes
-    std::unordered_set<std::string> old_r_dims;
-    for (auto r_axis : r_axes) {
-      old_r_dims.emplace(old_in_layouts[0][r_axis].name());
-    }
-
-    // 2) Collect the new axes by walking new_layout.
-    tvm::Array<tvm::Integer> new_r_axes;
-    std::string inferred_in_string = "";
-    std::string inferred_out_string = "";
-    auto push_new_axis = [&](const std::string& layout_dim, int axis) {
-      if ((old_r_dims.count(layout_dim) && !params->exclude) ||
-          (!old_r_dims.count(layout_dim) && params->exclude)) {
-        new_r_axes.push_back(tvm::Integer(axis));
-        return true;
-      }
-      return false;
-    };
-    for (size_t axis_index = 0; axis_index < layout->axes.size(); 
++axis_index) {
-      const auto& layout_axis = LayoutAxis::Get(layout->axes[axis_index]);
-      const std::string& layout_dim = layout_axis.name();
-      if (layout_axis.IsPrimal()) {
-        push_new_axis(layout_dim, axis_index);
-        inferred_in_string += layout_dim;
-        if (!old_r_dims.count(layout_dim) || params->keepdims) {
-          inferred_out_string += layout_dim;
-        }
-      } else {
-        // For example, if the original layout is NCHW, the new layout is 
NCHW8c, and the original
-        // reduce axes is [1], the new reduce axes become [1, 4].
-        auto primal_dim = layout_axis.ToPrimal().name();
-        auto packed_dim = std::to_string(layout.FactorOf(layout_axis)) + 
layout_dim;
-        inferred_in_string += packed_dim;
-        if (push_new_axis(primal_dim, axis_index)) {
-          if (params->exclude) {
-            // The primal axis is not reduced, so keep the input packed dim.
-            inferred_out_string += packed_dim;
-          } else if (params->keepdims) {
-            // If the primal axis is part of reduce axes in the original 
layout, the inner dim
-            // becomes 1 after reduction.
-            inferred_out_string += "1" + layout_dim;
-          }
-        } else {
-          inferred_out_string += packed_dim;
-        }
-      }
-    }
-
-    // 3) Set the new axis and layout.
-    return std::make_tuple(Layout(inferred_in_string), 
Layout(inferred_out_string), new_r_axes);
-  };
-
-  std::string new_layout_string;
-  Array<Integer> new_r_axes;
-  Array<Layout> new_input_layouts;
-
-  auto check_num_input_layouts = [](Array<Layout> in_layouts) {
-    // The second case is for variance op
-    ICHECK(in_layouts.size() == 1 || in_layouts.size() == 2);
-  };
-
-  if (new_in_layouts.defined() && r_axes.size()) {
-    // Adapt to new layout. The axis has to change. Record original reduce 
axes. Convert to the
-    // modified layout axes.
-    check_num_input_layouts(new_in_layouts);
-    check_num_input_layouts(old_in_layouts);
-
-    // Get inferred_in and inferred_out from new_in_layout.
-    std::tie(inferred_in, inferred_out, new_r_axes) = infer(new_in_layouts[0]);
-    params->axis = new_r_axes;
-  } else if (old_in_layouts.defined()) {
-    check_num_input_layouts(old_in_layouts);
-
-    // If the new layout is undefined, get inferred_in and inferred_out from 
old_in_layout.
-    if (old_in_layouts[0].defined()) {
-      std::tie(inferred_in, inferred_out, std::ignore) = 
infer(old_in_layouts[0]);
-    }
-  }
-
-  new_input_layouts.push_back(inferred_in);
-
-  if (old_in_layouts.size() == 2) {
-    new_input_layouts.push_back(inferred_in);
-  }
-
-  return InferCorrectLayoutOutput(new_input_layouts, {inferred_out}, 
Attrs(params));
-}
-
-template <typename F>
-Array<te::Tensor> ReduceCompute(const Attrs& attrs, const Array<te::Tensor>& 
inputs,
-                                const Type& out_type, F f) {
-  const ReduceAttrs* param = attrs.as<ReduceAttrs>();
-  ICHECK(param != nullptr);
-  if (inputs[0]->shape.size() == 0) {
-    return {topi::identity(inputs[0])};
-  }
-  auto axes = param->axis;
-  if (param->exclude) {
-    axes = GetExcludeAxes(inputs[0]->shape.size(), param->axis);
-    if (axes.size() == 0) {
-      return {topi::identity(inputs[0])};
-    }
-  }
-
-  return {f(inputs[0], axes, param->keepdims, false)};
-}
-
-template <typename F>
-Array<te::Tensor> ArgReduceCompute(const Attrs& attrs, const 
Array<te::Tensor>& inputs,
-                                   const Type& out_type, F f) {
-  const ArgReduceAttrs* param = attrs.as<ArgReduceAttrs>();
-  ICHECK(param != nullptr);
-  if (inputs[0]->shape.size() == 0) {
-    return {topi::identity(inputs[0])};
-  }
-  auto axes = param->axis;
-  if (param->exclude) {
-    axes = GetExcludeAxes(inputs[0]->shape.size(), param->axis);
-    if (axes.size() == 0) {
-      return {topi::identity(inputs[0])};
-    }
-  }
-
-  return {f(inputs[0], axes, param->keepdims, false, 
param->select_last_index)};
-}
-
-/*!
- * \brief ReduceShapeImpl get the outshape for the reduction operator
- * \param in_shape Shape of input data.
- * \param param Attrs details.
- * \param reporter The reporter to report solution to.
- * \return oshape Output shape inferred.
- * \tparam AttrsType The attribute type.
- */
-template <typename AttrsType>
-inline std::vector<IndexExpr> ReduceShapeImpl(const std::vector<IndexExpr>& 
in_shape,
-                                              const AttrsType* param,
-                                              const TypeReporter& reporter) {
-  uint32_t indim = in_shape.size();
-  auto r_axes = GetReduceAxes(indim, param->axis, param->exclude);
-  if (!r_axes.size()) {
-    return in_shape;
-  }
-
-  auto max_shape = tir::make_const(DataType::Int(64), 1);
-  bool is_dynamic_input = false;
-  for (int64_t axis : r_axes) {
-    if (in_shape[axis].as<IntImmNode>()) {
-      max_shape *= in_shape[axis];
-    } else {
-      is_dynamic_input = true;
-      break;
-    }
-  }
-
-  if (is_dynamic_input) {
-    ICHECK(reporter->Assert(
-        max_shape < tir::make_const(DataType::Int(64), 
std::numeric_limits<int32_t>::max())))
-        << "The maximum possible index of reduced shape cannot be more than 
int32 max.";
-  }
-
-  if (param->keepdims) {
-    std::vector<IndexExpr> oshape(in_shape);
-    for (unsigned i = 0, j = 0; i < indim; ++i) {
-      if (j >= r_axes.size() || !(r_axes[j] == i)) {
-        continue;
-      }
-      oshape[i] = 1;
-      ++j;
-    }
-    return oshape;
-  } else {
-    auto osize = indim - r_axes.size();
-    std::vector<IndexExpr> oshape(osize);
-    for (unsigned i = 0, j = 0, k = 0; i < indim; ++i) {
-      if (j < r_axes.size() && (r_axes[j] == i)) {
-        ++j;
-        continue;
-      }
-      oshape[k++] = in_shape[i];
-    }
-    return oshape;
-  }
-}
-
 template <class T>
 bool GenericReduceRel(const Array<Type>& types, int num_inputs, const Attrs& 
attrs,
                       const TypeReporter& reporter) {
diff --git a/src/relay/op/tensor/reduce.h b/src/relay/op/tensor/reduce.h
new file mode 100644
index 0000000000..54b0f3951d
--- /dev/null
+++ b/src/relay/op/tensor/reduce.h
@@ -0,0 +1,303 @@
+#include <tvm/relay/attrs/reduce.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/op.h>
+#include <tvm/topi/elemwise.h>
+#include <tvm/topi/reduction.h>
+
+#include <limits>
+#include <numeric>
+
+#include "../make_op.h"
+#include "../op_common.h"
+#include "../type_relations.h"
+
+#ifndef TVM_RELAY_OP_TENSOR_REDUCE_H
+#define TVM_RELAY_OP_TENSOR_REDUCE_H
+
+namespace tvm {
+namespace relay {
+/*! \brief GetReduceAxes,
+    get the new axis from indim and other arguments * \param indim Number of 
dimensions of input
+            data.* \param axis The input axis
+                vector.* \param exclude Whether 'axis' input given is the 
excluded
+                    axis.* \return r_axes The new reduced axes of the output. 
*/
+inline std::vector<int64_t> GetReduceAxes(const uint32_t indim, const 
Array<Integer>& inaxis,
+                                          bool exclude) {
+  if (!inaxis.defined() || inaxis.empty()) {
+    std::vector<int64_t> r_axes(indim);
+    std::iota(r_axes.begin(), r_axes.end(), 0);
+    return r_axes;
+  }
+
+  std::vector<int64_t> in_axes;
+  for (auto i : inaxis) {
+    int64_t axis = i->value;
+    if (axis < 0) {
+      axis = axis + indim;
+    }
+
+    // Check out of bounds error
+    ICHECK(axis >= 0) << "Axis out of bounds in reduce operator.";
+    ICHECK(axis < indim) << "Axis out of bounds in reduce operator.";
+    in_axes.push_back(axis);
+  }
+
+  ICHECK(in_axes[in_axes.size() - 1] < indim)
+      << "Reduction axis " << in_axes[in_axes.size() - 1] << " exceeds input 
dimensions " << indim;
+
+  std::sort(in_axes.begin(), in_axes.end());
+
+  if (!exclude) {
+    return in_axes;
+  }
+
+  auto r_size = indim - in_axes.size();
+  std::vector<int64_t> r_axes(r_size);
+  for (uint32_t i = 0, j = 0, k = 0; i < indim; ++i) {
+    if (j < in_axes.size() && in_axes[j] == i) {
+      ++j;
+      continue;
+    }
+    r_axes[k++] = i;
+  }
+  return r_axes;
+}
+
+// Get axis under exclude condition.
+Array<Integer> GetExcludeAxes(size_t indim, const Array<Integer>& inaxis) {
+  ICHECK(inaxis.defined()) << "Cannot set exclude when axis=None";
+  std::vector<bool> axis_flag(indim, true);
+  for (auto i : inaxis) {
+    int64_t axis = i->value;
+    if (axis < 0) {
+      axis = axis + static_cast<int64_t>(indim);
+    }
+    // Check out of bounds error
+    ICHECK_GE(axis, 0) << "Axis out of bounds in reduce operator.";
+    ICHECK_LT(axis, static_cast<int64_t>(indim)) << "Axis out of bounds in 
reduce operator.";
+    axis_flag[axis] = false;
+  }
+
+  Array<Integer> r_axes;
+
+  for (size_t i = 0; i < axis_flag.size(); ++i) {
+    if (axis_flag[i]) {
+      r_axes.push_back(static_cast<int>(i));
+    }
+  }
+  return r_axes;
+}
+
+// Return the modified layout for AlterOpLayout pass.
+template <typename T>
+InferCorrectLayoutOutput ReduceInferCorrectLayout(const Attrs& attrs,
+                                                  const Array<Layout>& 
new_in_layouts,
+                                                  const Array<Layout>& 
old_in_layouts,
+                                                  const 
Array<tvm::relay::Type>& old_in_types) {
+  const auto* attrs_ptr = attrs.as<T>();
+  ICHECK(attrs_ptr);
+  ObjectPtr<T> params = make_object<T>(*attrs_ptr);
+
+  // Get the reduce axes.
+  Array<Array<IndexExpr>> old_in_shapes;
+  for (auto old_in_t : old_in_types) {
+    ICHECK(old_in_t.as<TensorTypeNode>());
+    old_in_shapes.push_back(old_in_t.as<TensorTypeNode>()->shape);
+  }
+  uint32_t indim = old_in_shapes[0].size();
+  auto r_axes = GetReduceAxes(indim, params->axis, params->exclude);
+
+  Layout inferred_in = Layout::Undef();
+  Layout inferred_out = Layout::Undef();
+
+  // Infer [in_layout, out_layout, new_r_axes] from old_in_layout or 
new_in_layout
+  auto infer = [&](const Layout& layout) {
+    // 1) Collect the original axes
+    std::unordered_set<std::string> old_r_dims;
+    for (auto r_axis : r_axes) {
+      old_r_dims.emplace(old_in_layouts[0][r_axis].name());
+    }
+
+    // 2) Collect the new axes by walking new_layout.
+    tvm::Array<tvm::Integer> new_r_axes;
+    std::string inferred_in_string = "";
+    std::string inferred_out_string = "";
+    auto push_new_axis = [&](const std::string& layout_dim, int axis) {
+      if ((old_r_dims.count(layout_dim) && !params->exclude) ||
+          (!old_r_dims.count(layout_dim) && params->exclude)) {
+        new_r_axes.push_back(tvm::Integer(axis));
+        return true;
+      }
+      return false;
+    };
+    for (size_t axis_index = 0; axis_index < layout->axes.size(); 
++axis_index) {
+      const auto& layout_axis = LayoutAxis::Get(layout->axes[axis_index]);
+      const std::string& layout_dim = layout_axis.name();
+      if (layout_axis.IsPrimal()) {
+        push_new_axis(layout_dim, axis_index);
+        inferred_in_string += layout_dim;
+        if (!old_r_dims.count(layout_dim) || params->keepdims) {
+          inferred_out_string += layout_dim;
+        }
+      } else {
+        // For example, if the original layout is NCHW, the new layout is 
NCHW8c, and the original
+        // reduce axes is [1], the new reduce axes become [1, 4].
+        auto primal_dim = layout_axis.ToPrimal().name();
+        auto packed_dim = std::to_string(layout.FactorOf(layout_axis)) + 
layout_dim;
+        inferred_in_string += packed_dim;
+        if (push_new_axis(primal_dim, axis_index)) {
+          if (params->exclude) {
+            // The primal axis is not reduced, so keep the input packed dim.
+            inferred_out_string += packed_dim;
+          } else if (params->keepdims) {
+            // If the primal axis is part of reduce axes in the original 
layout, the inner dim
+            // becomes 1 after reduction.
+            inferred_out_string += "1" + layout_dim;
+          }
+        } else {
+          inferred_out_string += packed_dim;
+        }
+      }
+    }
+
+    // 3) Set the new axis and layout.
+    return std::make_tuple(Layout(inferred_in_string), 
Layout(inferred_out_string), new_r_axes);
+  };
+
+  std::string new_layout_string;
+  Array<Integer> new_r_axes;
+  Array<Layout> new_input_layouts;
+
+  auto check_num_input_layouts = [](Array<Layout> in_layouts) {
+    // The second case is for variance op
+    ICHECK(in_layouts.size() == 1 || in_layouts.size() == 2);
+  };
+
+  if (new_in_layouts.defined() && r_axes.size()) {
+    // Adapt to new layout. The axis has to change. Record original reduce 
axes. Convert to the
+    // modified layout axes.
+    check_num_input_layouts(new_in_layouts);
+    check_num_input_layouts(old_in_layouts);
+
+    // Get inferred_in and inferred_out from new_in_layout.
+    std::tie(inferred_in, inferred_out, new_r_axes) = infer(new_in_layouts[0]);
+    params->axis = new_r_axes;
+  } else if (old_in_layouts.defined()) {
+    check_num_input_layouts(old_in_layouts);
+
+    // If the new layout is undefined, get inferred_in and inferred_out from 
old_in_layout.
+    if (old_in_layouts[0].defined()) {
+      std::tie(inferred_in, inferred_out, std::ignore) = 
infer(old_in_layouts[0]);
+    }
+  }
+
+  new_input_layouts.push_back(inferred_in);
+
+  if (old_in_layouts.size() == 2) {
+    new_input_layouts.push_back(inferred_in);
+  }
+
+  return InferCorrectLayoutOutput(new_input_layouts, {inferred_out}, 
Attrs(params));
+}
+
+template <typename F>
+Array<te::Tensor> ReduceCompute(const Attrs& attrs, const Array<te::Tensor>& 
inputs,
+                                const Type& out_type, F f) {
+  const ReduceAttrs* param = attrs.as<ReduceAttrs>();
+  ICHECK(param != nullptr);
+  if (inputs[0]->shape.size() == 0) {
+    return {topi::identity(inputs[0])};
+  }
+  auto axes = param->axis;
+  if (param->exclude) {
+    axes = GetExcludeAxes(inputs[0]->shape.size(), param->axis);
+    if (axes.size() == 0) {
+      return {topi::identity(inputs[0])};
+    }
+  }
+
+  return {f(inputs[0], axes, param->keepdims, false)};
+}
+
+template <typename F>
+Array<te::Tensor> ArgReduceCompute(const Attrs& attrs, const 
Array<te::Tensor>& inputs,
+                                   const Type& out_type, F f) {
+  const ArgReduceAttrs* param = attrs.as<ArgReduceAttrs>();
+  ICHECK(param != nullptr);
+  if (inputs[0]->shape.size() == 0) {
+    return {topi::identity(inputs[0])};
+  }
+  auto axes = param->axis;
+  if (param->exclude) {
+    axes = GetExcludeAxes(inputs[0]->shape.size(), param->axis);
+    if (axes.size() == 0) {
+      return {topi::identity(inputs[0])};
+    }
+  }
+
+  return {f(inputs[0], axes, param->keepdims, false, 
param->select_last_index)};
+}
+
+/*!
+ * \brief ReduceShapeImpl get the outshape for the reduction operator
+ * \param in_shape Shape of input data.
+ * \param param Attrs details.
+ * \param reporter The reporter to report solution to.
+ * \return oshape Output shape inferred.
+ * \tparam AttrsType The attribute type.
+ */
+template <typename AttrsType>
+inline std::vector<IndexExpr> ReduceShapeImpl(const std::vector<IndexExpr>& 
in_shape,
+                                              const AttrsType* param,
+                                              const TypeReporter& reporter) {
+  uint32_t indim = in_shape.size();
+  auto r_axes = GetReduceAxes(indim, param->axis, param->exclude);
+  if (!r_axes.size()) {
+    return in_shape;
+  }
+
+  auto max_shape = tir::make_const(DataType::Int(64), 1);
+  bool is_dynamic_input = false;
+  for (int64_t axis : r_axes) {
+    if (in_shape[axis].as<IntImmNode>()) {
+      max_shape *= in_shape[axis];
+    } else {
+      is_dynamic_input = true;
+      break;
+    }
+  }
+
+  if (is_dynamic_input) {
+    ICHECK(reporter->Assert(
+        max_shape < tir::make_const(DataType::Int(64), 
std::numeric_limits<int32_t>::max())))
+        << "The maximum possible index of reduced shape cannot be more than 
int32 max.";
+  }
+
+  if (param->keepdims) {
+    std::vector<IndexExpr> oshape(in_shape);
+    for (unsigned i = 0, j = 0; i < indim; ++i) {
+      if (j >= r_axes.size() || !(r_axes[j] == i)) {
+        continue;
+      }
+      oshape[i] = 1;
+      ++j;
+    }
+    return oshape;
+  } else {
+    auto osize = indim - r_axes.size();
+    std::vector<IndexExpr> oshape(osize);
+    for (unsigned i = 0, j = 0, k = 0; i < indim; ++i) {
+      if (j < r_axes.size() && (r_axes[j] == i)) {
+        ++j;
+        continue;
+      }
+      oshape[k++] = in_shape[i];
+    }
+    return oshape;
+  }
+}
+
+}  // namespace relay
+}  // namespace tvm
+
+#endif  // TVM_RELAY_OP_TENSOR_REDUCE_H
diff --git a/src/relay/transforms/simplify_inference.cc 
b/src/relay/transforms/simplify_inference.cc
index e7eef41e41..a7d12740b0 100644
--- a/src/relay/transforms/simplify_inference.cc
+++ b/src/relay/transforms/simplify_inference.cc
@@ -115,6 +115,7 @@ Expr GroupNormToInferUnpack(const Attrs attrs, Expr data, 
Expr gamma, Expr beta,
   return out;
 }
 
+/*
 Expr LayerNormToInferUnpack(const Attrs attrs, Expr data, Expr gamma, Expr 
beta, Type tdata) {
   auto ttype = tdata.as<TensorTypeNode>();
   ICHECK(ttype);
@@ -137,6 +138,7 @@ Expr LayerNormToInferUnpack(const Attrs attrs, Expr data, 
Expr gamma, Expr beta,
   }
   return out;
 }
+*/
 
 Expr InstanceNormToInferUnpack(const Attrs attrs, Expr data, Expr gamma, Expr 
beta, Type tdata) {
   auto ttype = tdata.as<TensorTypeNode>();
@@ -184,7 +186,7 @@ class InferenceSimplifier : public MixedModeMutator {
       : batch_norm_op_(Op::Get("nn.batch_norm")),
         dropout_op_(Op::Get("nn.dropout")),
         instance_norm_op_(Op::Get("nn.instance_norm")),
-        layer_norm_op_(Op::Get("nn.layer_norm")),
+        // layer_norm_op_(Op::Get("nn.layer_norm")),
         group_norm_op_(Op::Get("nn.group_norm")),
         l2_norm_op_(Op::Get("nn.l2_normalize")) {}
 
@@ -207,10 +209,6 @@ class InferenceSimplifier : public MixedModeMutator {
   Expr Rewrite_(const CallNode* n, const Expr& new_n) {
     if (n->op == batch_norm_op_) {
       ty_map_[new_n.as<CallNode>()->args[0]] = n->args[0]->checked_type();
-    } else if (n->op == layer_norm_op_) {
-      const auto* call = new_n.as<CallNode>();
-      return LayerNormToInferUnpack(call->attrs, call->args[0], call->args[1], 
call->args[2],
-                                    n->args[0]->checked_type());
     } else if (n->op == group_norm_op_) {
       const auto* call = new_n.as<CallNode>();
       return GroupNormToInferUnpack(call->attrs, call->args[0], call->args[1], 
call->args[2],
@@ -233,7 +231,6 @@ class InferenceSimplifier : public MixedModeMutator {
   const Op& batch_norm_op_;
   const Op& dropout_op_;
   const Op& instance_norm_op_;
-  const Op& layer_norm_op_;
   const Op& group_norm_op_;
   const Op& l2_norm_op_;
   std::unordered_map<Expr, Type, ObjectPtrHash, ObjectPtrEqual> ty_map_;

Reply via email to