You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by bg...@apache.org on 2022/03/04 07:35:45 UTC
[incubator-mxnet] branch v1.9.x updated: Assign attributes of transformer operators (#20902)
This is an automated email from the ASF dual-hosted git repository.
bgawrych pushed a commit to branch v1.9.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/v1.9.x by this push:
new edba375 Assign attributes of transformer operators (#20902)
edba375 is described below
commit edba3755e2d2b287e936886f43c63f02a80777a2
Author: Paweł Głomski <pa...@intel.com>
AuthorDate: Fri Mar 4 08:33:11 2022 +0100
Assign attributes of transformer operators (#20902)
---
src/operator/subgraph/mkldnn/mkldnn_transformer_property.h | 9 ++++-----
1 file changed, 4 insertions(+), 5 deletions(-)
diff --git a/src/operator/subgraph/mkldnn/mkldnn_transformer_property.h b/src/operator/subgraph/mkldnn/mkldnn_transformer_property.h
index f022bcc..8228f4e 100644
--- a/src/operator/subgraph/mkldnn/mkldnn_transformer_property.h
+++ b/src/operator/subgraph/mkldnn/mkldnn_transformer_property.h
@@ -90,16 +90,15 @@ class SgMKLDNNTransformerProperty : public SubgraphProperty {
new_sym.outputs.emplace_back(last_node);
std::ostringstream node_name;
std::string op_name;
- MKLDNNSelfAttParam new_param;
DFSVisit(new_sym.outputs, [&](const nnvm::ObjectPtr &node) {
if (node->op() &&
(node->op()->name == SELFATT_QK ||
node->op()->name == SELFATT_VALATT)) {
op_name = node->op()->name;
auto param = nnvm::get<InterleavedMatMulParam>(node->attrs.parsed);
- new_param.heads = param.heads;
- new_param.quantized = false;
- new_param.enable_float_output = false;
+ n->attrs.dict["heads"] = std::to_string(param.heads);
+ n->attrs.dict["quantized"] = "False";
+ n->attrs.dict["enable_float_output"] = "False";
}
});
node_name << NameMapping.at(op_name) << "_" << std::to_string(subgraph_id);
@@ -109,7 +108,7 @@ class SgMKLDNNTransformerProperty : public SubgraphProperty {
n->attrs.op = Op::Get(OpMapping.at(op_name));
CHECK(n->attrs.op);
n->attrs.subgraphs.emplace_back(std::make_shared<nnvm::Symbol>(new_sym));
- n->attrs.parsed = new_param;
+ n->op()->attr_parser(&(n->attrs));
return n;
}