You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/05/18 07:39:13 UTC
[GitHub] chinakook opened a new issue #10995: Some mxnet ctc_loss bug~
chinakook opened a new issue #10995: Some mxnet ctc_loss bug~
URL: https://github.com/apache/incubator-mxnet/issues/10995
Mxnet ctc_loss has nearly the same source code with baidu's warpctc with little modifications, but it has some bugs.
```python
import mxnet as mx
import numpy as np
import numpy.random as npr
```
### Case 1 - mxnet ctc_loss is all right
```python
batch_size = 1024
seq_len = 35
label_len = 10
num_classes = 60
x = mx.nd.random.uniform(shape=(seq_len, batch_size, num_classes), ctx=mx.gpu(0))
y = npr.randint(0, num_classes, size=(batch_size, label_len))
Y = mx.nd.array(y, ctx=mx.gpu(0)) # float label type
loss = mx.nd.contrib.ctc_loss(data=x, label=Y)
loss = mx.nd.make_loss(loss)
print(loss.asnumpy())
```
### Case 2 - mxnet ctc_loss cannot support integer label types
```python
batch_size = 1024
seq_len = 35
label_len = 10
num_classes = 60
x = mx.nd.random.uniform(shape=(seq_len, batch_size, num_classes), ctx=mx.gpu(0))
y = npr.randint(0, num_classes, size=(batch_size, label_len))
Y = mx.nd.array(y, ctx=mx.gpu(0), dtype=np.int32)
loss = mx.nd.contrib.ctc_loss(data=x, label=Y)
loss = mx.nd.make_loss(loss)
print(loss.asnumpy())
```
### Case 3 - mxnet ctc_loss is slow or will crash when num_classes is big
```python
batch_size = 1024
seq_len = 35
label_len = 10
num_classes = 6000
x = mx.nd.random.uniform(shape=(seq_len, batch_size, num_classes), ctx=mx.gpu(0))
y = npr.randint(0, num_classes, size=(batch_size, label_len))
Y = mx.nd.array(y, ctx=mx.gpu(0), dtype=np.int32)
loss = mx.nd.contrib.ctc_loss(data=x, label=Y)
loss = mx.nd.make_loss(loss)
print(loss.asnumpy())
x = mx.nd.Reshape(x, shape=(-3, -2))
Y = mx.nd.Reshape(Y, shape=(-1,))
loss = mx.nd.WarpCTC(data=x, label=Y, label_length=label_len, input_length=seq_len)
print(loss)
```
### Case 4 - warpctc is all OK with big num_classes and integer types
```python
batch_size = 1024
seq_len = 35
label_len = 10
num_classes = 6000
x = mx.nd.random.uniform(shape=(seq_len, batch_size, num_classes), ctx=mx.gpu(0))
y = npr.randint(0, num_classes, size=(batch_size, label_len))
Y = mx.nd.array(y, ctx=mx.gpu(0), dtype=np.int32)
x = mx.nd.Reshape(x, shape=(-3, -2))
Y = mx.nd.Reshape(Y, shape=(-1,))
loss = mx.nd.WarpCTC(data=x, label=Y, label_length=label_len, input_length=seq_len)
print(loss)
```
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
users@infra.apache.org
With regards,
Apache Git Services