You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by sx...@apache.org on 2020/10/31 18:39:13 UTC
[incubator-mxnet] branch master updated: operators for the sliding
window self-attention (#19387)
This is an automated email from the ASF dual-hosted git repository.
sxjscience pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new cec6bcf operators for the sliding window self-attention (#19387)
cec6bcf is described below
commit cec6bcf4cc31751370a0964ec1ff6760ba9fa319
Author: Ziyue Huang <zi...@apache.org>
AuthorDate: Sun Nov 1 02:37:59 2020 +0800
operators for the sliding window self-attention (#19387)
* operators for the sliding window attention
* address comments
* lint
* fix CI
---
python/mxnet/amp/lists/symbol_fp16.py | 3 +
src/operator/contrib/transformer-inl.h | 222 +++++++++++++++++++++++++++++++++
src/operator/contrib/transformer.cc | 191 ++++++++++++++++++++++++++++
src/operator/contrib/transformer.cu | 16 +++
tests/python/unittest/test_operator.py | 78 ++++++++++++
5 files changed, 510 insertions(+)
diff --git a/python/mxnet/amp/lists/symbol_fp16.py b/python/mxnet/amp/lists/symbol_fp16.py
index b7e3dcb..4b54448 100644
--- a/python/mxnet/amp/lists/symbol_fp16.py
+++ b/python/mxnet/amp/lists/symbol_fp16.py
@@ -591,6 +591,9 @@ FP32_FUNCS = [
'_npx_deformable_convolution',
'_npx_modulated_deformable_convolution',
'_contrib_DeformablePSROIPooling',
+ '_contrib_sldwin_atten_score',
+ '_contrib_sldwin_atten_mask_like',
+ '_contrib_sldwin_atten_context',
]
if Features().is_enabled('MKLDNN'):
diff --git a/src/operator/contrib/transformer-inl.h b/src/operator/contrib/transformer-inl.h
index da48ffa..0e7f6c4 100644
--- a/src/operator/contrib/transformer-inl.h
+++ b/src/operator/contrib/transformer-inl.h
@@ -61,6 +61,228 @@ static void DivSqrtDimForward_(const nnvm::NodeAttrs& attrs,
});
}
+
+
+struct SldWinAttenParam : public dmlc::Parameter<SldWinAttenParam> {
+ int w;
+ bool symmetric;
+ DMLC_DECLARE_PARAMETER(SldWinAttenParam) {
+ DMLC_DECLARE_FIELD(w)
+ .describe("The one-sided window length");
+ DMLC_DECLARE_FIELD(symmetric)
+ .describe("If false, each token will only attend to itself and the previous tokens.");
+ }
+};
+
+
+struct SldWinAttenMaskLike {
+ MSHADOW_XINLINE static void Map(int i, float *out, int32_t *dilation, int32_t *val_length,
+ bool symmetric, int w, int seq_length, int num_heads) {
+ out[i] = 1;
+ int w_len = symmetric ? (w + w + 1) : (w + 1);
+ int idx_0 = i / (seq_length * num_heads * w_len); // batch idx
+ int tmp = i % (seq_length * num_heads * w_len);
+ int idx_1 = tmp / (num_heads * w_len); // sequence idx
+ tmp = tmp % (num_heads * w_len);
+ int idx_2 = tmp / w_len; // head idx
+ int idx_3 = tmp % w_len; // win idx
+
+ bool is_zero = idx_3 < (w - idx_1/dilation[idx_2]) || idx_1 >= val_length[idx_0] \
+ || (symmetric && (w_len-idx_3-1) < (w - (val_length[idx_0]-idx_1-1)/dilation[idx_2]));
+ if (is_zero) out[i] = 0;
+ }
+};
+
+
+template<typename xpu>
+void SldWinAttenMaskLikeForward(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const std::vector<TBlob>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<TBlob>& outputs) {
+ CHECK_EQ(inputs.size(), 3U);
+ CHECK_EQ(outputs.size(), 1U);
+ CHECK_EQ(req.size(), 1U);
+ CHECK_EQ(req[0], kWriteTo) << "Currently only support kWriteTo";
+ using namespace mshadow;
+ Stream<xpu>* s = ctx.get_stream<xpu>();
+ const SldWinAttenParam& param = nnvm::get<SldWinAttenParam>(attrs.parsed);
+ CHECK_EQ(outputs[0].type_flag_, kFloat32);
+ float* out = outputs[0].dptr<float>();
+ CHECK_EQ(inputs[1].type_flag_, kInt32);
+ int32_t* dilation = inputs[1].dptr<int32_t>();
+ CHECK_EQ(inputs[2].type_flag_, kInt32);
+ int32_t* val_length = inputs[2].dptr<int32_t>();
+
+ int seq_length = inputs[0].shape_[1];
+ int num_heads = inputs[0].shape_[2];
+ int num_threads = outputs[0].Size();
+
+ mxnet_op::Kernel<SldWinAttenMaskLike, xpu>::Launch(s, num_threads, out, dilation,
+ val_length, param.symmetric, param.w, seq_length, num_heads);
+}
+
+
+
+
+struct DiagMM {
+ MSHADOW_XINLINE static void Map(int tid, float *out, float *lhs, float *rhs,
+ int32_t *dilation, int batch_size, int seq_length,
+ int num_heads, int out_last_dim, int lhs_last_dim, int w,
+ int w_right, bool diagonal_lhs, bool transpose_lhs) {
+ out[tid] = 0;
+ int stride = seq_length * num_heads * out_last_dim;
+ int idx_0 = tid / stride; // batch idx
+ int tmp = tid % stride;
+ stride = num_heads * out_last_dim;
+ int idx_1 = tmp / stride; // sequence idx
+ tmp = tmp % stride;
+ int idx_2 = tmp / out_last_dim; // head idx
+ int idx_3 = tmp % out_last_dim; // window idx or hidden feature idx
+
+ if (!diagonal_lhs) {
+ int tmp_idx = idx_1 + dilation[idx_2] * (idx_3 - w);
+ if (tmp_idx >= seq_length || tmp_idx < 0) return;
+ for (int i = 0; i < lhs_last_dim; i++) {
+ int lhs_idx = idx_0 * (seq_length * num_heads * lhs_last_dim) + \
+ idx_1 * (num_heads * lhs_last_dim) + idx_2 * lhs_last_dim + i;
+ int rhs_idx = idx_0 * (seq_length * num_heads * lhs_last_dim) + \
+ tmp_idx * (num_heads * lhs_last_dim) + idx_2 * lhs_last_dim + i;
+ out[tid] += lhs[lhs_idx] * rhs[rhs_idx];
+ }
+ } else {
+ if (!transpose_lhs) {
+ for (int i = 0; i < lhs_last_dim; i++) {
+ int tmp_idx = idx_1 + dilation[idx_2] * (i - w);
+ if (tmp_idx >= seq_length || tmp_idx < 0) continue;
+ int lhs_idx = idx_0 * (seq_length * num_heads * lhs_last_dim) + \
+ idx_1 * (num_heads * lhs_last_dim) + idx_2 * lhs_last_dim + i;
+ int rhs_idx = idx_0 * (seq_length * num_heads * out_last_dim) + \
+ tmp_idx * (num_heads * out_last_dim) + idx_2 * out_last_dim + idx_3;
+ out[tid] += lhs[lhs_idx] * rhs[rhs_idx];
+ }
+ } else {
+ for (int i = 0; i < lhs_last_dim; i++) {
+ int tmp_idx = idx_1 + dilation[idx_2] * (i - w_right);
+ if (tmp_idx >= seq_length || tmp_idx < 0) continue;
+ int lhs_idx = idx_0 * (seq_length * num_heads * lhs_last_dim) + \
+ tmp_idx * (num_heads * lhs_last_dim) + idx_2 * lhs_last_dim + ((w_right + w) - i);
+ int rhs_idx = idx_0 * (seq_length * num_heads * out_last_dim) + \
+ tmp_idx * (num_heads * out_last_dim) + idx_2 * out_last_dim + idx_3;
+ out[tid] += lhs[lhs_idx] * rhs[rhs_idx];
+ }
+ }
+ }
+ }
+};
+
+
+
+template<typename xpu>
+void DiagMMImpl(const OpContext& ctx, const TBlob& out, const TBlob& lhs,
+ const TBlob& rhs, const TBlob& dilation, bool diagonal_lhs,
+ bool transpose_lhs, int w, int w_right) {
+ using namespace mshadow;
+ Stream<xpu>* s = ctx.get_stream<xpu>();
+ CHECK_EQ(out.type_flag_, kFloat32);
+ CHECK_EQ(lhs.type_flag_, kFloat32);
+ CHECK_EQ(rhs.type_flag_, kFloat32);
+ CHECK_EQ(dilation.type_flag_, kInt32);
+
+ float* lhs_data = lhs.dptr<float>();
+ float* rhs_data = rhs.dptr<float>();
+ int32_t* dilation_data = dilation.dptr<int32_t>();
+ float* out_data = out.dptr<float>();
+
+ int batch_size = lhs.shape_[0];
+ int seq_length = lhs.shape_[1];
+ int num_heads = lhs.shape_[2];
+ int lhs_last_dim = lhs.shape_[3];
+ int out_last_dim = out.shape_[3];
+ int num_threads = out.Size();
+
+ mxnet_op::Kernel<DiagMM, xpu>::Launch(s, num_threads, out_data, lhs_data, rhs_data,
+ dilation_data, batch_size, seq_length, num_heads, out_last_dim, lhs_last_dim, w,
+ w_right, diagonal_lhs, transpose_lhs);
+}
+
+
+template<typename xpu>
+void SldWinAttenScoreForward(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const std::vector<TBlob>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<TBlob>& outputs) {
+ CHECK_EQ(inputs.size(), 3U);
+ CHECK_EQ(outputs.size(), 1U);
+ using namespace mshadow;
+ const SldWinAttenParam& param = nnvm::get<SldWinAttenParam>(attrs.parsed);
+ int w_right = param.symmetric ? param.w : 0;
+ DiagMMImpl<xpu>(ctx, outputs.at(0), inputs.at(0), inputs.at(1), inputs.at(2),
+ false, false, param.w, w_right);
+}
+
+
+template<typename xpu>
+void SldWinAttenScoreBackward(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const std::vector<TBlob>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<TBlob>& outputs) {
+ CHECK_EQ(inputs.size(), 4U);
+ CHECK_EQ(outputs.size(), 3U);
+ using namespace mshadow;
+ const SldWinAttenParam& param = nnvm::get<SldWinAttenParam>(attrs.parsed);
+ int w_right = param.symmetric ? param.w : 0;
+ // grad_query = matmul(grad_score, key)
+ DiagMMImpl<xpu>(ctx, outputs.at(0), inputs.at(0), inputs.at(2), inputs.at(3),
+ true, false, param.w, w_right);
+ // grad_key = matmul(grad_score.T, query)
+ DiagMMImpl<xpu>(ctx, outputs.at(1), inputs.at(0), inputs.at(1), inputs.at(3),
+ true, true, param.w, w_right);
+}
+
+
+
+template<typename xpu>
+void SldWinAttenContextForward(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const std::vector<TBlob>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<TBlob>& outputs) {
+ CHECK_EQ(inputs.size(), 3U);
+ CHECK_EQ(outputs.size(), 1U);
+ using namespace mshadow;
+ const SldWinAttenParam& param = nnvm::get<SldWinAttenParam>(attrs.parsed);
+ int w_right = param.symmetric ? param.w : 0;
+ // context_vec = matmul(score, value)
+ DiagMMImpl<xpu>(ctx, outputs.at(0), inputs.at(0), inputs.at(1), inputs.at(2),
+ true, false, param.w, w_right);
+}
+
+
+template<typename xpu>
+void SldWinAttenContextBackward(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const std::vector<TBlob>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<TBlob>& outputs) {
+ CHECK_EQ(inputs.size(), 4U);
+ CHECK_EQ(outputs.size(), 3U);
+ using namespace mshadow;
+ const SldWinAttenParam& param = nnvm::get<SldWinAttenParam>(attrs.parsed);
+ int w_right = param.symmetric ? param.w : 0;
+ // grad_score = matmul(grad_context, value.T)
+ DiagMMImpl<xpu>(ctx, outputs.at(0), inputs.at(0), inputs.at(2), inputs.at(3),
+ false, false, param.w, w_right);
+ // grad_value = matmul(score.T, grad_context)
+ DiagMMImpl<xpu>(ctx, outputs.at(1), inputs.at(1), inputs.at(0), inputs.at(3),
+ true, true, param.w, w_right);
+}
+
+
+
+
} // namespace op
} // namespace mxnet
#endif // MXNET_OPERATOR_CONTRIB_TRANSFORMER_INL_H_
diff --git a/src/operator/contrib/transformer.cc b/src/operator/contrib/transformer.cc
index a839811..43c322e 100644
--- a/src/operator/contrib/transformer.cc
+++ b/src/operator/contrib/transformer.cc
@@ -841,5 +841,196 @@ MXNET_OPERATOR_REGISTER_UNARY(_contrib_div_sqrt_dim)
.set_attr<FCompute>("FCompute<cpu>", DivSqrtDimForward_<cpu>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_contrib_div_sqrt_dim"});
+
+DMLC_REGISTER_PARAMETER(SldWinAttenParam);
+
+NNVM_REGISTER_OP(_contrib_sldwin_atten_mask_like)
+.add_alias("_npx_sldwin_atten_mask_like")
+.describe(R"code(Compute the mask for the sliding window attention score, used in
+Longformer (https://arxiv.org/pdf/2004.05150.pdf). In this attention pattern,
+given a fixed window size *2w*, each token attends to *w* tokens on the left side
+if we use causal attention (setting *symmetric* to *False*),
+otherwise each token attends to *w* tokens on each side.
+
+The shapes of the inputs are:
+- *score* : (batch_size, seq_length, num_heads, w + w + 1) if symmetric is True,
+ otherwise (batch_size, seq_length, num_heads, w + 1).
+- *dilation* : (num_heads,)
+- *valid_length* : (batch_size,)
+
+The shape of the output is:
+- *mask* : same as the shape of *score*
+
+)code" ADD_FILELINE)
+.set_num_inputs(3)
+.set_num_outputs(1)
+.set_attr_parser(ParamParser<SldWinAttenParam>)
+.set_attr<nnvm::FListInputNames>("FListInputNames", [](const NodeAttrs& attrs) {
+ return std::vector<std::string>{"score", "dilation", "valid_length"};
+})
+.set_attr<nnvm::FListOutputNames>("FListOutputNames", [](const NodeAttrs& attrs) {
+ return std::vector<std::string>{"mask"};
+})
+.set_attr<mxnet::FInferShape>("FInferShape", [](const nnvm::NodeAttrs& attrs,
+ mxnet::ShapeVector *in_attrs,
+ mxnet::ShapeVector *out_attrs) {
+ CHECK_EQ(in_attrs->size(), 3U);
+ CHECK_EQ(out_attrs->size(), 1U);
+ const mxnet::TShape& dshape = (*in_attrs)[0];
+ if (!shape_is_known(dshape)) return false;
+ SHAPE_ASSIGN_CHECK(*out_attrs, 0, dshape);
+ return true;
+})
+.set_attr<nnvm::FInferType>("FInferType", [](const nnvm::NodeAttrs &attrs,
+ std::vector<int> *in_attrs,
+ std::vector<int> *out_attrs) {
+ CHECK_EQ(in_attrs->size(), 3U);
+ CHECK_EQ(out_attrs->size(), 1U);
+ TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kFloat32);
+ return out_attrs->at(0) != -1;
+})
+.set_attr<FCompute>("FCompute<cpu>", SldWinAttenMaskLikeForward<cpu>)
+.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
+.add_argument("score", "NDArray-or-Symbol", "sliding window attention score")
+.add_argument("dilation", "NDArray-or-Symbol", "dilation")
+.add_argument("valid_length", "NDArray-or-Symbol", "valid length")
+.add_arguments(SldWinAttenParam::__FIELDS__());
+
+
+NNVM_REGISTER_OP(_contrib_sldwin_atten_score)
+.add_alias("_npx_sldwin_atten_score")
+.describe(R"code(Compute the sliding window attention score, which is used in
+Longformer (https://arxiv.org/pdf/2004.05150.pdf). In this attention pattern,
+given a fixed window size *2w*, each token attends to *w* tokens on the left side
+if we use causal attention (setting *symmetric* to *False*),
+otherwise each token attends to *w* tokens on each side.
+
+The shapes of the inputs are:
+- *query* : (batch_size, seq_length, num_heads, num_head_units)
+- *key* : (batch_size, seq_length, num_heads, num_head_units)
+- *dilation* : (num_heads,)
+
+The shape of the output is:
+- *score* : (batch_size, seq_length, num_heads, w + w + 1) if symmetric is True,
+ otherwise (batch_size, seq_length, num_heads, w + 1).
+
+)code" ADD_FILELINE)
+.set_num_inputs(3)
+.set_num_outputs(1)
+.set_attr_parser(ParamParser<SldWinAttenParam>)
+.set_attr<nnvm::FListInputNames>("FListInputNames", [](const NodeAttrs& attrs) {
+ return std::vector<std::string>{"query", "key", "dilation"};
+})
+.set_attr<nnvm::FListOutputNames>("FListOutputNames", [](const NodeAttrs& attrs) {
+ return std::vector<std::string>{"score"};
+})
+.set_attr<mxnet::FInferShape>("FInferShape", [](const nnvm::NodeAttrs& attrs,
+ mxnet::ShapeVector *inshapes,
+ mxnet::ShapeVector *outshapes) {
+ unsigned int batch_size = inshapes->at(0)[0];
+ unsigned int seq_length = inshapes->at(0)[1];
+ unsigned int num_heads = inshapes->at(0)[2];
+ unsigned int lhs_last_dim = inshapes->at(0)[3];
+ unsigned int num_hidden = inshapes->at(1)[3];
+ CHECK_EQ(lhs_last_dim, num_hidden);
+ CHECK_EQ(inshapes->at(2)[0], num_heads);
+ const SldWinAttenParam& param = nnvm::get<SldWinAttenParam>(attrs.parsed);
+ unsigned int w_len = param.symmetric ? (param.w + param.w + 1) : (param.w + 1);
+ outshapes->at(0) = mshadow::Shape4(batch_size, seq_length, num_heads, w_len);
+ return true;
+})
+.set_attr<nnvm::FInferType>("FInferType", [](const nnvm::NodeAttrs &attrs,
+ std::vector<int> *in_attrs,
+ std::vector<int> *out_attrs) {
+ CHECK_EQ(in_attrs->size(), 3U);
+ CHECK_EQ(out_attrs->size(), 1U);
+ TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kFloat32);
+ return out_attrs->at(0) != -1;
+})
+.set_attr<FCompute>("FCompute<cpu>", SldWinAttenScoreForward<cpu>)
+.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_sldwin_atten_score"})
+.add_argument("query", "NDArray-or-Symbol", "query")
+.add_argument("key", "NDArray-or-Symbol", "key")
+.add_argument("dilation", "NDArray-or-Symbol", "dilation")
+.add_arguments(SldWinAttenParam::__FIELDS__());
+
+
+NNVM_REGISTER_OP(_backward_sldwin_atten_score)
+.set_num_inputs(4)
+.set_num_outputs(3)
+.set_attr<nnvm::TIsBackward>("TIsBackward", true)
+.set_attr_parser(ParamParser<SldWinAttenParam>)
+.set_attr<FCompute>("FCompute<cpu>", SldWinAttenScoreBackward<cpu>);
+
+
+NNVM_REGISTER_OP(_contrib_sldwin_atten_context)
+.add_alias("_npx_sldwin_atten_context")
+.describe(R"code(Compute the context vector for sliding window attention, used in
+Longformer (https://arxiv.org/pdf/2004.05150.pdf). In this attention pattern,
+given a fixed window size *2w*, each token attends to *w* tokens on the left side
+if we use causal attention (setting *symmetric* to *False*),
+otherwise each token attends to *w* tokens on each side.
+
+The shapes of the inputs are:
+- *score* : (batch_size, seq_length, num_heads, w + w + 1) if symmetric is True,
+ otherwise (batch_size, seq_length, num_heads, w + 1).
+- *value* : (batch_size, seq_length, num_heads, num_head_units)
+- *dilation* : (num_heads,)
+
+The shape of the output is:
+- *context_vec* : (batch_size, seq_length, num_heads, num_head_units)
+
+)code" ADD_FILELINE)
+.set_num_inputs(3)
+.set_num_outputs(1)
+.set_attr_parser(ParamParser<SldWinAttenParam>)
+.set_attr<nnvm::FListInputNames>("FListInputNames", [](const NodeAttrs& attrs) {
+ return std::vector<std::string>{"score", "value", "dilation"};
+})
+.set_attr<nnvm::FListOutputNames>("FListOutputNames", [](const NodeAttrs& attrs) {
+ return std::vector<std::string>{"context_vec"};
+})
+.set_attr<mxnet::FInferShape>("FInferShape", [](const nnvm::NodeAttrs& attrs,
+ mxnet::ShapeVector *inshapes,
+ mxnet::ShapeVector *outshapes) {
+ unsigned int batch_size = inshapes->at(0)[0];
+ unsigned int seq_length = inshapes->at(0)[1];
+ unsigned int num_heads = inshapes->at(0)[2];
+ unsigned int lhs_last_dim = inshapes->at(0)[3];
+ unsigned int num_hidden = inshapes->at(1)[3];
+ CHECK_EQ(inshapes->at(2)[0], num_heads);
+ const SldWinAttenParam& param = nnvm::get<SldWinAttenParam>(attrs.parsed);
+ unsigned int w_len = param.symmetric ? (param.w + param.w + 1) : (param.w + 1);
+ CHECK_EQ(lhs_last_dim, w_len);
+
+ outshapes->at(0) = mshadow::Shape4(batch_size, seq_length, num_heads, num_hidden);
+
+ return true;
+})
+.set_attr<nnvm::FInferType>("FInferType", [](const nnvm::NodeAttrs &attrs,
+ std::vector<int> *in_attrs,
+ std::vector<int> *out_attrs) {
+ CHECK_EQ(in_attrs->size(), 3U);
+ CHECK_EQ(out_attrs->size(), 1U);
+ TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kFloat32);
+ return out_attrs->at(0) != -1;
+})
+.set_attr<FCompute>("FCompute<cpu>", SldWinAttenContextForward<cpu>)
+.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_sldwin_atten_context"})
+.add_argument("score", "NDArray-or-Symbol", "score")
+.add_argument("value", "NDArray-or-Symbol", "value")
+.add_argument("dilation", "NDArray-or-Symbol", "dilation")
+.add_arguments(SldWinAttenParam::__FIELDS__());
+
+
+NNVM_REGISTER_OP(_backward_sldwin_atten_context)
+.set_num_inputs(4)
+.set_num_outputs(3)
+.set_attr<nnvm::TIsBackward>("TIsBackward", true)
+.set_attr_parser(ParamParser<SldWinAttenParam>)
+.set_attr<FCompute>("FCompute<cpu>", SldWinAttenContextBackward<cpu>);
+
+
+
} // namespace op
} // namespace mxnet
diff --git a/src/operator/contrib/transformer.cu b/src/operator/contrib/transformer.cu
index bfa4993..ab20f1c 100644
--- a/src/operator/contrib/transformer.cu
+++ b/src/operator/contrib/transformer.cu
@@ -682,5 +682,21 @@ NNVM_REGISTER_OP(_backward_interleaved_matmul_encdec_valatt)
NNVM_REGISTER_OP(_contrib_div_sqrt_dim)
.set_attr<FCompute>("FCompute<gpu>", DivSqrtDimForward_<gpu>);
+
+NNVM_REGISTER_OP(_contrib_sldwin_atten_mask_like)
+.set_attr<FCompute>("FCompute<gpu>", SldWinAttenMaskLikeForward<gpu>);
+
+NNVM_REGISTER_OP(_contrib_sldwin_atten_score)
+.set_attr<FCompute>("FCompute<gpu>", SldWinAttenScoreForward<gpu>);
+
+NNVM_REGISTER_OP(_backward_sldwin_atten_score)
+.set_attr<FCompute>("FCompute<gpu>", SldWinAttenScoreBackward<gpu>);
+
+NNVM_REGISTER_OP(_contrib_sldwin_atten_context)
+.set_attr<FCompute>("FCompute<gpu>", SldWinAttenContextForward<gpu>);
+
+NNVM_REGISTER_OP(_backward_sldwin_atten_context)
+.set_attr<FCompute>("FCompute<gpu>", SldWinAttenContextBackward<gpu>);
+
} // namespace op
} // namespace mxnet
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index 3660d7f..b3a6ce4 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -9250,3 +9250,81 @@ def test_broadcast_ops_on_misaligned_input_oneside(dtype, lead_dim, both_ways):
mx.nd.waitall()
assert_almost_equal(f, expected)
+
+def test_sldwin_selfatten_operators():
+ def gen_sliding_window_mask_full(batch_size, num_heads, seq_length, w, symmetric, d):
+ mask_np = np.zeros((batch_size, num_heads, seq_length, seq_length))
+ for i in range(seq_length):
+ end = (i + 1 + w * d) if symmetric else (i + 1)
+ for j in range(i - w * d, end, d):
+ if j >= 0 and j < seq_length:
+ mask_np[:, :, i, j] = 1
+ return mask_np
+
+ def test_sldwin_atten_op_impl(batch_size, seq_length, num_heads,
+ num_head_units, w, symmetric, d):
+ # Generate the data
+ query = np.random.normal(0, 1, (batch_size, seq_length, num_heads, num_head_units))
+ key = np.random.normal(0, 1, (batch_size, seq_length, num_heads, num_head_units))
+ value = np.random.normal(0, 1, (batch_size, seq_length, num_heads, num_head_units))
+ valid_length = np.zeros((batch_size,))
+ valid_length[:] = seq_length
+
+ query = mx.np.array(query, dtype=np.float32)
+ key = mx.np.array(key, dtype=np.float32)
+ value = mx.np.array(value, dtype=np.float32)
+ dilation = mx.np.ones((num_heads,), dtype=np.int32)
+ dilation[:] = d
+ valid_length = mx.np.array(valid_length, dtype=np.int32)
+
+ query.attach_grad()
+ key.attach_grad()
+ value.attach_grad()
+
+ with mx.autograd.record():
+ score = mx.npx.sldwin_atten_score(query, key, dilation,
+ w=w, symmetric=symmetric)
+ mask = mx.npx.sldwin_atten_mask_like(score, dilation, valid_length,
+ w=w, symmetric=symmetric)
+ score = score * mask
+ out = mx.npx.sldwin_atten_context(score, value, dilation,
+ w=w, symmetric=symmetric)
+ out.backward()
+
+ out_np = out.asnumpy()
+ grad_query = query.grad.asnumpy()
+ grad_key = key.grad.asnumpy()
+ grad_value = value.grad.asnumpy()
+
+ query.grad[:] = 0
+ key.grad[:] = 0
+ value.grad[:] = 0
+
+ mask_np = gen_sliding_window_mask_full(batch_size, num_heads, seq_length,
+ w, symmetric, d)
+ mask = mx.np.array(mask_np, dtype=np.float32)
+
+ with mx.autograd.record():
+ score = mx.npx.batch_dot(mx.np.swapaxes(query, 1, 2),
+ mx.np.swapaxes(key, 1, 2),
+ transpose_b=True)
+ score = score * mask
+ out = mx.npx.batch_dot(score,
+ mx.np.swapaxes(value, 1, 2)).transpose((0, 2, 1, 3))
+ out.backward()
+
+ out_np_gt = out.asnumpy()
+ grad_query_gt = query.grad.asnumpy()
+ grad_key_gt = key.grad.asnumpy()
+ grad_value_gt = value.grad.asnumpy()
+
+ assert_allclose(out_np_gt, out_np, 1E-3, 1E-3)
+ assert_allclose(grad_query_gt, grad_query, 1E-3, 1E-3)
+ assert_allclose(grad_key_gt, grad_key, 1E-3, 1E-3)
+ assert_allclose(grad_value_gt, grad_value, 1E-3, 1E-3)
+
+ for symmetric in [True, False]:
+ for d in [1, 2, 3]:
+ test_sldwin_atten_op_impl(2, 128, 2, 8, 16, symmetric, d)
+ test_sldwin_atten_op_impl(1, 8, 2, 4, 2, symmetric, d)
+