You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by bg...@apache.org on 2022/05/19 06:59:26 UTC
[incubator-mxnet] branch master updated: fix transformer optimization for gpt-2 (#21007)
This is an automated email from the ASF dual-hosted git repository.
bgawrych pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 0497efd4e4 fix transformer optimization for gpt-2 (#21007)
0497efd4e4 is described below
commit 0497efd4e4fcd18ed1cc459218daf6e67b38d536
Author: bgawrych <ba...@intel.com>
AuthorDate: Thu May 19 08:59:07 2022 +0200
fix transformer optimization for gpt-2 (#21007)
---
.../subgraph/dnnl/dnnl_transformer_qk_property.h | 78 ++++++++++++++--------
.../python/dnnl/subgraphs/test_matmul_subgraph.py | 64 ++++++++++++++----
2 files changed, 101 insertions(+), 41 deletions(-)
diff --git a/src/operator/subgraph/dnnl/dnnl_transformer_qk_property.h b/src/operator/subgraph/dnnl/dnnl_transformer_qk_property.h
index c117cf67fe..fc14df37e2 100644
--- a/src/operator/subgraph/dnnl/dnnl_transformer_qk_property.h
+++ b/src/operator/subgraph/dnnl/dnnl_transformer_qk_property.h
@@ -51,7 +51,7 @@
namespace mxnet {
namespace op {
-class SgDNNLTransformerQKSelector : public SubgraphSelector {
+class SgDNNLTransformerQKSelector : public SubgraphSelectorV2 {
enum SelectStatus {
kFail = 0,
kStart,
@@ -69,59 +69,83 @@ class SgDNNLTransformerQKSelector : public SubgraphSelector {
private:
SelectStatus status_;
- std::vector<const nnvm::Node*> matched_list_;
+ std::vector<const BiDirectedNode*> matched_list_;
+
+ bool CheckSplitConditions(const BiDirectedNode& node) {
+ const SplitParam& param = dmlc::get<SplitParam>(node.node->attrs.parsed);
+
+ if (param.axis != -1 || param.sections != 3 || param.squeeze_axis)
+ return false;
+
+ const auto first_reshape = (*(matched_list_.end() - 2))->node;
+ const auto second_reshape = (*(matched_list_.end() - 1))->node;
+ if (first_reshape->op() != Op::Get("_npx_reshape") ||
+ second_reshape->op() != Op::Get("_npx_reshape")) {
+ return false;
+ }
+ // 3 sections - ensure that every output is used only once
+ if (node.outputs.size() == 3 && node.outputs.count(first_reshape) &&
+ node.outputs.count(second_reshape)) {
+ return true;
+ }
+
+ return false;
+ }
public:
- bool Select(const nnvm::Node& n, const std::shared_ptr<NodeAttr>& node_attr) override {
- if (n.op() == Op::Get("batch_dot")) {
+ bool Select(const BiDirectedNode& seed_node,
+ const std::shared_ptr<NodeAttr>& node_attr) override {
+ if (seed_node.node->op() == Op::Get("batch_dot")) {
status_ = kStart;
matched_list_.clear();
- matched_list_.push_back(&n);
+ matched_list_.push_back(&seed_node);
return true;
}
return false;
}
- bool SelectInput(const nnvm::Node& n, const nnvm::Node& new_node) override {
- if (status_ == kFail || status_ == kSuccess || new_node.is_variable())
+ bool SelectInput(const BiDirectedNode& n, const BiDirectedNode& input_node) override {
+ if (status_ == kFail || status_ == kSuccess || input_node.node->is_variable())
return false;
-
+ const auto& raw_input_node = *input_node.node;
switch (status_) {
case kStart:
- if (new_node.op() == Op::Get("SwapAxis")) {
- if (CheckSwapAxisConditions(new_node)) {
+ if (raw_input_node.op() == Op::Get("SwapAxis")) {
+ if (CheckSwapAxisConditions(raw_input_node)) {
status_ = kFirstSwapAx;
- matched_list_.push_back(&new_node);
+ matched_list_.push_back(&input_node);
}
return true;
}
case kFirstSwapAx:
- if (new_node.op() == Op::Get("SwapAxis")) {
- if (CheckSwapAxisConditions(new_node)) {
+ if (raw_input_node.op() == Op::Get("SwapAxis")) {
+ if (CheckSwapAxisConditions(raw_input_node)) {
status_ = kSecondSwapAx;
- matched_list_.push_back(&new_node);
+ matched_list_.push_back(&input_node);
return true;
}
}
case kSecondSwapAx:
- if (new_node.op() == Op::Get("_npx_reshape")) {
+ if (raw_input_node.op() == Op::Get("_npx_reshape")) {
// input to reshape must be first or second output from split
- if (CheckReshapeConditions(new_node, 0) || CheckReshapeConditions(new_node, 1)) {
+ if (CheckReshapeConditions(raw_input_node, 0) ||
+ CheckReshapeConditions(raw_input_node, 1)) {
status_ = kFirstReshape;
- matched_list_.push_back(&new_node);
+ matched_list_.push_back(&input_node);
return true;
}
}
case kFirstReshape:
- if (new_node.op() == Op::Get("_npx_reshape")) {
- if (CheckReshapeConditions(new_node, 0) || CheckReshapeConditions(new_node, 1)) {
+ if (raw_input_node.op() == Op::Get("_npx_reshape")) {
+ if (CheckReshapeConditions(raw_input_node, 0) ||
+ CheckReshapeConditions(raw_input_node, 1)) {
status_ = kSecondReshape;
- matched_list_.push_back(&new_node);
+ matched_list_.push_back(&input_node);
return true;
}
}
case kSecondReshape:
- if (new_node.op() == Op::Get("_split_v2")) {
+ if (raw_input_node.op() == Op::Get("_split_v2") && CheckSplitConditions(input_node)) {
status_ = kSuccess;
return true;
}
@@ -132,17 +156,17 @@ class SgDNNLTransformerQKSelector : public SubgraphSelector {
return false;
}
- bool SelectOutput(const nnvm::Node& n, const nnvm::Node& new_node) override {
+ bool SelectOutput(const BiDirectedNode& n, const BiDirectedNode& output_node) override {
return false;
}
- std::vector<nnvm::Node*> Filter(const std::vector<nnvm::Node*>& candidates) override {
+ std::vector<BiDirectedNode*> Filter(const std::vector<BiDirectedNode*>& candidates) override {
if (status_ != kSuccess) {
- return std::vector<nnvm::Node*>(0);
+ return std::vector<BiDirectedNode*>(0);
} else {
- std::vector<nnvm::Node*> ret;
+ std::vector<BiDirectedNode*> ret;
for (auto i : matched_list_) {
- auto non_const_i = const_cast<nnvm::Node*>(i);
+ auto non_const_i = const_cast<BiDirectedNode*>(i);
if (std::find(candidates.begin(), candidates.end(), non_const_i) != candidates.end()) {
ret.push_back(non_const_i);
}
@@ -201,7 +225,7 @@ class SgDNNLTransformerQKProperty : public SubgraphProperty {
return n;
}
- SubgraphSelectorPtr CreateSubgraphSelector() const override {
+ SubgraphSelectorV2Ptr CreateSubgraphSelectorV2() const override {
auto selector = std::make_shared<SgDNNLTransformerQKSelector>();
return selector;
}
diff --git a/tests/python/dnnl/subgraphs/test_matmul_subgraph.py b/tests/python/dnnl/subgraphs/test_matmul_subgraph.py
index 5d95d25da4..a96c8c31ba 100644
--- a/tests/python/dnnl/subgraphs/test_matmul_subgraph.py
+++ b/tests/python/dnnl/subgraphs/test_matmul_subgraph.py
@@ -27,16 +27,19 @@ import math
class MultiHeadAttention(nn.HybridBlock):
- def __init__(self, units, num_heads, dtype='float32', **kwargs):
+ def __init__(self, units, num_heads, dtype='float32', negative_case=False, **kwargs):
super(MultiHeadAttention, self).__init__(**kwargs)
self._units = units
self._num_heads = num_heads
self._fc = nn.Dense(in_units=self._units, units=3*self._units, flatten=False, dtype=dtype)
self._scale = math.sqrt(self._units // self._num_heads)
+ self.negative_case = negative_case
def forward(self, x, mask):
out = self._fc(x)
query, key, value = mx.np.split(out, 3, axis=-1)
+ if self.negative_case:
+ key = key * 2
query = mx.npx.reshape(query, (-2, -2, self._num_heads, -1))
key = mx.npx.reshape(key, (-2, -2, self._num_heads, -1))
value = mx.npx.reshape(value, (-2, -2, self._num_heads, -1))
@@ -48,7 +51,6 @@ class MultiHeadAttention(nn.HybridBlock):
context_vec = mx.npx.batch_dot(attn_weights,
mx.np.swapaxes(value, 1, 2)).transpose((0, 2, 1, 3))
context_vec = mx.npx.reshape(context_vec, (-2, -2, -1))
-
return context_vec
@use_np
@@ -70,8 +72,7 @@ def test_self_attention(batch_size, seq_length, units, num_heads):
out = fused_net(in_data, mask)
mx.nd.waitall()
- for i in range(len(out)):
- assert_almost_equal(out[i].asnumpy(), ref_out[i].asnumpy())
+ assert_almost_equal(out.asnumpy(), ref_out.asnumpy())
calib_data = mx.gluon.data.DataLoader(mx.gluon.data.ArrayDataset(in_data, mask), batch_size=1)
qnet = mx.contrib.quant.quantize_net(net, quantized_dtype='auto',
@@ -85,11 +86,47 @@ def test_self_attention(batch_size, seq_length, units, num_heads):
qout = qnet(in_data, mask)
mx.nd.waitall()
- for i in range(len(ref_out)):
- min_range = np.min(ref_out[i].asnumpy())
- max_range = np.max(ref_out[i].asnumpy())
- atol = 0.1 * max(abs(min_range), abs(max_range))
- assert_almost_equal_with_err(qout[i].asnumpy(), ref_out[i].asnumpy(), rtol=0.1, atol=atol, etol=0.2)
+ min_range = np.min(ref_out.asnumpy())
+ max_range = np.max(ref_out.asnumpy())
+ atol = 0.1 * max(abs(min_range), abs(max_range))
+ assert_almost_equal_with_err(qout.asnumpy(), ref_out.asnumpy(), rtol=0.1, atol=atol, etol=0.2)
+
+@use_np
+@pytest.mark.parametrize('batch_size', [1, 32])
+@pytest.mark.parametrize('seq_length', [124, 384])
+@pytest.mark.parametrize('units', [256, 768])
+@pytest.mark.parametrize('num_heads', [4, 8])
+def test_self_attention_negative(batch_size, seq_length, units, num_heads):
+ net = MultiHeadAttention(units, num_heads, negative_case=True)
+ in_data = mx.np.random.uniform(size=[batch_size, seq_length, units], dtype='float32')
+ mask = mx.np.random.uniform(low=0, high=2, size=[batch_size, seq_length, seq_length], dtype='int32')
+
+ net.initialize()
+ fused_net = net
+ net.hybridize()
+ ref_out = net(in_data, mask)
+
+ fused_net.optimize_for(in_data, mask, backend="ONEDNN")
+ out = fused_net(in_data, mask)
+ mx.nd.waitall()
+
+ assert_almost_equal(out.asnumpy(), ref_out.asnumpy())
+
+ calib_data = mx.gluon.data.DataLoader(mx.gluon.data.ArrayDataset(in_data, mask), batch_size=1)
+ qnet = mx.contrib.quant.quantize_net(net, quantized_dtype='auto',
+ exclude_layers=None,
+ exclude_layers_match=None,
+ calib_data=calib_data,
+ calib_mode='naive',
+ num_calib_batches=1,
+ ctx=mx.cpu())
+
+ qout = qnet(in_data, mask)
+ mx.nd.waitall()
+ min_range = np.min(ref_out.asnumpy())
+ max_range = np.max(ref_out.asnumpy())
+ atol = 0.1 * max(abs(min_range), abs(max_range))
+ assert_almost_equal_with_err(qout.asnumpy(), ref_out.asnumpy(), rtol=0.1, atol=atol, etol=0.2)
@use_np
@pytest.mark.parametrize('batch_size', [1, 32])
@@ -133,8 +170,7 @@ def test_batch_dot(batch_size, seq_length, units, num_heads):
qout = qnet(lhs_data, rhs_data)
mx.nd.waitall()
- for i in range(len(ref_out)):
- min_range = np.min(ref_out[i].asnumpy())
- max_range = np.max(ref_out[i].asnumpy())
- atol = 0.1 * max(abs(min_range), abs(max_range))
- assert_almost_equal_with_err(qout[i].asnumpy(), ref_out[i].asnumpy(), rtol=0.1, atol=atol, etol=0.1)
+ min_range = np.min(ref_out.asnumpy())
+ max_range = np.max(ref_out.asnumpy())
+ atol = 0.1 * max(abs(min_range), abs(max_range))
+ assert_almost_equal_with_err(qout.asnumpy(), ref_out.asnumpy(), rtol=0.1, atol=atol, etol=0.1)