You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@singa.apache.org by wa...@apache.org on 2018/07/11 08:29:34 UTC

[2/4] incubator-singa git commit: SINGA-380 Fix bugs from Reshape

SINGA-380 Fix bugs from Reshape

Add SoftmaxCrossEntropy Operation which accepts logits as input (Softmax is applied in the operation)
There is another operation for CrossEntropy which accepts probabilities as input.

Fix the memory leaking bug from Reshape in C++.

fix mem leak bug from reshape


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/58e6640d
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/58e6640d
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/58e6640d

Branch: refs/heads/master
Commit: 58e6640df8d8f3572e15e818e59eb552ea314eaa
Parents: 81908a8
Author: Wang Wei <wa...@gmail.com>
Authored: Sun Jul 8 21:42:26 2018 +0800
Committer: wang wei <wa...@comp.nus.edu.sg>
Committed: Wed Jul 11 15:20:43 2018 +0800

----------------------------------------------------------------------
 examples/autograd/mnist_cnn.py | 30 ++++++++++------
 include/singa/core/tensor.h    |  6 +++-
 python/singa/autograd.py       | 61 ++++++++++++++++++-------------
 src/api/core_tensor.i          |  3 ++
 src/core/tensor/tensor.cc      | 71 +++++++++++++++++++++++--------------
 src/model/layer/cudnn_rnn.cc   |  2 +-
 6 files changed, 108 insertions(+), 65 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/58e6640d/examples/autograd/mnist_cnn.py
----------------------------------------------------------------------
diff --git a/examples/autograd/mnist_cnn.py b/examples/autograd/mnist_cnn.py
index db21485..43a22ba 100755
--- a/examples/autograd/mnist_cnn.py
+++ b/examples/autograd/mnist_cnn.py
@@ -77,7 +77,7 @@ if __name__ == '__main__':
     args = parser.parse_args()
 
     assert os.path.exists(args.file_path), \
-        'Pls download the MNIST dataset from '
+        'Pls download the MNIST dataset from https://s3.amazonaws.com/img-datasets/mnist.npz'
 
     if args.use_cpu:
         print('Using CPU')
@@ -92,7 +92,7 @@ if __name__ == '__main__':
     num_classes = 10
     epochs = 1
 
-    sgd = optimizer.SGD(0.05)
+    sgd = optimizer.SGD(0.001)
 
     x_train = preprocess(train[0])
     y_train = to_categorical(train[1], num_classes)
@@ -109,24 +109,26 @@ if __name__ == '__main__':
     conv2 = autograd.Conv2D(32, 32, 3, padding=1)
     linear = autograd.Linear(32 * 28 * 28, 10)
 
+
     def forward(x, t):
+        
         y = conv1(x)
         y = autograd.relu(y)
+        y = autograd.max_pool_2d(y)
         y = conv2(y)
         y = autograd.relu(y)
         y = autograd.max_pool_2d(y)
-        y = autograd.flatten(y)
+        y=autograd.flatten(y)
         y = linear(y)
-        y = autograd.soft_max(y)
-        loss = autograd.cross_entropy(y, t)
+        loss = autograd.softmax_cross_entropy(y, t)
         return loss, y
 
     autograd.training = True
-    for epoch in range(epochs):
+    for epoch in range(50):
         for i in range(batch_number):
