You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by an...@apache.org on 2018/10/09 01:37:15 UTC

[incubator-mxnet] branch master updated: [MXNET-912] Refactoring ctc loss operator (#12637)

This is an automated email from the ASF dual-hosted git repository.

anirudh2290 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new a2bafe1  [MXNET-912] Refactoring ctc loss operator (#12637)
a2bafe1 is described below

commit a2bafe177940e9966949c6f63cda421f10b22b19
Author: Lin Yuan <ap...@gmail.com>
AuthorDate: Mon Oct 8 18:36:58 2018 -0700

    [MXNET-912] Refactoring ctc loss operator (#12637)
    
    * Implement ctc_loss as a normal operator
    
    * Update unit test
    
    * Update unit test and fix bug in backward
    
    * fix lint error
    
    * refactoring
    
    * Fix compilation error in CUDA
    
    * Fix CPU compilation error
    
    * Move ctc_include to nn folder and refactor
    
    * temporarily disable lint on 3rd party includes
    
    * move ctc_include to 3rdparty
    
    * remove contrib ctc_loss operator
    
    * revert a change by mistake
    
    * Fix a bug in kDevCPU
    
    * revert change by mistake
    
    * add alias to make it backward compatible
    
    * add unit test for backward compatibility
    
    * linting
---
 .../contrib => 3rdparty}/ctc_include/LICENSE       |   0
 .../ctc_include/contrib/moderngpu/LICENSE          |   0
 .../moderngpu/include/device/ctaloadbalance.cuh    |   0
 .../contrib/moderngpu/include/device/ctamerge.cuh  |   0
 .../contrib/moderngpu/include/device/ctascan.cuh   |   0
 .../contrib/moderngpu/include/device/ctasearch.cuh |   0
 .../moderngpu/include/device/ctasegreduce.cuh      |   0
 .../moderngpu/include/device/ctasegscan.cuh        |   0
 .../moderngpu/include/device/ctasegsort.cuh        |   0
 .../moderngpu/include/device/ctasortedsearch.cuh   |   0
 .../moderngpu/include/device/devicetypes.cuh       |   0
 .../moderngpu/include/device/deviceutil.cuh        |   0
 .../moderngpu/include/device/intrinsics.cuh        |   0
 .../contrib/moderngpu/include/device/loadstore.cuh |   0
 .../moderngpu/include/device/serialsets.cuh        |   0
 .../moderngpu/include/device/sortnetwork.cuh       |   0
 .../contrib/moderngpu/include/mgpudevice.cuh       |   0
 .../contrib/moderngpu/include/mgpuenums.h          |   0
 .../contrib/moderngpu/include/util/static.h        |   0
 .../ctc_include/detail/cpu_ctc.h                   |   0
 .../ctc_include/detail/ctc_helper.h                |   0
 .../ctc_include/detail/gpu_ctc.h                   |   0
 .../ctc_include/detail/gpu_ctc_kernels.h           |   0
 .../ctc_include/detail/hostdevice.h                |   0
 python/mxnet/gluon/loss.py                         |   8 +-
 src/operator/contrib/ctc_loss-inl.h                | 591 ---------------------
 src/operator/nn/ctc_loss-inl.h                     | 397 ++++++++++++++
 src/operator/{contrib => nn}/ctc_loss.cc           |  87 +--
 src/operator/{contrib => nn}/ctc_loss.cu           |  27 +-
 tests/python/unittest/test_operator.py             | 114 +++-
 30 files changed, 573 insertions(+), 651 deletions(-)

diff --git a/src/operator/contrib/ctc_include/LICENSE b/3rdparty/ctc_include/LICENSE
similarity index 100%
rename from src/operator/contrib/ctc_include/LICENSE
rename to 3rdparty/ctc_include/LICENSE
diff --git a/src/operator/contrib/ctc_include/contrib/moderngpu/LICENSE b/3rdparty/ctc_include/contrib/moderngpu/LICENSE
similarity index 100%
rename from src/operator/contrib/ctc_include/contrib/moderngpu/LICENSE
rename to 3rdparty/ctc_include/contrib/moderngpu/LICENSE
diff --git a/src/operator/contrib/ctc_include/contrib/moderngpu/include/device/ctaloadbalance.cuh b/3rdparty/ctc_include/contrib/moderngpu/include/device/ctaloadbalance.cuh
similarity index 100%
rename from src/operator/contrib/ctc_include/contrib/moderngpu/include/device/ctaloadbalance.cuh
rename to 3rdparty/ctc_include/contrib/moderngpu/include/device/ctaloadbalance.cuh
diff --git a/src/operator/contrib/ctc_include/contrib/moderngpu/include/device/ctamerge.cuh b/3rdparty/ctc_include/contrib/moderngpu/include/device/ctamerge.cuh
similarity index 100%
rename from src/operator/contrib/ctc_include/contrib/moderngpu/include/device/ctamerge.cuh
rename to 3rdparty/ctc_include/contrib/moderngpu/include/device/ctamerge.cuh
diff --git a/src/operator/contrib/ctc_include/contrib/moderngpu/include/device/ctascan.cuh b/3rdparty/ctc_include/contrib/moderngpu/include/device/ctascan.cuh
similarity index 100%
rename from src/operator/contrib/ctc_include/contrib/moderngpu/include/device/ctascan.cuh
rename to 3rdparty/ctc_include/contrib/moderngpu/include/device/ctascan.cuh
diff --git a/src/operator/contrib/ctc_include/contrib/moderngpu/include/device/ctasearch.cuh b/3rdparty/ctc_include/contrib/moderngpu/include/device/ctasearch.cuh
similarity index 100%
rename from src/operator/contrib/ctc_include/contrib/moderngpu/include/device/ctasearch.cuh
rename to 3rdparty/ctc_include/contrib/moderngpu/include/device/ctasearch.cuh
diff --git a/src/operator/contrib/ctc_include/contrib/moderngpu/include/device/ctasegreduce.cuh b/3rdparty/ctc_include/contrib/moderngpu/include/device/ctasegreduce.cuh
similarity index 100%
rename from src/operator/contrib/ctc_include/contrib/moderngpu/include/device/ctasegreduce.cuh
rename to 3rdparty/ctc_include/contrib/moderngpu/include/device/ctasegreduce.cuh
diff --git a/src/operator/contrib/ctc_include/contrib/moderngpu/include/device/ctasegscan.cuh b/3rdparty/ctc_include/contrib/moderngpu/include/device/ctasegscan.cuh
similarity index 100%
rename from src/operator/contrib/ctc_include/contrib/moderngpu/include/device/ctasegscan.cuh
rename to 3rdparty/ctc_include/contrib/moderngpu/include/device/ctasegscan.cuh
diff --git a/src/operator/contrib/ctc_include/contrib/moderngpu/include/device/ctasegsort.cuh b/3rdparty/ctc_include/contrib/moderngpu/include/device/ctasegsort.cuh
similarity index 100%
rename from src/operator/contrib/ctc_include/contrib/moderngpu/include/device/ctasegsort.cuh
rename to 3rdparty/ctc_include/contrib/moderngpu/include/device/ctasegsort.cuh
diff --git a/src/operator/contrib/ctc_include/contrib/moderngpu/include/device/ctasortedsearch.cuh b/3rdparty/ctc_include/contrib/moderngpu/include/device/ctasortedsearch.cuh
similarity index 100%
rename from src/operator/contrib/ctc_include/contrib/moderngpu/include/device/ctasortedsearch.cuh
rename to 3rdparty/ctc_include/contrib/moderngpu/include/device/ctasortedsearch.cuh
diff --git a/src/operator/contrib/ctc_include/contrib/moderngpu/include/device/devicetypes.cuh b/3rdparty/ctc_include/contrib/moderngpu/include/device/devicetypes.cuh
similarity index 100%
rename from src/operator/contrib/ctc_include/contrib/moderngpu/include/device/devicetypes.cuh
rename to 3rdparty/ctc_include/contrib/moderngpu/include/device/devicetypes.cuh
diff --git a/src/operator/contrib/ctc_include/contrib/moderngpu/include/device/deviceutil.cuh b/3rdparty/ctc_include/contrib/moderngpu/include/device/deviceutil.cuh
similarity index 100%
rename from src/operator/contrib/ctc_include/contrib/moderngpu/include/device/deviceutil.cuh
rename to 3rdparty/ctc_include/contrib/moderngpu/include/device/deviceutil.cuh
diff --git a/src/operator/contrib/ctc_include/contrib/moderngpu/include/device/intrinsics.cuh b/3rdparty/ctc_include/contrib/moderngpu/include/device/intrinsics.cuh
similarity index 100%
rename from src/operator/contrib/ctc_include/contrib/moderngpu/include/device/intrinsics.cuh
rename to 3rdparty/ctc_include/contrib/moderngpu/include/device/intrinsics.cuh
diff --git a/src/operator/contrib/ctc_include/contrib/moderngpu/include/device/loadstore.cuh b/3rdparty/ctc_include/contrib/moderngpu/include/device/loadstore.cuh
similarity index 100%
rename from src/operator/contrib/ctc_include/contrib/moderngpu/include/device/loadstore.cuh
rename to 3rdparty/ctc_include/contrib/moderngpu/include/device/loadstore.cuh
diff --git a/src/operator/contrib/ctc_include/contrib/moderngpu/include/device/serialsets.cuh b/3rdparty/ctc_include/contrib/moderngpu/include/device/serialsets.cuh
similarity index 100%
rename from src/operator/contrib/ctc_include/contrib/moderngpu/include/device/serialsets.cuh
rename to 3rdparty/ctc_include/contrib/moderngpu/include/device/serialsets.cuh
diff --git a/src/operator/contrib/ctc_include/contrib/moderngpu/include/device/sortnetwork.cuh b/3rdparty/ctc_include/contrib/moderngpu/include/device/sortnetwork.cuh
similarity index 100%
rename from src/operator/contrib/ctc_include/contrib/moderngpu/include/device/sortnetwork.cuh
rename to 3rdparty/ctc_include/contrib/moderngpu/include/device/sortnetwork.cuh
diff --git a/src/operator/contrib/ctc_include/contrib/moderngpu/include/mgpudevice.cuh b/3rdparty/ctc_include/contrib/moderngpu/include/mgpudevice.cuh
similarity index 100%
rename from src/operator/contrib/ctc_include/contrib/moderngpu/include/mgpudevice.cuh
rename to 3rdparty/ctc_include/contrib/moderngpu/include/mgpudevice.cuh
diff --git a/src/operator/contrib/ctc_include/contrib/moderngpu/include/mgpuenums.h b/3rdparty/ctc_include/contrib/moderngpu/include/mgpuenums.h
similarity index 100%
rename from src/operator/contrib/ctc_include/contrib/moderngpu/include/mgpuenums.h
rename to 3rdparty/ctc_include/contrib/moderngpu/include/mgpuenums.h
diff --git a/src/operator/contrib/ctc_include/contrib/moderngpu/include/util/static.h b/3rdparty/ctc_include/contrib/moderngpu/include/util/static.h
similarity index 100%
rename from src/operator/contrib/ctc_include/contrib/moderngpu/include/util/static.h
rename to 3rdparty/ctc_include/contrib/moderngpu/include/util/static.h
diff --git a/src/operator/contrib/ctc_include/detail/cpu_ctc.h b/3rdparty/ctc_include/detail/cpu_ctc.h
similarity index 100%
rename from src/operator/contrib/ctc_include/detail/cpu_ctc.h
rename to 3rdparty/ctc_include/detail/cpu_ctc.h
diff --git a/src/operator/contrib/ctc_include/detail/ctc_helper.h b/3rdparty/ctc_include/detail/ctc_helper.h
similarity index 100%
rename from src/operator/contrib/ctc_include/detail/ctc_helper.h
rename to 3rdparty/ctc_include/detail/ctc_helper.h
diff --git a/src/operator/contrib/ctc_include/detail/gpu_ctc.h b/3rdparty/ctc_include/detail/gpu_ctc.h
similarity index 100%
rename from src/operator/contrib/ctc_include/detail/gpu_ctc.h
rename to 3rdparty/ctc_include/detail/gpu_ctc.h
diff --git a/src/operator/contrib/ctc_include/detail/gpu_ctc_kernels.h b/3rdparty/ctc_include/detail/gpu_ctc_kernels.h
similarity index 100%
rename from src/operator/contrib/ctc_include/detail/gpu_ctc_kernels.h
rename to 3rdparty/ctc_include/detail/gpu_ctc_kernels.h
diff --git a/src/operator/contrib/ctc_include/detail/hostdevice.h b/3rdparty/ctc_include/detail/hostdevice.h
similarity index 100%
rename from src/operator/contrib/ctc_include/detail/hostdevice.h
rename to 3rdparty/ctc_include/detail/hostdevice.h
diff --git a/python/mxnet/gluon/loss.py b/python/mxnet/gluon/loss.py
index 2be4398..7e4d345 100644
--- a/python/mxnet/gluon/loss.py
+++ b/python/mxnet/gluon/loss.py
@@ -468,10 +468,10 @@ class CTCLoss(Loss):
             pred = F.swapaxes(pred, 0, 1)
         if self._batch_axis == 1:
             label = F.swapaxes(label, 0, 1)
-        loss = F.contrib.CTCLoss(pred, label, pred_lengths, label_lengths,
-                                 use_data_lengths=pred_lengths is not None,
-                                 use_label_lengths=label_lengths is not None,
-                                 blank_label='last')
+        loss = F.CTCLoss(pred, label, pred_lengths, label_lengths,
+                         use_data_lengths=pred_lengths is not None,
+                         use_label_lengths=label_lengths is not None,
+                         blank_label='last')
         return _apply_weighting(F, loss, self._weight, sample_weight)
 
 
diff --git a/src/operator/contrib/ctc_loss-inl.h b/src/operator/contrib/ctc_loss-inl.h
deleted file mode 100644
index c8a8b26..0000000
--- a/src/operator/contrib/ctc_loss-inl.h
+++ /dev/null
@@ -1,591 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * Copyright (c) 2016 by Contributors
- * \file ctc_loss-inl.h
- * \brief
- * \author Sebastian Bodenstien
-*/
-
-#ifndef MXNET_OPERATOR_CONTRIB_CTC_LOSS_INL_H_
-#define MXNET_OPERATOR_CONTRIB_CTC_LOSS_INL_H_
-
-#include <dmlc/logging.h>
-#include <dmlc/parameter.h>
-#include <mxnet/operator.h>
-#include <algorithm>
-#include <map>
-#include <vector>
-#include <string>
-#include <utility>
-#include <ctime>
-#include <cstring>
-#include <iostream>
-#include "../operator_common.h"
-#include "../sequence_op_common.h"
-#include "../mshadow_op.h"
-#include "../nn/sequence_mask-inl.h"
-
-#if defined(__CUDACC__) && MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7
-#define CUDNN_LABEL_LENGTH_LIMIT 256
-#include "../nn/softmax-inl.h"
-#endif  // CUDNN
-
-namespace mxnet {
-namespace op {
-
-namespace ctc_loss {
-enum CTCLossOpInputs { kData, kLabel };
-enum CTCLossOpOutputs { kOut, kGrad };
-enum CTCLossOpForwardResource { kTempSpace };
-}
-
-template <typename T>
-inline void get_workspace_size(std::vector<int> *label_lengths,
-                               std::vector<int> *data_lengths,
-                               int alphabet_size, int minibatch, bool gpu,
-                               size_t *size_bytes) {
-  // This is the max of all S and T for all examples in the minibatch.
-  int maxL = *std::max_element(label_lengths->data(),
-                               label_lengths->data() + minibatch);
-  int maxT = *std::max_element(data_lengths->data(),
-                               data_lengths->data() + minibatch);
-
-  const int S = 2 * maxL + 1;
-
-  *size_bytes = 0;
-
-  if (gpu) {
-    // GPU storage
-    // nll_forward, nll_backward
-    *size_bytes += 2 * sizeof(T) * minibatch;
-
-    // repeats
-    *size_bytes += sizeof(int) * minibatch;
-
-    // label offsets
-    *size_bytes += sizeof(int) * minibatch;
-
-    // utt_length
-    *size_bytes += sizeof(int) * minibatch;
-
-    // label lengths
-    *size_bytes += sizeof(int) * minibatch;
-
-    // labels without blanks - overallocate for now
-    *size_bytes += sizeof(int) * maxL * minibatch;
-
-    // labels with blanks
-    *size_bytes += sizeof(int) * S * minibatch;
-
-    // alphas
-    *size_bytes += sizeof(T) * S * maxT * minibatch;
-
-    // denoms
-    *size_bytes += sizeof(T) * maxT * minibatch;
-
-    // probs (since we will pass in activations)
-    *size_bytes += sizeof(T) * alphabet_size * maxT * minibatch;
-
-  } else {
-    // cpu can eventually replace all minibatch with
-    // max number of concurrent threads if memory is
-    // really tight
-
-    // per minibatch memory
-    size_t per_minibatch_bytes = 0;
-
-    // output
-    per_minibatch_bytes += sizeof(T) * alphabet_size;
-
-    // alphas
-    per_minibatch_bytes += sizeof(T) * S * maxT;
-
-    // betas
-    per_minibatch_bytes += sizeof(T) * S;
-
-    // labels w/blanks, e_inc, s_inc
-    per_minibatch_bytes += 3 * sizeof(int) * S;
-
-    *size_bytes = per_minibatch_bytes * minibatch;
-
-    // probs
-    *size_bytes += sizeof(T) * alphabet_size * maxT * minibatch;
-  }
-}
-
-// Takes a tensor of labels, and interprets 0-elements at the end of the vector
-// as padding. The tensor is packed into an std::vector without padding
-// characters. The label sequence lengths are also inferred from the padding chars.
-// When cudnn is enabled, the return value signifies whether the cudnn length limit is exceeded.
-template <typename DType, typename xpu>
-inline bool LabelTensorToPackedVector(mshadow::Tensor<xpu, 2, DType> labels,
-                                      int padding_mask,
-                                      std::vector<int> *packed_labels,
-                                      std::vector<int> *label_lengths) {
-  int batch = labels.size(0);
-  int max_num_labels = labels.size(1);
-  bool exceed_limit = false;
-
-  std::vector<int> cpu_labels(max_num_labels*batch);
-  mshadow::Tensor<xpu, 1, DType> flat_labels = labels.FlatTo1D();
-  IndexTensorToVector(flat_labels, &cpu_labels);
-
-  for (int b = 0; b < batch; ++b) {
-    auto start = cpu_labels.data()+b*max_num_labels;
-    auto res = std::find(start, start+max_num_labels, padding_mask);
-    int len = std::distance(start, res);
-#if defined(__CUDACC__) && MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7
-    exceed_limit = exceed_limit || len > CUDNN_LABEL_LENGTH_LIMIT;
-#endif
-    std::copy(start, start + len,
-              std::back_inserter(*packed_labels));
-    label_lengths->at(b) = len;
-  }
-  return exceed_limit;
-}
-
-// Takes a tensor of labels, and a vector which specifies the actual length of each label
-// The tensor is packed into an std::vector without padding characters.
-// The label length vector is copied into an std::vector.
-// When cudnn is enabled, the return value signifies whether the cudnn length limit is exceeded.
-template <typename DType, typename xpu>
-inline bool PackLabelByLength(mshadow::Tensor<xpu, 2, DType> labels,
-                              mshadow::Tensor<xpu, 1, DType> in_label_lengths,
-                              std::vector<int> *packed_labels,
-                              std::vector<int> *label_lengths) {
-  int batch = labels.size(0);
-  int max_num_labels = labels.size(1);
-  bool exceed_limit = false;
-
-  IndexTensorToVector(in_label_lengths, label_lengths);
-
-  std::vector<int> cpu_labels(max_num_labels*batch);
-  mshadow::Tensor<xpu, 1, DType> flat_labels = labels.FlatTo1D();
-  IndexTensorToVector(flat_labels, &cpu_labels);
-
-  for (int b = 0; b < batch; ++b) {
-    auto start = cpu_labels.data()+b*max_num_labels;
-    int len = label_lengths->at(b);
-#if defined(__CUDACC__) && MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7
-    exceed_limit = exceed_limit || len > CUDNN_LABEL_LENGTH_LIMIT;
-#endif
-    std::copy(start, start + len,
-              std::back_inserter(*packed_labels));
-  }
-  return exceed_limit;
-}
-
-struct CTCLossParam : public dmlc::Parameter<CTCLossParam> {
-  bool use_data_lengths;
-  bool use_label_lengths;
-  int blank_label;
-  DMLC_DECLARE_PARAMETER(CTCLossParam) {
-    DMLC_DECLARE_FIELD(use_data_lengths).set_default(false)
-      .describe("Whether the data lenghts are decided by `data_lengths`. "
-                "If false, the lengths are equal to the max sequence length.");
-    DMLC_DECLARE_FIELD(use_label_lengths).set_default(false)
-      .describe("Whether the label lenghts are decided by "
-                "`label_lengths`, or derived from `padding_mask`. "
-                "If false, the lengths are derived from the "
-                "first occurrence of the value of `padding_mask`. "
-                "The value of `padding_mask` is ``0`` when first CTC label is reserved for blank, "
-                "and ``-1`` when last label is reserved for blank. See `blank_label`.");
-    DMLC_DECLARE_FIELD(blank_label)
-      .add_enum("first", 0)
-      .add_enum("last", 1)
-      .set_default(0)
-      .describe("Set the label that is reserved for blank label."
-                "If \"first\", 0-th label is reserved, and "
-                "label values for tokens in the vocabulary are "
-                "between ``1`` and ``alphabet_size-1``, and the padding mask is ``-1``. "
-                "If \"last\", last label value ``alphabet_size-1`` "
-                "is reserved for blank label instead, "
-                "and label values for tokens in the vocabulary are "
-                "between ``0`` and ``alphabet_size-2``, and the padding mask is ``0``.");
-  }
-};
-
-template <typename xpu>
-class CTCLossOp : public Operator {
- public:
-  explicit CTCLossOp(CTCLossParam p) {
-    this->param_ = p;
-    exceed_cudnn_limit = false;
-#if defined(__CUDACC__) && MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7
-    CUDNN_CALL(cudnnCreateCTCLossDescriptor(&ctc_desc_));
-    CUDNN_CALL(cudnnSetCTCLossDescriptor(ctc_desc_, CUDNN_DATA_FLOAT));
-    CUDNN_CALL(cudnnCreateTensorDescriptor(&prob_desc_));
-    CUDNN_CALL(cudnnCreateTensorDescriptor(&grad_desc_));
-#endif
-  }
-
-  ~CTCLossOp() {
-#if defined(__CUDACC__) && MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7
-    CUDNN_CALL(cudnnDestroyCTCLossDescriptor(ctc_desc_));
-    CUDNN_CALL(cudnnDestroyTensorDescriptor(prob_desc_));
-    CUDNN_CALL(cudnnDestroyTensorDescriptor(grad_desc_));
-#endif
-  }
-
-  virtual void Forward(const OpContext &ctx, const std::vector<TBlob> &in_data,
-                       const std::vector<OpReqType> &req,
-                       const std::vector<TBlob> &out_data,
-                       const std::vector<TBlob> &aux_args) {
-    using namespace mshadow;
-    using namespace mshadow::expr;
-    CHECK_EQ(in_data.size(), 2U+param_.use_data_lengths+param_.use_label_lengths);
-    CHECK_EQ(out_data.size(), 2U);
-    exceed_cudnn_limit = false;
-    Stream<xpu> *s = ctx.get_stream<xpu>();
-
-    MSHADOW_TYPE_SWITCH(in_data[ctc_loss::kLabel].type_flag_, DType, {
-      Tensor<xpu, 3, real_t> data =
-        in_data[ctc_loss::kData].get<xpu, 3, real_t>(s);
-      Tensor<xpu, 2, DType> labels =
-        in_data[ctc_loss::kLabel].get<xpu, 2, DType>(s);
-
-      Tensor<xpu, 1, real_t> costs =
-        out_data[ctc_loss::kOut].get<xpu, 1, real_t>(s);
-      Tensor<xpu, 3, real_t> grad =
-        out_data[ctc_loss::kGrad].get<xpu, 3, real_t>(s);
-
-      int max_seq_len = data.size(0);
-      int batch_size = data.size(1);
-      int alphabet_size = data.size(2);
-
-      // data_lengths
-      std::vector<int> data_lengths(batch_size, max_seq_len);
-      if (param_.use_data_lengths) {
-        int kInputLength = 2;
-        IndexTensorToVector(in_data[kInputLength].get<xpu, 1, real_t>(s), &data_lengths);
-      }
-
-      // label_lengths
-      std::vector<int> packed_labels;
-      std::vector<int> label_lengths(batch_size);
-
-      if (param_.use_label_lengths) {
-        int kLabelLength = 2 + param_.use_data_lengths;
-        exceed_cudnn_limit =
-          PackLabelByLength(labels, in_data[kLabelLength].get<xpu, 1, DType>(s),
-                           &packed_labels, &label_lengths);
-      } else {
-        exceed_cudnn_limit = LabelTensorToPackedVector(labels, param_.blank_label == 0 ? 0 : -1,
-                                                      &packed_labels, &label_lengths);
-      }
-
-      // CUDNN is disabled due to lack of support for input lengths
-      /* #if defined(__CUDACC__) && MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7 */
-      /*     if (!exceed_cudnn_limit) { */
-      /*       cudnn_forward(ctx, s, data, costs, grad, */
-      /*                     &data_lengths, &label_lengths, &packed_labels, */
-      /*                     max_seq_len, batch_size, alphabet_size, */
-      /*                     req[ctc_loss::kGrad] != mxnet::kNullOp); */
-      /*     } else { */
-      /*       baidu_forward(ctx, s, data, costs, grad, */
-      /*                     &data_lengths, &label_lengths, &packed_labels, */
-      /*                     batch_size, alphabet_size, req[ctc_loss::kGrad] != mxnet::kNullOp);*/
-      /*     } */
-      /* #else */
-
-      baidu_forward(ctx, s, data, costs, grad,
-                    &data_lengths, &label_lengths, &packed_labels,
-                    batch_size, alphabet_size, req[ctc_loss::kGrad] != mxnet::kNullOp);
-
-      if (param_.use_data_lengths) {
-        // baidu warp CTC implementation sometimes includes undefined gradients
-        // for data outside of length mask. Setting to 0 to make it consistent
-        // with CPU implementation.
-        int kInputLength = 2;
-        mxnet_op::SequenceMask(grad, in_data[kInputLength].get<xpu, 1, real_t>(s),
-                              static_cast<real_t>(0));
-      }
-    });
-  }
-
-  virtual void Backward(const OpContext &ctx,
-                        const std::vector<TBlob> &out_grad,
-                        const std::vector<TBlob> &in_data,
-                        const std::vector<TBlob> &out_data,
-                        const std::vector<OpReqType> &req,
-                        const std::vector<TBlob> &in_grad,
-                        const std::vector<TBlob> &aux_args) {
-    using namespace mshadow;
-    using namespace mshadow::expr;
-
-    Stream<xpu> *s = ctx.get_stream<xpu>();
-
-    Tensor<xpu, 3, real_t> data_grad =
-        in_grad[ctc_loss::kData].get<xpu, 3, real_t>(s);
-    Tensor<xpu, 1, real_t> output_grad =
-        out_grad[ctc_loss::kOut].get<xpu, 1, real_t>(s);
-
-    Tensor<xpu, 3, real_t> data_grad_computed =
-        out_data[ctc_loss::kGrad].get<xpu, 3, real_t>(s);
-
-    Assign(data_grad, req[ctc_loss::kData],
-           mshadow::expr::broadcast<1>(output_grad, data_grad.shape_) * data_grad_computed);
-  }
-
- private:
-  CTCLossParam param_;
-  bool exceed_cudnn_limit;
-
-#if defined(__CUDACC__) && MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7
-  cudnnDataType_t dtype_;
-  cudnnCTCLossDescriptor_t ctc_desc_;
-  cudnnTensorDescriptor_t prob_desc_, grad_desc_;
-
-  inline virtual void cudnn_forward(const OpContext &ctx,
-                                    mshadow::Stream<xpu>* s,
-                                    mshadow::Tensor<xpu, 3, real_t> data,
-                                    mshadow::Tensor<xpu, 1, real_t> costs,
-                                    mshadow::Tensor<xpu, 3, real_t> grad,
-                                    std::vector<int>* data_lengths,
-                                    std::vector<int>* label_lengths,
-                                    std::vector<int>* packed_labels,
-                                    int max_seq_len,
-                                    int batch_size,
-                                    int alphabet_size,
-                                    bool req_grad) {
-    using namespace mshadow;
-
-    // call cudnn to calculate ctc loss
-    dtype_ = CUDNN_DATA_FLOAT;
-    int dims[3], strides[3];
-    size_t workspace_bytes;
-    int workspace_size;
-    dims[0] = max_seq_len;
-    dims[1] = batch_size;
-    dims[2] = alphabet_size;
-    strides[0] = batch_size*alphabet_size;
-    strides[1] = alphabet_size;
-    strides[2] = 1;
-    cudnnCTCLossAlgo_t ctc_algo = CUDNN_CTC_LOSS_ALGO_DETERMINISTIC;
-    CUDNN_CALL(cudnnSetTensorNdDescriptor(prob_desc_,
-                                          dtype_,
-                                          3,
-                                          dims,
-                                          strides));
-    CUDNN_CALL(cudnnSetTensorNdDescriptor(grad_desc_,
-                                          dtype_,
-                                          3,
-                                          dims,
-                                          strides));
-    CUDNN_CALL(cudnnGetCTCLossWorkspaceSize(s->dnn_handle_,
-                                            prob_desc_,
-                                            req_grad?grad_desc_:NULL,
-                                            packed_labels->data(),
-                                            label_lengths->data(),
-                                            data_lengths->data(),
-                                            ctc_algo,
-                                            ctc_desc_,
-                                            &workspace_bytes));
-    workspace_size = (workspace_bytes + sizeof(real_t) - 1)/sizeof(real_t);
-
-    Tensor<xpu, 1, real_t> temp_space =
-      ctx.requested[ctc_loss::kTempSpace].get_space_typed<xpu, 1, real_t>(
-          mshadow::Shape1(workspace_size+data.shape_.FlatTo1D()[0]), s);
-
-    Tensor<gpu, 1, real_t> work_space(temp_space.dptr_,
-                                      mshadow::Shape1(workspace_size), s);
-    Tensor<xpu, 3, real_t> prob(temp_space.dptr_+workspace_size,
-                                data.shape_, s);
-
-    // since the input is activation before softmax and cudnn ctc takes softmax
-    // apply softmax to inputs first.
-    mxnet_op::Softmax<mxnet_op::softmax_fwd, false>(
-      s, data.dptr_, prob.dptr_, data.shape_, 2, 1.0);
-
-    CUDNN_CALL(cudnnCTCLoss(s->dnn_handle_,
-                            prob_desc_,
-                            prob.dptr_,
-                            packed_labels->data(),
-                            label_lengths->data(),
-                            data_lengths->data(),
-                            costs.dptr_,
-                            req_grad?grad_desc_:NULL,
-                            req_grad?grad.dptr_:NULL,
-                            ctc_algo,
-                            ctc_desc_,
-                            work_space.dptr_,
-                            workspace_bytes));
-
-    if (req_grad) {
-      mxnet_op::SoftmaxGrad<mshadow_op::mul, mxnet_op::softmax_bwd, kWriteTo, false>(
-        s, prob.dptr_, grad.dptr_, grad.dptr_, data.shape_, 2, 1.0);
-      Assign(grad, mxnet::kWriteInplace, grad * alphabet_size);
-    }
-  }
-#endif  // __CUDACC__ && CUDNN
-
-  inline void baidu_forward(const OpContext &ctx,
-                            mshadow::Stream<xpu>* s,
-                            mshadow::Tensor<xpu, 3, real_t> data,
-                            mshadow::Tensor<xpu, 1, real_t> costs,
-                            mshadow::Tensor<xpu, 3, real_t> grad,
-                            std::vector<int>* data_lengths,
-                            std::vector<int>* label_lengths,
-                            std::vector<int>* packed_labels,
-                            int batch_size,
-                            int alphabet_size,
-                            bool req_grad) {
-    using namespace mshadow;
-    // allocate temporary workspace
-    size_t size_bytes;
-    bool gpu = data.kDevCPU ? false : true;
-    get_workspace_size<real_t>(label_lengths, data_lengths, alphabet_size,
-                               batch_size, gpu, &size_bytes);
-
-    // round-up so there are enough elems in memory
-    int num_tmp_elems = (size_bytes + sizeof(real_t) - 1) / sizeof(real_t);
-    Tensor<xpu, 1, real_t> workspace =
-        ctx.requested[ctc_loss::kTempSpace].get_space_typed<xpu, 1, real_t>(
-            Shape1(num_tmp_elems), s);
-
-    compute_ctc_cost(data, costs.dptr_, grad.dptr_, packed_labels->data(),
-                     label_lengths->data(), data_lengths->data(),
-                     workspace.dptr_, req_grad,
-                     param_.blank_label == 0 ? 0 : (alphabet_size-1));
-  }
-};  // class CTCLossOp
-
-template <typename xpu>
-Operator *CreateOp(CTCLossParam param, int dtype);
-
-#if DMLC_USE_CXX11
-class CTCLossProp : public OperatorProperty {
- public:
-  int NumVisibleOutputs() const override { return 1; }
-
-  int NumOutputs() const override { return 2; }
-
-  std::vector<std::string> ListArguments() const override {
-    if (param_.use_data_lengths && param_.use_label_lengths) {
-      return {"data", "label", "data_lengths", "label_lengths"};
-    } else if (param_.use_data_lengths) {
-      return {"data", "label", "data_lengths"};
-    } else if (param_.use_label_lengths) {
-      return {"data", "label", "label_lengths"};
-    } else {
-      return {"data", "label"};
-    }
-  }
-
-  std::vector<std::string> ListOutputs() const override {
-    return {"output", "grad"};
-  }
-
-  void Init(const std::vector<std::pair<std::string, std::string>> &kwargs) override {
-    param_.Init(kwargs);
-  }
-
-  std::map<std::string, std::string> GetParams() const override {
-    return param_.__DICT__();
-  }
-
-  bool InferShape(std::vector<TShape> *in_shape, std::vector<TShape> *out_shape,
-                  std::vector<TShape> *aux_shape) const override {
-    using namespace mshadow;
-    index_t expected_inputs = 2+param_.use_data_lengths+param_.use_label_lengths;
-    CHECK_EQ(in_shape->size(), expected_inputs)
-        << "Expect " << expected_inputs << " inputs to the symbol.";
-
-    const TShape &dshape = (*in_shape)[ctc_loss::kData];
-    const TShape &lshape = (*in_shape)[ctc_loss::kLabel];
-    CHECK_EQ(dshape.ndim(), 3U) << "The data array must be of rank 3.";
-    CHECK_EQ(lshape.ndim(), 2U) << "The labels array must be of rank 2.";
-    CHECK_EQ(dshape[1], lshape[0])
-        << "The batch size for the labels and data arrays must be the same.";
-    if (param_.use_data_lengths) {
-      int kInputLength = 2;
-      const TShape &dlshape = (*in_shape)[kInputLength];
-      CHECK_EQ(dlshape.ndim(), 1U) << "Data length array must be a vector.";
-      CHECK_EQ(dlshape[0], dshape[1])
-          << "The batch size for the data and data lengths must be the same.";
-    }
-    if (param_.use_label_lengths) {
-      int kLabelLength = 2+param_.use_data_lengths;
-      const TShape &llshape = (*in_shape)[kLabelLength];
-      CHECK_EQ(llshape.ndim(), 1U) << "Label length array must be a vector.";
-      CHECK_EQ(llshape[0], lshape[0])
-          << "The batch size for the labels and label lengths must be the same.";
-    }
-
-    CHECK_GE(dshape[0], lshape[1]) << "The max number of labels cannot exceed "
-                                      "the maximum sequence length of the "
-                                      "data.";
-
-    TShape oshape(1);
-    oshape[0] = dshape[1];  // batch size
-    out_shape->clear();
-    out_shape->push_back(oshape);  // forward output
-    out_shape->push_back(dshape);  // grad output
-    return true;
-  }
-
-  bool InferType(std::vector<int> *in_type,
-    std::vector<int> *out_type,
-    std::vector<int> *aux_type) const override {
-    CHECK_LE(in_type->size(), this->ListArguments().size());
-    int dtype = (*in_type)[ctc_loss::kData];
-    CHECK_NE(dtype, -1) << "Input data must have specified type";
-
-    out_type->clear();
-    out_type->push_back(dtype);  // forward output
-    out_type->push_back(dtype);  // grad output
-    return true;
-  }
-
-  OperatorProperty *Copy() const override {
-    auto ptr = new CTCLossProp();
-    ptr->param_ = param_;
-    return ptr;
-  }
-
-  std::string TypeString() const override { return "_contrib_CTCLoss"; }
-
-  std::vector<ResourceRequest> ForwardResource(
-      const std::vector<TShape> &in_shape) const override {
-    return {ResourceRequest::kTempSpace};
-  }
-
-  std::vector<int> DeclareBackwardDependency(
-      const std::vector<int> &out_grad, const std::vector<int> &in_data,
-      const std::vector<int> &out_data) const override {
-    return {out_grad[ctc_loss::kOut], out_data[ctc_loss::kGrad]};
-  }
-
-  Operator *CreateOperator(Context ctx) const override {
-    LOG(FATAL) << "Not Implemented.";
-    return NULL;
-  }
-
-  Operator *CreateOperatorEx(Context ctx, std::vector<TShape> *in_shape,
-                             std::vector<int> *in_type) const override;
-
- private:
-  CTCLossParam param_;
-};      // class CTCLossProp
-#endif  // DMLC_USE_CXX11
-}  // namespace op
-}  // namespace mxnet
-#endif  // MXNET_OPERATOR_CONTRIB_CTC_LOSS_INL_H_
diff --git a/src/operator/nn/ctc_loss-inl.h b/src/operator/nn/ctc_loss-inl.h
new file mode 100644
index 0000000..754cf84
--- /dev/null
+++ b/src/operator/nn/ctc_loss-inl.h
@@ -0,0 +1,397 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2017 by Contributors
+ * \file ctc_loss-inl.h
+ * \brief CTC Loss operator
+*/
+
+#ifndef MXNET_OPERATOR_NN_CTC_LOSS_INL_H_
+#define MXNET_OPERATOR_NN_CTC_LOSS_INL_H_
+
+#include <mxnet/operator_util.h>
+#include <vector>
+#include <algorithm>
+#include <string>
+#include "../mshadow_op.h"
+#include "./sequence_mask-inl.h"
+#include "../sequence_op_common.h"
+#include "../operator_common.h"
+#include "../elemwise_op_common.h"
+
+namespace mxnet {
+namespace op {
+
+namespace ctc_loss {
+enum CTCLossOpInputs { kData, kLabel };
+enum CTCLossOpOutputs { kOut, kGrad };
+}
+
+template <typename T>
+inline void get_workspace_size(const std::vector<int> *label_lengths,
+                               const std::vector<int> *data_lengths,
+                               int alphabet_size, int minibatch, bool isGPU,
+                               size_t *size_bytes) {
+  // This is the max of all S and T for all examples in the minibatch.
+  int maxL = *std::max_element(label_lengths->data(),
+                               label_lengths->data() + minibatch);
+  int maxT = *std::max_element(data_lengths->data(),
+                               data_lengths->data() + minibatch);
+
+  const int S = 2 * maxL + 1;
+
+  *size_bytes = 0;
+
+  if (isGPU) {
+    // GPU storage
+    // nll_forward, nll_backward
+    *size_bytes += 2 * sizeof(T) * minibatch;
+
+    // repeats
+    *size_bytes += sizeof(int) * minibatch;
+
+    // label offsets
+    *size_bytes += sizeof(int) * minibatch;
+
+    // utt_length
+    *size_bytes += sizeof(int) * minibatch;
+
+    // label lengths
+    *size_bytes += sizeof(int) * minibatch;
+
+    // labels without blanks - overallocate for now
+    *size_bytes += sizeof(int) * maxL * minibatch;
+
+    // labels with blanks
+    *size_bytes += sizeof(int) * S * minibatch;
+
+    // alphas
+    *size_bytes += sizeof(T) * S * maxT * minibatch;
+
+    // denoms
+    *size_bytes += sizeof(T) * maxT * minibatch;
+
+    // probs (since we will pass in activations)
+    *size_bytes += sizeof(T) * alphabet_size * maxT * minibatch;
+
+  } else {
+    // cpu can eventually replace all minibatch with
+    // max number of concurrent threads if memory is
+    // really tight
+
+    // per minibatch memory
+    size_t per_minibatch_bytes = 0;
+
+    // output
+    per_minibatch_bytes += sizeof(T) * alphabet_size;
+
+    // alphas
+    per_minibatch_bytes += sizeof(T) * S * maxT;
+
+    // betas
+    per_minibatch_bytes += sizeof(T) * S;
+
+    // labels w/blanks, e_inc, s_inc
+    per_minibatch_bytes += 3 * sizeof(int) * S;
+
+    *size_bytes = per_minibatch_bytes * minibatch;
+
+    // probs
+    *size_bytes += sizeof(T) * alphabet_size * maxT * minibatch;
+  }
+}
+
+// Takes a tensor of labels, and interprets 0-elements at the end of the vector
+// as padding. The tensor is packed into an std::vector without padding
+// characters. The label sequence lengths are also inferred from the padding chars.
+template <typename DType, typename xpu>
+inline void LabelTensorToPackedVector(mshadow::Tensor<xpu, 2, DType> labels,
+                                      int padding_mask,
+                                      std::vector<int> *packed_labels,
+                                      std::vector<int> *label_lengths) {
+  int batch = labels.size(0);
+  int max_num_labels = labels.size(1);
+
+  std::vector<int> cpu_labels(max_num_labels * batch);
+  mshadow::Tensor<xpu, 1, DType> flat_labels = labels.FlatTo1D();
+  IndexTensorToVector(flat_labels, &cpu_labels);
+
+  for (int b = 0; b < batch; ++b) {
+    auto start = cpu_labels.data() + b * max_num_labels;
+    auto res = std::find(start, start + max_num_labels, padding_mask);
+    int len = std::distance(start, res);
+    std::copy(start, start + len,
+              std::back_inserter(*packed_labels));
+    label_lengths->at(b) = len;
+  }
+}
+
+// Takes a tensor of labels, and a vector which specifies the actual length of each label
+// The tensor is packed into an std::vector without padding characters.
+// The label length vector is copied into an std::vector.
+template <typename DType, typename xpu>
+inline void PackLabelByLength(mshadow::Tensor<xpu, 2, DType> labels,
+                              mshadow::Tensor<xpu, 1, DType> in_label_lengths,
+                              std::vector<int> *packed_labels,
+                              std::vector<int> *label_lengths) {
+  int batch = labels.size(0);
+  int max_num_labels = labels.size(1);
+
+  IndexTensorToVector(in_label_lengths, label_lengths);
+
+  std::vector<int> cpu_labels(max_num_labels * batch);
+  mshadow::Tensor<xpu, 1, DType> flat_labels = labels.FlatTo1D();
+  IndexTensorToVector(flat_labels, &cpu_labels);
+
+  for (int b = 0; b < batch; ++b) {
+    auto start = cpu_labels.data() + b * max_num_labels;
+    int len = label_lengths->at(b);
+    std::copy(start, start + len,
+              std::back_inserter(*packed_labels));
+  }
+}
+
+struct CTCLossOpParam : public dmlc::Parameter<CTCLossOpParam> {
+  bool use_data_lengths;
+  bool use_label_lengths;
+  int blank_label;
+  DMLC_DECLARE_PARAMETER(CTCLossOpParam) {
+    DMLC_DECLARE_FIELD(use_data_lengths).set_default(false)
+      .describe("Whether the data lenghts are decided by `data_lengths`. "
+                "If false, the lengths are equal to the max sequence length.");
+    DMLC_DECLARE_FIELD(use_label_lengths).set_default(false)
+      .describe("Whether the label lenghts are decided by "
+                "`label_lengths`, or derived from `padding_mask`. "
+                "If false, the lengths are derived from the "
+                "first occurrence of the value of `padding_mask`. "
+                "The value of `padding_mask` is ``0`` when first CTC label is reserved for blank, "
+                "and ``-1`` when last label is reserved for blank. See `blank_label`.");
+    DMLC_DECLARE_FIELD(blank_label)
+      .add_enum("first", 0)
+      .add_enum("last", 1)
+      .set_default(0)
+      .describe("Set the label that is reserved for blank label."
+                "If \"first\", 0-th label is reserved, and "
+                "label values for tokens in the vocabulary are "
+                "between ``1`` and ``alphabet_size-1``, and the padding mask is ``-1``. "
+                "If \"last\", last label value ``alphabet_size-1`` "
+                "is reserved for blank label instead, "
+                "and label values for tokens in the vocabulary are "
+                "between ``0`` and ``alphabet_size-2``, and the padding mask is ``0``.");
+  }
+};
+
+// By default, the inputs must include data array and label array
+// if use_data_lengths parameter is set, user should also supply
+// data_lengths array; if use_label_lengths parameter is set, user
+// should also specify label_lengths array.
+inline uint32_t CTCLossOpNumInputs(const NodeAttrs& attrs) {
+  const CTCLossOpParam& param = nnvm::get<CTCLossOpParam>(attrs.parsed);
+  return 2U + param.use_data_lengths + param.use_label_lengths;
+}
+
+inline bool CTCLossOpShape(const nnvm::NodeAttrs &attrs,
+                           std::vector<TShape>* in_attrs,
+                           std::vector<TShape>* out_attrs) {
+    const CTCLossOpParam& param = nnvm::get<CTCLossOpParam>(attrs.parsed);
+    CHECK_EQ(in_attrs->size(), CTCLossOpNumInputs(attrs));
+    CHECK_EQ(out_attrs->size(), 2U);
+
+    const TShape &dshape = (*in_attrs)[ctc_loss::kData];
+    const TShape &lshape = (*in_attrs)[ctc_loss::kLabel];
+    CHECK_EQ(dshape.ndim(), 3U) << "The number of dimensions of data array must be 3.";
+    CHECK_EQ(lshape.ndim(), 2U) << "The number of dimensions of labels array must be 2.";
+    CHECK_EQ(dshape[1], lshape[0])
+        << "The batch size for the labels and data arrays must be the same.";
+
+    if (param.use_data_lengths) {
+      int kInputLength = 2;
+      const TShape &dlshape = (*in_attrs)[kInputLength];
+      CHECK_EQ(dlshape.ndim(), 1U) << "Data length array must be a vector.";
+      CHECK_EQ(dlshape[0], dshape[1])
+          << "The batch size for the data and data lengths must be the same.";
+    }
+    if (param.use_label_lengths) {
+      int kLabelLength = 2 + param.use_data_lengths;
+      const TShape &llshape = (*in_attrs)[kLabelLength];
+      CHECK_EQ(llshape.ndim(), 1U) << "Label length array must be a vector.";
+      CHECK_EQ(llshape[0], lshape[0])
+          << "The batch size for the labels and label lengths must be the same.";
+    }
+    CHECK_GE(dshape[0], lshape[1]) << "The max number of labels cannot exceed "
+                                      "the maximum sequence length of the "
+                                      "data.";
+
+    TShape oshape(1);
+    oshape[0] = dshape[1];  // batch size
+    SHAPE_ASSIGN_CHECK(*out_attrs, 0, oshape);  // forward output
+    SHAPE_ASSIGN_CHECK(*out_attrs, 1, dshape);  // grad output
+    return true;
+}
+
+inline bool CTCLossOpType(const nnvm::NodeAttrs& attrs,
+                          std::vector<int>* in_attrs,
+                          std::vector<int>* out_attrs) {
+    CHECK_GE(in_attrs->size(), 2U);
+    CHECK_EQ(out_attrs->size(), 2U);
+    int dtype = (*in_attrs)[ctc_loss::kData];
+    CHECK_NE(dtype, -1) << "Input data must have specified type";
+
+    TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0));  // forward output
+    TYPE_ASSIGN_CHECK(*out_attrs, 1, in_attrs->at(0));  // grad output
+    return true;
+}
+
+inline bool CTCLossOpStorageType(const nnvm::NodeAttrs& attrs,
+                                 const int dev_mask,
+                                 DispatchMode* dispatch_mode,
+                                 std::vector<int>* in_attrs,
+                                 std::vector<int>* out_attrs) {
+  CHECK_GE(in_attrs->size(), 2U);
+  CHECK_EQ(out_attrs->size(), 2U);
+  const int in_stype = in_attrs->at(0);
+  bool dispatched = false;
+  if (!dispatched && in_stype == kDefaultStorage) {
+    // dns -> dns
+    dispatched = storage_type_assign(out_attrs, kDefaultStorage,
+                                     dispatch_mode, DispatchMode::kFCompute);
+  }
+  if (!dispatched) {
+    dispatched = dispatch_fallback(out_attrs, dispatch_mode);
+  }
+  return dispatched;
+}
+
+
+inline std::vector<std::string> CTCLossOpListInputNames(const NodeAttrs& attrs) {
+  const CTCLossOpParam& param = nnvm::get<CTCLossOpParam>(attrs.parsed);
+  if (param.use_data_lengths && param.use_label_lengths) {
+    return {"data", "label", "data_lengths", "label_lengths"};
+  } else if (param.use_data_lengths) {
+    return {"data", "label", "data_lengths"};
+  } else if (param.use_label_lengths) {
+    return {"data", "label", "label_lengths"};
+  } else {
+    return {"data", "label"};
+  }
+}
+
+template<typename xpu>
+void CTCLossOpForward(const nnvm::NodeAttrs& attrs,
+                      const OpContext& ctx,
+                      const std::vector<TBlob>& inputs,
+                      const std::vector<OpReqType>& req,
+                      const std::vector<TBlob>& outputs) {
+  using namespace mshadow;
+  using namespace mshadow::expr;
+
+  const CTCLossOpParam& param = nnvm::get<CTCLossOpParam>(attrs.parsed);
+  CHECK_EQ(inputs.size(), CTCLossOpNumInputs(attrs));
+  CHECK_EQ(outputs.size(), 2U);
+  CHECK_EQ(req.size(), 2U);
+
+  const TBlob& in_data = inputs[ctc_loss::kData];
+  const TBlob& in_label = inputs[ctc_loss::kLabel];
+  const TBlob& out_data = outputs[ctc_loss::kOut];
+  const TBlob& out_grad = outputs[ctc_loss::kGrad];
+
+  Stream<xpu> *s = ctx.get_stream<xpu>();
+  MSHADOW_TYPE_SWITCH(inputs[ctc_loss::kLabel].type_flag_, DType, {
+    Tensor<xpu, 3, real_t> data = in_data.get<xpu, 3, real_t>(s);
+    Tensor<xpu, 2, DType> labels = in_label.get<xpu, 2, DType>(s);
+    Tensor<xpu, 1, real_t> costs = out_data.get<xpu, 1, real_t>(s);
+    Tensor<xpu, 3, real_t> grad = out_grad.get<xpu, 3, real_t>(s);
+
+    int max_seq_len = data.size(0);
+    int batch_size = data.size(1);
+    int alphabet_size = data.size(2);
+
+    // data_lengths
+    std::vector<int> data_lengths(batch_size, max_seq_len);
+    if (param.use_data_lengths) {
+      int kInputLength = 2;
+      IndexTensorToVector(inputs[kInputLength].get<xpu, 1, real_t>(s), &data_lengths);
+    }
+
+    // label_lengths
+    std::vector<int> packed_labels;
+    std::vector<int> label_lengths(batch_size);
+
+    if (param.use_label_lengths) {
+      int kLabelLength = 2 + param.use_data_lengths;
+      PackLabelByLength(labels, inputs[kLabelLength].get<xpu, 1, DType>(s),
+                        &packed_labels, &label_lengths);
+    } else {
+      LabelTensorToPackedVector(labels, param.blank_label == 0 ? 0 : -1,
+                                &packed_labels, &label_lengths);
+    }
+
+    size_t size_bytes;
+    get_workspace_size<real_t>(&label_lengths, &data_lengths, alphabet_size,
+                               batch_size, data.kDevCPU ? false : true, &size_bytes);
+
+    // round-up so there are enough elems in memory
+    int num_tmp_elems = (size_bytes + sizeof(real_t) - 1) / sizeof(real_t);
+    Tensor<xpu, 1, real_t> workspace =
+      ctx.requested[0].get_space_typed<xpu, 1, real_t>(Shape1(num_tmp_elems), s);
+
+    compute_ctc_cost(data, costs.dptr_, grad.dptr_, packed_labels.data(),
+                     label_lengths.data(), data_lengths.data(),
+                     workspace.dptr_, req[ctc_loss::kGrad] != mxnet::kNullOp,
+                     param.blank_label == 0 ? 0 : (alphabet_size - 1));
+
+    if (param.use_data_lengths) {
+      // baidu warp CTC implementation sometimes includes undefined gradients
+      // for data outside of length mask. Setting to 0 to make it consistent
+      // with CPU implementation.
+      int kInputLength = 2;
+      mxnet_op::SequenceMask(grad, inputs[kInputLength].get<xpu, 1, real_t>(s),
+                             static_cast<real_t>(0));
+    }
+  });
+}
+
+template<typename xpu>
+void CTCLossOpBackward(const nnvm::NodeAttrs& attrs,
+                       const OpContext& ctx,
+                       const std::vector<TBlob>& inputs,
+                       const std::vector<OpReqType>& req,
+                       const std::vector<TBlob>& outputs) {
+  using namespace mshadow;
+  using namespace mxnet_op;
+
+  Stream<xpu> *s = ctx.get_stream<xpu>();
+  const TBlob& in_grad = outputs[0];
+  const TBlob& out_grad = inputs[0];
+  const TBlob& grad_computed = inputs[3];  // grad computed in the forward step
+
+  Tensor<xpu, 3, real_t> igrad_data = in_grad.get<xpu, 3, real_t>(s);
+  Tensor<xpu, 1, real_t> ograd_data = out_grad.get<xpu, 1, real_t>(s);
+  Tensor<xpu, 3, real_t> computed_grad_data = grad_computed.get<xpu, 3, real_t>(s);
+
+  Assign(igrad_data, req[0],
+         mshadow::expr::broadcast<1>(ograd_data, computed_grad_data.shape_) * computed_grad_data);
+}
+
+}  // namespace op
+}  // namespace mxnet
+
+#endif  // MXNET_OPERATOR_NN_CTC_LOSS_INL_H_
+
diff --git a/src/operator/contrib/ctc_loss.cc b/src/operator/nn/ctc_loss.cc
similarity index 64%
rename from src/operator/contrib/ctc_loss.cc
rename to src/operator/nn/ctc_loss.cc
index 32e8e62..c381677 100644
--- a/src/operator/contrib/ctc_loss.cc
+++ b/src/operator/nn/ctc_loss.cc
@@ -18,26 +18,22 @@
  */
 
 /*!
- * Copyright (c) 2015 by Contributors
  * \file ctc_loss.cc
- * \brief
- * \author Sebastian Bodenstein
-*/
-
+ * \brief CPU Implementation of CTC Loss op
+ */
 #include "./ctc_loss-inl.h"
