You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/10/09 01:37:00 UTC

[GitHub] anirudh2290 closed pull request #12637: [MXNET-912] Refactoring ctc loss operator

anirudh2290 closed pull request #12637: [MXNET-912] Refactoring ctc loss operator
URL: https://github.com/apache/incubator-mxnet/pull/12637
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/src/operator/contrib/ctc_include/LICENSE b/3rdparty/ctc_include/LICENSE
similarity index 100%
rename from src/operator/contrib/ctc_include/LICENSE
rename to 3rdparty/ctc_include/LICENSE
diff --git a/src/operator/contrib/ctc_include/contrib/moderngpu/LICENSE b/3rdparty/ctc_include/contrib/moderngpu/LICENSE
similarity index 100%
rename from src/operator/contrib/ctc_include/contrib/moderngpu/LICENSE
rename to 3rdparty/ctc_include/contrib/moderngpu/LICENSE
diff --git a/src/operator/contrib/ctc_include/contrib/moderngpu/include/device/ctaloadbalance.cuh b/3rdparty/ctc_include/contrib/moderngpu/include/device/ctaloadbalance.cuh
similarity index 100%
rename from src/operator/contrib/ctc_include/contrib/moderngpu/include/device/ctaloadbalance.cuh
rename to 3rdparty/ctc_include/contrib/moderngpu/include/device/ctaloadbalance.cuh
diff --git a/src/operator/contrib/ctc_include/contrib/moderngpu/include/device/ctamerge.cuh b/3rdparty/ctc_include/contrib/moderngpu/include/device/ctamerge.cuh
similarity index 100%
rename from src/operator/contrib/ctc_include/contrib/moderngpu/include/device/ctamerge.cuh
rename to 3rdparty/ctc_include/contrib/moderngpu/include/device/ctamerge.cuh
diff --git a/src/operator/contrib/ctc_include/contrib/moderngpu/include/device/ctascan.cuh b/3rdparty/ctc_include/contrib/moderngpu/include/device/ctascan.cuh
similarity index 100%
rename from src/operator/contrib/ctc_include/contrib/moderngpu/include/device/ctascan.cuh
rename to 3rdparty/ctc_include/contrib/moderngpu/include/device/ctascan.cuh
diff --git a/src/operator/contrib/ctc_include/contrib/moderngpu/include/device/ctasearch.cuh b/3rdparty/ctc_include/contrib/moderngpu/include/device/ctasearch.cuh
similarity index 100%
rename from src/operator/contrib/ctc_include/contrib/moderngpu/include/device/ctasearch.cuh
rename to 3rdparty/ctc_include/contrib/moderngpu/include/device/ctasearch.cuh
diff --git a/src/operator/contrib/ctc_include/contrib/moderngpu/include/device/ctasegreduce.cuh b/3rdparty/ctc_include/contrib/moderngpu/include/device/ctasegreduce.cuh
similarity index 100%
rename from src/operator/contrib/ctc_include/contrib/moderngpu/include/device/ctasegreduce.cuh
rename to 3rdparty/ctc_include/contrib/moderngpu/include/device/ctasegreduce.cuh
diff --git a/src/operator/contrib/ctc_include/contrib/moderngpu/include/device/ctasegscan.cuh b/3rdparty/ctc_include/contrib/moderngpu/include/device/ctasegscan.cuh
similarity index 100%
rename from src/operator/contrib/ctc_include/contrib/moderngpu/include/device/ctasegscan.cuh
rename to 3rdparty/ctc_include/contrib/moderngpu/include/device/ctasegscan.cuh
diff --git a/src/operator/contrib/ctc_include/contrib/moderngpu/include/device/ctasegsort.cuh b/3rdparty/ctc_include/contrib/moderngpu/include/device/ctasegsort.cuh
similarity index 100%
rename from src/operator/contrib/ctc_include/contrib/moderngpu/include/device/ctasegsort.cuh
rename to 3rdparty/ctc_include/contrib/moderngpu/include/device/ctasegsort.cuh
diff --git a/src/operator/contrib/ctc_include/contrib/moderngpu/include/device/ctasortedsearch.cuh b/3rdparty/ctc_include/contrib/moderngpu/include/device/ctasortedsearch.cuh
similarity index 100%
rename from src/operator/contrib/ctc_include/contrib/moderngpu/include/device/ctasortedsearch.cuh
rename to 3rdparty/ctc_include/contrib/moderngpu/include/device/ctasortedsearch.cuh
diff --git a/src/operator/contrib/ctc_include/contrib/moderngpu/include/device/devicetypes.cuh b/3rdparty/ctc_include/contrib/moderngpu/include/device/devicetypes.cuh
similarity index 100%
rename from src/operator/contrib/ctc_include/contrib/moderngpu/include/device/devicetypes.cuh
rename to 3rdparty/ctc_include/contrib/moderngpu/include/device/devicetypes.cuh
diff --git a/src/operator/contrib/ctc_include/contrib/moderngpu/include/device/deviceutil.cuh b/3rdparty/ctc_include/contrib/moderngpu/include/device/deviceutil.cuh
similarity index 100%
rename from src/operator/contrib/ctc_include/contrib/moderngpu/include/device/deviceutil.cuh
rename to 3rdparty/ctc_include/contrib/moderngpu/include/device/deviceutil.cuh
diff --git a/src/operator/contrib/ctc_include/contrib/moderngpu/include/device/intrinsics.cuh b/3rdparty/ctc_include/contrib/moderngpu/include/device/intrinsics.cuh
similarity index 100%
rename from src/operator/contrib/ctc_include/contrib/moderngpu/include/device/intrinsics.cuh
rename to 3rdparty/ctc_include/contrib/moderngpu/include/device/intrinsics.cuh
diff --git a/src/operator/contrib/ctc_include/contrib/moderngpu/include/device/loadstore.cuh b/3rdparty/ctc_include/contrib/moderngpu/include/device/loadstore.cuh
similarity index 100%
rename from src/operator/contrib/ctc_include/contrib/moderngpu/include/device/loadstore.cuh
rename to 3rdparty/ctc_include/contrib/moderngpu/include/device/loadstore.cuh
diff --git a/src/operator/contrib/ctc_include/contrib/moderngpu/include/device/serialsets.cuh b/3rdparty/ctc_include/contrib/moderngpu/include/device/serialsets.cuh
similarity index 100%
rename from src/operator/contrib/ctc_include/contrib/moderngpu/include/device/serialsets.cuh
rename to 3rdparty/ctc_include/contrib/moderngpu/include/device/serialsets.cuh
diff --git a/src/operator/contrib/ctc_include/contrib/moderngpu/include/device/sortnetwork.cuh b/3rdparty/ctc_include/contrib/moderngpu/include/device/sortnetwork.cuh
similarity index 100%
rename from src/operator/contrib/ctc_include/contrib/moderngpu/include/device/sortnetwork.cuh
rename to 3rdparty/ctc_include/contrib/moderngpu/include/device/sortnetwork.cuh
diff --git a/src/operator/contrib/ctc_include/contrib/moderngpu/include/mgpudevice.cuh b/3rdparty/ctc_include/contrib/moderngpu/include/mgpudevice.cuh
similarity index 100%
rename from src/operator/contrib/ctc_include/contrib/moderngpu/include/mgpudevice.cuh
rename to 3rdparty/ctc_include/contrib/moderngpu/include/mgpudevice.cuh
diff --git a/src/operator/contrib/ctc_include/contrib/moderngpu/include/mgpuenums.h b/3rdparty/ctc_include/contrib/moderngpu/include/mgpuenums.h
similarity index 100%
rename from src/operator/contrib/ctc_include/contrib/moderngpu/include/mgpuenums.h
rename to 3rdparty/ctc_include/contrib/moderngpu/include/mgpuenums.h
diff --git a/src/operator/contrib/ctc_include/contrib/moderngpu/include/util/static.h b/3rdparty/ctc_include/contrib/moderngpu/include/util/static.h
similarity index 100%
rename from src/operator/contrib/ctc_include/contrib/moderngpu/include/util/static.h
rename to 3rdparty/ctc_include/contrib/moderngpu/include/util/static.h
diff --git a/src/operator/contrib/ctc_include/detail/cpu_ctc.h b/3rdparty/ctc_include/detail/cpu_ctc.h
similarity index 100%
rename from src/operator/contrib/ctc_include/detail/cpu_ctc.h
rename to 3rdparty/ctc_include/detail/cpu_ctc.h
diff --git a/src/operator/contrib/ctc_include/detail/ctc_helper.h b/3rdparty/ctc_include/detail/ctc_helper.h
similarity index 100%
rename from src/operator/contrib/ctc_include/detail/ctc_helper.h
rename to 3rdparty/ctc_include/detail/ctc_helper.h
diff --git a/src/operator/contrib/ctc_include/detail/gpu_ctc.h b/3rdparty/ctc_include/detail/gpu_ctc.h
similarity index 100%
rename from src/operator/contrib/ctc_include/detail/gpu_ctc.h
rename to 3rdparty/ctc_include/detail/gpu_ctc.h
diff --git a/src/operator/contrib/ctc_include/detail/gpu_ctc_kernels.h b/3rdparty/ctc_include/detail/gpu_ctc_kernels.h
similarity index 100%
rename from src/operator/contrib/ctc_include/detail/gpu_ctc_kernels.h
rename to 3rdparty/ctc_include/detail/gpu_ctc_kernels.h
diff --git a/src/operator/contrib/ctc_include/detail/hostdevice.h b/3rdparty/ctc_include/detail/hostdevice.h
similarity index 100%
rename from src/operator/contrib/ctc_include/detail/hostdevice.h
rename to 3rdparty/ctc_include/detail/hostdevice.h
diff --git a/python/mxnet/gluon/loss.py b/python/mxnet/gluon/loss.py
index 2be43981a64..7e4d3457763 100644
--- a/python/mxnet/gluon/loss.py
+++ b/python/mxnet/gluon/loss.py
@@ -468,10 +468,10 @@ def hybrid_forward(self, F, pred, label,
             pred = F.swapaxes(pred, 0, 1)
         if self._batch_axis == 1:
             label = F.swapaxes(label, 0, 1)
-        loss = F.contrib.CTCLoss(pred, label, pred_lengths, label_lengths,
-                                 use_data_lengths=pred_lengths is not None,
-                                 use_label_lengths=label_lengths is not None,
-                                 blank_label='last')
+        loss = F.CTCLoss(pred, label, pred_lengths, label_lengths,
+                         use_data_lengths=pred_lengths is not None,
+                         use_label_lengths=label_lengths is not None,
+                         blank_label='last')
         return _apply_weighting(F, loss, self._weight, sample_weight)
 
 
diff --git a/src/operator/contrib/ctc_loss-inl.h b/src/operator/contrib/ctc_loss-inl.h
deleted file mode 100644
index c8a8b263740..00000000000
--- a/src/operator/contrib/ctc_loss-inl.h
+++ /dev/null
@@ -1,591 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * Copyright (c) 2016 by Contributors
- * \file ctc_loss-inl.h
- * \brief
- * \author Sebastian Bodenstien
-*/
-
-#ifndef MXNET_OPERATOR_CONTRIB_CTC_LOSS_INL_H_
-#define MXNET_OPERATOR_CONTRIB_CTC_LOSS_INL_H_
-
-#include <dmlc/logging.h>
-#include <dmlc/parameter.h>
-#include <mxnet/operator.h>
-#include <algorithm>
-#include <map>
-#include <vector>
-#include <string>
-#include <utility>
-#include <ctime>
-#include <cstring>
-#include <iostream>
-#include "../operator_common.h"
-#include "../sequence_op_common.h"
-#include "../mshadow_op.h"
-#include "../nn/sequence_mask-inl.h"
-
-#if defined(__CUDACC__) && MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7
-#define CUDNN_LABEL_LENGTH_LIMIT 256
-#include "../nn/softmax-inl.h"
-#endif  // CUDNN
-
-namespace mxnet {
-namespace op {
-
-namespace ctc_loss {
-enum CTCLossOpInputs { kData, kLabel };
-enum CTCLossOpOutputs { kOut, kGrad };
-enum CTCLossOpForwardResource { kTempSpace };
-}
-
-template <typename T>
-inline void get_workspace_size(std::vector<int> *label_lengths,
-                               std::vector<int> *data_lengths,
-                               int alphabet_size, int minibatch, bool gpu,
-                               size_t *size_bytes) {
-  // This is the max of all S and T for all examples in the minibatch.
-  int maxL = *std::max_element(label_lengths->data(),
-                               label_lengths->data() + minibatch);
-  int maxT = *std::max_element(data_lengths->data(),
-                               data_lengths->data() + minibatch);
-
-  const int S = 2 * maxL + 1;
-
-  *size_bytes = 0;
-
-  if (gpu) {
-    // GPU storage
-    // nll_forward, nll_backward
-    *size_bytes += 2 * sizeof(T) * minibatch;
-
-    // repeats
-    *size_bytes += sizeof(int) * minibatch;
-
-    // label offsets
-    *size_bytes += sizeof(int) * minibatch;
-
-    // utt_length
-    *size_bytes += sizeof(int) * minibatch;
-
-    // label lengths
-    *size_bytes += sizeof(int) * minibatch;
-
-    // labels without blanks - overallocate for now
-    *size_bytes += sizeof(int) * maxL * minibatch;
-
-    // labels with blanks
-    *size_bytes += sizeof(int) * S * minibatch;
-
-    // alphas
-    *size_bytes += sizeof(T) * S * maxT * minibatch;
-
-    // denoms
-    *size_bytes += sizeof(T) * maxT * minibatch;
-
-    // probs (since we will pass in activations)
-    *size_bytes += sizeof(T) * alphabet_size * maxT * minibatch;
-
-  } else {
-    // cpu can eventually replace all minibatch with
-    // max number of concurrent threads if memory is
-    // really tight
-
-    // per minibatch memory
-    size_t per_minibatch_bytes = 0;
-
-    // output
-    per_minibatch_bytes += sizeof(T) * alphabet_size;
-
-    // alphas
-    per_minibatch_bytes += sizeof(T) * S * maxT;
-
-    // betas
-    per_minibatch_bytes += sizeof(T) * S;
-
-    // labels w/blanks, e_inc, s_inc
-    per_minibatch_bytes += 3 * sizeof(int) * S;
-
-    *size_bytes = per_minibatch_bytes * minibatch;
-
-    // probs
-    *size_bytes += sizeof(T) * alphabet_size * maxT * minibatch;
-  }
-}
-
-// Takes a tensor of labels, and interprets 0-elements at the end of the vector
-// as padding. The tensor is packed into an std::vector without padding
-// characters. The label sequence lengths are also inferred from the padding chars.
-// When cudnn is enabled, the return value signifies whether the cudnn length limit is exceeded.
-template <typename DType, typename xpu>
-inline bool LabelTensorToPackedVector(mshadow::Tensor<xpu, 2, DType> labels,
-                                      int padding_mask,
-                                      std::vector<int> *packed_labels,
-                                      std::vector<int> *label_lengths) {
-  int batch = labels.size(0);
-  int max_num_labels = labels.size(1);
-  bool exceed_limit = false;
-
-  std::vector<int> cpu_labels(max_num_labels*batch);
-  mshadow::Tensor<xpu, 1, DType> flat_labels = labels.FlatTo1D();
-  IndexTensorToVector(flat_labels, &cpu_labels);
-
-  for (int b = 0; b < batch; ++b) {
-    auto start = cpu_labels.data()+b*max_num_labels;
-    auto res = std::find(start, start+max_num_labels, padding_mask);
-    int len = std::distance(start, res);
-#if defined(__CUDACC__) && MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7
-    exceed_limit = exceed_limit || len > CUDNN_LABEL_LENGTH_LIMIT;
-#endif
-    std::copy(start, start + len,
-              std::back_inserter(*packed_labels));
-    label_lengths->at(b) = len;
-  }
-  return exceed_limit;
-}
-
-// Takes a tensor of labels, and a vector which specifies the actual length of each label
-// The tensor is packed into an std::vector without padding characters.
-// The label length vector is copied into an std::vector.
-// When cudnn is enabled, the return value signifies whether the cudnn length limit is exceeded.
-template <typename DType, typename xpu>
-inline bool PackLabelByLength(mshadow::Tensor<xpu, 2, DType> labels,
-                              mshadow::Tensor<xpu, 1, DType> in_label_lengths,
-                              std::vector<int> *packed_labels,
-                              std::vector<int> *label_lengths) {
-  int batch = labels.size(0);
-  int max_num_labels = labels.size(1);
-  bool exceed_limit = false;
-
-  IndexTensorToVector(in_label_lengths, label_lengths);
-
-  std::vector<int> cpu_labels(max_num_labels*batch);
-  mshadow::Tensor<xpu, 1, DType> flat_labels = labels.FlatTo1D();
-  IndexTensorToVector(flat_labels, &cpu_labels);
-
-  for (int b = 0; b < batch; ++b) {
-    auto start = cpu_labels.data()+b*max_num_labels;
-    int len = label_lengths->at(b);
-#if defined(__CUDACC__) && MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7
-    exceed_limit = exceed_limit || len > CUDNN_LABEL_LENGTH_LIMIT;
-#endif
-    std::copy(start, start + len,
-              std::back_inserter(*packed_labels));
-  }
-  return exceed_limit;
-}
-
-struct CTCLossParam : public dmlc::Parameter<CTCLossParam> {
-  bool use_data_lengths;
-  bool use_label_lengths;
-  int blank_label;
-  DMLC_DECLARE_PARAMETER(CTCLossParam) {
-    DMLC_DECLARE_FIELD(use_data_lengths).set_default(false)
-      .describe("Whether the data lenghts are decided by `data_lengths`. "
-                "If false, the lengths are equal to the max sequence length.");
-    DMLC_DECLARE_FIELD(use_label_lengths).set_default(false)
-      .describe("Whether the label lenghts are decided by "
-                "`label_lengths`, or derived from `padding_mask`. "
-                "If false, the lengths are derived from the "
-                "first occurrence of the value of `padding_mask`. "
-                "The value of `padding_mask` is ``0`` when first CTC label is reserved for blank, "
-                "and ``-1`` when last label is reserved for blank. See `blank_label`.");
-    DMLC_DECLARE_FIELD(blank_label)
-      .add_enum("first", 0)
-      .add_enum("last", 1)
-      .set_default(0)
-      .describe("Set the label that is reserved for blank label."
-                "If \"first\", 0-th label is reserved, and "
-                "label values for tokens in the vocabulary are "
-                "between ``1`` and ``alphabet_size-1``, and the padding mask is ``-1``. "
-                "If \"last\", last label value ``alphabet_size-1`` "
-                "is reserved for blank label instead, "
-                "and label values for tokens in the vocabulary are "
-                "between ``0`` and ``alphabet_size-2``, and the padding mask is ``0``.");
-  }
-};
-
-template <typename xpu>
-class CTCLossOp : public Operator {
- public:
-  explicit CTCLossOp(CTCLossParam p) {
-    this->param_ = p;
-    exceed_cudnn_limit = false;
-#if defined(__CUDACC__) && MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7
-    CUDNN_CALL(cudnnCreateCTCLossDescriptor(&ctc_desc_));
-    CUDNN_CALL(cudnnSetCTCLossDescriptor(ctc_desc_, CUDNN_DATA_FLOAT));
-    CUDNN_CALL(cudnnCreateTensorDescriptor(&prob_desc_));
-    CUDNN_CALL(cudnnCreateTensorDescriptor(&grad_desc_));
-#endif
-  }
-
-  ~CTCLossOp() {
-#if defined(__CUDACC__) && MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7
-    CUDNN_CALL(cudnnDestroyCTCLossDescriptor(ctc_desc_));
-    CUDNN_CALL(cudnnDestroyTensorDescriptor(prob_desc_));
-    CUDNN_CALL(cudnnDestroyTensorDescriptor(grad_desc_));
-#endif
-  }
-
-  virtual void Forward(const OpContext &ctx, const std::vector<TBlob> &in_data,
-                       const std::vector<OpReqType> &req,
-                       const std::vector<TBlob> &out_data,
-                       const std::vector<TBlob> &aux_args) {
-    using namespace mshadow;
-    using namespace mshadow::expr;
-    CHECK_EQ(in_data.size(), 2U+param_.use_data_lengths+param_.use_label_lengths);
-    CHECK_EQ(out_data.size(), 2U);
-    exceed_cudnn_limit = false;
-    Stream<xpu> *s = ctx.get_stream<xpu>();
-
-    MSHADOW_TYPE_SWITCH(in_data[ctc_loss::kLabel].type_flag_, DType, {
-      Tensor<xpu, 3, real_t> data =
-        in_data[ctc_loss::kData].get<xpu, 3, real_t>(s);
-      Tensor<xpu, 2, DType> labels =
-        in_data[ctc_loss::kLabel].get<xpu, 2, DType>(s);
-
-      Tensor<xpu, 1, real_t> costs =
-        out_data[ctc_loss::kOut].get<xpu, 1, real_t>(s);
-      Tensor<xpu, 3, real_t> grad =
-        out_data[ctc_loss::kGrad].get<xpu, 3, real_t>(s);
-
-      int max_seq_len = data.size(0);
-      int batch_size = data.size(1);
-      int alphabet_size = data.size(2);
-
-      // data_lengths
-      std::vector<int> data_lengths(batch_size, max_seq_len);
-      if (param_.use_data_lengths) {
-        int kInputLength = 2;
-        IndexTensorToVector(in_data[kInputLength].get<xpu, 1, real_t>(s), &data_lengths);
-      }
-
-      // label_lengths
-      std::vector<int> packed_labels;
-      std::vector<int> label_lengths(batch_size);
-
-      if (param_.use_label_lengths) {
-        int kLabelLength = 2 + param_.use_data_lengths;
-        exceed_cudnn_limit =
-          PackLabelByLength(labels, in_data[kLabelLength].get<xpu, 1, DType>(s),
-                           &packed_labels, &label_lengths);
-      } else {
-        exceed_cudnn_limit = LabelTensorToPackedVector(labels, param_.blank_label == 0 ? 0 : -1,
-                                                      &packed_labels, &label_lengths);
-      }
-
-      // CUDNN is disabled due to lack of support for input lengths
-      /* #if defined(__CUDACC__) && MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7 */
-      /*     if (!exceed_cudnn_limit) { */
-      /*       cudnn_forward(ctx, s, data, costs, grad, */
-      /*                     &data_lengths, &label_lengths, &packed_labels, */
-      /*                     max_seq_len, batch_size, alphabet_size, */
-      /*                     req[ctc_loss::kGrad] != mxnet::kNullOp); */
-      /*     } else { */
-      /*       baidu_forward(ctx, s, data, costs, grad, */
-      /*                     &data_lengths, &label_lengths, &packed_labels, */
-      /*                     batch_size, alphabet_size, req[ctc_loss::kGrad] != mxnet::kNullOp);*/
-      /*     } */
-      /* #else */
-
-      baidu_forward(ctx, s, data, costs, grad,
-                    &data_lengths, &label_lengths, &packed_labels,
-                    batch_size, alphabet_size, req[ctc_loss::kGrad] != mxnet::kNullOp);
-
-      if (param_.use_data_lengths) {
-        // baidu warp CTC implementation sometimes includes undefined gradients
-        // for data outside of length mask. Setting to 0 to make it consistent
-        // with CPU implementation.
-        int kInputLength = 2;
-        mxnet_op::SequenceMask(grad, in_data[kInputLength].get<xpu, 1, real_t>(s),
-                              static_cast<real_t>(0));
-      }
-    });
-  }
-
-  virtual void Backward(const OpContext &ctx,
-                        const std::vector<TBlob> &out_grad,
-                        const std::vector<TBlob> &in_data,
-                        const std::vector<TBlob> &out_data,
-                        const std::vector<OpReqType> &req,
-                        const std::vector<TBlob> &in_grad,
-                        const std::vector<TBlob> &aux_args) {
-    using namespace mshadow;
-    using namespace mshadow::expr;
-
-    Stream<xpu> *s = ctx.get_stream<xpu>();
-
-    Tensor<xpu, 3, real_t> data_grad =
-        in_grad[ctc_loss::kData].get<xpu, 3, real_t>(s);
-    Tensor<xpu, 1, real_t> output_grad =
-        out_grad[ctc_loss::kOut].get<xpu, 1, real_t>(s);
-
-    Tensor<xpu, 3, real_t> data_grad_computed =
-        out_data[ctc_loss::kGrad].get<xpu, 3, real_t>(s);
-
-    Assign(data_grad, req[ctc_loss::kData],
-           mshadow::expr::broadcast<1>(output_grad, data_grad.shape_) * data_grad_computed);
-  }
-
- private:
-  CTCLossParam param_;
-  bool exceed_cudnn_limit;
-
-#if defined(__CUDACC__) && MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7
-  cudnnDataType_t dtype_;
-  cudnnCTCLossDescriptor_t ctc_desc_;
-  cudnnTensorDescriptor_t prob_desc_, grad_desc_;
-
-  inline virtual void cudnn_forward(const OpContext &ctx,
-                                    mshadow::Stream<xpu>* s,
-                                    mshadow::Tensor<xpu, 3, real_t> data,
-                                    mshadow::Tensor<xpu, 1, real_t> costs,
-                                    mshadow::Tensor<xpu, 3, real_t> grad,
-                                    std::vector<int>* data_lengths,
-                                    std::vector<int>* label_lengths,
-                                    std::vector<int>* packed_labels,
-                                    int max_seq_len,
-                                    int batch_size,
-                                    int alphabet_size,
-                                    bool req_grad) {
-    using namespace mshadow;
-
-    // call cudnn to calculate ctc loss
-    dtype_ = CUDNN_DATA_FLOAT;
-    int dims[3], strides[3];
-    size_t workspace_bytes;
-    int workspace_size;
-    dims[0] = max_seq_len;
-    dims[1] = batch_size;
-    dims[2] = alphabet_size;
-    strides[0] = batch_size*alphabet_size;
-    strides[1] = alphabet_size;
-    strides[2] = 1;
-    cudnnCTCLossAlgo_t ctc_algo = CUDNN_CTC_LOSS_ALGO_DETERMINISTIC;
-    CUDNN_CALL(cudnnSetTensorNdDescriptor(prob_desc_,
-                                          dtype_,
-                                          3,
-                                          dims,
-                                          strides));
-    CUDNN_CALL(cudnnSetTensorNdDescriptor(grad_desc_,
-                                          dtype_,
-                                          3,
-                                          dims,
-                                          strides));
-    CUDNN_CALL(cudnnGetCTCLossWorkspaceSize(s->dnn_handle_,
-                                            prob_desc_,
-                                            req_grad?grad_desc_:NULL,
-                                            packed_labels->data(),
-                                            label_lengths->data(),
-                                            data_lengths->data(),
-                                            ctc_algo,
-                                            ctc_desc_,
-                                            &workspace_bytes));
-    workspace_size = (workspace_bytes + sizeof(real_t) - 1)/sizeof(real_t);
-
-    Tensor<xpu, 1, real_t> temp_space =
-      ctx.requested[ctc_loss::kTempSpace].get_space_typed<xpu, 1, real_t>(
-          mshadow::Shape1(workspace_size+data.shape_.FlatTo1D()[0]), s);
-
-    Tensor<gpu, 1, real_t> work_space(temp_space.dptr_,
-                                      mshadow::Shape1(workspace_size), s);
-    Tensor<xpu, 3, real_t> prob(temp_space.dptr_+workspace_size,
-                                data.shape_, s);
-
-    // since the input is activation before softmax and cudnn ctc takes softmax
-    // apply softmax to inputs first.
-    mxnet_op::Softmax<mxnet_op::softmax_fwd, false>(
-      s, data.dptr_, prob.dptr_, data.shape_, 2, 1.0);
-
-    CUDNN_CALL(cudnnCTCLoss(s->dnn_handle_,
-                            prob_desc_,
-                            prob.dptr_,
-                            packed_labels->data(),
-                            label_lengths->data(),
-                            data_lengths->data(),
-                            costs.dptr_,
-                            req_grad?grad_desc_:NULL,
-                            req_grad?grad.dptr_:NULL,
-                            ctc_algo,
-                            ctc_desc_,
-                            work_space.dptr_,
-                            workspace_bytes));
-
-    if (req_grad) {
-      mxnet_op::SoftmaxGrad<mshadow_op::mul, mxnet_op::softmax_bwd, kWriteTo, false>(
-        s, prob.dptr_, grad.dptr_, grad.dptr_, data.shape_, 2, 1.0);
-      Assign(grad, mxnet::kWriteInplace, grad * alphabet_size);
-    }
-  }
-#endif  // __CUDACC__ && CUDNN
-
-  inline void baidu_forward(const OpContext &ctx,
-                            mshadow::Stream<xpu>* s,
-                            mshadow::Tensor<xpu, 3, real_t> data,
-                            mshadow::Tensor<xpu, 1, real_t> costs,
-                            mshadow::Tensor<xpu, 3, real_t> grad,
-                            std::vector<int>* data_lengths,
-                            std::vector<int>* label_lengths,
-                            std::vector<int>* packed_labels,
-                            int batch_size,
-                            int alphabet_size,
-                            bool req_grad) {
-    using namespace mshadow;
-    // allocate temporary workspace
-    size_t size_bytes;
-    bool gpu = data.kDevCPU ? false : true;
-    get_workspace_size<real_t>(label_lengths, data_lengths, alphabet_size,
-                               batch_size, gpu, &size_bytes);
-
-    // round-up so there are enough elems in memory
-    int num_tmp_elems = (size_bytes + sizeof(real_t) - 1) / sizeof(real_t);
-    Tensor<xpu, 1, real_t> workspace =
-        ctx.requested[ctc_loss::kTempSpace].get_space_typed<xpu, 1, real_t>(
-            Shape1(num_tmp_elems), s);
-
-    compute_ctc_cost(data, costs.dptr_, grad.dptr_, packed_labels->data(),
-                     label_lengths->data(), data_lengths->data(),
-                     workspace.dptr_, req_grad,
-                     param_.blank_label == 0 ? 0 : (alphabet_size-1));
-  }
-};  // class CTCLossOp
-
-template <typename xpu>
-Operator *CreateOp(CTCLossParam param, int dtype);
-
-#if DMLC_USE_CXX11
-class CTCLossProp : public OperatorProperty {
- public:
-  int NumVisibleOutputs() const override { return 1; }
-
-  int NumOutputs() const override { return 2; }
-
-  std::vector<std::string> ListArguments() const override {
-    if (param_.use_data_lengths && param_.use_label_lengths) {
-      return {"data", "label", "data_lengths", "label_lengths"};
-    } else if (param_.use_data_lengths) {
-      return {"data", "label", "data_lengths"};
-    } else if (param_.use_label_lengths) {
-      return {"data", "label", "label_lengths"};
-    } else {
-      return {"data", "label"};
-    }
-  }
-
-  std::vector<std::string> ListOutputs() const override {
-    return {"output", "grad"};
-  }
-
-  void Init(const std::vector<std::pair<std::string, std::string>> &kwargs) override {
-    param_.Init(kwargs);
-  }
-
-  std::map<std::string, std::string> GetParams() const override {
-    return param_.__DICT__();
-  }
-
-  bool InferShape(std::vector<TShape> *in_shape, std::vector<TShape> *out_shape,
-                  std::vector<TShape> *aux_shape) const override {
-    using namespace mshadow;
-    index_t expected_inputs = 2+param_.use_data_lengths+param_.use_label_lengths;
-    CHECK_EQ(in_shape->size(), expected_inputs)
-        << "Expect " << expected_inputs << " inputs to the symbol.";
-
-    const TShape &dshape = (*in_shape)[ctc_loss::kData];
-    const TShape &lshape = (*in_shape)[ctc_loss::kLabel];
-    CHECK_EQ(dshape.ndim(), 3U) << "The data array must be of rank 3.";
-    CHECK_EQ(lshape.ndim(), 2U) << "The labels array must be of rank 2.";
-    CHECK_EQ(dshape[1], lshape[0])
-        << "The batch size for the labels and data arrays must be the same.";
-    if (param_.use_data_lengths) {
-      int kInputLength = 2;
-      const TShape &dlshape = (*in_shape)[kInputLength];
-      CHECK_EQ(dlshape.ndim(), 1U) << "Data length array must be a vector.";
-      CHECK_EQ(dlshape[0], dshape[1])
-          << "The batch size for the data and data lengths must be the same.";
-    }
-    if (param_.use_label_lengths) {
-      int kLabelLength = 2+param_.use_data_lengths;
-      const TShape &llshape = (*in_shape)[kLabelLength];
-      CHECK_EQ(llshape.ndim(), 1U) << "Label length array must be a vector.";
-      CHECK_EQ(llshape[0], lshape[0])
-          << "The batch size for the labels and label lengths must be the same.";
-    }
-
-    CHECK_GE(dshape[0], lshape[1]) << "The max number of labels cannot exceed "
-                                      "the maximum sequence length of the "
-                                      "data.";
-
-    TShape oshape(1);
-    oshape[0] = dshape[1];  // batch size
-    out_shape->clear();
-    out_shape->push_back(oshape);  // forward output
-    out_shape->push_back(dshape);  // grad output
-    return true;
-  }
-
-  bool InferType(std::vector<int> *in_type,
-    std::vector<int> *out_type,
-    std::vector<int> *aux_type) const override {
-    CHECK_LE(in_type->size(), this->ListArguments().size());
-    int dtype = (*in_type)[ctc_loss::kData];
-    CHECK_NE(dtype, -1) << "Input data must have specified type";
-
-    out_type->clear();
-    out_type->push_back(dtype);  // forward output
-    out_type->push_back(dtype);  // grad output
-    return true;
-  }
-
-  OperatorProperty *Copy() const override {
-    auto ptr = new CTCLossProp();
-    ptr->param_ = param_;
-    return ptr;
-  }
-
-  std::string TypeString() const override { return "_contrib_CTCLoss"; }
-
-  std::vector<ResourceRequest> ForwardResource(
-      const std::vector<TShape> &in_shape) const override {
-    return {ResourceRequest::kTempSpace};
-  }
-
-  std::vector<int> DeclareBackwardDependency(
-      const std::vector<int> &out_grad, const std::vector<int> &in_data,
-      const std::vector<int> &out_data) const override {
-    return {out_grad[ctc_loss::kOut], out_data[ctc_loss::kGrad]};
-  }
-
-  Operator *CreateOperator(Context ctx) const override {
-    LOG(FATAL) << "Not Implemented.";
-    return NULL;
-  }
-
-  Operator *CreateOperatorEx(Context ctx, std::vector<TShape> *in_shape,
-                             std::vector<int> *in_type) const override;
-
- private:
-  CTCLossParam param_;
-};      // class CTCLossProp
-#endif  // DMLC_USE_CXX11
-}  // namespace op
-}  // namespace mxnet
-#endif  // MXNET_OPERATOR_CONTRIB_CTC_LOSS_INL_H_
diff --git a/src/operator/nn/ctc_loss-inl.h b/src/operator/nn/ctc_loss-inl.h
new file mode 100644
index 00000000000..754cf8471b5
--- /dev/null
+++ b/src/operator/nn/ctc_loss-inl.h
@@ -0,0 +1,397 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2017 by Contributors
+ * \file ctc_loss-inl.h
+ * \brief CTC Loss operator
+*/
+
+#ifndef MXNET_OPERATOR_NN_CTC_LOSS_INL_H_
+#define MXNET_OPERATOR_NN_CTC_LOSS_INL_H_
+
+#include <mxnet/operator_util.h>
+#include <vector>
+#include <algorithm>
+#include <string>
+#include "../mshadow_op.h"
+#include "./sequence_mask-inl.h"
+#include "../sequence_op_common.h"
+#include "../operator_common.h"
+#include "../elemwise_op_common.h"
+
+namespace mxnet {
+namespace op {
+
+namespace ctc_loss {
+enum CTCLossOpInputs { kData, kLabel };
+enum CTCLossOpOutputs { kOut, kGrad };
+}
+
+template <typename T>
+inline void get_workspace_size(const std::vector<int> *label_lengths,
+                               const std::vector<int> *data_lengths,
+                               int alphabet_size, int minibatch, bool isGPU,
+                               size_t *size_bytes) {
+  // This is the max of all S and T for all examples in the minibatch.
+  int maxL = *std::max_element(label_lengths->data(),
+                               label_lengths->data() + minibatch);
+  int maxT = *std::max_element(data_lengths->data(),
+                               data_lengths->data() + minibatch);
+
+  const int S = 2 * maxL + 1;
+
+  *size_bytes = 0;
+
+  if (isGPU) {
+    // GPU storage
+    // nll_forward, nll_backward
+    *size_bytes += 2 * sizeof(T) * minibatch;
+
+    // repeats
+    *size_bytes += sizeof(int) * minibatch;
+
+    // label offsets
+    *size_bytes += sizeof(int) * minibatch;
+
+    // utt_length
+    *size_bytes += sizeof(int) * minibatch;
+
+    // label lengths
+    *size_bytes += sizeof(int) * minibatch;
+
+    // labels without blanks - overallocate for now
+    *size_bytes += sizeof(int) * maxL * minibatch;
+
+    // labels with blanks
+    *size_bytes += sizeof(int) * S * minibatch;
+
+    // alphas
+    *size_bytes += sizeof(T) * S * maxT * minibatch;
+
+    // denoms
+    *size_bytes += sizeof(T) * maxT * minibatch;
+
+    // probs (since we will pass in activations)
+    *size_bytes += sizeof(T) * alphabet_size * maxT * minibatch;
+
+  } else {
+    // cpu can eventually replace all minibatch with
+    // max number of concurrent threads if memory is
+    // really tight
+
+    // per minibatch memory
+    size_t per_minibatch_bytes = 0;
+
+    // output
+    per_minibatch_bytes += sizeof(T) * alphabet_size;
+
+    // alphas
+    per_minibatch_bytes += sizeof(T) * S * maxT;
+
+    // betas
+    per_minibatch_bytes += sizeof(T) * S;
+
+    // labels w/blanks, e_inc, s_inc
+    per_minibatch_bytes += 3 * sizeof(int) * S;
+
+    *size_bytes = per_minibatch_bytes * minibatch;
+
+    // probs
+    *size_bytes += sizeof(T) * alphabet_size * maxT * minibatch;
+  }
+}
+
+// Takes a tensor of labels, and interprets 0-elements at the end of the vector
+// as padding. The tensor is packed into an std::vector without padding
+// characters. The label sequence lengths are also inferred from the padding chars.
+template <typename DType, typename xpu>
+inline void LabelTensorToPackedVector(mshadow::Tensor<xpu, 2, DType> labels,
+                                      int padding_mask,
+                                      std::vector<int> *packed_labels,
+                                      std::vector<int> *label_lengths) {
+  int batch = labels.size(0);
+  int max_num_labels = labels.size(1);
+
+  std::vector<int> cpu_labels(max_num_labels * batch);
+  mshadow::Tensor<xpu, 1, DType> flat_labels = labels.FlatTo1D();
+  IndexTensorToVector(flat_labels, &cpu_labels);
+
+  for (int b = 0; b < batch; ++b) {
+    auto start = cpu_labels.data() + b * max_num_labels;
+    auto res = std::find(start, start + max_num_labels, padding_mask);
+    int len = std::distance(start, res);
+    std::copy(start, start + len,
+              std::back_inserter(*packed_labels));
+    label_lengths->at(b) = len;
+  }
+}
+
+// Takes a tensor of labels, and a vector which specifies the actual length of each label
+// The tensor is packed into an std::vector without padding characters.
+// The label length vector is copied into an std::vector.
+template <typename DType, typename xpu>
+inline void PackLabelByLength(mshadow::Tensor<xpu, 2, DType> labels,
+                              mshadow::Tensor<xpu, 1, DType> in_label_lengths,
+                              std::vector<int> *packed_labels,
+                              std::vector<int> *label_lengths) {
+  int batch = labels.size(0);
+  int max_num_labels = labels.size(1);
+
+  IndexTensorToVector(in_label_lengths, label_lengths);
+
+  std::vector<int> cpu_labels(max_num_labels * batch);
+  mshadow::Tensor<xpu, 1, DType> flat_labels = labels.FlatTo1D();
+  IndexTensorToVector(flat_labels, &cpu_labels);
+
+  for (int b = 0; b < batch; ++b) {
+    auto start = cpu_labels.data() + b * max_num_labels;
+    int len = label_lengths->at(b);
+    std::copy(start, start + len,
+              std::back_inserter(*packed_labels));
+  }
+}
+
+struct CTCLossOpParam : public dmlc::Parameter<CTCLossOpParam> {
+  bool use_data_lengths;
+  bool use_label_lengths;
+  int blank_label;
+  DMLC_DECLARE_PARAMETER(CTCLossOpParam) {
+    DMLC_DECLARE_FIELD(use_data_lengths).set_default(false)
+      .describe("Whether the data lenghts are decided by `data_lengths`. "
+                "If false, the lengths are equal to the max sequence length.");
+    DMLC_DECLARE_FIELD(use_label_lengths).set_default(false)
+      .describe("Whether the label lenghts are decided by "
+                "`label_lengths`, or derived from `padding_mask`. "
+                "If false, the lengths are derived from the "
+                "first occurrence of the value of `padding_mask`. "
+                "The value of `padding_mask` is ``0`` when first CTC label is reserved for blank, "
+                "and ``-1`` when last label is reserved for blank. See `blank_label`.");
+    DMLC_DECLARE_FIELD(blank_label)
+      .add_enum("first", 0)
+      .add_enum("last", 1)
+      .set_default(0)
+      .describe("Set the label that is reserved for blank label."
+                "If \"first\", 0-th label is reserved, and "
+                "label values for tokens in the vocabulary are "
+                "between ``1`` and ``alphabet_size-1``, and the padding mask is ``-1``. "
+                "If \"last\", last label value ``alphabet_size-1`` "
+                "is reserved for blank label instead, "
+                "and label values for tokens in the vocabulary are "
+                "between ``0`` and ``alphabet_size-2``, and the padding mask is ``0``.");
+  }
+};
+
+// By default, the inputs must include data array and label array
+// if use_data_lengths parameter is set, user should also supply
+// data_lengths array; if use_label_lengths parameter is set, user
+// should also specify label_lengths array.
+inline uint32_t CTCLossOpNumInputs(const NodeAttrs& attrs) {
+  const CTCLossOpParam& param = nnvm::get<CTCLossOpParam>(attrs.parsed);
+  return 2U + param.use_data_lengths + param.use_label_lengths;
+}
+
+inline bool CTCLossOpShape(const nnvm::NodeAttrs &attrs,
+                           std::vector<TShape>* in_attrs,
+                           std::vector<TShape>* out_attrs) {
+    const CTCLossOpParam& param = nnvm::get<CTCLossOpParam>(attrs.parsed);
+    CHECK_EQ(in_attrs->size(), CTCLossOpNumInputs(attrs));
+    CHECK_EQ(out_attrs->size(), 2U);
+
+    const TShape &dshape = (*in_attrs)[ctc_loss::kData];
+    const TShape &lshape = (*in_attrs)[ctc_loss::kLabel];
+    CHECK_EQ(dshape.ndim(), 3U) << "The number of dimensions of data array must be 3.";
+    CHECK_EQ(lshape.ndim(), 2U) << "The number of dimensions of labels array must be 2.";
+    CHECK_EQ(dshape[1], lshape[0])
+        << "The batch size for the labels and data arrays must be the same.";
+
+    if (param.use_data_lengths) {
+      int kInputLength = 2;
+      const TShape &dlshape = (*in_attrs)[kInputLength];
+      CHECK_EQ(dlshape.ndim(), 1U) << "Data length array must be a vector.";
+      CHECK_EQ(dlshape[0], dshape[1])
+          << "The batch size for the data and data lengths must be the same.";
+    }
+    if (param.use_label_lengths) {
+      int kLabelLength = 2 + param.use_data_lengths;
+      const TShape &llshape = (*in_attrs)[kLabelLength];
+      CHECK_EQ(llshape.ndim(), 1U) << "Label length array must be a vector.";
+      CHECK_EQ(llshape[0], lshape[0])
+          << "The batch size for the labels and label lengths must be the same.";
+    }
+    CHECK_GE(dshape[0], lshape[1]) << "The max number of labels cannot exceed "
+                                      "the maximum sequence length of the "
+                                      "data.";
+
+    TShape oshape(1);
+    oshape[0] = dshape[1];  // batch size
+    SHAPE_ASSIGN_CHECK(*out_attrs, 0, oshape);  // forward output
+    SHAPE_ASSIGN_CHECK(*out_attrs, 1, dshape);  // grad output
+    return true;
+}
+
+inline bool CTCLossOpType(const nnvm::NodeAttrs& attrs,
+                          std::vector<int>* in_attrs,
+                          std::vector<int>* out_attrs) {
+    CHECK_GE(in_attrs->size(), 2U);
+    CHECK_EQ(out_attrs->size(), 2U);
+    int dtype = (*in_attrs)[ctc_loss::kData];
+    CHECK_NE(dtype, -1) << "Input data must have specified type";
+
+    TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0));  // forward output
+    TYPE_ASSIGN_CHECK(*out_attrs, 1, in_attrs->at(0));  // grad output
+    return true;
+}
+
+inline bool CTCLossOpStorageType(const nnvm::NodeAttrs& attrs,
+                                 const int dev_mask,
+                                 DispatchMode* dispatch_mode,
+                                 std::vector<int>* in_attrs,
+                                 std::vector<int>* out_attrs) {
+  CHECK_GE(in_attrs->size(), 2U);
+  CHECK_EQ(out_attrs->size(), 2U);
+  const int in_stype = in_attrs->at(0);
+  bool dispatched = false;
+  if (!dispatched && in_stype == kDefaultStorage) {
+    // dns -> dns
+    dispatched = storage_type_assign(out_attrs, kDefaultStorage,
+                                     dispatch_mode, DispatchMode::kFCompute);
+  }
+  if (!dispatched) {
+    dispatched = dispatch_fallback(out_attrs, dispatch_mode);
+  }
+  return dispatched;
+}
+
+
+inline std::vector<std::string> CTCLossOpListInputNames(const NodeAttrs& attrs) {
+  const CTCLossOpParam& param = nnvm::get<CTCLossOpParam>(attrs.parsed);
+  if (param.use_data_lengths && param.use_label_lengths) {
+    return {"data", "label", "data_lengths", "label_lengths"};
+  } else if (param.use_data_lengths) {
+    return {"data", "label", "data_lengths"};
+  } else if (param.use_label_lengths) {
+    return {"data", "label", "label_lengths"};
+  } else {
+    return {"data", "label"};
+  }
+}
+
+template<typename xpu>
+void CTCLossOpForward(const nnvm::NodeAttrs& attrs,
+                      const OpContext& ctx,
+                      const std::vector<TBlob>& inputs,
+                      const std::vector<OpReqType>& req,
+                      const std::vector<TBlob>& outputs) {
+  using namespace mshadow;
+  using namespace mshadow::expr;
+
+  const CTCLossOpParam& param = nnvm::get<CTCLossOpParam>(attrs.parsed);
+  CHECK_EQ(inputs.size(), CTCLossOpNumInputs(attrs));
+  CHECK_EQ(outputs.size(), 2U);
+  CHECK_EQ(req.size(), 2U);
+
+  const TBlob& in_data = inputs[ctc_loss::kData];
+  const TBlob& in_label = inputs[ctc_loss::kLabel];
+  const TBlob& out_data = outputs[ctc_loss::kOut];
+  const TBlob& out_grad = outputs[ctc_loss::kGrad];
+
+  Stream<xpu> *s = ctx.get_stream<xpu>();
+  MSHADOW_TYPE_SWITCH(inputs[ctc_loss::kLabel].type_flag_, DType, {
+    Tensor<xpu, 3, real_t> data = in_data.get<xpu, 3, real_t>(s);
+    Tensor<xpu, 2, DType> labels = in_label.get<xpu, 2, DType>(s);
+    Tensor<xpu, 1, real_t> costs = out_data.get<xpu, 1, real_t>(s);
+    Tensor<xpu, 3, real_t> grad = out_grad.get<xpu, 3, real_t>(s);
+
+    int max_seq_len = data.size(0);
+    int batch_size = data.size(1);
+    int alphabet_size = data.size(2);
+
+    // data_lengths
+    std::vector<int> data_lengths(batch_size, max_seq_len);
+    if (param.use_data_lengths) {
+      int kInputLength = 2;
+      IndexTensorToVector(inputs[kInputLength].get<xpu, 1, real_t>(s), &data_lengths);
+    }
+
+    // label_lengths
+    std::vector<int> packed_labels;
+    std::vector<int> label_lengths(batch_size);
+
+    if (param.use_label_lengths) {
+      int kLabelLength = 2 + param.use_data_lengths;
+      PackLabelByLength(labels, inputs[kLabelLength].get<xpu, 1, DType>(s),
+                        &packed_labels, &label_lengths);
+    } else {
+      LabelTensorToPackedVector(labels, param.blank_label == 0 ? 0 : -1,
+                                &packed_labels, &label_lengths);
+    }
+
+    size_t size_bytes;
+    get_workspace_size<real_t>(&label_lengths, &data_lengths, alphabet_size,
+                               batch_size, data.kDevCPU ? false : true, &size_bytes);
+
+    // round-up so there are enough elems in memory
+    int num_tmp_elems = (size_bytes + sizeof(real_t) - 1) / sizeof(real_t);
+    Tensor<xpu, 1, real_t> workspace =
+      ctx.requested[0].get_space_typed<xpu, 1, real_t>(Shape1(num_tmp_elems), s);
+
+    compute_ctc_cost(data, costs.dptr_, grad.dptr_, packed_labels.data(),
+                     label_lengths.data(), data_lengths.data(),
+                     workspace.dptr_, req[ctc_loss::kGrad] != mxnet::kNullOp,
+                     param.blank_label == 0 ? 0 : (alphabet_size - 1));
+
+    if (param.use_data_lengths) {
+      // baidu warp CTC implementation sometimes includes undefined gradients
+      // for data outside of length mask. Setting to 0 to make it consistent
+      // with CPU implementation.
+      int kInputLength = 2;
+      mxnet_op::SequenceMask(grad, inputs[kInputLength].get<xpu, 1, real_t>(s),
+                             static_cast<real_t>(0));
+    }
+  });
+}
+
+template<typename xpu>
+void CTCLossOpBackward(const nnvm::NodeAttrs& attrs,
+                       const OpContext& ctx,
+                       const std::vector<TBlob>& inputs,
+                       const std::vector<OpReqType>& req,
+                       const std::vector<TBlob>& outputs) {
+  using namespace mshadow;
+  using namespace mxnet_op;
+
+  Stream<xpu> *s = ctx.get_stream<xpu>();
+  const TBlob& in_grad = outputs[0];
+  const TBlob& out_grad = inputs[0];
+  const TBlob& grad_computed = inputs[3];  // grad computed in the forward step
+
+  Tensor<xpu, 3, real_t> igrad_data = in_grad.get<xpu, 3, real_t>(s);
+  Tensor<xpu, 1, real_t> ograd_data = out_grad.get<xpu, 1, real_t>(s);
+  Tensor<xpu, 3, real_t> computed_grad_data = grad_computed.get<xpu, 3, real_t>(s);
+
+  Assign(igrad_data, req[0],
+         mshadow::expr::broadcast<1>(ograd_data, computed_grad_data.shape_) * computed_grad_data);
+}
+
+}  // namespace op
+}  // namespace mxnet
+
+#endif  // MXNET_OPERATOR_NN_CTC_LOSS_INL_H_
+
diff --git a/src/operator/contrib/ctc_loss.cc b/src/operator/nn/ctc_loss.cc
similarity index 64%
rename from src/operator/contrib/ctc_loss.cc
rename to src/operator/nn/ctc_loss.cc
index 32e8e629f09..c381677b3ce 100644
--- a/src/operator/contrib/ctc_loss.cc
+++ b/src/operator/nn/ctc_loss.cc
@@ -18,26 +18,22 @@
  */
 
 /*!
- * Copyright (c) 2015 by Contributors
  * \file ctc_loss.cc
- * \brief
- * \author Sebastian Bodenstein
-*/
-
+ * \brief CPU Implementation of CTC Loss op
+ */
 #include "./ctc_loss-inl.h"
