You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2021/04/20 13:25:47 UTC

[incubator-mxnet] branch v1.x updated: [Feature] Add oneDNN support for interleaved_matmul_selfatt_* operators (fp32/int8) (#20163)

This is an automated email from the ASF dual-hosted git repository.

zhasheng pushed a commit to branch v1.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/v1.x by this push:
     new 16d1da9  [Feature] Add oneDNN support for interleaved_matmul_selfatt_* operators (fp32/int8) (#20163)
16d1da9 is described below

commit 16d1da9055140c53029dcdfa290e2b36ad26b20f
Author: bgawrych <ba...@intel.com>
AuthorDate: Tue Apr 20 15:22:53 2021 +0200

    [Feature] Add oneDNN support for interleaved_matmul_selfatt_* operators (fp32/int8) (#20163)
    
    * Add oneDNN code to interleved kernels
    
    * check
    
    * Fix selfattQK subgraph
    
    * fix qk
    
    * Fixes QK
    
    * add test for oneDNN self_att qk
    
    * basic valatt
    
    * add valatt test
    
    * refactor valatt
    
    * fix review
    
    * Change param struct name
    
    * Fix sanity
    
    * Fix sanity
    
    Co-authored-by: grygielski <ad...@gmail.com>
---
 .../subgraph/mkldnn/mkldnn_subgraph_property.cc    |  23 +-
 .../subgraph/mkldnn/mkldnn_transformer-inl.h       |  58 ++
 src/operator/subgraph/mkldnn/mkldnn_transformer.cc | 670 +++++++++++++++++++++
 .../mkldnn_transformer_post_quantize_property.h    | 207 +++++++
 .../subgraph/mkldnn/mkldnn_transformer_property.h  | 136 +++++
 tests/python/mkl/test_subgraph.py                  | 110 +++-
 6 files changed, 1187 insertions(+), 17 deletions(-)

diff --git a/src/operator/subgraph/mkldnn/mkldnn_subgraph_property.cc b/src/operator/subgraph/mkldnn/mkldnn_subgraph_property.cc
index 18cd303..9190ba4 100644
--- a/src/operator/subgraph/mkldnn/mkldnn_subgraph_property.cc
+++ b/src/operator/subgraph/mkldnn/mkldnn_subgraph_property.cc
@@ -25,6 +25,8 @@
 #include "mkldnn_fc_post_quantize_property.h"
 #include "mkldnn_elemwisemul_post_quantize_property.h"
 #include "mkldnn_post_quantize_align_scale_property.h"
+#include "mkldnn_transformer_property.h"
+#include "mkldnn_transformer_post_quantize_property.h"
 
 namespace mxnet {
 namespace op {
@@ -35,34 +37,29 @@ MXNET_REGISTER_SUBGRAPH_BACKEND(MKLDNN)
 
 MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN, SgMKLDNNConvProperty);
 
-#endif  // MXNET_USE_MKLDNN == 1
-#if MXNET_USE_MKLDNN == 1
 MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN, SgMKLDNNFCProperty);
-#endif  // MXNET_USE_MKLDNN == 1
-#if MXNET_USE_MKLDNN == 1
+
+MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN, SgMKLDNNTransformerProperty);
+
 MXNET_REGISTER_SUBGRAPH_BACKEND(MKLDNN_QUANTIZE)
 .set_attr("context", Context::CPU());
 
 MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, SgMKLDNNConvProperty)
 .set_attr("quantize", true);
 
-#endif  // MXNET_USE_MKLDNN == 1
-#if MXNET_USE_MKLDNN == 1
-
 MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, SgMKLDNNFCProperty)
 .set_attr("quantize", true);
-#endif  // MXNET_USE_MKLDNN == 1
-#if MXNET_USE_MKLDNN == 1
+
+MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, SgMKLDNNTransformerProperty);
+
+MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, SgMKLDNNTransformerPostQuantizeProperty);
+
 MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, SgMKLDNNPostQuantizeProperty);
-#endif  // MXNET_USE_MKLDNN == 1
 
-#if MXNET_USE_MKLDNN == 1
 MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, SgMKLDNNFCPostQuantizeProperty);
 MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, ElemwiseMulPostQuantizeProperty);
 
 MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, SgMKLDNNPostQuantizeAlignScaleProperty);
-#endif  // MXNET_USE_MKLDNN == 1
-#if MXNET_USE_MKLDNN == 1
 }  // namespace op
 }  // namespace mxnet
 #endif  // MXNET_USE_MKLDNN == 1
