You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by sx...@apache.org on 2020/10/31 18:39:13 UTC

[incubator-mxnet] branch master updated: operators for the sliding window self-attention (#19387)

This is an automated email from the ASF dual-hosted git repository.

sxjscience pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new cec6bcf  operators for the sliding window self-attention (#19387)
cec6bcf is described below

commit cec6bcf4cc31751370a0964ec1ff6760ba9fa319
Author: Ziyue Huang <zi...@apache.org>
AuthorDate: Sun Nov 1 02:37:59 2020 +0800

    operators for the sliding window self-attention (#19387)
    
    * operators for the sliding window attention
    
    * address comments
    
    * lint
    
    * fix CI
---
 python/mxnet/amp/lists/symbol_fp16.py  |   3 +
 src/operator/contrib/transformer-inl.h | 222 +++++++++++++++++++++++++++++++++
 src/operator/contrib/transformer.cc    | 191 ++++++++++++++++++++++++++++
 src/operator/contrib/transformer.cu    |  16 +++
 tests/python/unittest/test_operator.py |  78 ++++++++++++
 5 files changed, 510 insertions(+)

diff --git a/python/mxnet/amp/lists/symbol_fp16.py b/python/mxnet/amp/lists/symbol_fp16.py
index b7e3dcb..4b54448 100644
--- a/python/mxnet/amp/lists/symbol_fp16.py
+++ b/python/mxnet/amp/lists/symbol_fp16.py
@@ -591,6 +591,9 @@ FP32_FUNCS = [
     '_npx_deformable_convolution',
     '_npx_modulated_deformable_convolution',
     '_contrib_DeformablePSROIPooling',
+    '_contrib_sldwin_atten_score',
+    '_contrib_sldwin_atten_mask_like',
+    '_contrib_sldwin_atten_context',
     ]
 
 if Features().is_enabled('MKLDNN'):
diff --git a/src/operator/contrib/transformer-inl.h b/src/operator/contrib/transformer-inl.h
index da48ffa..0e7f6c4 100644
--- a/src/operator/contrib/transformer-inl.h
+++ b/src/operator/contrib/transformer-inl.h
@@ -61,6 +61,228 @@ static void DivSqrtDimForward_(const nnvm::NodeAttrs& attrs,
   });
 }
 
+
+
+struct SldWinAttenParam : public dmlc::Parameter<SldWinAttenParam> {
+  int w;
+  bool symmetric;
+  DMLC_DECLARE_PARAMETER(SldWinAttenParam) {
+    DMLC_DECLARE_FIELD(w)
+    .describe("The one-sided window length");
+    DMLC_DECLARE_FIELD(symmetric)
+    .describe("If false, each token will only attend to itself and the previous tokens.");
+  }
+};
+
+
+struct SldWinAttenMaskLike {
+  MSHADOW_XINLINE static void Map(int i, float *out, int32_t *dilation, int32_t *val_length,
+                                  bool symmetric, int w, int seq_length, int num_heads) {
+    out[i] = 1;
+    int w_len = symmetric ? (w + w + 1) : (w + 1);
+    int idx_0 = i / (seq_length * num_heads * w_len);  // batch idx
+    int tmp = i % (seq_length * num_heads * w_len);
+    int idx_1 = tmp / (num_heads * w_len);  // sequence idx
+    tmp = tmp % (num_heads * w_len);
+    int idx_2 = tmp / w_len;  // head idx
+    int idx_3 = tmp % w_len;  // win idx
+
+    bool is_zero = idx_3 < (w - idx_1/dilation[idx_2]) || idx_1 >= val_length[idx_0] \
+      || (symmetric && (w_len-idx_3-1) < (w - (val_length[idx_0]-idx_1-1)/dilation[idx_2]));
+    if (is_zero) out[i] = 0;
+  }
+};
+
+
+template<typename xpu>
+void SldWinAttenMaskLikeForward(const nnvm::NodeAttrs& attrs,
+                                const OpContext& ctx,
+                                const std::vector<TBlob>& inputs,
+                                const std::vector<OpReqType>& req,
+                                const std::vector<TBlob>& outputs) {
+  CHECK_EQ(inputs.size(), 3U);
+  CHECK_EQ(outputs.size(), 1U);
+  CHECK_EQ(req.size(), 1U);
+  CHECK_EQ(req[0], kWriteTo) << "Currently only support kWriteTo";
+  using namespace mshadow;
+  Stream<xpu>* s = ctx.get_stream<xpu>();
+  const SldWinAttenParam& param = nnvm::get<SldWinAttenParam>(attrs.parsed);
+  CHECK_EQ(outputs[0].type_flag_, kFloat32);
+  float* out = outputs[0].dptr<float>();
+  CHECK_EQ(inputs[1].type_flag_, kInt32);
+  int32_t* dilation = inputs[1].dptr<int32_t>();
+  CHECK_EQ(inputs[2].type_flag_, kInt32);
+  int32_t* val_length = inputs[2].dptr<int32_t>();
+
+  int seq_length = inputs[0].shape_[1];
+  int num_heads = inputs[0].shape_[2];
+  int num_threads = outputs[0].Size();
+
+  mxnet_op::Kernel<SldWinAttenMaskLike, xpu>::Launch(s, num_threads, out, dilation,
+      val_length, param.symmetric, param.w, seq_length, num_heads);
+}
+
+
+
+
+struct DiagMM {
+  MSHADOW_XINLINE static void Map(int tid, float *out, float *lhs, float *rhs,
+                                  int32_t *dilation, int batch_size, int seq_length,
+                                  int num_heads, int out_last_dim, int lhs_last_dim, int w,
+                                  int w_right, bool diagonal_lhs, bool transpose_lhs) {
+    out[tid] = 0;
+    int stride = seq_length * num_heads * out_last_dim;
+    int idx_0 = tid / stride;  // batch idx
+    int tmp = tid % stride;
+    stride = num_heads * out_last_dim;
+    int idx_1 = tmp / stride;  // sequence idx
+    tmp = tmp % stride;
+    int idx_2 = tmp / out_last_dim;  // head idx
+    int idx_3 = tmp % out_last_dim;  // window idx or hidden feature idx
+
+    if (!diagonal_lhs) {
+      int tmp_idx = idx_1 + dilation[idx_2] * (idx_3 - w);
+      if (tmp_idx >= seq_length || tmp_idx < 0) return;
+      for (int i = 0; i < lhs_last_dim; i++) {
+        int lhs_idx = idx_0 * (seq_length * num_heads * lhs_last_dim) + \
+          idx_1 * (num_heads * lhs_last_dim) + idx_2 * lhs_last_dim + i;
+        int rhs_idx = idx_0 * (seq_length * num_heads * lhs_last_dim) + \
+          tmp_idx * (num_heads * lhs_last_dim) + idx_2 * lhs_last_dim + i;
+        out[tid] += lhs[lhs_idx] * rhs[rhs_idx];
+      }
+    } else {
+      if (!transpose_lhs) {
+        for (int i = 0; i < lhs_last_dim; i++) {
+          int tmp_idx = idx_1 + dilation[idx_2] * (i - w);
+          if (tmp_idx >= seq_length || tmp_idx < 0) continue;
+          int lhs_idx = idx_0 * (seq_length * num_heads * lhs_last_dim) + \
+            idx_1 * (num_heads * lhs_last_dim) + idx_2 * lhs_last_dim + i;
+          int rhs_idx = idx_0 * (seq_length * num_heads * out_last_dim) + \
+            tmp_idx * (num_heads * out_last_dim) + idx_2 * out_last_dim + idx_3;
+          out[tid] += lhs[lhs_idx] * rhs[rhs_idx];
+        }
+      } else {
+        for (int i = 0; i < lhs_last_dim; i++) {
+          int tmp_idx = idx_1 + dilation[idx_2] * (i - w_right);
+          if (tmp_idx >= seq_length || tmp_idx < 0) continue;
+          int lhs_idx = idx_0 * (seq_length * num_heads * lhs_last_dim) + \
+            tmp_idx * (num_heads * lhs_last_dim) + idx_2 * lhs_last_dim + ((w_right + w) - i);
+          int rhs_idx = idx_0 * (seq_length * num_heads * out_last_dim) + \
+            tmp_idx * (num_heads * out_last_dim) + idx_2 * out_last_dim + idx_3;
+          out[tid] += lhs[lhs_idx] * rhs[rhs_idx];
+        }
+      }
+    }
+  }
+};
+
+
+
+template<typename xpu>
+void DiagMMImpl(const OpContext& ctx, const TBlob& out, const TBlob& lhs,
+                const TBlob& rhs, const TBlob& dilation, bool diagonal_lhs,
+                bool transpose_lhs, int w, int w_right) {
+  using namespace mshadow;
+  Stream<xpu>* s = ctx.get_stream<xpu>();
+  CHECK_EQ(out.type_flag_, kFloat32);
+  CHECK_EQ(lhs.type_flag_, kFloat32);
+  CHECK_EQ(rhs.type_flag_, kFloat32);
+  CHECK_EQ(dilation.type_flag_, kInt32);
+
+  float* lhs_data = lhs.dptr<float>();
+  float* rhs_data = rhs.dptr<float>();
+  int32_t* dilation_data = dilation.dptr<int32_t>();
+  float* out_data = out.dptr<float>();
+
+  int batch_size = lhs.shape_[0];
+  int seq_length = lhs.shape_[1];
+  int num_heads = lhs.shape_[2];
+  int lhs_last_dim = lhs.shape_[3];
+  int out_last_dim = out.shape_[3];
+  int num_threads = out.Size();
+
+  mxnet_op::Kernel<DiagMM, xpu>::Launch(s, num_threads, out_data, lhs_data, rhs_data,
+      dilation_data, batch_size, seq_length, num_heads, out_last_dim, lhs_last_dim, w,
+      w_right, diagonal_lhs, transpose_lhs);
+}
+
+
+template<typename xpu>
+void SldWinAttenScoreForward(const nnvm::NodeAttrs& attrs,
+                             const OpContext& ctx,
+                             const std::vector<TBlob>& inputs,
+                             const std::vector<OpReqType>& req,
+                             const std::vector<TBlob>& outputs) {
+  CHECK_EQ(inputs.size(), 3U);
+  CHECK_EQ(outputs.size(), 1U);
+  using namespace mshadow;
+  const SldWinAttenParam& param = nnvm::get<SldWinAttenParam>(attrs.parsed);
+  int w_right = param.symmetric ? param.w : 0;
+  DiagMMImpl<xpu>(ctx, outputs.at(0), inputs.at(0), inputs.at(1), inputs.at(2),
+    false, false, param.w, w_right);
+}
+
+
+template<typename xpu>
+void SldWinAttenScoreBackward(const nnvm::NodeAttrs& attrs,
+                              const OpContext& ctx,
+                              const std::vector<TBlob>& inputs,
+                              const std::vector<OpReqType>& req,
+                              const std::vector<TBlob>& outputs) {
+  CHECK_EQ(inputs.size(), 4U);
+  CHECK_EQ(outputs.size(), 3U);
+  using namespace mshadow;
+  const SldWinAttenParam& param = nnvm::get<SldWinAttenParam>(attrs.parsed);
+  int w_right = param.symmetric ? param.w : 0;
+  // grad_query = matmul(grad_score, key)
+  DiagMMImpl<xpu>(ctx, outputs.at(0), inputs.at(0), inputs.at(2), inputs.at(3),
+      true, false, param.w, w_right);
+  // grad_key = matmul(grad_score.T, query)
+  DiagMMImpl<xpu>(ctx, outputs.at(1), inputs.at(0), inputs.at(1), inputs.at(3),
+      true, true, param.w, w_right);
+}
+
+
+
+template<typename xpu>
+void SldWinAttenContextForward(const nnvm::NodeAttrs& attrs,
+                               const OpContext& ctx,
+                               const std::vector<TBlob>& inputs,
+                               const std::vector<OpReqType>& req,
+                               const std::vector<TBlob>& outputs) {
+  CHECK_EQ(inputs.size(), 3U);
+  CHECK_EQ(outputs.size(), 1U);
+  using namespace mshadow;
+  const SldWinAttenParam& param = nnvm::get<SldWinAttenParam>(attrs.parsed);
+  int w_right = param.symmetric ? param.w : 0;
+  // context_vec = matmul(score, value)
+  DiagMMImpl<xpu>(ctx, outputs.at(0), inputs.at(0), inputs.at(1), inputs.at(2),
+      true, false, param.w, w_right);
+}
+
+
+template<typename xpu>
+void SldWinAttenContextBackward(const nnvm::NodeAttrs& attrs,
+                                const OpContext& ctx,
+                                const std::vector<TBlob>& inputs,
+                                const std::vector<OpReqType>& req,
+                                const std::vector<TBlob>& outputs) {
+  CHECK_EQ(inputs.size(), 4U);
+  CHECK_EQ(outputs.size(), 3U);
+  using namespace mshadow;
+  const SldWinAttenParam& param = nnvm::get<SldWinAttenParam>(attrs.parsed);
+  int w_right = param.symmetric ? param.w : 0;
+  // grad_score = matmul(grad_context, value.T)
+  DiagMMImpl<xpu>(ctx, outputs.at(0), inputs.at(0), inputs.at(2), inputs.at(3),
+      false, false, param.w, w_right);
+  // grad_value = matmul(score.T, grad_context)
+  DiagMMImpl<xpu>(ctx, outputs.at(1), inputs.at(1), inputs.at(0), inputs.at(3),
+      true, true, param.w, w_right);
+}
+
+
+
+
 }  // namespace op
 }  // namespace mxnet
 #endif  // MXNET_OPERATOR_CONTRIB_TRANSFORMER_INL_H_
