You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/05/29 06:45:12 UTC
[GitHub] XiaotaoChen opened a new pull request #11083: fix bug that running
dot(csr.T, dns)=dns would told: storage type fallback detected
XiaotaoChen opened a new pull request #11083: fix bug that running dot(csr.T, dns)=dns would told: storage type fallback detected
URL: https://github.com/apache/incubator-mxnet/pull/11083
Description
This PR is to fix a bug that cause the implement code of dot(csr.T, dns)=dns can't be called on CPU and GPU. Due to the function of DotForwardInferStorageType in dot-ink.h haven't map this situation that input stype=[csr, default], output stype=[default] and transpose_a=True to the implement of dot(csr.T, dns)=dns
Details
The following code can reproduce this bug.
import mxnet as mx
from mxnet.test_utils import rand_ndarray
def testBug(dev):
shape_lhs = (200, 200)
shape_rhs = (200, 200)
mx_sparse = rand_ndarray(shape_lhs, 'csr', density=0.01).as_in_context(dev)
mx_dns = rand_ndarray(shape_rhs, 'default', density=1.0).as_in_context(dev)
mx.nd.dot(mx_sparse, mx_dns, transpose_a=True, transpose_b=False
, forward_stype='default')
mx.nd.waitall()
if __name__ == "__main__":
print('test dot(csr.T, dns)=dns on cpu')
testBug(mx.cpu())
print('test dot(csr.T, dns)=dns on gpu')
testBug(mx.gpu())
Here is the log info. It tells the storage types of dot(csr.T, dns)=dns would fallback, and then the actually running code is dot(dns,dns)=dns
test dot(csr.T, dns)=dns on cpu
[21:56:09] src/operator/nn/./../../common/utils.h:416:
Storage type fallback detected:
operator = dot
input storage types = [csr, default, ]
output storage types = [default, ]
params = {"forward_stype" : default, "transpose_b" : False, "transpose_a" : True, }
context.dev_mask = cpu
The operator with default storage type will be dispatched for execution. You're seeing this warning message because the operator above is unable to process the given ndarrays with specified storage types, context and parameter. Temporary dense ndarrays are generated in order to execute the operator. You can set environment variable MXNET_STORAGE_FALLBACK_LOG_VERBOSE to 0 to suppress this warning.
test dot(csr.T, dns)=dns on gpu
[21:57:01] src/operator/nn/./../../common/utils.h:416:
Storage type fallback detected:
operator = dot
input storage types = [csr, default, ]
output storage types = [default, ]
params = {"forward_stype" : default, "transpose_b" : False, "transpose_a" : True, }
context.dev_mask = gpu
The operator with default storage type will be dispatched for execution. You're seeing this warning message because the operator above is unable to process the given ndarrays with specified storage types, context and parameter. Temporary dense ndarrays are generated in order to execute the operator. You can set environment variable MXNET_STORAGE_FALLBACK_LOG_VERBOSE to 0 to suppress this warning.
fix the bug
based on the above analysis, because the function of DotForwardInferStorageType haven't map the corresponding situation to dot(csr.T, dns)=dns.
The relative code is as follows. According to the code, if target_stype is equal to KRowSparseStorage, the implement of dot would map to dot(csr.T, dns)=row_sparse. However, haven't map it to dot(csr.T dns)=dns, when target_stype is equal to kDefaultStorage.
if (!dispatched && lhs_stype == kCSRStorage && only_lhs_transpose && rhs_rsp_or_dns) {
// csr.T, rsp/dns -> rsp
target_stype = hint_has_value ? target_stype : kRowSparseStorage;
if (target_stype == kRowSparseStorage) {
dispatched = storage_type_assign(&out_stype, kRowSparseStorage,
dispatch_mode, DispatchMode::kFComputeEx);
}
}
Obviously, we should add some lines to map dot to dot(csr.T, dns)=dns, the modified code is as below.
if (!dispatched && lhs_stype == kCSRStorage && only_lhs_transpose && rhs_rsp_or_dns) {
// csr.T, rsp/dns -> rsp
target_stype = hint_has_value ? target_stype : kRowSparseStorage;
if (target_stype == kRowSparseStorage) {
dispatched = storage_type_assign(&out_stype, kRowSparseStorage,
dispatch_mode, DispatchMode::kFComputeEx);
}
else if(target_stype == kDefaultStorage){
// csr.T, rsp/dns -> dns
dispatched = storage_type_assign(&out_stype, kDefaultStorage,
dispatch_mode, DispatchMode::kFComputeEx);
}
}
after this modification, i run the test script, The log doesn't output the warning content.
@pengzhao-intel @TaoLv
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
users@infra.apache.org
With regards,
Apache Git Services