You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by pa...@apache.org on 2019/05/24 22:50:41 UTC
[incubator-mxnet] branch master updated: MKLDNN RNN Inference
Integration(fp32 LSTM and vRNN with tanh and relu) (#14713)
This is an automated email from the ASF dual-hosted git repository.
patriczhao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 653cbb4 MKLDNN RNN Inference Integration(fp32 LSTM and vRNN with tanh and relu) (#14713)
653cbb4 is described below
commit 653cbb4f1485a450a81d53f6e76421c980be2689
Author: Hao Li <ha...@intel.com>
AuthorDate: Sat May 25 06:50:20 2019 +0800
MKLDNN RNN Inference Integration(fp32 LSTM and vRNN with tanh and relu) (#14713)
* trigger the ci
* integrate mkldnn rnn fp32 inference(LSTM and vRNN with tanh and relu)
* fix bug about comparison between signed and unsigned integer expressions
* fix unix-gpu issue
* fix unix gpu bug
* fix unix-gpu issues
* fix some comments
* fix issue
* fix comment
* rename `cached` to `initialized`
* support IType
* TODO for MKLDNN GRU
* fix bugs in memory adjustment
* Reformat TODO for MKLDNN GRU
* Reserve original RNN path
* Remove MKLDNN GRU
* Fix bug for rnn forward
* Remove `__CUDAACC__`
* Move `RNNStatefulComputeCPU` to rnn.cc
* Remove redundent macro of `__CUDACC__`
* Remove the last macro `__CUDACC__` from rnn*
---
src/operator/nn/mkldnn/mkldnn_rnn_impl.h | 740 +++++++++++++++++++++++++++++++
src/operator/rnn-inl.h | 161 +++++--
src/operator/rnn.cc | 431 ++++++++++++++++++
src/operator/rnn_impl.h | 7 +
4 files changed, 1297 insertions(+), 42 deletions(-)
diff --git a/src/operator/nn/mkldnn/mkldnn_rnn_impl.h b/src/operator/nn/mkldnn/mkldnn_rnn_impl.h
new file mode 100644
index 0000000..ea8e07e
--- /dev/null
+++ b/src/operator/nn/mkldnn/mkldnn_rnn_impl.h
@@ -0,0 +1,740 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef MXNET_OPERATOR_NN_MKLDNN_MKLDNN_RNN_IMPL_H_
+#define MXNET_OPERATOR_NN_MKLDNN_MKLDNN_RNN_IMPL_H_
+#if MXNET_USE_MKLDNN == 1
+#include <dmlc/logging.h>
+#include <dmlc/parameter.h>
+#include <mxnet/operator.h>
+#include <mxnet/storage.h>
+#include <algorithm>
+#include <map>
+#include <vector>
+#include <utility>
+#include <string>
+#include "../../math_functions-inl.h"
+#include "../../operator_common.h"
+#include "../../rnn_impl.h"
+#include "../../rnn-inl.h"
+#include "mkldnn.hpp"
+#include "./mkldnn_base-inl.h"
+
+namespace mxnet {
+namespace op {
+
+static algorithm GetMKLDNNRNNAlgo(int mode,
+ int* ngates,
+ int* nstates) {
+ algorithm algo = algorithm::vanilla_rnn;
+ switch (mode) {
+ case rnn_enum::kLstm:
+ *ngates = 4;
+ *nstates = 2;
+ algo = algorithm::vanilla_lstm;
+ break;
+ case rnn_enum::kGru:
+ *ngates = 3;
+ *nstates = 1;
+ algo = algorithm::vanilla_gru;
+ break;
+ case rnn_enum::kRnnRelu:
+ case rnn_enum::kRnnTanh:
+ *ngates = 1;
+ *nstates = 1;
+ algo = algorithm::vanilla_rnn;
+ break;
+ default:
+ LOG(FATAL) << "unsupported RNN mode:" << mode;
+ break;
+ }
+ return algo;
+}
+
+static void ConcatData(mkldnn::memory::format src_format,
+ mkldnn::memory::format dst_format,
+ std::vector<mkldnn::memory::dims> srcs_cds,
+ mkldnn::memory::dims dst_cds,
+ mkldnn::memory::data_type mkldnn_dtype,
+ int concat_dimension,
+ std::vector<void*> srcs_data,
+ const mkldnn::memory &dst) {
+ auto cpu_engine = CpuEngine::Get()->get_engine();
+ std::vector<mkldnn::memory::primitive_desc> srcs_pd;
+ std::vector<mkldnn::memory> srcs;
+ for (size_t i = 0; i < srcs_cds.size(); i++) {
+ auto desc = mkldnn::memory::desc(srcs_cds[i], mkldnn_dtype, src_format);
+ auto mpd = mkldnn::memory::primitive_desc(desc, cpu_engine);
+ auto src_memory = mkldnn::memory(mpd, srcs_data[i]);
+ srcs_pd.push_back(mpd);
+ srcs.push_back(src_memory);
+ }
+ std::vector<primitive::at> inputs;
+ for (size_t i = 0; i < srcs_cds.size(); i++) {
+ inputs.push_back(srcs[i]);
+ }
+ auto dst_desc = mkldnn::memory::desc(dst_cds, mkldnn_dtype, dst_format);
+ auto concat_pd = concat::primitive_desc(dst_desc, concat_dimension, srcs_pd);
+ MKLDNNStream::Get()->RegisterPrim(concat(concat_pd, inputs, dst));
+ MKLDNNStream::Get()->Submit();
+}
+
+// cached mkldnn memory
+// first layer wx, wh with next L - 1 layers wx and wh
+// with L layers hx and cx, src and dst data/iter etc.
+// it will prepare memory on before and after reorder and concat.
+// for unidirectional, it will fused as dim like 1 + (L - 1) when I != H.
+// for bidirectional, it will fused as data + back_data (weight, bias, iter etc),
+// also need to identify first layer and next layers
+static size_t GetMKLDNNRNNCacheMemorySize(int L,
+ int D,
+ int T,
+ int N,
+ int I,
+ int H,
+ int mode) {
+ size_t size = 0;
+ switch (mode) {
+ case rnn_enum::kLstm:
+ size = 2 * (D * (I + H) * 4 * H + (L - 1) * D * (D * H + H) * 4 * H +
+ L * D * 2 * N * H) + T * N * D * H + L * 2 * D * 4 * H + (L + 2) * D * 2 * N * H +
+ 6 * D * (I + H + 2) * 4 * H + T * N * I * 2;
+ break;
+ case rnn_enum::kGru:
+ size = 2 * (D * (I + H) * 3 * H + (L - 1) * D * (D * H + H) * 3 * H +
+ L * D * 2 * N * H) + T * N * D * H + L * 2 * D * 3 * H + (L + 2) * D * 2 * N * H +
+ 6 * D * (I + H + 2) * 3 * H + T * N * I * 2;
+ break;
+ case rnn_enum::kRnnRelu:
+ case rnn_enum::kRnnTanh:
+ size = 2 * (D * (I + H) * 1 * H + (L - 1) * D * (D * H + H) * 1 * H +
+ L * D * 2 * N * H) + T * N * D * H + L * 2 * D * 1 * H + (L + 2) * D * 2 * N * H +
+ 6 * D * (I + H + 2) * 1 * H + T * N * I * 2;
+ break;
+ default:
+ LOG(FATAL) << "unknown RNN mode " << mode;
+ break;
+ }
+ return size;
+}
+
+template <typename DType>
+static void AdjustGruWeightGateOrder(DType* weight,
+ const int I,
+ const int H) {
+ // mxnet gru gate order is reset, update and new gates
+ // mkldnn gru gate order is update, reset and new gates
+ const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
+ DType* weight_reset = weight;
+ DType* weight_update = weight + I * H;
+ #pragma omp parallel for num_threads(omp_threads)
+ for (int i = 0; i < I * H; i++) {
+ DType tmp = weight_update[i];
+ weight_update[i] = weight_reset[i];
+ weight_reset[i] = tmp;
+ }
+}
+
+template <typename DType>
+static void AdjustGruBiasGateOrder(DType* bias,
+ const int H) {
+ // mxnet gru gate order is reset, update and new gates
+ // mkldnn gru gate order is update, reset and new gates
+ const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
+ DType* bias_reset = bias;
+ DType* bias_update = bias + H;
+ #pragma omp parallel for num_threads(omp_threads)
+ for (int i = 0; i < H; i++) {
+ DType tmp = bias_update[i];
+ bias_update[i] = bias_reset[i];
+ bias_reset[i] = tmp;
+ }
+}
+// since there is different sematics of MKLDNN's Fused RNN and MXNet FusedRNN,
+// bidirectional will be fused layer by layer,
+// unidirectional will be done by fused 1 + fused (L - 1) layers or fused L layers(when I = H)
+
+template <typename DType>
+static void MKLDNNRNNForwardSingleLayerBi(bool state_outputs,
+ const int T,
+ const int N,
+ const int I,
+ const int H,
+ DType* x_ptr,
+ mkldnn::memory *user_src_layer_memory,
+ DType* hx_ptr,
+ DType* cx_ptr,
+ DType* w_ptr,
+ DType* b_ptr,
+ DType* y_ptr,
+ DType* hy_ptr,
+ DType* cy_ptr,
+ std::vector<mkldnn::memory> *concat_weight_memory,
+ std::vector<mkldnn::memory> *concat_iter_memory,
+ std::vector<mkldnn::memory> *x_memory,
+ std::vector<mkldnn::memory> *hcx_memory,
+ std::vector<mkldnn::memory> *wx_memory,
+ std::vector<mkldnn::memory> *wh_memory,
+ std::vector<mkldnn::memory> *bias_memory,
+ std::vector<mkldnn::memory> *y_memory,
+ std::vector<mkldnn::memory> *hcy_memory,
+ std::vector<primitive> *rnn_forward_prim,
+ int layer_index,
+ bool *has_cache,
+ int lvalue,
+ int dtype,
+ bool is_train,
+ int mode) {
+ int ngates = 0, nstates = 0;
+ algorithm nalgorithm = GetMKLDNNRNNAlgo(mode, &ngates, &nstates);
+ mkldnn::memory::data_type mkldnn_dtype = get_mkldnn_type(dtype);
+ const int single_cell_size = N * H;
+ const int single_b_size = ngates * H;
+ DType* wx = w_ptr; // ngates * H, I
+ DType* wh = w_ptr + I * H * ngates; // ngates * H, H
+ DType* back_wx = w_ptr + ngates * H * (I + H);
+ DType* back_wh = back_wx + I * H * ngates;
+ DType* bx = b_ptr;
+ DType* bh = b_ptr + H * ngates;
+ DType* back_bx = b_ptr + single_b_size * 2;
+ DType* back_bh = back_bx + H * ngates;
+ const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
+ auto cpu_engine = CpuEngine::Get()->get_engine();
+ auto null_memory_ = null_memory(cpu_engine);
+ int offset1 = 0, offset2 = 0;
+ bool initialized = *has_cache;
+ mkldnn::memory::dims src_layer_tz = {T, N, I};
+ mkldnn::memory::dims dst_layer_tz = {T, N, 2 * H};
+ mkldnn::memory::dims weights_layer_tz = {1, 2, I, ngates, H}; // ldigo
+ mkldnn::memory::dims weights_layer_r_tz = {1, 1, I, ngates, H}; // ldigo for reorder
+ mkldnn::memory::dims weights_iter_tz = {1, 2, H, ngates, H}; // ldigo
+ mkldnn::memory::dims weights_iter_r_tz = {1, 1, H, ngates, H}; // ldigo for reorder
+ mkldnn::memory::dims bias_tz = {1, 2, ngates, H};
+ mkldnn::memory::dims src_iter_tz = {1, 2, nstates, N, H}; // ldsnc
+ mkldnn::memory::dims dst_iter_tz = {1, 2, nstates, N, H}; // ldsnc
+
+ if (!initialized) {
+ if (mode == rnn_enum::kGru) {
+ AdjustGruWeightGateOrder(wx, I, H);
+ AdjustGruWeightGateOrder(back_wx, I, H);
+ AdjustGruWeightGateOrder(wh, H, H);
+ AdjustGruWeightGateOrder(back_wh, H, H);
+ AdjustGruBiasGateOrder(bx, H);
+ AdjustGruBiasGateOrder(back_bx, H);
+ AdjustGruBiasGateOrder(bh, H);
+ AdjustGruBiasGateOrder(back_bh, H);
+ }
+ auto src_wx = (*concat_weight_memory)[2 * layer_index];
+ auto src_wh = (*concat_weight_memory)[2 * layer_index + 1];
+ std::vector<void*> srcs_data1;
+ srcs_data1.push_back(wx);
+ srcs_data1.push_back(back_wx);
+ ConcatData(mkldnn::memory::format::ldgoi, mkldnn::memory::format::ldgoi,
+ {weights_layer_r_tz, weights_layer_r_tz}, weights_layer_tz,
+ mkldnn_dtype, 1, srcs_data1, src_wx);
+ srcs_data1.clear();
+ srcs_data1.push_back(wh);
+ srcs_data1.push_back(back_wh);
+ ConcatData(mkldnn::memory::format::ldgoi, mkldnn::memory::format::ldgoi,
+ {weights_iter_r_tz, weights_iter_r_tz}, weights_iter_tz,
+ mkldnn_dtype, 1, srcs_data1, src_wh);
+ int tmpvalue = 0;
+ if (lvalue > 0) {
+ tmpvalue = lvalue + 1;
+ }
+ MKLDNNStream::Get()->RegisterPrim(reorder(src_wx, (*wx_memory)[tmpvalue]));
+ MKLDNNStream::Get()->RegisterPrim(reorder(src_wh, (*wh_memory)[tmpvalue]));
+
+ DType* user_bias = reinterpret_cast<DType *>
+ ((*bias_memory)[tmpvalue].get_data_handle());
+ #pragma omp parallel for num_threads(omp_threads)
+ for (int j = 0; j < single_b_size; j++) {
+ user_bias[j] = bx[j] + bh[j];
+ user_bias[single_b_size + j] = back_bx[j] + back_bh[j];
+ }
+ }
+ if (lvalue > 0) {
+ (*wx_memory)[layer_index].set_data_handle((*wx_memory)[lvalue + 1].get_data_handle());
+ (*wh_memory)[layer_index].set_data_handle((*wh_memory)[lvalue + 1].get_data_handle());
+ (*bias_memory)[layer_index].set_data_handle((*bias_memory)[lvalue + 1].get_data_handle());
+ }
+
+ auto src_layer_md = mkldnn::memory::desc(
+ { src_layer_tz }, mkldnn_dtype, mkldnn::memory::format::tnc);
+ auto weight_layer_md = mkldnn::memory::desc(
+ { weights_layer_tz }, mkldnn_dtype, mkldnn::memory::format::ldigo);
+ auto weight_iter_md = mkldnn::memory::desc(
+ { weights_iter_tz }, mkldnn_dtype, mkldnn::memory::format::ldigo);
+ auto dst_layer_md = mkldnn::memory::desc(
+ { dst_layer_tz }, mkldnn_dtype, mkldnn::memory::format::tnc);
+ auto dst_iter_md = mkldnn::memory::desc(
+ { dst_iter_tz }, mkldnn_dtype, mkldnn::memory::format::ldsnc);
+ auto src_iter_md = mkldnn::memory::desc(
+ {src_iter_tz}, mkldnn_dtype, mkldnn::memory::format::ldsnc);
+ auto bias_md = mkldnn::memory::desc({bias_tz},
+ mkldnn_dtype, mkldnn::memory::format::ldgo);
+
+ auto user_src_iter_memory = (*concat_iter_memory)[2];
+ if (mode == rnn_enum::kLstm) {
+ std::vector<void*> srcs_data1;
+ srcs_data1.push_back(hx_ptr);
+ srcs_data1.push_back(cx_ptr);
+ auto tmp1_src_iter_memory = (*concat_iter_memory)[0];
+ ConcatData(mkldnn::memory::format::ldsnc, mkldnn::memory::format::ldsnc,
+ {{1, 1, 1, N, H}, {1, 1, 1, N, H}}, {1, 1, nstates, N, H}, mkldnn_dtype, 2,
+ srcs_data1, tmp1_src_iter_memory);
+ std::vector<void*> srcs_data2;
+ srcs_data2.push_back(hx_ptr + single_cell_size);
+ srcs_data2.push_back(cx_ptr + single_cell_size);
+ auto tmp2_src_iter_memory = (*concat_iter_memory)[1];
+ ConcatData(mkldnn::memory::format::ldsnc, mkldnn::memory::format::ldsnc,
+ {{1, 1, 1, N, H}, {1, 1, 1, N, H}}, {1, 1, nstates, N, H}, mkldnn_dtype, 2,
+ srcs_data2, tmp2_src_iter_memory);
+ std::vector<void*> srcs_data3;
+ srcs_data3.push_back(reinterpret_cast<DType *>(tmp1_src_iter_memory.get_data_handle()));
+ srcs_data3.push_back(reinterpret_cast<DType *>(tmp2_src_iter_memory.get_data_handle()));
+ ConcatData(mkldnn::memory::format::ldsnc, mkldnn::memory::format::ldsnc,
+ {{1, 1, nstates, N, H}, {1, 1, nstates, N, H}}, {1, 2, nstates, N, H},
+ mkldnn_dtype, 1, srcs_data3, user_src_iter_memory);
+ } else {
+ user_src_iter_memory.set_data_handle(hx_ptr);
+ }
+ (*hcx_memory)[layer_index].set_data_handle(user_src_iter_memory.get_data_handle());
+
+ rnn_cell::desc rnn_cell(nalgorithm,
+ mode == rnn_enum::kRnnRelu ? algorithm::eltwise_relu : algorithm::eltwise_tanh);
+
+ rnn_forward::desc layer_desc(prop_kind::forward_inference, rnn_cell,
+ rnn_direction::bidirectional_concat, src_layer_md,
+ src_iter_md, weight_layer_md, weight_iter_md,
+ bias_md, dst_layer_md, dst_iter_md);
+
+ auto prim_desc
+ = rnn_forward::primitive_desc(layer_desc, cpu_engine);
+
+ if (x_ptr && layer_index == 0) {
+ (*x_memory)[layer_index].set_data_handle(x_ptr);
+ } else {
+ (*x_memory)[layer_index].set_data_handle((*user_src_layer_memory).get_data_handle());
+ }
+ (*y_memory)[layer_index].set_data_handle(y_ptr);
+
+ if (rnn_forward_prim->size() <= (size_t)layer_index) {
+ primitive rnn_prim = rnn_forward(prim_desc, (*x_memory)[layer_index],
+ (*hcx_memory)[layer_index], (*wx_memory)[layer_index],
+ (*wh_memory)[layer_index], (*bias_memory)[layer_index],
+ (*y_memory)[layer_index],
+ (*hcy_memory)[layer_index], null_memory_);
+ rnn_forward_prim->push_back(rnn_prim);
+ }
+ MKLDNNStream::Get()->RegisterPrim((*rnn_forward_prim)[layer_index]);
+ MKLDNNStream::Get()->Submit();
+
+ if (state_outputs) {
+ DType* dst_hcy = reinterpret_cast<DType *> ((*hcy_memory)[layer_index].get_data_handle());
+ if (mode == rnn_enum::kLstm) {
+ offset1 = nstates * single_cell_size;
+ offset2 = (nstates + 1) * single_cell_size;
+ #pragma omp parallel for num_threads(omp_threads)
+ for (int n = 0; n < single_cell_size; n++) {
+ hy_ptr[n] = dst_hcy[n];
+ hy_ptr[n + single_cell_size] = dst_hcy[n + offset1];
+ cy_ptr[n] = dst_hcy[n + single_cell_size];
+ cy_ptr[n + single_cell_size] = dst_hcy[n + offset2];
+ }
+ } else {
+ #pragma omp parallel for num_threads(omp_threads)
+ for (int n = 0; n < 2 * single_cell_size; n++) {
+ hy_ptr[n] = dst_hcy[n];
+ }
+ }
+ }
+}
+
+
+template <typename DType>
+static void MKLDNNRNNForwardUnidi(bool state_outputs,
+ const int L,
+ const int T,
+ const int N,
+ const int I,
+ const int H,
+ DType* x_ptr,
+ mkldnn::memory *user_src_layer_memory,
+ DType* hx_ptr,
+ DType* cx_ptr,
+ DType* w_ptr,
+ DType* b_ptr,
+ DType* y_ptr,
+ DType* hy_ptr,
+ DType* cy_ptr,
+ std::vector<mkldnn::memory> *concat_weight_memory,
+ std::vector<mkldnn::memory> *concat_iter_memory,
+ std::vector<mkldnn::memory> *x_memory,
+ std::vector<mkldnn::memory> *hcx_memory,
+ std::vector<mkldnn::memory> *wx_memory,
+ std::vector<mkldnn::memory> *wh_memory,
+ std::vector<mkldnn::memory> *bias_memory,
+ std::vector<mkldnn::memory> *y_memory,
+ std::vector<mkldnn::memory> *hcy_memory,
+ std::vector<primitive> *rnn_forward_prim,
+ int layer_index,
+ bool *has_cache,
+ int dtype,
+ bool is_train,
+ int mode) {
+ int ngates = 0, nstates = 0;
+ algorithm nalgorithm = GetMKLDNNRNNAlgo(mode, &ngates, &nstates);
+ mkldnn::memory::data_type mkldnn_dtype = get_mkldnn_type(dtype);
+ const int cell_size = N * H;
+ const int single_cell_size = N * H;
+ const int single_b_size = ngates * H;
+ int w_size = (I + H) * H * ngates;
+ const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
+ auto cpu_engine = CpuEngine::Get()->get_engine();
+ auto null_memory_ = null_memory(cpu_engine);
+ int offset1 = 0, offset2 = 0;
+ bool initialized = *has_cache;
+
+ mkldnn::memory::dims src_layer_tz = {T, N, I};
+ mkldnn::memory::dims dst_layer_tz = {T, N, H};
+ mkldnn::memory::dims weights_layer_tz = {L, 1, I, ngates, H}; // ldigo
+ mkldnn::memory::dims weights_iter_tz = {L, 1, H, ngates, H}; // ldigo
+ mkldnn::memory::dims bias_tz = {L, 1, ngates, H};
+ mkldnn::memory::dims src_iter_tz = {L, 1, nstates, N, H}; // ldsnc
+ mkldnn::memory::dims dst_iter_tz = {L, 1, nstates, N, H}; // ldsnc
+ mkldnn::memory::dims weights_layer_r_tz = {1, 1, I, ngates, H}; // ldigo for reorder
+ mkldnn::memory::dims weights_iter_r_tz = {1, 1, H, ngates, H}; // ldigo for reorder
+
+ auto weight_layer_md = mkldnn::memory::desc(
+ { weights_layer_tz }, mkldnn_dtype, mkldnn::memory::format::ldigo);
+ auto weight_iter_md = mkldnn::memory::desc(
+ { weights_iter_tz }, mkldnn_dtype, mkldnn::memory::format::ldigo);
+ auto src_layer_md = mkldnn::memory::desc(
+ { src_layer_tz }, mkldnn_dtype, mkldnn::memory::format::tnc);
+ auto dst_layer_md = mkldnn::memory::desc(
+ {dst_layer_tz}, mkldnn_dtype, mkldnn::memory::format::tnc);
+ auto src_iter_md = mkldnn::memory::desc(
+ {src_iter_tz}, mkldnn_dtype, mkldnn::memory::format::ldsnc);
+ auto bias_md = mkldnn::memory::desc({bias_tz},
+ mkldnn_dtype, mkldnn::memory::format::ldgo);
+ auto dst_iter_md = mkldnn::memory::desc(
+ {dst_iter_tz}, mkldnn_dtype, mkldnn::memory::format::ldsnc);
+
+ for (int l = 0; l < L; l++) {
+ if (mode == rnn_enum::kLstm) {
+ std::vector<void*> srcs_data;
+ srcs_data.push_back(hx_ptr);
+ srcs_data.push_back(cx_ptr);
+ auto tmp_src_iter_memory = (*concat_iter_memory)[l + layer_index];
+ ConcatData(mkldnn::memory::format::ldsnc, mkldnn::memory::format::ldsnc,
+ {{1, 1, 1, N, H}, {1, 1, 1, N, H}}, {1, 1, nstates, N, H}, mkldnn_dtype,
+ 2, srcs_data, tmp_src_iter_memory);
+ } else {
+ (*concat_iter_memory)[l + layer_index].set_data_handle(hx_ptr);
+ }
+ hx_ptr += cell_size;
+ if (mode == rnn_enum::kLstm) {
+ cx_ptr += cell_size;
+ }
+ }
+
+ auto user_src_iter_memory = null_memory_;
+ if (L == 1) {
+ user_src_iter_memory = (*concat_iter_memory)[layer_index];
+ } else {
+ user_src_iter_memory = (*concat_iter_memory)[L + layer_index];
+ std::vector<void*> src_l_data;
+ std::vector<mkldnn::memory::dims> src_l_dim;
+ for (int l = 0; l < L; l++) {
+ src_l_data.push_back(reinterpret_cast<DType *>
+ ((*concat_iter_memory)[l + layer_index].get_data_handle()));
+ src_l_dim.push_back({1, 1, nstates, N, H});
+ }
+ ConcatData(mkldnn::memory::format::ldsnc, mkldnn::memory::format::ldsnc, src_l_dim,
+ {L, 1, nstates, N, H}, mkldnn_dtype, 0, src_l_data, user_src_iter_memory);
+ }
+ (*hcx_memory)[layer_index].set_data_handle(user_src_iter_memory.get_data_handle());
+
+ auto src_wx_f = (*concat_weight_memory)[2 * layer_index];
+ auto src_wh_f = (*concat_weight_memory)[2 * layer_index + 1];
+
+ std::vector<void*> srcs_data_x;
+ std::vector<void*> srcs_data_h;
+ std::vector<mkldnn::memory::dims> src_l_dim_x;
+ std::vector<mkldnn::memory::dims> src_l_dim_h;
+ if (!initialized) {
+ if (L == 1) {
+ DType* wx = w_ptr;
+ DType* wh = w_ptr + I * H * ngates;
+ if (mode == rnn_enum::kGru) {
+ AdjustGruWeightGateOrder(wx, I, H);
+ AdjustGruWeightGateOrder(wh, H, H);
+ AdjustGruBiasGateOrder(b_ptr, H);
+ AdjustGruBiasGateOrder(b_ptr + H * ngates, H);
+ }
+ src_wx_f.set_data_handle(wx);
+ src_wh_f.set_data_handle(wh);
+ } else {
+ for (int l = 0; l < L; l++) {
+ DType* wx = w_ptr;
+ DType* wh = w_ptr + I * H * ngates;
+ DType* bx = b_ptr + l * ngates * H * 2;
+ DType* bh = b_ptr + l * ngates * H * 2 + H * ngates;
+ if (mode == rnn_enum::kGru) {
+ AdjustGruWeightGateOrder(wx, I, H);
+ AdjustGruWeightGateOrder(wh, H, H);
+ AdjustGruBiasGateOrder(bx, H);
+ AdjustGruBiasGateOrder(bh, H);
+ }
+ srcs_data_x.push_back(wx);
+ srcs_data_h.push_back(wh);
+ src_l_dim_x.push_back(weights_layer_r_tz);
+ src_l_dim_h.push_back(weights_iter_r_tz);
+ w_ptr = w_ptr + w_size;
+ }
+ ConcatData(mkldnn::memory::format::ldgoi, mkldnn::memory::format::ldgoi,
+ src_l_dim_x, weights_layer_tz, mkldnn_dtype, 0, srcs_data_x, src_wx_f);
+ ConcatData(mkldnn::memory::format::ldgoi, mkldnn::memory::format::ldgoi,
+ src_l_dim_h, weights_iter_tz, mkldnn_dtype, 0, srcs_data_h, src_wh_f);
+ }
+ MKLDNNStream::Get()->RegisterPrim(reorder(src_wx_f, (*wx_memory)[layer_index]));
+ MKLDNNStream::Get()->RegisterPrim(reorder(src_wh_f, (*wh_memory)[layer_index]));
+
+ DType* user_bias_f = reinterpret_cast<DType *> ((*bias_memory)[layer_index].get_data_handle());
+ #pragma omp parallel for num_threads(omp_threads)
+ for (int j = 0; j < L * single_b_size; j++) {
+ int k = j / single_b_size;
+ user_bias_f[j] = b_ptr[j + k * single_b_size] + b_ptr[j + k * single_b_size + single_b_size];
+ }
+ }
+
+ rnn_cell::desc rnn_cell(nalgorithm,
+ mode == rnn_enum::kRnnRelu ? algorithm::eltwise_relu : algorithm::eltwise_tanh);
+
+ rnn_forward::desc layer_desc(prop_kind::forward_inference, rnn_cell,
+ rnn_direction::unidirectional, src_layer_md,
+ src_iter_md, weight_layer_md, weight_iter_md,
+ bias_md, dst_layer_md, dst_iter_md);
+
+ auto prim_desc
+ = rnn_forward::primitive_desc(layer_desc, cpu_engine);
+
+ if (x_ptr && layer_index == 0) {
+ (*x_memory)[layer_index].set_data_handle(x_ptr);
+ } else {
+ (*x_memory)[layer_index].set_data_handle((*user_src_layer_memory).get_data_handle());
+ }
+ (*y_memory)[layer_index].set_data_handle(y_ptr);
+
+ if (rnn_forward_prim->size() <= (size_t)layer_index) {
+ primitive rnn_prim = rnn_forward(prim_desc, (*x_memory)[layer_index],
+ (*hcx_memory)[layer_index], (*wx_memory)[layer_index],
+ (*wh_memory)[layer_index], (*bias_memory)[layer_index],
+ (*y_memory)[layer_index],
+ (*hcy_memory)[layer_index], null_memory_);
+ rnn_forward_prim->push_back(rnn_prim);
+ }
+ MKLDNNStream::Get()->RegisterPrim((*rnn_forward_prim)[layer_index]);
+ MKLDNNStream::Get()->Submit();
+
+ if (state_outputs) {
+ DType* dst_hcy = reinterpret_cast<DType *> ((*hcy_memory)[layer_index].get_data_handle());
+ if (mode == rnn_enum::kLstm) {
+ for (int l = 0; l < L; l++) {
+ offset1 = l * single_cell_size;
+ offset2 = l * nstates * single_cell_size;
+ #pragma omp parallel for num_threads(omp_threads)
+ for (int n = 0; n < single_cell_size; n++) {
+ hy_ptr[offset1 + n] = dst_hcy[offset2 + n];
+ cy_ptr[offset1 + n] = dst_hcy[offset2 + n + single_cell_size];
+ }
+ }
+ } else {
+ #pragma omp parallel for num_threads(omp_threads)
+ for (int n = 0; n < L * single_cell_size; n++) {
+ hy_ptr[n] = dst_hcy[n];
+ }
+ }
+ }
+}
+
+template <typename DType>
+static void MKLDNNRNNForward(bool state_outputs,
+ const int L,
+ const int D,
+ const int T,
+ const int N,
+ const int I,
+ const int H,
+ DType* x_ptr,
+ DType* hx_ptr,
+ DType* cx_ptr,
+ DType* w_ptr,
+ DType* b_ptr,
+ DType* y_ptr,
+ DType* hy_ptr,
+ DType* cy_ptr,
+ std::vector<mkldnn::memory> *concat_weight_memory,
+ std::vector<mkldnn::memory> *concat_iter_memory,
+ std::vector<mkldnn::memory> *x_memory,
+ std::vector<mkldnn::memory> *hcx_memory,
+ std::vector<mkldnn::memory> *wx_memory,
+ std::vector<mkldnn::memory> *wh_memory,
+ std::vector<mkldnn::memory> *bias_memory,
+ std::vector<mkldnn::memory> *y_memory,
+ std::vector<mkldnn::memory> *hcy_memory,
+ std::vector<primitive> *rnn_forward_prim,
+ bool *has_cache,
+ int dtype,
+ bool is_train,
+ int mode) {
+ int ngates = 0, nstates = 0;
+ GetMKLDNNRNNAlgo(mode, &ngates, &nstates);
+ const int b_size = 2 * H * ngates * D;
+ const int cell_size = N * H * D;
+ // First layer
+ int w_size = (I + H) * H * ngates * D;
+ auto cpu_engine = CpuEngine::Get()->get_engine();
+ auto null_memory_ = null_memory(cpu_engine);
+ DType* tmpNull = NULL;
+ // when D = 1 and I == H, L layers can be fused together
+ if (D == 1 && I == H) {
+ MKLDNNRNNForwardUnidi(state_outputs, L, T, N, I, H, x_ptr, &null_memory_,
+ hx_ptr, cx_ptr, w_ptr, b_ptr, y_ptr, hy_ptr, cy_ptr, concat_weight_memory,
+ concat_iter_memory, x_memory, hcx_memory, wx_memory, wh_memory,
+ bias_memory, y_memory, hcy_memory, rnn_forward_prim,
+ 0, has_cache, dtype, is_train, mode);
+ } else {
+ auto user_src_layer_memory_l = null_memory_;
+ if (D == 2) {
+ MKLDNNRNNForwardSingleLayerBi(state_outputs, T, N, I, H, x_ptr, &user_src_layer_memory_l,
+ hx_ptr, cx_ptr, w_ptr, b_ptr, y_ptr, hy_ptr, cy_ptr, concat_weight_memory,
+ concat_iter_memory, x_memory, hcx_memory, wx_memory, wh_memory,
+ bias_memory, y_memory, hcy_memory, rnn_forward_prim,
+ 0, has_cache, 0, dtype, is_train, mode);
+ } else {
+ MKLDNNRNNForwardUnidi(state_outputs, 1, T, N, I, H, x_ptr, &user_src_layer_memory_l,
+ hx_ptr, cx_ptr, w_ptr, b_ptr, y_ptr, hy_ptr, cy_ptr, concat_weight_memory,
+ concat_iter_memory, x_memory, hcx_memory, wx_memory, wh_memory,
+ bias_memory, y_memory, hcy_memory, rnn_forward_prim,
+ 0, has_cache, dtype, is_train, mode);
+ }
+ if (L > 1) {
+ user_src_layer_memory_l = (*y_memory)[0];
+ // go to next L - 1 layers.
+ // If D = 2, do it layer by layer. If D = 1, fused L - 1 layers
+ w_ptr += w_size;
+ b_ptr += b_size;
+ if (D == 2) {
+ w_size = (H * D + H) * H * ngates * D;
+ for (int l = 0; l < L - 1; l++) {
+ if (state_outputs) {
+ hy_ptr += cell_size;
+ if (mode == rnn_enum::kLstm) {
+ cy_ptr += cell_size;
+ }
+ }
+ hx_ptr += cell_size;
+ if (mode == rnn_enum::kLstm) {
+ cx_ptr += cell_size;
+ }
+ MKLDNNRNNForwardSingleLayerBi(state_outputs, T, N, D * H, H, tmpNull,
+ &user_src_layer_memory_l, hx_ptr, cx_ptr, w_ptr, b_ptr, y_ptr, hy_ptr,
+ cy_ptr, concat_weight_memory, concat_iter_memory, x_memory,
+ hcx_memory, wx_memory, wh_memory, bias_memory,
+ y_memory, hcy_memory, rnn_forward_prim,
+ 1, has_cache, l + 1, dtype, is_train, mode);
+ user_src_layer_memory_l = (*y_memory)[1];
+ w_ptr += w_size;
+ b_ptr += b_size;
+ }
+ }
+ if (D == 1) {
+ if (state_outputs) {
+ hy_ptr += cell_size;
+ if (mode == rnn_enum::kLstm) {
+ cy_ptr += cell_size;
+ }
+ }
+ w_size = (H + H) * H * ngates;
+ MKLDNNRNNForwardUnidi(state_outputs, L - 1, T, N, H, H, tmpNull, &user_src_layer_memory_l,
+ hx_ptr, cx_ptr, w_ptr, b_ptr, y_ptr, hy_ptr, cy_ptr, concat_weight_memory,
+ concat_iter_memory, x_memory, hcx_memory, wx_memory,
+ wh_memory, bias_memory, y_memory, hcy_memory,
+ rnn_forward_prim, 1, has_cache, dtype, is_train, mode);
+ }
+ }
+ }
+ *has_cache = true;
+}
+
+template <typename DType>
+static void MKLDNNRNNForwardInference(bool state_outputs,
+ const int num_layers,
+ const int direction,
+ const int seq_length,
+ const int batch_size,
+ const int input_size,
+ const int state_size,
+ DType* x_ptr,
+ DType* hx_ptr,
+ DType* cx_ptr,
+ DType* w_ptr,
+ DType* b_ptr,
+ DType* y_ptr,
+ DType* hy_ptr,
+ DType* cy_ptr,
+ std::vector<mkldnn::memory>* concat_weight_memory,
+ std::vector<mkldnn::memory>* concat_iter_memory,
+ std::vector<mkldnn::memory> *x_memory,
+ std::vector<mkldnn::memory> *hcx_memory,
+ std::vector<mkldnn::memory> *wx_memory,
+ std::vector<mkldnn::memory> *wh_memory,
+ std::vector<mkldnn::memory> *bias_memory,
+ std::vector<mkldnn::memory> *y_memory,
+ std::vector<mkldnn::memory> *hcy_memory,
+ std::vector<primitive> *rnn_forward_prim,
+ bool *has_cache,
+ int dtype,
+ bool is_train,
+ int mode) {
+ switch (mode) {
+ case rnn_enum::kLstm:
+ case rnn_enum::kGru:
+ case rnn_enum::kRnnTanh:
+ case rnn_enum::kRnnRelu:
+ MKLDNNRNNForward<DType>(state_outputs, num_layers, direction, seq_length,
+ batch_size, input_size, state_size, x_ptr, hx_ptr,
+ cx_ptr, w_ptr, b_ptr, y_ptr, hy_ptr, cy_ptr,
+ concat_weight_memory, concat_iter_memory, x_memory,
+ hcx_memory, wx_memory, wh_memory,
+ bias_memory, y_memory, hcy_memory, rnn_forward_prim,
+ has_cache, dtype, is_train, mode);
+ break;
+ default:
+ LOG(FATAL) << "unknown RNN mode" << mode;
+ break;
+ }
+}
+
+} // namespace op
+} // namespace mxnet
+#endif // MXNET_USE_MKLDNN == 1
+#endif // MXNET_OPERATOR_NN_MKLDNN_MKLDNN_RNN_IMPL_H_
diff --git a/src/operator/rnn-inl.h b/src/operator/rnn-inl.h
index e43b3c9..9785be2 100644
--- a/src/operator/rnn-inl.h
+++ b/src/operator/rnn-inl.h
@@ -44,17 +44,13 @@
#include "./math_functions-inl.h"
#include "./operator_common.h"
#include "./rnn_impl.h"
+#if MXNET_USE_MKLDNN == 1
+#include "./nn/mkldnn/mkldnn_rnn_impl.h"
+#endif
namespace mxnet {
namespace op {
-namespace rnn_enum {
- enum RNNOpInputs {kData, kParams, kState, kStateCell, kSequenceLength};
- enum RNNOpOutputs {kOut, kStateOut, kStateCellOut};
- enum RNNModeType {kRnnRelu, kRnnTanh, kLstm, kGru};
- enum RNNOpResource {kTempSpace, kCuDNNDropoutDescSpace};
-}
-
inline int GetRnnParamSize(int num_layer,
int input_size,
int state_size,
@@ -400,9 +396,29 @@ class RNNOp {
public:
RNNParam param_;
Context ctx_;
+ #if MXNET_USE_MKLDNN == 1
+ std::vector<mkldnn::memory> concat_weight_memory;
+ std::vector<mkldnn::memory> concat_iter_memory;
+ std::vector<primitive> rnn_forward_prim;
+ std::vector<mkldnn::memory> x_memory;
+ std::vector<mkldnn::memory> hcx_memory;
+ std::vector<mkldnn::memory> wx_memory;
+ std::vector<mkldnn::memory> wh_memory;
+ std::vector<mkldnn::memory> bias_memory;
+ std::vector<mkldnn::memory> y_memory;
+ std::vector<mkldnn::memory> hcy_memory;
+ bool has_cache;
+ bool init_mem_;
+ size_t reserve_mem_size_;
+ Storage::Handle mem_space_;
+ #endif
explicit RNNOp(RNNParam param, Context ctx) {
this->param_ = param;
this->ctx_ = ctx;
+ #if MXNET_USE_MKLDNN == 1
+ init_mem_ = false;
+ reserve_mem_size_ = 0;
+ #endif
#if MXNET_USE_CUDNN_RNN
init_cudnn_ = false;
dtype_ = mshadow::DataType<DType>::kCudnnFlag;
@@ -410,8 +426,8 @@ class RNNOp {
// No tests in place for fp16 RNNs, so leave TensorCore disabled for now.
cudnn_tensor_core_ = false;
// When fp16 RNN tests are introduced, we can enable TensorCore as follows:
-// cudnn_tensor_core =
-// mshadow::DataType<DType>::kFlag == mshadow::kFloat16 && GetEnvAllowTensorCore();
+ // cudnn_tensor_core =
+ // mshadow::DataType<DType>::kFlag == mshadow::kFloat16 && GetEnvAllowTensorCore();
// Defaults
input_mode_ = CUDNN_LINEAR_INPUT; // Don't support this yet
// RNN Mode
@@ -492,7 +508,6 @@ class RNNOp {
this->temp_init_space_ = false;
this->reserve_cpu_space_size_ = 0;
this->temp_cpu_space_size_ = 0;
-
if (param_.projection_size.has_value()) {
LOG(FATAL) <<
"hidden layer projection is only supported for GPU with CuDNN later than 7.1.1";
@@ -505,6 +520,12 @@ class RNNOp {
}
~RNNOp() {
+ #if MXNET_USE_MKLDNN == 1
+ if (init_mem_) {
+ Storage::Get()->Free(mem_space_);
+ init_mem_ = false;
+ }
+ #endif
#if MXNET_USE_CUDNN_RNN
CUDNN_CALL(cudnnDestroyTensorDescriptor(hx_desc_));
CUDNN_CALL(cudnnDestroyTensorDescriptor(cx_desc_));
@@ -829,22 +850,23 @@ class RNNOp {
#endif
if (ctx_.dev_type == kCPU) {
- // allocate temp space
- const size_t work_cpu_space_size =
- GetRNNWorkspaceSize(param_.seq_length_, param_.batch_size_,
- param_.state_size, direction, param_.mode);
- if (temp_init_space_ && temp_cpu_space_size_ < work_cpu_space_size) {
- Storage::Get()->Free(temp_cpu_space_);
- temp_init_space_ = false;
- }
- if (!temp_init_space_) {
- temp_cpu_space_ = Storage::Get()->Alloc
- (work_cpu_space_size * sizeof(DType), Context::CPU());
- temp_cpu_space_size_ = work_cpu_space_size;
- temp_init_space_ = true;
- }
- DType* work_cpu_space = static_cast<DType*>(temp_cpu_space_.dptr);
if (ctx.is_train) {
+ // allocate temp space
+ const size_t work_cpu_space_size =
+ GetRNNWorkspaceSize(param_.seq_length_, param_.batch_size_,
+ param_.state_size, direction, param_.mode);
+ if (temp_init_space_ && temp_cpu_space_size_ < work_cpu_space_size) {
+ Storage::Get()->Free(temp_cpu_space_);
+ temp_init_space_ = false;
+ }
+ if (!temp_init_space_) {
+ temp_cpu_space_ = Storage::Get()->Alloc
+ (work_cpu_space_size * sizeof(DType), Context::CPU());
+ temp_cpu_space_size_ = work_cpu_space_size;
+ temp_init_space_ = true;
+ }
+ DType* work_cpu_space = static_cast<DType*>(temp_cpu_space_.dptr);
+
const size_t r_size = GetRNNReserveSpaceSize(param_.num_layers, direction,
param_.seq_length_, param_.batch_size_,
param_.state_size, param_.mode);
@@ -880,23 +902,78 @@ class RNNOp {
param_.p,
param_.mode);
} else {
- RNNForwardInference<DType>(work_cpu_space,
- param_.state_outputs,
- param_.num_layers,
- direction,
- param_.seq_length_,
- param_.batch_size_,
- param_.input_size_,
- param_.state_size,
- x.dptr_,
- hx.dptr_,
- cx_ptr,
- w.dptr_,
- b_ptr,
- y.dptr_,
- hy_ptr,
- cy_ptr,
- param_.mode);
+ #if MXNET_USE_MKLDNN == 1
+ if (dmlc::GetEnv("MXNET_USE_MKLDNN_RNN", 1) && param_.mode != rnn_enum::kGru) {
+ // TODO(zixuanweeei): MKLDNN GRU has precision issue. A stable one
+ // will be added to MXNet when we figure out the issue.
+ int dtype = in_data[rnn_enum::kData].type_flag_;
+ MKLDNNRNNForwardInference<DType>(param_.state_outputs,
+ param_.num_layers,
+ direction,
+ param_.seq_length_,
+ param_.batch_size_,
+ param_.input_size_,
+ param_.state_size,
+ x.dptr_,
+ hx.dptr_,
+ cx_ptr,
+ w.dptr_,
+ b_ptr,
+ y.dptr_,
+ hy_ptr,
+ cy_ptr,
+ &concat_weight_memory,
+ &concat_iter_memory,
+ &x_memory,
+ &hcx_memory,
+ &wx_memory,
+ &wh_memory,
+ &bias_memory,
+ &y_memory,
+ &hcy_memory,
+ &rnn_forward_prim,
+ &has_cache,
+ dtype,
+ ctx.is_train,
+ param_.mode);
+ } else {
+ #endif
+ // Before integrating MKLDNN GRU fp32 inference
+ // using below code for keep func being OK
+ const size_t work_cpu_space_size =
+ GetRNNWorkspaceSize(param_.seq_length_, param_.batch_size_,
+ param_.state_size, direction, param_.mode);
+ if (temp_init_space_ && temp_cpu_space_size_ < work_cpu_space_size) {
+ Storage::Get()->Free(temp_cpu_space_);
+ temp_init_space_ = false;
+ }
+ if (!temp_init_space_) {
+ temp_cpu_space_ = Storage::Get()->Alloc
+ (work_cpu_space_size * sizeof(DType), Context::CPU());
+ temp_cpu_space_size_ = work_cpu_space_size;
+ temp_init_space_ = true;
+ }
+ DType* work_cpu_space = static_cast<DType*>(temp_cpu_space_.dptr);
+ RNNForwardInference<DType>(work_cpu_space,
+ param_.state_outputs,
+ param_.num_layers,
+ direction,
+ param_.seq_length_,
+ param_.batch_size_,
+ param_.input_size_,
+ param_.state_size,
+ x.dptr_,
+ hx.dptr_,
+ cx_ptr,
+ w.dptr_,
+ b_ptr,
+ y.dptr_,
+ hy_ptr,
+ cy_ptr,
+ param_.mode);
+ #if MXNET_USE_MKLDNN == 1
+ }
+ #endif
}
}
}
diff --git a/src/operator/rnn.cc b/src/operator/rnn.cc
index 9b412a2..3218494 100644
--- a/src/operator/rnn.cc
+++ b/src/operator/rnn.cc
@@ -167,6 +167,21 @@ static bool RNNType(const nnvm::NodeAttrs& attrs,
return true;
}
+inline static bool RNNStorageType(const nnvm::NodeAttrs& attrs,
+ const int dev_mask,
+ DispatchMode* dispatch_mode,
+ std::vector<int> *in_attrs,
+ std::vector<int> *out_attrs) {
+ DispatchMode wanted_mode = DispatchMode::kFCompute;
+
+ #if MXNET_USE_MKLDNN == 1
+ wanted_mode = DispatchMode::kFComputeEx;
+ #endif
+
+ return storage_type_assign(out_attrs, mxnet::kDefaultStorage,
+ dispatch_mode, wanted_mode);
+}
+
struct RNNGrad {
const char *op_name;
std::vector<nnvm::NodeEntry> operator()(const nnvm::NodePtr &n,
@@ -191,6 +206,417 @@ struct RNNGrad {
}
};
+#if MXNET_USE_MKLDNN == 1
+static void RNNStatefulComputeCPU(const OpStatePtr& state_ptr,
+ const OpContext& ctx,
+ const std::vector<NDArray>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<NDArray>& outputs) {
+ std::vector<TBlob> in_blobs;
+ std::vector<TBlob> out_blobs;
+ std::vector<NDArray> temp_ndarrays_i;
+ std::vector<NDArray> temp_ndarrays_o;
+ for (const NDArray& in : inputs) {
+ if (in.storage_type() == kDefaultStorage) {
+ temp_ndarrays_i.push_back(in.Reorder2Default());
+ in_blobs.emplace_back(temp_ndarrays_i.back().data());
+ } else {
+ in_blobs.emplace_back(in.data());
+ }
+ }
+
+ for (const NDArray& out : outputs) {
+ if (out.storage_type() == kDefaultStorage) {
+ temp_ndarrays_o.push_back(out.Reorder2Default());
+ out_blobs.emplace_back(temp_ndarrays_o.back().data());
+ } else {
+ out_blobs.emplace_back(out.data());
+ }
+ }
+ int dtype = in_blobs[rnn_enum::kData].type_flag_;
+ int itype = in_blobs[inputs.size()-1].type_flag_;
+ mkldnn::memory::data_type mkldnn_dtype = get_mkldnn_type(dtype);
+ Stream<cpu> *s = ctx.get_stream<cpu>();
+ auto cpu_engine = CpuEngine::Get()->get_engine();
+ MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
+ MSHADOW_TYPE_SWITCH(itype, IType, {
+ RNNOp<cpu, DType, IType>& op = state_ptr.get_state<RNNOp<cpu, DType, IType>>();
+ const RNNParam& param = op.param_;
+ int ngates = 0, nstates = 0;
+ GetMKLDNNRNNAlgo(param.mode, &ngates, &nstates);
+ int D = param.bidirectional ? 2 : 1;
+ Tensor<cpu, 3, DType> x = in_blobs[rnn_enum::kData].get<cpu, 3, DType>(s);
+ int T = x.shape_[0];
+ int N = x.shape_[1];
+ int I = x.shape_[2];
+ int H = param.state_size;
+ int L = param.num_layers;
+
+ const size_t r_size = GetMKLDNNRNNCacheMemorySize(L, D, T, N, I, H, param.mode);
+ if (op.init_mem_ && op.reserve_mem_size_ < r_size) {
+ Storage::Get()->Free(op.mem_space_);
+ op.init_mem_ = false;
+ }
+ if (!op.init_mem_) {
+ op.mem_space_ = Storage::Get()->Alloc(
+ r_size * sizeof(DType),
+ Context::CPU());
+ op.reserve_mem_size_ = r_size;
+ op.init_mem_ = true;
+ op.has_cache = false;
+ }
+ if (op.has_cache && op.x_memory.size() == 0) {
+ op.has_cache = false;
+ }
+
+ DType* workptr = static_cast<DType*>(op.mem_space_.dptr);
+ mkldnn::memory::dims src_layer_tz_0 = {T, N, I};
+ mkldnn::memory::dims src_layer_tz = {T, N, D * H};
+ mkldnn::memory::dims dst_layer_tz = {T, N, D * H};
+ auto dst_layer_md = mkldnn::memory::desc(
+ { dst_layer_tz }, mkldnn_dtype, mkldnn::memory::format::tnc);
+ if (op.x_memory.size() == 0) {
+ if (D == 1 && I == H) {
+ auto user_src_layer_md = mkldnn::memory::desc(
+ { src_layer_tz }, mkldnn_dtype, mkldnn::memory::format::tnc);
+ auto user_src_layer_memory_n = mkldnn::memory({ user_src_layer_md, cpu_engine });
+ op.x_memory.push_back(user_src_layer_memory_n);
+
+ mkldnn::memory::dims weights_layer_tz = {L, 1, I, ngates, H}; // ldigo
+ mkldnn::memory::dims weights_iter_tz = {L, 1, H, ngates, H}; // ldigo
+ mkldnn::memory::dims bias_tz = {L, 1, ngates, H};
+ auto user_weight_layer_md = mkldnn::memory::desc(
+ { weights_layer_tz }, mkldnn_dtype, mkldnn::memory::format::ldigo);
+ auto user_weight_iter_md = mkldnn::memory::desc(
+ { weights_iter_tz }, mkldnn_dtype, mkldnn::memory::format::ldigo);
+ auto user_bias_md = mkldnn::memory::desc({ bias_tz },
+ mkldnn_dtype, mkldnn::memory::format::ldgo);
+ DType* weight_layer_n = workptr; // L * I * ngates * H
+ auto user_weight_layer_memory_n
+ = mkldnn::memory({ user_weight_layer_md, cpu_engine }, weight_layer_n);
+ op.wx_memory.push_back(user_weight_layer_memory_n);
+
+ DType* weight_iter_n = weight_layer_n + L * I * ngates * H; // L * H * ngates * H
+ auto user_weight_iter_memory_n
+ = mkldnn::memory({ user_weight_iter_md, cpu_engine }, weight_iter_n);
+ op.wh_memory.push_back(user_weight_iter_memory_n);
+
+ DType* bias_n = weight_iter_n + L * H * ngates * H; // L * ngates * H
+ auto user_bias_memory_n =
+ mkldnn::memory({ user_bias_md, cpu_engine }, bias_n);
+ op.bias_memory.push_back(user_bias_memory_n);
+
+ auto wx_md_n = mkldnn::memory::desc(
+ { weights_layer_tz }, mkldnn_dtype, mkldnn::memory::format::ldgoi);
+ DType* wx_n = bias_n + L * ngates * H; // L * ngates * I * H
+ auto wx_memory_n =
+ mkldnn::memory({ wx_md_n, cpu_engine }, wx_n);
+ DType* wh_n = wx_n + L * ngates * I * H; // L * ngates * H * H
+ auto wh_md_n = mkldnn::memory::desc(
+ { weights_iter_tz }, mkldnn_dtype, mkldnn::memory::format::ldgoi);
+ auto wh_memory_n =
+ mkldnn::memory({ wh_md_n, cpu_engine }, wh_n);
+
+ op.concat_weight_memory.push_back(wx_memory_n);
+ op.concat_weight_memory.push_back(wh_memory_n);
+ workptr = wh_n + L * ngates * H * H;
+
+ mkldnn::memory::dims src_iter_tz_n1 = {1, 1, nstates, N, H}; // ldsnc
+ auto src_iter_md_n1 = mkldnn::memory::desc(
+ { src_iter_tz_n1 }, mkldnn_dtype, mkldnn::memory::format::ldsnc);
+ for (int l = 0; l < L; l++) {
+ DType* src_iter_n1 = workptr; // nstates * N * H
+ auto src_iter_memory_n1 =
+ mkldnn::memory({ src_iter_md_n1, cpu_engine }, src_iter_n1);
+ op.concat_iter_memory.push_back(src_iter_memory_n1);
+ workptr = src_iter_n1 + nstates * N * H;
+ }
+ mkldnn::memory::dims src_iter_tz_n = {L, 1, nstates, N, H}; // ldsnc
+ auto src_iter_md_n = mkldnn::memory::desc(
+ { src_iter_tz_n }, mkldnn_dtype, mkldnn::memory::format::ldsnc);
+ DType* src_iter_n = workptr; // L * nstates * N * H
+ auto src_iter_memory_n =
+ mkldnn::memory({ src_iter_md_n, cpu_engine }, src_iter_n);
+ op.concat_iter_memory.push_back(src_iter_memory_n);
+ op.hcx_memory.push_back(src_iter_memory_n);
+ DType* dst_layer_n = src_iter_n + L * nstates * N * H; // T * N * D * H
+ auto dst_layer_memory_n
+ = mkldnn::memory({ dst_layer_md, cpu_engine }, dst_layer_n);
+ op.y_memory.push_back(dst_layer_memory_n);
+
+ mkldnn::memory::dims dst_iter_tz_n = {L, 1, nstates, N, H}; // ldsnc
+ auto dst_iter_md_n = mkldnn::memory::desc(
+ { dst_iter_tz_n }, mkldnn_dtype, mkldnn::memory::format::ldsnc);
+ DType* dst_iter_n = dst_layer_n + T * N * D * H; // L * nstates * N * H
+ auto dst_iter_memory_n =
+ mkldnn::memory({ dst_iter_md_n, cpu_engine }, dst_iter_n);
+ op.hcy_memory.push_back(dst_iter_memory_n);
+ workptr = dst_iter_n + L * nstates * N * H;
+
+ } else {
+ auto user_src_layer_md_0 = mkldnn::memory::desc(
+ { src_layer_tz_0 }, mkldnn_dtype, mkldnn::memory::format::tnc);
+ auto user_src_layer_memory_0 = mkldnn::memory({ user_src_layer_md_0, cpu_engine });
+ op.x_memory.push_back(user_src_layer_memory_0);
+
+ mkldnn::memory::dims weights_layer_tz_0 = {1, D, I, ngates, H}; // ldigo
+ mkldnn::memory::dims weights_iter_tz_0 = {1, D, H, ngates, H}; // ldigo
+ mkldnn::memory::dims bias_tz_0 = {1, D, ngates, H};
+ auto user_weight_layer_md_0 = mkldnn::memory::desc(
+ { weights_layer_tz_0 }, mkldnn_dtype, mkldnn::memory::format::ldigo);
+ auto user_weight_iter_md_0 = mkldnn::memory::desc(
+ { weights_iter_tz_0 }, mkldnn_dtype, mkldnn::memory::format::ldigo);
+ auto user_bias_md_0 = mkldnn::memory::desc({ bias_tz_0 },
+ mkldnn_dtype, mkldnn::memory::format::ldgo);
+
+ DType* weight_layer_0 = workptr; // D * I * ngates * H
+ auto user_weight_layer_memory_0
+ = mkldnn::memory({ user_weight_layer_md_0, cpu_engine }, weight_layer_0);
+ op.wx_memory.push_back(user_weight_layer_memory_0);
+
+ DType* weight_iter_0 = weight_layer_0 + D * I * ngates * H; // D * H * ngates * H
+ auto user_weight_iter_memory_0
+ = mkldnn::memory({ user_weight_iter_md_0, cpu_engine }, weight_iter_0);
+ op.wh_memory.push_back(user_weight_iter_memory_0);
+
+ DType* bias_0 = weight_iter_0 + D * H * ngates * H; // D * ngates * H
+ auto user_bias_memory_0 =
+ mkldnn::memory({ user_bias_md_0, cpu_engine }, bias_0);
+ op.bias_memory.push_back(user_bias_memory_0);
+ workptr = bias_0 + D * ngates * H;
+
+ auto wx_md_0 = mkldnn::memory::desc(
+ { weights_layer_tz_0 }, mkldnn_dtype, mkldnn::memory::format::ldgoi);
+ auto wx_memory_0 =
+ mkldnn::memory({ wx_md_0, cpu_engine });
+ auto wh_md_0 = mkldnn::memory::desc(
+ { weights_iter_tz_0 }, mkldnn_dtype, mkldnn::memory::format::ldgoi);
+ auto wh_memory_0 =
+ mkldnn::memory({ wh_md_0, cpu_engine });
+ if (D == 2) {
+ DType* wx_0 = workptr; // D * ngates * I * H
+ wx_memory_0.set_data_handle(wx_0);
+ DType* wh_0 = wx_0 + D * ngates * I * H; // D * ngates * H * H
+ wh_memory_0.set_data_handle(wh_0);
+ workptr = wh_0 + D * ngates * H * H;
+ }
+ op.concat_weight_memory.push_back(wx_memory_0);
+ op.concat_weight_memory.push_back(wh_memory_0);
+
+ mkldnn::memory::dims src_iter_undi_tz_0 = {1, 1, nstates, N, H}; // ldsnc
+ auto src_iter_undi_md_0 = mkldnn::memory::desc(
+ { src_iter_undi_tz_0 }, mkldnn_dtype, mkldnn::memory::format::ldsnc);
+ DType* src_iter_undi_0 = workptr; // nstates * N * H
+ auto src_iter_undi_memory_0 =
+ mkldnn::memory({ src_iter_undi_md_0, cpu_engine }, src_iter_undi_0);
+ op.concat_iter_memory.push_back(src_iter_undi_memory_0);
+ workptr = src_iter_undi_0 + nstates * N * H;
+ if (D == 1) {
+ op.hcx_memory.push_back(src_iter_undi_memory_0);
+ } else {
+ DType* src_iter_undi2_0 = workptr; // nstates * N * H
+ auto src_iter_undi2_memory_0 =
+ mkldnn::memory({ src_iter_undi_md_0, cpu_engine }, src_iter_undi2_0);
+ op.concat_iter_memory.push_back(src_iter_undi2_memory_0);
+
+ mkldnn::memory::dims src_iter_tz_0 = {1, D, nstates, N, H}; // ldsnc
+ auto src_iter_md_0 = mkldnn::memory::desc(
+ { src_iter_tz_0 }, mkldnn_dtype, mkldnn::memory::format::ldsnc);
+ DType* src_iter_0 = src_iter_undi2_0 + nstates * N * H; // D * nstates * N * H
+ auto src_iter_memory_0 =
+ mkldnn::memory({ src_iter_md_0, cpu_engine }, src_iter_0);
+ op.concat_iter_memory.push_back(src_iter_memory_0);
+ op.hcx_memory.push_back(src_iter_memory_0);
+ workptr = src_iter_0 + D * nstates * N * H;
+ }
+
+ DType* dst_layer_0 = workptr; // T * N * D * H
+ auto dst_layer_memory_0
+ = mkldnn::memory({ dst_layer_md, cpu_engine }, dst_layer_0);
+ op.y_memory.push_back(dst_layer_memory_0);
+
+ mkldnn::memory::dims dst_iter_tz_0 = {1, D, nstates, N, H}; // ldsnc
+ auto dst_iter_md_0 = mkldnn::memory::desc(
+ { dst_iter_tz_0 }, mkldnn_dtype, mkldnn::memory::format::ldsnc);
+ DType* dst_iter_0 = dst_layer_0 + T * N * D * H; // D * nstates * N * H
+ auto dst_iter_memory_0 =
+ mkldnn::memory({ dst_iter_md_0, cpu_engine }, dst_iter_0);
+ op.hcy_memory.push_back(dst_iter_memory_0);
+ workptr = dst_iter_0 + D * nstates * N * H;
+
+ // next L - 1 layers
+ if (L > 1 && D == 1) {
+ auto user_src_layer_md = mkldnn::memory::desc(
+ { src_layer_tz }, mkldnn_dtype, mkldnn::memory::format::tnc);
+ auto user_src_layer_memory = mkldnn::memory({ user_src_layer_md, cpu_engine });
+ op.x_memory.push_back(user_src_layer_memory);
+
+ mkldnn::memory::dims weights_layer_tz = {L - 1, 1, H, ngates, H}; // ldigo
+ mkldnn::memory::dims weights_iter_tz = {L - 1, 1, H, ngates, H}; // ldigo
+ mkldnn::memory::dims bias_tz = {L - 1, 1, ngates, H};
+ auto user_weight_layer_md = mkldnn::memory::desc(
+ { weights_layer_tz }, mkldnn_dtype, mkldnn::memory::format::ldigo);
+ auto user_weight_iter_md = mkldnn::memory::desc(
+ { weights_iter_tz }, mkldnn_dtype, mkldnn::memory::format::ldigo);
+ auto user_bias_md = mkldnn::memory::desc({ bias_tz },
+ mkldnn_dtype, mkldnn::memory::format::ldgo);
+
+ DType* weight_layer_n = workptr; // (L - 1) * H * ngates * H
+ auto user_weight_layer_memory_n
+ = mkldnn::memory({ user_weight_layer_md, cpu_engine }, weight_layer_n);
+ op.wx_memory.push_back(user_weight_layer_memory_n);
+
+ DType* weight_iter_n = weight_layer_n +
+ (L - 1) * H * ngates * H; // (L - 1) * H * ngates * H
+ auto user_weight_iter_memory_n
+ = mkldnn::memory({ user_weight_iter_md, cpu_engine }, weight_iter_n);
+ op.wh_memory.push_back(user_weight_iter_memory_n);
+
+ DType* bias_n = weight_iter_n + (L - 1) * H * ngates * H; // (L - 1) * ngates * H
+ auto user_bias_memory_n =
+ mkldnn::memory({ user_bias_md, cpu_engine }, bias_n);
+ op.bias_memory.push_back(user_bias_memory_n);
+
+ auto wx_md_n = mkldnn::memory::desc(
+ { weights_layer_tz }, mkldnn_dtype, mkldnn::memory::format::ldgoi);
+ DType* wx_n = bias_n + (L - 1) * ngates * H; // (L - 1) * ngates * H * H
+ auto wx_memory_n =
+ mkldnn::memory({ wx_md_n, cpu_engine }, wx_n);
+ DType* wh_n = wx_n + (L - 1) * ngates * H * H; // (L - 1) * ngates * H * H
+ auto wh_md_n = mkldnn::memory::desc(
+ { weights_iter_tz }, mkldnn_dtype, mkldnn::memory::format::ldgoi);
+ auto wh_memory_n =
+ mkldnn::memory({ wh_md_n, cpu_engine }, wh_n);
+
+ op.concat_weight_memory.push_back(wx_memory_n);
+ op.concat_weight_memory.push_back(wh_memory_n);
+ workptr = wh_n + (L - 1) * ngates * H * H;
+
+ mkldnn::memory::dims src_iter_tz_n1 = {1, 1, nstates, N, H}; // ldsnc
+ auto src_iter_md_n1 = mkldnn::memory::desc(
+ { src_iter_tz_n1 }, mkldnn_dtype, mkldnn::memory::format::ldsnc);
+ for (int l = 0; l < L - 1; l++) {
+ DType* src_iter_n1 = workptr; // nstates * N * H
+ auto src_iter_memory_n1 =
+ mkldnn::memory({ src_iter_md_n1, cpu_engine }, src_iter_n1);
+ op.concat_iter_memory.push_back(src_iter_memory_n1);
+ workptr = src_iter_n1 + nstates * N * H;
+ }
+ mkldnn::memory::dims src_iter_tz_n = {L - 1, 1, nstates, N, H}; // ldsnc
+ auto src_iter_md_n = mkldnn::memory::desc(
+ { src_iter_tz_n }, mkldnn_dtype, mkldnn::memory::format::ldsnc);
+ DType* src_iter_n = workptr; // (L - 1) * nstates * N * H
+ auto src_iter_memory_n =
+ mkldnn::memory({ src_iter_md_n, cpu_engine }, src_iter_n);
+ op.concat_iter_memory.push_back(src_iter_memory_n);
+ op.hcx_memory.push_back(src_iter_memory_n);
+
+ DType* dst_layer_n = src_iter_n + (L - 1) * nstates * N * H; // T * N * D * H
+ auto dst_layer_memory_n
+ = mkldnn::memory({ dst_layer_md, cpu_engine }, dst_layer_n);
+ op.y_memory.push_back(dst_layer_memory_n);
+
+ mkldnn::memory::dims dst_iter_tz_n = {L - 1, 1, nstates, N, H}; // ldsnc
+ auto dst_iter_md_n = mkldnn::memory::desc(
+ { dst_iter_tz_n }, mkldnn_dtype, mkldnn::memory::format::ldsnc);
+ DType* dst_iter_n = dst_layer_n + T * N * D * H; // (L - 1) * nstates * N * H
+ auto dst_iter_memory_n =
+ mkldnn::memory({ dst_iter_md_n, cpu_engine }, dst_iter_n);
+ op.hcy_memory.push_back(dst_iter_memory_n);
+ }
+
+ if (L > 1 && D == 2) {
+ mkldnn::memory::dims weights_layer_tz = {1, D, H * D, ngates, H}; // ldigo
+ mkldnn::memory::dims weights_iter_tz = {1, D, H, ngates, H}; // ldigo
+ mkldnn::memory::dims bias_tz = {1, D, ngates, H};
+ auto user_weight_layer_md = mkldnn::memory::desc(
+ { weights_layer_tz }, mkldnn_dtype, mkldnn::memory::format::ldigo);
+ auto user_weight_iter_md = mkldnn::memory::desc(
+ { weights_iter_tz }, mkldnn_dtype, mkldnn::memory::format::ldigo);
+ auto user_bias_md = mkldnn::memory::desc({ bias_tz },
+ mkldnn_dtype, mkldnn::memory::format::ldgo);
+
+ auto user_src_layer_md = mkldnn::memory::desc(
+ { src_layer_tz }, mkldnn_dtype, mkldnn::memory::format::tnc);
+ auto user_src_layer_memory = mkldnn::memory({ user_src_layer_md, cpu_engine });
+ op.x_memory.push_back(user_src_layer_memory);
+
+ auto wx_md_n = mkldnn::memory::desc(
+ { weights_layer_tz }, mkldnn_dtype, mkldnn::memory::format::ldgoi);
+ auto wh_md_n = mkldnn::memory::desc(
+ { weights_iter_tz }, mkldnn_dtype, mkldnn::memory::format::ldgoi);
+
+ for (int l = 0; l < L; l++) {
+ DType* weight_layer_n = workptr; // D * (H * D) * ngates * H
+ auto user_weight_layer_memory_n
+ = mkldnn::memory({ user_weight_layer_md, cpu_engine }, weight_layer_n);
+ op.wx_memory.push_back(user_weight_layer_memory_n);
+
+ DType* weight_iter_n = weight_layer_n +
+ D * (H * D) * ngates * H; // D * H * ngates * H
+ auto user_weight_iter_memory_n
+ = mkldnn::memory({ user_weight_iter_md, cpu_engine }, weight_iter_n);
+ op.wh_memory.push_back(user_weight_iter_memory_n);
+
+ DType* bias_n = weight_iter_n + D * H * ngates * H; // D * ngates * H
+ auto user_bias_memory_n =
+ mkldnn::memory({ user_bias_md, cpu_engine }, bias_n);
+ op.bias_memory.push_back(user_bias_memory_n);
+ workptr = bias_n + D * ngates * H;
+ }
+
+ DType* wx_n = workptr; // D * ngates * (D * H) * H
+ DType* wh_n = wx_n + D * ngates * (D * H) * H; // D * ngates * H * H
+ auto wx_memory_n =
+ mkldnn::memory({ wx_md_n, cpu_engine }, wx_n);
+ auto wh_memory_n =
+ mkldnn::memory({ wh_md_n, cpu_engine }, wh_n);
+ op.concat_weight_memory.push_back(wx_memory_n);
+ op.concat_weight_memory.push_back(wh_memory_n);
+
+ mkldnn::memory::dims src_iter_undi_tz = {1, 1, nstates, N, H}; // ldsnc
+ auto src_iter_undi_md = mkldnn::memory::desc(
+ { src_iter_undi_tz }, mkldnn_dtype, mkldnn::memory::format::ldsnc);
+ DType* src_iter_undi = wh_n + D * ngates * H * H; // nstates * N * H
+ auto src_iter_undi_memory =
+ mkldnn::memory({ src_iter_undi_md, cpu_engine }, src_iter_undi);
+ op.concat_iter_memory.push_back(src_iter_undi_memory_0);
+
+ DType* src_iter_undi2 = src_iter_undi + nstates * N * H; // nstates * N * H
+ auto src_iter_undi2_memory =
+ mkldnn::memory({ src_iter_undi_md, cpu_engine }, src_iter_undi2);
+ op.concat_iter_memory.push_back(src_iter_undi2_memory);
+
+ mkldnn::memory::dims src_iter_tz = {1, D, nstates, N, H}; // ldsnc
+ auto src_iter_md = mkldnn::memory::desc(
+ { src_iter_tz }, mkldnn_dtype, mkldnn::memory::format::ldsnc);
+ DType* src_iter = src_iter_undi2 + nstates * N * H; // D * nstates * N * H
+ auto src_iter_memory =
+ mkldnn::memory({ src_iter_md, cpu_engine }, src_iter);
+ op.concat_iter_memory.push_back(src_iter_memory);
+ op.hcx_memory.push_back(src_iter_memory);
+
+ DType* dst_layer_n = src_iter + D * nstates * N * H; // T * N * D * H
+ auto dst_layer_memory_n
+ = mkldnn::memory({ dst_layer_md, cpu_engine }, dst_layer_n);
+ op.y_memory.push_back(dst_layer_memory_n);
+
+ mkldnn::memory::dims dst_iter_tz_n = {1, D, nstates, N, H}; // ldsnc
+ auto dst_iter_md_n = mkldnn::memory::desc(
+ { dst_iter_tz_n }, mkldnn_dtype, mkldnn::memory::format::ldsnc);
+ DType* dst_iter_n = dst_layer_n + T * N * D * H; // D * nstates * N * H
+ auto dst_iter_memory_n =
+ mkldnn::memory({ dst_iter_md_n, cpu_engine }, dst_iter_n);
+ op.hcy_memory.push_back(dst_iter_memory_n);
+ }
+ }
+ }
+ op.Forward(ctx, in_blobs, req, out_blobs);
+ });
+ });
+}
+#endif
+
NNVM_REGISTER_OP(RNN)
.describe(R"code(Applies recurrent layers to input data. Currently, vanilla RNN, LSTM and GRU are
implemented, with both multi-layer and bidirectional support.
@@ -269,8 +695,13 @@ The definition of GRU here is slightly different from paper but compatible with
})
.set_attr<mxnet::FInferShape>("FInferShape", RNNShape)
.set_attr<nnvm::FInferType>("FInferType", RNNType)
+.set_attr<FInferStorageType>("FInferStorageType", RNNStorageType)
.set_attr<FCreateOpState>("FCreateOpState", CreateRNNState)
.set_attr<FStatefulCompute>("FStatefulCompute<cpu>", RNNStatefulCompute<cpu>)
+#if MXNET_USE_MKLDNN == 1
+.set_attr<bool>("TIsMKLDNN", true)
+.set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", RNNStatefulComputeCPU)
+#endif
.set_attr<nnvm::FGradient>("FGradient", RNNGrad{"_backward_RNN"})
.set_attr<FResourceRequestEx>("FResourceRequestEx",
[](const NodeAttrs& attrs, const int dev_mask, const DispatchMode dispatch_mode) {
diff --git a/src/operator/rnn_impl.h b/src/operator/rnn_impl.h
index e1b4a2b..425ea4a 100644
--- a/src/operator/rnn_impl.h
+++ b/src/operator/rnn_impl.h
@@ -44,6 +44,13 @@
namespace mxnet {
namespace op {
+namespace rnn_enum {
+ enum RNNOpInputs {kData, kParams, kState, kStateCell, kSequenceLength};
+ enum RNNOpOutputs {kOut, kStateOut, kStateCellOut};
+ enum RNNModeType {kRnnRelu, kRnnTanh, kLstm, kGru};
+ enum RNNOpResource {kTempSpace, kCuDNNDropoutDescSpace};
+}
+
template<typename DType>
inline DType sigmoid(DType x) {
return 1.0f / (1.0f + exp(-x));