diff --git a/src/operator/contrib/transformer.cc b/src/operator/contrib/transformer.cc
index a839811..43c322e 100644
--- a/src/operator/contrib/transformer.cc
+++ b/src/operator/contrib/transformer.cc
@@ -841,5 +841,196 @@ MXNET_OPERATOR_REGISTER_UNARY(_contrib_div_sqrt_dim)
 .set_attr<FCompute>("FCompute<cpu>", DivSqrtDimForward_<cpu>)
 .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_contrib_div_sqrt_dim"});
 
+
+DMLC_REGISTER_PARAMETER(SldWinAttenParam);
+
+NNVM_REGISTER_OP(_contrib_sldwin_atten_mask_like)
+.add_alias("_npx_sldwin_atten_mask_like")
+.describe(R"code(Compute the mask for the sliding window attention score, used in
+Longformer (https://arxiv.org/pdf/2004.05150.pdf). In this attention pattern,
+given a fixed window size *2w*, each token attends to *w* tokens on the left side
+if we use causal attention (setting *symmetric* to *False*),
+otherwise each token attends to *w* tokens on each side.
+
+The shapes of the inputs are:
+- *score* : (batch_size, seq_length, num_heads, w + w + 1) if symmetric is True,
+            otherwise (batch_size, seq_length, num_heads, w + 1).
+- *dilation* : (num_heads,)
+- *valid_length* : (batch_size,)
+
+The shape of the output is:
+- *mask* : same as the shape of *score*
+
+)code" ADD_FILELINE)
+.set_num_inputs(3)
+.set_num_outputs(1)
+.set_attr_parser(ParamParser<SldWinAttenParam>)
+.set_attr<nnvm::FListInputNames>("FListInputNames", [](const NodeAttrs& attrs) {
+  return std::vector<std::string>{"score", "dilation", "valid_length"};
+})
+.set_attr<nnvm::FListOutputNames>("FListOutputNames", [](const NodeAttrs& attrs) {
+  return std::vector<std::string>{"mask"};
+})
+.set_attr<mxnet::FInferShape>("FInferShape", [](const nnvm::NodeAttrs& attrs,
+                                              mxnet::ShapeVector *in_attrs,
+                                              mxnet::ShapeVector *out_attrs) {
+  CHECK_EQ(in_attrs->size(), 3U);
+  CHECK_EQ(out_attrs->size(), 1U);
+  const mxnet::TShape& dshape = (*in_attrs)[0];
+  if (!shape_is_known(dshape)) return false;
+  SHAPE_ASSIGN_CHECK(*out_attrs, 0, dshape);
+  return true;
+})
+.set_attr<nnvm::FInferType>("FInferType", [](const nnvm::NodeAttrs &attrs,
+                                             std::vector<int> *in_attrs,
+                                             std::vector<int> *out_attrs) {
+  CHECK_EQ(in_attrs->size(), 3U);
+  CHECK_EQ(out_attrs->size(), 1U);
+  TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kFloat32);
+  return out_attrs->at(0) != -1;
+})
+.set_attr<FCompute>("FCompute<cpu>", SldWinAttenMaskLikeForward<cpu>)
+.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
+.add_argument("score", "NDArray-or-Symbol", "sliding window attention score")
+.add_argument("dilation", "NDArray-or-Symbol", "dilation")
+.add_argument("valid_length", "NDArray-or-Symbol", "valid length")
+.add_arguments(SldWinAttenParam::__FIELDS__());
+
+
+NNVM_REGISTER_OP(_contrib_sldwin_atten_score)
+.add_alias("_npx_sldwin_atten_score")
+.describe(R"code(Compute the sliding window attention score, which is used in
+Longformer (https://arxiv.org/pdf/2004.05150.pdf). In this attention pattern,
+given a fixed window size *2w*, each token attends to *w* tokens on the left side
+if we use causal attention (setting *symmetric* to *False*),
+otherwise each token attends to *w* tokens on each side.
+
+The shapes of the inputs are:
+- *query* : (batch_size, seq_length, num_heads, num_head_units)
+- *key* : (batch_size, seq_length, num_heads, num_head_units)
+- *dilation* : (num_heads,)
+
+The shape of the output is:
+- *score* : (batch_size, seq_length, num_heads, w + w + 1) if symmetric is True,
+            otherwise (batch_size, seq_length, num_heads, w + 1).
+
+)code" ADD_FILELINE)
+.set_num_inputs(3)
+.set_num_outputs(1)
+.set_attr_parser(ParamParser<SldWinAttenParam>)
+.set_attr<nnvm::FListInputNames>("FListInputNames", [](const NodeAttrs& attrs) {
+  return std::vector<std::string>{"query", "key", "dilation"};
+})
+.set_attr<nnvm::FListOutputNames>("FListOutputNames", [](const NodeAttrs& attrs) {
+  return std::vector<std::string>{"score"};
+})
+.set_attr<mxnet::FInferShape>("FInferShape", [](const nnvm::NodeAttrs& attrs,
+                                              mxnet::ShapeVector *inshapes,
+                                              mxnet::ShapeVector *outshapes) {
+  unsigned int batch_size = inshapes->at(0)[0];
+  unsigned int seq_length = inshapes->at(0)[1];
+  unsigned int num_heads = inshapes->at(0)[2];
+  unsigned int lhs_last_dim = inshapes->at(0)[3];
+  unsigned int num_hidden = inshapes->at(1)[3];
+  CHECK_EQ(lhs_last_dim, num_hidden);
+  CHECK_EQ(inshapes->at(2)[0], num_heads);
+  const SldWinAttenParam& param = nnvm::get<SldWinAttenParam>(attrs.parsed);
+  unsigned int w_len = param.symmetric ? (param.w + param.w + 1) : (param.w + 1);
+  outshapes->at(0) = mshadow::Shape4(batch_size, seq_length, num_heads, w_len);
+  return true;
+})
+.set_attr<nnvm::FInferType>("FInferType", [](const nnvm::NodeAttrs &attrs,
+                                             std::vector<int> *in_attrs,
+                                             std::vector<int> *out_attrs) {
+  CHECK_EQ(in_attrs->size(), 3U);
+  CHECK_EQ(out_attrs->size(), 1U);
+  TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kFloat32);
+  return out_attrs->at(0) != -1;
+})
+.set_attr<FCompute>("FCompute<cpu>", SldWinAttenScoreForward<cpu>)
+.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_sldwin_atten_score"})
+.add_argument("query", "NDArray-or-Symbol", "query")
+.add_argument("key", "NDArray-or-Symbol", "key")
+.add_argument("dilation", "NDArray-or-Symbol", "dilation")
+.add_arguments(SldWinAttenParam::__FIELDS__());
+
+
+NNVM_REGISTER_OP(_backward_sldwin_atten_score)
+.set_num_inputs(4)
+.set_num_outputs(3)
+.set_attr<nnvm::TIsBackward>("TIsBackward", true)
+.set_attr_parser(ParamParser<SldWinAttenParam>)
+.set_attr<FCompute>("FCompute<cpu>", SldWinAttenScoreBackward<cpu>);
+
+
+NNVM_REGISTER_OP(_contrib_sldwin_atten_context)
+.add_alias("_npx_sldwin_atten_context")
+.describe(R"code(Compute the context vector for sliding window attention, used in
+Longformer (https://arxiv.org/pdf/2004.05150.pdf). In this attention pattern,
+given a fixed window size *2w*, each token attends to *w* tokens on the left side
+if we use causal attention (setting *symmetric* to *False*),
+otherwise each token attends to *w* tokens on each side.
+
+The shapes of the inputs are:
+- *score* : (batch_size, seq_length, num_heads, w + w + 1) if symmetric is True,
+            otherwise (batch_size, seq_length, num_heads, w + 1).
+- *value* : (batch_size, seq_length, num_heads, num_head_units)
+- *dilation* : (num_heads,)
+
+The shape of the output is:
+- *context_vec* : (batch_size, seq_length, num_heads, num_head_units)
+
+)code" ADD_FILELINE)
+.set_num_inputs(3)
+.set_num_outputs(1)
+.set_attr_parser(ParamParser<SldWinAttenParam>)
+.set_attr<nnvm::FListInputNames>("FListInputNames", [](const NodeAttrs& attrs) {
+  return std::vector<std::string>{"score", "value", "dilation"};
+})
+.set_attr<nnvm::FListOutputNames>("FListOutputNames", [](const NodeAttrs& attrs) {
+  return std::vector<std::string>{"context_vec"};
+})
+.set_attr<mxnet::FInferShape>("FInferShape", [](const nnvm::NodeAttrs& attrs,
+                                              mxnet::ShapeVector *inshapes,
+                                              mxnet::ShapeVector *outshapes) {
+  unsigned int batch_size = inshapes->at(0)[0];
+  unsigned int seq_length = inshapes->at(0)[1];
+  unsigned int num_heads = inshapes->at(0)[2];
+  unsigned int lhs_last_dim = inshapes->at(0)[3];
+  unsigned int num_hidden = inshapes->at(1)[3];
+  CHECK_EQ(inshapes->at(2)[0], num_heads);
+  const SldWinAttenParam& param = nnvm::get<SldWinAttenParam>(attrs.parsed);
+  unsigned int w_len = param.symmetric ? (param.w + param.w + 1) : (param.w + 1);
+  CHECK_EQ(lhs_last_dim, w_len);
+
+  outshapes->at(0) = mshadow::Shape4(batch_size, seq_length, num_heads, num_hidden);
+
+  return true;
+})
+.set_attr<nnvm::FInferType>("FInferType", [](const nnvm::NodeAttrs &attrs,
+                                             std::vector<int> *in_attrs,
+                                             std::vector<int> *out_attrs) {
+  CHECK_EQ(in_attrs->size(), 3U);
+  CHECK_EQ(out_attrs->size(), 1U);
+  TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kFloat32);
+  return out_attrs->at(0) != -1;
+})
+.set_attr<FCompute>("FCompute<cpu>", SldWinAttenContextForward<cpu>)
+.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_sldwin_atten_context"})
+.add_argument("score", "NDArray-or-Symbol", "score")
+.add_argument("value", "NDArray-or-Symbol", "value")
+.add_argument("dilation", "NDArray-or-Symbol", "dilation")
+.add_arguments(SldWinAttenParam::__FIELDS__());
+
+
+NNVM_REGISTER_OP(_backward_sldwin_atten_context)
+.set_num_inputs(4)
+.set_num_outputs(3)
+.set_attr<nnvm::TIsBackward>("TIsBackward", true)
+.set_attr_parser(ParamParser<SldWinAttenParam>)
+.set_attr<FCompute>("FCompute<cpu>", SldWinAttenContextBackward<cpu>);
+
+
+
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/contrib/transformer.cu b/src/operator/contrib/transformer.cu
index bfa4993..ab20f1c 100644
--- a/src/operator/contrib/transformer.cu
+++ b/src/operator/contrib/transformer.cu
@@ -682,5 +682,21 @@ NNVM_REGISTER_OP(_backward_interleaved_matmul_encdec_valatt)
 NNVM_REGISTER_OP(_contrib_div_sqrt_dim)
 .set_attr<FCompute>("FCompute<gpu>", DivSqrtDimForward_<gpu>);
 
