You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/09/12 19:02:06 UTC

[GitHub] szha closed pull request #12468: [MXNET-807] Support integer label type in ctc_loss operator

szha closed pull request #12468: [MXNET-807] Support integer label type in ctc_loss operator
URL: https://github.com/apache/incubator-mxnet/pull/12468
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/src/operator/contrib/ctc_loss-inl.h b/src/operator/contrib/ctc_loss-inl.h
index 9380be47451..c8a8b263740 100644
--- a/src/operator/contrib/ctc_loss-inl.h
+++ b/src/operator/contrib/ctc_loss-inl.h
@@ -256,66 +256,69 @@ class CTCLossOp : public Operator {
     exceed_cudnn_limit = false;
     Stream<xpu> *s = ctx.get_stream<xpu>();
 
-    Tensor<xpu, 3, real_t> data =
+    MSHADOW_TYPE_SWITCH(in_data[ctc_loss::kLabel].type_flag_, DType, {
+      Tensor<xpu, 3, real_t> data =
         in_data[ctc_loss::kData].get<xpu, 3, real_t>(s);
-    Tensor<xpu, 2, real_t> labels =
-        in_data[ctc_loss::kLabel].get<xpu, 2, real_t>(s);
+      Tensor<xpu, 2, DType> labels =
+        in_data[ctc_loss::kLabel].get<xpu, 2, DType>(s);
 
-    Tensor<xpu, 1, real_t> costs =
+      Tensor<xpu, 1, real_t> costs =
         out_data[ctc_loss::kOut].get<xpu, 1, real_t>(s);
-    Tensor<xpu, 3, real_t> grad =
+      Tensor<xpu, 3, real_t> grad =
         out_data[ctc_loss::kGrad].get<xpu, 3, real_t>(s);
 
-    int max_seq_len = data.size(0);
-    int batch_size = data.size(1);
-    int alphabet_size = data.size(2);
-
-    // data_lengths
-    std::vector<int> data_lengths(batch_size, max_seq_len);
-    if (param_.use_data_lengths) {
-      int kInputLength = 2;
-      IndexTensorToVector(in_data[kInputLength].get<xpu, 1, real_t>(s), &data_lengths);
-    }
-
-    // label_lengths
-    std::vector<int> packed_labels;
-    std::vector<int> label_lengths(batch_size);
-
-    if (param_.use_label_lengths) {
-      int kLabelLength = 2+param_.use_data_lengths;
-      exceed_cudnn_limit = PackLabelByLength(labels, in_data[kLabelLength].get<xpu, 1, real_t>(s),
-                                             &packed_labels, &label_lengths);
-    } else {
-      exceed_cudnn_limit = LabelTensorToPackedVector(labels, param_.blank_label == 0?0:-1,
-                                                     &packed_labels, &label_lengths);
-    }
-
-// CUDNN is disabled due to lack of support for input lengths
-/* #if defined(__CUDACC__) && MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7 */
-/*     if (!exceed_cudnn_limit) { */
-/*       cudnn_forward(ctx, s, data, costs, grad, */
-/*                     &data_lengths, &label_lengths, &packed_labels, */
-/*                     max_seq_len, batch_size, alphabet_size, */
-/*                     req[ctc_loss::kGrad] != mxnet::kNullOp); */
-/*     } else { */
-/*       baidu_forward(ctx, s, data, costs, grad, */
-/*                     &data_lengths, &label_lengths, &packed_labels, */
-/*                     batch_size, alphabet_size, req[ctc_loss::kGrad] != mxnet::kNullOp); */
-/*     } */
-/* #else */
-
-    baidu_forward(ctx, s, data, costs, grad,
-                  &data_lengths, &label_lengths, &packed_labels,
-                  batch_size, alphabet_size, req[ctc_loss::kGrad] != mxnet::kNullOp);
-
-    if (param_.use_data_lengths) {
-      // baidu warp CTC implementation sometimes includes undefined gradients
-      // for data outside of length mask. Setting to 0 to make it consistent
-      // with CPU implementation.
-      int kInputLength = 2;
-      mxnet_op::SequenceMask(grad, in_data[kInputLength].get<xpu, 1, real_t>(s),
-                             static_cast<real_t>(0));
-    }
+      int max_seq_len = data.size(0);
+      int batch_size = data.size(1);
+      int alphabet_size = data.size(2);
+
+      // data_lengths
+      std::vector<int> data_lengths(batch_size, max_seq_len);
+      if (param_.use_data_lengths) {
+        int kInputLength = 2;
+        IndexTensorToVector(in_data[kInputLength].get<xpu, 1, real_t>(s), &data_lengths);
+      }
+
+      // label_lengths
+      std::vector<int> packed_labels;
+      std::vector<int> label_lengths(batch_size);
+
+      if (param_.use_label_lengths) {
+        int kLabelLength = 2 + param_.use_data_lengths;
+        exceed_cudnn_limit =
+          PackLabelByLength(labels, in_data[kLabelLength].get<xpu, 1, DType>(s),
+                           &packed_labels, &label_lengths);
+      } else {
+        exceed_cudnn_limit = LabelTensorToPackedVector(labels, param_.blank_label == 0 ? 0 : -1,
+                                                      &packed_labels, &label_lengths);
+      }
+
+      // CUDNN is disabled due to lack of support for input lengths
+      /* #if defined(__CUDACC__) && MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7 */
+      /*     if (!exceed_cudnn_limit) { */
+      /*       cudnn_forward(ctx, s, data, costs, grad, */
+      /*                     &data_lengths, &label_lengths, &packed_labels, */
+      /*                     max_seq_len, batch_size, alphabet_size, */
+      /*                     req[ctc_loss::kGrad] != mxnet::kNullOp); */
+      /*     } else { */
+      /*       baidu_forward(ctx, s, data, costs, grad, */
+      /*                     &data_lengths, &label_lengths, &packed_labels, */
+      /*                     batch_size, alphabet_size, req[ctc_loss::kGrad] != mxnet::kNullOp);*/
+      /*     } */
+      /* #else */
+
+      baidu_forward(ctx, s, data, costs, grad,
+                    &data_lengths, &label_lengths, &packed_labels,
+                    batch_size, alphabet_size, req[ctc_loss::kGrad] != mxnet::kNullOp);
+
+      if (param_.use_data_lengths) {
+        // baidu warp CTC implementation sometimes includes undefined gradients
+        // for data outside of length mask. Setting to 0 to make it consistent
+        // with CPU implementation.
+        int kInputLength = 2;
+        mxnet_op::SequenceMask(grad, in_data[kInputLength].get<xpu, 1, real_t>(s),
+                              static_cast<real_t>(0));
+      }
+    });
   }
 
   virtual void Backward(const OpContext &ctx,
@@ -434,17 +437,17 @@ class CTCLossOp : public Operator {
   }
 #endif  // __CUDACC__ && CUDNN
 
-  inline virtual void baidu_forward(const OpContext &ctx,
-                                    mshadow::Stream<xpu>* s,
-                                    mshadow::Tensor<xpu, 3, real_t> data,
-                                    mshadow::Tensor<xpu, 1, real_t> costs,
-                                    mshadow::Tensor<xpu, 3, real_t> grad,
-                                    std::vector<int>* data_lengths,
-                                    std::vector<int>* label_lengths,
-                                    std::vector<int>* packed_labels,
-                                    int batch_size,
-                                    int alphabet_size,
-                                    bool req_grad) {
+  inline void baidu_forward(const OpContext &ctx,
+                            mshadow::Stream<xpu>* s,
+                            mshadow::Tensor<xpu, 3, real_t> data,
+                            mshadow::Tensor<xpu, 1, real_t> costs,
+                            mshadow::Tensor<xpu, 3, real_t> grad,
+                            std::vector<int>* data_lengths,
+                            std::vector<int>* label_lengths,
+                            std::vector<int>* packed_labels,
+                            int batch_size,
+                            int alphabet_size,
+                            bool req_grad) {
     using namespace mshadow;
     // allocate temporary workspace
     size_t size_bytes;
@@ -461,7 +464,7 @@ class CTCLossOp : public Operator {
     compute_ctc_cost(data, costs.dptr_, grad.dptr_, packed_labels->data(),
                      label_lengths->data(), data_lengths->data(),
                      workspace.dptr_, req_grad,
-                     param_.blank_label == 0?0:(alphabet_size-1));
+                     param_.blank_label == 0 ? 0 : (alphabet_size-1));
   }
 };  // class CTCLossOp
 
@@ -534,11 +537,24 @@ class CTCLossProp : public OperatorProperty {
     TShape oshape(1);
     oshape[0] = dshape[1];  // batch size
     out_shape->clear();
-    out_shape->push_back(oshape);
+    out_shape->push_back(oshape);  // forward output
     out_shape->push_back(dshape);  // grad output
     return true;
   }
 
+  bool InferType(std::vector<int> *in_type,
+    std::vector<int> *out_type,
+    std::vector<int> *aux_type) const override {
+    CHECK_LE(in_type->size(), this->ListArguments().size());
+    int dtype = (*in_type)[ctc_loss::kData];
+    CHECK_NE(dtype, -1) << "Input data must have specified type";
+
+    out_type->clear();
+    out_type->push_back(dtype);  // forward output
+    out_type->push_back(dtype);  // grad output
+    return true;
+  }
+
   OperatorProperty *Copy() const override {
     auto ptr = new CTCLossProp();
     ptr->param_ = param_;
diff --git a/tests/python/unittest/test_contrib_operator.py b/tests/python/unittest/test_contrib_operator.py
index fc6c1be9c3a..76efe305bce 100644
--- a/tests/python/unittest/test_contrib_operator.py
+++ b/tests/python/unittest/test_contrib_operator.py
@@ -244,6 +244,7 @@ def assert_match(inputs, x, y, threshold, is_ascend=False):
     assert_match([[0.5, 0.6], [0.1, 0.2], [0.3, 0.4]], [1, -1, 0], [2, 0], 1e-12, False)
     assert_match([[0.5, 0.6], [0.1, 0.2], [0.3, 0.4]], [-1, 0, 1], [1, 2], 100, True)
 
+
 if __name__ == '__main__':
     import nose
     nose.runmodule()
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index 9842a69e18d..4ec4bf1b384 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -4516,6 +4516,30 @@ def test_ctc_loss():
     true_loss = np.array([7.3557, 5.4091], dtype=np.float32) # from Torch
     check_ctc_loss(acts2, labels2, true_loss)
 
+    # Test 3: check use integer type as label
+    labels3 = np.array([[2, 3, 1], [2, 0, 0]], dtype=np.int32)
+    true_loss = np.array([7.3557, 5.4091], dtype=np.float32) # from Torch
+    check_ctc_loss(acts2, labels3, true_loss)
+
+@with_seed()
+def test_ctc_loss_with_large_classes():
+    ctx = default_context()
+    num_classes = 6000
+    seq_len = 8
+    batch_size = 2
+    data = np.empty((num_classes, 0))
+    for i in range(seq_len * batch_size) :
+        row = np.roll(np.arange(num_classes, dtype=np.float32), i).reshape(num_classes, 1)
+        data = np.append(data, row/13, axis=1)
+    data = data.reshape(seq_len, batch_size, num_classes)
+    label = np.array([
+        [100, 200, 300, 400, 500, 0, 0, 0],
+        [1000, 2000, 3000, 4000, 0, 5000, 0, 0]], dtype=np.int32)
+    nd_data = mx.nd.array(data)
+    nd_label = mx.nd.array(label)
+    loss = mx.nd.contrib.ctc_loss(data=nd_data, label=nd_label)
+    expected_loss = np.array([688.02826, 145.34462])
+    assert_almost_equal(loss.asnumpy(), expected_loss)
 
 @with_seed()
 def test_ctc_loss_grad():


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services