You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by gi...@git.apache.org on 2017/08/30 05:24:56 UTC

[GitHub] szha commented on a change in pull request #7660: fix ctc on softmax grad and req option

szha commented on a change in pull request #7660: fix ctc on softmax grad and req option
URL: https://github.com/apache/incubator-mxnet/pull/7660#discussion_r135972464
 
 

 ##########
 File path: tests/python/gpu/test_operator_gpu.py
 ##########
 @@ -1357,6 +1357,28 @@ def test_autograd_save_memory():
             x.wait_to_read()
     x.backward()
 
+def test_gluon_ctc_consistency():
+    loss = mx.gluon.loss.CTCLoss(padding_mask=0)
+    data = mx.nd.arange(0, 4, repeat=40, ctx=mx.gpu(0)).reshape((2,20,4)).flip(axis=0)
+    cpu_label = mx.nd.array([[2,1,0,0],[3,2,2,0]], ctx=mx.cpu(0))
+    gpu_label = mx.nd.array([[2,1,0,0],[3,2,2,0]], ctx=mx.gpu(0))
+
+    cpu_data = data.copy().as_in_context(mx.cpu(0))
+    cpu_data.attach_grad()
+    with mx.autograd.record():
+        l_cpu = loss(cpu_data, cpu_label)
+        l_cpu.backward()
+    cpu_data.detach()
+
+    gpu_data = data.copyto(mx.gpu(0))
+    gpu_data.attach_grad()
+    with mx.autograd.record():
+        l_gpu = loss(gpu_data, gpu_label)
+        l_gpu.backward()
+    gpu_data.detach()
+
+    assert_almost_equal(cpu_data.grad.asnumpy(), gpu_data.grad.asnumpy(), atol=1e-3, rtol=1e-3)
 
 Review comment:
   ```
   In [1]: from __future__ import print_function
      ...: from mxnet import gluon
      ...: import mxnet as mx
      ...: loss = gluon.loss.CTCLoss(padding_mask=0)
      ...: data = mx.nd.arange(0, 4, repeat=40).reshape((2,20,4)).flip(axis=0)
      ...: label = mx.nd.array([[2,1,0,0],[3,2,2,0]])
      ...: data.attach_grad()
      ...: with mx.autograd.record():
      ...:     l = loss(data, label)
      ...:     print(l)
      ...:     l.backward()
      ...: data.detach()
      ...: print(data.grad)
      ...:
   
   [ 18.82820702  16.50581741]
   <NDArray 2 @cpu(0)>
   
   [[[-0.56818217  0.25        0.06818183  0.25      ]
     [-0.43571669  0.24740258 -0.06168866  0.25      ]
     [-0.34275508  0.24015717 -0.14740321  0.25      ]
     [-0.28055429  0.22675993 -0.19620776  0.25      ]
     [-0.24145621  0.20625414 -0.21479931  0.25      ]
     [-0.21889889  0.17822932 -0.20933169  0.25      ]
     [-0.20741701  0.14282268 -0.18540773  0.25      ]
     [-0.2026315   0.10071747 -0.1480875   0.25      ]
     [-0.20126608  0.05314353 -0.101881    0.25      ]
     [-0.20112836  0.00187904 -0.05075252  0.25      ]
     [-0.20112836 -0.05075252  0.00187904  0.25      ]
     [-0.20126608 -0.101881    0.05314353  0.25      ]
     [-0.2026315  -0.1480875   0.10071747  0.25      ]
     [-0.20741701 -0.18540773  0.14282268  0.25      ]
     [-0.21890068 -0.20933169  0.17822932  0.25      ]
     [-0.24145621 -0.21479931  0.20625414  0.25      ]
     [-0.28055429 -0.19620776  0.22675993  0.25      ]
     [-0.34275508 -0.14740321  0.24015717  0.25      ]
     [-0.43571669 -0.06168866  0.24740258  0.25      ]
     [-0.56818217  0.06818183  0.25        0.25      ]]
   
    [[-0.47727478  0.25        0.25       -0.02272752]
     [-0.32142842  0.25        0.23701297 -0.16558546]
     [-0.238722    0.25        0.20625421 -0.21753165]
     [-0.1993053   0.25        0.15863517 -0.2093308 ]
     [-0.18393114  0.25        0.0986056  -0.1646733 ]
     [-0.18109006  0.25        0.03234358 -0.1012527 ]
     [-0.1845623   0.25       -0.03370428 -0.0317339 ]
     [-0.19141069  0.25       -0.09393495  0.03534578]
     [-0.2004036   0.25       -0.14435497  0.09475859]
     [-0.21085775  0.25       -0.18299362  0.14385229]
     [-0.22191447  0.25       -0.20997432  0.18188837]
     [-0.23224759  0.25       -0.22722232  0.20947087]
     [-0.24019703  0.25       -0.23785028  0.22804667]
     [-0.24432448  0.25       -0.24516809  0.23949245]
     [-0.24441877  0.25       -0.2513606   0.2457782 ]
     [-0.24290282  0.25       -0.25580966  0.24871336]
     [-0.24669671  0.25       -0.25307524  0.24977216]
     [-0.26948035  0.25       -0.23051959  0.25      ]
     [-0.33441514  0.25       -0.16558546  0.25      ]
     [-0.47727478  0.25       -0.02272752  0.25      ]]]
   <NDArray 2x20x4 @cpu(0)>
   ```
 
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services