You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by pt...@apache.org on 2021/12/02 18:16:26 UTC
[incubator-mxnet] branch master updated: Automatic Layout Management (#20718)
This is an automated email from the ASF dual-hosted git repository.
ptrendx pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 40359ce Automatic Layout Management (#20718)
40359ce is described below
commit 40359ceda150ca75da6e45b1ea35d747ef53deac
Author: Vladimir Cherepanov <56...@users.noreply.github.com>
AuthorDate: Thu Dec 2 10:14:13 2021 -0800
Automatic Layout Management (#20718)
* Automatic Layout Management
Originally authored by Dawid Tracz <dt...@nvidia.com>
* Fix clang-format
* Fix clang-format in mshadow
* Print layout name instead of a number
* Generalize NHWC target layout to other dimensions
* Change layout optimization API
* Add layout optimization tests
* Add backward check to tests
* Generalize tests to 1..3 spatial dims
* Add NWC layout to ConvolutionParams
* Enable layout optimization tests only with cuDNN
Co-authored-by: Vladimir Cherepanov <vc...@nvidia.com>
---
3rdparty/mshadow/mshadow/base.h | 60 +++++++
3rdparty/mshadow/mshadow/tensor.h | 91 +++++++++++
include/mxnet/c_api.h | 10 ++
python/mxnet/amp/amp.py | 8 +-
src/c_api/c_api.cc | 13 ++
src/common/alm.cc | 209 ++++++++++++++++++++++++
src/common/alm.h | 100 ++++++++++++
src/imperative/cached_op.h | 3 +
src/operator/cudnn_ops.cc | 2 +-
src/operator/elemwise_op_common.h | 10 ++
src/operator/leaky_relu.cc | 13 ++
src/operator/nn/batch_norm.cc | 17 ++
src/operator/nn/convolution-inl.h | 1 +
src/operator/nn/convolution.cc | 27 +++
src/operator/nn/deconvolution.cc | 25 +++
src/operator/nn/pooling.cc | 18 ++
src/operator/operator_common.h | 1 +
src/operator/tensor/amp_cast.cc | 17 ++
src/operator/tensor/elemwise_binary_op.h | 1 +
src/operator/tensor/elemwise_binary_scalar_op.h | 2 +
src/operator/tensor/elemwise_unary_op.h | 2 +
src/operator/tensor/matrix_op.cc | 17 ++
tests/python/gpu/test_amp_init.py | 96 ++++++++++-
23 files changed, 737 insertions(+), 6 deletions(-)
diff --git a/3rdparty/mshadow/mshadow/base.h b/3rdparty/mshadow/mshadow/base.h
index e018551..5f6fb0c 100644
--- a/3rdparty/mshadow/mshadow/base.h
+++ b/3rdparty/mshadow/mshadow/base.h
@@ -496,6 +496,8 @@ const int index_type_flag = DataType<lapack_index_t>::kFlag;
/*! layout flag */
enum LayoutFlag {
+ kUNKNOWN = -1,
+
kNCHW = 0,
kNHWC,
kCHWN,
@@ -509,6 +511,64 @@ enum LayoutFlag {
kCDHWN
};
+inline LayoutFlag layoutFlag(std::string layoutstr) {
+ switch (layoutstr.length()) {
+ case 4:
+ if (layoutstr == "NHWC")
+ return kNHWC;
+ if (layoutstr == "NCHW")
+ return kNCHW;
+ if (layoutstr == "CHWN")
+ return kCHWN;
+ return kUNKNOWN;
+ case 3:
+ if (layoutstr == "NWC")
+ return kNWC;
+ if (layoutstr == "NCW")
+ return kNCW;
+ if (layoutstr == "CWN")
+ return kCWN;
+ return kUNKNOWN;
+ case 5:
+ if (layoutstr == "NDHWC")
+ return kNDHWC;
+ if (layoutstr == "NCDHW")
+ return kNCDHW;
+ if (layoutstr == "CDHWN")
+ return kCDHWN;
+ return kUNKNOWN;
+ default:
+ return kUNKNOWN;
+ }
+}
+
+inline std::string toString(LayoutFlag layout) {
+ switch (layout) {
+ case kUNKNOWN:
+ return "";
+ case kNCHW:
+ return "NCHW";
+ case kNHWC:
+ return "NHWC";
+ case kCHWN:
+ return "CHWN";
+ case kNCW:
+ return "NCW";
+ case kNWC:
+ return "NWC";
+ case kCWN:
+ return "CWN";
+ case kNCDHW:
+ return "NCDHW";
+ case kNDHWC:
+ return "NDHWC";
+ case kCDHWN:
+ return "CDHWN";
+ default:
+ return "";
+ }
+}
+
template<int layout>
struct LayoutType;
diff --git a/3rdparty/mshadow/mshadow/tensor.h b/3rdparty/mshadow/mshadow/tensor.h
index e417fbb..fdf5e06 100644
--- a/3rdparty/mshadow/mshadow/tensor.h
+++ b/3rdparty/mshadow/mshadow/tensor.h
@@ -391,6 +391,97 @@ inline Shape<5> ConvertLayout(const Shape<5>& src, int src_layout, int dst_layou
}
/*!
+ * \brief returns axes of transpose operation
+ * that needs to be performed between src layout and dst
+ * \param src_layout input layout
+ * \param dst_layout output layout
+ * \return vector of required type describing axes of a transpose operation
+ */
+template <typename dim_t>
+inline std::vector<dim_t> getTranspAxes(const LayoutFlag src_layout, const LayoutFlag dst_layout) {
+ auto apply = [](const std::vector<dim_t>& v, const std::vector<dim_t>& op) {
+ CHECK_EQ(v.size(), op.size()) << "Layout ndims does not match";
+ std::vector<dim_t> ret(v.size());
+ for (size_t i = 0; i < v.size(); i++) {
+ ret[i] = v[op[i]];
+ }
+ return ret;
+ };
+ std::vector<dim_t> axes;
+ // transpose from `case` to ND?H?WC
+ switch (src_layout) {
+ case kUNKNOWN:
+ LOG(FATAL) << "Unknown source layout";
+ break;
+ case kNHWC:
+ axes = std::vector<dim_t>({0, 1, 2, 3});
+ break;
+ case kNCHW:
+ axes = std::vector<dim_t>({0, 2, 3, 1});
+ break;
+ case kCHWN:
+ axes = std::vector<dim_t>({3, 1, 2, 0});
+ break;
+ case kNWC:
+ axes = std::vector<dim_t>({0, 1, 2});
+ break;
+ case kNCW:
+ axes = std::vector<dim_t>({0, 2, 1});
+ break;
+ case kCWN:
+ axes = std::vector<dim_t>({2, 1, 0});
+ break;
+ case kNDHWC:
+ axes = std::vector<dim_t>({0, 1, 2, 3, 4});
+ break;
+ case kNCDHW:
+ axes = std::vector<dim_t>({0, 2, 3, 4, 1});
+ break;
+ case kCDHWN:
+ axes = std::vector<dim_t>({4, 1, 2, 3, 0});
+ break;
+ default:
+ LOG(FATAL) << "Invalid source layout " << src_layout;
+ }
+ // transpose from ND?H?WC to `case`
+ switch (dst_layout) {
+ case kUNKNOWN:
+ LOG(FATAL) << "Unknown destination layout";
+ break;
+ case kNHWC:
+ axes = apply(axes, {0, 1, 2, 3});
+ break;
+ case kNCHW:
+ axes = apply(axes, {0, 3, 1, 2});
+ break;
+ case kCHWN:
+ axes = apply(axes, {3, 1, 2, 0});
+ break;
+ case kNWC:
+ axes = apply(axes, {0, 1, 2});
+ break;
+ case kNCW:
+ axes = apply(axes, {0, 2, 1});
+ break;
+ case kCWN:
+ axes = apply(axes, {2, 1, 0});
+ break;
+ case kNDHWC:
+ axes = apply(axes, {0, 1, 2, 3, 4});
+ break;
+ case kNCDHW:
+ axes = apply(axes, {0, 4, 1, 2, 3});
+ break;
+ case kCDHWN:
+ axes = apply(axes, {4, 1, 2, 3, 0});
+ break;
+ default:
+ LOG(FATAL) << "Invalid destination layout " << src_layout;
+ }
+ return axes;
+}
+
+/*!
* \brief computaion stream structure, used for asynchronous computations
*/
template<typename Device>
diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h
index b25ccad..94609de 100644
--- a/include/mxnet/c_api.h
+++ b/include/mxnet/c_api.h
@@ -3161,6 +3161,16 @@ MXNET_DLL int MXCUDAProfilerStart();
*/
MXNET_DLL int MXCUDAProfilerStop();
+/*!
+ * \brief Turns on or off Layout Optimization
+ */
+MXNET_DLL int MXSetOptimizeLayout(bool val);
+
+/*!
+ * \brief Get current Layout Optimization status
+ */
+MXNET_DLL int MXGetOptimizeLayout(bool* val);
+
#ifdef __cplusplus
}
#endif // __cplusplus
diff --git a/python/mxnet/amp/amp.py b/python/mxnet/amp/amp.py
index c7aab71..750b3d0 100644
--- a/python/mxnet/amp/amp.py
+++ b/python/mxnet/amp/amp.py
@@ -307,7 +307,7 @@ def warn_if_model_exists():
return
def init(target_dtype='float16', target_precision_ops=None,
- conditional_fp32_ops=None, fp32_ops=None):
+ conditional_fp32_ops=None, fp32_ops=None, layout_optimization=False):
"""Initialize AMP (automatic mixed precision).
This needs to be done before model creation.
@@ -333,7 +333,11 @@ def init(target_dtype='float16', target_precision_ops=None,
assert target_dtype in ['float16', np.float16, 'bfloat16', bfloat16], \
"AMP currently supports only float16 or bfloat16 as a target_dtype"
_amp_initialized = True
- logging.info("Using AMP")
+ log_msg = "Using AMP"
+ if layout_optimization:
+ log_msg += "\n - layout optimization: enabled"
+ check_call(_LIB.MXSetOptimizeLayout(ctypes.c_bool(True)))
+ logging.info(log_msg)
if target_dtype == "bfloat16":
target_dtype = bfloat16
else:
diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc
index 61a47b0..d533a2a 100644
--- a/src/c_api/c_api.cc
+++ b/src/c_api/c_api.cc
@@ -55,6 +55,7 @@
#include "../operator/tvmop/op_module.h"
#include "../operator/subgraph/partitioner/custom_subgraph_property.h"
#include "../operator/subgraph/subgraph_property.h"
+#include "../common/alm.h"
#include "../common/utils.h"
#include "../profiler/profiler.h"
#include "../serialization/cnpy.h"
@@ -4004,3 +4005,15 @@ int MXCUDAProfilerStop() {
#endif
API_END();
}
+
+int MXSetOptimizeLayout(bool val) {
+ API_BEGIN();
+ mxnet::alm::ALMParams::get().optimize = val;
+ API_END();
+}
+
+int MXGetOptimizeLayout(bool* val) {
+ API_BEGIN();
+ *val = mxnet::alm::ALMParams::get().optimize;
+ API_END();
+}
diff --git a/src/common/alm.cc b/src/common/alm.cc
new file mode 100644
index 0000000..3a38ee5
--- /dev/null
+++ b/src/common/alm.cc
@@ -0,0 +1,209 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file alm.cc
+ * \brief Automatic Layout Manager
+ * \author Dawid Tracz, Vladimir Cherepanov
+ */
+
+#include "alm.h"
+
+#include <algorithm>
+#include <sstream>
+#include <unordered_set>
+#include <utility>
+
+#include "../operator/nn/convolution-inl.h"
+#include "../operator/nn/deconvolution-inl.h"
+#include "../operator/tensor/matrix_op-inl.h"
+
+namespace mxnet {
+namespace alm {
+
+namespace {
+
+nnvm::ObjectPtr CreateTransposeNode(const std::string& name, const alm::Transpose& axes) {
+ nnvm::ObjectPtr newptr = nnvm::Node::Create();
+ newptr->attrs.op = nnvm::Op::Get("transpose");
+ newptr->attrs.name = name;
+ // set tranpose axes
+ std::ostringstream ss;
+ ss << mxnet::TShape(axes.begin(), axes.end());
+ newptr->attrs.dict["axes"] = ss.str();
+ newptr->op()->attr_parser(&(newptr->attrs));
+ return newptr;
+}
+
+mshadow::LayoutFlag TargetLayout(const nnvm::ObjectPtr& node) {
+ static const Op* conv_op = Op::Get("Convolution");
+ static const Op* deconv_op = Op::Get("Deconvolution");
+
+ static const std::unordered_map<int, mshadow::LayoutFlag> ndim2layout{
+ {1, mshadow::kNWC},
+ {2, mshadow::kNHWC},
+ {3, mshadow::kNDHWC},
+ };
+
+ auto target_layout = [](const auto& param) {
+ auto it = ndim2layout.find(param.kernel.ndim());
+ CHECK(it != ndim2layout.end()) << "Unexpected kernel dimensions: " << param.kernel;
+ return it->second;
+ };
+
+ if (node->op() == conv_op)
+ return target_layout(nnvm::get<op::ConvolutionParam>(node->attrs.parsed));
+
+ if (node->op() == deconv_op)
+ return target_layout(nnvm::get<op::DeconvolutionParam>(node->attrs.parsed));
+
+ return mshadow::kUNKNOWN;
+}
+
+} // namespace
+
+nnvm::Graph OptimizeLayout(nnvm::Graph&& g) {
+ static const auto& op_map = Op::GetAttr<mxnet::alm::FChangeLayout>("FChangeLayout");
+ static const Op* transpose_op = Op::Get("transpose");
+ std::unordered_set<nnvm::ObjectPtr> outputs;
+ for (auto& o : g.outputs)
+ outputs.insert(o.node);
+ nnvm::NodeEntryMap<alm::Transpose> changed;
+ struct ToDelete {
+ nnvm::ObjectPtr node; // output of the transpose
+ size_t input_idx;
+ };
+ std::vector<ToDelete> to_delete;
+ struct ToAdd {
+ nnvm::ObjectPtr node;
+ size_t input_idx;
+ alm::Transpose axes;
+ };
+ std::vector<ToAdd> to_add;
+ DFSVisit(g.outputs, [&outputs, &changed, &to_add, &to_delete](const nnvm::ObjectPtr& node) {
+ std::vector<alm::Transpose> input_axes(node->inputs.size());
+ for (size_t i = 0; i < node->inputs.size(); ++i) {
+ if (node->inputs[i].node->op() == transpose_op) {
+ const auto& param = nnvm::get<op::TransposeParam>(node->inputs[i].node->attrs.parsed);
+ if (IsIdentity(FromTShape(param.axes))) {
+ to_delete.push_back({node, i});
+ continue;
+ }
+ }
+ auto it = changed.find(node->inputs[i]);
+ if (it == changed.end())
+ continue;
+ input_axes[i] = it->second;
+ }
+ auto fchange = op_map.get(node->op(), nullptr);
+ if (fchange && outputs.count(node) == 0) {
+ std::vector<alm::Transpose> output_axes;
+ if (fchange(&node->attrs, TargetLayout(node), &input_axes, &output_axes))
+ node->op()->attr_parser(&node->attrs);
+ for (size_t i = 0; i < output_axes.size(); ++i) {
+ if (IsIdentity(output_axes[i]))
+ continue;
+ changed.insert(std::make_pair(nnvm::NodeEntry(node, i, 0), output_axes[i]));
+ }
+ }
+ for (size_t i = 0; i < input_axes.size(); ++i) {
+ if (IsIdentity(input_axes[i]))
+ continue;
+ to_add.push_back({node, i, input_axes[i]});
+ }
+ });
+ for (const auto& t : to_delete) {
+ auto& tnode = t.node->inputs[t.input_idx].node;
+ CHECK_EQ(tnode->inputs.size(), 1);
+ t.node->inputs[t.input_idx] = tnode->inputs[0];
+ }
+ size_t node_no = 0;
+ for (const auto& t : to_add) {
+ auto tnode = CreateTransposeNode("ALM_transpose_" + std::to_string(node_no++), t.axes);
+ tnode->inputs.push_back(t.node->inputs[t.input_idx]);
+ t.node->inputs[t.input_idx] = nnvm::NodeEntry(tnode);
+ }
+ nnvm::Graph ret;
+ ret.outputs = g.outputs;
+ return ret;
+}
+
+Transpose Reverse(const Transpose& axes) {
+ Transpose rev(axes.size());
+ for (size_t i = 0; i < rev.size(); i++)
+ rev[axes[i]] = i;
+ return rev;
+}
+
+Transpose Compose(const Transpose& lhs, const Transpose& rhs) {
+ if (lhs.empty())
+ return rhs;
+ if (rhs.empty())
+ return lhs;
+ CHECK_EQ(lhs.size(), rhs.size());
+ Transpose ret(lhs.size());
+ for (auto i = 0; i < ret.size(); ++i)
+ ret[i] = lhs[rhs[i]];
+ return ret;
+}
+
+bool IsIdentity(const Transpose& t) {
+ for (size_t i = 0; i < t.size(); ++i) {
+ if (t[i] != i)
+ return false;
+ }
+ return true;
+}
+
+mshadow::LayoutFlag ApplyTranspose(mshadow::LayoutFlag layout, const Transpose& axes) {
+ auto ret = mshadow::layoutFlag(ApplyTranspose(mshadow::toString(layout), axes));
+ CHECK_NE(ret, mshadow::kUNKNOWN);
+ return ret;
+}
+
+std::string ApplyTranspose(const std::string& layout, const Transpose& axes) {
+ std::string ret(layout.size(), ' ');
+ for (size_t i = 0; i < ret.size(); i++)
+ ret[i] = layout[axes[i]];
+ return ret;
+}
+
+Transpose FromTShape(const mxnet::TShape& s) {
+ Transpose ret(s.ndim());
+ std::copy(s.begin(), s.end(), ret.begin());
+ return ret;
+}
+
+Transpose FactorCommonTranspose(std::vector<Transpose>* axes) {
+ Transpose ret;
+ for (auto& t : *axes) {
+ if (IsIdentity(t))
+ continue;
+ if (IsIdentity(ret)) {
+ std::swap(t, ret);
+ continue;
+ }
+ auto rev = Reverse(ret);
+ t = Compose(t, rev);
+ }
+ return ret;
+}
+
+} // namespace alm
+} // namespace mxnet
diff --git a/src/common/alm.h b/src/common/alm.h
new file mode 100644
index 0000000..923f4eb
--- /dev/null
+++ b/src/common/alm.h
@@ -0,0 +1,100 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file alm.h
+ * \brief Automatic Layout Manager
+ * \author Dawid Tracz, Vladimir Cherepanov
+ */
+
+#ifndef MXNET_COMMON_ALM_H_
+#define MXNET_COMMON_ALM_H_
+
+#include <mxnet/base.h>
+#include <nnvm/graph.h>
+#include <nnvm/node.h>
+#include <functional>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+namespace mxnet {
+namespace alm {
+
+/*!
+ * \brief A singleton flag, set and read by MXSetOptimizeLayout and MXGetOptimizeLayout
+ */
+struct ALMParams {
+ bool optimize = false;
+
+ static ALMParams& get() {
+ static ALMParams alm;
+ return alm;
+ }
+};
+
+/*!
+ * \bried Top-level function to run layout optimization.
+ */
+nnvm::Graph OptimizeLayout(nnvm::Graph&& g);
+
+/*!
+ * \brief Transpose, represented by permutation of axes.
+ */
+using Transpose = std::vector<size_t>;
+
+bool IsIdentity(const Transpose& t);
+Transpose Reverse(const Transpose& axes);
+
+/*!
+ * \bried Compose 2 transposes. Not commutative: a * b means b is applied first, then a.
+ */
+Transpose Compose(const Transpose& lhs, const Transpose& rhs);
+
+mshadow::LayoutFlag ApplyTranspose(mshadow::LayoutFlag layout, const Transpose& axes);
+std::string ApplyTranspose(const std::string& layout, const Transpose& axes);
+
+Transpose FromTShape(const mxnet::TShape& s);
+
+/*!
+ * \brief May change operator's layout. Used in LayoutOptimization.
+ *
+ * \param target_layout The target layout to change to, or kUNKNOWN. In the latter case the target
+ * layout is calculated based on in_axes, with a goal to cancel them out (at least some, ideally -
+ * all).
+ * \param in_axes (in/out) On input - pending inputs' transposes. On output - inputs' transposes,
+ * required by the new layout.
+ * \param out_axes (out) Outputs' transposes, required to convert to the original layouts.
+ * \return true if attrs changed and params need to be reparsed.
+ */
+using FChangeLayout = std::function<bool(nnvm::NodeAttrs*,
+ mshadow::LayoutFlag target_layout,
+ std::vector<Transpose>* in_axes,
+ std::vector<Transpose>* out_axes)>;
+
+/*!
+ * \brief Factors out and returns a common transpose, or default-constructed Transpose if all
+ * axes (in/out parameter) are empty.
+ */
+Transpose FactorCommonTranspose(std::vector<Transpose>* axes);
+
+} // namespace alm
+} // namespace mxnet
+
+#endif // MXNET_COMMON_ALM_H_
diff --git a/src/imperative/cached_op.h b/src/imperative/cached_op.h
index 97ac23c..079a56e 100644
--- a/src/imperative/cached_op.h
+++ b/src/imperative/cached_op.h
@@ -28,6 +28,7 @@
#include <string>
#include <unordered_map>
#include <map>
+#include "../common/alm.h"
#include "../operator/operator_common.h"
#include "../operator/subgraph/common.h"
#include "./imperative_utils.h"
@@ -208,6 +209,8 @@ void CreateForwardGraph(const nnvm::Symbol& sym, nnvm::Graph* fwd_graph) {
fwd_graph->outputs.push_back(nodeEntry);
}
}
+ if (alm::ALMParams::get().optimize)
+ *fwd_graph = alm::OptimizeLayout(std::move(*fwd_graph));
}
/* \brief construct grad_graph from fwd_graph and ograd_entries*/
diff --git a/src/operator/cudnn_ops.cc b/src/operator/cudnn_ops.cc
index 2778f7b..e7e649f 100644
--- a/src/operator/cudnn_ops.cc
+++ b/src/operator/cudnn_ops.cc
@@ -433,7 +433,7 @@ cudnnBackendHeurMode_t HeurMode() {
std::string ConvParamStr(const ConvParam& param) {
std::ostringstream ss;
- ss << " layout: " << param.layout.value();
+ ss << mshadow::toString(static_cast<mshadow::LayoutFlag>(param.layout.value()));
ss << " kernel: " << param.kernel;
ss << " stride: " << param.stride;
ss << " dilate: " << param.dilate;
diff --git a/src/operator/elemwise_op_common.h b/src/operator/elemwise_op_common.h
index 27ed029..5884d99 100644
--- a/src/operator/elemwise_op_common.h
+++ b/src/operator/elemwise_op_common.h
@@ -36,6 +36,7 @@
#include <utility>
#include "./operator_common.h"
#include "./mxnet_op.h"
+#include "../common/alm.h"
namespace mxnet {
namespace op {
@@ -197,6 +198,15 @@ inline bool ElemwiseType(const nnvm::NodeAttrs& attrs,
attrs, in_attrs, out_attrs, -1);
}
+inline bool ElemwiseChangeLayout(nnvm::NodeAttrs* attrs,
+ mshadow::LayoutFlag targetLayout,
+ std::vector<alm::Transpose>* inpTransposes,
+ std::vector<alm::Transpose>* outTransposes) {
+ CHECK_EQ(targetLayout, mshadow::kUNKNOWN);
+ outTransposes->assign(attrs->op->num_outputs, alm::FactorCommonTranspose(inpTransposes));
+ return false;
+}
+
// Special case of ElemwiseType. Constrains dtype to integer types
template <index_t n_in, index_t n_out>
inline bool ElemwiseIntType(const nnvm::NodeAttrs& attrs,
diff --git a/src/operator/leaky_relu.cc b/src/operator/leaky_relu.cc
index ff2ce4a..39aa11d 100644
--- a/src/operator/leaky_relu.cc
+++ b/src/operator/leaky_relu.cc
@@ -24,6 +24,7 @@
*/
#include "./leaky_relu-inl.h"
+#include "../common/alm.h"
#if MXNET_USE_ONEDNN == 1
#include "./nn/dnnl/dnnl_base-inl.h"
#include "./nn/dnnl/dnnl_ops-inl.h"
@@ -145,6 +146,17 @@ inline static bool BackwardLeakyReLUStorageType(const nnvm::NodeAttrs& attrs,
}
#endif // MXNET_USE_ONEDNN == 1
+static bool LRChangeLayout(nnvm::NodeAttrs* attrs,
+ mshadow::LayoutFlag target_layout,
+ std::vector<alm::Transpose>* in_axes,
+ std::vector<alm::Transpose>* out_axes) {
+ CHECK_EQ(target_layout, mshadow::kUNKNOWN);
+ out_axes->assign(1, alm::FactorCommonTranspose(in_axes));
+ if (attrs->dict["act_type"] == "rrelu")
+ out_axes->resize(2);
+ return false;
+}
+
NNVM_REGISTER_OP(LeakyReLU)
.describe(R"code(Applies Leaky rectified linear unit activation element-wise to the input.
@@ -195,6 +207,7 @@ The following modified ReLU Activation functions are supported:
})
.set_attr<mxnet::FInferShape>("FInferShape", LeakyReLUShape)
.set_attr<nnvm::FInferType>("FInferType", LeakyReLUType)
+ .set_attr<mxnet::alm::FChangeLayout>("FChangeLayout", LRChangeLayout)
.set_attr<FCompute>("FCompute<cpu>", LeakyReLUCompute<cpu>)
#if MXNET_USE_ONEDNN == 1
.set_attr<bool>("TIsDNNL", true)
diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc
index d3502b9..04cc78a 100644
--- a/src/operator/nn/batch_norm.cc
+++ b/src/operator/nn/batch_norm.cc
@@ -27,6 +27,7 @@
#include "../elemwise_op_common.h"
#include "../operator_common.h"
+#include "../../common/alm.h"
#include "batch_norm-inl.h"
#if MXNET_USE_ONEDNN == 1
@@ -445,6 +446,21 @@ static bool BatchNormType(const nnvm::NodeAttrs& attrs,
return true;
}
+static bool BNChangeLayout(nnvm::NodeAttrs* attrs,
+ mshadow::LayoutFlag targetLayout,
+ std::vector<alm::Transpose>* inpTransposes,
+ std::vector<alm::Transpose>* outTransposes) {
+ CHECK_EQ(targetLayout, mshadow::kUNKNOWN);
+ auto t = alm::FactorCommonTranspose(inpTransposes);
+ outTransposes->assign(1, t);
+ if (alm::IsIdentity(t))
+ return false;
+ const auto& param = nnvm::get<BatchNormParam>(attrs->parsed);
+ CHECK_LT(param.axis, t.size());
+ attrs->dict["axis"] = std::to_string(t[param.axis]);
+ return true;
+}
+
#if MXNET_USE_ONEDNN == 1
static inline bool SupportDNNLBN(const NDArray& input, const BatchNormParam& param) {
if (mxnet::op::batchnorm::disable_mkl)
@@ -641,6 +657,7 @@ then set ``gamma`` to 1 and its gradient to 0.
})
.set_attr<mxnet::FInferShape>("FInferShape", BatchNormShape)
.set_attr<nnvm::FInferType>("FInferType", BatchNormType)
+ .set_attr<mxnet::alm::FChangeLayout>("FChangeLayout", BNChangeLayout)
.set_attr<FInferStorageType>("FInferStorageType", BatchNormStorageType)
.set_attr<FCompute>("FCompute<cpu>", BatchNormCompute<cpu>)
#if MXNET_USE_ONEDNN == 1
diff --git a/src/operator/nn/convolution-inl.h b/src/operator/nn/convolution-inl.h
index b611542..9994c7b 100644
--- a/src/operator/nn/convolution-inl.h
+++ b/src/operator/nn/convolution-inl.h
@@ -100,6 +100,7 @@ struct ConvolutionParam : public dmlc::Parameter<ConvolutionParam> {
.add_enum("NCW", mshadow::kNCW)
.add_enum("NCHW", mshadow::kNCHW)
.add_enum("NCDHW", mshadow::kNCDHW)
+ .add_enum("NWC", mshadow::kNWC)
.add_enum("NHWC", mshadow::kNHWC)
.add_enum("NDHWC", mshadow::kNDHWC)
.set_default(dmlc::optional<int>())
diff --git a/src/operator/nn/convolution.cc b/src/operator/nn/convolution.cc
index 787fbc0..a39fa3f 100644
--- a/src/operator/nn/convolution.cc
+++ b/src/operator/nn/convolution.cc
@@ -23,9 +23,12 @@
* \author Bing Xu, Jun Wu, Da Zheng
*/
+#include <mshadow/base.h>
+#include <mshadow/tensor.h>
#include "./convolution-inl.h"
#include "../elemwise_op_common.h"
#include "../operator_common.h"
+#include "../../common/alm.h"
#if MXNET_USE_ONEDNN == 1
#include "./dnnl/dnnl_base-inl.h"
#include "./dnnl/dnnl_ops-inl.h"
@@ -79,6 +82,29 @@ static void ConvolutionGradComputeExCPU(const nnvm::NodeAttrs& attrs,
}
#endif
+static bool ConvChangeLayout(nnvm::NodeAttrs* attrs,
+ mshadow::LayoutFlag target_layout,
+ std::vector<alm::Transpose>* in_axes,
+ std::vector<alm::Transpose>* out_axes) {
+ const auto& param = nnvm::get<ConvolutionParam>(attrs->parsed);
+ CHECK(param.layout) << "Current layout of convolution should be known: " << attrs->name;
+ auto layout = static_cast<mshadow::LayoutFlag>(param.layout.value());
+ auto t = target_layout != mshadow::kUNKNOWN ?
+ mshadow::getTranspAxes<size_t>(layout, target_layout) :
+ alm::FactorCommonTranspose(in_axes);
+ out_axes->assign(1, alm::Reverse(t));
+ if (alm::IsIdentity(t))
+ return false;
+ if (target_layout != mshadow::kUNKNOWN) {
+ for (auto i : {0, 1})
+ in_axes->at(i) = alm::Compose(t, in_axes->at(i));
+ } else {
+ target_layout = alm::ApplyTranspose(layout, t);
+ }
+ attrs->dict["layout"] = mshadow::toString(target_layout);
+ return true;
+}
+
static bool ConvolutionShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector* in_shape,
mxnet::ShapeVector* out_shape) {
@@ -502,6 +528,7 @@ There are other options to tune the performance.
})
.set_attr<mxnet::FInferShape>("FInferShape", ConvolutionShape)
.set_attr<nnvm::FInferType>("FInferType", ConvolutionType)
+ .set_attr<mxnet::alm::FChangeLayout>("FChangeLayout", ConvChangeLayout)
#if MXNET_USE_ONEDNN == 1
.set_attr<FInferStorageType>("FInferStorageType", ConvStorageType)
#endif
diff --git a/src/operator/nn/deconvolution.cc b/src/operator/nn/deconvolution.cc
index 86cde82..2bef3fc 100644
--- a/src/operator/nn/deconvolution.cc
+++ b/src/operator/nn/deconvolution.cc
@@ -25,6 +25,7 @@
#include "./deconvolution-inl.h"
#include "../operator_common.h"
+#include "../../common/alm.h"
#include "../../common/utils.h"
#if MXNET_USE_ONEDNN == 1
#include "./dnnl/dnnl_base-inl.h"
@@ -401,6 +402,29 @@ struct DeconvolutionGrad {
}
};
+static bool DeconvChangeLayout(nnvm::NodeAttrs* attrs,
+ mshadow::LayoutFlag target_layout,
+ std::vector<alm::Transpose>* in_axes,
+ std::vector<alm::Transpose>* out_axes) {
+ const auto& param = nnvm::get<DeconvolutionParam>(attrs->parsed);
+ CHECK(param.layout) << "Current layout of convolution should be known: " << attrs->name;
+ auto layout = static_cast<mshadow::LayoutFlag>(param.layout.value());
+ auto t = target_layout != mshadow::kUNKNOWN ?
+ mshadow::getTranspAxes<size_t>(layout, target_layout) :
+ alm::FactorCommonTranspose(in_axes);
+ out_axes->assign(1, alm::Reverse(t));
+ if (alm::IsIdentity(t))
+ return false;
+ if (target_layout != mshadow::kUNKNOWN) {
+ for (auto i : {0, 1})
+ in_axes->at(i) = alm::Compose(t, in_axes->at(i));
+ } else {
+ target_layout = alm::ApplyTranspose(layout, t);
+ }
+ attrs->dict["layout"] = mshadow::toString(target_layout);
+ return true;
+}
+
DMLC_REGISTER_PARAMETER(DeconvolutionParam);
NNVM_REGISTER_OP(Deconvolution)
@@ -428,6 +452,7 @@ NNVM_REGISTER_OP(Deconvolution)
})
.set_attr<mxnet::FInferShape>("FInferShape", DeconvolutionShape)
.set_attr<nnvm::FInferType>("FInferType", DeconvolutionType)
+ .set_attr<mxnet::alm::FChangeLayout>("FChangeLayout", DeconvChangeLayout)
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& n) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
diff --git a/src/operator/nn/pooling.cc b/src/operator/nn/pooling.cc
index edb6a1e..7b302ee9 100644
--- a/src/operator/nn/pooling.cc
+++ b/src/operator/nn/pooling.cc
@@ -24,6 +24,7 @@
*/
#include "../elemwise_op_common.h"
#include "./pooling-inl.h"
+#include "../../common/alm.h"
#if MXNET_USE_ONEDNN == 1
#include "./dnnl/dnnl_base-inl.h"
#include "./dnnl/dnnl_pooling-inl.h"
@@ -270,6 +271,22 @@ static bool PoolingShape(const nnvm::NodeAttrs& attrs,
return true;
}
+static bool PoolChangeLayout(nnvm::NodeAttrs* attrs,
+ mshadow::LayoutFlag targetLayout,
+ std::vector<alm::Transpose>* inpTransposes,
+ std::vector<alm::Transpose>* outTransposes) {
+ CHECK_EQ(targetLayout, mshadow::kUNKNOWN);
+ const auto& param = nnvm::get<PoolingParam>(attrs->parsed);
+ CHECK(param.layout) << "Current layout of pooling should be known: " << attrs->name;
+ auto layout = static_cast<mshadow::LayoutFlag>(param.layout.value());
+ auto t = alm::FactorCommonTranspose(inpTransposes);
+ if (alm::IsIdentity(t))
+ return false;
+ outTransposes->assign(1, t);
+ attrs->dict["layout"] = mshadow::toString(alm::ApplyTranspose(layout, alm::Reverse(t)));
+ return true;
+}
+
#if MXNET_USE_ONEDNN == 1
void PoolingComputeExCPU(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
@@ -443,6 +460,7 @@ For each window ``X``, the mathematical expression for Lp pooling is:
#endif
.set_attr<nnvm::FInferType>("FInferType", PoolingType)
.set_attr<mxnet::FInferShape>("FInferShape", PoolingShape)
+ .set_attr<mxnet::alm::FChangeLayout>("FChangeLayout", PoolChangeLayout)
.set_attr<FCompute>("FCompute<cpu>", PoolingCompute<cpu>)
#if MXNET_USE_ONEDNN == 1
.set_attr<bool>("TIsDNNL", true)
diff --git a/src/operator/operator_common.h b/src/operator/operator_common.h
index 8c5beec..9a219aa 100644
--- a/src/operator/operator_common.h
+++ b/src/operator/operator_common.h
@@ -42,6 +42,7 @@
#include "../common/utils.h"
namespace mxnet {
+
namespace op {
/*!
* \brief assign the expression to out according to request
diff --git a/src/operator/tensor/amp_cast.cc b/src/operator/tensor/amp_cast.cc
index 62e63a1..1899a4c 100644
--- a/src/operator/tensor/amp_cast.cc
+++ b/src/operator/tensor/amp_cast.cc
@@ -23,10 +23,25 @@
*/
#include "./amp_cast.h"
+#include "../../common/alm.h"
namespace mxnet {
namespace op {
+static bool MCastChangeLayout(nnvm::NodeAttrs* attrs,
+ mshadow::LayoutFlag targetLayout,
+ std::vector<alm::Transpose>* inpTransposes,
+ std::vector<alm::Transpose>* outTransposes) {
+ auto n_inps = attrs->op->get_num_inputs(*attrs);
+ auto n_outs = attrs->op->get_num_outputs(*attrs);
+ CHECK_EQ(n_inps, n_outs) << "This operator should have the same number inputs and outputs";
+ CHECK_EQ(inpTransposes->size(), n_inps);
+ CHECK_EQ(targetLayout, mshadow::kUNKNOWN);
+ *outTransposes = std::move(*inpTransposes);
+ inpTransposes->assign(n_inps, alm::Transpose());
+ return false;
+}
+
DMLC_REGISTER_PARAMETER(AMPCastParam);
DMLC_REGISTER_PARAMETER(AMPMultiCastParam);
@@ -135,6 +150,7 @@ It casts only between low precision float/FP32 and does not do anything for othe
.set_attr_parser(ParamParser<AMPCastParam>)
.set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<1, 1>)
.set_attr<nnvm::FInferType>("FInferType", AMPCastType)
+ .set_attr<mxnet::alm::FChangeLayout>("FChangeLayout", ElemwiseChangeLayout)
.set_attr<nnvm::FInplaceOption>("FInplaceOption",
[](const NodeAttrs& attrs) {
return std::vector<std::pair<int, int>>{{0, 0}};
@@ -188,6 +204,7 @@ It casts only between low precision float/FP32 and does not do anything for othe
.set_attr_parser(ParamParser<AMPMultiCastParam>)
.set_attr<mxnet::FInferShape>("FInferShape", AMPMultiCastShape)
.set_attr<nnvm::FInferType>("FInferType", AMPMultiCastType)
+ .set_attr<mxnet::alm::FChangeLayout>("FChangeLayout", MCastChangeLayout)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
uint32_t num_args =
diff --git a/src/operator/tensor/elemwise_binary_op.h b/src/operator/tensor/elemwise_binary_op.h
index 4f36b8a..732b6a5 100644
--- a/src/operator/tensor/elemwise_binary_op.h
+++ b/src/operator/tensor/elemwise_binary_op.h
@@ -813,6 +813,7 @@ class ElemwiseBinaryOp : public OpBase {
}) \
.set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<2, 1>) \
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<2, 1>) \
+ .set_attr<mxnet::alm::FChangeLayout>("FChangeLayout", ElemwiseChangeLayout) \
.set_attr<nnvm::FInplaceOption>("FInplaceOption", \
[](const NodeAttrs& attrs) { \
return std::vector<std::pair<int, int> >{{0, 0}, {1, 0}}; \
diff --git a/src/operator/tensor/elemwise_binary_scalar_op.h b/src/operator/tensor/elemwise_binary_scalar_op.h
index aa6b7f5..8c025ef 100644
--- a/src/operator/tensor/elemwise_binary_scalar_op.h
+++ b/src/operator/tensor/elemwise_binary_scalar_op.h
@@ -31,6 +31,7 @@
#include <string>
#include "../mshadow_op.h"
#include "../elemwise_op_common.h"
+#include "../../common/alm.h"
#include "elemwise_unary_op.h"
namespace mxnet {
@@ -447,6 +448,7 @@ class BinaryScalarOp : public UnaryOp {
.set_attr_parser(ParamParser<NumpyBinaryScalarParam>) \
.set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<1, 1>) \
.set_attr<nnvm::FInferType>("FInferType", NumpyBinaryScalarType) \
+ .set_attr<mxnet::alm::FChangeLayout>("FChangeLayout", ElemwiseChangeLayout) \
.set_attr<nnvm::FInplaceOption>("FInplaceOption", \
[](const NodeAttrs& attrs) { \
return std::vector<std::pair<int, int> >{{0, 0}}; \
diff --git a/src/operator/tensor/elemwise_unary_op.h b/src/operator/tensor/elemwise_unary_op.h
index 5d23c98..0048777 100644
--- a/src/operator/tensor/elemwise_unary_op.h
+++ b/src/operator/tensor/elemwise_unary_op.h
@@ -35,6 +35,7 @@
#include "../mshadow_op.h"
#include "../mxnet_op.h"
#include "../elemwise_op_common.h"
+#include "../../common/alm.h"
#include "../../common/utils.h"
#include "../../ndarray/ndarray_function.h"
@@ -865,6 +866,7 @@ void NumpyNanToNumOpBackward(const nnvm::NodeAttrs& attrs,
.set_num_outputs(1) \
.set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<1, 1>) \
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>) \
+ .set_attr<mxnet::alm::FChangeLayout>("FChangeLayout", ElemwiseChangeLayout) \
.set_attr<nnvm::FInplaceOption>("FInplaceOption", \
[](const NodeAttrs& attrs) { \
return std::vector<std::pair<int, int> >{{0, 0}}; \
diff --git a/src/operator/tensor/matrix_op.cc b/src/operator/tensor/matrix_op.cc
index 787eb5c..b65c7cb 100644
--- a/src/operator/tensor/matrix_op.cc
+++ b/src/operator/tensor/matrix_op.cc
@@ -334,6 +334,22 @@ inline static bool TransposeStorageType(const nnvm::NodeAttrs& attrs,
}
#endif
+static bool TransposeChangeLayout(nnvm::NodeAttrs* attrs,
+ mshadow::LayoutFlag target_layout,
+ std::vector<alm::Transpose>* in_axes,
+ std::vector<alm::Transpose>* out_axes) {
+ CHECK_EQ(target_layout, mshadow::kUNKNOWN);
+ CHECK_EQ(in_axes->size(), 1);
+ const auto& param = nnvm::get<TransposeParam>(attrs->parsed);
+ auto new_axes = alm::Compose(alm::FromTShape(param.axes), in_axes->at(0));
+ std::ostringstream ss;
+ ss << mxnet::TShape(new_axes.begin(), new_axes.end());
+ attrs->dict["axes"] = ss.str();
+ in_axes->assign(1, alm::Transpose());
+ out_axes->assign(1, alm::Transpose());
+ return true;
+}
+
NNVM_REGISTER_OP(transpose)
.describe(R"code(Permutes the dimensions of an array.
Examples::
@@ -360,6 +376,7 @@ Examples::
.set_attr_parser(ParamParser<TransposeParam>)
.set_attr<mxnet::FInferShape>("FInferShape", TransposeShape)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
+ .set_attr<mxnet::alm::FChangeLayout>("FChangeLayout", TransposeChangeLayout)
.set_attr<nnvm::FGradient>(
"FGradient",
[](const nnvm::ObjectPtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
diff --git a/tests/python/gpu/test_amp_init.py b/tests/python/gpu/test_amp_init.py
index 2980366..28d1123 100644
--- a/tests/python/gpu/test_amp_init.py
+++ b/tests/python/gpu/test_amp_init.py
@@ -15,12 +15,18 @@
# specific language governing permissions and limitations
# under the License.
-import mxnet as mx
-from mxnet.gluon import nn
-from mxnet import amp
+from contextlib import contextmanager
+import ctypes
+
import numpy as np
import pytest
+import mxnet as mx
+from mxnet import amp
+from mxnet.base import check_call, _LIB
+from mxnet.gluon import nn
+from mxnet.test_utils import assert_allclose
+
@pytest.fixture
def np_shape_array():
@@ -35,6 +41,17 @@ def amp_init():
amp.init()
+@contextmanager
+def optimize_layout(optimize=True):
+ prev = ctypes.c_bool()
+ check_call(_LIB.MXGetOptimizeLayout(ctypes.byref(prev)))
+ check_call(_LIB.MXSetOptimizeLayout(ctypes.c_bool(optimize)))
+ try:
+ yield
+ finally:
+ check_call(_LIB.MXSetOptimizeLayout(prev))
+
+
def test_npi_concatenate_multicast(np_shape_array, amp_init):
class Foo(nn.HybridBlock):
def __init__(self, **kwargs):
@@ -51,3 +68,76 @@ def test_npi_concatenate_multicast(np_shape_array, amp_init):
data = mx.np.ones((32, 8), ctx=mx.gpu())
out = foo(data)
assert out.dtype == np.float32
+
+
+CONV = {1: nn.Conv1D, 2: nn.Conv2D, 3: nn.Conv3D}
+MAX_POOL = {1: nn.MaxPool1D, 2: nn.MaxPool2D, 3: nn.MaxPool3D}
+
+
+class Conv(nn.HybridBlock):
+ def __init__(self, ndim, **kwargs):
+ super().__init__(**kwargs)
+ self.conv = CONV[ndim](10, 3)
+
+ def forward(self, x):
+ y = self.conv(x)
+ return y * 2
+
+
+class ConvBN(nn.HybridBlock):
+ def __init__(self, ndim, **kwargs):
+ super().__init__(**kwargs)
+ self.conv = CONV[ndim](10, 3)
+ self.bn = nn.BatchNorm()
+
+ def forward(self, x):
+ y = self.conv(x)
+ y = self.bn(y)
+ return y * 2 + 10
+
+
+class PoolConv(nn.HybridBlock):
+ def __init__(self, ndim, **kwargs):
+ super().__init__(**kwargs)
+ self.pool = MAX_POOL[ndim]()
+ self.conv = CONV[ndim](10, 3)
+
+ def forward(self, x):
+ y = self.pool(x)
+ y = self.conv(y)
+ return y * 2
+
+
+@pytest.mark.skipif(not mx.runtime.Features().is_enabled('CUDNN'),
+ reason='Channel-last layouts are only supported with cuDNN.')
+@pytest.mark.parametrize('ndim', [1, 2, 3])
+@pytest.mark.parametrize('model', [Conv, ConvBN, PoolConv])
+def test_optimize_layout(np_shape_array, amp_init, model, ndim):
+ m = model(ndim)
+ m.initialize(ctx=mx.gpu())
+ m.hybridize()
+ x = mx.np.random.uniform(low=0, high=10, size=(32, 2, 17, 15, 12)[:ndim + 2], ctx=mx.gpu())
+ m(x)
+ param_init = {k:v.data().copy() for k, v in m.collect_params().items()}
+ for v in m.collect_params().values():
+ v.data().attach_grad()
+ with mx.autograd.record():
+ y = m(x)
+ y.backward()
+ with optimize_layout():
+ m2 = model(ndim)
+ m2.initialize(ctx=mx.gpu())
+ m2.load_dict(param_init, device=mx.gpu())
+ m2.hybridize()
+ for v in m2.collect_params().values():
+ v.data().attach_grad()
+ with mx.autograd.record():
+ y2 = m2(x)
+ y2.backward()
+ rtol = 1e-2
+ atol = 1e-2
+ assert_allclose(y2, y, rtol=rtol, atol=atol)
+ for k, v in m.collect_params().items():
+ if v.grad_req == 'null':
+ continue
+ assert_allclose(m2.collect_params()[k].grad(), v.grad(), rtol=rtol, atol=atol)