-#include "./ctc_include/detail/cpu_ctc.h"
+#include "../../../3rdparty/ctc_include/detail/cpu_ctc.h"
 
 namespace mshadow {
-
 template <typename DType>
 ctcStatus_t compute_ctc_cost(const Tensor<cpu, 3, DType> activations,
                              DType *costs, DType *grads, int *labels,
                              int *label_lengths, int *data_lengths,
-                             void *workspace, int train, int blank_label) {
+                             void *workspace, bool isTraining, int blank_label) {
   int minibatch = static_cast<int>(activations.size(1));
   int alphabet_size = static_cast<int>(activations.size(2));
   mxnet_warpctc::CpuCTC<DType> ctc(alphabet_size, minibatch, workspace, blank_label);
-  if (train) {
+  if (isTraining) {
     return ctc.cost_and_grad(activations.dptr_, grads, costs, labels,
                              label_lengths, data_lengths);
   } else {
@@ -45,32 +41,18 @@ ctcStatus_t compute_ctc_cost(const Tensor<cpu, 3, DType> activations,
                              data_lengths);
   }
 }
-
 }  // namespace mshadow
 
 namespace mxnet {
 namespace op {
-template <>
-Operator *CreateOp<cpu>(CTCLossParam param, int dtype) {
-  return new CTCLossOp<cpu>(param);
-}
-
-// DO_BIND_DISPATCH comes from operator_common.h
-Operator *CTCLossProp::CreateOperatorEx(Context ctx,
-                                        std::vector<TShape> *in_shape,
-                                        std::vector<int> *in_type) const {
-  std::vector<TShape> out_shape, aux_shape;
-  std::vector<int> out_type, aux_type;
-  CHECK(InferType(in_type, &out_type, &aux_type));
-  CHECK(InferShape(in_shape, &out_shape, &aux_shape));
-  DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0]);
-}
-
-DMLC_REGISTER_PARAMETER(CTCLossParam);
 
-MXNET_REGISTER_OP_PROPERTY(_contrib_CTCLoss, CTCLossProp)
-    .describe(R"code(Connectionist Temporal Classification Loss.
+DMLC_REGISTER_PARAMETER(CTCLossOpParam);
 
+NNVM_REGISTER_OP(CTCLoss)
+.add_alias("ctc_loss")
+.add_alias("_contrib_CTCLoss")
+.add_alias("_contrib_ctc_loss")
+.describe(R"code(Connectionist Temporal Classification Loss.
 The shapes of the inputs and outputs:
 
 - **data**: `(sequence_length, batch_size, alphabet_size)`
@@ -113,18 +95,41 @@ Sequence Data with Recurrent Neural Networks*, A. Graves *et al*. for more
 information on the definition and the algorithm.
 
 )code" ADD_FILELINE)
-    .add_argument("data", "NDArray-or-Symbol", "Input data to the ctc_loss op.")
-    .add_argument("label", "NDArray-or-Symbol",
-                  "Ground-truth labels for the loss.")
-    .add_argument("data_lengths", "NDArray-or-Symbol",
-                  "Lengths of data for each of the samples. Only required "
-                  "when use_data_lengths is true.")
-    .add_argument("label_lengths", "NDArray-or-Symbol",
-                  "Lengths of labels for each of the samples. Only required "
-                  "when use_label_lengths is true.")
-    .add_arguments(CTCLossParam::__FIELDS__());
-
-NNVM_REGISTER_OP(_contrib_CTCLoss).add_alias("_contrib_ctc_loss");
+.set_attr_parser(ParamParser<CTCLossOpParam>)
+.set_num_inputs(CTCLossOpNumInputs)
+.set_num_outputs(2)
+.set_attr<nnvm::FListInputNames>("FListInputNames", CTCLossOpListInputNames)
+.set_attr<nnvm::FListOutputNames>("FListOutputNAmes",
+  [](const NodeAttrs& attrs) {
+    return std::vector<std::string>{"out", "grad"};
+  })
+.set_attr<nnvm::FNumVisibleOutputs>("FNumVisibleOutputs",
+  [](const NodeAttrs& attrs) {
+    return 1;
+  })
+.set_attr<nnvm::FInferShape>("FInferShape", CTCLossOpShape)
+.set_attr<nnvm::FInferType>("FInferType", CTCLossOpType)
+.set_attr<FInferStorageType>("FInferStorageType", CTCLossOpStorageType)
+.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& attrs)
+  { return std::vector<ResourceRequest>{ResourceRequest::kTempSpace}; })
+.set_attr<FCompute>("FCompute<cpu>", CTCLossOpForward<cpu>)
+.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseOut{"_backward_ctc_loss"})
+.add_argument("data", "NDArray-or-Symbol", "Input ndarray")
+.add_argument("label", "NDArray-or-Symbol", "Ground-truth labels for the loss.")
+.add_argument("data_lengths", "NDArray-or-Symbol",
+              "Lengths of data for each of the samples. Only required "
+              "when use_data_lengths is true.")
+.add_argument("label_lengths", "NDArray-or-Symbol",
+              "Lengths of labels for each of the samples. Only required "
+              "when use_label_lengths is true.")
+.add_arguments(CTCLossOpParam::__FIELDS__());
+
+NNVM_REGISTER_OP(_backward_ctc_loss)
+.set_attr_parser(ParamParser<CTCLossOpParam>)
+.set_num_inputs(1)
+.set_num_outputs(CTCLossOpNumInputs)
+.set_attr<nnvm::TIsBackward>("TIsBackward", true)
+.set_attr<FCompute>("FCompute<cpu>", CTCLossOpBackward<cpu>);
 
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/contrib/ctc_loss.cu b/src/operator/nn/ctc_loss.cu
similarity index 81%
rename from src/operator/contrib/ctc_loss.cu
rename to src/operator/nn/ctc_loss.cu
index 3f5f12c..a4491bf 100644
--- a/src/operator/contrib/ctc_loss.cu
+++ b/src/operator/nn/ctc_loss.cu
@@ -18,14 +18,13 @@
  */
 
 /*!
- * Copyright (c) 2015 by Contributors
+ * Copyright (c) 2018 by Contributors
  * \file ctc_loss.cu
- * \brief
- * \author Sebastian Bodenstein
-*/
-#include <algorithm>
+ * \brief GPU Implementation of ctc_loss op
+ */
+
 #include "./ctc_loss-inl.h"
