You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@singa.apache.org by zh...@apache.org on 2016/06/13 13:20:17 UTC

[24/50] [abbrv] incubator-singa git commit: SINGA-184 Add Cross Entropy loss computation

SINGA-184 Add Cross Entropy loss computation

Implement Cross Entropy loss
Pass cpplint.py, test pass compilation
Todo: check test


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/efd7b627
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/efd7b627
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/efd7b627

Branch: refs/heads/master
Commit: efd7b627bacb4acd6a3322468350f2b5399f725b
Parents: 3e2507b
Author: kaiping <ka...@comp.nus.edu.sg>
Authored: Fri May 27 12:09:30 2016 +0800
Committer: Wei Wang <wa...@comp.nus.edu.sg>
Committed: Tue May 31 22:14:09 2016 +0800

----------------------------------------------------------------------
 src/model/loss/cross_entropy.h   | 105 ++++++++++++++++++++++++++++++++++
 test/singa/test_cross_entropy.cc |  66 +++++++++++++++++++++
 2 files changed, 171 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/efd7b627/src/model/loss/cross_entropy.h
----------------------------------------------------------------------
diff --git a/src/model/loss/cross_entropy.h b/src/model/loss/cross_entropy.h
new file mode 100644
index 0000000..815b795
--- /dev/null
+++ b/src/model/loss/cross_entropy.h
@@ -0,0 +1,105 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef SRC_MODEL_LOSS_CROSS_ENTROPY_H_
+#define SRC_MODEL_LOSS_CROSS_ENTROPY_H_
+#include <stack>
+#include "singa/model/loss.h"
+
+namespace singa {
+
+/// Cross entropy is for cross entropy loss.
+class CrossEntropy : public Loss<Tensor> {
+ public:
+  /// Compute the loss values for each sample/instance given the prediction
+  /// and the target, which is sum {-log(prob_of_truth)}
+  /// Users can call Average(const Tensor&) to get the average
+  /// loss value over all samples in the batch.
+  Tensor Forward(const Tensor& prediction, const Tensor& target) override;
+
+  /// Compute the gradients of the loss values w.r.t. the prediction,
+  /// which is: if the entry x corresponds to ground truth,
+  /// then softmax(x) - 1; else, softmax(x)
+  Tensor Backward() override;
+
+ private:
+  // to buffer intermediate data, i.e., softmax(prediction), target
+  std::stack<Tensor> buf_;
+};
+
+Tensor CrossEntropy::Forward(const Tensor& prediction, const Tensor& target) {
+  CHECK(buf_.empty()) << "Do not call Forward successively for more than twice."
+                      << " The calling pattern is [Forward|Evaluate] Backward";
+
+  size_t batchsize = 1;
+  if (prediction.nDim() > 1) batchsize = prediction.shape().at(0);
+  size_t dim = prediction.Size() / batchsize;
+  // a temporal Softmax layer for forward computation
+//  LayerConf conf; // TODO(kaiping): this is currently commented
+//  Softmax softmax_tmp;
+//  softmax_tmp.Setup(conf);
+//  Tensor softmax = softmax_tmp.Forward(0, prediction);
+
+  Tensor softmax(Shape{batchsize, dim});  // TODO(kaiping): Delete
+//  softmax.SetValue<float>(0.5f); // TODO(kaiping): Delete
+
+  softmax.Reshape(Shape{batchsize, dim});
+  // buffer intermediate data
+  buf_.push(softmax);
+  buf_.push(target);
+
+  // Compute loss for each sample
+  Tensor loss(Shape{batchsize, 1});
+  float * pre_ptr = reinterpret_cast<float*>(softmax.blob()->mutable_data());
+  float * truth_ptr = reinterpret_cast<float*>(target.blob()->mutable_data());
+  float * loss_ptr = reinterpret_cast<float*>(loss.blob()->mutable_data());
+  for (size_t i = 0; i < batchsize; i++) {
+    int ilabel = static_cast<int>(truth_ptr[i]);
+    CHECK_GE(ilabel, 0);
+    float prob_of_truth = pre_ptr[ilabel];
+    loss_ptr[i] = -log(prob_of_truth);
+    pre_ptr += dim;  // change to the next sample
+  }
+  return loss;
+}
+
+Tensor CrossEntropy::Backward() {
+  const Tensor& target = buf_.top();
+  buf_.pop();
+  Tensor softmax = buf_.top();
+  buf_.pop();
+
+  size_t batchsize = 1;
+  if (softmax.nDim() > 1)
+    batchsize = softmax.shape().at(0);
+  size_t dim = softmax.Size() / batchsize;
+  float * truth_ptr = reinterpret_cast<float*>(target.blob()->mutable_data());
+  float * pre_ptr = reinterpret_cast<float*>(softmax.blob()->mutable_data());
+  for (size_t i = 0; i < batchsize; i++) {
+    int ilabel = static_cast<int>(truth_ptr[i]);
+    // CHECK_GE(ilabel, 0);
+    pre_ptr[ilabel] -= 1.0;
+    pre_ptr += dim;  // change to the next sample
+  }
+  return softmax;
+}
+}  // namespace singa
+
+#endif  // SRC_MODEL_LOSS_CROSS_ENTROPY_H_
+
+

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/efd7b627/test/singa/test_cross_entropy.cc
----------------------------------------------------------------------
diff --git a/test/singa/test_cross_entropy.cc b/test/singa/test_cross_entropy.cc
new file mode 100644
index 0000000..9bb2321
--- /dev/null
+++ b/test/singa/test_cross_entropy.cc
@@ -0,0 +1,66 @@
+/************************************************************
+*
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements.  See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership.  The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License.  You may obtain a copy of the License at
+*
+*   http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an
+* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+* KIND, either express or implied.  See the License for the
+* specific language governing permissions and limitations
+* under the License.
+*
+*************************************************************/
+
+#include "gtest/gtest.h"
+#include "singa/core/tensor.h"
+#include "singa/core/device.h"
+#include "../src/model/loss/cross_entropy.h"
+
+using singa::Tensor;
+class TestCrossEntropy : public ::testing::Test {
+ protected:
+  virtual void SetUp() {
+    p.Reshape(singa::Shape{2, 4});
+    t.Reshape(singa::Shape{2, 1});
+    p.CopyDataFromHostPtr(pdat, sizeof(pdat) / sizeof(float));
+    t.CopyDataFromHostPtr(tdat, sizeof(pdat) / sizeof(float));
+  }
+  const float pdat[8] = {0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1};
+  const float tdat[2] = {0.0, 2.0};
+
+  singa::Tensor p, t;
+};
+
+TEST_F(TestCrossEntropy, CppForward) {
+  singa::CrossEntropy cross_entropy;
+  const Tensor& loss = cross_entropy.Forward(p, t);
+  auto ldat = loss.data<const float*>();
+
+  const float result_test = -log(0.25);
+  EXPECT_FLOAT_EQ(ldat[0], result_test);
+  EXPECT_FLOAT_EQ(ldat[1], result_test);
+}
+
+TEST_F(TestCrossEntropy, CppBackward) {
+  singa::CrossEntropy cross_entropy;
+  cross_entropy.Forward(p, t);
+  const Tensor& grad = cross_entropy.Backward();
+
+  auto gdat = grad.data<const float*>();
+  EXPECT_FLOAT_EQ(gdat[0], -0.75);
+  EXPECT_FLOAT_EQ(gdat[1], 0.25);
+  EXPECT_FLOAT_EQ(gdat[2], 0.25);
+  EXPECT_FLOAT_EQ(gdat[3], 0.25);
+  EXPECT_FLOAT_EQ(gdat[4], 0.25);
+  EXPECT_FLOAT_EQ(gdat[5], 0.25);
+  EXPECT_FLOAT_EQ(gdat[6], -0.75);
+  EXPECT_FLOAT_EQ(gdat[7], 0.25);
+}