You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2021/04/20 13:25:47 UTC
[incubator-mxnet] branch v1.x updated: [Feature] Add oneDNN support
for interleaved_matmul_selfatt_* operators (fp32/int8) (#20163)
This is an automated email from the ASF dual-hosted git repository.
zhasheng pushed a commit to branch v1.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/v1.x by this push:
new 16d1da9 [Feature] Add oneDNN support for interleaved_matmul_selfatt_* operators (fp32/int8) (#20163)
16d1da9 is described below
commit 16d1da9055140c53029dcdfa290e2b36ad26b20f
Author: bgawrych <ba...@intel.com>
AuthorDate: Tue Apr 20 15:22:53 2021 +0200
[Feature] Add oneDNN support for interleaved_matmul_selfatt_* operators (fp32/int8) (#20163)
* Add oneDNN code to interleved kernels
* check
* Fix selfattQK subgraph
* fix qk
* Fixes QK
* add test for oneDNN self_att qk
* basic valatt
* add valatt test
* refactor valatt
* fix review
* Change param struct name
* Fix sanity
* Fix sanity
Co-authored-by: grygielski <ad...@gmail.com>
---
.../subgraph/mkldnn/mkldnn_subgraph_property.cc | 23 +-
.../subgraph/mkldnn/mkldnn_transformer-inl.h | 58 ++
src/operator/subgraph/mkldnn/mkldnn_transformer.cc | 670 +++++++++++++++++++++
.../mkldnn_transformer_post_quantize_property.h | 207 +++++++
.../subgraph/mkldnn/mkldnn_transformer_property.h | 136 +++++
tests/python/mkl/test_subgraph.py | 110 +++-
6 files changed, 1187 insertions(+), 17 deletions(-)
diff --git a/src/operator/subgraph/mkldnn/mkldnn_subgraph_property.cc b/src/operator/subgraph/mkldnn/mkldnn_subgraph_property.cc
index 18cd303..9190ba4 100644
--- a/src/operator/subgraph/mkldnn/mkldnn_subgraph_property.cc
+++ b/src/operator/subgraph/mkldnn/mkldnn_subgraph_property.cc
@@ -25,6 +25,8 @@
#include "mkldnn_fc_post_quantize_property.h"
#include "mkldnn_elemwisemul_post_quantize_property.h"
#include "mkldnn_post_quantize_align_scale_property.h"
+#include "mkldnn_transformer_property.h"
+#include "mkldnn_transformer_post_quantize_property.h"
namespace mxnet {
namespace op {
@@ -35,34 +37,29 @@ MXNET_REGISTER_SUBGRAPH_BACKEND(MKLDNN)
MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN, SgMKLDNNConvProperty);
-#endif // MXNET_USE_MKLDNN == 1
-#if MXNET_USE_MKLDNN == 1
MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN, SgMKLDNNFCProperty);
-#endif // MXNET_USE_MKLDNN == 1
-#if MXNET_USE_MKLDNN == 1
+
+MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN, SgMKLDNNTransformerProperty);
+
MXNET_REGISTER_SUBGRAPH_BACKEND(MKLDNN_QUANTIZE)
.set_attr("context", Context::CPU());
MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, SgMKLDNNConvProperty)
.set_attr("quantize", true);
-#endif // MXNET_USE_MKLDNN == 1
-#if MXNET_USE_MKLDNN == 1
-
MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, SgMKLDNNFCProperty)
.set_attr("quantize", true);
-#endif // MXNET_USE_MKLDNN == 1
-#if MXNET_USE_MKLDNN == 1
+
+MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, SgMKLDNNTransformerProperty);
+
+MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, SgMKLDNNTransformerPostQuantizeProperty);
+
MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, SgMKLDNNPostQuantizeProperty);
-#endif // MXNET_USE_MKLDNN == 1
-#if MXNET_USE_MKLDNN == 1
MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, SgMKLDNNFCPostQuantizeProperty);
MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, ElemwiseMulPostQuantizeProperty);
MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, SgMKLDNNPostQuantizeAlignScaleProperty);
-#endif // MXNET_USE_MKLDNN == 1
-#if MXNET_USE_MKLDNN == 1
} // namespace op
} // namespace mxnet
#endif // MXNET_USE_MKLDNN == 1
diff --git a/src/operator/subgraph/mkldnn/mkldnn_transformer-inl.h b/src/operator/subgraph/mkldnn/mkldnn_transformer-inl.h
new file mode 100644
index 0000000..d400435
--- /dev/null
+++ b/src/operator/subgraph/mkldnn/mkldnn_transformer-inl.h
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_TRANSFORMER_INL_H_
+#define MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_TRANSFORMER_INL_H_
+
+#include "../../mxnet_op.h"
+#include "../../mshadow_op.h"
+
+
+namespace mxnet {
+namespace op {
+
+struct MKLDNNSelfAttParam : public dmlc::Parameter<MKLDNNSelfAttParam> {
+ int heads;
+ bool quantized;
+ bool enable_float_output;
+ dmlc::optional<float> min_calib_range; // min float value calculated from calibration dataset
+ dmlc::optional<float> max_calib_range; // max float value calculated from calibration dataset
+ DMLC_DECLARE_PARAMETER(MKLDNNSelfAttParam) {
+ DMLC_DECLARE_FIELD(heads)
+ .describe("Set number of heads");
+ DMLC_DECLARE_FIELD(quantized).set_default(false)
+ .describe("Whether it's a quantized InterleavedMatMul operator");
+ DMLC_DECLARE_FIELD(enable_float_output).set_default(false)
+ .describe("Whether to enable float32 output");
+ DMLC_DECLARE_FIELD(min_calib_range)
+ .set_default(dmlc::optional<float>())
+ .describe("The minimum scalar value in the form of float32 obtained "
+ "through calibration. If present, it will be used to by "
+ "quantized InterleavedMatMul op to calculate primitive scale");
+ DMLC_DECLARE_FIELD(max_calib_range)
+ .set_default(dmlc::optional<float>())
+ .describe("The maximum scalar value in the form of float32 obtained "
+ "through calibration. If present, it will be used to by "
+ "quantized InterleavedMatMul op to calculate primitive scale");
+ }
+};
+
+} // namespace op
+} // namespace mxnet
+#endif // MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_TRANSFORMER_INL_H_
diff --git a/src/operator/subgraph/mkldnn/mkldnn_transformer.cc b/src/operator/subgraph/mkldnn/mkldnn_transformer.cc
new file mode 100644
index 0000000..8757da0
--- /dev/null
+++ b/src/operator/subgraph/mkldnn/mkldnn_transformer.cc
@@ -0,0 +1,670 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements. See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership. The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an
+* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+* KIND, either express or implied. See the License for the
+* specific language governing permissions and limitations
+* under the License.
+*/
+
+#if MXNET_USE_MKLDNN == 1
+
+#include <utility>
+#include <vector>
+#include <string>
+#include "../common.h"
+#include "./mkldnn_transformer-inl.h"
+#include "../../contrib/transformer-inl.h"
+#include "../../tensor/elemwise_unary_op.h"
+
+#include "../../quantization/quantization_utils.h"
+
+namespace mxnet {
+namespace op {
+
+DMLC_REGISTER_PARAMETER(MKLDNNSelfAttParam);
+
+template<int base_num_inputs>
+static bool SgMKLDNNSelfAttShape(const NodeAttrs& attrs,
+ mxnet::ShapeVector* in_shapes,
+ mxnet::ShapeVector* out_shapes) {
+ const auto& param = nnvm::get<MKLDNNSelfAttParam>(attrs.parsed);
+ if (param.quantized) {
+ mxnet::ShapeVector base_in_shapes;
+ mxnet::ShapeVector base_out_shapes = {out_shapes->at(0)};
+
+ for (int i = 0; i < base_num_inputs; i++) {
+ base_in_shapes.emplace_back(in_shapes->at(i));
+ }
+ bool ret = DefaultSubgraphOpShape(attrs, &base_in_shapes, &base_out_shapes);
+
+ for (size_t i = 0; i < in_shapes->size(); ++i) {
+ if (i < base_in_shapes.size())
+ in_shapes->at(i) = base_in_shapes[i];
+ else
+ SHAPE_ASSIGN_CHECK(*in_shapes, i, mxnet::TShape({1}));
+ }
+ out_shapes->resize(3);
+ out_shapes->at(0) = base_out_shapes[0];
+ if (!param.enable_float_output) {
+ SHAPE_ASSIGN_CHECK(*out_shapes, 1, mxnet::TShape({1})); // min output
+ SHAPE_ASSIGN_CHECK(*out_shapes, 2, mxnet::TShape({1})); // max output
+ }
+
+ return ret;
+ } else {
+ return DefaultSubgraphOpShape(attrs, in_shapes, out_shapes);
+ }
+}
+
+static bool SgMKLDNNSelfAttQKInferType(const nnvm::NodeAttrs &attrs,
+ std::vector<int> *in_types,
+ std::vector<int> *out_types) {
+ const auto& param = nnvm::get<MKLDNNSelfAttParam>(attrs.parsed);
+ if (param.quantized) {
+ CHECK(in_types->at(0) == mshadow::kInt8)
+ << "QuantizedInterleavedMatMulSelfAttQK only supports int8 input, while "
+ << in_types->at(0) << " is given.";
+
+ TYPE_ASSIGN_CHECK(*in_types, 1, mshadow::kFloat32); // min value
+ TYPE_ASSIGN_CHECK(*in_types, 2, mshadow::kFloat32); // max value
+
+ if (param.enable_float_output) {
+ TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kFloat32); // output
+ } else {
+ if (param.min_calib_range.has_value() && param.max_calib_range.has_value()) {
+ TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kInt8); // output
+ } else {
+ TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kInt32); // output
+ }
+ TYPE_ASSIGN_CHECK(*out_types, 1, mshadow::kFloat32); // min output
+ TYPE_ASSIGN_CHECK(*out_types, 2, mshadow::kFloat32); // max output
+ }
+ return true;
+ } else {
+ return DefaultSubgraphOpType(attrs, in_types, out_types);
+ }
+}
+
+template<int base_num_inputs>
+static bool SgMKLDNNSelfAttStorageType(const nnvm::NodeAttrs &attrs,
+ const int dev_mask,
+ DispatchMode *dispatch_mode,
+ std::vector<int> *in_attrs,
+ std::vector<int> *out_attrs) {
+ auto const ¶m = nnvm::get<MKLDNNSelfAttParam>(attrs.parsed);
+ if (param.quantized) {
+ std::vector<int> base_in_attrs;
+ std::vector<int> base_out_attrs{out_attrs->at(0)};
+
+ for (int i = 0; i < base_num_inputs; i++) {
+ base_in_attrs.emplace_back(in_attrs->at(i));
+ }
+ bool ret = DefaultSubgraphOpStorageType(attrs, dev_mask, dispatch_mode,
+ &base_in_attrs, &base_out_attrs);
+
+ for (size_t i = 0; i < in_attrs->size(); ++i) {
+ if (i < base_in_attrs.size())
+ in_attrs->at(i) = base_in_attrs[i];
+ else
+ type_assign(&in_attrs->at(i), mxnet::kDefaultStorage);
+ }
+
+ out_attrs->at(0) = base_out_attrs[0];
+ if (!param.enable_float_output) {
+ type_assign(&out_attrs->at(1), mxnet::kDefaultStorage);
+ type_assign(&out_attrs->at(2), mxnet::kDefaultStorage);
+ }
+ return ret;
+ } else {
+ return DefaultSubgraphOpStorageType(attrs, dev_mask, dispatch_mode,
+ in_attrs, out_attrs);
+ }
+}
+
+class SgMKLDNNSelfAttQKOp {
+ public:
+ explicit SgMKLDNNSelfAttQKOp(const nnvm::NodeAttrs &attrs) :
+ param_(nnvm::get<MKLDNNSelfAttParam>(attrs.parsed)) {}
+
+ void Forward(const OpContext &ctx,
+ const std::vector<NDArray> &inputs,
+ const std::vector<OpReqType> &req,
+ const std::vector<NDArray> &outputs);
+
+ void Backward(const OpContext &ctx,
+ const std::vector<NDArray> &inputs,
+ const std::vector<OpReqType> &req,
+ const std::vector<NDArray> &outputs) {
+ LOG(FATAL) << "Not implemented: subgraph mkldnn fully connected only supports "
+ "inference computation.";
+ }
+
+ void Initialize(const OpContext &ctx,
+ const std::vector<NDArray> &inputs,
+ const std::vector<OpReqType> &req,
+ const std::vector<NDArray> &outputs);
+
+ bool IsInitialized() {
+ return initialized_;
+ }
+
+ private:
+ bool initialized_{false};
+ MKLDNNSelfAttParam param_;
+ mkldnn_args_map_t args_;
+ std::shared_ptr<dnnl::matmul> fwd_;
+ std::shared_ptr<dnnl::memory> cached_query_mem_;
+ std::shared_ptr<dnnl::memory> cached_key_mem_;
+ std::shared_ptr<dnnl::memory> cached_out_mem_;
+ float min_data_;
+ float max_data_;
+ float min_output_;
+ float max_output_;
+ float data_scale_{0.0f};
+};
+
+static OpStatePtr CreateSgMKLDNNSelfAttQKState(const nnvm::NodeAttrs &attrs,
+ Context ctx,
+ const mxnet::ShapeVector &in_shapes,
+ const std::vector<int> &in_types) {
+ return OpStatePtr::Create<SgMKLDNNSelfAttQKOp>(attrs);
+}
+
+static void SgMKLDNNSelfAttQKForward(const OpStatePtr &state_pointer,
+ const OpContext &ctx,
+ const std::vector<NDArray> &inputs,
+ const std::vector<OpReqType> &req,
+ const std::vector<NDArray> &outputs) {
+ SgMKLDNNSelfAttQKOp &op = state_pointer.get_state<SgMKLDNNSelfAttQKOp>();
+ if (!op.IsInitialized()) {
+ op.Initialize(ctx, inputs, req, outputs);
+ }
+ op.Forward(ctx, inputs, req, outputs);
+}
+
+void SgMKLDNNSelfAttQKOp::Initialize(const OpContext &ctx,
+ const std::vector<NDArray> &inputs,
+ const std::vector<OpReqType> &req,
+ const std::vector<NDArray> &outputs) {
+ using namespace mkldnn;
+ const auto qkv_tensor = inputs[0];
+ const auto out_tensor = outputs[0];
+ const auto qkv_dtype = get_mkldnn_type(qkv_tensor.dtype());
+
+ const memory::dim heads = param_.heads;
+ const memory::dim sequences = inputs[0].shape()[1];
+ const memory::dim qkv_seq_len = inputs[0].shape()[0];
+ const memory::dim output_lin_dim = inputs[0].shape()[2];
+ const memory::dim embed_dim = output_lin_dim / 3;
+ const memory::dim head_dim = embed_dim / heads;
+ const memory::dim attn_batches = heads * sequences;
+ const memory::dim lead_dim = attn_batches * 3 * head_dim;
+ const memory::dim batch_stride = 3 * head_dim;
+
+ float min_data = 0.0f;
+ float max_data = 0.0f;
+
+ if (param_.quantized) {
+ min_data_ = inputs[1].data().dptr<float>()[0];
+ max_data_ = inputs[2].data().dptr<float>()[0];
+ }
+
+ const auto engine = CpuEngine::Get()->get_engine();
+
+ memory::dims query_dims = {attn_batches, qkv_seq_len, head_dim};
+ memory::dims key_dims = {attn_batches, head_dim, qkv_seq_len};
+ memory::dims out_dims = {attn_batches, qkv_seq_len, qkv_seq_len};
+
+ memory::dims query_strides = {batch_stride, lead_dim, 1};
+ memory::dims key_strides = {batch_stride, 1, lead_dim};
+
+ auto query_md = memory::desc(query_dims, qkv_dtype, query_strides);
+ auto key_md = memory::desc(key_dims, qkv_dtype, key_strides);
+
+ memory::desc out_md;
+
+ float oscale = 1.0f;
+ if (param_.quantized) {
+ data_scale_ = GetQuantizeScale(qkv_tensor.dtype(), min_data_, max_data_);
+
+ if (param_.min_calib_range.has_value() &&
+ param_.max_calib_range.has_value()) {
+ min_output_ = param_.min_calib_range.value();
+ max_output_ = param_.max_calib_range.value();
+ oscale =
+ GetQuantizeScale(out_tensor.dtype(), min_output_, max_output_) /
+ (data_scale_ * data_scale_);
+ out_md = memory::desc(out_dims, memory::data_type::s8, memory::format_tag::abc);
+ } else if (param_.enable_float_output) {
+ oscale = 1.0f / (data_scale_ * data_scale_);
+ out_md = dnnl::memory::desc(out_dims, memory::data_type::f32, memory::format_tag::abc);
+ } else {
+ mshadow::Stream<cpu> *s = ctx.get_stream<cpu>();
+ mxnet_op::Kernel<QuantizationRangeForS8S8MultiplicationStruct, cpu>::Launch(
+ s, 1, &min_output_, &max_output_, &min_data, &max_data, &min_data,
+ &max_data);
+ out_md = dnnl::memory::desc(out_dims, memory::data_type::s32, memory::format_tag::abc);
+ }
+ } else {
+ out_md = dnnl::memory::desc(out_dims, memory::data_type::f32, memory::format_tag::abc);
+ }
+ oscale /= sqrt(static_cast<float>(head_dim)); // combine quantized scale and sqrt(head_dim)
+
+ dnnl::primitive_attr attr;
+ attr.set_output_scales(0, {oscale});
+ auto matmul_d = matmul::desc(query_md, key_md, out_md);
+ auto matmul_pd = matmul::primitive_desc(matmul_d, attr, engine);
+
+ fwd_ = std::make_shared<matmul>(matmul_pd);
+
+ MSHADOW_TYPE_SWITCH(inputs[0].dtype(), DType, {
+ DType* query_mem_ptr = inputs[0].data().dptr<DType>();
+ DType* key_mem_ptr = query_mem_ptr + head_dim;
+ cached_query_mem_ = std::make_shared<memory>(query_md, engine, query_mem_ptr);
+ cached_key_mem_ = std::make_shared<memory>(key_md, engine, key_mem_ptr);
+ });
+ MSHADOW_TYPE_SWITCH(outputs[0].dtype(), DType, {
+ cached_out_mem_ = std::make_shared<memory>(out_md, engine, outputs[0].data().dptr<DType>());
+ });
+
+ args_[DNNL_ARG_SRC] = *cached_query_mem_;
+ args_[DNNL_ARG_WEIGHTS] = *cached_key_mem_;
+ args_[DNNL_ARG_DST] = *cached_out_mem_;
+ initialized_ = true;
+}
+
+
+void SgMKLDNNSelfAttQKOp::Forward(const OpContext &ctx,
+ const std::vector<NDArray> &inputs,
+ const std::vector<OpReqType> &req,
+ const std::vector<NDArray> &outputs) {
+ const size_t head_dim = inputs[0].shape()[2] / 3 / param_.heads;
+
+ MSHADOW_TYPE_SWITCH(inputs[0].dtype(), DType, {
+ DType* query_mem_ptr = inputs[0].data().dptr<DType>();
+ DType* key_mem_ptr = query_mem_ptr + head_dim;
+ cached_query_mem_->set_data_handle(query_mem_ptr);
+ cached_key_mem_->set_data_handle(key_mem_ptr);
+ });
+
+ MSHADOW_TYPE_SWITCH(outputs[0].dtype(), DType, {
+ cached_out_mem_->set_data_handle(outputs[0].data().dptr<DType>());
+ });
+
+ MKLDNNStream::Get()->RegisterPrimArgs(*fwd_, args_);
+ MKLDNNStream::Get()->Submit();
+
+ if (param_.quantized && !param_.enable_float_output) {
+ float* output_min = outputs[1].data().dptr<float>();
+ float* output_max = outputs[2].data().dptr<float>();
+
+ *output_min = min_output_;
+ *output_max = max_output_;
+ }
+}
+
+nnvm::ObjectPtr SgMKLDNNSelfAttQKQuantizedOp(const NodeAttrs& attrs) {
+ nnvm::ObjectPtr node = nnvm::Node::Create();
+ auto const ¶m = nnvm::get<MKLDNNSelfAttParam>(attrs.parsed);
+ node->attrs.op = Op::Get("_sg_mkldnn_selfatt_qk");
+ node->attrs.name = "quantized_" + attrs.name;
+ node->attrs.dict = attrs.dict;
+ node->attrs.dict["heads"] = std::to_string(param.heads);
+ node->attrs.dict["quantized"] = "True";
+ node->attrs.subgraphs.reserve(attrs.subgraphs.size());
+ for (auto sub : attrs.subgraphs) {
+ node->attrs.subgraphs.push_back(sub);
+ }
+ node->op()->attr_parser(&(node->attrs));
+ return node;
+}
+
+NNVM_REGISTER_OP(_sg_mkldnn_selfatt_qk)
+.describe(R"code(_sg_mkldnn_selfatt_qk)code" ADD_FILELINE)
+.set_num_inputs([](const NodeAttrs& attrs) {
+ auto const& param = nnvm::get<MKLDNNSelfAttParam>(attrs.parsed);
+ if (param.quantized) {
+ return 3;
+ } else {
+ return 1;
+ }
+})
+.set_num_outputs([](const NodeAttrs& attrs) {
+ auto const& param = nnvm::get<MKLDNNSelfAttParam>(attrs.parsed);
+ if (param.quantized && !param.enable_float_output) {
+ return 3;
+ } else {
+ return 1;
+ }
+})
+.set_attr_parser(ParamParser<MKLDNNSelfAttParam>)
+.set_attr<nnvm::FListInputNames>("FListInputNames", [](const NodeAttrs& attrs) {
+ auto const& param = nnvm::get<MKLDNNSelfAttParam>(attrs.parsed);
+ std::vector<std::string> input_names {"queries_keys_values"};
+ if (param.quantized) {
+ input_names.emplace_back("min_qkv");
+ input_names.emplace_back("max_qkv");
+ }
+ return input_names;
+})
+.set_attr<nnvm::FListOutputNames>("FListOutputNames", [](const NodeAttrs& attrs) {
+ auto const& param = nnvm::get<MKLDNNSelfAttParam>(attrs.parsed);
+ std::vector<std::string> output_names {"output"};
+ if (param.quantized && !param.enable_float_output) {
+ output_names.emplace_back("min_output");
+ output_names.emplace_back("max_output");
+ }
+ return output_names;
+})
+.set_attr<mxnet::FInferShape>("FInferShape", SgMKLDNNSelfAttShape<1>)
+.set_attr<nnvm::FInferType>("FInferType", SgMKLDNNSelfAttQKInferType)
+.set_attr<FInferStorageType>("FInferStorageType", SgMKLDNNSelfAttStorageType<1>)
+.set_attr<FCreateOpState>("FCreateOpState", CreateSgMKLDNNSelfAttQKState)
+.set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", SgMKLDNNSelfAttQKForward)
+.set_attr<bool>("TIsMKLDNN", true)
+.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
+.set_attr<FQuantizable>("FQuantizable", [](const NodeAttrs& attrs) {
+ return QuantizeType::kMust;
+})
+.set_attr<FQuantizedOp>("FQuantizedOp", SgMKLDNNSelfAttQKQuantizedOp)
+.set_attr<FNeedRequantize>("FNeedRequantize", [](const NodeAttrs& attrs) { return true; })
+.add_argument("queries_keys_values", "NDArray-or-Symbol", "Interleaved queries, keys and values")
+.add_arguments(MKLDNNSelfAttParam::__FIELDS__());
+
+/**********************************_sg_mkldnn_selfatt_valatt**********************************/
+
+static bool SgMKLDNNSelfAttValAttInferType(const nnvm::NodeAttrs &attrs,
+ std::vector<int> *in_types,
+ std::vector<int> *out_types) {
+ const auto& param = nnvm::get<MKLDNNSelfAttParam>(attrs.parsed);
+ if (param.quantized) {
+ TYPE_ASSIGN_CHECK(*in_types, 0, mshadow::kInt8); // qkv input
+ TYPE_ASSIGN_CHECK(*in_types, 1, mshadow::kUint8); // att input
+
+ // min qkv, max qkv, min att, max att
+ for (size_t i = 2; i < in_types->size(); ++i) {
+ TYPE_ASSIGN_CHECK(*in_types, i, mshadow::kFloat32);
+ }
+
+ if (param.enable_float_output) {
+ TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kFloat32); // output
+ } else {
+ if (param.min_calib_range.has_value() && param.max_calib_range.has_value()) {
+ TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kInt8); // output
+ } else {
+ TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kInt32); // output
+ }
+ TYPE_ASSIGN_CHECK(*out_types, 1, mshadow::kFloat32); // min output
+ TYPE_ASSIGN_CHECK(*out_types, 2, mshadow::kFloat32); // max output
+ }
+ return true;
+ } else {
+ return DefaultSubgraphOpType(attrs, in_types, out_types);
+ }
+}
+
+nnvm::ObjectPtr SgMKLDNNSelfAttValAttQuantizedOp(const NodeAttrs& attrs) {
+ nnvm::ObjectPtr node = nnvm::Node::Create();
+ auto const ¶m = nnvm::get<MKLDNNSelfAttParam>(attrs.parsed);
+ node->attrs.op = Op::Get("_sg_mkldnn_selfatt_valatt");
+ node->attrs.name = "quantized_" + attrs.name;
+ node->attrs.dict = attrs.dict;
+ node->attrs.dict["heads"] = std::to_string(param.heads);
+ node->attrs.dict["quantized"] = "True";
+ node->attrs.subgraphs.reserve(attrs.subgraphs.size());
+ for (auto sub : attrs.subgraphs) {
+ node->attrs.subgraphs.push_back(sub);
+ }
+ node->op()->attr_parser(&(node->attrs));
+ return node;
+}
+
+class MKLDNNSelfAttValAttOp {
+ public:
+ explicit MKLDNNSelfAttValAttOp(const nnvm::NodeAttrs &attrs) :
+ param_(nnvm::get<MKLDNNSelfAttParam>(attrs.parsed)) {}
+
+ void Forward(const OpContext &ctx,
+ const std::vector<NDArray> &inputs,
+ const std::vector<OpReqType> &req,
+ const std::vector<NDArray> &outputs);
+
+ void Backward(const OpContext &ctx,
+ const std::vector<NDArray> &inputs,
+ const std::vector<OpReqType> &req,
+ const std::vector<NDArray> &outputs) {
+ LOG(FATAL) << "Not implemented: subgraph mkldnn fully connected only supports "
+ "inference computation.";
+ }
+
+ void Initialize(const OpContext &ctx,
+ const std::vector<NDArray> &inputs,
+ const std::vector<OpReqType> &req,
+ const std::vector<NDArray> &outputs);
+
+ bool IsInitialized() {
+ return initialized_;
+ }
+
+ private:
+ bool initialized_{false};
+ MKLDNNSelfAttParam param_;
+ mkldnn_args_map_t args_;
+ std::shared_ptr<dnnl::matmul> fwd_;
+ std::shared_ptr<dnnl::memory> cached_att_mem_;
+ std::shared_ptr<dnnl::memory> cached_qkv_mem_;
+ std::shared_ptr<dnnl::memory> cached_out_mem_;
+ float min_qkv_;
+ float max_qkv_;
+ float min_att_;
+ float max_att_;
+ float min_output_;
+ float max_output_;
+ float qkv_scale_{0.0f};
+ float att_scale_{0.0f};
+};
+
+static OpStatePtr CreateMKLDNNSelfAttValAttState(const nnvm::NodeAttrs &attrs,
+ Context ctx,
+ const mxnet::ShapeVector &in_shapes,
+ const std::vector<int> &in_types) {
+ return OpStatePtr::Create<MKLDNNSelfAttValAttOp>(attrs);
+}
+
+static void MKLDNNSelfAttValAttForward(const OpStatePtr &state_pointer,
+ const OpContext &ctx,
+ const std::vector<NDArray> &inputs,
+ const std::vector<OpReqType> &req,
+ const std::vector<NDArray> &outputs) {
+ MKLDNNSelfAttValAttOp &op = state_pointer.get_state<MKLDNNSelfAttValAttOp>();
+ if (!op.IsInitialized()) {
+ op.Initialize(ctx, inputs, req, outputs);
+ }
+ op.Forward(ctx, inputs, req, outputs);
+}
+
+void MKLDNNSelfAttValAttOp::Initialize(const OpContext &ctx,
+ const std::vector<NDArray> &inputs,
+ const std::vector<OpReqType> &req,
+ const std::vector<NDArray> &outputs) {
+ const dnnl::memory::dim qkv_seq_len = inputs[0].shape()[0];
+ const dnnl::memory::dim sequences = inputs[0].shape()[1];
+ const dnnl::memory::dim output_lin_dim = inputs[0].shape()[2];
+ const dnnl::memory::dim embed_dim = output_lin_dim / 3;
+ const dnnl::memory::dim head_dim = embed_dim / param_.heads;
+ const dnnl::memory::dim attn_batches = param_.heads * sequences;
+ const dnnl::memory::dim lead_dim = attn_batches * 3 * head_dim;
+ const dnnl::memory::dim batch_stride = 3 * head_dim;
+
+
+ dnnl::memory::dims att_dims = {attn_batches, qkv_seq_len, qkv_seq_len};
+ dnnl::memory::dims qkv_dims = {attn_batches, qkv_seq_len, head_dim};
+ dnnl::memory::dims dst_dims = {attn_batches, qkv_seq_len, head_dim};
+
+ dnnl::memory::dims att_strides = {qkv_seq_len * qkv_seq_len, qkv_seq_len, 1};
+ dnnl::memory::dims qkv_strides = {batch_stride, lead_dim, 1};
+
+ auto att_dtype = inputs[1].dtype();
+ auto qkv_dtype = inputs[0].dtype();
+ auto out_dtype = outputs[0].dtype();
+ auto att_md = dnnl::memory::desc(att_dims, get_mkldnn_type(att_dtype), att_strides);
+ auto qkv_md = dnnl::memory::desc(qkv_dims, get_mkldnn_type(qkv_dtype), qkv_strides);
+
+ dnnl::memory::desc out_md;
+ dnnl::primitive_attr attr;
+
+ float oscale = 1.0f;
+ if (param_.quantized) {
+ min_qkv_ = inputs[2].data().dptr<float>()[0];
+ max_qkv_ = inputs[3].data().dptr<float>()[0];
+ min_att_ = inputs[4].data().dptr<float>()[0];
+ max_att_ = inputs[5].data().dptr<float>()[0];
+ qkv_scale_ = GetQuantizeScale(qkv_dtype, min_qkv_, max_qkv_);
+ att_scale_ = GetQuantizeScale(att_dtype, min_att_, max_att_);
+
+ if (param_.min_calib_range.has_value() &&
+ param_.max_calib_range.has_value()) {
+ min_output_ = param_.min_calib_range.value();
+ max_output_ = param_.max_calib_range.value();
+
+ oscale = GetQuantizeScale(out_dtype, min_output_, max_output_) / (qkv_scale_ * att_scale_);
+ attr.set_output_scales(0, {oscale});
+ } else if (param_.enable_float_output) {
+ oscale = 1.0f / (qkv_scale_ * att_scale_);
+ attr.set_output_scales(0, {oscale});
+ } else {
+ mshadow::Stream<cpu> *s = ctx.get_stream<cpu>();
+ mxnet_op::Kernel<QuantizationRangeForS8U8MultiplicationStruct, cpu>::Launch(
+ s, 1, &min_output_, &max_output_, &min_qkv_, &max_qkv_, &min_att_,
+ &max_att_);
+ }
+ }
+ out_md = dnnl::memory::desc(dst_dims, get_mkldnn_type(out_dtype), dnnl::memory::format_tag::bac);
+
+ const auto engine = CpuEngine::Get()->get_engine();
+ auto matmul_d = dnnl::matmul::desc(att_md, qkv_md, out_md);
+ auto matmul_pd = dnnl::matmul::primitive_desc(matmul_d, attr, engine);
+
+ fwd_ = std::make_shared<dnnl::matmul>(matmul_pd);
+
+ MSHADOW_TYPE_SWITCH(att_dtype, DType, {
+ DType* att_ptr = inputs[1].data().dptr<DType>();
+ cached_att_mem_ = std::make_shared<dnnl::memory>(att_md, engine, att_ptr);
+ });
+ MSHADOW_TYPE_SWITCH(qkv_dtype, DType, {
+ DType* value_ptr = inputs[0].data().dptr<DType>() + 2*head_dim;
+ cached_qkv_mem_ = std::make_shared<dnnl::memory>(qkv_md, engine, value_ptr);
+ });
+ MSHADOW_TYPE_SWITCH(out_dtype, DType, {
+ DType* out_ptr = outputs[0].data().dptr<DType>();
+ cached_out_mem_ = std::make_shared<dnnl::memory>(out_md, engine, out_ptr);
+ });
+
+ args_[DNNL_ARG_SRC] = *cached_att_mem_;
+ args_[DNNL_ARG_WEIGHTS] = *cached_qkv_mem_;
+ args_[DNNL_ARG_DST] = *cached_out_mem_;
+ initialized_ = true;
+}
+
+void MKLDNNSelfAttValAttOp::Forward(const OpContext &ctx,
+ const std::vector<NDArray> &inputs,
+ const std::vector<OpReqType> &req,
+ const std::vector<NDArray> &outputs) {
+ const auto engine = CpuEngine::Get()->get_engine();
+ const size_t head_dim = inputs[0].shape()[2] / param_.heads / 3;
+ MSHADOW_TYPE_SWITCH(inputs[1].dtype(), DType, {
+ DType* att_ptr = inputs[1].data().dptr<DType>();
+ cached_att_mem_->set_data_handle(att_ptr);
+ });
+ MSHADOW_TYPE_SWITCH(inputs[0].dtype(), DType, {
+ DType* value_ptr = inputs[0].data().dptr<DType>() + 2*head_dim;
+ cached_qkv_mem_->set_data_handle(value_ptr);
+ });
+ MSHADOW_TYPE_SWITCH(outputs[0].dtype(), DType, {
+ DType* out_ptr = outputs[0].data().dptr<DType>();
+ cached_out_mem_->set_data_handle(out_ptr);
+ });
+
+ MKLDNNStream::Get()->RegisterPrimArgs(*fwd_, args_);
+ MKLDNNStream::Get()->Submit();
+
+ if (param_.quantized && !param_.enable_float_output) {
+ float* output_min = outputs[1].data().dptr<float>();
+ float* output_max = outputs[2].data().dptr<float>();
+
+ *output_min = min_output_;
+ *output_max = max_output_;
+ }
+}
+
+NNVM_REGISTER_OP(_sg_mkldnn_selfatt_valatt)
+.describe(R"code(_sg_mkldnn_selfatt_valatt)code" ADD_FILELINE)
+.set_num_inputs([](const NodeAttrs& attrs) {
+ auto const& param = nnvm::get<MKLDNNSelfAttParam>(attrs.parsed);
+ if (param.quantized) {
+ return 6;
+ } else {
+ return 2;
+ }
+})
+.set_num_outputs([](const NodeAttrs& attrs) {
+ auto const& param = nnvm::get<MKLDNNSelfAttParam>(attrs.parsed);
+ if (param.quantized && !param.enable_float_output) {
+ return 3;
+ } else {
+ return 1;
+ }
+})
+.set_attr_parser(ParamParser<MKLDNNSelfAttParam>)
+.set_attr<nnvm::FListInputNames>("FListInputNames", [](const NodeAttrs& attrs) {
+ auto const& param = nnvm::get<MKLDNNSelfAttParam>(attrs.parsed);
+ std::vector<std::string> input_names {"queries_keys_values", "attention"};
+ if (param.quantized) {
+ input_names.emplace_back("min_qkv");
+ input_names.emplace_back("max_qkv");
+
+ input_names.emplace_back("min_attention");
+ input_names.emplace_back("max_attention");
+ }
+ return input_names;
+})
+.set_attr<nnvm::FListOutputNames>("FListOutputNames", [](const NodeAttrs& attrs) {
+ auto const& param = nnvm::get<MKLDNNSelfAttParam>(attrs.parsed);
+ std::vector<std::string> output_names {"output"};
+ if (param.quantized && !param.enable_float_output) {
+ output_names.emplace_back("min_output");
+ output_names.emplace_back("max_output");
+ }
+ return output_names;
+})
+.set_attr<mxnet::FInferShape>("FInferShape", SgMKLDNNSelfAttShape<2>)
+.set_attr<nnvm::FInferType>("FInferType", SgMKLDNNSelfAttValAttInferType)
+.set_attr<FInferStorageType>("FInferStorageType", SgMKLDNNSelfAttStorageType<2>)
+.set_attr<FCreateOpState>("FCreateOpState", CreateMKLDNNSelfAttValAttState)
+.set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", MKLDNNSelfAttValAttForward)
+.set_attr<bool>("TIsMKLDNN", true)
+.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
+.set_attr<FQuantizable>("FQuantizable", [](const NodeAttrs& attrs) {
+ return QuantizeType::kMust;
+})
+.set_attr<FQuantizedOp>("FQuantizedOp", SgMKLDNNSelfAttValAttQuantizedOp)
+.set_attr<FNeedRequantize>("FNeedRequantize", [](const NodeAttrs& attrs) { return true; })
+.add_argument("queries_keys_values", "NDArray-or-Symbol", "Queries, keys and values interleaved")
+.add_argument("attention", "NDArray-or-Symbol", "Attention maps")
+.add_arguments(MKLDNNSelfAttParam::__FIELDS__());
+
+} // namespace op
+} // namespace mxnet
+
+#endif
diff --git a/src/operator/subgraph/mkldnn/mkldnn_transformer_post_quantize_property.h b/src/operator/subgraph/mkldnn/mkldnn_transformer_post_quantize_property.h
new file mode 100644
index 0000000..adf6230
--- /dev/null
+++ b/src/operator/subgraph/mkldnn/mkldnn_transformer_post_quantize_property.h
@@ -0,0 +1,207 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_TRANSFORMER_POST_QUANTIZE_PROPERTY_H_
+#define MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_TRANSFORMER_POST_QUANTIZE_PROPERTY_H_
+#if MXNET_USE_MKLDNN == 1
+
+#include <string>
+#include <vector>
+#include "../../quantization/requantize-inl.h"
+#include "../common.h"
+#include "mkldnn_subgraph_base-inl.h"
+
+namespace mxnet {
+namespace op {
+
+class SgMKLDNNTransformerPostQuantizeSelector : public SubgraphSelector {
+ public:
+ /*! \brief pattern match status */
+ enum SelectStatus {
+ kFail = 0,
+ kStart,
+ kRequantize,
+ kSuccess,
+ };
+
+ private:
+ bool disable_all;
+ bool disable_float_output;
+ SelectStatus status;
+ std::vector<const nnvm::Node *> matched_list;
+
+ public:
+ explicit SgMKLDNNTransformerPostQuantizeSelector(const bool dis_all,
+ const bool dis_float_output)
+ : disable_all(dis_all),
+ disable_float_output(dis_float_output) {}
+
+ bool Select(const nnvm::Node &n) override {
+ if ((!disable_all) &&
+ (n.op() == Op::Get("_sg_mkldnn_selfatt_qk") ||
+ n.op() == Op::Get("_sg_mkldnn_selfatt_valatt"))) {
+ status = disable_all ? kSuccess : kStart;
+ matched_list.clear();
+ matched_list.push_back(&n);
+ return true;
+ }
+ return false;
+ }
+
+ bool SelectInput(const nnvm::Node &n, const nnvm::Node &new_node) override {
+ return false;
+ }
+
+ bool SelectOutput(const nnvm::Node &n, const nnvm::Node &new_node) override {
+ if (status == kFail || status == kSuccess || new_node.is_variable())
+ return false;
+ // If n isn't the last matched node, then we encoutered a internal
+ // branch, we should pop out the node behind n and stop fusion.
+ if (matched_list.back() != &n) {
+ if (std::find(matched_list.begin(), matched_list.end(), &n) !=
+ matched_list.end()) {
+ while (matched_list.back() != &n) {
+ matched_list.pop_back();
+ }
+ }
+
+ status = kSuccess;
+ return false;
+ }
+
+ switch (status) {
+ case kStart:
+ if (new_node.op() == Op::Get("_contrib_requantize")) {
+ auto const ¶m = nnvm::get<RequantizeParam>(new_node.attrs.parsed);
+ if (param.min_calib_range.has_value() &&
+ param.max_calib_range.has_value()) {
+ matched_list.push_back(&new_node);
+ status = kRequantize;
+ return true;
+ }
+ }
+ case kRequantize:
+ if ((!disable_float_output) && (new_node.op() == Op::Get("_contrib_dequantize"))) {
+ matched_list.push_back(&new_node);
+ status = kSuccess;
+ return true;
+ }
+ default:
+ status = kSuccess;
+ return false;
+ }
+ }
+
+ std::vector<nnvm::Node *> Filter(
+ const std::vector<nnvm::Node *> &candidates) override {
+ if ((status != kSuccess) || (matched_list.size() <= 1)) {
+ return std::vector<nnvm::Node *>(0);
+ } else {
+ std::vector<nnvm::Node *> ret;
+ for (auto i : matched_list) {
+ auto non_const_i = const_cast<nnvm::Node *>(i);
+ if (std::find(candidates.begin(), candidates.end(), non_const_i) !=
+ candidates.end()) {
+ ret.push_back(non_const_i);
+ }
+ }
+ return ret;
+ }
+ }
+
+ void Reset() override {
+ CHECK_GE(matched_list.size(), 1);
+ auto new_selector = SgMKLDNNTransformerPostQuantizeSelector(disable_all, disable_float_output);
+ new_selector.Select(*matched_list[0]);
+ *this = new_selector;
+ }
+};
+
+class SgMKLDNNTransformerPostQuantizeProperty : public SubgraphProperty {
+ public:
+ SgMKLDNNTransformerPostQuantizeProperty() {
+ disable_fuse_all = dmlc::GetEnv("MXNET_DISABLE_MKLDNN_QTRANSFORMER_FUSE_ALL", false);
+ disable_float_output = dmlc::GetEnv("MXNET_DISABLE_MKLDNN_QTRANSFORMER_FLOAT_OUTPUT", false);
+ }
+
+ static SubgraphPropertyPtr Create() {
+ static const std::string &name = "MKLDNN Transformer post-quantization optimization pass";
+ auto property = std::make_shared<SgMKLDNNTransformerPostQuantizeProperty>();
+ property->SetAttr<std::string>("property_name", name);
+ property->SetAttr<bool>("inference_only", true);
+ return property;
+ }
+
+ nnvm::ObjectPtr CreateSubgraphNode(const nnvm::Symbol &sym,
+ const int subgraph_id = 0) const override {
+ nnvm::ObjectPtr interleaved_node = nullptr;
+ nnvm::ObjectPtr requantize_node = nullptr;
+ nnvm::ObjectPtr dequantize_node = nullptr;
+
+ DFSVisit(sym.outputs, [&](const nnvm::ObjectPtr &node) {
+ if (node->is_variable()) return;
+ if (node->op() == Op::Get("_sg_mkldnn_selfatt_qk") ||
+ node->op() == Op::Get("_sg_mkldnn_selfatt_valatt")) {
+ interleaved_node = node;
+ } else if (node->op() == Op::Get("_contrib_requantize")) {
+ requantize_node = node;
+ } else if (node->op() == Op::Get("_contrib_dequantize")) {
+ dequantize_node = node;
+ }
+ });
+
+ CHECK_NOTNULL(interleaved_node);
+ CHECK_NOTNULL(requantize_node);
+ auto const &requantize_param =
+ nnvm::get<RequantizeParam>(requantize_node->attrs.parsed);
+ CHECK(requantize_param.min_calib_range.has_value());
+ CHECK(requantize_param.max_calib_range.has_value());
+
+ // When only fusing quantized_interleaved_matmul and requantize, set min/max_cablib_range,
+ // When fusing quantized_interleaved_matmul + requantize + dequantize,
+ // set dequantize flag to true.
+ if (dequantize_node != nullptr) {
+ interleaved_node->attrs.dict["enable_float_output"] = "True";
+ } else {
+ interleaved_node->attrs.dict["min_calib_range"] =
+ std::to_string(requantize_param.min_calib_range.value());
+ interleaved_node->attrs.dict["max_calib_range"] =
+ std::to_string(requantize_param.max_calib_range.value());
+ }
+ interleaved_node->op()->attr_parser(&(interleaved_node->attrs));
+ return interleaved_node;
+ }
+
+ SubgraphSelectorPtr CreateSubgraphSelector() const override {
+ auto selector =
+ std::make_shared<SgMKLDNNTransformerPostQuantizeSelector>(disable_fuse_all,
+ disable_float_output);
+ return selector;
+ }
+
+ private:
+ bool disable_fuse_all;
+ bool disable_float_output;
+};
+
+} // namespace op
+} // namespace mxnet
+
+#endif // if MXNET_USE_MKLDNN == 1
+#endif // MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_TRANSFORMER_POST_QUANTIZE_PROPERTY_H_
diff --git a/src/operator/subgraph/mkldnn/mkldnn_transformer_property.h b/src/operator/subgraph/mkldnn/mkldnn_transformer_property.h
new file mode 100644
index 0000000..f022bcc
--- /dev/null
+++ b/src/operator/subgraph/mkldnn/mkldnn_transformer_property.h
@@ -0,0 +1,136 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+
+#ifndef MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_TRANSFORMER_PROPERTY_H_
+#define MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_TRANSFORMER_PROPERTY_H_
+#if MXNET_USE_MKLDNN == 1
+
+#include <map>
+#include <string>
+#include <vector>
+#include "../common.h"
+#include "../../tensor/matrix_op-inl.h"
+#include "../../contrib/transformer-inl.h"
+#include "mkldnn_transformer-inl.h"
+#include "mkldnn_subgraph_base-inl.h"
+
+namespace mxnet {
+namespace op {
+
+#define SELFATT_QK "_contrib_interleaved_matmul_selfatt_qk"
+#define SELFATT_VALATT "_contrib_interleaved_matmul_selfatt_valatt"
+
+const std::map<std::string, std::string> OpMapping = {
+ {SELFATT_QK, "_sg_mkldnn_selfatt_qk"},
+ {SELFATT_VALATT, "_sg_mkldnn_selfatt_valatt"}
+};
+
+const std::map<std::string, std::string> NameMapping = {
+ {SELFATT_QK, "sg_mkldnn_selfatt_qk"},
+ {SELFATT_VALATT, "sg_mkldnn_selfatt_valatt"}
+};
+
+class SgMKLDNNTransformerSelector : public SubgraphSelector {
+ public:
+ bool Select(const nnvm::Node &n, const std::shared_ptr<NodeAttr>& node_attr) override {
+ if (n.op() == Op::Get(SELFATT_QK) ||
+ n.op() == Op::Get(SELFATT_VALATT)) {
+ return true;
+ }
+ return false;
+ }
+
+ bool SelectInput(const nnvm::Node &n, const nnvm::Node &new_node) override {
+ return false;
+ }
+
+ bool SelectOutput(const nnvm::Node &n, const nnvm::Node &new_node) override {
+ return false;
+ }
+};
+
+class SgMKLDNNTransformerProperty : public SubgraphProperty {
+ public:
+ SgMKLDNNTransformerProperty() {}
+
+ static SubgraphPropertyPtr Create() {
+ static const std::string &name = "MKLDNN Transformer optimization pass";
+ auto property = std::make_shared<SgMKLDNNTransformerProperty>();
+ property->SetAttr<std::string>("property_name", name);
+ property->SetAttr<bool>("inference_only", true);
+ if (dmlc::GetEnv("MXNET_DISABLE_MKLDNN_TRANSFORMER_OPT", 0)) {
+ property->SetAttr<bool>("disable", true);
+ }
+ return property;
+ }
+
+ nnvm::ObjectPtr CreateSubgraphNode(const nnvm::Symbol &sym,
+ const int subgraph_id = 0) const override {
+ nnvm::ObjectPtr n = nnvm::Node::Create();
+ // This op has single output, remove duplicated.
+ auto last_node = sym.outputs[0].node;
+ nnvm::Symbol new_sym;
+ new_sym.outputs.emplace_back(last_node);
+ std::ostringstream node_name;
+ std::string op_name;
+ MKLDNNSelfAttParam new_param;
+ DFSVisit(new_sym.outputs, [&](const nnvm::ObjectPtr &node) {
+ if (node->op() &&
+ (node->op()->name == SELFATT_QK ||
+ node->op()->name == SELFATT_VALATT)) {
+ op_name = node->op()->name;
+ auto param = nnvm::get<InterleavedMatMulParam>(node->attrs.parsed);
+ new_param.heads = param.heads;
+ new_param.quantized = false;
+ new_param.enable_float_output = false;
+ }
+ });
+ node_name << NameMapping.at(op_name) << "_" << std::to_string(subgraph_id);
+
+
+ n->attrs.name = node_name.str();
+ n->attrs.op = Op::Get(OpMapping.at(op_name));
+ CHECK(n->attrs.op);
+ n->attrs.subgraphs.emplace_back(std::make_shared<nnvm::Symbol>(new_sym));
+ n->attrs.parsed = new_param;
+ return n;
+ }
+
+ SubgraphSelectorPtr CreateSubgraphSelector() const override {
+ auto selector = std::make_shared<SgMKLDNNTransformerSelector>();
+ return selector;
+ }
+
+ void ConnectSubgraphOutputs(
+ const nnvm::ObjectPtr n,
+ std::vector<nnvm::NodeEntry *> *output_entries) const override {
+ // Connect all extern output entries to output[0]
+ for (size_t i = 0; i < output_entries->size(); ++i) {
+ auto entry_ptr = output_entries->at(i);
+ *entry_ptr = nnvm::NodeEntry{n, entry_ptr->index, 0};
+ }
+ }
+};
+
+} // namespace op
+} // namespace mxnet
+
+#endif // if MXNET_USE_MKLDNN == 1
+#endif // MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_TRANSFORMER_PROPERTY_H_
diff --git a/tests/python/mkl/test_subgraph.py b/tests/python/mkl/test_subgraph.py
index 65b73e4..79494a0 100644
--- a/tests/python/mkl/test_subgraph.py
+++ b/tests/python/mkl/test_subgraph.py
@@ -45,6 +45,14 @@ config = {
'fc': {
OP_NAME: 'sg_mkldnn_fully_connected',
QUANTIZED_OP_NAME: 'quantized_sg_mkldnn_fully_connected'
+ },
+ 'selfatt_qk': {
+ OP_NAME: 'sg_mkldnn_selfatt_qk',
+ QUANTIZED_OP_NAME: 'quantized_sg_mkldnn_selfatt_qk'
+ },
+ 'selfatt_valatt': {
+ OP_NAME: 'sg_mkldnn_selfatt_valatt',
+ QUANTIZED_OP_NAME: 'quantized_sg_mkldnn_selfatt_valatt'
}
}
@@ -52,6 +60,10 @@ DATA_SHAPE=[(64, 4, 10, 10), (4, 3, 24, 24), (1, 16, 32, 32)]
fc_post_ops_list=['relu', 'sigmoid', 'tanh', 'softrelu',
'square', 'square_root', 'abs', 'exp', 'bounded_relu']
+quant_op_fp32_output_support = ("quantized_sg_mkldnn_fully_connected",
+ "quantized_sg_mkldnn_selfatt_qk",
+ "quantized_sg_mkldnn_selfatt_valatt")
+
def check_qsym_calibrated(qsym, out_type, name='conv'):
quantized_op_name = 'quantized_' + name
assert ''.join(qsym.attr_dict().keys()).find(quantized_op_name) != -1
@@ -59,7 +71,8 @@ def check_qsym_calibrated(qsym, out_type, name='conv'):
if k.find('_quantize') != -1:
assert v['out_type'] == out_type
if k.find(quantized_op_name) != -1:
- if quantized_op_name.startswith("quantized_sg_mkldnn_fully_connected") and 'enable_float_output' in v:
+ if ('enable_float_output' in v
+ and quantized_op_name.startswith(quant_op_fp32_output_support)):
continue
assert 'min_calib_range' in v
assert 'max_calib_range' in v
@@ -119,9 +132,11 @@ def check_qsym_gluon_forward(qsym, qarg_params, qaux_params, data_shape):
class CalibIter(mx.io.DataIter):
def __init__(self, batch, data_shape, batch_size):
super(CalibIter, self).__init__(batch_size)
- self.data_shape = data_shape
self.label_shape = (batch_size,)
- self.provide_data = [('data', self.data_shape)]
+ if isinstance(data_shape, tuple):
+ self.provide_data = [('data', data_shape)]
+ else:
+ self.provide_data = data_shape
self.provide_label = []
self.batch = batch
@@ -249,7 +264,6 @@ def check_fusion(sym, data_shape, attrs_dict, check_fp32_fusion=True, check_quan
if ''.join(sym.get_internals().list_outputs()).find('sqrt') != -1:
check_quantization = False
data_min = 0
-
sym_sg = sym.get_backend_symbol(SG_PASS_NAME)
for name, attrs in attrs_dict.items():
if name in config:
@@ -677,6 +691,13 @@ def fc_eltwise(no_bias, data_shape, flatten=True, alg='relu'):
return sym, attr
+def single_selfatt_qk(data_shape, nheads=16):
+ attr = {'selfatt_qk': {}}
+ data = mx.symbol.Variable('data', shape=data_shape, dtype='float32')
+ qk = mx.symbol.contrib.interleaved_matmul_selfatt_qk(queries_keys_values=data,
+ heads=nheads)
+ return qk, attr
+
# fc + relu can't be fusion case
# eg.1
# fc -----------> relu
@@ -866,6 +887,87 @@ def test_fc_eltwise():
check_fusion(syms, dshape, attrs, check_quantization=False)
@with_seed()
+def test_selfatt_qk():
+ batchsizes = [1, 8]
+ seq_lengths = [180, 384]
+ num_hidden = [1024, 3072]
+ num_heads = [8, 16]
+ for bs, seqlen, nhidden, nheads in itertools.product(batchsizes, seq_lengths, num_hidden, num_heads):
+ dshape = (seqlen, bs, nhidden)
+ syms, attrs = single_selfatt_qk(dshape, nheads)
+ check_fusion(syms, dshape, attrs, out_types=['int8', 'auto'], check_quantization=True)
+
+@with_seed()
+def test_selfatt_valatt():
+ batchsizes = [1, 8]
+ seq_lengths = [18, 255, 384]
+ num_hidden = [1024, 3072]
+ num_heads = [1, 16]
+
+ def get_valatt_symbol(qkv_shape, attention_shape, nheads):
+ qkv = mx.symbol.Variable('qkv', shape=qkv_shape, dtype='float32')
+ attention = mx.symbol.Variable('attention', shape=attention_shape, dtype='float32')
+ # CalibIter assumes that batch_size is always first dimension
+ # following operators changes shapes to the proper one
+ qkv_swap = mx.symbol.swapaxes(data=qkv, dim1=0, dim2=1)
+ attention_reshape = mx.symbol.reshape(data=attention, shape=(-1, 0, 0), reverse=True)
+ sym = mx.symbol.contrib.interleaved_matmul_selfatt_valatt(queries_keys_values=qkv_swap,
+ attention=attention_reshape,
+ heads=nheads)
+ return sym
+
+ def check_valatt_quantize(sym, qkv_shape, att_shape):
+ qkv_nd = mx.nd.random.uniform(low=-1, high=1, shape=qkv_shape)
+ weight_nd = mx.nd.random.uniform(low=0, high=1, shape=att_shape)
+ arg_params = {
+ 'qkv': qkv_nd,
+ 'attention': weight_nd
+ }
+
+ ex = sym.bind(mx.cpu(), arg_params, args_grad=None)
+ ex.forward()
+ ref_out = ex.outputs
+
+ sym_sg = sym.get_backend_symbol(QUANTIZE_SG_PASS_NAME)
+
+ batch = mx.io.DataBatch([qkv_nd, weight_nd], [])
+ calib_data = CalibIter(batch, [('qkv', qkv_shape), ('attention', att_shape)], bs)
+ qsym, qarg_params, qaux_params = mx.contrib.quant.quantize_model(sym=sym_sg,
+ arg_params=arg_params,
+ aux_params={},
+ ctx=mx.cpu(),
+ excluded_sym_names=None,
+ excluded_op_names=None,
+ quantize_granularity='tensor-wise',
+ quantized_dtype='auto',
+ calib_mode='naive',
+ calib_data=calib_data,
+ data_names=('qkv', 'attention'),
+ label_names=None,
+ num_calib_examples=1,
+ quantize_mode='full')
+ qsym = qsym.get_backend_symbol(QUANTIZE_SG_PASS_NAME)
+
+ qex = qsym.bind(mx.cpu(), arg_params, args_grad=None)
+ qex.forward()
+ quantized_out = qex.outputs
+
+ for i in range(len(ref_out)):
+ min_range = mx.nd.min(ref_out[i]).asscalar()
+ max_range = mx.nd.max(ref_out[i]).asscalar()
+ atol = 0.1 * max(abs(min_range), abs(max_range))
+ assert_almost_equal_with_err(quantized_out[i].asnumpy(), ref_out[i].asnumpy(), rtol=0.1, atol=atol, etol=0.2)
+
+ for bs, seqlen, nhidden, nheads in itertools.product(batchsizes, seq_lengths, num_hidden, num_heads):
+ qkv_shape = (bs, seqlen, 3*nhidden)
+ att_shape = (bs, nheads, seqlen, seqlen)
+
+ sym = get_valatt_symbol(qkv_shape, att_shape, nheads)
+ check_fusion(sym, None, {'selfatt_valatt': {}}, check_quantization=False)
+ check_valatt_quantize(sym, qkv_shape, att_shape)
+
+
+@with_seed()
def test_neg_fc_relu():
for dshape, no_bias, flatten in itertools.product(DATA_SHAPE, [True, False], [True, False]):
syms, attrs, excluded_attrs = neg_fc_relu(no_bias, dshape, flatten)