You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/05/29 06:45:12 UTC

[GitHub] XiaotaoChen opened a new pull request #11083: fix bug that running dot(csr.T, dns)=dns would told: storage type fallback detected

XiaotaoChen opened a new pull request #11083: fix bug that running dot(csr.T, dns)=dns would told: storage type fallback detected
URL: https://github.com/apache/incubator-mxnet/pull/11083
 
 
   Description
   
   This PR is to fix a bug that cause the implement code of  dot(csr.T, dns)=dns can't be called on CPU and GPU. Due to  the function of DotForwardInferStorageType  in dot-ink.h  haven't map this situation that input stype=[csr, default], output stype=[default] and transpose_a=True to the implement of  dot(csr.T, dns)=dns 
   
   Details
   
   The following code can reproduce this bug.
   
       import mxnet as mx
       from mxnet.test_utils import rand_ndarray
       
       def testBug(dev):
           shape_lhs = (200, 200)
           shape_rhs = (200, 200)
           mx_sparse = rand_ndarray(shape_lhs, 'csr', density=0.01).as_in_context(dev)
           mx_dns = rand_ndarray(shape_rhs, 'default', density=1.0).as_in_context(dev)
           mx.nd.dot(mx_sparse, mx_dns, transpose_a=True, transpose_b=False
                     , forward_stype='default')
           mx.nd.waitall()
           
       if __name__ == "__main__":
           print('test dot(csr.T, dns)=dns on cpu')
           testBug(mx.cpu())
           print('test dot(csr.T, dns)=dns on gpu')
           testBug(mx.gpu())
   
   Here is the log info. It tells the storage types of  dot(csr.T, dns)=dns  would fallback, and then the actually running code is  dot(dns,dns)=dns
   
       test dot(csr.T, dns)=dns on cpu
       [21:56:09] src/operator/nn/./../../common/utils.h:416: 
       Storage type fallback detected:
       operator = dot
       input storage types = [csr, default, ]
       output storage types = [default, ]
       params = {"forward_stype" : default, "transpose_b" : False, "transpose_a" : True, }
       context.dev_mask = cpu
       The operator with default storage type will be dispatched for execution. You're seeing this warning message because the operator above is unable to process the given ndarrays with specified storage types, context and parameter. Temporary dense ndarrays are generated in order to execute the operator. You can set environment variable MXNET_STORAGE_FALLBACK_LOG_VERBOSE to 0 to suppress this warning.
       test dot(csr.T, dns)=dns on gpu
       [21:57:01] src/operator/nn/./../../common/utils.h:416: 
       Storage type fallback detected:
       operator = dot
       input storage types = [csr, default, ]
       output storage types = [default, ]
       params = {"forward_stype" : default, "transpose_b" : False, "transpose_a" : True, }
       context.dev_mask = gpu
       The operator with default storage type will be dispatched for execution. You're seeing this warning message because the operator above is unable to process the given ndarrays with specified storage types, context and parameter. Temporary dense ndarrays are generated in order to execute the operator. You can set environment variable MXNET_STORAGE_FALLBACK_LOG_VERBOSE to 0 to suppress this warning.
   
   fix the bug
   
   based on the above analysis, because the function of DotForwardInferStorageType haven't map the corresponding situation to  dot(csr.T, dns)=dns.  
   
   The relative code is as follows. According to the code, if  target_stype is equal to  KRowSparseStorage, the implement of dot would map to dot(csr.T, dns)=row_sparse. However, haven't map it to dot(csr.T dns)=dns, when target_stype is equal to kDefaultStorage.  
   
         if (!dispatched && lhs_stype == kCSRStorage && only_lhs_transpose && rhs_rsp_or_dns) {
           // csr.T, rsp/dns -> rsp
           target_stype = hint_has_value ? target_stype : kRowSparseStorage;
           if (target_stype == kRowSparseStorage) {
             dispatched = storage_type_assign(&out_stype, kRowSparseStorage,
                                              dispatch_mode, DispatchMode::kFComputeEx);
           }
         }
   
    Obviously, we should add some lines to map dot to dot(csr.T, dns)=dns, the modified code is as below. 
   
         if (!dispatched && lhs_stype == kCSRStorage && only_lhs_transpose && rhs_rsp_or_dns) {
           // csr.T, rsp/dns -> rsp
           target_stype = hint_has_value ? target_stype : kRowSparseStorage;
           if (target_stype == kRowSparseStorage) {
             dispatched = storage_type_assign(&out_stype, kRowSparseStorage,
                                              dispatch_mode, DispatchMode::kFComputeEx);
           }
           else if(target_stype == kDefaultStorage){
             // csr.T, rsp/dns -> dns
             dispatched = storage_type_assign(&out_stype, kDefaultStorage,
                                              dispatch_mode, DispatchMode::kFComputeEx);
           }
         }
   
   after this modification, i run the test script, The log doesn't output the warning content. 
   @pengzhao-intel @TaoLv  

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services