-#include "./ctc_include/detail/cpu_ctc.h"
+#include "../../../3rdparty/ctc_include/detail/cpu_ctc.h"
 
 namespace mshadow {
-
 template <typename DType>
 ctcStatus_t compute_ctc_cost(const Tensor<cpu, 3, DType> activations,
                              DType *costs, DType *grads, int *labels,
                              int *label_lengths, int *data_lengths,
-                             void *workspace, int train, int blank_label) {
+                             void *workspace, bool isTraining, int blank_label) {
   int minibatch = static_cast<int>(activations.size(1));
   int alphabet_size = static_cast<int>(activations.size(2));
   mxnet_warpctc::CpuCTC<DType> ctc(alphabet_size, minibatch, workspace, blank_label);
-  if (train) {
+  if (isTraining) {
     return ctc.cost_and_grad(activations.dptr_, grads, costs, labels,
                              label_lengths, data_lengths);
   } else {
@@ -45,32 +41,18 @@ ctcStatus_t compute_ctc_cost(const Tensor<cpu, 3, DType> activations,
                              data_lengths);
   }
 }
-
 }  // namespace mshadow
 
 namespace mxnet {
 namespace op {
-template <>
-Operator *CreateOp<cpu>(CTCLossParam param, int dtype) {
-  return new CTCLossOp<cpu>(param);
-}
-
-// DO_BIND_DISPATCH comes from operator_common.h
-Operator *CTCLossProp::CreateOperatorEx(Context ctx,
-                                        std::vector<TShape> *in_shape,
-                                        std::vector<int> *in_type) const {
-  std::vector<TShape> out_shape, aux_shape;
-  std::vector<int> out_type, aux_type;
-  CHECK(InferType(in_type, &out_type, &aux_type));
-  CHECK(InferShape(in_shape, &out_shape, &aux_shape));
-  DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0]);
-}
-
-DMLC_REGISTER_PARAMETER(CTCLossParam);
 
-MXNET_REGISTER_OP_PROPERTY(_contrib_CTCLoss, CTCLossProp)
-    .describe(R"code(Connectionist Temporal Classification Loss.
+DMLC_REGISTER_PARAMETER(CTCLossOpParam);
 
+NNVM_REGISTER_OP(CTCLoss)
+.add_alias("ctc_loss")
+.add_alias("_contrib_CTCLoss")
+.add_alias("_contrib_ctc_loss")
+.describe(R"code(Connectionist Temporal Classification Loss.
 The shapes of the inputs and outputs:
 
 - **data**: `(sequence_length, batch_size, alphabet_size)`
@@ -113,18 +95,41 @@ Sequence Data with Recurrent Neural Networks*, A. Graves *et al*. for more
 information on the definition and the algorithm.
 
 )code" ADD_FILELINE)
-    .add_argument("data", "NDArray-or-Symbol", "Input data to the ctc_loss op.")
-    .add_argument("label", "NDArray-or-Symbol",
-                  "Ground-truth labels for the loss.")
-    .add_argument("data_lengths", "NDArray-or-Symbol",
-                  "Lengths of data for each of the samples. Only required "
-                  "when use_data_lengths is true.")
-    .add_argument("label_lengths", "NDArray-or-Symbol",
-                  "Lengths of labels for each of the samples. Only required "
-                  "when use_label_lengths is true.")
-    .add_arguments(CTCLossParam::__FIELDS__());
-
-NNVM_REGISTER_OP(_contrib_CTCLoss).add_alias("_contrib_ctc_loss");
+.set_attr_parser(ParamParser<CTCLossOpParam>)
+.set_num_inputs(CTCLossOpNumInputs)
+.set_num_outputs(2)
+.set_attr<nnvm::FListInputNames>("FListInputNames", CTCLossOpListInputNames)
+.set_attr<nnvm::FListOutputNames>("FListOutputNAmes",
+  [](const NodeAttrs& attrs) {
+    return std::vector<std::string>{"out", "grad"};
+  })
+.set_attr<nnvm::FNumVisibleOutputs>("FNumVisibleOutputs",
+  [](const NodeAttrs& attrs) {
+    return 1;
+  })
+.set_attr<nnvm::FInferShape>("FInferShape", CTCLossOpShape)
+.set_attr<nnvm::FInferType>("FInferType", CTCLossOpType)
+.set_attr<FInferStorageType>("FInferStorageType", CTCLossOpStorageType)
+.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& attrs)
+  { return std::vector<ResourceRequest>{ResourceRequest::kTempSpace}; })
+.set_attr<FCompute>("FCompute<cpu>", CTCLossOpForward<cpu>)
+.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseOut{"_backward_ctc_loss"})
+.add_argument("data", "NDArray-or-Symbol", "Input ndarray")
+.add_argument("label", "NDArray-or-Symbol", "Ground-truth labels for the loss.")
+.add_argument("data_lengths", "NDArray-or-Symbol",
+              "Lengths of data for each of the samples. Only required "
+              "when use_data_lengths is true.")
+.add_argument("label_lengths", "NDArray-or-Symbol",
+              "Lengths of labels for each of the samples. Only required "
+              "when use_label_lengths is true.")
+.add_arguments(CTCLossOpParam::__FIELDS__());
+
+NNVM_REGISTER_OP(_backward_ctc_loss)
+.set_attr_parser(ParamParser<CTCLossOpParam>)
+.set_num_inputs(1)
+.set_num_outputs(CTCLossOpNumInputs)
+.set_attr<nnvm::TIsBackward>("TIsBackward", true)
+.set_attr<FCompute>("FCompute<cpu>", CTCLossOpBackward<cpu>);
 
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/contrib/ctc_loss.cu b/src/operator/nn/ctc_loss.cu
similarity index 81%
rename from src/operator/contrib/ctc_loss.cu
rename to src/operator/nn/ctc_loss.cu
index 3f5f12ca439..a4491bf6986 100644
--- a/src/operator/contrib/ctc_loss.cu
+++ b/src/operator/nn/ctc_loss.cu
@@ -18,14 +18,13 @@
  */
 
 /*!
- * Copyright (c) 2015 by Contributors
+ * Copyright (c) 2018 by Contributors
  * \file ctc_loss.cu
- * \brief
- * \author Sebastian Bodenstein
-*/
-#include <algorithm>
+ * \brief GPU Implementation of ctc_loss op
+ */
+
 #include "./ctc_loss-inl.h"
