You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ak...@apache.org on 2021/08/26 21:29:11 UTC
[incubator-mxnet] branch v1.x updated: Assign attributes of
transformer operators (#20555)
This is an automated email from the ASF dual-hosted git repository.
akarbown pushed a commit to branch v1.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/v1.x by this push:
new 2b9607a Assign attributes of transformer operators (#20555)
2b9607a is described below
commit 2b9607a9f5eb2ca72c05305c27cc772a4b00db80
Author: Paweł Głomski <pa...@intel.com>
AuthorDate: Thu Aug 26 23:27:32 2021 +0200
Assign attributes of transformer operators (#20555)
---
src/operator/subgraph/mkldnn/mkldnn_transformer_property.h | 9 ++++-----
1 file changed, 4 insertions(+), 5 deletions(-)
diff --git a/src/operator/subgraph/mkldnn/mkldnn_transformer_property.h b/src/operator/subgraph/mkldnn/mkldnn_transformer_property.h
index 03a1d21..ee523a7 100644
--- a/src/operator/subgraph/mkldnn/mkldnn_transformer_property.h
+++ b/src/operator/subgraph/mkldnn/mkldnn_transformer_property.h
@@ -87,14 +87,13 @@ class SgMKLDNNTransformerProperty : public SubgraphProperty {
new_sym.outputs.emplace_back(last_node);
std::ostringstream node_name;
std::string op_name;
- MKLDNNSelfAttParam new_param;
DFSVisit(new_sym.outputs, [&](const nnvm::ObjectPtr& node) {
if (node->op() && (node->op()->name == SELFATT_QK || node->op()->name == SELFATT_VALATT)) {
op_name = node->op()->name;
auto param = nnvm::get<InterleavedMatMulParam>(node->attrs.parsed);
- new_param.heads = param.heads;
- new_param.quantized = false;
- new_param.enable_float_output = false;
+ n->attrs.dict["heads"] = std::to_string(param.heads);
+ n->attrs.dict["quantized"] = "False";
+ n->attrs.dict["enable_float_output"] = "False";
}
});
node_name << NameMapping.at(op_name) << "_" << std::to_string(subgraph_id);
@@ -103,7 +102,7 @@ class SgMKLDNNTransformerProperty : public SubgraphProperty {
n->attrs.op = Op::Get(OpMapping.at(op_name));
CHECK(n->attrs.op);
n->attrs.subgraphs.emplace_back(std::make_shared<nnvm::Symbol>(new_sym));
- n->attrs.parsed = new_param;
+ n->op()->attr_parser(&(n->attrs));
return n;
}