You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2020/01/09 10:39:36 UTC
[GitHub] [incubator-mxnet] Wallart opened a new issue #17256: Sparse
compression causes errors
Wallart opened a new issue #17256: Sparse compression causes errors
URL: https://github.com/apache/incubator-mxnet/issues/17256
Hello everyone,
I am trying to use sparse tensors to save memory in my Transformer architecture and I'm applying F.sparse.cast_storage on an attention weights tensor.
```
class ScaledDotProductAttn(gluon.HybridBlock):
def __init__(self, dim_k, *args, **kwargs):
super(ScaledDotProductAttn, self).__init__(*args, **kwargs)
self._dim_k = dim_k
def hybrid_forward(self, F, *args, **kwargs):
query, key, value, mask, sparse_pattern = args
matmul_qk = F.linalg.gemm2(query, key, transpose_b=True) # seq_len_q, seq_len_k
scaled_attn_logits = matmul_qk / math.sqrt(self._dim_k)
if mask is not None:
scaled_attn_logits = F.broadcast_add(scaled_attn_logits, mask * -1e9)
attn_weights = F.softmax(scaled_attn_logits) # seq_len_q, seq_len_k
if sparse_pattern is not None:
attn_weights = F.sparse.cast_storage(attn_weights * sparse_pattern, 'csr')
output = F.linalg.gemm2(attn_weights, value) # seq_len_q, seq_len_k
return output, attn_weights
```
As you can see the sparseNDArray is densified on the fly to produce output (because value is not sparse). Then, I return a dense output and a sparse attn_weights.
Output will be finally used to compute the loss, and attn_weights for plotting if necessary.
The error occurs when I'm updating the loss metric which is calling asnumpy internally.
```
Traceback (most recent call last):
File "/home/wallart/workspaces/Transformer/trainer/transformer_trainer.py", line 77, in train
self._loss_metric.update(0, [l * self._opts.batch_size for l in losses])
File "/opt/miniconda3/envs/intelpython3/lib/python3.6/site-packages/mxnet-1.6.0-py3.6.egg/mxnet/metric.py", line 1687, in update
loss = ndarray.sum(pred).asscalar()
File "/opt/miniconda3/envs/intelpython3/lib/python3.6/site-packages/mxnet-1.6.0-py3.6.egg/mxnet/ndarray/ndarray.py", line 2553, in asscalar
return self.asnumpy()[0]
File "/opt/miniconda3/envs/intelpython3/lib/python3.6/site-packages/mxnet-1.6.0-py3.6.egg/mxnet/ndarray/ndarray.py", line 2535, in asnumpy
ctypes.c_size_t(data.size)))
File "/opt/miniconda3/envs/intelpython3/lib/python3.6/site-packages/mxnet-1.6.0-py3.6.egg/mxnet/base.py", line 255, in check_call
raise MXNetError(py_str(_LIB.MXGetLastError()))
mxnet.base.MXNetError: [10:01:05] src/operator/tensor/././cast_storage-inl.cuh:470: Check failed: dns.shape_.ndim() == 2 (4 vs. 2)
```
The issue occurs both on MXNet 1.5.1 and 1.6.0.rc0.
Everything works if I disable the F.sparse.cast_storage call
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
users@infra.apache.org
With regards,
Apache Git Services