-#include "./ctc_include/detail/gpu_ctc.h"
+#include "../../../3rdparty/ctc_include/detail/gpu_ctc.h"
 
 namespace mshadow {
 
@@ -45,17 +44,19 @@ ctcStatus_t compute_ctc_cost(const Tensor<gpu, 3, DType> activations,
     return ctc.score_forward(activations.dptr_, costs, labels,
                              label_lengths, input_lengths);
 }
-
 }  // namespace mshadow
 
-////////////////////////////////////////////////////////////////////////////////
-
 namespace mxnet {
 namespace op {
-template <>
-Operator *CreateOp<gpu>(CTCLossParam param, int dtype) {
-  return new CTCLossOp<gpu>(param);
-}
+
+NNVM_REGISTER_OP(CTCLoss)
+.add_alias("ctc_loss")
+.add_alias("_contrib_ctc_loss")
+.add_alias("_contrib_CTCLoss")
+.set_attr<FCompute>("FCompute<gpu>", CTCLossOpForward<gpu>);
+
+NNVM_REGISTER_OP(_backward_ctc_loss)
+.set_attr<FCompute>("FCompute<gpu>", CTCLossOpBackward<gpu>);
 
 }  // namespace op
 }  // namespace mxnet
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index b17562c1d94..5332517fa68 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -4491,11 +4491,34 @@ def test_pick_helper(index_type=np.int32):
 def check_ctc_loss(acts, labels, loss_truth):
     in_var = mx.sym.Variable('input')
     labels_var = mx.sym.Variable('labels')