-#include "./ctc_include/detail/gpu_ctc.h"
+#include "../../../3rdparty/ctc_include/detail/gpu_ctc.h"
 
 namespace mshadow {
 
@@ -45,17 +44,19 @@ ctcStatus_t compute_ctc_cost(const Tensor<gpu, 3, DType> activations,
     return ctc.score_forward(activations.dptr_, costs, labels,
                              label_lengths, input_lengths);
 }
-
 }  // namespace mshadow
 
-////////////////////////////////////////////////////////////////////////////////
-
 namespace mxnet {
 namespace op {
-template <>
-Operator *CreateOp<gpu>(CTCLossParam param, int dtype) {
-  return new CTCLossOp<gpu>(param);
-}
+
+NNVM_REGISTER_OP(CTCLoss)
+.add_alias("ctc_loss")
+.add_alias("_contrib_ctc_loss")
+.add_alias("_contrib_CTCLoss")
+.set_attr<FCompute>("FCompute<gpu>", CTCLossOpForward<gpu>);
+
+NNVM_REGISTER_OP(_backward_ctc_loss)
+.set_attr<FCompute>("FCompute<gpu>", CTCLossOpBackward<gpu>);
 
 }  // namespace op
 }  // namespace mxnet
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index b17562c..5332517 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -4491,11 +4491,34 @@ def test_pick():
 def check_ctc_loss(acts, labels, loss_truth):
     in_var = mx.sym.Variable('input')
     labels_var = mx.sym.Variable('labels')
