You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by an...@apache.org on 2022/08/24 20:28:48 UTC

[tvm] branch aluo/play-with-layer-norm created (now ce5e9ae637)

This is an automated email from the ASF dual-hosted git repository.

andrewzhaoluo pushed a change to branch aluo/play-with-layer-norm
in repository https://gitbox.apache.org/repos/asf/tvm.git


      at ce5e9ae637 stash work

This branch includes the following new commits:

     new ce5e9ae637 stash work

The 1 revisions listed above as "new" are entirely new to this
repository and will be described in separate emails.  The revisions
listed as "add" were already present in the repository and have only
been added to this reference.



[tvm] 01/01: stash work

Posted by an...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

andrewzhaoluo pushed a commit to branch aluo/play-with-layer-norm
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit ce5e9ae637a44bcea108edb90c03501d1a550238
Author: Andrew Zhao Luo <an...@gmail.com>
AuthorDate: Wed Aug 24 13:28:20 2022 -0700

    stash work
---
 src/relay/op/nn/nn.cc                      |  21 ++
 src/relay/op/tensor/reduce.cc              | 285 +--------------------------
 src/relay/op/tensor/reduce.h               | 303 +++++++++++++++++++++++++++++
 src/relay/transforms/simplify_inference.cc |   9 +-
 4 files changed, 329 insertions(+), 289 deletions(-)

diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc
index 9e73c64564..4e5aaf7d67 100644
--- a/src/relay/op/nn/nn.cc
+++ b/src/relay/op/nn/nn.cc
@@ -35,12 +35,15 @@
 #include <tvm/topi/nn/softmax.h>
 
 #include <algorithm>
+#include <limits>
+#include <numeric>
 #include <string>
 #include <vector>
 
 #include "../../transforms/infer_layout_utils.h"
 #include "../make_op.h"
 #include "../op_common.h"