-    ctc = mx.sym.contrib.ctc_loss(in_var, labels_var)
+    ctc = mx.sym.ctc_loss(in_var, labels_var)
     acts_nd = mx.nd.array(acts, ctx=default_context())
     labels_nd = mx.nd.array(labels, ctx=default_context())
     exe = ctc.bind(ctx=default_context(), args=[acts_nd, labels_nd])
+    # test forward with grad calc
+    exe.forward(is_train=True)
+    outTest = exe.outputs[0]
     # test forward without grad calc
+    exe.forward(is_train=False)
+    outTrain = exe.outputs[0]
+    # make sure losses calculated with both modes are the same
+    assert_almost_equal(outTest.asnumpy(), outTrain.asnumpy())
+
+    # test against ground truth, if available
+    if loss_truth is not None:
+        assert_almost_equal(outTest.asnumpy(), loss_truth)
+    # test grad
+    check_numeric_gradient(ctc, [acts, labels], grad_nodes=['input'], rtol=0.05, atol=1e-3)
+
+# check contrib operator for backward compatibility
+def check_contrib_ctc_loss(acts, labels, loss_truth):
+    in_var = mx.sym.Variable('input')
+    labels_var = mx.sym.Variable('labels')
+    ctc = mx.sym.contrib.ctc_loss(in_var, labels_var)
+    acts_nd = mx.nd.array(acts, ctx=default_context())
+    labels_nd = mx.nd.array(labels, ctx=default_context())
+    exe = ctc.bind(ctx=default_context(), args=[acts_nd, labels_nd])
+    # test forward with grad calc
     exe.forward(is_train=True)
     outTest = exe.outputs[0]
     # test forward without grad calc
