You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by bg...@apache.org on 2022/05/19 06:59:26 UTC

[incubator-mxnet] branch master updated: fix transformer optimization for gpt-2 (#21007)

This is an automated email from the ASF dual-hosted git repository.

bgawrych pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 0497efd4e4 fix transformer optimization for gpt-2 (#21007)
0497efd4e4 is described below

commit 0497efd4e4fcd18ed1cc459218daf6e67b38d536
Author: bgawrych <ba...@intel.com>
AuthorDate: Thu May 19 08:59:07 2022 +0200

    fix transformer optimization for gpt-2 (#21007)
---
 .../subgraph/dnnl/dnnl_transformer_qk_property.h   | 78 ++++++++++++++--------
 .../python/dnnl/subgraphs/test_matmul_subgraph.py  | 64 ++++++++++++++----
 2 files changed, 101 insertions(+), 41 deletions(-)

diff --git a/src/operator/subgraph/dnnl/dnnl_transformer_qk_property.h b/src/operator/subgraph/dnnl/dnnl_transformer_qk_property.h
index c117cf67fe..fc14df37e2 100644
--- a/src/operator/subgraph/dnnl/dnnl_transformer_qk_property.h
+++ b/src/operator/subgraph/dnnl/dnnl_transformer_qk_property.h
@@ -51,7 +51,7 @@
 namespace mxnet {
 namespace op {
 
-class SgDNNLTransformerQKSelector : public SubgraphSelector {
+class SgDNNLTransformerQKSelector : public SubgraphSelectorV2 {
   enum SelectStatus {
     kFail = 0,
     kStart,
@@ -69,59 +69,83 @@ class SgDNNLTransformerQKSelector : public SubgraphSelector {
 
  private:
   SelectStatus status_;
-  std::vector<const nnvm::Node*> matched_list_;
+  std::vector<const BiDirectedNode*> matched_list_;
+
+  bool CheckSplitConditions(const BiDirectedNode& node) {
+    const SplitParam& param = dmlc::get<SplitParam>(node.node->attrs.parsed);
+
+    if (param.axis != -1 || param.sections != 3 || param.squeeze_axis)
+      return false;
+
+    const auto first_reshape  = (*(matched_list_.end() - 2))->node;
+    const auto second_reshape = (*(matched_list_.end() - 1))->node;
+    if (first_reshape->op() != Op::Get("_npx_reshape") ||
+        second_reshape->op() != Op::Get("_npx_reshape")) {
+      return false;
+    }
+    // 3 sections - ensure that every output is used only once
+    if (node.outputs.size() == 3 && node.outputs.count(first_reshape) &&
+        node.outputs.count(second_reshape)) {
+      return true;
+    }
+
+    return false;
+  }
 
  public:
-  bool Select(const nnvm::Node& n, const std::shared_ptr<NodeAttr>& node_attr) override {
-    if (n.op() == Op::Get("batch_dot")) {
+  bool Select(const BiDirectedNode& seed_node,
+              const std::shared_ptr<NodeAttr>& node_attr) override {
+    if (seed_node.node->op() == Op::Get("batch_dot")) {
       status_ = kStart;
       matched_list_.clear();
-      matched_list_.push_back(&n);
+      matched_list_.push_back(&seed_node);
       return true;
     }
     return false;
   }
 
-  bool SelectInput(const nnvm::Node& n, const nnvm::Node& new_node) override {
-    if (status_ == kFail || status_ == kSuccess || new_node.is_variable())
+  bool SelectInput(const BiDirectedNode& n, const BiDirectedNode& input_node) override {
+    if (status_ == kFail || status_ == kSuccess || input_node.node->is_variable())
       return false;
-
+    const auto& raw_input_node = *input_node.node;
     switch (status_) {
       case kStart:
-        if (new_node.op() == Op::Get("SwapAxis")) {
-          if (CheckSwapAxisConditions(new_node)) {
+        if (raw_input_node.op() == Op::Get("SwapAxis")) {
+          if (CheckSwapAxisConditions(raw_input_node)) {
             status_ = kFirstSwapAx;
-            matched_list_.push_back(&new_node);
+            matched_list_.push_back(&input_node);
           }
           return true;
         }
       case kFirstSwapAx:
-        if (new_node.op() == Op::Get("SwapAxis")) {
-          if (CheckSwapAxisConditions(new_node)) {
+        if (raw_input_node.op() == Op::Get("SwapAxis")) {
+          if (CheckSwapAxisConditions(raw_input_node)) {
             status_ = kSecondSwapAx;
-            matched_list_.push_back(&new_node);
+            matched_list_.push_back(&input_node);
             return true;
           }
         }
       case kSecondSwapAx:
-        if (new_node.op() == Op::Get("_npx_reshape")) {
+        if (raw_input_node.op() == Op::Get("_npx_reshape")) {
           // input to reshape must be first or second output from split
-          if (CheckReshapeConditions(new_node, 0) || CheckReshapeConditions(new_node, 1)) {
+          if (CheckReshapeConditions(raw_input_node, 0) ||
+              CheckReshapeConditions(raw_input_node, 1)) {
             status_ = kFirstReshape;
-            matched_list_.push_back(&new_node);
+            matched_list_.push_back(&input_node);
             return true;
           }
         }
       case kFirstReshape:
-        if (new_node.op() == Op::Get("_npx_reshape")) {
-          if (CheckReshapeConditions(new_node, 0) || CheckReshapeConditions(new_node, 1)) {
+        if (raw_input_node.op() == Op::Get("_npx_reshape")) {
+          if (CheckReshapeConditions(raw_input_node, 0) ||
+              CheckReshapeConditions(raw_input_node, 1)) {
             status_ = kSecondReshape;
-            matched_list_.push_back(&new_node);
+            matched_list_.push_back(&input_node);
             return true;
           }
         }
       case kSecondReshape:
-        if (new_node.op() == Op::Get("_split_v2")) {
+        if (raw_input_node.op() == Op::Get("_split_v2") && CheckSplitConditions(input_node)) {
           status_ = kSuccess;
           return true;
         }
@@ -132,17 +156,17 @@ class SgDNNLTransformerQKSelector : public SubgraphSelector {
     return false;
   }
 
-  bool SelectOutput(const nnvm::Node& n, const nnvm::Node& new_node) override {
+  bool SelectOutput(const BiDirectedNode& n, const BiDirectedNode& output_node) override {
     return false;
   }
 
-  std::vector<nnvm::Node*> Filter(const std::vector<nnvm::Node*>& candidates) override {
+  std::vector<BiDirectedNode*> Filter(const std::vector<BiDirectedNode*>& candidates) override {
     if (status_ != kSuccess) {
-      return std::vector<nnvm::Node*>(0);
+      return std::vector<BiDirectedNode*>(0);
     } else {
-      std::vector<nnvm::Node*> ret;
+      std::vector<BiDirectedNode*> ret;
       for (auto i : matched_list_) {
-        auto non_const_i = const_cast<nnvm::Node*>(i);
+        auto non_const_i = const_cast<BiDirectedNode*>(i);
         if (std::find(candidates.begin(), candidates.end(), non_const_i) != candidates.end()) {
           ret.push_back(non_const_i);
         }
@@ -201,7 +225,7 @@ class SgDNNLTransformerQKProperty : public SubgraphProperty {
     return n;
   }
 
-  SubgraphSelectorPtr CreateSubgraphSelector() const override {
+  SubgraphSelectorV2Ptr CreateSubgraphSelectorV2() const override {
     auto selector = std::make_shared<SgDNNLTransformerQKSelector>();
     return selector;
   }
diff --git a/tests/python/dnnl/subgraphs/test_matmul_subgraph.py b/tests/python/dnnl/subgraphs/test_matmul_subgraph.py
index 5d95d25da4..a96c8c31ba 100644
--- a/tests/python/dnnl/subgraphs/test_matmul_subgraph.py
+++ b/tests/python/dnnl/subgraphs/test_matmul_subgraph.py
@@ -27,16 +27,19 @@ import math
 
 
 class MultiHeadAttention(nn.HybridBlock):
-  def __init__(self, units, num_heads, dtype='float32', **kwargs):
+  def __init__(self, units, num_heads, dtype='float32', negative_case=False, **kwargs):
       super(MultiHeadAttention, self).__init__(**kwargs)
       self._units = units
       self._num_heads = num_heads
       self._fc = nn.Dense(in_units=self._units, units=3*self._units, flatten=False, dtype=dtype)
       self._scale = math.sqrt(self._units // self._num_heads)
+      self.negative_case = negative_case
 
   def forward(self, x, mask):
       out = self._fc(x)
       query, key, value = mx.np.split(out, 3, axis=-1)
+      if self.negative_case:
+        key = key * 2
       query = mx.npx.reshape(query, (-2, -2, self._num_heads, -1))
       key = mx.npx.reshape(key, (-2, -2, self._num_heads, -1))
       value = mx.npx.reshape(value, (-2, -2, self._num_heads, -1))
@@ -48,7 +51,6 @@ class MultiHeadAttention(nn.HybridBlock):
       context_vec = mx.npx.batch_dot(attn_weights,
                                      mx.np.swapaxes(value, 1, 2)).transpose((0, 2, 1, 3))
       context_vec = mx.npx.reshape(context_vec, (-2, -2, -1))
-      
       return context_vec
 
 @use_np
@@ -70,8 +72,7 @@ def test_self_attention(batch_size, seq_length, units, num_heads):
   out = fused_net(in_data, mask)
   mx.nd.waitall()
 
-  for i in range(len(out)):
-    assert_almost_equal(out[i].asnumpy(), ref_out[i].asnumpy())
+  assert_almost_equal(out.asnumpy(), ref_out.asnumpy())
 
   calib_data = mx.gluon.data.DataLoader(mx.gluon.data.ArrayDataset(in_data, mask), batch_size=1)
   qnet = mx.contrib.quant.quantize_net(net, quantized_dtype='auto',
@@ -85,11 +86,47 @@ def test_self_attention(batch_size, seq_length, units, num_heads):
   qout = qnet(in_data, mask)
   mx.nd.waitall()
 
-  for i in range(len(ref_out)):
-      min_range = np.min(ref_out[i].asnumpy())
-      max_range = np.max(ref_out[i].asnumpy())
-      atol = 0.1 * max(abs(min_range), abs(max_range))
-      assert_almost_equal_with_err(qout[i].asnumpy(), ref_out[i].asnumpy(), rtol=0.1, atol=atol, etol=0.2)
+  min_range = np.min(ref_out.asnumpy())
+  max_range = np.max(ref_out.asnumpy())
+  atol = 0.1 * max(abs(min_range), abs(max_range))
+  assert_almost_equal_with_err(qout.asnumpy(), ref_out.asnumpy(), rtol=0.1, atol=atol, etol=0.2)
+
+@use_np
+@pytest.mark.parametrize('batch_size', [1, 32])
+@pytest.mark.parametrize('seq_length', [124, 384])
+@pytest.mark.parametrize('units', [256, 768])
+@pytest.mark.parametrize('num_heads', [4, 8])
+def test_self_attention_negative(batch_size, seq_length, units, num_heads):
+  net = MultiHeadAttention(units, num_heads, negative_case=True)
+  in_data = mx.np.random.uniform(size=[batch_size, seq_length, units], dtype='float32')
+  mask = mx.np.random.uniform(low=0, high=2, size=[batch_size, seq_length, seq_length], dtype='int32')
+
+  net.initialize()
+  fused_net = net
+  net.hybridize()
+  ref_out = net(in_data, mask)
+
+  fused_net.optimize_for(in_data, mask, backend="ONEDNN")
+  out = fused_net(in_data, mask)
+  mx.nd.waitall()
+
+  assert_almost_equal(out.asnumpy(), ref_out.asnumpy())
+
+  calib_data = mx.gluon.data.DataLoader(mx.gluon.data.ArrayDataset(in_data, mask), batch_size=1)
+  qnet = mx.contrib.quant.quantize_net(net, quantized_dtype='auto',
+                                            exclude_layers=None,
+                                            exclude_layers_match=None,
+                                            calib_data=calib_data,
+                                            calib_mode='naive',
+                                            num_calib_batches=1,
+                                            ctx=mx.cpu())
+
+  qout = qnet(in_data, mask)
+  mx.nd.waitall()
+  min_range = np.min(ref_out.asnumpy())
+  max_range = np.max(ref_out.asnumpy())
+  atol = 0.1 * max(abs(min_range), abs(max_range))
+  assert_almost_equal_with_err(qout.asnumpy(), ref_out.asnumpy(), rtol=0.1, atol=atol, etol=0.2)
 
 @use_np
 @pytest.mark.parametrize('batch_size', [1, 32])
@@ -133,8 +170,7 @@ def test_batch_dot(batch_size, seq_length, units, num_heads):
   qout = qnet(lhs_data, rhs_data)
   mx.nd.waitall()
 
-  for i in range(len(ref_out)):
-      min_range = np.min(ref_out[i].asnumpy())
-      max_range = np.max(ref_out[i].asnumpy())
-      atol = 0.1 * max(abs(min_range), abs(max_range))
-      assert_almost_equal_with_err(qout[i].asnumpy(), ref_out[i].asnumpy(), rtol=0.1, atol=atol, etol=0.1)
+  min_range = np.min(ref_out.asnumpy())
+  max_range = np.max(ref_out.asnumpy())
+  atol = 0.1 * max(abs(min_range), abs(max_range))
+  assert_almost_equal_with_err(qout.asnumpy(), ref_out.asnumpy(), rtol=0.1, atol=atol, etol=0.1)