+
+NNVM_REGISTER_OP(_contrib_sldwin_atten_mask_like)
+.set_attr<FCompute>("FCompute<gpu>", SldWinAttenMaskLikeForward<gpu>);
+
+NNVM_REGISTER_OP(_contrib_sldwin_atten_score)
+.set_attr<FCompute>("FCompute<gpu>", SldWinAttenScoreForward<gpu>);
+
+NNVM_REGISTER_OP(_backward_sldwin_atten_score)
+.set_attr<FCompute>("FCompute<gpu>", SldWinAttenScoreBackward<gpu>);
+
+NNVM_REGISTER_OP(_contrib_sldwin_atten_context)
+.set_attr<FCompute>("FCompute<gpu>", SldWinAttenContextForward<gpu>);
+
+NNVM_REGISTER_OP(_backward_sldwin_atten_context)
+.set_attr<FCompute>("FCompute<gpu>", SldWinAttenContextBackward<gpu>);
+
 }  // namespace op
 }  // namespace mxnet
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index 3660d7f..b3a6ce4 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -9250,3 +9250,81 @@ def test_broadcast_ops_on_misaligned_input_oneside(dtype, lead_dim, both_ways):
     mx.nd.waitall()
     assert_almost_equal(f, expected)
 
+
+def test_sldwin_selfatten_operators():
+    def gen_sliding_window_mask_full(batch_size, num_heads, seq_length, w, symmetric, d):
+        mask_np = np.zeros((batch_size, num_heads, seq_length, seq_length))
+        for i in range(seq_length):
+            end = (i + 1 + w * d) if symmetric else (i + 1)
+            for j in range(i - w * d, end, d):
+                if j >= 0 and j < seq_length:
+                    mask_np[:, :, i, j] = 1
+        return mask_np
+
+    def test_sldwin_atten_op_impl(batch_size, seq_length, num_heads,
+                                  num_head_units, w, symmetric, d):
+        # Generate the data
+        query = np.random.normal(0, 1, (batch_size, seq_length, num_heads, num_head_units))
+        key = np.random.normal(0, 1, (batch_size, seq_length, num_heads, num_head_units))
+        value = np.random.normal(0, 1, (batch_size, seq_length, num_heads, num_head_units))
+        valid_length = np.zeros((batch_size,))
+        valid_length[:] = seq_length
+
+        query = mx.np.array(query, dtype=np.float32)
+        key = mx.np.array(key, dtype=np.float32)
+        value = mx.np.array(value, dtype=np.float32)
+        dilation = mx.np.ones((num_heads,), dtype=np.int32)
+        dilation[:] = d
+        valid_length = mx.np.array(valid_length, dtype=np.int32)
+
+        query.attach_grad()
+        key.attach_grad()
+        value.attach_grad()
+
+        with mx.autograd.record():
+            score = mx.npx.sldwin_atten_score(query, key, dilation,
+                w=w, symmetric=symmetric)
+            mask = mx.npx.sldwin_atten_mask_like(score, dilation, valid_length,
+                w=w, symmetric=symmetric)
+            score = score * mask
+            out = mx.npx.sldwin_atten_context(score, value, dilation,
+                w=w, symmetric=symmetric)
+            out.backward()
+
+        out_np = out.asnumpy()
+        grad_query = query.grad.asnumpy()
+        grad_key = key.grad.asnumpy()
+        grad_value = value.grad.asnumpy()
+
+        query.grad[:] = 0
+        key.grad[:] = 0
+        value.grad[:] = 0
+
+        mask_np = gen_sliding_window_mask_full(batch_size, num_heads, seq_length,
+                                               w, symmetric, d)
+        mask = mx.np.array(mask_np, dtype=np.float32)
+
+        with mx.autograd.record():
+            score = mx.npx.batch_dot(mx.np.swapaxes(query, 1, 2),
+                                     mx.np.swapaxes(key, 1, 2),
+                                     transpose_b=True)
+            score = score * mask
+            out = mx.npx.batch_dot(score,
+                                   mx.np.swapaxes(value, 1, 2)).transpose((0, 2, 1, 3))
+            out.backward()
+
+        out_np_gt = out.asnumpy()
+        grad_query_gt = query.grad.asnumpy()
+        grad_key_gt = key.grad.asnumpy()
+        grad_value_gt = value.grad.asnumpy()
+
+        assert_allclose(out_np_gt, out_np, 1E-3, 1E-3)
+        assert_allclose(grad_query_gt, grad_query, 1E-3, 1E-3)
+        assert_allclose(grad_key_gt, grad_key, 1E-3, 1E-3)
+        assert_allclose(grad_value_gt, grad_value, 1E-3, 1E-3)
+
+    for symmetric in [True, False]:
+        for d in [1, 2, 3]:
+            test_sldwin_atten_op_impl(2, 128, 2, 8, 16, symmetric, d)
+            test_sldwin_atten_op_impl(1, 8, 2, 4, 2, symmetric, d)
+