You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by bg...@apache.org on 2022/03/04 07:35:45 UTC

[incubator-mxnet] branch v1.9.x updated: Assign attributes of transformer operators (#20902)

This is an automated email from the ASF dual-hosted git repository.

bgawrych pushed a commit to branch v1.9.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/v1.9.x by this push:
     new edba375  Assign attributes of transformer operators (#20902)
edba375 is described below

commit edba3755e2d2b287e936886f43c63f02a80777a2
Author: Paweł Głomski <pa...@intel.com>
AuthorDate: Fri Mar 4 08:33:11 2022 +0100

    Assign attributes of transformer operators (#20902)
---
 src/operator/subgraph/mkldnn/mkldnn_transformer_property.h | 9 ++++-----
 1 file changed, 4 insertions(+), 5 deletions(-)

diff --git a/src/operator/subgraph/mkldnn/mkldnn_transformer_property.h b/src/operator/subgraph/mkldnn/mkldnn_transformer_property.h
index f022bcc..8228f4e 100644
--- a/src/operator/subgraph/mkldnn/mkldnn_transformer_property.h
+++ b/src/operator/subgraph/mkldnn/mkldnn_transformer_property.h
@@ -90,16 +90,15 @@ class SgMKLDNNTransformerProperty : public SubgraphProperty {
     new_sym.outputs.emplace_back(last_node);
     std::ostringstream node_name;
     std::string op_name;
-    MKLDNNSelfAttParam new_param;
     DFSVisit(new_sym.outputs, [&](const nnvm::ObjectPtr &node) {
       if (node->op() &&
           (node->op()->name == SELFATT_QK ||
            node->op()->name == SELFATT_VALATT)) {
         op_name = node->op()->name;
         auto param = nnvm::get<InterleavedMatMulParam>(node->attrs.parsed);
-        new_param.heads = param.heads;
-        new_param.quantized = false;
-        new_param.enable_float_output = false;
+        n->attrs.dict["heads"] = std::to_string(param.heads);
+        n->attrs.dict["quantized"] = "False";
+        n->attrs.dict["enable_float_output"] = "False";
       }
     });
     node_name << NameMapping.at(op_name) << "_" << std::to_string(subgraph_id);
@@ -109,7 +108,7 @@ class SgMKLDNNTransformerProperty : public SubgraphProperty {
     n->attrs.op = Op::Get(OpMapping.at(op_name));
     CHECK(n->attrs.op);
     n->attrs.subgraphs.emplace_back(std::make_shared<nnvm::Symbol>(new_sym));
-    n->attrs.parsed = new_param;
+    n->op()->attr_parser(&(n->attrs));
     return n;
   }