+#include "../tensor/reduce.h"
 #include "../type_relations.h"
 
 namespace tvm {
@@ -976,6 +979,22 @@ Expr MakeLayerNorm(Expr data, Expr gamma, Expr beta, int axis, double epsilon, b
 
 TVM_REGISTER_GLOBAL("relay.op.nn._make.layer_norm").set_body_typed(MakeLayerNorm);
 
+Array<te::Tensor> LayerNormCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
+                                   const Type& out_type) {
+  IndexExpr count = tir::make_const(inputs[0]->dtype, 1);
+  const ReduceAttrs* param = attrs.as<ReduceAttrs>();
+  ICHECK(param != nullptr);
+  auto axes = param->axis;
+  for (int64_t i : GetReduceAxes(inputs[0]->shape.size(), param->axis, param->exclude)) {
+    count *= inputs[0]->shape[i];
+  }
+  // Although count is created as inputs[0]->dtype,
+  // its type may be changed (promoted) during multiplication
+  count = cast(inputs[0]->dtype, count);
+  auto res = ReduceCompute(attrs, inputs, out_type, topi::sum);
+  return {topi::divide(res[0], count)};
+}
+
 RELAY_REGISTER_OP("nn.layer_norm")
     .describe(R"code(
 )code" TVM_ADD_FILELINE)
@@ -986,6 +1005,8 @@ RELAY_REGISTER_OP("nn.layer_norm")
     .add_argument("beta", "Tensor", "The beta offset factor.")
     .set_attr<FInferCorrectLayout>("FInferCorrectLayout",
                                    NormalizationInferCorrectLayout<LayerNormAttrs>)
+    .set_attr<TOpPattern>("TOpPattern", kElemWise)
+    .set_attr<FTVMCompute>("FTVMCompute", LayerNormCompute)
     .set_support_level(1)
     .add_type_rel("LayerNorm", LayerNormRel);
 
diff --git a/src/relay/op/tensor/reduce.cc b/src/relay/op/tensor/reduce.cc
index 2b1afc6e55..37e884e5f5 100644
--- a/src/relay/op/tensor/reduce.cc
+++ b/src/relay/op/tensor/reduce.cc
@@ -21,6 +21,8 @@
  * \file reduce.cc
  * \brief Reduction operators.
  */
+#include "reduce.h"
+
 #include <tvm/relay/attrs/reduce.h>
 #include <tvm/relay/expr.h>
 #include <tvm/relay/op.h>
@@ -41,289 +43,6 @@ TVM_REGISTER_NODE_TYPE(ReduceAttrs);
 TVM_REGISTER_NODE_TYPE(ArgReduceAttrs);
 TVM_REGISTER_NODE_TYPE(VarianceAttrs);
 
-/*!
- * \brief GetReduceAxes, get the new axis from indim and other arguments
- * \param indim Number of dimensions of input data.
- * \param axis The input axis vector.
- * \param exclude Whether 'axis' input given is the excluded axis.
- * \return r_axes The new reduced axes of the output.
- */
-inline std::vector<int64_t> GetReduceAxes(const uint32_t indim, const Array<Integer>& inaxis,
-                                          bool exclude) {
-  if (!inaxis.defined() || inaxis.empty()) {
-    std::vector<int64_t> r_axes(indim);
-    std::iota(r_axes.begin(), r_axes.end(), 0);
-    return r_axes;
-  }
-
-  std::vector<int64_t> in_axes;
-  for (auto i : inaxis) {
-    int64_t axis = i->value;
-    if (axis < 0) {
-      axis = axis + indim;
-    }
-
-    // Check out of bounds error
-    ICHECK(axis >= 0) << "Axis out of bounds in reduce operator.";
-    ICHECK(axis < indim) << "Axis out of bounds in reduce operator.";
-    in_axes.push_back(axis);
-  }
-
-  ICHECK(in_axes[in_axes.size() - 1] < indim)
-      << "Reduction axis " << in_axes[in_axes.size() - 1] << " exceeds input dimensions " << indim;
-
-  std::sort(in_axes.begin(), in_axes.end());
-
-  if (!exclude) {
-    return in_axes;
-  }
-
-  auto r_size = indim - in_axes.size();
-  std::vector<int64_t> r_axes(r_size);
-  for (uint32_t i = 0, j = 0, k = 0; i < indim; ++i) {
-    if (j < in_axes.size() && in_axes[j] == i) {
-      ++j;
-      continue;
-    }
-    r_axes[k++] = i;
-  }
-  return r_axes;
-}
-
-// Get axis under exclude condition.
-Array<Integer> GetExcludeAxes(size_t indim, const Array<Integer>& inaxis) {
-  ICHECK(inaxis.defined()) << "Cannot set exclude when axis=None";
-  std::vector<bool> axis_flag(indim, true);
-  for (auto i : inaxis) {
-    int64_t axis = i->value;
-    if (axis < 0) {
-      axis = axis + static_cast<int64_t>(indim);
-    }
-    // Check out of bounds error
-    ICHECK_GE(axis, 0) << "Axis out of bounds in reduce operator.";
-    ICHECK_LT(axis, static_cast<int64_t>(indim)) << "Axis out of bounds in reduce operator.";
-    axis_flag[axis] = false;
-  }
-
-  Array<Integer> r_axes;
-
-  for (size_t i = 0; i < axis_flag.size(); ++i) {
-    if (axis_flag[i]) {
-      r_axes.push_back(static_cast<int>(i));
-    }
-  }
-  return r_axes;
-}
-
-// Return the modified layout for AlterOpLayout pass.
-template <typename T>
-InferCorrectLayoutOutput ReduceInferCorrectLayout(const Attrs& attrs,
-                                                  const Array<Layout>& new_in_layouts,
-                                                  const Array<Layout>& old_in_layouts,
-                                                  const Array<tvm::relay::Type>& old_in_types) {
-  const auto* attrs_ptr = attrs.as<T>();
-  ICHECK(attrs_ptr);
-  ObjectPtr<T> params = make_object<T>(*attrs_ptr);
-
-  // Get the reduce axes.
-  Array<Array<IndexExpr>> old_in_shapes;
-  for (auto old_in_t : old_in_types) {
-    ICHECK(old_in_t.as<TensorTypeNode>());
-    old_in_shapes.push_back(old_in_t.as<TensorTypeNode>()->shape);
-  }
-  uint32_t indim = old_in_shapes[0].size();
-  auto r_axes = GetReduceAxes(indim, params->axis, params->exclude);
-
-  Layout inferred_in = Layout::Undef();
-  Layout inferred_out = Layout::Undef();
-
-  // Infer [in_layout, out_layout, new_r_axes] from old_in_layout or new_in_layout
-  auto infer = [&](const Layout& layout) {
-    // 1) Collect the original axes
-    std::unordered_set<std::string> old_r_dims;
-    for (auto r_axis : r_axes) {
-      old_r_dims.emplace(old_in_layouts[0][r_axis].name());
-    }
-
-    // 2) Collect the new axes by walking new_layout.
-    tvm::Array<tvm::Integer> new_r_axes;
-    std::string inferred_in_string = "";
-    std::string inferred_out_string = "";
-    auto push_new_axis = [&](const std::string& layout_dim, int axis) {
-      if ((old_r_dims.count(layout_dim) && !params->exclude) ||
-          (!old_r_dims.count(layout_dim) && params->exclude)) {
-        new_r_axes.push_back(tvm::Integer(axis));
-        return true;
-      }
-      return false;
-    };
-    for (size_t axis_index = 0; axis_index < layout->axes.size(); ++axis_index) {
-      const auto& layout_axis = LayoutAxis::Get(layout->axes[axis_index]);
-      const std::string& layout_dim = layout_axis.name();
-      if (layout_axis.IsPrimal()) {
-        push_new_axis(layout_dim, axis_index);
-        inferred_in_string += layout_dim;
-        if (!old_r_dims.count(layout_dim) || params->keepdims) {
-          inferred_out_string += layout_dim;
-        }
-      } else {
-        // For example, if the original layout is NCHW, the new layout is NCHW8c, and the original
-        // reduce axes is [1], the new reduce axes become [1, 4].
-        auto primal_dim = layout_axis.ToPrimal().name();
-        auto packed_dim = std::to_string(layout.FactorOf(layout_axis)) + layout_dim;
-        inferred_in_string += packed_dim;
-        if (push_new_axis(primal_dim, axis_index)) {
-          if (params->exclude) {
-            // The primal axis is not reduced, so keep the input packed dim.
-            inferred_out_string += packed_dim;
-          } else if (params->keepdims) {
-            // If the primal axis is part of reduce axes in the original layout, the inner dim
-            // becomes 1 after reduction.
-            inferred_out_string += "1" + layout_dim;
-          }
-        } else {
-          inferred_out_string += packed_dim;
-        }
-      }
-    }
-
-    // 3) Set the new axis and layout.
-    return std::make_tuple(Layout(inferred_in_string), Layout(inferred_out_string), new_r_axes);
-  };
-
-  std::string new_layout_string;
-  Array<Integer> new_r_axes;
-  Array<Layout> new_input_layouts;
-
-  auto check_num_input_layouts = [](Array<Layout> in_layouts) {
-    // The second case is for variance op
-    ICHECK(in_layouts.size() == 1 || in_layouts.size() == 2);
-  };
-
-  if (new_in_layouts.defined() && r_axes.size()) {
-    // Adapt to new layout. The axis has to change. Record original reduce axes. Convert to the
-    // modified layout axes.
-    check_num_input_layouts(new_in_layouts);
-    check_num_input_layouts(old_in_layouts);
-
-    // Get inferred_in and inferred_out from new_in_layout.
-    std::tie(inferred_in, inferred_out, new_r_axes) = infer(new_in_layouts[0]);
-    params->axis = new_r_axes;
-  } else if (old_in_layouts.defined()) {
-    check_num_input_layouts(old_in_layouts);
-
-    // If the new layout is undefined, get inferred_in and inferred_out from old_in_layout.
-    if (old_in_layouts[0].defined()) {
-      std::tie(inferred_in, inferred_out, std::ignore) = infer(old_in_layouts[0]);
-    }
-  }
-
-  new_input_layouts.push_back(inferred_in);
-
-  if (old_in_layouts.size() == 2) {
-    new_input_layouts.push_back(inferred_in);
-  }
-
-  return InferCorrectLayoutOutput(new_input_layouts, {inferred_out}, Attrs(params));
-}
-
-template <typename F>
-Array<te::Tensor> ReduceCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
-                                const Type& out_type, F f) {
-  const ReduceAttrs* param = attrs.as<ReduceAttrs>();
-  ICHECK(param != nullptr);
-  if (inputs[0]->shape.size() == 0) {
-    return {topi::identity(inputs[0])};
-  }
-  auto axes = param->axis;
-  if (param->exclude) {
-    axes = GetExcludeAxes(inputs[0]->shape.size(), param->axis);
-    if (axes.size() == 0) {
-      return {topi::identity(inputs[0])};
-    }
-  }
-
-  return {f(inputs[0], axes, param->keepdims, false)};
-}
-
-template <typename F>
-Array<te::Tensor> ArgReduceCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
-                                   const Type& out_type, F f) {
-  const ArgReduceAttrs* param = attrs.as<ArgReduceAttrs>();
-  ICHECK(param != nullptr);
-  if (inputs[0]->shape.size() == 0) {
-    return {topi::identity(inputs[0])};
-  }
-  auto axes = param->axis;
-  if (param->exclude) {
-    axes = GetExcludeAxes(inputs[0]->shape.size(), param->axis);
-    if (axes.size() == 0) {
-      return {topi::identity(inputs[0])};
-    }
-  }
-
-  return {f(inputs[0], axes, param->keepdims, false, param->select_last_index)};
-}
-
-/*!
- * \brief ReduceShapeImpl get the outshape for the reduction operator
- * \param in_shape Shape of input data.
- * \param param Attrs details.
- * \param reporter The reporter to report solution to.
- * \return oshape Output shape inferred.
- * \tparam AttrsType The attribute type.
- */
-template <typename AttrsType>
-inline std::vector<IndexExpr> ReduceShapeImpl(const std::vector<IndexExpr>& in_shape,
-                                              const AttrsType* param,
-                                              const TypeReporter& reporter) {
-  uint32_t indim = in_shape.size();
-  auto r_axes = GetReduceAxes(indim, param->axis, param->exclude);
-  if (!r_axes.size()) {
-    return in_shape;
-  }
-
-  auto max_shape = tir::make_const(DataType::Int(64), 1);
-  bool is_dynamic_input = false;
-  for (int64_t axis : r_axes) {
-    if (in_shape[axis].as<IntImmNode>()) {
-      max_shape *= in_shape[axis];
-    } else {
-      is_dynamic_input = true;
-      break;
-    }
-  }
-
-  if (is_dynamic_input) {
-    ICHECK(reporter->Assert(
-        max_shape < tir::make_const(DataType::Int(64), std::numeric_limits<int32_t>::max())))
-        << "The maximum possible index of reduced shape cannot be more than int32 max.";
-  }
-
-  if (param->keepdims) {
-    std::vector<IndexExpr> oshape(in_shape);
-    for (unsigned i = 0, j = 0; i < indim; ++i) {
-      if (j >= r_axes.size() || !(r_axes[j] == i)) {
-        continue;
-      }
-      oshape[i] = 1;
-      ++j;
-    }
-    return oshape;
-  } else {
-    auto osize = indim - r_axes.size();
-    std::vector<IndexExpr> oshape(osize);
-    for (unsigned i = 0, j = 0, k = 0; i < indim; ++i) {
-      if (j < r_axes.size() && (r_axes[j] == i)) {
-        ++j;
-        continue;
-      }
-      oshape[k++] = in_shape[i];
-    }
-    return oshape;
-  }
-}
-
 template <class T>
 bool GenericReduceRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
                       const TypeReporter& reporter) {
diff --git a/src/relay/op/tensor/reduce.h b/src/relay/op/tensor/reduce.h
new file mode 100644
index 0000000000..54b0f3951d
--- /dev/null
+++ b/src/relay/op/tensor/reduce.h
@@ -0,0 +1,303 @@
+#include <tvm/relay/attrs/reduce.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/op.h>
+#include <tvm/topi/elemwise.h>
+#include <tvm/topi/reduction.h>
+
+#include <limits>
+#include <numeric>
+
+#include "../make_op.h"
+#include "../op_common.h"
+#include "../type_relations.h"
+
+#ifndef TVM_RELAY_OP_TENSOR_REDUCE_H
+#define TVM_RELAY_OP_TENSOR_REDUCE_H
+
+namespace tvm {
+namespace relay {
+/*! \brief GetReduceAxes,
+    get the new axis from indim and other arguments * \param indim Number of dimensions of input
+            data.* \param axis The input axis
+                vector.* \param exclude Whether 'axis' input given is the excluded
+                    axis.* \return r_axes The new reduced axes of the output. */
+inline std::vector<int64_t> GetReduceAxes(const uint32_t indim, const Array<Integer>& inaxis,
+                                          bool exclude) {
+  if (!inaxis.defined() || inaxis.empty()) {
+    std::vector<int64_t> r_axes(indim);
+    std::iota(r_axes.begin(), r_axes.end(), 0);
+    return r_axes;
+  }
+
+  std::vector<int64_t> in_axes;
+  for (auto i : inaxis) {
+    int64_t axis = i->value;
+    if (axis < 0) {
+      axis = axis + indim;
+    }
+
+    // Check out of bounds error
+    ICHECK(axis >= 0) << "Axis out of bounds in reduce operator.";
+    ICHECK(axis < indim) << "Axis out of bounds in reduce operator.";
+    in_axes.push_back(axis);
+  }
+
+  ICHECK(in_axes[in_axes.size() - 1] < indim)
+      << "Reduction axis " << in_axes[in_axes.size() - 1] << " exceeds input dimensions " << indim;
+
+  std::sort(in_axes.begin(), in_axes.end());
+
+  if (!exclude) {
+    return in_axes;
+  }
+
+  auto r_size = indim - in_axes.size();
+  std::vector<int64_t> r_axes(r_size);
+  for (uint32_t i = 0, j = 0, k = 0; i < indim; ++i) {
+    if (j < in_axes.size() && in_axes[j] == i) {
+      ++j;
+      continue;
+    }
+    r_axes[k++] = i;
+  }
+  return r_axes;
+}
+
+// Get axis under exclude condition.
+Array<Integer> GetExcludeAxes(size_t indim, const Array<Integer>& inaxis) {
+  ICHECK(inaxis.defined()) << "Cannot set exclude when axis=None";
+  std::vector<bool> axis_flag(indim, true);
+  for (auto i : inaxis) {
+    int64_t axis = i->value;
+    if (axis < 0) {
+      axis = axis + static_cast<int64_t>(indim);
+    }
+    // Check out of bounds error
+    ICHECK_GE(axis, 0) << "Axis out of bounds in reduce operator.";
+    ICHECK_LT(axis, static_cast<int64_t>(indim)) << "Axis out of bounds in reduce operator.";
+    axis_flag[axis] = false;
+  }
+
+  Array<Integer> r_axes;
+
+  for (size_t i = 0; i < axis_flag.size(); ++i) {
+    if (axis_flag[i]) {
+      r_axes.push_back(static_cast<int>(i));
+    }
+  }
+  return r_axes;
+}
+
+// Return the modified layout for AlterOpLayout pass.
+template <typename T>
+InferCorrectLayoutOutput ReduceInferCorrectLayout(const Attrs& attrs,
+                                                  const Array<Layout>& new_in_layouts,
+                                                  const Array<Layout>& old_in_layouts,
+                                                  const Array<tvm::relay::Type>& old_in_types) {
+  const auto* attrs_ptr = attrs.as<T>();
+  ICHECK(attrs_ptr);
+  ObjectPtr<T> params = make_object<T>(*attrs_ptr);
+
+  // Get the reduce axes.
+  Array<Array<IndexExpr>> old_in_shapes;
+  for (auto old_in_t : old_in_types) {
+    ICHECK(old_in_t.as<TensorTypeNode>());
+    old_in_shapes.push_back(old_in_t.as<TensorTypeNode>()->shape);
+  }
+  uint32_t indim = old_in_shapes[0].size();
+  auto r_axes = GetReduceAxes(indim, params->axis, params->exclude);
+
+  Layout inferred_in = Layout::Undef();
+  Layout inferred_out = Layout::Undef();
+
+  // Infer [in_layout, out_layout, new_r_axes] from old_in_layout or new_in_layout
+  auto infer = [&](const Layout& layout) {
+    // 1) Collect the original axes
+    std::unordered_set<std::string> old_r_dims;
+    for (auto r_axis : r_axes) {
+      old_r_dims.emplace(old_in_layouts[0][r_axis].name());
+    }
+
+    // 2) Collect the new axes by walking new_layout.
+    tvm::Array<tvm::Integer> new_r_axes;
+    std::string inferred_in_string = "";
+    std::string inferred_out_string = "";
+    auto push_new_axis = [&](const std::string& layout_dim, int axis) {
+      if ((old_r_dims.count(layout_dim) && !params->exclude) ||
+          (!old_r_dims.count(layout_dim) && params->exclude)) {
+        new_r_axes.push_back(tvm::Integer(axis));
+        return true;
+      }
+      return false;
+    };
+    for (size_t axis_index = 0; axis_index < layout->axes.size(); ++axis_index) {
+      const auto& layout_axis = LayoutAxis::Get(layout->axes[axis_index]);
+      const std::string& layout_dim = layout_axis.name();
+      if (layout_axis.IsPrimal()) {
+        push_new_axis(layout_dim, axis_index);
+        inferred_in_string += layout_dim;
+        if (!old_r_dims.count(layout_dim) || params->keepdims) {
+          inferred_out_string += layout_dim;
+        }
+      } else {
+        // For example, if the original layout is NCHW, the new layout is NCHW8c, and the original
+        // reduce axes is [1], the new reduce axes become [1, 4].
+        auto primal_dim = layout_axis.ToPrimal().name();
+        auto packed_dim = std::to_string(layout.FactorOf(layout_axis)) + layout_dim;
+        inferred_in_string += packed_dim;
+        if (push_new_axis(primal_dim, axis_index)) {
+          if (params->exclude) {
+            // The primal axis is not reduced, so keep the input packed dim.
+            inferred_out_string += packed_dim;
+          } else if (params->keepdims) {
+            // If the primal axis is part of reduce axes in the original layout, the inner dim
+            // becomes 1 after reduction.
+            inferred_out_string += "1" + layout_dim;
+          }
+        } else {
+          inferred_out_string += packed_dim;
+        }
+      }
+    }
+
+    // 3) Set the new axis and layout.
+    return std::make_tuple(Layout(inferred_in_string), Layout(inferred_out_string), new_r_axes);
+  };
+
+  std::string new_layout_string;
+  Array<Integer> new_r_axes;
+  Array<Layout> new_input_layouts;
+
+  auto check_num_input_layouts = [](Array<Layout> in_layouts) {
+    // The second case is for variance op
+    ICHECK(in_layouts.size() == 1 || in_layouts.size() == 2);
+  };
+
+  if (new_in_layouts.defined() && r_axes.size()) {
+    // Adapt to new layout. The axis has to change. Record original reduce axes. Convert to the
+    // modified layout axes.
+    check_num_input_layouts(new_in_layouts);
+    check_num_input_layouts(old_in_layouts);
+
+    // Get inferred_in and inferred_out from new_in_layout.
+    std::tie(inferred_in, inferred_out, new_r_axes) = infer(new_in_layouts[0]);
+    params->axis = new_r_axes;
+  } else if (old_in_layouts.defined()) {
+    check_num_input_layouts(old_in_layouts);
+
+    // If the new layout is undefined, get inferred_in and inferred_out from old_in_layout.
+    if (old_in_layouts[0].defined()) {
+      std::tie(inferred_in, inferred_out, std::ignore) = infer(old_in_layouts[0]);
+    }
+  }
+
+  new_input_layouts.push_back(inferred_in);
+
+  if (old_in_layouts.size() == 2) {
+    new_input_layouts.push_back(inferred_in);
+  }
+
+  return InferCorrectLayoutOutput(new_input_layouts, {inferred_out}, Attrs(params));
+}
+
+template <typename F>
+Array<te::Tensor> ReduceCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
+                                const Type& out_type, F f) {
+  const ReduceAttrs* param = attrs.as<ReduceAttrs>();
+  ICHECK(param != nullptr);
+  if (inputs[0]->shape.size() == 0) {
+    return {topi::identity(inputs[0])};
+  }
+  auto axes = param->axis;
+  if (param->exclude) {
+    axes = GetExcludeAxes(inputs[0]->shape.size(), param->axis);
+    if (axes.size() == 0) {
+      return {topi::identity(inputs[0])};
+    }
+  }
+
+  return {f(inputs[0], axes, param->keepdims, false)};
+}
+
+template <typename F>
+Array<te::Tensor> ArgReduceCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
+                                   const Type& out_type, F f) {
+  const ArgReduceAttrs* param = attrs.as<ArgReduceAttrs>();
+  ICHECK(param != nullptr);
+  if (inputs[0]->shape.size() == 0) {
+    return {topi::identity(inputs[0])};
+  }
+  auto axes = param->axis;
+  if (param->exclude) {
+    axes = GetExcludeAxes(inputs[0]->shape.size(), param->axis);
+    if (axes.size() == 0) {
+      return {topi::identity(inputs[0])};
+    }
+  }
+
+  return {f(inputs[0], axes, param->keepdims, false, param->select_last_index)};
+}
+
+/*!
+ * \brief ReduceShapeImpl get the outshape for the reduction operator
+ * \param in_shape Shape of input data.
+ * \param param Attrs details.
+ * \param reporter The reporter to report solution to.
+ * \return oshape Output shape inferred.
+ * \tparam AttrsType The attribute type.
+ */
+template <typename AttrsType>
+inline std::vector<IndexExpr> ReduceShapeImpl(const std::vector<IndexExpr>& in_shape,
+                                              const AttrsType* param,
+                                              const TypeReporter& reporter) {
+  uint32_t indim = in_shape.size();
+  auto r_axes = GetReduceAxes(indim, param->axis, param->exclude);
+  if (!r_axes.size()) {
+    return in_shape;
+  }
+
+  auto max_shape = tir::make_const(DataType::Int(64), 1);
+  bool is_dynamic_input = false;
+  for (int64_t axis : r_axes) {
+    if (in_shape[axis].as<IntImmNode>()) {
+      max_shape *= in_shape[axis];
+    } else {
+      is_dynamic_input = true;
+      break;
+    }
+  }
+
+  if (is_dynamic_input) {
+    ICHECK(reporter->Assert(
+        max_shape < tir::make_const(DataType::Int(64), std::numeric_limits<int32_t>::max())))
+        << "The maximum possible index of reduced shape cannot be more than int32 max.";
+  }
+
+  if (param->keepdims) {
+    std::vector<IndexExpr> oshape(in_shape);
+    for (unsigned i = 0, j = 0; i < indim; ++i) {
+      if (j >= r_axes.size() || !(r_axes[j] == i)) {
+        continue;
+      }
+      oshape[i] = 1;
+      ++j;
+    }
+    return oshape;
+  } else {
+    auto osize = indim - r_axes.size();
+    std::vector<IndexExpr> oshape(osize);
+    for (unsigned i = 0, j = 0, k = 0; i < indim; ++i) {
+      if (j < r_axes.size() && (r_axes[j] == i)) {
+        ++j;
+        continue;
+      }
+      oshape[k++] = in_shape[i];
+    }
+    return oshape;
+  }
+}
+
+}  // namespace relay
+}  // namespace tvm
+
+#endif  // TVM_RELAY_OP_TENSOR_REDUCE_H
diff --git a/src/relay/transforms/simplify_inference.cc b/src/relay/transforms/simplify_inference.cc
index e7eef41e41..a7d12740b0 100644
--- a/src/relay/transforms/simplify_inference.cc
+++ b/src/relay/transforms/simplify_inference.cc
@@ -115,6 +115,7 @@ Expr GroupNormToInferUnpack(const Attrs attrs, Expr data, Expr gamma, Expr beta,
   return out;
 }
 