-            inputs = tensor.Tensor(device=dev, data=x_train[i * 100:(1 + i) * 100])
-            targets = tensor.Tensor(device=dev, data=y_train[i * 100:(1 + i) * 100])
-
+            inputs = tensor.Tensor(device=dev, data=x_train[ i * 100:(1 + i) * 100], stores_grad=False)
+            targets = tensor.Tensor(device=dev, data=y_train[i * 100:(1 + i) * 100], requires_grad=False, stores_grad=False)
+            
             loss, y = forward(inputs, targets)
 
             accuracy_rate = accuracy(tensor.to_numpy(y),
@@ -134,6 +136,12 @@ if __name__ == '__main__':
             if (i % 5 == 0):
                 print('accuracy is:', accuracy_rate, 'loss is:',
                       tensor.to_numpy(loss)[0])
-
+            
             for p, gp in autograd.backward(loss):
-                sgd.apply(0, gp, p, '')
+                sgd.apply(epoch, gp, p, '')
+            
+            
+
+            
+            
+

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/58e6640d/include/singa/core/tensor.h
----------------------------------------------------------------------
diff --git a/include/singa/core/tensor.h b/include/singa/core/tensor.h
index a1badd4..5921762 100644
--- a/include/singa/core/tensor.h
+++ b/include/singa/core/tensor.h
@@ -292,7 +292,7 @@ Tensor Reshape(const Tensor &in, Shape &&s);
 void CopyDataToFrom(Tensor *dst, const Tensor &src, const size_t num,
                     const size_t dst_offset = 0, const size_t src_offset = 0);
 
-void RepeatDataToFrom(bool broadcast_flag, vector<size_t> repeats, int axis, 
+void RepeatDataToFrom(bool broadcast_flag, vector<size_t> repeats, int axis,
                       Tensor *dst, const Tensor &in, const size_t num);
 
 // =============Element-wise operations====================================
@@ -508,6 +508,10 @@ void ComputeCrossEntropy(const Tensor &p, const Tensor &t, Tensor *loss);
 
 void SoftmaxCrossEntropyBwd(const Tensor &t, Tensor *p);
 
+/// To be called by pysinga autograd operations;
+/// swig ignores the const qualifier http://www.swig.org/Doc3.0/SWIGPlus.html#SWIGPlus_const
+const Tensor CrossEntropyFwd(const Tensor& p, const Tensor& t);
+const Tensor SoftmaxCrossEntropyBwd(const Tensor& p, const Tensor& t);
 
 /// Return a tensor consisting of rows ([start, end)) from 'in'. It copies the
 /// values from 'in'. 'in' ia a 2D Tensor.

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/58e6640d/python/singa/autograd.py
----------------------------------------------------------------------
diff --git a/python/singa/autograd.py b/python/singa/autograd.py
index 2ba3098..63698c2 100755
--- a/python/singa/autograd.py
+++ b/python/singa/autograd.py
@@ -40,14 +40,8 @@ class Operation(object):
 
     Steps to add a specific operation Xxxx:
     1. create a subclass of Operation, name it as Xxxx
-    2. if Xxxx is implemented using other Operations, then override
-       _do_forward() function;
-       if Xxxx is implemented using CTensor operations,
-       then override the forward() and backward(); The arguments of forward()
+    2. override the forward() and backward(); The arguments of forward()
        and backward() should only include CTensor;
-       if Xxxx is implemented by calling functions in layer.py, then override
-       __call__(), forward() and backward(). TODO(wangwei) avoid this complex
-       case.
     '''
 
     def __call__(self, *xs):
@@ -311,9 +305,9 @@ def soft_max(x, axis=0):
     return SoftMax(axis)(x)[0]
 
 
-class CrossEntropy(Operation):
+class NLL(Operation):
     '''
-    Calculte CrossEntropy loss for a batch of training data.
+    Calculte negative log likelihood loss for a batch of training data.
 
     '''
 
@@ -356,8 +350,27 @@ class CrossEntropy(Operation):
             pass  # TODO, broadcast elementwise multiply seems not support
 
 
-def cross_entropy(y, t):
-    return CrossEntropy()(y, t)[0]
+def nll(y, t):
+    return NLL()(y, t)[0]
+
+
+class SoftMaxCrossEntropy(Operation):
+
+    def forward(self, x, t):
+        self.p = singa.SoftMax(x)
+        self.t = t
+        loss = CTensor((1,), self.p.device())
+        ret = singa.CrossEntropyFwd(self.p, t)
+        loss.SetFloatValue(singa.SumAsFloat(ret) / x.shape()[0])
+        return loss
+
+    def backward(self, dy=1.0):
+        return singa.SoftmaxCrossEntropyBwd(self.p, self.t), None
+
+
+def softmax_cross_entropy(x, t):
+    # x is the logits and t is the ground truth; both are 2D.
+    return SoftMaxCrossEntropy()(x, t)[0]
 
 
 def ctensor2numpy(x):
@@ -427,23 +440,20 @@ def max_pool_2d(x, kernel_size=3, stride=1, padding=0, dilation=1,
 
 class Flatten(Operation):
 
-    def __init__(self):
-        self.PyLayer = layer.Flatten('flatten', 1)
+    def __init(self, start_axis=1):
+        # flatten all axis after (inclusive) start_axis
+        self.start_axis = start_axis
+        assert start_axis == 1, 'must flatten into 2d array not'
 
-    def __call__(self, x):
-        if training:
-            self.flag = model_pb2.kTrain
-        else:
-            self.flag = model_pb2.kEval
-        if not self.PyLayer.has_setup:
-            self.PyLayer.setup(x.shape[1:])
-        return self._do_forward(x)
-
-    def forward(self, *xs):
-        return self.PyLayer.layer.Forward(self.flag, xs[0])
+    def forward(self, x):
+        # TODO Do flatten start from axis != 1
+        self.shape = list(x.shape())
+        y = x.Reshape((x.shape()[0], x.Size() // x.shape()[0]))
+        return y
 
     def backward(self, dy):
-        return self.PyLayer.layer.Backward(0, dy)[0]
+        dx = dy.Reshape(self.shape)
+        return dx
 
 
 def flatten(x):
@@ -623,6 +633,7 @@ def backward(y, dy=None):
                     if not isinstance(src_op, Dummy):
                         ready.append((src_op, not_ready[src_op]))
                     del not_ready[src_op]
+        del op  # delete the operation to free all tensors from this op
 
 
 class Layer(object):

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/58e6640d/src/api/core_tensor.i
----------------------------------------------------------------------
diff --git a/src/api/core_tensor.i b/src/api/core_tensor.i
index d94e506..cc72d21 100644
--- a/src/api/core_tensor.i
+++ b/src/api/core_tensor.i
@@ -325,4 +325,7 @@ namespace singa{
 
   Tensor SoftMax(const Tensor &in);
   void SoftMax(const Tensor &in, Tensor *out);
+
+  const Tensor CrossEntropyFwd(const Tensor& p, const Tensor& t);
+  const Tensor SoftmaxCrossEntropyBwd(const Tensor& p, const Tensor& t);
 }

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/58e6640d/src/core/tensor/tensor.cc
----------------------------------------------------------------------
diff --git a/src/core/tensor/tensor.cc b/src/core/tensor/tensor.cc
old mode 100644
new mode 100755
index 05db7cf..e5e8017
--- a/src/core/tensor/tensor.cc
+++ b/src/core/tensor/tensor.cc
@@ -218,10 +218,10 @@ void Tensor::CopyData(const Tensor &src) {
 }
 
 void Tensor::RepeatData(vector<size_t> repeats, int axis, int total_repeats, const Tensor &src) {
-  if(repeats.size() == 1) {
+  if (repeats.size() == 1) {
     CHECK_EQ(Size(), src.Size()*total_repeats);
   } else {
-    CHECK_EQ(Size(), src.Size()*total_repeats/src.shape()[axis]);
+    CHECK_EQ(Size(), src.Size()*total_repeats / src.shape()[axis]);
   }
 
   CHECK(block_ != nullptr);
@@ -344,7 +344,7 @@ Tensor Tensor::Repeat(vector<size_t> repeats, int axis, std::shared_ptr<Device>
     total_repeats = repeats[0];
     tshape.push_back(Product(shape_)*total_repeats);
   } else {
-    if (repeats.size() == 1){
+    if (repeats.size() == 1) {
       total_repeats = repeats[0];
       for (size_t i = 0; i < shape_.size(); i++) {
         if (i == axis) {
@@ -358,15 +358,15 @@ Tensor Tensor::Repeat(vector<size_t> repeats, int axis, std::shared_ptr<Device>
         LOG(FATAL) << "the repeats number doesn't match the axis";
       }
       for (size_t i = 0; i < shape_[axis]; i++) {
-        if(repeats[i] < 0) {
+        if (repeats[i] < 0) {
           LOG(FATAL) << "the repeats number is less than zero";
         }
         total_repeats += repeats[i];
       }
-      for (size_t i = 0; i < shape_.size(); i++){
+      for (size_t i = 0; i < shape_.size(); i++) {
         if (i == axis) {
           tshape.push_back(total_repeats);
-        } else{
+        } else {
           tshape.push_back(shape_[i]);
         }
       }
@@ -539,7 +539,7 @@ void CopyDataToFrom(Tensor *dst, const Tensor &src, const size_t num,
   }
 }
 
-void RepeatDataToFrom(bool broadcast_flag, vector<size_t> repeats, int axis, 
+void RepeatDataToFrom(bool broadcast_flag, vector<size_t> repeats, int axis,
                       Tensor *dst, const Tensor &src, const size_t num) {
   if (repeats.size() == 1) {
     broadcast_flag = true;
@@ -548,7 +548,7 @@ void RepeatDataToFrom(bool broadcast_flag, vector<size_t> repeats, int axis,
       LOG(FATAL) << "When repeats parameter is sequence, axis cannot be None";
     }
   }
-  for (size_t i = 0; i < repeats.size(); i++){
+  for (size_t i = 0; i < repeats.size(); i++) {
     CHECK_GE(repeats[i], 0);
   }
   auto width = SizeOf(src.data_type());
@@ -557,7 +557,7 @@ void RepeatDataToFrom(bool broadcast_flag, vector<size_t> repeats, int axis,
   int chunk = width;
   int axis_shape = 1;
   int shape_outer = 1;
-  if (axis == Noaxis){
+  if (axis == Noaxis) {
     axis_shape = 1;
     shape_outer = Product(src.shape());
   } else {
@@ -565,7 +565,7 @@ void RepeatDataToFrom(bool broadcast_flag, vector<size_t> repeats, int axis,
       shape_outer *= src.shape()[i];
     }
     axis_shape = src.shape()[axis];
-    for(size_t i = axis + 1; i < src.nDim(); i++) {
+    for (size_t i = axis + 1; i < src.nDim(); i++) {
       chunk *= src.shape()[i];
     }
   }
@@ -693,7 +693,7 @@ void Tensor::SetValue(const SType x) {
   CHECK_EQ(sizeof(SType), SizeOf(data_type_));
   //auto size = Size();
   auto ptr = block_;
-  
+
   TYPE_LANG_SWITCH(data_type_, DType, device_->lang(), Lang, {
     // TODO(wangwei) cast x to DType
     device_->Exec([this, x, ptr](Context * ctx) {
@@ -1268,6 +1268,18 @@ void Mult(const SType alpha, const Tensor &A, const Tensor &B, const SType beta,
 // ************************
 // Misc.
 // ************************
+const Tensor CrossEntropyFwd(const Tensor& p, const Tensor& t) {
+  Tensor loss({p.shape(0)}, p.device(), p.data_type());
+  ComputeCrossEntropy(p, t, &loss);
+  return loss;
+}
+
+const Tensor SoftmaxCrossEntropyBwd(const Tensor& p, const Tensor& t) {
+  auto g = p.Clone();
+  SoftmaxCrossEntropyBwd(t, &g);
+  return g;
+}
+
 void ComputeCrossEntropy(const Tensor &p, const Tensor &t, Tensor *loss) {
   CHECK_LE(p.nDim(), 2u);
   CHECK_LE(t.nDim(), 2u);
@@ -1302,6 +1314,7 @@ Tensor Tensor::Reshape(const Shape &shape) {
   if (strides_.size() == 0)
     strides_.push_back(1);
 
+  // TODO(wangwei) remove this condition and report error if size changes.
   if (Product(shape_) != Product(shape)) {
     if (block_ != nullptr && block_->DecRefCount() == 0)
       device_->FreeBlock(block_);
@@ -1316,14 +1329,16 @@ Tensor Tensor::Reshape(const Shape &shape) {
     singa::Transform(*this, &t);
     t.shape_ = shape;
     return t;
- }
-
-  shape_ = shape;
-  generate_strides();
-  Tensor t(shape, device_, data_type_);
-  t.block_ = block_;
-  t.block_->IncRefCount();
-  return t;
+  } else {
+    Tensor t;
+    t.shape_ = shape;
+    t.device_ = device_;
+    t.data_type_ = data_type_;
+    t.block_ = block_;  // be careful about the block inference (mem leaking)
+    t.block_->IncRefCount();
+    t.generate_strides();
+    return t;
+  }
 }
 
 Tensor Tensor::Reshape(Shape &&shape) {
@@ -1344,14 +1359,16 @@ Tensor Tensor::Reshape(Shape &&shape) {
     singa::Transform(*this, &t);
     t.shape_ = shape;
     return t;
- }
-
-  shape_ = shape;
-  generate_strides();
-  Tensor t(shape, device_, data_type_);
-  t.block_ = block_;
-  t.block_->IncRefCount();
-  return t;
+  } else {
+    Tensor t;
+    t.shape_ = shape;
+    t.device_ = device_;
+    t.data_type_ = data_type_;
+    t.block_ = block_;  // be careful about the block inference (mem leaking)
+    t.block_->IncRefCount();
+    t.generate_strides();
+    return t;
+  }
 }
 
 Tensor Reshape(const Tensor &in, const Shape &s) {

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/58e6640d/src/model/layer/cudnn_rnn.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/cudnn_rnn.cc b/src/model/layer/cudnn_rnn.cc
old mode 100644
new mode 100755
index 28a52c5..eb2bfd3
--- a/src/model/layer/cudnn_rnn.cc
+++ b/src/model/layer/cudnn_rnn.cc
@@ -144,7 +144,7 @@ void CudnnRNN::SetRNNDescriptor(shared_ptr<Device> dev) {
     rnn_mode = CUDNN_RNN_TANH;
   else if (rnn_mode_ == "gru")
     rnn_mode = CUDNN_GRU;
-#ifdef CUDNN_MAJOR == 5
+#if CUDNN_MAJOR <= 5
   CUDNN_CHECK(cudnnSetRNNDescriptor(rnn_desc_, hidden_size_, num_stacks_,
                                     dropout_desc_, input_mode, direction,
                                     rnn_mode, dtype_));