You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by bg...@apache.org on 2022/08/22 11:54:02 UTC
[incubator-mxnet] branch master updated: Fix fused resnet low accuracy (#21122)
This is an automated email from the ASF dual-hosted git repository.
bgawrych pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new daac02c785 Fix fused resnet low accuracy (#21122)
daac02c785 is described below
commit daac02c7854ffa71bc11fd950c2d6c9ea356b394
Author: hankaj <ha...@intel.com>
AuthorDate: Mon Aug 22 13:53:44 2022 +0200
Fix fused resnet low accuracy (#21122)
* Change flag to postop
* Add attributes to batch norm relu node
* Refactor code with batch norm relu op
* Delete fuse_norm_relu flag
* Delete fuse_norm_relu flag
* Refactor BN operator
* Review suggestions
* Review suggestions once again
* Fix formatting
* Fix lint errors
---
python/mxnet/amp/lists/symbol_bf16.py | 2 +-
python/mxnet/amp/lists/symbol_fp16.py | 2 +-
python/mxnet/gluon/nn/basic_layers.py | 81 +---
src/operator/nn/batch_norm-inl.h | 2 +-
src/operator/nn/batch_norm.cc | 6 +-
src/operator/nn/dnnl/dnnl_batch_norm-inl.h | 420 ++++-----------------
.../{dnnl_batch_norm-inl.h => dnnl_batch_norm.cc} | 287 +++++---------
.../quantization/dnnl/dnnl_quantized_batch_norm.cc | 2 +-
.../dnnl/dnnl_bn_relu.cc} | 86 +----
src/operator/subgraph/dnnl/dnnl_bn_relu_property.h | 7 +-
tests/python/dnnl/op_cfg.py | 2 +-
tests/python/dnnl/test_dnnl.py | 58 ---
12 files changed, 189 insertions(+), 766 deletions(-)
diff --git a/python/mxnet/amp/lists/symbol_bf16.py b/python/mxnet/amp/lists/symbol_bf16.py
index 89ddea2820..5b9df27497 100644
--- a/python/mxnet/amp/lists/symbol_bf16.py
+++ b/python/mxnet/amp/lists/symbol_bf16.py
@@ -41,7 +41,6 @@ if Features.instance.is_enabled('ONEDNN'):
# are dtype neutral (can work in both bf16 and fp32)
BF16_FP32_FUNCS = [
'_contrib_AdaptiveAvgPooling2D',
- '_contrib_BatchNormWithReLU',
'Activation',
'BatchNorm',
'LayerNorm',
@@ -102,6 +101,7 @@ WIDEST_TYPE_CASTS = [
if Features.instance.is_enabled('ONEDNN'):
WIDEST_TYPE_CASTS.extend([
'_sg_onednn_batch_dot',
+ '_sg_onednn_batch_norm',
])
# Functions that when running with Bfloat16, the params that still need float32.
diff --git a/python/mxnet/amp/lists/symbol_fp16.py b/python/mxnet/amp/lists/symbol_fp16.py
index 1cd5316361..76e9488f69 100644
--- a/python/mxnet/amp/lists/symbol_fp16.py
+++ b/python/mxnet/amp/lists/symbol_fp16.py
@@ -43,7 +43,6 @@ FP16_FP32_FUNCS = [
'BlockGrad',
'Cast',
'cast_storage',
- '_contrib_BatchNormWithReLU',
'_contrib_allclose',
'_contrib_arange_like',
'_contrib_dynamic_reshape',
@@ -637,6 +636,7 @@ if Features().is_enabled('ONEDNN'):
'_sg_onednn_selfatt_qk',
'_sg_onednn_selfatt_valatt',
'_sg_onednn_batch_dot',
+ '_sg_onednn_batch_norm',
'_sg_pow_mul_scalar'
])
diff --git a/python/mxnet/gluon/nn/basic_layers.py b/python/mxnet/gluon/nn/basic_layers.py
index 4afffb7d47..883b714b16 100644
--- a/python/mxnet/gluon/nn/basic_layers.py
+++ b/python/mxnet/gluon/nn/basic_layers.py
@@ -19,7 +19,7 @@
# pylint: disable= arguments-differ
"""Basic neural network layers."""
__all__ = ['Sequential', 'HybridSequential', 'Dense', 'Dropout', 'Embedding',
- 'BatchNorm', 'SyncBatchNorm', 'BatchNormReLU', 'InstanceNorm', 'LayerNorm', 'GroupNorm',
+ 'BatchNorm', 'SyncBatchNorm', 'InstanceNorm', 'LayerNorm', 'GroupNorm',
'Flatten', 'Lambda', 'HybridLambda', 'Concatenate', 'HybridConcatenate', 'Identity']
import warnings
import uuid
@@ -322,8 +322,6 @@ class _BatchNorm(HybridBlock):
If True, use global moving statistics instead of local batch-norm. This will force
change batch-norm into a scale shift operator.
If False, use local batch-norm.
- fuse_relu: bool, default False
- If True, this operator is equal to `BN+ReLU`.
beta_initializer: str or `Initializer`, default 'zeros'
Initializer for the beta weight.
gamma_initializer: str or `Initializer`, default 'ones'
@@ -345,14 +343,13 @@ class _BatchNorm(HybridBlock):
- **out**: output tensor with the same shape as `data`.
"""
def __init__(self, axis=1, momentum=0.9, epsilon=1e-5, center=True, scale=True,
- use_global_stats=False, fuse_relu=False,
+ use_global_stats=False,
beta_initializer='zeros', gamma_initializer='ones',
running_mean_initializer='zeros', running_variance_initializer='ones',
in_channels=0, **kwargs):
super(_BatchNorm, self).__init__(**kwargs)
self._kwargs = {'axis': axis, 'eps': epsilon, 'momentum': momentum,
'fix_gamma': not scale, 'use_global_stats': use_global_stats}
- self.fuse_relu = fuse_relu
self._axis = axis
if in_channels != 0:
self.in_channels = in_channels
@@ -383,13 +380,7 @@ class _BatchNorm(HybridBlock):
def forward(self, x):
device = x.device
- if self.fuse_relu:
- return npx.batch_norm_with_relu(x, self.gamma.data(device), self.beta.data(device),
- self.running_mean.data(device),
- self.running_var.data(device),
- name='fwd', **self._kwargs)
- else:
- return npx.batch_norm(x, self.gamma.data(device), self.beta.data(device),
+ return npx.batch_norm(x, self.gamma.data(device), self.beta.data(device),
self.running_mean.data(device),
self.running_var.data(device),
name='fwd', **self._kwargs)
@@ -467,71 +458,7 @@ class BatchNorm(_BatchNorm):
super(BatchNorm, self).__init__(
axis=axis, momentum=momentum, epsilon=epsilon, center=center,
scale=scale,
- use_global_stats=use_global_stats, fuse_relu=False,
- beta_initializer=beta_initializer,
- gamma_initializer=gamma_initializer,
- running_mean_initializer=running_mean_initializer,
- running_variance_initializer=running_variance_initializer,
- in_channels=in_channels, **kwargs)
-
-
-class BatchNormReLU(_BatchNorm):
- """Batch normalization layer (Ioffe and Szegedy, 2014).
- Normalizes the input at each batch, i.e. applies a transformation
- that maintains the mean activation close to 0 and the activation
- standard deviation close to 1.
-
- Parameters
- ----------
- axis : int, default 1
- The axis that should be normalized. This is typically the channels
- (C) axis. For instance, after a `Conv2D` layer with `layout='NCHW'`,
- set `axis=1` in `BatchNorm`. If `layout='NHWC'`, then set `axis=3`.
- momentum: float, default 0.9
- Momentum for the moving average.
- epsilon: float, default 1e-5
- Small float added to variance to avoid dividing by zero.
- center: bool, default True
- If True, add offset of `beta` to normalized tensor.
- If False, `beta` is ignored.
- scale: bool, default True
- If True, multiply by `gamma`. If False, `gamma` is not used.
- When the next layer is linear (also e.g. `nn.relu`),
- this can be disabled since the scaling
- will be done by the next layer.
- use_global_stats: bool, default False
- If True, use global moving statistics instead of local batch-norm. This will force
- change batch-norm into a scale shift operator.
- If False, use local batch-norm.
- beta_initializer: str or `Initializer`, default 'zeros'
- Initializer for the beta weight.
- gamma_initializer: str or `Initializer`, default 'ones'
- Initializer for the gamma weight.
- running_mean_initializer: str or `Initializer`, default 'zeros'
- Initializer for the running mean.
- running_variance_initializer: str or `Initializer`, default 'ones'
- Initializer for the running variance.
- in_channels : int, default 0
- Number of channels (feature maps) in input data. If not specified,
- initialization will be deferred to the first time `forward` is called
- and `in_channels` will be inferred from the shape of input data.
-
-
- Inputs:
- - **data**: input tensor with arbitrary shape.
-
- Outputs:
- - **out**: output tensor with the same shape as `data`.
- """
- def __init__(self, axis=1, momentum=0.9, epsilon=1e-5, center=True, scale=True,
- use_global_stats=False,
- beta_initializer='zeros', gamma_initializer='ones',
- running_mean_initializer='zeros', running_variance_initializer='ones',
- in_channels=0, **kwargs):
- super(BatchNormReLU, self).__init__(
- axis=axis, momentum=momentum, epsilon=epsilon,
- center=center, scale=scale,
- use_global_stats=use_global_stats, fuse_relu=True,
+ use_global_stats=use_global_stats,
beta_initializer=beta_initializer,
gamma_initializer=gamma_initializer,
running_mean_initializer=running_mean_initializer,
diff --git a/src/operator/nn/batch_norm-inl.h b/src/operator/nn/batch_norm-inl.h
index 92eded093d..d863302040 100644
--- a/src/operator/nn/batch_norm-inl.h
+++ b/src/operator/nn/batch_norm-inl.h
@@ -46,7 +46,7 @@
#endif
/*! \brief inverse standard deviation <-> variance */
-#define VARIANCE_TO_INVSTD(__var$, __eps$) (1.0 / std::sqrt((__var$) + DType(__eps$)))
+#define VARIANCE_TO_INVSTD(__var$, __eps$) (1.0 / std::sqrt((__var$) + (__eps$)))
#define INVSTD_TO_VARIANCE(__invstd$, __eps$) ((1.0 / ((__invstd$) * (__invstd$))) - (__eps$))
namespace mxnet {
diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc
index dc09ebeb22..154471f109 100644
--- a/src/operator/nn/batch_norm.cc
+++ b/src/operator/nn/batch_norm.cc
@@ -475,9 +475,7 @@ void BatchNormComputeExCPU(const nnvm::NodeAttrs& attrs,
CHECK_EQ(inputs.size(), 5U);
if (SupportDNNLBN(inputs[0])) {
DNNL_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
- DNNL_REAL_TYPE_SWITCH(inputs[0].dtype(), DTYPE, {
- DNNLRun(DNNLBatchNormForward<DTYPE, /*fuse_relu*/ false>, attrs, ctx, inputs, req, outputs);
- });
+ DNNLRun(DNNLBatchNormForward</*fuse_relu*/ false>, attrs, ctx, inputs, req, outputs);
DNNL_OPCHECK_RUN(BatchNormCompute<cpu>, attrs, ctx, inputs, req, outputs);
return;
}
@@ -491,7 +489,7 @@ void BatchNormGradComputeExCPU(const nnvm::NodeAttrs& attrs,
const std::vector<NDArray>& outputs) {
if (SupportDNNLBN(inputs[0])) {
DNNL_OPCHECK_INIT(true, outputs.size(), inputs, outputs);
- DNNLRun(DNNLBatchNormBackward<float, /*fuse_relu*/ false>, attrs, ctx, inputs, req, outputs);
+ DNNLRun(DNNLBatchNormBackward, attrs, ctx, inputs, req, outputs);
DNNL_OPCHECK_RUN(BatchNormGradCompute<cpu>, attrs, ctx, inputs, req, outputs);
return;
}
diff --git a/src/operator/nn/dnnl/dnnl_batch_norm-inl.h b/src/operator/nn/dnnl/dnnl_batch_norm-inl.h
index 2780c9685f..40cf88ade1 100644
--- a/src/operator/nn/dnnl/dnnl_batch_norm-inl.h
+++ b/src/operator/nn/dnnl/dnnl_batch_norm-inl.h
@@ -18,7 +18,7 @@
*/
/*!
- * \file dnnl_batch_norm.cc
+ * \file dnnl_batch_norm-inl.h
* \brief
* \author Tao Lv
*/
@@ -28,11 +28,12 @@
#if MXNET_USE_ONEDNN == 1
#include <dnnl.hpp>
+
#include <utility>
#include <vector>
-#include "operator/nn/batch_norm-inl.h"
#include "dnnl_base-inl.h"
+#include "operator/nn/batch_norm-inl.h"
namespace mxnet {
namespace op {
@@ -44,8 +45,7 @@ typedef dnnl::batch_normalization_backward::desc t_bn_b_desc;
inline static dnnl::normalization_flags _GetFlags(const std::vector<NDArray>& in_data,
const std::vector<NDArray>& aux_states,
- bool is_train_and_not_global_stats,
- bool fuse_relu) {
+ bool is_train_and_not_global_stats) {
dnnl::normalization_flags flags = static_cast<dnnl::normalization_flags>(0U);
if (in_data.size() == 3U) {
flags |= dnnl::normalization_flags::use_scale_shift;
@@ -56,15 +56,12 @@ inline static dnnl::normalization_flags _GetFlags(const std::vector<NDArray>& in
if (aux_states.size() == 2U && !is_train_and_not_global_stats) {
flags |= dnnl::normalization_flags::use_global_stats;
}
-
- if (fuse_relu) {
- flags |= dnnl::normalization_flags::fuse_norm_relu;
- }
return flags;
}
inline static t_bn_f_pdesc _GetFwd(const dnnl::memory& data_mem,
bool is_train,
+ bool fuse_relu,
float eps,
dnnl::normalization_flags flags) {
auto data_md = data_mem.get_desc();
@@ -73,6 +70,18 @@ inline static t_bn_f_pdesc _GetFwd(const dnnl::memory& data_mem,
if (is_train) {
t_bn_f_desc bnFwd_desc(dnnl::prop_kind::forward_training, data_md, eps, flags);
return t_bn_f_pdesc(bnFwd_desc, engine);
+ }
+
+ if (fuse_relu) {
+ const float scale = 1.f;
+ const float alpha = 0.f;
+ const float beta = 0.f;
+ dnnl::post_ops post_ops;
+ post_ops.append_eltwise(scale, dnnl::algorithm::eltwise_relu, alpha, beta);
+ dnnl::primitive_attr attr;
+ attr.set_post_ops(post_ops);
+ t_bn_f_desc bnFwd_desc(dnnl::prop_kind::forward_inference, data_md, eps, flags);
+ return t_bn_f_pdesc(bnFwd_desc, attr, engine);
} else {
t_bn_f_desc bnFwd_desc(dnnl::prop_kind::forward_inference, data_md, eps, flags);
return t_bn_f_pdesc(bnFwd_desc, engine);
@@ -88,7 +97,7 @@ inline static t_bn_b_pdesc _GetBwd(const dnnl::memory& data_mem,
auto engine = CpuEngine::Get()->get_engine();
t_bn_b_desc bnBwd_desc(dnnl::prop_kind::backward, diff_md, data_md, eps, flags);
- return t_bn_b_pdesc(bnBwd_desc, engine, _GetFwd(data_mem, true, eps, flags));
+ return t_bn_b_pdesc(bnBwd_desc, engine, _GetFwd(data_mem, true, false, eps, flags));
}
typedef ParamOpSign<BatchNormParam> DNNLBNSignature;
@@ -100,63 +109,39 @@ class DNNLBNForward {
t_bn_f_pdesc pd;
public:
- DNNLBNForward(const t_bn_f_pdesc& _pd, bool is_train_and_not_global_stats) : pd(_pd) {
- weight_m.reset(new dnnl::memory(pd.weights_desc(), CpuEngine::Get()->get_engine()));
- fwd.reset(new dnnl::batch_normalization_forward(pd));
- this->is_train_and_not_global_stats = is_train_and_not_global_stats;
- }
+ DNNLBNForward(const t_bn_f_pdesc& _pd, bool is_train_and_not_global_stats);
- const dnnl::memory& GetWeight() const {
- return *weight_m;
- }
+ const dnnl::memory& GetWeight() const;
- const t_bn_f_pdesc& GetPd() const {
- return pd;
- }
+ const t_bn_f_pdesc& GetPd() const;
- const dnnl::batch_normalization_forward& GetFwd() const {
- return *fwd;
- }
-};
+ const dnnl::batch_normalization_forward& GetFwd() const;
-template <typename DType>
-static DNNLBNForward& GetBNForward(const BatchNormParam& param,
- const OpContext& ctx,
- const dnnl::memory* data_mem,
- dnnl::normalization_flags flags) {
-#if DMLC_CXX11_THREAD_LOCAL
- static thread_local std::unordered_map<DNNLBNSignature, DNNLBNForward, OpHash> fwds;
-#else
- static MX_THREAD_LOCAL std::unordered_map<DNNLBNSignature, DNNLBNForward, OpHash> fwds;
-#endif
- DNNLBNSignature key(param);
- key.AddSign(ctx.is_train);
- key.AddSign(*data_mem);
- key.AddSign(static_cast<int>(flags));
-
- auto it = fwds.find(key);
- if (it == fwds.end()) {
- auto fwd_pd = _GetFwd(*data_mem, ctx.is_train, param.eps, flags);
- DNNLBNForward fwd(fwd_pd, ctx.is_train && !param.use_global_stats);
- it = AddToCache(&fwds, key, fwd);
- }
- return it->second;
-}
+ static DNNLBNForward& GetCached(const BatchNormParam& param,
+ const OpContext& ctx,
+ const dnnl::memory* data_mem,
+ bool fuse_relu,
+ dnnl::normalization_flags flags);
+ void Execute(const OpContext& ctx,
+ const BatchNormParam& param,
+ const std::vector<NDArray>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<NDArray>& outputs,
+ bool fuse_relu);
+};
-template <typename DType>
-void DNNLBatchNormForwardImpl(const nnvm::NodeAttrs& attrs,
- const OpContext& ctx,
- const std::vector<NDArray>& inputs,
- const std::vector<OpReqType>& req,
- const std::vector<NDArray>& outputs,
- bool fuse_relu) {
+template <bool fuse_relu>
+void DNNLBatchNormForward(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const std::vector<NDArray>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<NDArray>& outputs) {
const BatchNormParam& param = nnvm::get<BatchNormParam>(attrs.parsed);
std::vector<NDArray> in_data(inputs.begin(), inputs.begin() + batchnorm::kInMovingMean);
mxnet::TShape shape = inputs[batchnorm::kData].shape();
const int real_axis = mxnet::op::batchnorm::GetRealAxis(shape, param.axis);
CHECK_LT(real_axis, shape.ndim());
- NDArray out = outputs[batchnorm::kOut];
if (param.axis != 1 || shape.ndim() != 4) {
// reshape to (N, C, 1, D)
mxnet::TShape new_shape{
@@ -165,109 +150,18 @@ void DNNLBatchNormForwardImpl(const nnvm::NodeAttrs& attrs,
1,
static_cast<index_t>(shape.ProdShape(real_axis + 1, static_cast<int>(shape.ndim())))};
in_data[batchnorm::kData] = in_data[batchnorm::kData].Reshape(new_shape);
- out = out.Reshape(new_shape);
}
const std::vector<NDArray> aux_states(inputs.begin() + batchnorm::kInMovingMean, inputs.end());
TmpMemMgr::Get()->Init(ctx.requested[batchnorm::kTempSpace]);
dnnl::normalization_flags flags =
- _GetFlags(in_data, aux_states, ctx.is_train && !param.use_global_stats, fuse_relu);
+ _GetFlags(in_data, aux_states, ctx.is_train && !param.use_global_stats);
NDArray& data = in_data[batchnorm::kData];
if (data.IsDNNLData() && data.IsView())
data = data.Reorder2Default();
- auto data_mem = data.GetDNNLData();
- auto& fwd = GetBNForward<DType>(param, ctx, data_mem, flags);
-
- // for output memory
- auto fwd_dst_desc = fwd.GetPd().dst_desc();
- auto out_mem = const_cast<NDArray&>(out).CreateDNNLData(&fwd_dst_desc);
-
- // mxnet will always use scale shift.
- // But if fix_gamma is true, then all scale elements will be set to 1.0f
- if (static_cast<int>(flags) & static_cast<int>(dnnl::normalization_flags::use_scale_shift)) {
- const NDArray& gamma = in_data[batchnorm::kGamma];
- const NDArray& beta = in_data[batchnorm::kBeta];
- CHECK_EQ(gamma.storage_type(), mxnet::kDefaultStorage);
- CHECK_EQ(beta.storage_type(), mxnet::kDefaultStorage);
-
- const dnnl::memory& weight_mem = fwd.GetWeight();
- float* weight_buf = reinterpret_cast<float*>(weight_mem.get_data_handle());
-
- index_t channels_ = data.shape()[1];
- CHECK(weight_mem.get_desc().get_size() == channels_ * sizeof(float) * 2);
- float* weight_ptr = gamma.data().dptr<float>();
- float* bias_ptr = beta.data().dptr<float>();
- const size_t copy_size = sizeof(weight_buf[0]) * channels_;
- if (!param.fix_gamma) {
- memcpy(weight_buf, weight_ptr, copy_size);
- memcpy(&weight_buf[channels_], bias_ptr, copy_size);
- } else if (IsBNWriting(req[batchnorm::kGamma])) {
- for (index_t i = 0; i < channels_; i++) {
- weight_buf[i] = 1.0f;
- weight_ptr[i] = 1.0f;
- weight_buf[channels_ + i] = bias_ptr[i]; // bias
- }
- } else {
- for (index_t i = 0; i < channels_; i++) {
- weight_buf[i] = 1.0f;
- weight_buf[channels_ + i] = bias_ptr[i]; // bias
- }
- }
-
- dnnl_args_map_t net_args;
- net_args[DNNL_ARG_SRC] = *data_mem;
- net_args[DNNL_ARG_SCALE_SHIFT] = weight_mem;
- net_args[DNNL_ARG_DST] = *out_mem;
- if (fuse_relu) {
- const NDArray* workspace = nullptr;
- workspace = &outputs[3];
- auto engine = CpuEngine::Get()->get_engine();
- if (workspace == nullptr) {
- LOG(FATAL) << "oneDNN BatchNorm: incorrect workspace input";
- }
- auto ws = std::make_shared<dnnl::memory>(
- fwd.GetPd().workspace_desc(), engine, workspace->GetDNNLData()->get_data_handle());
- net_args[DNNL_ARG_WORKSPACE] = *ws;
- }
- if (!ctx.is_train || param.use_global_stats) {
- float* omean = outputs[batchnorm::kMean].data().dptr<float>();
- float* ovar = outputs[batchnorm::kVar].data().dptr<float>();
- float* inmean = aux_states[batchnorm::kMovingMean].data().dptr<float>();
- float* invar = aux_states[batchnorm::kMovingVar].data().dptr<float>();
- // to align with origin implmentation: batch_norm.cc: L164
- for (index_t i = 0; i < channels_; i++) {
- omean[i] = inmean[i];
- ovar[i] = VARIANCE_TO_INVSTD(invar[i], param.eps);
- }
- net_args[DNNL_ARG_MEAN] = *(aux_states[batchnorm::kMovingMean].GetDNNLData());
- net_args[DNNL_ARG_VARIANCE] = *(aux_states[batchnorm::kMovingVar].GetDNNLData());
- DNNLStream::Get()->RegisterPrimArgs(fwd.GetFwd(), net_args);
- DNNLStream::Get()->Submit();
- } else { // training
- const NDArray& outMean = outputs[batchnorm::kMean];
- const NDArray& outVar = outputs[batchnorm::kVar];
- net_args[DNNL_ARG_MEAN] = *(outMean.GetDNNLData());
- net_args[DNNL_ARG_VARIANCE] = *(outVar.GetDNNLData());
- DNNLStream::Get()->RegisterPrimArgs(fwd.GetFwd(), net_args);
- DNNLStream::Get()->Submit();
-
- float* ovar = outVar.data().dptr<float>();
- for (index_t i = 0; i < channels_; i++) {
- ovar[i] = VARIANCE_TO_INVSTD(ovar[i], param.eps);
- }
- }
- } else { // no input gamma and beta
- LOG(FATAL) << "oneDNN batch normalization: should not reach here ...";
- }
-}
-
-template <typename DType, bool fuse_relu>
-void DNNLBatchNormForward(const nnvm::NodeAttrs& attrs,
- const OpContext& ctx,
- const std::vector<NDArray>& inputs,
- const std::vector<OpReqType>& req,
- const std::vector<NDArray>& outputs) {
- DNNLBatchNormForwardImpl<DType>(attrs, ctx, inputs, req, outputs, fuse_relu);
+ auto data_mem = data.GetDNNLData();
+ DNNLBNForward& fwd = DNNLBNForward::GetCached(param, ctx, data_mem, fuse_relu, flags);
+ fwd.Execute(ctx, param, inputs, req, outputs, fuse_relu);
}
class DNNLBNBackward {
@@ -278,95 +172,50 @@ class DNNLBNBackward {
public:
const t_bn_b_pdesc pd;
- explicit DNNLBNBackward(const t_bn_b_pdesc& _pd)
- : weight_m(new dnnl::memory(_pd.weights_desc(), CpuEngine::Get()->get_engine())),
- gradw_m(new dnnl::memory(_pd.diff_weights_desc(), CpuEngine::Get()->get_engine())),
- pd(_pd) {
- bwd.reset(new dnnl::batch_normalization_backward(pd));
- }
+ explicit DNNLBNBackward(const t_bn_b_pdesc& _pd);
- const dnnl::memory& GetWeight() const {
- return *weight_m;
- }
+ const dnnl::memory& GetWeight() const;
- const dnnl::memory& GetGradw() const {
- return *gradw_m;
- }
+ const dnnl::memory& GetGradw() const;
- const dnnl::batch_normalization_backward& GetBwd() const {
- return *bwd;
- }
-};
+ const dnnl::batch_normalization_backward& GetBwd() const;
-template <typename DType>
-static DNNLBNBackward& GetBNBackward(const BatchNormParam& param,
- const OpContext& ctx,
- const NDArray& in_data,
- const dnnl::memory& in_mem,
- const NDArray& diff_data,
- const dnnl::memory& diff_mem,
- dnnl::normalization_flags flags) {
-#if DMLC_CXX11_THREAD_LOCAL
- static thread_local std::unordered_map<DNNLBNSignature, DNNLBNBackward, OpHash> bwds;
-#else
- static MX_THREAD_LOCAL std::unordered_map<DNNLBNSignature, DNNLBNBackward, OpHash> bwds;
-#endif
- DNNLBNSignature key(param);
- key.AddSign(in_data);
- key.AddSign(diff_data);
- key.AddSign(static_cast<int>(flags));
-
- auto it = bwds.find(key);
- if (it == bwds.end()) {
- auto bwd_pd = _GetBwd(in_mem, diff_mem, param.eps, flags);
- DNNLBNBackward bwd(bwd_pd);
- it = AddToCache(&bwds, key, bwd);
- }
- return it->second;
-}
+ static DNNLBNBackward& GetCached(const BatchNormParam& param,
+ const OpContext& ctx,
+ const NDArray& in_data,
+ const dnnl::memory& in_mem,
+ const NDArray& diff_data,
+ const dnnl::memory& diff_mem,
+ dnnl::normalization_flags flags);
-template <typename DType>
-void DNNLBatchNormBackwardImpl(const nnvm::NodeAttrs& attrs,
- const OpContext& ctx,
- const std::vector<NDArray>& inputs,
- const std::vector<OpReqType>& req,
- const std::vector<NDArray>& outputs,
- bool fuse_relu) {
- if (fuse_relu) {
- CHECK_EQ(inputs.size(), 9U);
- } else {
- CHECK_EQ(inputs.size(), 8U);
- }
+ void Execute(const BatchNormParam& param,
+ const OpContext& ctx,
+ const std::vector<NDArray>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<NDArray>& outputs);
+};
+
+inline void DNNLBatchNormBackward(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const std::vector<NDArray>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<NDArray>& outputs) {
const BatchNormParam& param = nnvm::get<BatchNormParam>(attrs.parsed);
std::vector<NDArray> out_grad(1);
- std::vector<NDArray> out_data(3);
std::vector<NDArray> in_data(3);
std::vector<NDArray> aux_states(2);
- out_grad[0] = inputs[0];
- out_data[batchnorm::kMean] = inputs[1];
- out_data[batchnorm::kVar] = inputs[2];
- in_data[batchnorm::kData] = inputs[3];
- in_data[batchnorm::kGamma] = inputs[4];
- in_data[batchnorm::kBeta] = inputs[5];
- aux_states[batchnorm::kMovingMean] = inputs[6];
- aux_states[batchnorm::kMovingVar] = inputs[7];
- const std::vector<NDArray>& in_grad = outputs;
+ out_grad[0] = inputs[0];
+ in_data[batchnorm::kData] = inputs[3];
+ in_data[batchnorm::kGamma] = inputs[4];
+ in_data[batchnorm::kBeta] = inputs[5];
+ aux_states[batchnorm::kMovingMean] = inputs[6];
+ aux_states[batchnorm::kMovingVar] = inputs[7];
TmpMemMgr::Get()->Init(ctx.requested[batchnorm::kTempSpace]);
dnnl::normalization_flags flags =
- _GetFlags(in_data, aux_states, ctx.is_train && !param.use_global_stats, fuse_relu);
-
- NDArray data = in_data[batchnorm::kData];
- NDArray diff = out_grad[batchnorm::kOut];
- NDArray gradIn = in_grad[batchnorm::kData];
- const NDArray& moving_mean = aux_states[batchnorm::kMovingMean];
- const NDArray& moving_var = aux_states[batchnorm::kMovingVar];
- const NDArray& out_mean = out_data[batchnorm::kMean];
- const NDArray& out_var = out_data[batchnorm::kVar];
+ _GetFlags(in_data, aux_states, ctx.is_train && !param.use_global_stats);
- CHECK(out_mean.IsDefaultData());
- CHECK(out_var.IsDefaultData());
- CHECK(moving_mean.IsDefaultData());
- CHECK(moving_var.IsDefaultData());
+ NDArray data = in_data[batchnorm::kData];
+ NDArray diff = out_grad[batchnorm::kOut];
mxnet::TShape shape = data.shape();
const int real_axis = mxnet::op::batchnorm::GetRealAxis(shape, param.axis);
@@ -378,15 +227,12 @@ void DNNLBatchNormBackwardImpl(const nnvm::NodeAttrs& attrs,
shape[real_axis],
1,
static_cast<index_t>(shape.ProdShape(real_axis + 1, static_cast<int>(shape.ndim())))};
- data = data.Reshape(new_shape);
- diff = diff.Reshape(new_shape);
- gradIn = gradIn.Reshape(new_shape);
+ data = data.Reshape(new_shape);
+ diff = diff.Reshape(new_shape);
}
auto data_mem = data.GetDNNLData();
auto diff_mem = diff.GetDNNLData();
- // DNNL batchnorm should run on special layouts. If one of them isn't, we
- // should reorder them.
if (data.IsDefaultData()) {
auto diff_desc = diff_mem->get_desc();
data_mem = data.GetDNNLDataReorder(&diff_desc);
@@ -394,113 +240,9 @@ void DNNLBatchNormBackwardImpl(const nnvm::NodeAttrs& attrs,
auto data_desc = data_mem->get_desc();
diff_mem = diff.GetDNNLDataReorder(&data_desc);
}
- auto& bwd = GetBNBackward<DType>(param, ctx, data, *data_mem, diff, *diff_mem, flags);
- auto gradi_mem =
- CreateDNNLMem(const_cast<NDArray&>(gradIn), bwd.pd.diff_src_desc(), req[batchnorm::kData]);
-
- if (static_cast<int>(flags) & static_cast<int>(dnnl::normalization_flags::use_scale_shift)) {
- const NDArray& gamma = in_data[batchnorm::kGamma];
- const NDArray& beta = in_data[batchnorm::kBeta];
- DType* weight_buf = reinterpret_cast<DType*>(bwd.GetWeight().get_data_handle());
- index_t channels_ = data.shape()[1];
- DType* weight_ptr = gamma.data().dptr<DType>();
- DType* bias_ptr = beta.data().dptr<DType>();
- const size_t copy_size = sizeof(DType) * channels_;
- if (!param.fix_gamma) {
- memcpy(weight_buf, weight_ptr, copy_size);
- memcpy(&weight_buf[channels_], bias_ptr, copy_size);
- } else {
- for (index_t i = 0; i < channels_; i++) {
- weight_buf[i] = static_cast<DType>(1.0f);
- }
- memcpy(&weight_buf[channels_], bias_ptr, copy_size);
- }
- dnnl_args_map_t net_args;
- net_args[DNNL_ARG_SRC] = *data_mem;
- net_args[DNNL_ARG_DIFF_SRC] = *gradi_mem.second;
- net_args[DNNL_ARG_SCALE_SHIFT] = bwd.GetWeight();
- net_args[DNNL_ARG_DIFF_SCALE_SHIFT] = bwd.GetGradw();
- net_args[DNNL_ARG_DIFF_DST] = *diff_mem;
-
- if (fuse_relu) {
- const NDArray* workspace = nullptr;
- workspace = &inputs[8];
- if (workspace != nullptr) {
- net_args[DNNL_ARG_WORKSPACE] = *(workspace->GetDNNLData());
- }
- }
-
- // training but no input mean and variance
- if (ctx.is_train && !param.use_global_stats) {
- DType* moving_mean_ptr = moving_mean.data().dptr<DType>();
- DType* moving_var_ptr = moving_var.data().dptr<DType>();
- DType* out_mean_ptr = out_mean.data().dptr<DType>();
- DType* out_var_ptr = out_var.data().dptr<DType>();
- dnnl::memory var_mem(bwd.pd.variance_desc(), CpuEngine::Get()->get_engine());
- DType* tmp_var_ptr = reinterpret_cast<DType*>(var_mem.get_data_handle());
-
- DType minus_mom = (1.0f - param.momentum);
- for (index_t i = 0; i < channels_; i++) {
- moving_mean_ptr[i] = moving_mean_ptr[i] * param.momentum + out_mean_ptr[i] * minus_mom;
- float variance = INVSTD_TO_VARIANCE(out_var_ptr[i], param.eps);
- tmp_var_ptr[i] = variance;
- moving_var_ptr[i] = moving_var_ptr[i] * param.momentum + variance * minus_mom;
- }
- net_args[DNNL_ARG_MEAN] = *(out_mean.GetDNNLData());
- net_args[DNNL_ARG_VARIANCE] = var_mem;
- } else {
- net_args[DNNL_ARG_MEAN] = *(moving_mean.GetDNNLData());
- net_args[DNNL_ARG_VARIANCE] = *(moving_var.GetDNNLData());
- }
- DNNLStream::Get()->RegisterPrimArgs(bwd.GetBwd(), net_args);
- CommitOutput(gradIn, gradi_mem);
- DNNLStream::Get()->Submit();
-
- // copy data from gradw_mem to in_grad[1] and in_grad[2]
- DType* gw_buf = reinterpret_cast<DType*>(bwd.GetGradw().get_data_handle());
- DType* w_grad_1 = in_grad[batchnorm::kGamma].data().dptr<DType>();
- DType* w_grad_2 = in_grad[batchnorm::kBeta].data().dptr<DType>();
-
- // the gradient of gamma
- if (!param.fix_gamma) {
- if (req[batchnorm::kGamma] != kNullOp) {
- if (req[batchnorm::kGamma] != kAddTo) {
- memcpy(w_grad_1, gw_buf, copy_size);
- } else {
- for (index_t i = 0; i < channels_; i++) {
- w_grad_1[i] += gw_buf[i];
- }
- }
- }
- } else {
- for (index_t i = 0; i < channels_; i++) {
- (in_grad[1].data().dptr<DType>())[i] = 0.0f;
- }
- }
-
- // the gradient of beta
- if (req[batchnorm::kBeta] != kNullOp) {
- if (req[batchnorm::kBeta] != kAddTo) {
- memcpy(w_grad_2, &gw_buf[channels_], copy_size);
- } else {
- DType* grad_beta = &gw_buf[channels_];
- for (index_t i = 0; i < channels_; i++) {
- w_grad_2[i] += grad_beta[i];
- }
- }
- }
- } else {
- LOG(FATAL) << "oneDNN batch normalization backward: should not reach here ...";
- }
-}
-
-template <typename DType, bool fuse_relu>
-void DNNLBatchNormBackward(const nnvm::NodeAttrs& attrs,
- const OpContext& ctx,
- const std::vector<NDArray>& inputs,
- const std::vector<OpReqType>& req,
- const std::vector<NDArray>& outputs) {
- DNNLBatchNormBackwardImpl<DType>(attrs, ctx, inputs, req, outputs, fuse_relu);
+ DNNLBNBackward& bwd =
+ DNNLBNBackward::GetCached(param, ctx, data, *data_mem, diff, *diff_mem, flags);
+ bwd.Execute(param, ctx, inputs, req, outputs);
}
} // namespace op
} // namespace mxnet
diff --git a/src/operator/nn/dnnl/dnnl_batch_norm-inl.h b/src/operator/nn/dnnl/dnnl_batch_norm.cc
similarity index 56%
copy from src/operator/nn/dnnl/dnnl_batch_norm-inl.h
copy to src/operator/nn/dnnl/dnnl_batch_norm.cc
index 2780c9685f..e4a6e1691a 100644
--- a/src/operator/nn/dnnl/dnnl_batch_norm-inl.h
+++ b/src/operator/nn/dnnl/dnnl_batch_norm.cc
@@ -19,20 +19,17 @@
/*!
* \file dnnl_batch_norm.cc
- * \brief
- * \author Tao Lv
*/
-#ifndef MXNET_OPERATOR_NN_DNNL_DNNL_BATCH_NORM_INL_H_
-#define MXNET_OPERATOR_NN_DNNL_DNNL_BATCH_NORM_INL_H_
-
#if MXNET_USE_ONEDNN == 1
#include <dnnl.hpp>
+
#include <utility>
#include <vector>
-#include "operator/nn/batch_norm-inl.h"
#include "dnnl_base-inl.h"
+#include "dnnl_batch_norm-inl.h"
+#include "operator/nn/batch_norm-inl.h"
namespace mxnet {
namespace op {
@@ -42,115 +39,57 @@ typedef dnnl::batch_normalization_forward::desc t_bn_f_desc;
typedef dnnl::batch_normalization_backward::primitive_desc t_bn_b_pdesc;
typedef dnnl::batch_normalization_backward::desc t_bn_b_desc;
-inline static dnnl::normalization_flags _GetFlags(const std::vector<NDArray>& in_data,
- const std::vector<NDArray>& aux_states,
- bool is_train_and_not_global_stats,
- bool fuse_relu) {
- dnnl::normalization_flags flags = static_cast<dnnl::normalization_flags>(0U);
- if (in_data.size() == 3U) {
- flags |= dnnl::normalization_flags::use_scale_shift;
- }
-
- // aux_states[0]: inMean
- // aux_states[1]: inVariance
- if (aux_states.size() == 2U && !is_train_and_not_global_stats) {
- flags |= dnnl::normalization_flags::use_global_stats;
- }
-
- if (fuse_relu) {
- flags |= dnnl::normalization_flags::fuse_norm_relu;
- }
- return flags;
+DNNLBNForward::DNNLBNForward(const t_bn_f_pdesc& _pd, bool is_train_and_not_global_stats)
+ : pd(_pd) {
+ weight_m.reset(new dnnl::memory(pd.weights_desc(), CpuEngine::Get()->get_engine()));
+ fwd.reset(new dnnl::batch_normalization_forward(pd));
+ this->is_train_and_not_global_stats = is_train_and_not_global_stats;
}
-inline static t_bn_f_pdesc _GetFwd(const dnnl::memory& data_mem,
- bool is_train,
- float eps,
- dnnl::normalization_flags flags) {
- auto data_md = data_mem.get_desc();
- auto engine = CpuEngine::Get()->get_engine();
-
- if (is_train) {
- t_bn_f_desc bnFwd_desc(dnnl::prop_kind::forward_training, data_md, eps, flags);
- return t_bn_f_pdesc(bnFwd_desc, engine);
- } else {
- t_bn_f_desc bnFwd_desc(dnnl::prop_kind::forward_inference, data_md, eps, flags);
- return t_bn_f_pdesc(bnFwd_desc, engine);
- }
+const dnnl::memory& DNNLBNForward::GetWeight() const {
+ return *weight_m;
}
-inline static t_bn_b_pdesc _GetBwd(const dnnl::memory& data_mem,
- const dnnl::memory& diff_mem,
- float eps,
- dnnl::normalization_flags flags) {
- auto data_md = data_mem.get_desc();
- auto diff_md = diff_mem.get_desc();
- auto engine = CpuEngine::Get()->get_engine();
-
- t_bn_b_desc bnBwd_desc(dnnl::prop_kind::backward, diff_md, data_md, eps, flags);
- return t_bn_b_pdesc(bnBwd_desc, engine, _GetFwd(data_mem, true, eps, flags));
+const t_bn_f_pdesc& DNNLBNForward::GetPd() const {
+ return pd;
}
-typedef ParamOpSign<BatchNormParam> DNNLBNSignature;
-
-class DNNLBNForward {
- std::shared_ptr<const dnnl::memory> weight_m;
- std::shared_ptr<dnnl::batch_normalization_forward> fwd;
- bool is_train_and_not_global_stats;
- t_bn_f_pdesc pd;
-
- public:
- DNNLBNForward(const t_bn_f_pdesc& _pd, bool is_train_and_not_global_stats) : pd(_pd) {
- weight_m.reset(new dnnl::memory(pd.weights_desc(), CpuEngine::Get()->get_engine()));
- fwd.reset(new dnnl::batch_normalization_forward(pd));
- this->is_train_and_not_global_stats = is_train_and_not_global_stats;
- }
-
- const dnnl::memory& GetWeight() const {
- return *weight_m;
- }
-
- const t_bn_f_pdesc& GetPd() const {
- return pd;
- }
-
- const dnnl::batch_normalization_forward& GetFwd() const {
- return *fwd;
- }
-};
+const dnnl::batch_normalization_forward& DNNLBNForward::GetFwd() const {
+ return *fwd;
+}
-template <typename DType>
-static DNNLBNForward& GetBNForward(const BatchNormParam& param,
- const OpContext& ctx,
- const dnnl::memory* data_mem,
- dnnl::normalization_flags flags) {
+DNNLBNForward& DNNLBNForward::GetCached(const BatchNormParam& param,
+ const OpContext& ctx,
+ const dnnl::memory* data_mem,
+ bool fuse_relu,
+ dnnl::normalization_flags flags) {
#if DMLC_CXX11_THREAD_LOCAL
static thread_local std::unordered_map<DNNLBNSignature, DNNLBNForward, OpHash> fwds;
#else
static MX_THREAD_LOCAL std::unordered_map<DNNLBNSignature, DNNLBNForward, OpHash> fwds;
#endif
+
DNNLBNSignature key(param);
key.AddSign(ctx.is_train);
key.AddSign(*data_mem);
key.AddSign(static_cast<int>(flags));
+ key.AddSign(fuse_relu);
auto it = fwds.find(key);
if (it == fwds.end()) {
- auto fwd_pd = _GetFwd(*data_mem, ctx.is_train, param.eps, flags);
+ auto fwd_pd = _GetFwd(*data_mem, ctx.is_train, fuse_relu, param.eps, flags);
DNNLBNForward fwd(fwd_pd, ctx.is_train && !param.use_global_stats);
it = AddToCache(&fwds, key, fwd);
}
return it->second;
}
-template <typename DType>
-void DNNLBatchNormForwardImpl(const nnvm::NodeAttrs& attrs,
- const OpContext& ctx,
- const std::vector<NDArray>& inputs,
- const std::vector<OpReqType>& req,
- const std::vector<NDArray>& outputs,
- bool fuse_relu) {
- const BatchNormParam& param = nnvm::get<BatchNormParam>(attrs.parsed);
+void DNNLBNForward::Execute(const OpContext& ctx,
+ const BatchNormParam& param,
+ const std::vector<NDArray>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<NDArray>& outputs,
+ bool fuse_relu) {
std::vector<NDArray> in_data(inputs.begin(), inputs.begin() + batchnorm::kInMovingMean);
mxnet::TShape shape = inputs[batchnorm::kData].shape();
@@ -171,15 +110,14 @@ void DNNLBatchNormForwardImpl(const nnvm::NodeAttrs& attrs,
const std::vector<NDArray> aux_states(inputs.begin() + batchnorm::kInMovingMean, inputs.end());
TmpMemMgr::Get()->Init(ctx.requested[batchnorm::kTempSpace]);
dnnl::normalization_flags flags =
- _GetFlags(in_data, aux_states, ctx.is_train && !param.use_global_stats, fuse_relu);
+ _GetFlags(in_data, aux_states, ctx.is_train && !param.use_global_stats);
NDArray& data = in_data[batchnorm::kData];
if (data.IsDNNLData() && data.IsView())
data = data.Reorder2Default();
auto data_mem = data.GetDNNLData();
- auto& fwd = GetBNForward<DType>(param, ctx, data_mem, flags);
// for output memory
- auto fwd_dst_desc = fwd.GetPd().dst_desc();
+ auto fwd_dst_desc = GetPd().dst_desc();
auto out_mem = const_cast<NDArray&>(out).CreateDNNLData(&fwd_dst_desc);
// mxnet will always use scale shift.
@@ -190,7 +128,7 @@ void DNNLBatchNormForwardImpl(const nnvm::NodeAttrs& attrs,
CHECK_EQ(gamma.storage_type(), mxnet::kDefaultStorage);
CHECK_EQ(beta.storage_type(), mxnet::kDefaultStorage);
- const dnnl::memory& weight_mem = fwd.GetWeight();
+ const dnnl::memory& weight_mem = GetWeight();
float* weight_buf = reinterpret_cast<float*>(weight_mem.get_data_handle());
index_t channels_ = data.shape()[1];
@@ -218,17 +156,6 @@ void DNNLBatchNormForwardImpl(const nnvm::NodeAttrs& attrs,
net_args[DNNL_ARG_SRC] = *data_mem;
net_args[DNNL_ARG_SCALE_SHIFT] = weight_mem;
net_args[DNNL_ARG_DST] = *out_mem;
- if (fuse_relu) {
- const NDArray* workspace = nullptr;
- workspace = &outputs[3];
- auto engine = CpuEngine::Get()->get_engine();
- if (workspace == nullptr) {
- LOG(FATAL) << "oneDNN BatchNorm: incorrect workspace input";
- }
- auto ws = std::make_shared<dnnl::memory>(
- fwd.GetPd().workspace_desc(), engine, workspace->GetDNNLData()->get_data_handle());
- net_args[DNNL_ARG_WORKSPACE] = *ws;
- }
if (!ctx.is_train || param.use_global_stats) {
float* omean = outputs[batchnorm::kMean].data().dptr<float>();
float* ovar = outputs[batchnorm::kVar].data().dptr<float>();
@@ -241,14 +168,14 @@ void DNNLBatchNormForwardImpl(const nnvm::NodeAttrs& attrs,
}
net_args[DNNL_ARG_MEAN] = *(aux_states[batchnorm::kMovingMean].GetDNNLData());
net_args[DNNL_ARG_VARIANCE] = *(aux_states[batchnorm::kMovingVar].GetDNNLData());
- DNNLStream::Get()->RegisterPrimArgs(fwd.GetFwd(), net_args);
+ DNNLStream::Get()->RegisterPrimArgs(GetFwd(), net_args);
DNNLStream::Get()->Submit();
} else { // training
const NDArray& outMean = outputs[batchnorm::kMean];
const NDArray& outVar = outputs[batchnorm::kVar];
net_args[DNNL_ARG_MEAN] = *(outMean.GetDNNLData());
net_args[DNNL_ARG_VARIANCE] = *(outVar.GetDNNLData());
- DNNLStream::Get()->RegisterPrimArgs(fwd.GetFwd(), net_args);
+ DNNLStream::Get()->RegisterPrimArgs(GetFwd(), net_args);
DNNLStream::Get()->Submit();
float* ovar = outVar.data().dptr<float>();
@@ -261,51 +188,32 @@ void DNNLBatchNormForwardImpl(const nnvm::NodeAttrs& attrs,
}
}
-template <typename DType, bool fuse_relu>
-void DNNLBatchNormForward(const nnvm::NodeAttrs& attrs,
- const OpContext& ctx,
- const std::vector<NDArray>& inputs,
- const std::vector<OpReqType>& req,
- const std::vector<NDArray>& outputs) {
- DNNLBatchNormForwardImpl<DType>(attrs, ctx, inputs, req, outputs, fuse_relu);
+DNNLBNBackward::DNNLBNBackward(const t_bn_b_pdesc& _pd)
+ : weight_m(new dnnl::memory(_pd.weights_desc(), CpuEngine::Get()->get_engine())),
+ gradw_m(new dnnl::memory(_pd.diff_weights_desc(), CpuEngine::Get()->get_engine())),
+ pd(_pd) {
+ bwd.reset(new dnnl::batch_normalization_backward(pd));
}
-class DNNLBNBackward {
- std::shared_ptr<dnnl::batch_normalization_backward> bwd;
- const std::shared_ptr<dnnl::memory> weight_m;
- const std::shared_ptr<dnnl::memory> gradw_m;
-
- public:
- const t_bn_b_pdesc pd;
-
- explicit DNNLBNBackward(const t_bn_b_pdesc& _pd)
- : weight_m(new dnnl::memory(_pd.weights_desc(), CpuEngine::Get()->get_engine())),
- gradw_m(new dnnl::memory(_pd.diff_weights_desc(), CpuEngine::Get()->get_engine())),
- pd(_pd) {
- bwd.reset(new dnnl::batch_normalization_backward(pd));
- }
+const dnnl::memory& DNNLBNBackward::GetWeight() const {
+ return *weight_m;
+}
- const dnnl::memory& GetWeight() const {
- return *weight_m;
- }
+const dnnl::memory& DNNLBNBackward::GetGradw() const {
+ return *gradw_m;
+}
- const dnnl::memory& GetGradw() const {
- return *gradw_m;
- }
+const dnnl::batch_normalization_backward& DNNLBNBackward::GetBwd() const {
+ return *bwd;
+}
- const dnnl::batch_normalization_backward& GetBwd() const {
- return *bwd;
- }
-};
-
-template <typename DType>
-static DNNLBNBackward& GetBNBackward(const BatchNormParam& param,
- const OpContext& ctx,
- const NDArray& in_data,
- const dnnl::memory& in_mem,
- const NDArray& diff_data,
- const dnnl::memory& diff_mem,
- dnnl::normalization_flags flags) {
+DNNLBNBackward& DNNLBNBackward::GetCached(const BatchNormParam& param,
+ const OpContext& ctx,
+ const NDArray& in_data,
+ const dnnl::memory& in_mem,
+ const NDArray& diff_data,
+ const dnnl::memory& diff_mem,
+ dnnl::normalization_flags flags) {
#if DMLC_CXX11_THREAD_LOCAL
static thread_local std::unordered_map<DNNLBNSignature, DNNLBNBackward, OpHash> bwds;
#else
@@ -325,19 +233,12 @@ static DNNLBNBackward& GetBNBackward(const BatchNormParam& param,
return it->second;
}
-template <typename DType>
-void DNNLBatchNormBackwardImpl(const nnvm::NodeAttrs& attrs,
- const OpContext& ctx,
- const std::vector<NDArray>& inputs,
- const std::vector<OpReqType>& req,
- const std::vector<NDArray>& outputs,
- bool fuse_relu) {
- if (fuse_relu) {
- CHECK_EQ(inputs.size(), 9U);
- } else {
- CHECK_EQ(inputs.size(), 8U);
- }
- const BatchNormParam& param = nnvm::get<BatchNormParam>(attrs.parsed);
+void DNNLBNBackward::Execute(const BatchNormParam& param,
+ const OpContext& ctx,
+ const std::vector<NDArray>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<NDArray>& outputs) {
+ CHECK_EQ(inputs.size(), 8U);
std::vector<NDArray> out_grad(1);
std::vector<NDArray> out_data(3);
std::vector<NDArray> in_data(3);
@@ -353,7 +254,7 @@ void DNNLBatchNormBackwardImpl(const nnvm::NodeAttrs& attrs,
const std::vector<NDArray>& in_grad = outputs;
TmpMemMgr::Get()->Init(ctx.requested[batchnorm::kTempSpace]);
dnnl::normalization_flags flags =
- _GetFlags(in_data, aux_states, ctx.is_train && !param.use_global_stats, fuse_relu);
+ _GetFlags(in_data, aux_states, ctx.is_train && !param.use_global_stats);
NDArray data = in_data[batchnorm::kData];
NDArray diff = out_grad[batchnorm::kOut];
@@ -394,52 +295,43 @@ void DNNLBatchNormBackwardImpl(const nnvm::NodeAttrs& attrs,
auto data_desc = data_mem->get_desc();
diff_mem = diff.GetDNNLDataReorder(&data_desc);
}
- auto& bwd = GetBNBackward<DType>(param, ctx, data, *data_mem, diff, *diff_mem, flags);
auto gradi_mem =
- CreateDNNLMem(const_cast<NDArray&>(gradIn), bwd.pd.diff_src_desc(), req[batchnorm::kData]);
+ CreateDNNLMem(const_cast<NDArray&>(gradIn), pd.diff_src_desc(), req[batchnorm::kData]);
if (static_cast<int>(flags) & static_cast<int>(dnnl::normalization_flags::use_scale_shift)) {
const NDArray& gamma = in_data[batchnorm::kGamma];
const NDArray& beta = in_data[batchnorm::kBeta];
- DType* weight_buf = reinterpret_cast<DType*>(bwd.GetWeight().get_data_handle());
+ float* weight_buf = reinterpret_cast<float*>(GetWeight().get_data_handle());
index_t channels_ = data.shape()[1];
- DType* weight_ptr = gamma.data().dptr<DType>();
- DType* bias_ptr = beta.data().dptr<DType>();
- const size_t copy_size = sizeof(DType) * channels_;
+ float* weight_ptr = gamma.data().dptr<float>();
+ float* bias_ptr = beta.data().dptr<float>();
+ const size_t copy_size = sizeof(weight_buf[0]) * channels_;
if (!param.fix_gamma) {
memcpy(weight_buf, weight_ptr, copy_size);
memcpy(&weight_buf[channels_], bias_ptr, copy_size);
} else {
for (index_t i = 0; i < channels_; i++) {
- weight_buf[i] = static_cast<DType>(1.0f);
+ weight_buf[i] = 1.0f;
}
memcpy(&weight_buf[channels_], bias_ptr, copy_size);
}
dnnl_args_map_t net_args;
net_args[DNNL_ARG_SRC] = *data_mem;
net_args[DNNL_ARG_DIFF_SRC] = *gradi_mem.second;
- net_args[DNNL_ARG_SCALE_SHIFT] = bwd.GetWeight();
- net_args[DNNL_ARG_DIFF_SCALE_SHIFT] = bwd.GetGradw();
+ net_args[DNNL_ARG_SCALE_SHIFT] = GetWeight();
+ net_args[DNNL_ARG_DIFF_SCALE_SHIFT] = GetGradw();
net_args[DNNL_ARG_DIFF_DST] = *diff_mem;
- if (fuse_relu) {
- const NDArray* workspace = nullptr;
- workspace = &inputs[8];
- if (workspace != nullptr) {
- net_args[DNNL_ARG_WORKSPACE] = *(workspace->GetDNNLData());
- }
- }
-
// training but no input mean and variance
if (ctx.is_train && !param.use_global_stats) {
- DType* moving_mean_ptr = moving_mean.data().dptr<DType>();
- DType* moving_var_ptr = moving_var.data().dptr<DType>();
- DType* out_mean_ptr = out_mean.data().dptr<DType>();
- DType* out_var_ptr = out_var.data().dptr<DType>();
- dnnl::memory var_mem(bwd.pd.variance_desc(), CpuEngine::Get()->get_engine());
- DType* tmp_var_ptr = reinterpret_cast<DType*>(var_mem.get_data_handle());
-
- DType minus_mom = (1.0f - param.momentum);
+ float* moving_mean_ptr = moving_mean.data().dptr<float>();
+ float* moving_var_ptr = moving_var.data().dptr<float>();
+ float* out_mean_ptr = out_mean.data().dptr<float>();
+ float* out_var_ptr = out_var.data().dptr<float>();
+ dnnl::memory var_mem(pd.variance_desc(), CpuEngine::Get()->get_engine());
+ float* tmp_var_ptr = reinterpret_cast<float*>(var_mem.get_data_handle());
+
+ float minus_mom = (1.0f - param.momentum);
for (index_t i = 0; i < channels_; i++) {
moving_mean_ptr[i] = moving_mean_ptr[i] * param.momentum + out_mean_ptr[i] * minus_mom;
float variance = INVSTD_TO_VARIANCE(out_var_ptr[i], param.eps);
@@ -452,14 +344,14 @@ void DNNLBatchNormBackwardImpl(const nnvm::NodeAttrs& attrs,
net_args[DNNL_ARG_MEAN] = *(moving_mean.GetDNNLData());
net_args[DNNL_ARG_VARIANCE] = *(moving_var.GetDNNLData());
}
- DNNLStream::Get()->RegisterPrimArgs(bwd.GetBwd(), net_args);
+ DNNLStream::Get()->RegisterPrimArgs(GetBwd(), net_args);
CommitOutput(gradIn, gradi_mem);
DNNLStream::Get()->Submit();
// copy data from gradw_mem to in_grad[1] and in_grad[2]
- DType* gw_buf = reinterpret_cast<DType*>(bwd.GetGradw().get_data_handle());
- DType* w_grad_1 = in_grad[batchnorm::kGamma].data().dptr<DType>();
- DType* w_grad_2 = in_grad[batchnorm::kBeta].data().dptr<DType>();
+ float* gw_buf = reinterpret_cast<float*>(GetGradw().get_data_handle());
+ float* w_grad_1 = in_grad[batchnorm::kGamma].data().dptr<float>();
+ float* w_grad_2 = in_grad[batchnorm::kBeta].data().dptr<float>();
// the gradient of gamma
if (!param.fix_gamma) {
@@ -474,7 +366,7 @@ void DNNLBatchNormBackwardImpl(const nnvm::NodeAttrs& attrs,
}
} else {
for (index_t i = 0; i < channels_; i++) {
- (in_grad[1].data().dptr<DType>())[i] = 0.0f;
+ (in_grad[1].data().dptr<float>())[i] = 0.0f;
}
}
@@ -483,7 +375,7 @@ void DNNLBatchNormBackwardImpl(const nnvm::NodeAttrs& attrs,
if (req[batchnorm::kBeta] != kAddTo) {
memcpy(w_grad_2, &gw_buf[channels_], copy_size);
} else {
- DType* grad_beta = &gw_buf[channels_];
+ float* grad_beta = &gw_buf[channels_];
for (index_t i = 0; i < channels_; i++) {
w_grad_2[i] += grad_beta[i];
}
@@ -494,15 +386,6 @@ void DNNLBatchNormBackwardImpl(const nnvm::NodeAttrs& attrs,
}
}
-template <typename DType, bool fuse_relu>
-void DNNLBatchNormBackward(const nnvm::NodeAttrs& attrs,
- const OpContext& ctx,
- const std::vector<NDArray>& inputs,
- const std::vector<OpReqType>& req,
- const std::vector<NDArray>& outputs) {
- DNNLBatchNormBackwardImpl<DType>(attrs, ctx, inputs, req, outputs, fuse_relu);
-}
} // namespace op
} // namespace mxnet
-#endif // MXNET_USE_ONEDNN
-#endif // MXNET_OPERATOR_NN_DNNL_DNNL_BATCH_NORM_INL_H_
+#endif // MXNET_USE_ONEDNN == 1
diff --git a/src/operator/quantization/dnnl/dnnl_quantized_batch_norm.cc b/src/operator/quantization/dnnl/dnnl_quantized_batch_norm.cc
index ba3b2f84ce..39bbb03d78 100644
--- a/src/operator/quantization/dnnl/dnnl_quantized_batch_norm.cc
+++ b/src/operator/quantization/dnnl/dnnl_quantized_batch_norm.cc
@@ -83,7 +83,7 @@ static void DNNLQuantizedBatchNormForward(const nnvm::NodeAttrs& attrs,
dnnl::normalization_flags flags =
dnnl::normalization_flags::use_global_stats | dnnl::normalization_flags::use_scale_shift;
- auto& fwd = GetBNForward<float>(param, ctx, data_mem, flags);
+ auto& fwd = DNNLBNForward::GetCached(param, ctx, data_mem, false, flags);
const dnnl::memory& weight_mem = fwd.GetWeight();
CHECK_EQ(weight_mem.get_desc().get_size(), channel_count * sizeof(float) * 2);
float* weight_buf = reinterpret_cast<float*>(weight_mem.get_data_handle());
diff --git a/src/operator/contrib/batch_norm_relu.cc b/src/operator/subgraph/dnnl/dnnl_bn_relu.cc
similarity index 73%
rename from src/operator/contrib/batch_norm_relu.cc
rename to src/operator/subgraph/dnnl/dnnl_bn_relu.cc
index a0f158f42b..383beb0eab 100644
--- a/src/operator/contrib/batch_norm_relu.cc
+++ b/src/operator/subgraph/dnnl/dnnl_bn_relu.cc
@@ -18,17 +18,18 @@
*/
/*!
- * \file batch_norm_relu.cc
+ * \file dnnl_bn_relu.cc
* \brief
* \author Xinyu Chen
*/
-#include "../nn/batch_norm-inl.h"
#include <nnvm/op_attr_types.h>
-#include "../elemwise_op_common.h"
-#include "../operator_common.h"
+
+#include "operator/elemwise_op_common.h"
+#include "operator/nn/batch_norm-inl.h"
+#include "operator/operator_common.h"
#if MXNET_USE_ONEDNN == 1
-#include "../nn/dnnl/dnnl_batch_norm-inl.h"
+#include "operator/nn/dnnl/dnnl_batch_norm-inl.h"
#endif
namespace mxnet {
@@ -144,27 +145,12 @@ void BatchNormWithReLUComputeExCPU(const nnvm::NodeAttrs& attrs,
if (SupportDNNLBNReLU(inputs[0])) {
CHECK_GT(outputs.size(), 3U);
DNNL_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
- DNNL_REAL_TYPE_SWITCH(inputs[0].dtype(), DTYPE, {
- DNNLRun(DNNLBatchNormForward<DTYPE, /*fuse_relu*/ true>, attrs, ctx, inputs, req, outputs);
- });
+ DNNLRun(DNNLBatchNormForward</*fuse_relu*/ true>, attrs, ctx, inputs, req, outputs);
return;
}
LOG(FATAL) << "BatchNormWithReLU operator only supports oneDNN Backend.";
}
-void BatchNormWithReLUGradComputeExCPU(const nnvm::NodeAttrs& attrs,
- const OpContext& ctx,
- const std::vector<NDArray>& inputs,
- const std::vector<OpReqType>& req,
- const std::vector<NDArray>& outputs) {
- if (SupportDNNLBNReLU(inputs[0])) {
- CHECK_EQ(inputs.size(), 9U);
- DNNL_OPCHECK_INIT(true, outputs.size(), inputs, outputs);
- DNNLRun(DNNLBatchNormBackward<float, /*fuse_relu*/ true>, attrs, ctx, inputs, req, outputs);
- return;
- }
- LOG(FATAL) << "BatchNormWithReLU operator only supports oneDNN Backend.";
-}
#endif
static inline bool BatchNormWithReLUStorageType(const nnvm::NodeAttrs& attrs,
@@ -200,47 +186,7 @@ static inline bool BatchNormWithReLUStorageType(const nnvm::NodeAttrs& attrs,
return dispatched;
}
-std::vector<nnvm::NodeEntry> BatchNormWithReLUGrad(const nnvm::ObjectPtr& n,
- const std::vector<nnvm::NodeEntry>& ograds) {
- std::vector<nnvm::NodeEntry> out_data;
- out_data.reserve(n->num_outputs());
- for (size_t i = 0; i < n->num_outputs(); ++i)
- out_data.emplace_back(n, i, 0);
- std::vector<nnvm::NodeEntry> heads;
- heads.reserve(9);
- heads.emplace_back(ograds.at(0));
- heads.emplace_back(out_data.at(batchnormrelu::kMean));
- heads.emplace_back(out_data.at(batchnormrelu::kVar));
- heads.emplace_back(n->inputs.at(batchnormrelu::kData));
- heads.emplace_back(n->inputs.at(batchnormrelu::kGamma));
- heads.emplace_back(n->inputs.at(batchnormrelu::kBeta));
- heads.emplace_back(n->inputs.at(batchnormrelu::kInMovingMean));
- heads.emplace_back(n->inputs.at(batchnormrelu::kInMovingVar));
- heads.emplace_back(out_data.at(batchnormrelu::kWorkspace));
-
- nnvm::ObjectPtr gnode = nnvm::Node::Create();
- gnode->inputs = std::move(heads);
- gnode->control_deps.emplace_back(n);
- gnode->attrs = n->attrs;
- gnode->attrs.op = nnvm::Op::Get("_backward_contrib_BatchNormWithReLU");
- gnode->attrs.name = n->attrs.name + "_backward";
- // The input of batchnorm
- std::vector<nnvm::NodeEntry> in_grad;
- in_grad.reserve(5);
- for (size_t i = 0; i < 3; ++i)
- in_grad.emplace_back(gnode, i, 0);
- // attach no gradient node to forbid gradient on aux_state
- nnvm::ObjectPtr ng = nnvm::Node::Create();
- ng->attrs.op = Op::Get("_NoGradient");
- ng->attrs.name = "NoGradient";
- // the aux state of batchnorm
- for (size_t i = 3; i < 5; ++i)
- in_grad.emplace_back(ng);
- return in_grad;
-}
-
-NNVM_REGISTER_OP(_contrib_BatchNormWithReLU)
- .add_alias("_npx_batch_norm_with_relu")
+NNVM_REGISTER_OP(_sg_onednn_batch_norm)
.describe(R"code(Batch normalization with ReLU fusion.
An extented operator of Batch normalization which can fuse ReLU activation.
@@ -275,7 +221,6 @@ An extented operator of Batch normalization which can fuse ReLU activation.
#if MXNET_USE_ONEDNN == 1
.set_attr<FComputeEx>("FComputeEx<cpu>", BatchNormWithReLUComputeExCPU)
#endif
- .set_attr<nnvm::FGradient>("FGradient", BatchNormWithReLUGrad)
#if MXNET_USE_ONEDNN == 1
.set_attr<bool>("TIsDNNL", true)
.set_attr<FResourceRequest>("FResourceRequest",
@@ -301,20 +246,5 @@ An extented operator of Batch normalization which can fuse ReLU activation.
}
});
-NNVM_REGISTER_OP(_backward_contrib_BatchNormWithReLU)
- .set_num_inputs(9)
- .set_num_outputs(3)
- .set_attr<nnvm::TIsBackward>("TIsBackward", true)
- .set_attr<FInferStorageType>("FInferStorageType", BatchNormWithReLUStorageType)
-#if MXNET_USE_ONEDNN == 1
- .set_attr<FResourceRequest>("FResourceRequest",
- [](const NodeAttrs& n) {
- return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
- })
- .set_attr<bool>("TIsDNNL", true)
- .set_attr<FComputeEx>("FComputeEx<cpu>", BatchNormWithReLUGradComputeExCPU)
-#endif
- .set_attr_parser(ParamParser<BatchNormParam>);
-
} // namespace op
} // namespace mxnet
diff --git a/src/operator/subgraph/dnnl/dnnl_bn_relu_property.h b/src/operator/subgraph/dnnl/dnnl_bn_relu_property.h
index 792236a50d..c350fb90a4 100644
--- a/src/operator/subgraph/dnnl/dnnl_bn_relu_property.h
+++ b/src/operator/subgraph/dnnl/dnnl_bn_relu_property.h
@@ -25,10 +25,10 @@
#include <string>
#include <vector>
+#include "dnnl_subgraph_base-inl.h"
#include "operator/nn/dnnl/dnnl_act-inl.h"
#include "operator/nn/dnnl/dnnl_batch_norm-inl.h"
#include "operator/subgraph/common.h"
-#include "dnnl_subgraph_base-inl.h"
namespace mxnet {
namespace op {
@@ -116,10 +116,11 @@ class SgDNNLBNReLUProperty : public SubgraphProperty {
});
n->attrs.name = node_name.str();
- n->attrs.op = Op::Get("_contrib_BatchNormWithReLU");
+ n->attrs.op = Op::Get("_sg_onednn_batch_norm");
CHECK(n->attrs.op);
n->attrs.subgraphs.emplace_back(std::make_shared<nnvm::Symbol>(sym));
- n->attrs.parsed = param;
+ param.SetAttrDict(&(n->attrs.dict));
+ n->op()->attr_parser(&(n->attrs));
return n;
}
diff --git a/tests/python/dnnl/op_cfg.py b/tests/python/dnnl/op_cfg.py
index 9effb305b1..c4739ab807 100644
--- a/tests/python/dnnl/op_cfg.py
+++ b/tests/python/dnnl/op_cfg.py
@@ -197,7 +197,6 @@ def get_all_ops_cfgs(dtype):
TensorArg(lambda shape, dtype: mx.nd.random.uniform(0, 1, shape, dtype))
],
},
- '_contrib_BatchNormWithReLU': {CFG_BASED_ON: 'BatchNorm'},
'LRN': {
'data,nsize': [(default_tensor(2, dtype), 3)]
},
@@ -289,6 +288,7 @@ def get_all_ops_cfgs(dtype):
CFG_BASED_ON: 'batch_dot',
CFG_SUBGRAPH: [SubgraphCfg('batch_dot', 'ONEDNN')],
},
+ '_sg_onednn_batch_norm': {CFG_BASED_ON: 'BatchNorm'},
'_sg_onednn_selfatt_qk': {
CFG_SUBGRAPH: [SubgraphCfg('_sg_onednn_selfatt_qk', 'ONEDNN')],
'queries_keys_values': [mx.nd.random.normal(0, 1, (1, 4, 3*2*8), dtype)],
diff --git a/tests/python/dnnl/test_dnnl.py b/tests/python/dnnl/test_dnnl.py
index 102af2e38d..8eb8373784 100644
--- a/tests/python/dnnl/test_dnnl.py
+++ b/tests/python/dnnl/test_dnnl.py
@@ -288,64 +288,6 @@ def test_batchnorm():
for stype in stypes:
check_batchnorm_training(stype)
-def test_batchnorm_relu_fusion():
- def check_batchnorm_relu_fusion(shape):
- x = mx.sym.Variable('x')
- in_data = mx.nd.random.normal(shape=shape)
- grad_out = mx.nd.random.uniform(0, 1, shape)
- bn = mx.sym.BatchNorm(data=x, fix_gamma=False)
- relu = mx.sym.Activation(data=bn, act_type='relu', name='relu')
- exe = relu._simple_bind(ctx=mx.cpu(), x=shape, grad_req='write')
- exe.arg_arrays[0][:] = in_data
- exe.forward(is_train=True)
- exe.backward(grad_out)
- no_fuse_outputs = exe.outputs
- no_fuse_grads = exe.grad_arrays
-
- bnrelu = mx.sym.contrib.BatchNormWithReLU(data=x, fix_gamma=False)
- exe_fuse = bnrelu._simple_bind(ctx=mx.cpu(), x=shape, grad_req='write')
- exe_fuse.arg_arrays[0][:] = in_data
- exe_fuse.forward(is_train=True)
- exe_fuse.backward(grad_out)
- fuse_outputs = exe_fuse.outputs
- fuse_grads = exe_fuse.grad_arrays
-
- for i in range(len(no_fuse_outputs)):
- assert_almost_equal(no_fuse_outputs[i], fuse_outputs[i])
- for i in range(len(no_fuse_grads)):
- assert_almost_equal(no_fuse_grads[i], fuse_grads[i])
-
- def check_batchnorm_relu_fusion_gluon(shape):
- class BNNet(gluon.HybridBlock):
- def __init__(self, fuse_relu):
- super(BNNet, self).__init__()
- self.fuse_relu = fuse_relu
- if self.fuse_relu:
- self.bn = gluon.nn.BatchNormReLU()
- else:
- self.bn = gluon.nn.BatchNorm()
- self.relu = gluon.nn.Activation('relu')
-
- def forward(self, x):
- y = self.bn(x)
- if not self.fuse_relu:
- y = self.relu(y)
- return y
- fused_net = BNNet(fuse_relu=True)
- unfused_net = BNNet(fuse_relu=False)
- fused_net.initialize()
- unfused_net.initialize()
- in_data = mx.np.random.normal(size=shape)
- no_fuse_outputs = unfused_net.forward(in_data)
- fuse_outputs = fused_net.forward(in_data)
-
- for i in range(len(no_fuse_outputs)):
- assert_almost_equal(no_fuse_outputs[i], fuse_outputs[i])
-
- check_batchnorm_relu_fusion((1, 3, 224, 224))
- check_batchnorm_relu_fusion((8, 3, 224, 224))
- check_batchnorm_relu_fusion_gluon((1, 3, 224, 224))
- check_batchnorm_relu_fusion_gluon((8, 3, 224, 224))
def test_softmax():
def check_softmax_training(stype):