-    ctc = mx.sym.contrib.ctc_loss(in_var, labels_var)
+    ctc = mx.sym.ctc_loss(in_var, labels_var)
     acts_nd = mx.nd.array(acts, ctx=default_context())
     labels_nd = mx.nd.array(labels, ctx=default_context())
     exe = ctc.bind(ctx=default_context(), args=[acts_nd, labels_nd])
+    # test forward with grad calc
+    exe.forward(is_train=True)
+    outTest = exe.outputs[0]
     # test forward without grad calc
+    exe.forward(is_train=False)
+    outTrain = exe.outputs[0]
+    # make sure losses calculated with both modes are the same
+    assert_almost_equal(outTest.asnumpy(), outTrain.asnumpy())
+
+    # test against ground truth, if available
+    if loss_truth is not None:
+        assert_almost_equal(outTest.asnumpy(), loss_truth)
+    # test grad
+    check_numeric_gradient(ctc, [acts, labels], grad_nodes=['input'], rtol=0.05, atol=1e-3)
+
+# check contrib operator for backward compatibility
+def check_contrib_ctc_loss(acts, labels, loss_truth):
+    in_var = mx.sym.Variable('input')
+    labels_var = mx.sym.Variable('labels')
+    ctc = mx.sym.contrib.ctc_loss(in_var, labels_var)
+    acts_nd = mx.nd.array(acts, ctx=default_context())
+    labels_nd = mx.nd.array(labels, ctx=default_context())
+    exe = ctc.bind(ctx=default_context(), args=[acts_nd, labels_nd])
+    # test forward with grad calc
     exe.forward(is_train=True)
     outTest = exe.outputs[0]
     # test forward without grad calc
