You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by pt...@apache.org on 2021/12/02 18:16:26 UTC

[incubator-mxnet] branch master updated: Automatic Layout Management (#20718)

This is an automated email from the ASF dual-hosted git repository.

ptrendx pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 40359ce  Automatic Layout Management (#20718)
40359ce is described below

commit 40359ceda150ca75da6e45b1ea35d747ef53deac
Author: Vladimir Cherepanov <56...@users.noreply.github.com>
AuthorDate: Thu Dec 2 10:14:13 2021 -0800

    Automatic Layout Management (#20718)
    
    * Automatic Layout Management
    
    Originally authored by Dawid Tracz <dt...@nvidia.com>
    
    * Fix clang-format
    
    * Fix clang-format in mshadow
    
    * Print layout name instead of a number
    
    * Generalize NHWC target layout to other dimensions
    
    * Change layout optimization API
    
    * Add layout optimization tests
    
    * Add backward check to tests
    
    * Generalize tests to 1..3 spatial dims
    
    * Add NWC layout to ConvolutionParams
    
    * Enable layout optimization tests only with cuDNN
    
    Co-authored-by: Vladimir Cherepanov <vc...@nvidia.com>
---
 3rdparty/mshadow/mshadow/base.h                 |  60 +++++++
 3rdparty/mshadow/mshadow/tensor.h               |  91 +++++++++++
 include/mxnet/c_api.h                           |  10 ++
 python/mxnet/amp/amp.py                         |   8 +-
 src/c_api/c_api.cc                              |  13 ++
 src/common/alm.cc                               | 209 ++++++++++++++++++++++++
 src/common/alm.h                                | 100 ++++++++++++
 src/imperative/cached_op.h                      |   3 +
 src/operator/cudnn_ops.cc                       |   2 +-
 src/operator/elemwise_op_common.h               |  10 ++
 src/operator/leaky_relu.cc                      |  13 ++
 src/operator/nn/batch_norm.cc                   |  17 ++
 src/operator/nn/convolution-inl.h               |   1 +
 src/operator/nn/convolution.cc                  |  27 +++
 src/operator/nn/deconvolution.cc                |  25 +++
 src/operator/nn/pooling.cc                      |  18 ++
 src/operator/operator_common.h                  |   1 +
 src/operator/tensor/amp_cast.cc                 |  17 ++
 src/operator/tensor/elemwise_binary_op.h        |   1 +
 src/operator/tensor/elemwise_binary_scalar_op.h |   2 +
 src/operator/tensor/elemwise_unary_op.h         |   2 +
 src/operator/tensor/matrix_op.cc                |  17 ++
 tests/python/gpu/test_amp_init.py               |  96 ++++++++++-
 23 files changed, 737 insertions(+), 6 deletions(-)

diff --git a/3rdparty/mshadow/mshadow/base.h b/3rdparty/mshadow/mshadow/base.h
index e018551..5f6fb0c 100644
--- a/3rdparty/mshadow/mshadow/base.h
+++ b/3rdparty/mshadow/mshadow/base.h
@@ -496,6 +496,8 @@ const int index_type_flag = DataType<lapack_index_t>::kFlag;
 
 /*! layout flag */
 enum LayoutFlag {
+  kUNKNOWN = -1,
+
   kNCHW = 0,
   kNHWC,
   kCHWN,
@@ -509,6 +511,64 @@ enum LayoutFlag {
   kCDHWN
 };
 
+inline LayoutFlag layoutFlag(std::string layoutstr) {
+  switch (layoutstr.length()) {
+    case 4:
+      if (layoutstr == "NHWC")
+        return kNHWC;
+      if (layoutstr == "NCHW")
+        return kNCHW;
+      if (layoutstr == "CHWN")
+        return kCHWN;
+      return kUNKNOWN;
+    case 3:
+      if (layoutstr == "NWC")
+        return kNWC;
+      if (layoutstr == "NCW")
+        return kNCW;
+      if (layoutstr == "CWN")
+        return kCWN;
+      return kUNKNOWN;
+    case 5:
+      if (layoutstr == "NDHWC")
+        return kNDHWC;
+      if (layoutstr == "NCDHW")
+        return kNCDHW;
+      if (layoutstr == "CDHWN")
+        return kCDHWN;
+      return kUNKNOWN;
+    default:
+      return kUNKNOWN;
+  }
+}
+
+inline std::string toString(LayoutFlag layout) {
+  switch (layout) {
+    case kUNKNOWN:
+      return "";
+    case kNCHW:
+      return "NCHW";
+    case kNHWC:
+      return "NHWC";
+    case kCHWN:
+      return "CHWN";
+    case kNCW:
+      return "NCW";
+    case kNWC:
+      return "NWC";
+    case kCWN:
+      return "CWN";
+    case kNCDHW:
+      return "NCDHW";
+    case kNDHWC:
+      return "NDHWC";
+    case kCDHWN:
+      return "CDHWN";
+    default:
+      return "";
+  }
+}
+
 template<int layout>
 struct LayoutType;
 
diff --git a/3rdparty/mshadow/mshadow/tensor.h b/3rdparty/mshadow/mshadow/tensor.h
index e417fbb..fdf5e06 100644
--- a/3rdparty/mshadow/mshadow/tensor.h
+++ b/3rdparty/mshadow/mshadow/tensor.h
@@ -391,6 +391,97 @@ inline Shape<5> ConvertLayout(const Shape<5>& src, int src_layout, int dst_layou
 }
 
 /*!
+ * \brief returns axes of transpose operation
+ *        that needs to be performed between src layout and dst
+ * \param src_layout input layout
+ * \param dst_layout output layout
+ * \return vector of required type describing axes of a transpose operation
+ */
+template <typename dim_t>
+inline std::vector<dim_t> getTranspAxes(const LayoutFlag src_layout, const LayoutFlag dst_layout) {
+  auto apply = [](const std::vector<dim_t>& v, const std::vector<dim_t>& op) {
+    CHECK_EQ(v.size(), op.size()) << "Layout ndims does not match";
+    std::vector<dim_t> ret(v.size());
+    for (size_t i = 0; i < v.size(); i++) {
+      ret[i] = v[op[i]];
+    }
+    return ret;
+  };
+  std::vector<dim_t> axes;
+  // transpose from `case` to ND?H?WC
+  switch (src_layout) {
+    case kUNKNOWN:
+      LOG(FATAL) << "Unknown source layout";
+      break;
+    case kNHWC:
+      axes = std::vector<dim_t>({0, 1, 2, 3});
+      break;
+    case kNCHW:
+      axes = std::vector<dim_t>({0, 2, 3, 1});
+      break;
+    case kCHWN:
+      axes = std::vector<dim_t>({3, 1, 2, 0});
+      break;
+    case kNWC:
+      axes = std::vector<dim_t>({0, 1, 2});
+      break;
+    case kNCW:
+      axes = std::vector<dim_t>({0, 2, 1});
+      break;
+    case kCWN:
+      axes = std::vector<dim_t>({2, 1, 0});
+      break;
+    case kNDHWC:
+      axes = std::vector<dim_t>({0, 1, 2, 3, 4});
+      break;
+    case kNCDHW:
+      axes = std::vector<dim_t>({0, 2, 3, 4, 1});
+      break;
+    case kCDHWN:
+      axes = std::vector<dim_t>({4, 1, 2, 3, 0});
+      break;
+    default:
+      LOG(FATAL) << "Invalid source layout " << src_layout;
+  }
+  // transpose from ND?H?WC to `case`
+  switch (dst_layout) {
+    case kUNKNOWN:
+      LOG(FATAL) << "Unknown destination layout";
+      break;
+    case kNHWC:
+      axes = apply(axes, {0, 1, 2, 3});
+      break;
+    case kNCHW:
+      axes = apply(axes, {0, 3, 1, 2});
+      break;
+    case kCHWN:
+      axes = apply(axes, {3, 1, 2, 0});
+      break;
+    case kNWC:
+      axes = apply(axes, {0, 1, 2});
+      break;
+    case kNCW:
+      axes = apply(axes, {0, 2, 1});
+      break;
+    case kCWN:
+      axes = apply(axes, {2, 1, 0});
+      break;
+    case kNDHWC:
+      axes = apply(axes, {0, 1, 2, 3, 4});
+      break;
+    case kNCDHW:
+      axes = apply(axes, {0, 4, 1, 2, 3});
+      break;
+    case kCDHWN:
+      axes = apply(axes, {4, 1, 2, 3, 0});
+      break;
+    default:
+      LOG(FATAL) << "Invalid destination layout " << src_layout;
+  }
+  return axes;
+}
+
+/*!
  * \brief computaion stream structure, used for asynchronous computations
  */
 template<typename Device>
diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h
index b25ccad..94609de 100644
--- a/include/mxnet/c_api.h
+++ b/include/mxnet/c_api.h
@@ -3161,6 +3161,16 @@ MXNET_DLL int MXCUDAProfilerStart();
  */
 MXNET_DLL int MXCUDAProfilerStop();
 
+/*!
+ * \brief Turns on or off Layout Optimization
+ */
+MXNET_DLL int MXSetOptimizeLayout(bool val);
+
+/*!
+ * \brief Get current Layout Optimization status
+ */
+MXNET_DLL int MXGetOptimizeLayout(bool* val);
+
 #ifdef __cplusplus
 }
 #endif  // __cplusplus
diff --git a/python/mxnet/amp/amp.py b/python/mxnet/amp/amp.py
index c7aab71..750b3d0 100644
--- a/python/mxnet/amp/amp.py
+++ b/python/mxnet/amp/amp.py
@@ -307,7 +307,7 @@ def warn_if_model_exists():
                 return
 
 def init(target_dtype='float16', target_precision_ops=None,
-         conditional_fp32_ops=None, fp32_ops=None):
+         conditional_fp32_ops=None, fp32_ops=None, layout_optimization=False):
     """Initialize AMP (automatic mixed precision).
 
     This needs to be done before model creation.
@@ -333,7 +333,11 @@ def init(target_dtype='float16', target_precision_ops=None,
         assert target_dtype in ['float16', np.float16, 'bfloat16', bfloat16], \
                "AMP currently supports only float16 or bfloat16 as a target_dtype"
         _amp_initialized = True
-        logging.info("Using AMP")
+        log_msg = "Using AMP"
+        if layout_optimization:
+            log_msg += "\n - layout optimization: enabled"
+            check_call(_LIB.MXSetOptimizeLayout(ctypes.c_bool(True)))
+        logging.info(log_msg)
         if target_dtype == "bfloat16":
             target_dtype = bfloat16
         else:
diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc
index 61a47b0..d533a2a 100644
--- a/src/c_api/c_api.cc
+++ b/src/c_api/c_api.cc
@@ -55,6 +55,7 @@
 #include "../operator/tvmop/op_module.h"
 #include "../operator/subgraph/partitioner/custom_subgraph_property.h"
 #include "../operator/subgraph/subgraph_property.h"
+#include "../common/alm.h"
 #include "../common/utils.h"
 #include "../profiler/profiler.h"
 #include "../serialization/cnpy.h"
@@ -4004,3 +4005,15 @@ int MXCUDAProfilerStop() {
 #endif
   API_END();
 }
+
+int MXSetOptimizeLayout(bool val) {
+  API_BEGIN();
+  mxnet::alm::ALMParams::get().optimize = val;
+  API_END();
+}
+
+int MXGetOptimizeLayout(bool* val) {
+  API_BEGIN();
+  *val = mxnet::alm::ALMParams::get().optimize;
+  API_END();
+}
diff --git a/src/common/alm.cc b/src/common/alm.cc
new file mode 100644
index 0000000..3a38ee5
--- /dev/null
+++ b/src/common/alm.cc
@@ -0,0 +1,209 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file alm.cc
+ * \brief Automatic Layout Manager
+ * \author Dawid Tracz, Vladimir Cherepanov
+ */
+
+#include "alm.h"
+
+#include <algorithm>
+#include <sstream>
+#include <unordered_set>
+#include <utility>
+
+#include "../operator/nn/convolution-inl.h"
+#include "../operator/nn/deconvolution-inl.h"
+#include "../operator/tensor/matrix_op-inl.h"
+
+namespace mxnet {
+namespace alm {
+
+namespace {
+
+nnvm::ObjectPtr CreateTransposeNode(const std::string& name, const alm::Transpose& axes) {
+  nnvm::ObjectPtr newptr = nnvm::Node::Create();
+  newptr->attrs.op       = nnvm::Op::Get("transpose");
+  newptr->attrs.name     = name;
+  // set tranpose axes
+  std::ostringstream ss;
+  ss << mxnet::TShape(axes.begin(), axes.end());
+  newptr->attrs.dict["axes"] = ss.str();
+  newptr->op()->attr_parser(&(newptr->attrs));
+  return newptr;
+}
+
+mshadow::LayoutFlag TargetLayout(const nnvm::ObjectPtr& node) {
+  static const Op* conv_op   = Op::Get("Convolution");
+  static const Op* deconv_op = Op::Get("Deconvolution");
+
+  static const std::unordered_map<int, mshadow::LayoutFlag> ndim2layout{
+      {1, mshadow::kNWC},
+      {2, mshadow::kNHWC},
+      {3, mshadow::kNDHWC},
+  };
+
+  auto target_layout = [](const auto& param) {
+    auto it = ndim2layout.find(param.kernel.ndim());
+    CHECK(it != ndim2layout.end()) << "Unexpected kernel dimensions: " << param.kernel;
+    return it->second;
+  };
+
+  if (node->op() == conv_op)
+    return target_layout(nnvm::get<op::ConvolutionParam>(node->attrs.parsed));
+
+  if (node->op() == deconv_op)
+    return target_layout(nnvm::get<op::DeconvolutionParam>(node->attrs.parsed));
+
+  return mshadow::kUNKNOWN;
+}
+
+}  // namespace
+
+nnvm::Graph OptimizeLayout(nnvm::Graph&& g) {
+  static const auto& op_map     = Op::GetAttr<mxnet::alm::FChangeLayout>("FChangeLayout");
+  static const Op* transpose_op = Op::Get("transpose");
+  std::unordered_set<nnvm::ObjectPtr> outputs;
+  for (auto& o : g.outputs)
+    outputs.insert(o.node);
+  nnvm::NodeEntryMap<alm::Transpose> changed;
+  struct ToDelete {
+    nnvm::ObjectPtr node;  // output of the transpose
+    size_t input_idx;
+  };
+  std::vector<ToDelete> to_delete;
+  struct ToAdd {
+    nnvm::ObjectPtr node;
+    size_t input_idx;
+    alm::Transpose axes;
+  };
+  std::vector<ToAdd> to_add;
+  DFSVisit(g.outputs, [&outputs, &changed, &to_add, &to_delete](const nnvm::ObjectPtr& node) {
+    std::vector<alm::Transpose> input_axes(node->inputs.size());
+    for (size_t i = 0; i < node->inputs.size(); ++i) {
+      if (node->inputs[i].node->op() == transpose_op) {
+        const auto& param = nnvm::get<op::TransposeParam>(node->inputs[i].node->attrs.parsed);
+        if (IsIdentity(FromTShape(param.axes))) {
+          to_delete.push_back({node, i});
+          continue;
+        }
+      }
+      auto it = changed.find(node->inputs[i]);
+      if (it == changed.end())
+        continue;
+      input_axes[i] = it->second;
+    }
+    auto fchange = op_map.get(node->op(), nullptr);
+    if (fchange && outputs.count(node) == 0) {
+      std::vector<alm::Transpose> output_axes;
+      if (fchange(&node->attrs, TargetLayout(node), &input_axes, &output_axes))
+        node->op()->attr_parser(&node->attrs);
+      for (size_t i = 0; i < output_axes.size(); ++i) {
+        if (IsIdentity(output_axes[i]))
+          continue;
+        changed.insert(std::make_pair(nnvm::NodeEntry(node, i, 0), output_axes[i]));
+      }
+    }
+    for (size_t i = 0; i < input_axes.size(); ++i) {
+      if (IsIdentity(input_axes[i]))
+        continue;
+      to_add.push_back({node, i, input_axes[i]});
+    }
+  });
+  for (const auto& t : to_delete) {
+    auto& tnode = t.node->inputs[t.input_idx].node;
+    CHECK_EQ(tnode->inputs.size(), 1);
+    t.node->inputs[t.input_idx] = tnode->inputs[0];
+  }
+  size_t node_no = 0;
+  for (const auto& t : to_add) {
+    auto tnode = CreateTransposeNode("ALM_transpose_" + std::to_string(node_no++), t.axes);
+    tnode->inputs.push_back(t.node->inputs[t.input_idx]);
+    t.node->inputs[t.input_idx] = nnvm::NodeEntry(tnode);
+  }
+  nnvm::Graph ret;
+  ret.outputs = g.outputs;
+  return ret;
+}
+
+Transpose Reverse(const Transpose& axes) {
+  Transpose rev(axes.size());
+  for (size_t i = 0; i < rev.size(); i++)
+    rev[axes[i]] = i;
+  return rev;
+}
+
+Transpose Compose(const Transpose& lhs, const Transpose& rhs) {
+  if (lhs.empty())
+    return rhs;
+  if (rhs.empty())
+    return lhs;
+  CHECK_EQ(lhs.size(), rhs.size());
+  Transpose ret(lhs.size());
+  for (auto i = 0; i < ret.size(); ++i)
+    ret[i] = lhs[rhs[i]];
+  return ret;
+}
+
+bool IsIdentity(const Transpose& t) {
+  for (size_t i = 0; i < t.size(); ++i) {
+    if (t[i] != i)
+      return false;
+  }
+  return true;
+}
+
+mshadow::LayoutFlag ApplyTranspose(mshadow::LayoutFlag layout, const Transpose& axes) {
+  auto ret = mshadow::layoutFlag(ApplyTranspose(mshadow::toString(layout), axes));
+  CHECK_NE(ret, mshadow::kUNKNOWN);
+  return ret;
+}
+
+std::string ApplyTranspose(const std::string& layout, const Transpose& axes) {
+  std::string ret(layout.size(), ' ');
+  for (size_t i = 0; i < ret.size(); i++)
+    ret[i] = layout[axes[i]];
+  return ret;
+}
+
+Transpose FromTShape(const mxnet::TShape& s) {
+  Transpose ret(s.ndim());
+  std::copy(s.begin(), s.end(), ret.begin());
+  return ret;
+}
+
+Transpose FactorCommonTranspose(std::vector<Transpose>* axes) {
+  Transpose ret;
+  for (auto& t : *axes) {
+    if (IsIdentity(t))
+      continue;
+    if (IsIdentity(ret)) {
+      std::swap(t, ret);
+      continue;
+    }
+    auto rev = Reverse(ret);
+    t        = Compose(t, rev);
+  }
+  return ret;
+}
+
+}  // namespace alm
+}  // namespace mxnet
diff --git a/src/common/alm.h b/src/common/alm.h
new file mode 100644
index 0000000..923f4eb
--- /dev/null
+++ b/src/common/alm.h
@@ -0,0 +1,100 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file alm.h
+ * \brief Automatic Layout Manager
+ * \author Dawid Tracz, Vladimir Cherepanov
+ */
+
+#ifndef MXNET_COMMON_ALM_H_
+#define MXNET_COMMON_ALM_H_
+
+#include <mxnet/base.h>
+#include <nnvm/graph.h>
+#include <nnvm/node.h>
+#include <functional>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+namespace mxnet {
+namespace alm {
+
+/*!
+ *  \brief A singleton flag, set and read by MXSetOptimizeLayout and MXGetOptimizeLayout
+ */
+struct ALMParams {
+  bool optimize = false;
+
+  static ALMParams& get() {
+    static ALMParams alm;
+    return alm;
+  }
+};
+
+/*!
+ * \bried Top-level function to run layout optimization.
+ */
+nnvm::Graph OptimizeLayout(nnvm::Graph&& g);
+
+/*!
+ * \brief Transpose, represented by permutation of axes.
+ */
+using Transpose = std::vector<size_t>;
+
+bool IsIdentity(const Transpose& t);
+Transpose Reverse(const Transpose& axes);
+
+/*!
+ * \bried Compose 2 transposes. Not commutative: a * b means b is applied first, then a.
+ */
+Transpose Compose(const Transpose& lhs, const Transpose& rhs);
+
+mshadow::LayoutFlag ApplyTranspose(mshadow::LayoutFlag layout, const Transpose& axes);
+std::string ApplyTranspose(const std::string& layout, const Transpose& axes);
+
+Transpose FromTShape(const mxnet::TShape& s);
+
+/*!
+ * \brief May change operator's layout. Used in LayoutOptimization.
+ *
+ * \param target_layout The target layout to change to, or kUNKNOWN. In the latter case the target
+ * layout is calculated based on in_axes, with a goal to cancel them out (at least some, ideally -
+ * all).
+ * \param in_axes (in/out) On input - pending inputs' transposes. On output - inputs' transposes,
+ * required by the new layout.
+ * \param out_axes (out) Outputs' transposes, required to convert to the original layouts.
+ * \return true if attrs changed and params need to be reparsed.
+ */
+using FChangeLayout = std::function<bool(nnvm::NodeAttrs*,
+                                         mshadow::LayoutFlag target_layout,
+                                         std::vector<Transpose>* in_axes,
+                                         std::vector<Transpose>* out_axes)>;
+
+/*!
+ * \brief Factors out and returns a common transpose, or default-constructed Transpose if all
+ * axes (in/out parameter) are empty.
+ */
+Transpose FactorCommonTranspose(std::vector<Transpose>* axes);
+
+}  // namespace alm
+}  // namespace mxnet
+
+#endif  // MXNET_COMMON_ALM_H_
diff --git a/src/imperative/cached_op.h b/src/imperative/cached_op.h
index 97ac23c..079a56e 100644
--- a/src/imperative/cached_op.h
+++ b/src/imperative/cached_op.h
@@ -28,6 +28,7 @@
 #include <string>
 #include <unordered_map>
 #include <map>
+#include "../common/alm.h"
 #include "../operator/operator_common.h"
 #include "../operator/subgraph/common.h"
 #include "./imperative_utils.h"
@@ -208,6 +209,8 @@ void CreateForwardGraph(const nnvm::Symbol& sym, nnvm::Graph* fwd_graph) {
       fwd_graph->outputs.push_back(nodeEntry);
     }
   }
+  if (alm::ALMParams::get().optimize)
+    *fwd_graph = alm::OptimizeLayout(std::move(*fwd_graph));
 }
 
 /* \brief construct grad_graph from fwd_graph and ograd_entries*/
diff --git a/src/operator/cudnn_ops.cc b/src/operator/cudnn_ops.cc
index 2778f7b..e7e649f 100644
--- a/src/operator/cudnn_ops.cc
+++ b/src/operator/cudnn_ops.cc
@@ -433,7 +433,7 @@ cudnnBackendHeurMode_t HeurMode() {
 
 std::string ConvParamStr(const ConvParam& param) {
   std::ostringstream ss;
-  ss << " layout: " << param.layout.value();
+  ss << mshadow::toString(static_cast<mshadow::LayoutFlag>(param.layout.value()));
   ss << " kernel: " << param.kernel;
   ss << " stride: " << param.stride;
   ss << " dilate: " << param.dilate;
diff --git a/src/operator/elemwise_op_common.h b/src/operator/elemwise_op_common.h
index 27ed029..5884d99 100644
--- a/src/operator/elemwise_op_common.h
+++ b/src/operator/elemwise_op_common.h
@@ -36,6 +36,7 @@
 #include <utility>
 #include "./operator_common.h"
 #include "./mxnet_op.h"
+#include "../common/alm.h"
 
 namespace mxnet {
 namespace op {
@@ -197,6 +198,15 @@ inline bool ElemwiseType(const nnvm::NodeAttrs& attrs,
       attrs, in_attrs, out_attrs, -1);
 }
 
+inline bool ElemwiseChangeLayout(nnvm::NodeAttrs* attrs,
+                                 mshadow::LayoutFlag targetLayout,
+                                 std::vector<alm::Transpose>* inpTransposes,
+                                 std::vector<alm::Transpose>* outTransposes) {
+  CHECK_EQ(targetLayout, mshadow::kUNKNOWN);
+  outTransposes->assign(attrs->op->num_outputs, alm::FactorCommonTranspose(inpTransposes));
+  return false;
+}
+
 // Special case of ElemwiseType. Constrains dtype to integer types
 template <index_t n_in, index_t n_out>
 inline bool ElemwiseIntType(const nnvm::NodeAttrs& attrs,
diff --git a/src/operator/leaky_relu.cc b/src/operator/leaky_relu.cc
index ff2ce4a..39aa11d 100644
--- a/src/operator/leaky_relu.cc
+++ b/src/operator/leaky_relu.cc
@@ -24,6 +24,7 @@
  */
 
 #include "./leaky_relu-inl.h"
+#include "../common/alm.h"
 #if MXNET_USE_ONEDNN == 1
 #include "./nn/dnnl/dnnl_base-inl.h"
 #include "./nn/dnnl/dnnl_ops-inl.h"
@@ -145,6 +146,17 @@ inline static bool BackwardLeakyReLUStorageType(const nnvm::NodeAttrs& attrs,
 }
 #endif  // MXNET_USE_ONEDNN == 1
 
+static bool LRChangeLayout(nnvm::NodeAttrs* attrs,
+                           mshadow::LayoutFlag target_layout,
+                           std::vector<alm::Transpose>* in_axes,
+                           std::vector<alm::Transpose>* out_axes) {
+  CHECK_EQ(target_layout, mshadow::kUNKNOWN);
+  out_axes->assign(1, alm::FactorCommonTranspose(in_axes));
+  if (attrs->dict["act_type"] == "rrelu")
+    out_axes->resize(2);
+  return false;
+}
+
 NNVM_REGISTER_OP(LeakyReLU)
     .describe(R"code(Applies Leaky rectified linear unit activation element-wise to the input.
 
@@ -195,6 +207,7 @@ The following modified ReLU Activation functions are supported:
                                       })
     .set_attr<mxnet::FInferShape>("FInferShape", LeakyReLUShape)
     .set_attr<nnvm::FInferType>("FInferType", LeakyReLUType)
+    .set_attr<mxnet::alm::FChangeLayout>("FChangeLayout", LRChangeLayout)
     .set_attr<FCompute>("FCompute<cpu>", LeakyReLUCompute<cpu>)
 #if MXNET_USE_ONEDNN == 1
     .set_attr<bool>("TIsDNNL", true)
diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc
index d3502b9..04cc78a 100644
--- a/src/operator/nn/batch_norm.cc
+++ b/src/operator/nn/batch_norm.cc
@@ -27,6 +27,7 @@
 
 #include "../elemwise_op_common.h"
 #include "../operator_common.h"
+#include "../../common/alm.h"
 
 #include "batch_norm-inl.h"
 #if MXNET_USE_ONEDNN == 1
@@ -445,6 +446,21 @@ static bool BatchNormType(const nnvm::NodeAttrs& attrs,
   return true;
 }
 
+static bool BNChangeLayout(nnvm::NodeAttrs* attrs,
+                           mshadow::LayoutFlag targetLayout,
+                           std::vector<alm::Transpose>* inpTransposes,
+                           std::vector<alm::Transpose>* outTransposes) {
+  CHECK_EQ(targetLayout, mshadow::kUNKNOWN);
+  auto t = alm::FactorCommonTranspose(inpTransposes);
+  outTransposes->assign(1, t);
+  if (alm::IsIdentity(t))
+    return false;
+  const auto& param = nnvm::get<BatchNormParam>(attrs->parsed);
+  CHECK_LT(param.axis, t.size());
+  attrs->dict["axis"] = std::to_string(t[param.axis]);
+  return true;
+}
+
 #if MXNET_USE_ONEDNN == 1
 static inline bool SupportDNNLBN(const NDArray& input, const BatchNormParam& param) {
   if (mxnet::op::batchnorm::disable_mkl)
@@ -641,6 +657,7 @@ then set ``gamma`` to 1 and its gradient to 0.
                                    })
     .set_attr<mxnet::FInferShape>("FInferShape", BatchNormShape)
     .set_attr<nnvm::FInferType>("FInferType", BatchNormType)
+    .set_attr<mxnet::alm::FChangeLayout>("FChangeLayout", BNChangeLayout)
     .set_attr<FInferStorageType>("FInferStorageType", BatchNormStorageType)
     .set_attr<FCompute>("FCompute<cpu>", BatchNormCompute<cpu>)
 #if MXNET_USE_ONEDNN == 1
diff --git a/src/operator/nn/convolution-inl.h b/src/operator/nn/convolution-inl.h
index b611542..9994c7b 100644
--- a/src/operator/nn/convolution-inl.h
+++ b/src/operator/nn/convolution-inl.h
@@ -100,6 +100,7 @@ struct ConvolutionParam : public dmlc::Parameter<ConvolutionParam> {
         .add_enum("NCW", mshadow::kNCW)
         .add_enum("NCHW", mshadow::kNCHW)
         .add_enum("NCDHW", mshadow::kNCDHW)
+        .add_enum("NWC", mshadow::kNWC)
         .add_enum("NHWC", mshadow::kNHWC)
         .add_enum("NDHWC", mshadow::kNDHWC)
         .set_default(dmlc::optional<int>())
diff --git a/src/operator/nn/convolution.cc b/src/operator/nn/convolution.cc
index 787fbc0..a39fa3f 100644
--- a/src/operator/nn/convolution.cc
+++ b/src/operator/nn/convolution.cc
@@ -23,9 +23,12 @@
  * \author Bing Xu, Jun Wu, Da Zheng
  */
 
+#include <mshadow/base.h>
+#include <mshadow/tensor.h>
 #include "./convolution-inl.h"
 #include "../elemwise_op_common.h"
 #include "../operator_common.h"
+#include "../../common/alm.h"
 #if MXNET_USE_ONEDNN == 1
 #include "./dnnl/dnnl_base-inl.h"
 #include "./dnnl/dnnl_ops-inl.h"
@@ -79,6 +82,29 @@ static void ConvolutionGradComputeExCPU(const nnvm::NodeAttrs& attrs,
 }
 #endif
 
+static bool ConvChangeLayout(nnvm::NodeAttrs* attrs,
+                             mshadow::LayoutFlag target_layout,
+                             std::vector<alm::Transpose>* in_axes,
+                             std::vector<alm::Transpose>* out_axes) {
+  const auto& param = nnvm::get<ConvolutionParam>(attrs->parsed);
+  CHECK(param.layout) << "Current layout of convolution should be known: " << attrs->name;
+  auto layout = static_cast<mshadow::LayoutFlag>(param.layout.value());
+  auto t      = target_layout != mshadow::kUNKNOWN ?
+               mshadow::getTranspAxes<size_t>(layout, target_layout) :
+               alm::FactorCommonTranspose(in_axes);
+  out_axes->assign(1, alm::Reverse(t));
+  if (alm::IsIdentity(t))
+    return false;
+  if (target_layout != mshadow::kUNKNOWN) {
+    for (auto i : {0, 1})
+      in_axes->at(i) = alm::Compose(t, in_axes->at(i));
+  } else {
+    target_layout = alm::ApplyTranspose(layout, t);
+  }
+  attrs->dict["layout"] = mshadow::toString(target_layout);
+  return true;
+}
+
 static bool ConvolutionShape(const nnvm::NodeAttrs& attrs,
                              mxnet::ShapeVector* in_shape,
                              mxnet::ShapeVector* out_shape) {
@@ -502,6 +528,7 @@ There are other options to tune the performance.
                                       })
     .set_attr<mxnet::FInferShape>("FInferShape", ConvolutionShape)
     .set_attr<nnvm::FInferType>("FInferType", ConvolutionType)
+    .set_attr<mxnet::alm::FChangeLayout>("FChangeLayout", ConvChangeLayout)
 #if MXNET_USE_ONEDNN == 1
     .set_attr<FInferStorageType>("FInferStorageType", ConvStorageType)
 #endif
diff --git a/src/operator/nn/deconvolution.cc b/src/operator/nn/deconvolution.cc
index 86cde82..2bef3fc 100644
--- a/src/operator/nn/deconvolution.cc
+++ b/src/operator/nn/deconvolution.cc
@@ -25,6 +25,7 @@
 
 #include "./deconvolution-inl.h"
 #include "../operator_common.h"
+#include "../../common/alm.h"
 #include "../../common/utils.h"
 #if MXNET_USE_ONEDNN == 1
 #include "./dnnl/dnnl_base-inl.h"
@@ -401,6 +402,29 @@ struct DeconvolutionGrad {
   }
 };
 
+static bool DeconvChangeLayout(nnvm::NodeAttrs* attrs,
+                               mshadow::LayoutFlag target_layout,
+                               std::vector<alm::Transpose>* in_axes,
+                               std::vector<alm::Transpose>* out_axes) {
+  const auto& param = nnvm::get<DeconvolutionParam>(attrs->parsed);
+  CHECK(param.layout) << "Current layout of convolution should be known: " << attrs->name;
+  auto layout = static_cast<mshadow::LayoutFlag>(param.layout.value());
+  auto t      = target_layout != mshadow::kUNKNOWN ?
+               mshadow::getTranspAxes<size_t>(layout, target_layout) :
+               alm::FactorCommonTranspose(in_axes);
+  out_axes->assign(1, alm::Reverse(t));
+  if (alm::IsIdentity(t))
+    return false;
+  if (target_layout != mshadow::kUNKNOWN) {
+    for (auto i : {0, 1})
+      in_axes->at(i) = alm::Compose(t, in_axes->at(i));
+  } else {
+    target_layout = alm::ApplyTranspose(layout, t);
+  }
+  attrs->dict["layout"] = mshadow::toString(target_layout);
+  return true;
+}
+
 DMLC_REGISTER_PARAMETER(DeconvolutionParam);
 
 NNVM_REGISTER_OP(Deconvolution)
@@ -428,6 +452,7 @@ NNVM_REGISTER_OP(Deconvolution)
                                       })
     .set_attr<mxnet::FInferShape>("FInferShape", DeconvolutionShape)
     .set_attr<nnvm::FInferType>("FInferType", DeconvolutionType)
+    .set_attr<mxnet::alm::FChangeLayout>("FChangeLayout", DeconvChangeLayout)
     .set_attr<FResourceRequest>("FResourceRequest",
                                 [](const NodeAttrs& n) {
                                   return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
diff --git a/src/operator/nn/pooling.cc b/src/operator/nn/pooling.cc
index edb6a1e..7b302ee9 100644
--- a/src/operator/nn/pooling.cc
+++ b/src/operator/nn/pooling.cc
@@ -24,6 +24,7 @@
  */
 #include "../elemwise_op_common.h"
 #include "./pooling-inl.h"
+#include "../../common/alm.h"
 #if MXNET_USE_ONEDNN == 1
 #include "./dnnl/dnnl_base-inl.h"
 #include "./dnnl/dnnl_pooling-inl.h"
@@ -270,6 +271,22 @@ static bool PoolingShape(const nnvm::NodeAttrs& attrs,
   return true;
 }
 
+static bool PoolChangeLayout(nnvm::NodeAttrs* attrs,
+                             mshadow::LayoutFlag targetLayout,
+                             std::vector<alm::Transpose>* inpTransposes,
+                             std::vector<alm::Transpose>* outTransposes) {
+  CHECK_EQ(targetLayout, mshadow::kUNKNOWN);
+  const auto& param = nnvm::get<PoolingParam>(attrs->parsed);
+  CHECK(param.layout) << "Current layout of pooling should be known: " << attrs->name;
+  auto layout = static_cast<mshadow::LayoutFlag>(param.layout.value());
+  auto t      = alm::FactorCommonTranspose(inpTransposes);
+  if (alm::IsIdentity(t))
+    return false;
+  outTransposes->assign(1, t);
+  attrs->dict["layout"] = mshadow::toString(alm::ApplyTranspose(layout, alm::Reverse(t)));
+  return true;
+}
+
 #if MXNET_USE_ONEDNN == 1
 void PoolingComputeExCPU(const nnvm::NodeAttrs& attrs,
                          const OpContext& ctx,
@@ -443,6 +460,7 @@ For each window ``X``, the mathematical expression for Lp pooling is:
 #endif
     .set_attr<nnvm::FInferType>("FInferType", PoolingType)
     .set_attr<mxnet::FInferShape>("FInferShape", PoolingShape)
+    .set_attr<mxnet::alm::FChangeLayout>("FChangeLayout", PoolChangeLayout)
     .set_attr<FCompute>("FCompute<cpu>", PoolingCompute<cpu>)
 #if MXNET_USE_ONEDNN == 1
     .set_attr<bool>("TIsDNNL", true)
diff --git a/src/operator/operator_common.h b/src/operator/operator_common.h
index 8c5beec..9a219aa 100644
--- a/src/operator/operator_common.h
+++ b/src/operator/operator_common.h
@@ -42,6 +42,7 @@
 #include "../common/utils.h"
 
 namespace mxnet {
+
 namespace op {
 /*!
  * \brief assign the expression to out according to request
diff --git a/src/operator/tensor/amp_cast.cc b/src/operator/tensor/amp_cast.cc
index 62e63a1..1899a4c 100644
--- a/src/operator/tensor/amp_cast.cc
+++ b/src/operator/tensor/amp_cast.cc
@@ -23,10 +23,25 @@
  */
 
 #include "./amp_cast.h"
+#include "../../common/alm.h"
 
 namespace mxnet {
 namespace op {
 
+static bool MCastChangeLayout(nnvm::NodeAttrs* attrs,
+                              mshadow::LayoutFlag targetLayout,
+                              std::vector<alm::Transpose>* inpTransposes,
+                              std::vector<alm::Transpose>* outTransposes) {
+  auto n_inps = attrs->op->get_num_inputs(*attrs);
+  auto n_outs = attrs->op->get_num_outputs(*attrs);
+  CHECK_EQ(n_inps, n_outs) << "This operator should have the same number inputs and outputs";
+  CHECK_EQ(inpTransposes->size(), n_inps);
+  CHECK_EQ(targetLayout, mshadow::kUNKNOWN);
+  *outTransposes = std::move(*inpTransposes);
+  inpTransposes->assign(n_inps, alm::Transpose());
+  return false;
+}
+
 DMLC_REGISTER_PARAMETER(AMPCastParam);
 DMLC_REGISTER_PARAMETER(AMPMultiCastParam);
 
@@ -135,6 +150,7 @@ It casts only between low precision float/FP32 and does not do anything for othe
     .set_attr_parser(ParamParser<AMPCastParam>)
     .set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<1, 1>)
     .set_attr<nnvm::FInferType>("FInferType", AMPCastType)
+    .set_attr<mxnet::alm::FChangeLayout>("FChangeLayout", ElemwiseChangeLayout)
     .set_attr<nnvm::FInplaceOption>("FInplaceOption",
                                     [](const NodeAttrs& attrs) {
                                       return std::vector<std::pair<int, int>>{{0, 0}};
@@ -188,6 +204,7 @@ It casts only between low precision float/FP32 and does not do anything for othe
     .set_attr_parser(ParamParser<AMPMultiCastParam>)
     .set_attr<mxnet::FInferShape>("FInferShape", AMPMultiCastShape)
     .set_attr<nnvm::FInferType>("FInferType", AMPMultiCastType)
+    .set_attr<mxnet::alm::FChangeLayout>("FChangeLayout", MCastChangeLayout)
     .set_attr<nnvm::FListInputNames>("FListInputNames",
                                      [](const NodeAttrs& attrs) {
                                        uint32_t num_args =
diff --git a/src/operator/tensor/elemwise_binary_op.h b/src/operator/tensor/elemwise_binary_op.h
index 4f36b8a..732b6a5 100644
--- a/src/operator/tensor/elemwise_binary_op.h
+++ b/src/operator/tensor/elemwise_binary_op.h
@@ -813,6 +813,7 @@ class ElemwiseBinaryOp : public OpBase {
                                        })                                                         \
       .set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<2, 1>)                           \
       .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<2, 1>)                               \
+      .set_attr<mxnet::alm::FChangeLayout>("FChangeLayout", ElemwiseChangeLayout)                 \
       .set_attr<nnvm::FInplaceOption>("FInplaceOption",                                           \
                                       [](const NodeAttrs& attrs) {                                \
                                         return std::vector<std::pair<int, int> >{{0, 0}, {1, 0}}; \
diff --git a/src/operator/tensor/elemwise_binary_scalar_op.h b/src/operator/tensor/elemwise_binary_scalar_op.h
index aa6b7f5..8c025ef 100644
--- a/src/operator/tensor/elemwise_binary_scalar_op.h
+++ b/src/operator/tensor/elemwise_binary_scalar_op.h
@@ -31,6 +31,7 @@
 #include <string>
 #include "../mshadow_op.h"
 #include "../elemwise_op_common.h"
+#include "../../common/alm.h"
 #include "elemwise_unary_op.h"
 
 namespace mxnet {
@@ -447,6 +448,7 @@ class BinaryScalarOp : public UnaryOp {
       .set_attr_parser(ParamParser<NumpyBinaryScalarParam>)                               \
       .set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<1, 1>)                   \
       .set_attr<nnvm::FInferType>("FInferType", NumpyBinaryScalarType)                    \
+      .set_attr<mxnet::alm::FChangeLayout>("FChangeLayout", ElemwiseChangeLayout)         \
       .set_attr<nnvm::FInplaceOption>("FInplaceOption",                                   \
                                       [](const NodeAttrs& attrs) {                        \
                                         return std::vector<std::pair<int, int> >{{0, 0}}; \
diff --git a/src/operator/tensor/elemwise_unary_op.h b/src/operator/tensor/elemwise_unary_op.h
index 5d23c98..0048777 100644
--- a/src/operator/tensor/elemwise_unary_op.h
+++ b/src/operator/tensor/elemwise_unary_op.h
@@ -35,6 +35,7 @@
 #include "../mshadow_op.h"
 #include "../mxnet_op.h"
 #include "../elemwise_op_common.h"
+#include "../../common/alm.h"
 #include "../../common/utils.h"
 #include "../../ndarray/ndarray_function.h"
 
@@ -865,6 +866,7 @@ void NumpyNanToNumOpBackward(const nnvm::NodeAttrs& attrs,
       .set_num_outputs(1)                                                                 \
       .set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<1, 1>)                   \
       .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)                       \
+      .set_attr<mxnet::alm::FChangeLayout>("FChangeLayout", ElemwiseChangeLayout)         \
       .set_attr<nnvm::FInplaceOption>("FInplaceOption",                                   \
                                       [](const NodeAttrs& attrs) {                        \
                                         return std::vector<std::pair<int, int> >{{0, 0}}; \
diff --git a/src/operator/tensor/matrix_op.cc b/src/operator/tensor/matrix_op.cc
index 787eb5c..b65c7cb 100644
--- a/src/operator/tensor/matrix_op.cc
+++ b/src/operator/tensor/matrix_op.cc
@@ -334,6 +334,22 @@ inline static bool TransposeStorageType(const nnvm::NodeAttrs& attrs,
 }
 #endif
 
+static bool TransposeChangeLayout(nnvm::NodeAttrs* attrs,
+                                  mshadow::LayoutFlag target_layout,
+                                  std::vector<alm::Transpose>* in_axes,
+                                  std::vector<alm::Transpose>* out_axes) {
+  CHECK_EQ(target_layout, mshadow::kUNKNOWN);
+  CHECK_EQ(in_axes->size(), 1);
+  const auto& param = nnvm::get<TransposeParam>(attrs->parsed);
+  auto new_axes     = alm::Compose(alm::FromTShape(param.axes), in_axes->at(0));
+  std::ostringstream ss;
+  ss << mxnet::TShape(new_axes.begin(), new_axes.end());
+  attrs->dict["axes"] = ss.str();
+  in_axes->assign(1, alm::Transpose());
+  out_axes->assign(1, alm::Transpose());
+  return true;
+}
+
 NNVM_REGISTER_OP(transpose)
     .describe(R"code(Permutes the dimensions of an array.
 Examples::
@@ -360,6 +376,7 @@ Examples::
     .set_attr_parser(ParamParser<TransposeParam>)
     .set_attr<mxnet::FInferShape>("FInferShape", TransposeShape)
     .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
+    .set_attr<mxnet::alm::FChangeLayout>("FChangeLayout", TransposeChangeLayout)
     .set_attr<nnvm::FGradient>(
         "FGradient",
         [](const nnvm::ObjectPtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
diff --git a/tests/python/gpu/test_amp_init.py b/tests/python/gpu/test_amp_init.py
index 2980366..28d1123 100644
--- a/tests/python/gpu/test_amp_init.py
+++ b/tests/python/gpu/test_amp_init.py
@@ -15,12 +15,18 @@
 # specific language governing permissions and limitations
 # under the License.
 
-import mxnet as mx
-from mxnet.gluon import nn
-from mxnet import amp
+from contextlib import contextmanager
+import ctypes
+
 import numpy as np
 import pytest
 
+import mxnet as mx
+from mxnet import amp
+from mxnet.base import check_call, _LIB
+from mxnet.gluon import nn
+from mxnet.test_utils import assert_allclose
+
 
 @pytest.fixture
 def np_shape_array():
@@ -35,6 +41,17 @@ def amp_init():
     amp.init()
 
 
+@contextmanager
+def optimize_layout(optimize=True):
+    prev = ctypes.c_bool()
+    check_call(_LIB.MXGetOptimizeLayout(ctypes.byref(prev)))
+    check_call(_LIB.MXSetOptimizeLayout(ctypes.c_bool(optimize)))
+    try:
+        yield
+    finally:
+        check_call(_LIB.MXSetOptimizeLayout(prev))
+
+
 def test_npi_concatenate_multicast(np_shape_array, amp_init):
     class Foo(nn.HybridBlock):
         def __init__(self, **kwargs):
@@ -51,3 +68,76 @@ def test_npi_concatenate_multicast(np_shape_array, amp_init):
     data = mx.np.ones((32, 8), ctx=mx.gpu())
     out = foo(data)
     assert out.dtype == np.float32
+
+
+CONV = {1: nn.Conv1D, 2: nn.Conv2D, 3: nn.Conv3D}
+MAX_POOL = {1: nn.MaxPool1D, 2: nn.MaxPool2D, 3: nn.MaxPool3D}
+
+
+class Conv(nn.HybridBlock):
+    def __init__(self, ndim, **kwargs):
+        super().__init__(**kwargs)
+        self.conv = CONV[ndim](10, 3)
+
+    def forward(self, x):
+        y = self.conv(x)
+        return y * 2
+
+
+class ConvBN(nn.HybridBlock):
+    def __init__(self, ndim, **kwargs):
+        super().__init__(**kwargs)
+        self.conv = CONV[ndim](10, 3)
+        self.bn = nn.BatchNorm()
+
+    def forward(self, x):
+        y = self.conv(x)
+        y = self.bn(y)
+        return y * 2 + 10
+
+
+class PoolConv(nn.HybridBlock):
+    def __init__(self, ndim, **kwargs):
+        super().__init__(**kwargs)
+        self.pool = MAX_POOL[ndim]()
+        self.conv = CONV[ndim](10, 3)
+
+    def forward(self, x):
+        y = self.pool(x)
+        y = self.conv(y)
+        return y * 2
+
+
+@pytest.mark.skipif(not mx.runtime.Features().is_enabled('CUDNN'),
+                    reason='Channel-last layouts are only supported with cuDNN.')
+@pytest.mark.parametrize('ndim', [1, 2, 3])
+@pytest.mark.parametrize('model', [Conv, ConvBN, PoolConv])
+def test_optimize_layout(np_shape_array, amp_init, model, ndim):
+    m = model(ndim)
+    m.initialize(ctx=mx.gpu())
+    m.hybridize()
+    x = mx.np.random.uniform(low=0, high=10, size=(32, 2, 17, 15, 12)[:ndim + 2], ctx=mx.gpu())
+    m(x)
+    param_init = {k:v.data().copy() for k, v in m.collect_params().items()}
+    for v in m.collect_params().values():
+        v.data().attach_grad()
+    with mx.autograd.record():
+        y = m(x)
+    y.backward()
+    with optimize_layout():
+        m2 = model(ndim)
+        m2.initialize(ctx=mx.gpu())
+        m2.load_dict(param_init, device=mx.gpu())
+        m2.hybridize()
+        for v in m2.collect_params().values():
+            v.data().attach_grad()
+        with mx.autograd.record():
+            y2 = m2(x)
+        y2.backward()
+    rtol = 1e-2
+    atol = 1e-2
+    assert_allclose(y2, y, rtol=rtol, atol=atol)
+    for k, v in m.collect_params().items():
+        if v.grad_req == 'null':
+            continue
+        assert_allclose(m2.collect_params()[k].grad(), v.grad(), rtol=rtol, atol=atol)