@@ -4503,6 +4526,7 @@ def check_ctc_loss(acts, labels, loss_truth):
     outTrain = exe.outputs[0]
     # make sure losses calculated with both modes are the same
     assert_almost_equal(outTest.asnumpy(), outTrain.asnumpy())
+
     # test against ground truth, if available
     if loss_truth is not None:
         assert_almost_equal(outTest.asnumpy(), loss_truth)
@@ -4520,6 +4544,8 @@ def test_ctc_loss():
     labels = np.array([[2, 3, 0], [2, 3, 0]])
     true_loss = np.array([4.04789, 4.04789], dtype=np.float32) # from Torch
     check_ctc_loss(acts, labels, true_loss)
+    check_contrib_ctc_loss(acts, labels, true_loss)
+
     # Test 2:
     acts2 = np.array([
         [[-5, -4, -3, -2, -1], [1.2, 3.4, 1.2, -0.1, -2.34]],
@@ -4528,11 +4554,13 @@ def test_ctc_loss():
     labels2 = np.array([[2, 3, 1], [2, 0, 0]], dtype=np.float32)
     true_loss = np.array([7.3557, 5.4091], dtype=np.float32) # from Torch
     check_ctc_loss(acts2, labels2, true_loss)
+    check_contrib_ctc_loss(acts2, labels2, true_loss)
 
     # Test 3: check use integer type as label
     labels3 = np.array([[2, 3, 1], [2, 0, 0]], dtype=np.int32)
     true_loss = np.array([7.3557, 5.4091], dtype=np.float32) # from Torch
     check_ctc_loss(acts2, labels3, true_loss)
+    check_contrib_ctc_loss(acts2, labels3, true_loss)
 
 @with_seed()
 def test_ctc_loss_with_large_classes():
@@ -4550,7 +4578,7 @@ def test_ctc_loss_with_large_classes():
         [1000, 2000, 3000, 4000, 0, 5000, 0, 0]], dtype=np.int32)
     nd_data = mx.nd.array(data)
     nd_label = mx.nd.array(label)