@@ -4503,6 +4526,7 @@ def check_ctc_loss(acts, labels, loss_truth):
     outTrain = exe.outputs[0]
     # make sure losses calculated with both modes are the same
     assert_almost_equal(outTest.asnumpy(), outTrain.asnumpy())
+
     # test against ground truth, if available
     if loss_truth is not None:
         assert_almost_equal(outTest.asnumpy(), loss_truth)
@@ -4520,6 +4544,8 @@ def test_ctc_loss():
     labels = np.array([[2, 3, 0], [2, 3, 0]])
     true_loss = np.array([4.04789, 4.04789], dtype=np.float32) # from Torch
     check_ctc_loss(acts, labels, true_loss)
+    check_contrib_ctc_loss(acts, labels, true_loss)
+
     # Test 2:
     acts2 = np.array([
         [[-5, -4, -3, -2, -1], [1.2, 3.4, 1.2, -0.1, -2.34]],
@@ -4528,11 +4554,13 @@ def test_ctc_loss():
     labels2 = np.array([[2, 3, 1], [2, 0, 0]], dtype=np.float32)
     true_loss = np.array([7.3557, 5.4091], dtype=np.float32) # from Torch
     check_ctc_loss(acts2, labels2, true_loss)
+    check_contrib_ctc_loss(acts2, labels2, true_loss)
 
     # Test 3: check use integer type as label
     labels3 = np.array([[2, 3, 1], [2, 0, 0]], dtype=np.int32)
     true_loss = np.array([7.3557, 5.4091], dtype=np.float32) # from Torch
     check_ctc_loss(acts2, labels3, true_loss)
+    check_contrib_ctc_loss(acts2, labels3, true_loss)
 
 @with_seed()
 def test_ctc_loss_with_large_classes():
@@ -4550,7 +4578,7 @@ def test_ctc_loss_with_large_classes():
         [1000, 2000, 3000, 4000, 0, 5000, 0, 0]], dtype=np.int32)
     nd_data = mx.nd.array(data)
     nd_label = mx.nd.array(label)