+/*
 Expr LayerNormToInferUnpack(const Attrs attrs, Expr data, Expr gamma, Expr beta, Type tdata) {
   auto ttype = tdata.as<TensorTypeNode>();
   ICHECK(ttype);
@@ -137,6 +138,7 @@ Expr LayerNormToInferUnpack(const Attrs attrs, Expr data, Expr gamma, Expr beta,
   }
   return out;
 }
+*/
 
 Expr InstanceNormToInferUnpack(const Attrs attrs, Expr data, Expr gamma, Expr beta, Type tdata) {
   auto ttype = tdata.as<TensorTypeNode>();
@@ -184,7 +186,7 @@ class InferenceSimplifier : public MixedModeMutator {
       : batch_norm_op_(Op::Get("nn.batch_norm")),
         dropout_op_(Op::Get("nn.dropout")),
         instance_norm_op_(Op::Get("nn.instance_norm")),
-        layer_norm_op_(Op::Get("nn.layer_norm")),
+        // layer_norm_op_(Op::Get("nn.layer_norm")),
         group_norm_op_(Op::Get("nn.group_norm")),
         l2_norm_op_(Op::Get("nn.l2_normalize")) {}
 
@@ -207,10 +209,6 @@ class InferenceSimplifier : public MixedModeMutator {
   Expr Rewrite_(const CallNode* n, const Expr& new_n) {
     if (n->op == batch_norm_op_) {
       ty_map_[new_n.as<CallNode>()->args[0]] = n->args[0]->checked_type();
-    } else if (n->op == layer_norm_op_) {
-      const auto* call = new_n.as<CallNode>();
-      return LayerNormToInferUnpack(call->attrs, call->args[0], call->args[1], call->args[2],
-                                    n->args[0]->checked_type());
     } else if (n->op == group_norm_op_) {
       const auto* call = new_n.as<CallNode>();
       return GroupNormToInferUnpack(call->attrs, call->args[0], call->args[1], call->args[2],
@@ -233,7 +231,6 @@ class InferenceSimplifier : public MixedModeMutator {
   const Op& batch_norm_op_;
   const Op& dropout_op_;
   const Op& instance_norm_op_;
-  const Op& layer_norm_op_;
   const Op& group_norm_op_;
   const Op& l2_norm_op_;
   std::unordered_map<Expr, Type, ObjectPtrHash, ObjectPtrEqual> ty_map_;