You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@singa.apache.org by sh...@apache.org on 2020/07/09 14:08:46 UTC
[singa] branch dev updated: fix training loss error
This is an automated email from the ASF dual-hosted git repository.
shicong pushed a commit to branch dev
in repository https://gitbox.apache.org/repos/asf/singa.git
The following commit(s) were added to refs/heads/dev by this push:
new 731b89b fix training loss error
new 8366813 Merge pull request #762 from chrishkchris/fix_loss_error
731b89b is described below
commit 731b89b2121b9fd27ca68296cf85ef6bebc69a36
Author: Chris Yeung <ch...@yahoo.com.hk>
AuthorDate: Thu Jul 9 17:30:26 2020 +0800
fix training loss error
---
include/singa/core/device.h | 3 ++-
python/singa/autograd.py | 6 ++----
src/core/device/device.cc | 4 ++--
3 files changed, 6 insertions(+), 7 deletions(-)
diff --git a/include/singa/core/device.h b/include/singa/core/device.h
index 670c01e..648cda2 100644
--- a/include/singa/core/device.h
+++ b/include/singa/core/device.h
@@ -311,7 +311,8 @@ class Platform {
public:
/// Return the default host device
static std::shared_ptr<Device> GetDefaultDevice() {
- defaultDevice->Reset();
+ // cannot reset cpu device, which leads to error
+ // defaultDevice->Reset();
return defaultDevice;
}
diff --git a/python/singa/autograd.py b/python/singa/autograd.py
index de4a6ec..0b2a2b7 100644
--- a/python/singa/autograd.py
+++ b/python/singa/autograd.py
@@ -1249,13 +1249,11 @@ class SoftMaxCrossEntropy(Operator):
def softmax_cross_entropy(x, t):
- assert x.shape == t.shape, "input and target shape different: %s, %s" % (
- x.shape, t.shape)
assert x.ndim() == 2, "1st arg required 2d tensor. got shape: %s" % (
x.shape)
- assert t.ndim() == 2, "2nd arg required 2d tensor. got shape: %s" % (
+ assert t.ndim() <= 2, "2nd arg required <=2d tensor. got shape: %s" % (
t.shape)
- # x is the logits and t is the ground truth; both are 2D.
+ # x is the logits and t is the ground truth.
return SoftMaxCrossEntropy(t)(x)[0]
diff --git a/src/core/device/device.cc b/src/core/device/device.cc
index a1bc3cd..d054672 100644
--- a/src/core/device/device.cc
+++ b/src/core/device/device.cc
@@ -41,8 +41,8 @@ void Device::Reset() {
Sync();
// Reset Seed
- seed_ = std::chrono::system_clock::now().time_since_epoch().count();
- SetRandSeed(seed_);
+ // seed_ = std::chrono::system_clock::now().time_since_epoch().count();
+ // SetRandSeed(seed_);
// Reset Graph
graph_->Reset();