You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/09/27 20:23:50 UTC

[GitHub] apeforest commented on a change in pull request #12637: [MXNET-912] Refactoring ctc loss operator

apeforest commented on a change in pull request #12637: [MXNET-912] Refactoring ctc loss operator
URL: https://github.com/apache/incubator-mxnet/pull/12637#discussion_r221064660
 
 

 ##########
 File path: src/operator/nn/ctc_loss-inl.h
 ##########
 @@ -0,0 +1,392 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2017 by Contributors
+ * \file ctc_loss-inl.h
+ * \brief CTC Loss operator
+*/
+
+#ifndef MXNET_OPERATOR_NN_CTC_LOSS_INL_H_
+#define MXNET_OPERATOR_NN_CTC_LOSS_INL_H_
+
+#include <mxnet/operator_util.h>
+#include <vector>
+#include <algorithm>
+#include <string>
+#include "../mshadow_op.h"
+#include "./sequence_mask-inl.h"
+#include "../sequence_op_common.h"
+#include "../operator_common.h"
+#include "../elemwise_op_common.h"
+
+namespace mxnet {
+namespace op {
+
+namespace ctc_loss {
+enum CTCLossOpInputs { kData, kLabel };
+enum CTCLossOpOutputs { kOut, kGrad };
+}
+
+template <typename T>
+inline void get_workspace_size(const std::vector<int> *label_lengths,
+                               const std::vector<int> *data_lengths,
+                               int alphabet_size, int minibatch, bool isGPU,
+                               size_t *size_bytes) {
+  // This is the max of all S and T for all examples in the minibatch.
+  int maxL = *std::max_element(label_lengths->data(),
+                               label_lengths->data() + minibatch);
+  int maxT = *std::max_element(data_lengths->data(),
+                               data_lengths->data() + minibatch);
+
+  const int S = 2 * maxL + 1;
+
+  *size_bytes = 0;
+
+  if (isGPU) {
+    // GPU storage
+    // nll_forward, nll_backward
+    *size_bytes += 2 * sizeof(T) * minibatch;
+
+    // repeats
+    *size_bytes += sizeof(int) * minibatch;
+
+    // label offsets
+    *size_bytes += sizeof(int) * minibatch;
+
+    // utt_length
+    *size_bytes += sizeof(int) * minibatch;
+
+    // label lengths
+    *size_bytes += sizeof(int) * minibatch;
+
+    // labels without blanks - overallocate for now
+    *size_bytes += sizeof(int) * maxL * minibatch;
+
+    // labels with blanks
+    *size_bytes += sizeof(int) * S * minibatch;
+
+    // alphas
+    *size_bytes += sizeof(T) * S * maxT * minibatch;
+
+    // denoms
+    *size_bytes += sizeof(T) * maxT * minibatch;
+
+    // probs (since we will pass in activations)
+    *size_bytes += sizeof(T) * alphabet_size * maxT * minibatch;
+
+  } else {
+    // cpu can eventually replace all minibatch with
+    // max number of concurrent threads if memory is
+    // really tight
+
+    // per minibatch memory
+    size_t per_minibatch_bytes = 0;
+
+    // output
+    per_minibatch_bytes += sizeof(T) * alphabet_size;
+
+    // alphas
+    per_minibatch_bytes += sizeof(T) * S * maxT;
+
+    // betas
+    per_minibatch_bytes += sizeof(T) * S;
+
+    // labels w/blanks, e_inc, s_inc
+    per_minibatch_bytes += 3 * sizeof(int) * S;
+
+    *size_bytes = per_minibatch_bytes * minibatch;
+
+    // probs
+    *size_bytes += sizeof(T) * alphabet_size * maxT * minibatch;
+  }
+}
+
+// Takes a tensor of labels, and interprets 0-elements at the end of the vector
+// as padding. The tensor is packed into an std::vector without padding
+// characters. The label sequence lengths are also inferred from the padding chars.
+template <typename DType, typename xpu>
+inline void LabelTensorToPackedVector(mshadow::Tensor<xpu, 2, DType> labels,
+                                      int padding_mask,
+                                      std::vector<int> *packed_labels,
+                                      std::vector<int> *label_lengths) {
+  int batch = labels.size(0);
+  int max_num_labels = labels.size(1);
+
+  std::vector<int> cpu_labels(max_num_labels * batch);
+  mshadow::Tensor<xpu, 1, DType> flat_labels = labels.FlatTo1D();
+  IndexTensorToVector(flat_labels, &cpu_labels);
+
+  for (int b = 0; b < batch; ++b) {
+    auto start = cpu_labels.data() + b * max_num_labels;
+    auto res = std::find(start, start + max_num_labels, padding_mask);
+    int len = std::distance(start, res);
+    std::copy(start, start + len,
+              std::back_inserter(*packed_labels));
+    label_lengths->at(b) = len;
+  }
+}
+
+// Takes a tensor of labels, and a vector which specifies the actual length of each label
+// The tensor is packed into an std::vector without padding characters.
+// The label length vector is copied into an std::vector.
+template <typename DType, typename xpu>
+inline void PackLabelByLength(mshadow::Tensor<xpu, 2, DType> labels,
+                              mshadow::Tensor<xpu, 1, DType> in_label_lengths,
+                              std::vector<int> *packed_labels,
+                              std::vector<int> *label_lengths) {
+  int batch = labels.size(0);
+  int max_num_labels = labels.size(1);
+
+  IndexTensorToVector(in_label_lengths, label_lengths);
+
+  std::vector<int> cpu_labels(max_num_labels * batch);
+  mshadow::Tensor<xpu, 1, DType> flat_labels = labels.FlatTo1D();
+  IndexTensorToVector(flat_labels, &cpu_labels);
+
+  for (int b = 0; b < batch; ++b) {
+    auto start = cpu_labels.data() + b * max_num_labels;
+    int len = label_lengths->at(b);
+    std::copy(start, start + len,
+              std::back_inserter(*packed_labels));
+  }
+}
+
+struct CTCLossOpParam : public dmlc::Parameter<CTCLossOpParam> {
+  bool use_data_lengths;
+  bool use_label_lengths;
+  int blank_label;
+  DMLC_DECLARE_PARAMETER(CTCLossOpParam) {
+    DMLC_DECLARE_FIELD(use_data_lengths).set_default(false)
+      .describe("Whether the data lenghts are decided by `data_lengths`. "
+                "If false, the lengths are equal to the max sequence length.");
+    DMLC_DECLARE_FIELD(use_label_lengths).set_default(false)
+      .describe("Whether the label lenghts are decided by "
+                "`label_lengths`, or derived from `padding_mask`. "
+                "If false, the lengths are derived from the "
+                "first occurrence of the value of `padding_mask`. "
+                "The value of `padding_mask` is ``0`` when first CTC label is reserved for blank, "
+                "and ``-1`` when last label is reserved for blank. See `blank_label`.");
+    DMLC_DECLARE_FIELD(blank_label)
+      .add_enum("first", 0)
+      .add_enum("last", 1)
+      .set_default(0)
+      .describe("Set the label that is reserved for blank label."
+                "If \"first\", 0-th label is reserved, and "
+                "label values for tokens in the vocabulary are "
+                "between ``1`` and ``alphabet_size-1``, and the padding mask is ``-1``. "
+                "If \"last\", last label value ``alphabet_size-1`` "
+                "is reserved for blank label instead, "
+                "and label values for tokens in the vocabulary are "
+                "between ``0`` and ``alphabet_size-2``, and the padding mask is ``0``.");
+  }
+};
+
+inline bool CTCLossOpShape(const nnvm::NodeAttrs &attrs,
+                           std::vector<TShape>* in_attrs,
+                           std::vector<TShape>* out_attrs) {
+    const CTCLossOpParam& param = nnvm::get<CTCLossOpParam>(attrs.parsed);
+    CHECK_EQ(in_attrs->size(), 2U + param.use_data_lengths + param.use_label_lengths);
 
 Review comment:
   Changed.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services