You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ak...@apache.org on 2021/08/26 21:29:11 UTC

[incubator-mxnet] branch v1.x updated: Assign attributes of transformer operators (#20555)

This is an automated email from the ASF dual-hosted git repository.

akarbown pushed a commit to branch v1.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/v1.x by this push:
     new 2b9607a  Assign attributes of transformer operators (#20555)
2b9607a is described below

commit 2b9607a9f5eb2ca72c05305c27cc772a4b00db80
Author: Paweł Głomski <pa...@intel.com>
AuthorDate: Thu Aug 26 23:27:32 2021 +0200

    Assign attributes of transformer operators (#20555)
---
 src/operator/subgraph/mkldnn/mkldnn_transformer_property.h | 9 ++++-----
 1 file changed, 4 insertions(+), 5 deletions(-)

diff --git a/src/operator/subgraph/mkldnn/mkldnn_transformer_property.h b/src/operator/subgraph/mkldnn/mkldnn_transformer_property.h
index 03a1d21..ee523a7 100644
--- a/src/operator/subgraph/mkldnn/mkldnn_transformer_property.h
+++ b/src/operator/subgraph/mkldnn/mkldnn_transformer_property.h
@@ -87,14 +87,13 @@ class SgMKLDNNTransformerProperty : public SubgraphProperty {
     new_sym.outputs.emplace_back(last_node);
     std::ostringstream node_name;
     std::string op_name;
-    MKLDNNSelfAttParam new_param;
     DFSVisit(new_sym.outputs, [&](const nnvm::ObjectPtr& node) {
       if (node->op() && (node->op()->name == SELFATT_QK || node->op()->name == SELFATT_VALATT)) {
         op_name                       = node->op()->name;
         auto param                    = nnvm::get<InterleavedMatMulParam>(node->attrs.parsed);
-        new_param.heads               = param.heads;
-        new_param.quantized           = false;
-        new_param.enable_float_output = false;
+        n->attrs.dict["heads"]               = std::to_string(param.heads);
+        n->attrs.dict["quantized"]           = "False";
+        n->attrs.dict["enable_float_output"] = "False";
       }
     });
     node_name << NameMapping.at(op_name) << "_" << std::to_string(subgraph_id);
@@ -103,7 +102,7 @@ class SgMKLDNNTransformerProperty : public SubgraphProperty {
     n->attrs.op   = Op::Get(OpMapping.at(op_name));
     CHECK(n->attrs.op);
     n->attrs.subgraphs.emplace_back(std::make_shared<nnvm::Symbol>(new_sym));
-    n->attrs.parsed = new_param;
+    n->op()->attr_parser(&(n->attrs));
     return n;
   }