-    loss = mx.nd.contrib.ctc_loss(data=nd_data, label=nd_label)
+    loss = mx.nd.ctc_loss(data=nd_data, label=nd_label)
     expected_loss = np.array([688.02826, 145.34462])
     assert_almost_equal(loss.asnumpy(), expected_loss)
 
@@ -4619,6 +4647,85 @@ def check_ctc_loss_grad(blank_label): # from tf
         label_lens = np.array([5, 4], dtype=np.int32)
         loss_truth = np.array([-loss_log_prob_0, -loss_log_prob_1], np.float32)
 
+        with default_context():
+            data = mx.nd.array(inputs)
+            label = mx.nd.array(labels)
+            data.attach_grad()
+            with mx.autograd.record():
+                l = mx.ndarray.CTCLoss(data, label,
+                                       use_data_lengths=True,
+                                       use_label_lengths=True,
+                                       data_lengths=mx.nd.array(seq_lens),
+                                       label_lengths=mx.nd.array(label_lens),
+                                       blank_label=blank_label)
+                l.backward()
+            assert_almost_equal(l.asnumpy(), loss_truth, atol=1e-5, rtol=1e-5)
+            assert_almost_equal(data.grad.asnumpy(), grad_truth, atol=1e-5, rtol=1e-5)
+
+    # check contrib operator for backward compatibility
+    def check_contrib_ctc_loss_grad(blank_label): # from tf
+        vocab_size = 5
+        max_label_len = 5
+        padding_mask = -1+ (blank_label=='first')
+
+        targets_0 = [0, 1, 2, 1, 0]
+        loss_log_prob_0 = -3.34211
+        input_prob_matrix_0 = np.asarray(
+            [[0.633766, 0.221185, 0.0917319, 0.0129757, 0.0142857, 0.0260553],
+             [0.111121, 0.588392, 0.278779, 0.0055756, 0.00569609, 0.010436],
+             [0.0357786, 0.633813, 0.321418, 0.00249248, 0.00272882, 0.0037688],
+             [0.0663296, 0.643849, 0.280111, 0.00283995, 0.0035545, 0.00331533],
+             [0.458235, 0.396634, 0.123377, 0.00648837, 0.00903441, 0.00623107]],
+            dtype=np.float32)
+        gradient_log_prob_0 = np.asarray(
+            [[-0.366234, 0.221185, 0.0917319, 0.0129757, 0.0142857, 0.0260553],
+             [0.111121, -0.411608, 0.278779, 0.0055756, 0.00569609, 0.010436],
+             [0.0357786, 0.633813, -0.678582, 0.00249248, 0.00272882, 0.0037688],
+             [0.0663296, -0.356151, 0.280111, 0.00283995, 0.0035545, 0.00331533],
+             [-0.541765, 0.396634, 0.123377, 0.00648837, 0.00903441, 0.00623107]],
+            dtype=np.float32)
+
+        targets_1 = [0, 1, 1, 0]
+        loss_log_prob_1 = -5.42262
+        input_prob_matrix_1 = np.asarray(
+            [[0.30176, 0.28562, 0.0831517, 0.0862751, 0.0816851, 0.161508],
+             [0.24082, 0.397533, 0.0557226, 0.0546814, 0.0557528, 0.19549],
+             [0.230246, 0.450868, 0.0389607, 0.038309, 0.0391602, 0.202456],
+             [0.280884, 0.429522, 0.0326593, 0.0339046, 0.0326856, 0.190345],
+             [0.423286, 0.315517, 0.0338439, 0.0393744, 0.0339315, 0.154046]],
+            dtype=np.float32)
+        gradient_log_prob_1 = np.asarray(
+            [[-0.69824, 0.28562, 0.0831517, 0.0862751, 0.0816851, 0.161508],
+             [0.24082, -0.602467, 0.0557226, 0.0546814, 0.0557528, 0.19549],
+             [0.230246, 0.450868, 0.0389607, 0.038309, 0.0391602, -0.797544],
+             [0.280884, -0.570478, 0.0326593, 0.0339046, 0.0326856, 0.190345],
+             [-0.576714, 0.315517, 0.0338439, 0.0393744, 0.0339315, 0.154046]],
+            dtype=np.float32)
+
+        inputs = [
+            np.vstack(
+                [input_prob_matrix_0[t, :], input_prob_matrix_1[t, :]])
+            for t in range(5)
+        ] + 2 * [np.nan * np.ones((2, vocab_size+1), np.float32)]
+        inputs = np.log(np.asarray(inputs, dtype=np.float32))
+
+        grad_truth = np.array([
+            np.vstack(
+                [gradient_log_prob_0[t, :], gradient_log_prob_1[t, :]])
+            for t in range(5)
+        ] + 2 * [np.zeros((2, vocab_size+1), np.float32)])
+
+        if blank_label == 'first':
+            inputs = np.roll(inputs, 1, axis=2)
+            grad_truth = np.roll(grad_truth, 1, axis=2)
+
+        labels = (np.asarray([x + [padding_mask]*(max_label_len-len(x))
+                             for x in [targets_0, targets_1]])+(blank_label == 'first'))
+
+        seq_lens = np.array([5, 5], dtype=np.int32)
+        label_lens = np.array([5, 4], dtype=np.int32)
+        loss_truth = np.array([-loss_log_prob_0, -loss_log_prob_1], np.float32)
+
         with default_context():
             data = mx.nd.array(inputs)
             label = mx.nd.array(labels)
@@ -4634,8 +4741,11 @@ def check_ctc_loss_grad(blank_label): # from tf
             assert_almost_equal(l.asnumpy(), loss_truth, atol=1e-5, rtol=1e-5)
             assert_almost_equal(data.grad.asnumpy(), grad_truth, atol=1e-5, rtol=1e-5)
 
+
     check_ctc_loss_grad('first')
     check_ctc_loss_grad('last')
+    check_contrib_ctc_loss_grad('first')
+    check_contrib_ctc_loss_grad('last')
 
 
 @with_seed()


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services