diff --git a/src/operator/subgraph/mkldnn/mkldnn_transformer-inl.h b/src/operator/subgraph/mkldnn/mkldnn_transformer-inl.h
new file mode 100644
index 0000000..d400435
--- /dev/null
+++ b/src/operator/subgraph/mkldnn/mkldnn_transformer-inl.h
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_TRANSFORMER_INL_H_
+#define MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_TRANSFORMER_INL_H_
+
+#include "../../mxnet_op.h"
+#include "../../mshadow_op.h"
+
+
+namespace mxnet {
+namespace op {
+
+struct MKLDNNSelfAttParam : public dmlc::Parameter<MKLDNNSelfAttParam> {
+  int heads;
+  bool quantized;
+  bool enable_float_output;
+  dmlc::optional<float> min_calib_range;  // min float value calculated from calibration dataset
+  dmlc::optional<float> max_calib_range;  // max float value calculated from calibration dataset
+  DMLC_DECLARE_PARAMETER(MKLDNNSelfAttParam) {
+    DMLC_DECLARE_FIELD(heads)
+    .describe("Set number of heads");
+    DMLC_DECLARE_FIELD(quantized).set_default(false)
+    .describe("Whether it's a quantized InterleavedMatMul operator");
+    DMLC_DECLARE_FIELD(enable_float_output).set_default(false)
+    .describe("Whether to enable float32 output");
+    DMLC_DECLARE_FIELD(min_calib_range)
+    .set_default(dmlc::optional<float>())
+    .describe("The minimum scalar value in the form of float32 obtained "
+              "through calibration. If present, it will be used to by "
+              "quantized InterleavedMatMul op to calculate primitive scale");
+    DMLC_DECLARE_FIELD(max_calib_range)
+    .set_default(dmlc::optional<float>())
+    .describe("The maximum scalar value in the form of float32 obtained "
+              "through calibration. If present, it will be used to by "
+              "quantized InterleavedMatMul op to calculate primitive scale");
+  }
+};
+
+}  // namespace op
+}  // namespace mxnet
+#endif  // MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_TRANSFORMER_INL_H_
diff --git a/src/operator/subgraph/mkldnn/mkldnn_transformer.cc b/src/operator/subgraph/mkldnn/mkldnn_transformer.cc
new file mode 100644
index 0000000..8757da0
--- /dev/null
+++ b/src/operator/subgraph/mkldnn/mkldnn_transformer.cc
@@ -0,0 +1,670 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements.  See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership.  The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License.  You may obtain a copy of the License at
+*
+*   http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an
+* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+* KIND, either express or implied.  See the License for the
+* specific language governing permissions and limitations
+* under the License.
+*/
+
+#if MXNET_USE_MKLDNN == 1
+
+#include <utility>
+#include <vector>
+#include <string>
+#include "../common.h"
+#include "./mkldnn_transformer-inl.h"
+#include "../../contrib/transformer-inl.h"
+#include "../../tensor/elemwise_unary_op.h"
+
+#include "../../quantization/quantization_utils.h"
+
+namespace mxnet {
+namespace op {
+
+DMLC_REGISTER_PARAMETER(MKLDNNSelfAttParam);
+
+template<int base_num_inputs>
+static bool SgMKLDNNSelfAttShape(const NodeAttrs& attrs,
+                                 mxnet::ShapeVector* in_shapes,
+                                 mxnet::ShapeVector* out_shapes) {
+  const auto& param = nnvm::get<MKLDNNSelfAttParam>(attrs.parsed);
+  if (param.quantized) {
+    mxnet::ShapeVector base_in_shapes;
+    mxnet::ShapeVector base_out_shapes = {out_shapes->at(0)};
+
+    for (int i = 0; i < base_num_inputs; i++) {
+      base_in_shapes.emplace_back(in_shapes->at(i));
+    }
+    bool ret = DefaultSubgraphOpShape(attrs, &base_in_shapes, &base_out_shapes);
+
+    for (size_t i = 0; i < in_shapes->size(); ++i) {
+      if (i < base_in_shapes.size())
+        in_shapes->at(i) = base_in_shapes[i];
+      else
+        SHAPE_ASSIGN_CHECK(*in_shapes, i, mxnet::TShape({1}));
+    }
+    out_shapes->resize(3);
+    out_shapes->at(0) = base_out_shapes[0];
+    if (!param.enable_float_output) {
+      SHAPE_ASSIGN_CHECK(*out_shapes, 1, mxnet::TShape({1}));      // min output
+      SHAPE_ASSIGN_CHECK(*out_shapes, 2, mxnet::TShape({1}));      // max output
+    }
+
+    return ret;
+  } else {
+    return DefaultSubgraphOpShape(attrs, in_shapes, out_shapes);
+  }
+}
+
+static bool SgMKLDNNSelfAttQKInferType(const nnvm::NodeAttrs &attrs,
+                                       std::vector<int> *in_types,
+                                       std::vector<int> *out_types) {
+  const auto& param = nnvm::get<MKLDNNSelfAttParam>(attrs.parsed);
+  if (param.quantized) {
+     CHECK(in_types->at(0) == mshadow::kInt8)
+        << "QuantizedInterleavedMatMulSelfAttQK only supports int8 input, while "
+        << in_types->at(0) << " is given.";
+
+    TYPE_ASSIGN_CHECK(*in_types, 1, mshadow::kFloat32);  // min value
+    TYPE_ASSIGN_CHECK(*in_types, 2, mshadow::kFloat32);  // max value
+
+    if (param.enable_float_output) {
+      TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kFloat32);     // output
+    } else {
+      if (param.min_calib_range.has_value() && param.max_calib_range.has_value()) {
+        TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kInt8);      // output
+      } else {
+        TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kInt32);     // output
+      }
+      TYPE_ASSIGN_CHECK(*out_types, 1, mshadow::kFloat32);     // min output
+      TYPE_ASSIGN_CHECK(*out_types, 2, mshadow::kFloat32);     // max output
+    }
+    return true;
+  } else {
+    return DefaultSubgraphOpType(attrs, in_types, out_types);
+  }
+}
+
+template<int base_num_inputs>
+static bool SgMKLDNNSelfAttStorageType(const nnvm::NodeAttrs &attrs,
+                                       const int dev_mask,
+                                       DispatchMode *dispatch_mode,
+                                       std::vector<int> *in_attrs,
+                                       std::vector<int> *out_attrs) {
+  auto const &param = nnvm::get<MKLDNNSelfAttParam>(attrs.parsed);
+  if (param.quantized) {
+    std::vector<int> base_in_attrs;
+    std::vector<int> base_out_attrs{out_attrs->at(0)};
+
+    for (int i = 0; i < base_num_inputs; i++) {
+      base_in_attrs.emplace_back(in_attrs->at(i));
+    }
+    bool ret = DefaultSubgraphOpStorageType(attrs, dev_mask, dispatch_mode,
+                                            &base_in_attrs, &base_out_attrs);
+
+    for (size_t i = 0; i < in_attrs->size(); ++i) {
+      if (i < base_in_attrs.size())
+        in_attrs->at(i) = base_in_attrs[i];
+      else
+        type_assign(&in_attrs->at(i), mxnet::kDefaultStorage);
+    }
+
+    out_attrs->at(0) = base_out_attrs[0];
+    if (!param.enable_float_output) {
+      type_assign(&out_attrs->at(1), mxnet::kDefaultStorage);
+      type_assign(&out_attrs->at(2), mxnet::kDefaultStorage);
+    }
+    return ret;
+  } else {
+    return DefaultSubgraphOpStorageType(attrs, dev_mask, dispatch_mode,
+                                        in_attrs, out_attrs);
+  }
+}
+
+class SgMKLDNNSelfAttQKOp {
+ public:
+  explicit SgMKLDNNSelfAttQKOp(const nnvm::NodeAttrs &attrs) :
+    param_(nnvm::get<MKLDNNSelfAttParam>(attrs.parsed)) {}
+
+  void Forward(const OpContext &ctx,
+               const std::vector<NDArray> &inputs,
+               const std::vector<OpReqType> &req,
+               const std::vector<NDArray> &outputs);
+
+  void Backward(const OpContext &ctx,
+                const std::vector<NDArray> &inputs,
+                const std::vector<OpReqType> &req,
+                const std::vector<NDArray> &outputs) {
+    LOG(FATAL) << "Not implemented: subgraph mkldnn fully connected only supports "
+                  "inference computation.";
+  }
+
+  void Initialize(const OpContext &ctx,
+                  const std::vector<NDArray> &inputs,
+                  const std::vector<OpReqType> &req,
+                  const std::vector<NDArray> &outputs);
+
+  bool IsInitialized() {
+    return initialized_;
+  }
+
+ private:
+  bool initialized_{false};
+  MKLDNNSelfAttParam param_;
+  mkldnn_args_map_t args_;
+  std::shared_ptr<dnnl::matmul> fwd_;
+  std::shared_ptr<dnnl::memory> cached_query_mem_;
+  std::shared_ptr<dnnl::memory> cached_key_mem_;
+  std::shared_ptr<dnnl::memory> cached_out_mem_;
+  float min_data_;
+  float max_data_;
+  float min_output_;
+  float max_output_;
+  float data_scale_{0.0f};
+};
+
+static OpStatePtr CreateSgMKLDNNSelfAttQKState(const nnvm::NodeAttrs &attrs,
+                                               Context ctx,
+                                               const mxnet::ShapeVector &in_shapes,
+                                               const std::vector<int> &in_types) {
+  return OpStatePtr::Create<SgMKLDNNSelfAttQKOp>(attrs);
+}
+
+static void SgMKLDNNSelfAttQKForward(const OpStatePtr &state_pointer,
+                                     const OpContext &ctx,
+                                     const std::vector<NDArray> &inputs,
+                                     const std::vector<OpReqType> &req,
+                                     const std::vector<NDArray> &outputs) {
+  SgMKLDNNSelfAttQKOp &op = state_pointer.get_state<SgMKLDNNSelfAttQKOp>();
+  if (!op.IsInitialized()) {
+    op.Initialize(ctx, inputs, req, outputs);
+  }
+  op.Forward(ctx, inputs, req, outputs);
+}
+
+void SgMKLDNNSelfAttQKOp::Initialize(const OpContext &ctx,
+                                     const std::vector<NDArray> &inputs,
+                                     const std::vector<OpReqType> &req,
+                                     const std::vector<NDArray> &outputs) {
+    using namespace mkldnn;
+    const auto qkv_tensor = inputs[0];
+    const auto out_tensor = outputs[0];
+    const auto qkv_dtype = get_mkldnn_type(qkv_tensor.dtype());
+
+    const memory::dim heads          = param_.heads;
+    const memory::dim sequences      = inputs[0].shape()[1];
+    const memory::dim qkv_seq_len    = inputs[0].shape()[0];
+    const memory::dim output_lin_dim = inputs[0].shape()[2];
+    const memory::dim embed_dim      = output_lin_dim / 3;
+    const memory::dim head_dim       = embed_dim / heads;
+    const memory::dim attn_batches   = heads * sequences;
+    const memory::dim lead_dim       = attn_batches * 3 * head_dim;
+    const memory::dim batch_stride   = 3 * head_dim;
+
+    float min_data = 0.0f;
+    float max_data = 0.0f;
+
+    if (param_.quantized) {
+      min_data_ = inputs[1].data().dptr<float>()[0];
+      max_data_ = inputs[2].data().dptr<float>()[0];
+    }
+
+    const auto engine = CpuEngine::Get()->get_engine();
+
+    memory::dims query_dims    = {attn_batches, qkv_seq_len, head_dim};
+    memory::dims key_dims      = {attn_batches, head_dim, qkv_seq_len};
+    memory::dims out_dims      = {attn_batches, qkv_seq_len, qkv_seq_len};
+
+    memory::dims query_strides = {batch_stride, lead_dim, 1};
+    memory::dims key_strides   = {batch_stride, 1, lead_dim};
+
+    auto query_md = memory::desc(query_dims, qkv_dtype, query_strides);
+    auto key_md   = memory::desc(key_dims, qkv_dtype, key_strides);
+
+    memory::desc out_md;
+
+    float oscale = 1.0f;
+    if (param_.quantized) {
+      data_scale_ = GetQuantizeScale(qkv_tensor.dtype(), min_data_, max_data_);
+
+      if (param_.min_calib_range.has_value() &&
+          param_.max_calib_range.has_value()) {
+        min_output_ = param_.min_calib_range.value();
+        max_output_ = param_.max_calib_range.value();
+        oscale =
+            GetQuantizeScale(out_tensor.dtype(), min_output_, max_output_) /
+            (data_scale_ * data_scale_);
+        out_md = memory::desc(out_dims, memory::data_type::s8, memory::format_tag::abc);
+      } else if (param_.enable_float_output) {
+        oscale = 1.0f / (data_scale_ * data_scale_);
+        out_md = dnnl::memory::desc(out_dims, memory::data_type::f32, memory::format_tag::abc);
+      } else {
+        mshadow::Stream<cpu> *s = ctx.get_stream<cpu>();
+        mxnet_op::Kernel<QuantizationRangeForS8S8MultiplicationStruct, cpu>::Launch(
+              s, 1, &min_output_, &max_output_, &min_data, &max_data, &min_data,
+              &max_data);
+        out_md = dnnl::memory::desc(out_dims, memory::data_type::s32, memory::format_tag::abc);
+      }
+    } else {
+      out_md = dnnl::memory::desc(out_dims, memory::data_type::f32, memory::format_tag::abc);
+    }
+    oscale /= sqrt(static_cast<float>(head_dim));  // combine quantized scale and sqrt(head_dim)
+
+    dnnl::primitive_attr attr;
+    attr.set_output_scales(0, {oscale});
+    auto matmul_d = matmul::desc(query_md, key_md, out_md);
+    auto matmul_pd = matmul::primitive_desc(matmul_d, attr, engine);
+
+    fwd_ = std::make_shared<matmul>(matmul_pd);
+
+    MSHADOW_TYPE_SWITCH(inputs[0].dtype(), DType, {
+      DType* query_mem_ptr = inputs[0].data().dptr<DType>();
+      DType* key_mem_ptr   = query_mem_ptr + head_dim;
+      cached_query_mem_ = std::make_shared<memory>(query_md, engine, query_mem_ptr);
+      cached_key_mem_ = std::make_shared<memory>(key_md, engine, key_mem_ptr);
+    });
+    MSHADOW_TYPE_SWITCH(outputs[0].dtype(), DType, {
+      cached_out_mem_ = std::make_shared<memory>(out_md, engine, outputs[0].data().dptr<DType>());
+    });
+
+    args_[DNNL_ARG_SRC]     = *cached_query_mem_;
+    args_[DNNL_ARG_WEIGHTS] = *cached_key_mem_;
+    args_[DNNL_ARG_DST]     = *cached_out_mem_;
+    initialized_ = true;
+}
+
+
+void SgMKLDNNSelfAttQKOp::Forward(const OpContext &ctx,
+                                  const std::vector<NDArray> &inputs,
+                                  const std::vector<OpReqType> &req,
+                                  const std::vector<NDArray> &outputs) {
+    const size_t head_dim = inputs[0].shape()[2] / 3 / param_.heads;
+
+    MSHADOW_TYPE_SWITCH(inputs[0].dtype(), DType, {
+      DType* query_mem_ptr = inputs[0].data().dptr<DType>();
+      DType* key_mem_ptr   = query_mem_ptr + head_dim;
+      cached_query_mem_->set_data_handle(query_mem_ptr);
+      cached_key_mem_->set_data_handle(key_mem_ptr);
+    });
+
+    MSHADOW_TYPE_SWITCH(outputs[0].dtype(), DType, {
+      cached_out_mem_->set_data_handle(outputs[0].data().dptr<DType>());
+    });
+
+    MKLDNNStream::Get()->RegisterPrimArgs(*fwd_, args_);
+    MKLDNNStream::Get()->Submit();
+
+    if (param_.quantized && !param_.enable_float_output) {
+      float* output_min = outputs[1].data().dptr<float>();
+      float* output_max = outputs[2].data().dptr<float>();
+
+      *output_min = min_output_;
+      *output_max = max_output_;
+    }
+}
+
+nnvm::ObjectPtr SgMKLDNNSelfAttQKQuantizedOp(const NodeAttrs& attrs) {
+  nnvm::ObjectPtr node = nnvm::Node::Create();
+  auto const &param = nnvm::get<MKLDNNSelfAttParam>(attrs.parsed);
+  node->attrs.op = Op::Get("_sg_mkldnn_selfatt_qk");
+  node->attrs.name = "quantized_" + attrs.name;
+  node->attrs.dict = attrs.dict;
+  node->attrs.dict["heads"] = std::to_string(param.heads);
+  node->attrs.dict["quantized"] = "True";
+  node->attrs.subgraphs.reserve(attrs.subgraphs.size());
+  for (auto sub : attrs.subgraphs) {
+    node->attrs.subgraphs.push_back(sub);
+  }
+  node->op()->attr_parser(&(node->attrs));
+  return node;
+}
+
+NNVM_REGISTER_OP(_sg_mkldnn_selfatt_qk)
+.describe(R"code(_sg_mkldnn_selfatt_qk)code" ADD_FILELINE)
+.set_num_inputs([](const NodeAttrs& attrs) {
+  auto const& param = nnvm::get<MKLDNNSelfAttParam>(attrs.parsed);
+  if (param.quantized) {
+    return 3;
+  } else {
+    return 1;
+  }
+})
+.set_num_outputs([](const NodeAttrs& attrs) {
+  auto const& param = nnvm::get<MKLDNNSelfAttParam>(attrs.parsed);
+  if (param.quantized && !param.enable_float_output) {
+    return 3;
+  } else {
+    return 1;
+  }
+})
+.set_attr_parser(ParamParser<MKLDNNSelfAttParam>)
+.set_attr<nnvm::FListInputNames>("FListInputNames", [](const NodeAttrs& attrs) {
+  auto const& param = nnvm::get<MKLDNNSelfAttParam>(attrs.parsed);
+  std::vector<std::string> input_names {"queries_keys_values"};
+  if (param.quantized) {
+    input_names.emplace_back("min_qkv");
+    input_names.emplace_back("max_qkv");
+  }
+  return input_names;
+})
+.set_attr<nnvm::FListOutputNames>("FListOutputNames", [](const NodeAttrs& attrs) {
+  auto const& param = nnvm::get<MKLDNNSelfAttParam>(attrs.parsed);
+  std::vector<std::string> output_names {"output"};
+  if (param.quantized && !param.enable_float_output) {
+    output_names.emplace_back("min_output");
+    output_names.emplace_back("max_output");
+  }
+  return output_names;
+})
+.set_attr<mxnet::FInferShape>("FInferShape", SgMKLDNNSelfAttShape<1>)
+.set_attr<nnvm::FInferType>("FInferType", SgMKLDNNSelfAttQKInferType)
+.set_attr<FInferStorageType>("FInferStorageType", SgMKLDNNSelfAttStorageType<1>)
+.set_attr<FCreateOpState>("FCreateOpState", CreateSgMKLDNNSelfAttQKState)
+.set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", SgMKLDNNSelfAttQKForward)
+.set_attr<bool>("TIsMKLDNN", true)
+.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
+.set_attr<FQuantizable>("FQuantizable", [](const NodeAttrs& attrs) {
+    return QuantizeType::kMust;
+})
+.set_attr<FQuantizedOp>("FQuantizedOp", SgMKLDNNSelfAttQKQuantizedOp)
+.set_attr<FNeedRequantize>("FNeedRequantize", [](const NodeAttrs& attrs) { return true; })
+.add_argument("queries_keys_values", "NDArray-or-Symbol", "Interleaved queries, keys and values")
+.add_arguments(MKLDNNSelfAttParam::__FIELDS__());
+
+/**********************************_sg_mkldnn_selfatt_valatt**********************************/
+
+static bool SgMKLDNNSelfAttValAttInferType(const nnvm::NodeAttrs &attrs,
+                                         std::vector<int> *in_types,
+                                         std::vector<int> *out_types) {
+  const auto& param = nnvm::get<MKLDNNSelfAttParam>(attrs.parsed);
+  if (param.quantized) {
+    TYPE_ASSIGN_CHECK(*in_types, 0, mshadow::kInt8);   // qkv input
+    TYPE_ASSIGN_CHECK(*in_types, 1, mshadow::kUint8);  // att input
+
+    // min qkv, max qkv, min att, max att
+    for (size_t i = 2; i < in_types->size(); ++i) {
+      TYPE_ASSIGN_CHECK(*in_types, i, mshadow::kFloat32);
+    }
+
+    if (param.enable_float_output) {
+      TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kFloat32);     // output
+    } else {
+      if (param.min_calib_range.has_value() && param.max_calib_range.has_value()) {
+        TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kInt8);      // output
+      } else {
+        TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kInt32);     // output
+      }
+      TYPE_ASSIGN_CHECK(*out_types, 1, mshadow::kFloat32);     // min output
+      TYPE_ASSIGN_CHECK(*out_types, 2, mshadow::kFloat32);     // max output
+    }
+    return true;
+  } else {
+    return DefaultSubgraphOpType(attrs, in_types, out_types);
+  }
+}
+
+nnvm::ObjectPtr SgMKLDNNSelfAttValAttQuantizedOp(const NodeAttrs& attrs) {
+  nnvm::ObjectPtr node = nnvm::Node::Create();
+  auto const &param = nnvm::get<MKLDNNSelfAttParam>(attrs.parsed);
+  node->attrs.op = Op::Get("_sg_mkldnn_selfatt_valatt");
+  node->attrs.name = "quantized_" + attrs.name;
+  node->attrs.dict = attrs.dict;
+  node->attrs.dict["heads"] = std::to_string(param.heads);
+  node->attrs.dict["quantized"] = "True";
+  node->attrs.subgraphs.reserve(attrs.subgraphs.size());
+  for (auto sub : attrs.subgraphs) {
+    node->attrs.subgraphs.push_back(sub);
+  }
+  node->op()->attr_parser(&(node->attrs));
+  return node;
+}
+
+class MKLDNNSelfAttValAttOp {
+ public:
+  explicit MKLDNNSelfAttValAttOp(const nnvm::NodeAttrs &attrs) :
+    param_(nnvm::get<MKLDNNSelfAttParam>(attrs.parsed)) {}
+
+  void Forward(const OpContext &ctx,
+               const std::vector<NDArray> &inputs,
+               const std::vector<OpReqType> &req,
+               const std::vector<NDArray> &outputs);
+
+  void Backward(const OpContext &ctx,
+                const std::vector<NDArray> &inputs,
+                const std::vector<OpReqType> &req,
+                const std::vector<NDArray> &outputs) {
+    LOG(FATAL) << "Not implemented: subgraph mkldnn fully connected only supports "
+                  "inference computation.";
+  }
+
+  void Initialize(const OpContext &ctx,
+                  const std::vector<NDArray> &inputs,
+                  const std::vector<OpReqType> &req,
+                  const std::vector<NDArray> &outputs);
+
+  bool IsInitialized() {
+    return initialized_;
+  }
+
+ private:
+  bool initialized_{false};
+  MKLDNNSelfAttParam param_;
+  mkldnn_args_map_t args_;
+  std::shared_ptr<dnnl::matmul> fwd_;
+  std::shared_ptr<dnnl::memory> cached_att_mem_;
+  std::shared_ptr<dnnl::memory> cached_qkv_mem_;
+  std::shared_ptr<dnnl::memory> cached_out_mem_;
+  float min_qkv_;
+  float max_qkv_;
+  float min_att_;
+  float max_att_;
+  float min_output_;
+  float max_output_;
+  float qkv_scale_{0.0f};
+  float att_scale_{0.0f};
+};
+
+static OpStatePtr CreateMKLDNNSelfAttValAttState(const nnvm::NodeAttrs &attrs,
+                                                 Context ctx,
+                                                 const mxnet::ShapeVector &in_shapes,
+                                                 const std::vector<int> &in_types) {
+  return OpStatePtr::Create<MKLDNNSelfAttValAttOp>(attrs);
+}
+
+static void MKLDNNSelfAttValAttForward(const OpStatePtr &state_pointer,
+                                       const OpContext &ctx,
+                                       const std::vector<NDArray> &inputs,
+                                       const std::vector<OpReqType> &req,
+                                       const std::vector<NDArray> &outputs) {
+  MKLDNNSelfAttValAttOp &op = state_pointer.get_state<MKLDNNSelfAttValAttOp>();
+  if (!op.IsInitialized()) {
+    op.Initialize(ctx, inputs, req, outputs);
+  }
+  op.Forward(ctx, inputs, req, outputs);
+}
+
+void MKLDNNSelfAttValAttOp::Initialize(const OpContext &ctx,
+                                       const std::vector<NDArray> &inputs,
+                                       const std::vector<OpReqType> &req,
+                                       const std::vector<NDArray> &outputs) {
+  const dnnl::memory::dim qkv_seq_len    = inputs[0].shape()[0];
+  const dnnl::memory::dim sequences      = inputs[0].shape()[1];
+  const dnnl::memory::dim output_lin_dim = inputs[0].shape()[2];
+  const dnnl::memory::dim embed_dim      = output_lin_dim / 3;
+  const dnnl::memory::dim head_dim       = embed_dim / param_.heads;
+  const dnnl::memory::dim attn_batches   = param_.heads * sequences;
+  const dnnl::memory::dim lead_dim       = attn_batches * 3 * head_dim;
+  const dnnl::memory::dim batch_stride   = 3 * head_dim;
+
+
+  dnnl::memory::dims att_dims = {attn_batches, qkv_seq_len, qkv_seq_len};
+  dnnl::memory::dims qkv_dims = {attn_batches, qkv_seq_len, head_dim};
+  dnnl::memory::dims dst_dims = {attn_batches, qkv_seq_len, head_dim};
+
+  dnnl::memory::dims att_strides = {qkv_seq_len * qkv_seq_len, qkv_seq_len, 1};
+  dnnl::memory::dims qkv_strides = {batch_stride, lead_dim, 1};
+
+  auto att_dtype = inputs[1].dtype();
+  auto qkv_dtype = inputs[0].dtype();
+  auto out_dtype = outputs[0].dtype();
+  auto att_md = dnnl::memory::desc(att_dims, get_mkldnn_type(att_dtype), att_strides);
+  auto qkv_md = dnnl::memory::desc(qkv_dims, get_mkldnn_type(qkv_dtype), qkv_strides);
+
+  dnnl::memory::desc out_md;
+  dnnl::primitive_attr attr;
+
+  float oscale = 1.0f;
+  if (param_.quantized) {
+    min_qkv_ = inputs[2].data().dptr<float>()[0];
+    max_qkv_ = inputs[3].data().dptr<float>()[0];
+    min_att_ = inputs[4].data().dptr<float>()[0];
+    max_att_ = inputs[5].data().dptr<float>()[0];
+    qkv_scale_ = GetQuantizeScale(qkv_dtype, min_qkv_, max_qkv_);
+    att_scale_ = GetQuantizeScale(att_dtype, min_att_, max_att_);
+
+    if (param_.min_calib_range.has_value() &&
+        param_.max_calib_range.has_value()) {
+      min_output_ = param_.min_calib_range.value();
+      max_output_ = param_.max_calib_range.value();
+
+      oscale = GetQuantizeScale(out_dtype, min_output_, max_output_)  / (qkv_scale_ * att_scale_);
+      attr.set_output_scales(0, {oscale});
+    } else if (param_.enable_float_output) {
+      oscale = 1.0f / (qkv_scale_ * att_scale_);
+      attr.set_output_scales(0, {oscale});
+    } else {
+      mshadow::Stream<cpu> *s = ctx.get_stream<cpu>();
+      mxnet_op::Kernel<QuantizationRangeForS8U8MultiplicationStruct, cpu>::Launch(
+              s, 1, &min_output_, &max_output_, &min_qkv_, &max_qkv_, &min_att_,
+              &max_att_);
+    }
+  }
+  out_md = dnnl::memory::desc(dst_dims, get_mkldnn_type(out_dtype), dnnl::memory::format_tag::bac);
+
+  const auto engine = CpuEngine::Get()->get_engine();
+  auto matmul_d = dnnl::matmul::desc(att_md, qkv_md, out_md);
+  auto matmul_pd = dnnl::matmul::primitive_desc(matmul_d, attr, engine);
+
+  fwd_ = std::make_shared<dnnl::matmul>(matmul_pd);
+
+  MSHADOW_TYPE_SWITCH(att_dtype, DType, {
+    DType* att_ptr = inputs[1].data().dptr<DType>();
+    cached_att_mem_ = std::make_shared<dnnl::memory>(att_md, engine, att_ptr);
+  });
+  MSHADOW_TYPE_SWITCH(qkv_dtype, DType, {
+    DType* value_ptr = inputs[0].data().dptr<DType>() + 2*head_dim;
+    cached_qkv_mem_ = std::make_shared<dnnl::memory>(qkv_md, engine, value_ptr);
+  });
+  MSHADOW_TYPE_SWITCH(out_dtype, DType, {
+    DType* out_ptr = outputs[0].data().dptr<DType>();
+    cached_out_mem_ = std::make_shared<dnnl::memory>(out_md, engine, out_ptr);
+  });
+
+  args_[DNNL_ARG_SRC] = *cached_att_mem_;
+  args_[DNNL_ARG_WEIGHTS] = *cached_qkv_mem_;
+  args_[DNNL_ARG_DST] = *cached_out_mem_;
+  initialized_ = true;
+}
+
+void MKLDNNSelfAttValAttOp::Forward(const OpContext &ctx,
+                                    const std::vector<NDArray> &inputs,
+                                    const std::vector<OpReqType> &req,
+                                    const std::vector<NDArray> &outputs) {
+  const auto engine = CpuEngine::Get()->get_engine();
+  const size_t head_dim = inputs[0].shape()[2] / param_.heads / 3;
+  MSHADOW_TYPE_SWITCH(inputs[1].dtype(), DType, {
+    DType* att_ptr  = inputs[1].data().dptr<DType>();
+    cached_att_mem_->set_data_handle(att_ptr);
+  });
+  MSHADOW_TYPE_SWITCH(inputs[0].dtype(), DType, {
+    DType* value_ptr = inputs[0].data().dptr<DType>() + 2*head_dim;
+    cached_qkv_mem_->set_data_handle(value_ptr);
+  });
+  MSHADOW_TYPE_SWITCH(outputs[0].dtype(), DType, {
+    DType* out_ptr  = outputs[0].data().dptr<DType>();
+    cached_out_mem_->set_data_handle(out_ptr);
+  });
+
+  MKLDNNStream::Get()->RegisterPrimArgs(*fwd_, args_);
+  MKLDNNStream::Get()->Submit();
+
+  if (param_.quantized && !param_.enable_float_output) {
+    float* output_min = outputs[1].data().dptr<float>();
+    float* output_max = outputs[2].data().dptr<float>();
+
+    *output_min = min_output_;
+    *output_max = max_output_;
+  }
+}
+
+NNVM_REGISTER_OP(_sg_mkldnn_selfatt_valatt)
+.describe(R"code(_sg_mkldnn_selfatt_valatt)code" ADD_FILELINE)
+.set_num_inputs([](const NodeAttrs& attrs) {
+  auto const& param = nnvm::get<MKLDNNSelfAttParam>(attrs.parsed);
+  if (param.quantized) {
+    return 6;
+  } else {
+    return 2;
+  }
+})
+.set_num_outputs([](const NodeAttrs& attrs) {
+  auto const& param = nnvm::get<MKLDNNSelfAttParam>(attrs.parsed);
+  if (param.quantized && !param.enable_float_output) {
+    return 3;
+  } else {
+    return 1;
+  }
+})
+.set_attr_parser(ParamParser<MKLDNNSelfAttParam>)
+.set_attr<nnvm::FListInputNames>("FListInputNames", [](const NodeAttrs& attrs) {
+  auto const& param = nnvm::get<MKLDNNSelfAttParam>(attrs.parsed);
+  std::vector<std::string> input_names {"queries_keys_values", "attention"};
+  if (param.quantized) {
+    input_names.emplace_back("min_qkv");
+    input_names.emplace_back("max_qkv");
+
+    input_names.emplace_back("min_attention");
+    input_names.emplace_back("max_attention");
+  }
+  return input_names;
+})
+.set_attr<nnvm::FListOutputNames>("FListOutputNames", [](const NodeAttrs& attrs) {
+  auto const& param = nnvm::get<MKLDNNSelfAttParam>(attrs.parsed);
+  std::vector<std::string> output_names {"output"};
+  if (param.quantized && !param.enable_float_output) {
+    output_names.emplace_back("min_output");
+    output_names.emplace_back("max_output");
+  }
+  return output_names;
+})
+.set_attr<mxnet::FInferShape>("FInferShape", SgMKLDNNSelfAttShape<2>)
+.set_attr<nnvm::FInferType>("FInferType", SgMKLDNNSelfAttValAttInferType)
+.set_attr<FInferStorageType>("FInferStorageType", SgMKLDNNSelfAttStorageType<2>)
+.set_attr<FCreateOpState>("FCreateOpState", CreateMKLDNNSelfAttValAttState)
+.set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", MKLDNNSelfAttValAttForward)
+.set_attr<bool>("TIsMKLDNN", true)
+.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
+.set_attr<FQuantizable>("FQuantizable", [](const NodeAttrs& attrs) {
+    return QuantizeType::kMust;
+})
+.set_attr<FQuantizedOp>("FQuantizedOp", SgMKLDNNSelfAttValAttQuantizedOp)
+.set_attr<FNeedRequantize>("FNeedRequantize", [](const NodeAttrs& attrs) { return true; })
+.add_argument("queries_keys_values", "NDArray-or-Symbol", "Queries, keys and values interleaved")
+.add_argument("attention", "NDArray-or-Symbol", "Attention maps")
+.add_arguments(MKLDNNSelfAttParam::__FIELDS__());
+
+}  // namespace op
+}  // namespace mxnet
+
+#endif
diff --git a/src/operator/subgraph/mkldnn/mkldnn_transformer_post_quantize_property.h b/src/operator/subgraph/mkldnn/mkldnn_transformer_post_quantize_property.h
new file mode 100644
index 0000000..adf6230
--- /dev/null
+++ b/src/operator/subgraph/mkldnn/mkldnn_transformer_post_quantize_property.h
@@ -0,0 +1,207 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_TRANSFORMER_POST_QUANTIZE_PROPERTY_H_
+#define MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_TRANSFORMER_POST_QUANTIZE_PROPERTY_H_
+#if MXNET_USE_MKLDNN == 1
+
+#include <string>
+#include <vector>
+#include "../../quantization/requantize-inl.h"
+#include "../common.h"
+#include "mkldnn_subgraph_base-inl.h"
+
+namespace mxnet {
+namespace op {
+
+class SgMKLDNNTransformerPostQuantizeSelector : public SubgraphSelector {
+ public:
+  /*! \brief pattern match status */
+  enum SelectStatus {
+    kFail = 0,
+    kStart,
+    kRequantize,
+    kSuccess,
+  };
+
+ private:
+  bool disable_all;
+  bool disable_float_output;
+  SelectStatus status;
+  std::vector<const nnvm::Node *> matched_list;
+
+ public:
+  explicit SgMKLDNNTransformerPostQuantizeSelector(const bool dis_all,
+                                                   const bool dis_float_output)
+      : disable_all(dis_all),
+        disable_float_output(dis_float_output) {}
+
+  bool Select(const nnvm::Node &n) override {
+    if ((!disable_all) &&
+        (n.op() == Op::Get("_sg_mkldnn_selfatt_qk") ||
+         n.op() == Op::Get("_sg_mkldnn_selfatt_valatt"))) {
+      status = disable_all ? kSuccess : kStart;
+      matched_list.clear();
+      matched_list.push_back(&n);
+      return true;
+    }
+    return false;
+  }
+
+  bool SelectInput(const nnvm::Node &n, const nnvm::Node &new_node) override {
+    return false;
+  }
+
+  bool SelectOutput(const nnvm::Node &n, const nnvm::Node &new_node) override {
+    if (status == kFail || status == kSuccess || new_node.is_variable())
+      return false;
+    // If n isn't the last matched node, then we encoutered a internal
+    // branch, we should pop out the node behind n and stop fusion.
+    if (matched_list.back() != &n) {
+      if (std::find(matched_list.begin(), matched_list.end(), &n) !=
+        matched_list.end()) {
+        while (matched_list.back() != &n) {
+          matched_list.pop_back();
+        }
+      }
+
+      status = kSuccess;
+      return false;
+    }
+
+    switch (status) {
+      case kStart:
+        if (new_node.op() == Op::Get("_contrib_requantize")) {
+          auto const &param = nnvm::get<RequantizeParam>(new_node.attrs.parsed);
+          if (param.min_calib_range.has_value() &&
+              param.max_calib_range.has_value()) {
+            matched_list.push_back(&new_node);
+            status = kRequantize;
+            return true;
+          }
+        }
+      case kRequantize:
+        if ((!disable_float_output) && (new_node.op() == Op::Get("_contrib_dequantize"))) {
+            matched_list.push_back(&new_node);
+            status = kSuccess;
+            return true;
+        }
+      default:
+        status = kSuccess;
+        return false;
+    }
+  }
+
+  std::vector<nnvm::Node *> Filter(
+      const std::vector<nnvm::Node *> &candidates) override {
+    if ((status != kSuccess) || (matched_list.size() <= 1)) {
+      return std::vector<nnvm::Node *>(0);
+    } else {
+      std::vector<nnvm::Node *> ret;
+      for (auto i : matched_list) {
+        auto non_const_i = const_cast<nnvm::Node *>(i);
+        if (std::find(candidates.begin(), candidates.end(), non_const_i) !=
+            candidates.end()) {
+          ret.push_back(non_const_i);
+        }
+      }
+      return ret;
+    }
+  }
+
+  void Reset() override {
+    CHECK_GE(matched_list.size(), 1);
+    auto new_selector = SgMKLDNNTransformerPostQuantizeSelector(disable_all, disable_float_output);
+    new_selector.Select(*matched_list[0]);
+    *this = new_selector;
+  }
+};
+
+class SgMKLDNNTransformerPostQuantizeProperty : public SubgraphProperty {
+ public:
+  SgMKLDNNTransformerPostQuantizeProperty() {
+    disable_fuse_all = dmlc::GetEnv("MXNET_DISABLE_MKLDNN_QTRANSFORMER_FUSE_ALL", false);
+    disable_float_output = dmlc::GetEnv("MXNET_DISABLE_MKLDNN_QTRANSFORMER_FLOAT_OUTPUT", false);
+  }
+
+  static SubgraphPropertyPtr Create() {
+    static const std::string &name = "MKLDNN Transformer post-quantization optimization pass";
+    auto property = std::make_shared<SgMKLDNNTransformerPostQuantizeProperty>();
+    property->SetAttr<std::string>("property_name", name);
+    property->SetAttr<bool>("inference_only", true);
+    return property;
+  }
+
+  nnvm::ObjectPtr CreateSubgraphNode(const nnvm::Symbol &sym,
+                                   const int subgraph_id = 0) const override {
+    nnvm::ObjectPtr interleaved_node = nullptr;
+    nnvm::ObjectPtr requantize_node = nullptr;
+    nnvm::ObjectPtr dequantize_node = nullptr;
+
+    DFSVisit(sym.outputs, [&](const nnvm::ObjectPtr &node) {
+      if (node->is_variable()) return;
+      if (node->op() == Op::Get("_sg_mkldnn_selfatt_qk") ||
+          node->op() == Op::Get("_sg_mkldnn_selfatt_valatt")) {
+        interleaved_node = node;
+      } else if (node->op() == Op::Get("_contrib_requantize")) {
+        requantize_node = node;
+      } else if (node->op() == Op::Get("_contrib_dequantize")) {
+        dequantize_node = node;
+      }
+    });
+
+    CHECK_NOTNULL(interleaved_node);
+    CHECK_NOTNULL(requantize_node);
+    auto const &requantize_param =
+        nnvm::get<RequantizeParam>(requantize_node->attrs.parsed);
+    CHECK(requantize_param.min_calib_range.has_value());
+    CHECK(requantize_param.max_calib_range.has_value());
+
+    // When only fusing quantized_interleaved_matmul and requantize, set min/max_cablib_range,
+    // When fusing quantized_interleaved_matmul + requantize + dequantize,
+    // set dequantize flag to true.
+    if (dequantize_node != nullptr) {
+      interleaved_node->attrs.dict["enable_float_output"] = "True";
+    } else {
+      interleaved_node->attrs.dict["min_calib_range"] =
+          std::to_string(requantize_param.min_calib_range.value());
+      interleaved_node->attrs.dict["max_calib_range"] =
+          std::to_string(requantize_param.max_calib_range.value());
+    }
+    interleaved_node->op()->attr_parser(&(interleaved_node->attrs));
+    return interleaved_node;
+  }
+
+  SubgraphSelectorPtr CreateSubgraphSelector() const override {
+    auto selector =
+        std::make_shared<SgMKLDNNTransformerPostQuantizeSelector>(disable_fuse_all,
+                                                         disable_float_output);
+    return selector;
+  }
+
+ private:
+  bool disable_fuse_all;
+  bool disable_float_output;
+};
+
+}  // namespace op
+}  // namespace mxnet
+
+#endif  // if MXNET_USE_MKLDNN == 1
+#endif  // MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_TRANSFORMER_POST_QUANTIZE_PROPERTY_H_
diff --git a/src/operator/subgraph/mkldnn/mkldnn_transformer_property.h b/src/operator/subgraph/mkldnn/mkldnn_transformer_property.h
new file mode 100644
index 0000000..f022bcc
--- /dev/null
+++ b/src/operator/subgraph/mkldnn/mkldnn_transformer_property.h
@@ -0,0 +1,136 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+
+#ifndef MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_TRANSFORMER_PROPERTY_H_
+#define MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_TRANSFORMER_PROPERTY_H_
+#if MXNET_USE_MKLDNN == 1
+
+#include <map>
+#include <string>
+#include <vector>
+#include "../common.h"
+#include "../../tensor/matrix_op-inl.h"
+#include "../../contrib/transformer-inl.h"
+#include "mkldnn_transformer-inl.h"
+#include "mkldnn_subgraph_base-inl.h"
+
+namespace mxnet {
+namespace op {
+
+#define SELFATT_QK     "_contrib_interleaved_matmul_selfatt_qk"
+#define SELFATT_VALATT "_contrib_interleaved_matmul_selfatt_valatt"
+
+const std::map<std::string, std::string> OpMapping = {
+  {SELFATT_QK,     "_sg_mkldnn_selfatt_qk"},
+  {SELFATT_VALATT, "_sg_mkldnn_selfatt_valatt"}
+};
+
+const std::map<std::string, std::string> NameMapping = {
+  {SELFATT_QK,     "sg_mkldnn_selfatt_qk"},
+  {SELFATT_VALATT, "sg_mkldnn_selfatt_valatt"}
+};
+
+class SgMKLDNNTransformerSelector : public SubgraphSelector {
+ public:
+  bool Select(const nnvm::Node &n, const std::shared_ptr<NodeAttr>& node_attr) override {
+    if (n.op() == Op::Get(SELFATT_QK) ||
+        n.op() == Op::Get(SELFATT_VALATT)) {
+      return true;
+    }
+    return false;
+  }
+
+  bool SelectInput(const nnvm::Node &n, const nnvm::Node &new_node) override {
+    return false;
+  }
+
+  bool SelectOutput(const nnvm::Node &n, const nnvm::Node &new_node) override {
+    return false;
+  }
+};
+
+class SgMKLDNNTransformerProperty : public SubgraphProperty {
+ public:
+  SgMKLDNNTransformerProperty() {}
+
+  static SubgraphPropertyPtr Create() {
+    static const std::string &name = "MKLDNN Transformer optimization pass";
+    auto property = std::make_shared<SgMKLDNNTransformerProperty>();
+    property->SetAttr<std::string>("property_name", name);
+    property->SetAttr<bool>("inference_only", true);
+    if (dmlc::GetEnv("MXNET_DISABLE_MKLDNN_TRANSFORMER_OPT", 0)) {
+      property->SetAttr<bool>("disable", true);
+    }
+    return property;
+  }
+
+  nnvm::ObjectPtr CreateSubgraphNode(const nnvm::Symbol &sym,
+                                   const int subgraph_id = 0) const override {
+    nnvm::ObjectPtr n = nnvm::Node::Create();
+    // This op has single output, remove duplicated.
+    auto last_node = sym.outputs[0].node;
+    nnvm::Symbol new_sym;
+    new_sym.outputs.emplace_back(last_node);
+    std::ostringstream node_name;
+    std::string op_name;
+    MKLDNNSelfAttParam new_param;
+    DFSVisit(new_sym.outputs, [&](const nnvm::ObjectPtr &node) {
+      if (node->op() &&
+          (node->op()->name == SELFATT_QK ||
+           node->op()->name == SELFATT_VALATT)) {
+        op_name = node->op()->name;
+        auto param = nnvm::get<InterleavedMatMulParam>(node->attrs.parsed);
+        new_param.heads = param.heads;
+        new_param.quantized = false;
+        new_param.enable_float_output = false;
+      }
+    });
+    node_name << NameMapping.at(op_name) << "_" << std::to_string(subgraph_id);
+
+
+    n->attrs.name = node_name.str();
+    n->attrs.op = Op::Get(OpMapping.at(op_name));
+    CHECK(n->attrs.op);
+    n->attrs.subgraphs.emplace_back(std::make_shared<nnvm::Symbol>(new_sym));
+    n->attrs.parsed = new_param;
+    return n;
+  }
+
+  SubgraphSelectorPtr CreateSubgraphSelector() const override {
+    auto selector = std::make_shared<SgMKLDNNTransformerSelector>();
+    return selector;
+  }
+
+  void ConnectSubgraphOutputs(
+      const nnvm::ObjectPtr n,
+      std::vector<nnvm::NodeEntry *> *output_entries) const override {
+    // Connect all extern output entries to output[0]
+    for (size_t i = 0; i < output_entries->size(); ++i) {
+      auto entry_ptr = output_entries->at(i);
+      *entry_ptr = nnvm::NodeEntry{n, entry_ptr->index, 0};
+    }
+  }
+};
+
+}  // namespace op
+}  // namespace mxnet
+
+#endif  // if MXNET_USE_MKLDNN == 1
+#endif  // MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_TRANSFORMER_PROPERTY_H_
diff --git a/tests/python/mkl/test_subgraph.py b/tests/python/mkl/test_subgraph.py
index 65b73e4..79494a0 100644
--- a/tests/python/mkl/test_subgraph.py
+++ b/tests/python/mkl/test_subgraph.py
@@ -45,6 +45,14 @@ config =  {
   'fc': {
     OP_NAME: 'sg_mkldnn_fully_connected',
     QUANTIZED_OP_NAME: 'quantized_sg_mkldnn_fully_connected'
+  },
+  'selfatt_qk': {
+    OP_NAME: 'sg_mkldnn_selfatt_qk',
+    QUANTIZED_OP_NAME: 'quantized_sg_mkldnn_selfatt_qk'
+  },
+  'selfatt_valatt': {
+    OP_NAME: 'sg_mkldnn_selfatt_valatt',
+    QUANTIZED_OP_NAME: 'quantized_sg_mkldnn_selfatt_valatt'
   }
 }
 
