You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2020/01/09 10:39:36 UTC

[GitHub] [incubator-mxnet] Wallart opened a new issue #17256: Sparse compression causes errors

Wallart opened a new issue #17256: Sparse compression causes errors 
URL: https://github.com/apache/incubator-mxnet/issues/17256
 
 
   Hello everyone,
   I am trying to use sparse tensors to save memory in my Transformer architecture and I'm applying F.sparse.cast_storage on an attention weights tensor.
   
   ```
   class ScaledDotProductAttn(gluon.HybridBlock):
   
       def __init__(self, dim_k, *args, **kwargs):
           super(ScaledDotProductAttn, self).__init__(*args, **kwargs)
           self._dim_k = dim_k
   
       def hybrid_forward(self, F, *args, **kwargs):
           query, key, value, mask, sparse_pattern = args
   
           matmul_qk = F.linalg.gemm2(query, key, transpose_b=True)  # seq_len_q, seq_len_k
           scaled_attn_logits = matmul_qk / math.sqrt(self._dim_k)
   
           if mask is not None:
               scaled_attn_logits = F.broadcast_add(scaled_attn_logits, mask * -1e9)
   
           attn_weights = F.softmax(scaled_attn_logits)  # seq_len_q, seq_len_k
           if sparse_pattern is not None:
               attn_weights = F.sparse.cast_storage(attn_weights * sparse_pattern, 'csr')
   
           output = F.linalg.gemm2(attn_weights, value)  # seq_len_q, seq_len_k
           return output, attn_weights
   ```
   
   As you can see the sparseNDArray is densified on the fly to produce output (because value is not sparse). Then, I return a dense output and a sparse attn_weights.
   Output will be finally used to compute the loss, and attn_weights for plotting if necessary.
   
   The error occurs when I'm updating the loss metric which is calling asnumpy internally.
   ```
   Traceback (most recent call last):
     File "/home/wallart/workspaces/Transformer/trainer/transformer_trainer.py", line 77, in train
       self._loss_metric.update(0, [l * self._opts.batch_size for l in losses])
     File "/opt/miniconda3/envs/intelpython3/lib/python3.6/site-packages/mxnet-1.6.0-py3.6.egg/mxnet/metric.py", line 1687, in update
       loss = ndarray.sum(pred).asscalar()
     File "/opt/miniconda3/envs/intelpython3/lib/python3.6/site-packages/mxnet-1.6.0-py3.6.egg/mxnet/ndarray/ndarray.py", line 2553, in asscalar
       return self.asnumpy()[0]
     File "/opt/miniconda3/envs/intelpython3/lib/python3.6/site-packages/mxnet-1.6.0-py3.6.egg/mxnet/ndarray/ndarray.py", line 2535, in asnumpy
       ctypes.c_size_t(data.size)))
     File "/opt/miniconda3/envs/intelpython3/lib/python3.6/site-packages/mxnet-1.6.0-py3.6.egg/mxnet/base.py", line 255, in check_call
       raise MXNetError(py_str(_LIB.MXGetLastError()))
   mxnet.base.MXNetError: [10:01:05] src/operator/tensor/././cast_storage-inl.cuh:470: Check failed: dns.shape_.ndim() == 2 (4 vs. 2)
   ```
   
   The issue occurs both on MXNet 1.5.1 and 1.6.0.rc0.
   Everything works if I disable the F.sparse.cast_storage call
   

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services