-    loss = mx.nd.contrib.ctc_loss(data=nd_data, label=nd_label)
+    loss = mx.nd.ctc_loss(data=nd_data, label=nd_label)
     expected_loss = np.array([688.02826, 145.34462])
     assert_almost_equal(loss.asnumpy(), expected_loss)
 
@@ -4624,6 +4652,85 @@ def test_ctc_loss_grad():
             label = mx.nd.array(labels)
             data.attach_grad()
             with mx.autograd.record():
+                l = mx.ndarray.CTCLoss(data, label,
+                                       use_data_lengths=True,
+                                       use_label_lengths=True,
+                                       data_lengths=mx.nd.array(seq_lens),
+                                       label_lengths=mx.nd.array(label_lens),
+                                       blank_label=blank_label)
+                l.backward()
+            assert_almost_equal(l.asnumpy(), loss_truth, atol=1e-5, rtol=1e-5)
+            assert_almost_equal(data.grad.asnumpy(), grad_truth, atol=1e-5, rtol=1e-5)
+
+    # check contrib operator for backward compatibility
+    def check_contrib_ctc_loss_grad(blank_label): # from tf
+        vocab_size = 5
+        max_label_len = 5
+        padding_mask = -1+ (blank_label=='first')
+
+        targets_0 = [0, 1, 2, 1, 0]
+        loss_log_prob_0 = -3.34211
+        input_prob_matrix_0 = np.asarray(
+            [[0.633766, 0.221185, 0.0917319, 0.0129757, 0.0142857, 0.0260553],
+             [0.111121, 0.588392, 0.278779, 0.0055756, 0.00569609, 0.010436],
+             [0.0357786, 0.633813, 0.321418, 0.00249248, 0.00272882, 0.0037688],
+             [0.0663296, 0.643849, 0.280111, 0.00283995, 0.0035545, 0.00331533],
+             [0.458235, 0.396634, 0.123377, 0.00648837, 0.00903441, 0.00623107]],
+            dtype=np.float32)
+        gradient_log_prob_0 = np.asarray(
+            [[-0.366234, 0.221185, 0.0917319, 0.0129757, 0.0142857, 0.0260553],
+             [0.111121, -0.411608, 0.278779, 0.0055756, 0.00569609, 0.010436],
+             [0.0357786, 0.633813, -0.678582, 0.00249248, 0.00272882, 0.0037688],
+             [0.0663296, -0.356151, 0.280111, 0.00283995, 0.0035545, 0.00331533],
+             [-0.541765, 0.396634, 0.123377, 0.00648837, 0.00903441, 0.00623107]],
+            dtype=np.float32)
+
+        targets_1 = [0, 1, 1, 0]
+        loss_log_prob_1 = -5.42262
+        input_prob_matrix_1 = np.asarray(
+            [[0.30176, 0.28562, 0.0831517, 0.0862751, 0.0816851, 0.161508],
+             [0.24082, 0.397533, 0.0557226, 0.0546814, 0.0557528, 0.19549],
+             [0.230246, 0.450868, 0.0389607, 0.038309, 0.0391602, 0.202456],
+             [0.280884, 0.429522, 0.0326593, 0.0339046, 0.0326856, 0.190345],
+             [0.423286, 0.315517, 0.0338439, 0.0393744, 0.0339315, 0.154046]],
+            dtype=np.float32)
+        gradient_log_prob_1 = np.asarray(
+            [[-0.69824, 0.28562, 0.0831517, 0.0862751, 0.0816851, 0.161508],
+             [0.24082, -0.602467, 0.0557226, 0.0546814, 0.0557528, 0.19549],
+             [0.230246, 0.450868, 0.0389607, 0.038309, 0.0391602, -0.797544],
+             [0.280884, -0.570478, 0.0326593, 0.0339046, 0.0326856, 0.190345],
+             [-0.576714, 0.315517, 0.0338439, 0.0393744, 0.0339315, 0.154046]],
+            dtype=np.float32)
+
+        inputs = [
+            np.vstack(
+                [input_prob_matrix_0[t, :], input_prob_matrix_1[t, :]])
+            for t in range(5)
+        ] + 2 * [np.nan * np.ones((2, vocab_size+1), np.float32)]
+        inputs = np.log(np.asarray(inputs, dtype=np.float32))
+
+        grad_truth = np.array([
+            np.vstack(
+                [gradient_log_prob_0[t, :], gradient_log_prob_1[t, :]])
+            for t in range(5)
+        ] + 2 * [np.zeros((2, vocab_size+1), np.float32)])
+
+        if blank_label == 'first':
+            inputs = np.roll(inputs, 1, axis=2)
+            grad_truth = np.roll(grad_truth, 1, axis=2)
+
+        labels = (np.asarray([x + [padding_mask]*(max_label_len-len(x))
+                             for x in [targets_0, targets_1]])+(blank_label == 'first'))
+
+        seq_lens = np.array([5, 5], dtype=np.int32)
+        label_lens = np.array([5, 4], dtype=np.int32)
+        loss_truth = np.array([-loss_log_prob_0, -loss_log_prob_1], np.float32)
+
+        with default_context():
+            data = mx.nd.array(inputs)
+            label = mx.nd.array(labels)
+            data.attach_grad()
+            with mx.autograd.record():
                 l = mx.contrib.ndarray.CTCLoss(data, label,
                                                use_data_lengths=True,
                                                use_label_lengths=True,
@@ -4634,8 +4741,11 @@ def test_ctc_loss_grad():
             assert_almost_equal(l.asnumpy(), loss_truth, atol=1e-5, rtol=1e-5)
             assert_almost_equal(data.grad.asnumpy(), grad_truth, atol=1e-5, rtol=1e-5)
 
+
     check_ctc_loss_grad('first')
     check_ctc_loss_grad('last')
+    check_contrib_ctc_loss_grad('first')
+    check_contrib_ctc_loss_grad('last')
 
 
 @with_seed()