@@ -52,6 +60,10 @@ DATA_SHAPE=[(64, 4, 10, 10), (4, 3, 24, 24), (1, 16, 32, 32)]
 fc_post_ops_list=['relu', 'sigmoid', 'tanh', 'softrelu',
                   'square', 'square_root', 'abs', 'exp', 'bounded_relu']
 
+quant_op_fp32_output_support = ("quantized_sg_mkldnn_fully_connected",
+                                "quantized_sg_mkldnn_selfatt_qk",
+                                "quantized_sg_mkldnn_selfatt_valatt")
+
 def check_qsym_calibrated(qsym, out_type, name='conv'):
   quantized_op_name = 'quantized_' + name
   assert ''.join(qsym.attr_dict().keys()).find(quantized_op_name) != -1
@@ -59,7 +71,8 @@ def check_qsym_calibrated(qsym, out_type, name='conv'):
     if k.find('_quantize') != -1:
       assert v['out_type'] == out_type
     if k.find(quantized_op_name) != -1:
-      if quantized_op_name.startswith("quantized_sg_mkldnn_fully_connected") and 'enable_float_output' in v:
+      if ('enable_float_output' in v
+          and quantized_op_name.startswith(quant_op_fp32_output_support)):
         continue
       assert 'min_calib_range' in v
       assert 'max_calib_range' in v
