You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/05/18 07:39:13 UTC

[GitHub] chinakook opened a new issue #10995: Some mxnet ctc_loss bug~

chinakook opened a new issue #10995: Some mxnet ctc_loss bug~
URL: https://github.com/apache/incubator-mxnet/issues/10995
 
 
   Mxnet ctc_loss has nearly the same source code with baidu's warpctc with little modifications, but it has some bugs.
   ```python
   import mxnet as mx
   import numpy as np
   import numpy.random as npr
   ```
   
   ### Case 1 - mxnet ctc_loss is all right
   ```python
   batch_size = 1024
   seq_len = 35
   label_len = 10
   num_classes = 60
   
   x = mx.nd.random.uniform(shape=(seq_len, batch_size, num_classes), ctx=mx.gpu(0))
   y = npr.randint(0, num_classes, size=(batch_size, label_len))
   Y = mx.nd.array(y, ctx=mx.gpu(0)) # float label type
   
   loss = mx.nd.contrib.ctc_loss(data=x, label=Y)
   loss = mx.nd.make_loss(loss)
   print(loss.asnumpy())
   
   ```
   
   ### Case 2 - mxnet ctc_loss cannot support integer label types
   ```python
   batch_size = 1024
   seq_len = 35
   label_len = 10
   num_classes = 60
   
   x = mx.nd.random.uniform(shape=(seq_len, batch_size, num_classes), ctx=mx.gpu(0))
   y = npr.randint(0, num_classes, size=(batch_size, label_len))
   Y = mx.nd.array(y, ctx=mx.gpu(0), dtype=np.int32)
   
   loss = mx.nd.contrib.ctc_loss(data=x, label=Y)
   loss = mx.nd.make_loss(loss)
   print(loss.asnumpy())
   ```
   
   ### Case 3 - mxnet ctc_loss is slow or will crash when num_classes is big
   ```python
   batch_size = 1024
   seq_len = 35
   label_len = 10
   num_classes = 6000
   
   x = mx.nd.random.uniform(shape=(seq_len, batch_size, num_classes), ctx=mx.gpu(0))
   y = npr.randint(0, num_classes, size=(batch_size, label_len))
   Y = mx.nd.array(y, ctx=mx.gpu(0), dtype=np.int32)
   
   loss = mx.nd.contrib.ctc_loss(data=x, label=Y)
   loss = mx.nd.make_loss(loss)
   print(loss.asnumpy())
   
   x = mx.nd.Reshape(x, shape=(-3, -2))
   Y = mx.nd.Reshape(Y, shape=(-1,))
   loss = mx.nd.WarpCTC(data=x, label=Y, label_length=label_len, input_length=seq_len)
   print(loss)
   ```
   
   ### Case 4 - warpctc is all OK with big num_classes and integer types
   ```python
   batch_size = 1024
   seq_len = 35
   label_len = 10
   num_classes = 6000
   
   x = mx.nd.random.uniform(shape=(seq_len, batch_size, num_classes), ctx=mx.gpu(0))
   y = npr.randint(0, num_classes, size=(batch_size, label_len))
   Y = mx.nd.array(y, ctx=mx.gpu(0), dtype=np.int32)
   
   x = mx.nd.Reshape(x, shape=(-3, -2))
   Y = mx.nd.Reshape(Y, shape=(-1,))
   loss = mx.nd.WarpCTC(data=x, label=Y, label_length=label_len, input_length=seq_len)
   print(loss)
   ```

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services