@@ -119,9 +132,11 @@ def check_qsym_gluon_forward(qsym, qarg_params, qaux_params, data_shape):
 class CalibIter(mx.io.DataIter):
     def __init__(self, batch, data_shape, batch_size):
         super(CalibIter, self).__init__(batch_size)
-        self.data_shape = data_shape
         self.label_shape = (batch_size,)
-        self.provide_data = [('data', self.data_shape)]
+        if isinstance(data_shape, tuple):
+          self.provide_data = [('data', data_shape)]
+        else:
+          self.provide_data = data_shape
         self.provide_label = []
         self.batch = batch
 
@@ -249,7 +264,6 @@ def check_fusion(sym, data_shape, attrs_dict, check_fp32_fusion=True, check_quan
     if ''.join(sym.get_internals().list_outputs()).find('sqrt') != -1:
       check_quantization = False
       data_min = 0
-
     sym_sg = sym.get_backend_symbol(SG_PASS_NAME)
     for name, attrs in attrs_dict.items():
       if name in config:
@@ -677,6 +691,13 @@ def fc_eltwise(no_bias, data_shape, flatten=True, alg='relu'):
 
   return sym, attr
 
+def single_selfatt_qk(data_shape, nheads=16):
+  attr = {'selfatt_qk': {}}
+  data = mx.symbol.Variable('data', shape=data_shape, dtype='float32')
+  qk = mx.symbol.contrib.interleaved_matmul_selfatt_qk(queries_keys_values=data,
+                                                       heads=nheads)
+  return qk, attr
+
 # fc + relu can't be fusion case
 # eg.1
 # fc -----------> relu
@@ -866,6 +887,87 @@ def test_fc_eltwise():
       check_fusion(syms, dshape, attrs, check_quantization=False)
 
 @with_seed()
+def test_selfatt_qk():
+  batchsizes = [1, 8]
+  seq_lengths = [180, 384]
+  num_hidden = [1024, 3072]
+  num_heads = [8, 16]
+  for bs, seqlen, nhidden, nheads in itertools.product(batchsizes, seq_lengths, num_hidden, num_heads):
+    dshape = (seqlen, bs, nhidden)
+    syms, attrs = single_selfatt_qk(dshape, nheads)
+    check_fusion(syms, dshape, attrs, out_types=['int8', 'auto'], check_quantization=True)
+
+@with_seed()
+def test_selfatt_valatt():
+  batchsizes = [1, 8]
+  seq_lengths = [18, 255, 384]
+  num_hidden = [1024, 3072]
+  num_heads = [1, 16]
+
+  def get_valatt_symbol(qkv_shape, attention_shape, nheads):
+      qkv = mx.symbol.Variable('qkv', shape=qkv_shape, dtype='float32')
+      attention = mx.symbol.Variable('attention', shape=attention_shape, dtype='float32')
+      # CalibIter assumes that batch_size is always first dimension
+      # following operators changes shapes to the proper one
+      qkv_swap = mx.symbol.swapaxes(data=qkv, dim1=0, dim2=1)
+      attention_reshape = mx.symbol.reshape(data=attention, shape=(-1, 0, 0), reverse=True)
+      sym = mx.symbol.contrib.interleaved_matmul_selfatt_valatt(queries_keys_values=qkv_swap,
+                                                                attention=attention_reshape,
+                                                                heads=nheads)
+      return sym
+
+  def check_valatt_quantize(sym, qkv_shape, att_shape):
+      qkv_nd    = mx.nd.random.uniform(low=-1, high=1, shape=qkv_shape)
+      weight_nd  = mx.nd.random.uniform(low=0, high=1, shape=att_shape)
+      arg_params = {
+          'qkv': qkv_nd,
+          'attention': weight_nd
+      }
+
+      ex = sym.bind(mx.cpu(), arg_params, args_grad=None)
+      ex.forward()
+      ref_out = ex.outputs
+
+      sym_sg = sym.get_backend_symbol(QUANTIZE_SG_PASS_NAME)
+
+      batch = mx.io.DataBatch([qkv_nd, weight_nd], [])
+      calib_data = CalibIter(batch, [('qkv', qkv_shape), ('attention', att_shape)], bs)
+      qsym, qarg_params, qaux_params = mx.contrib.quant.quantize_model(sym=sym_sg,
+                                                                       arg_params=arg_params,
+                                                                       aux_params={},
+                                                                       ctx=mx.cpu(),
+                                                                       excluded_sym_names=None,
+                                                                       excluded_op_names=None,
+                                                                       quantize_granularity='tensor-wise',
+                                                                       quantized_dtype='auto',
+                                                                       calib_mode='naive',
+                                                                       calib_data=calib_data,
+                                                                       data_names=('qkv', 'attention'),
+                                                                       label_names=None,
+                                                                       num_calib_examples=1,
+                                                                       quantize_mode='full')
+      qsym = qsym.get_backend_symbol(QUANTIZE_SG_PASS_NAME)
+
+      qex = qsym.bind(mx.cpu(), arg_params, args_grad=None)
+      qex.forward()
+      quantized_out = qex.outputs
+
+      for i in range(len(ref_out)):
+        min_range = mx.nd.min(ref_out[i]).asscalar()
+        max_range = mx.nd.max(ref_out[i]).asscalar()
+        atol = 0.1 * max(abs(min_range), abs(max_range))
+        assert_almost_equal_with_err(quantized_out[i].asnumpy(), ref_out[i].asnumpy(), rtol=0.1, atol=atol, etol=0.2)
+
+  for bs, seqlen, nhidden, nheads in itertools.product(batchsizes, seq_lengths, num_hidden, num_heads):
+    qkv_shape = (bs, seqlen, 3*nhidden)
+    att_shape = (bs, nheads, seqlen, seqlen)
+
+    sym = get_valatt_symbol(qkv_shape, att_shape, nheads)
+    check_fusion(sym, None, {'selfatt_valatt': {}}, check_quantization=False)
+    check_valatt_quantize(sym, qkv_shape, att_shape)
+
+
+@with_seed()
 def test_neg_fc_relu():
   for dshape, no_bias, flatten in itertools.product(DATA_SHAPE, [True, False], [True, False]):
     syms, attrs, excluded_attrs = neg_fc_relu(no_bias, dshape, flatten)