You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@singa.apache.org by zh...@apache.org on 2016/06/13 13:20:02 UTC
[09/50] [abbrv] incubator-singa git commit: SINGA-171 - Create
CppDevice and CudaDevice
SINGA-171 - Create CppDevice and CudaDevice
Rename Device subclasses based on the programming language and hardware,
e.g., CppCPU indicates the device is a CPU which runs cpp code, CudaGPU
indicates the device is a NvidiaGPU which runs cuda code, and CudaCPU
indicates the device is a CPU which uses cuda to malloc and free pinned
memory for the CudaGPU.
Corrspondingly, we rename the lib namepace to lang. and Device type()
to lang().
Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/9d1bcb42
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/9d1bcb42
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/9d1bcb42
Branch: refs/heads/master
Commit: 9d1bcb429a6f0a79426551a5fd42fdcadbf2f852
Parents: e3da6a5
Author: Wei Wang <wa...@comp.nus.edu.sg>
Authored: Thu May 19 17:00:01 2016 +0800
Committer: Wei Wang <wa...@comp.nus.edu.sg>
Committed: Thu May 19 17:08:36 2016 +0800
----------------------------------------------------------------------
include/singa/core/common.h | 5 +-
include/singa/core/device.h | 77 +++++++------
include/singa/model/layer.h | 9 +-
include/singa/utils/cuda.h | 94 ----------------
include/singa/utils/cuda_utils.h | 94 ++++++++++++++++
src/core/device/cpp_cpu.cc | 47 ++++++++
src/core/device/cpp_device.cc | 47 --------
src/core/device/cuda_device.cc | 157 ---------------------------
src/core/device/cuda_gpu.cc | 159 +++++++++++++++++++++++++++
src/core/device/device.cc | 2 +-
src/core/tensor/tensor.cc | 185 +++++++++++++++-----------------
src/core/tensor/tensor_math.h | 106 +++++++-----------
src/core/tensor/tensor_math_cpp.h | 56 +++++-----
src/core/tensor/tensor_math_cuda.h | 2 +-
src/proto/core.proto | 2 +-
test/singa/test_cpp_cpu.cc | 71 ++++++++++++
test/singa/test_cpp_device.cc | 71 ------------
test/singa/test_cudnn_dropout.cc | 8 +-
test/singa/test_tensor.cc | 6 +-
19 files changed, 588 insertions(+), 610 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/9d1bcb42/include/singa/core/common.h
----------------------------------------------------------------------
diff --git a/include/singa/core/common.h b/include/singa/core/common.h
index 0fa301a..61c1c41 100644
--- a/include/singa/core/common.h
+++ b/include/singa/core/common.h
@@ -32,16 +32,15 @@
#endif
namespace singa {
-namespace lib {
+namespace lang {
/// To implemente functions using cpp libraries
typedef struct _Cpp { } Cpp;
/// To implemente functions using cuda libraries
typedef struct _Cuda { } Cuda;
/// To implement function using opencl libraries
typedef struct _Opencl { } Opencl;
-} // namespace lib
+} // namespace lang
-typedef unsigned char Byte;
/// Blob reprent a chunk of memory (on device or host) managed by VirtualMemory.
class Blob {
public:
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/9d1bcb42/include/singa/core/device.h
----------------------------------------------------------------------
diff --git a/include/singa/core/device.h b/include/singa/core/device.h
index 29b7677..a67b564 100644
--- a/include/singa/core/device.h
+++ b/include/singa/core/device.h
@@ -33,33 +33,12 @@ using std::vector;
using std::string;
using std::function;
namespace singa {
-/// The base type of callback argument structure.
-/// The specific arg should inherit from this one.
-class CallbackArg {
- public:
- template <typename T>
- T* CastTo() {
- static_assert(std::is_base_of<CallbackArg, T>::value,
- "The casted type must be a sub-class of CallbackArg");
- return static_cast<T*>(this);
- }
-};
-/// Type of callback functions for executing tensor ops.
-typedef function<void(CallbackArg*)> CallbackFn;
/// Allocate memory and execute Tensor operations.
/// There are three types of devices distinguished by their programming
/// languages, namely cpp, cuda and opencl.
class Device {
- public:
- /// Operation has a function, and read/write blobs.
- typedef struct _Operation {
- function<void(Context*)> fn;
- const vector<Blob*> read_blobs;
- const vector<Blob*> write_blobs;
- } Operation;
-
- public:
+ public:
Device() = default;
/// Constructor with device ID, num of executors (e.g., cuda streams),
/// max mem size to use (in MB), identifier of scheduler type (default
@@ -92,11 +71,14 @@ class Device {
/// wait for all operations submitted to this device.
void Sync();
- DeviceType type() const {
- return device_type_;
+ /// Return the programming language for this device.
+ LangType lang() const {
+ return lang_;
}
+ /// TODO(wangwei) remove it?
Device* host() const { return host_; }
+
int id() const { return id_; }
protected:
@@ -118,18 +100,19 @@ class Device {
unsigned seed_ = 0;
// Scheduler* scheduler_ = nullptr;
// VirtualMemory* vm_ = nullptr;
- /// could be kCpp, kCuda, kOpencl
- DeviceType device_type_;
+ /// Programming language type, could be kCpp, kCuda, kOpencl
+ LangType lang_;
// SafeQueue<Operation> op_queue_;
// SafeQueue<Operation> op_log_;
/// The host device
Device* host_;
};
-// Implement Device functions using cpp.
-class CppDevice : public Device {
+/// Represent a CPU device which may have multiple threads/executors.
+/// It runs cpp code.
+class CppCPU : public Device {
public:
- CppDevice(int id, int num_executors = 1,
+ CppCPU(int id = -1, int num_executors = 1,
string scheduler = "sync", string vm = "gc-only");
void SetRandSeed(unsigned seed) override;
@@ -150,17 +133,17 @@ class CppDevice : public Device {
};
/// a singleton CppDevice as the host for all devices.
-extern CppDevice hostDeviceSingleton;
+extern CppCPU defaultDevice;
// Implement Device using OpenCL libs.
// class OpenclDevice : public Device { };
#ifdef USE_CUDA
-// Implement Device using cuda.
-class CudaDevice : public Device {
+// Represent a Nvidia GPU which runs cuda code.
+class CudaGPU : public Device {
public:
- ~CudaDevice();
- CudaDevice(int id, int num_executors = 1, string scheduler = "sync",
+ ~CudaGPU();
+ CudaGPU(int id = -1, int num_executors = 1, string scheduler = "sync",
string vm = "gc-only");
void SetRandSeed(unsigned seed) override;
@@ -200,11 +183,37 @@ class CudaDevice : public Device {
Context ctx_;
};
+/// CudaCPU which uses cudaMallocHost to allocate pinned memory for host.
+
#endif // USE_CUDA
// Implement a CudaHost device, which used cuda functions for memory
// malloc/free.
// class CudaHost : public Device {}
+//
+/// The base type of callback argument structure.
+/// The specific arg should inherit from this one.
+/*
+class CallbackArg {
+ public:
+ template <typename T>
+ T* CastTo() {
+ static_assert(std::is_base_of<CallbackArg, T>::value,
+ "The casted type must be a sub-class of CallbackArg");
+ return static_cast<T*>(this);
+ }
+};
+/// Type of callback functions for executing tensor ops.
+typedef function<void(CallbackArg*)> CallbackFn;
+public:
+ /// Operation has a function, and read/write blobs.
+ typedef struct _Operation {
+ function<void(Context*)> fn;
+ const vector<Blob*> read_blobs;
+ const vector<Blob*> write_blobs;
+ } Operation;
+
+*/
} // namespace singa
#endif // SINGA_CORE_DEVICE_H_
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/9d1bcb42/include/singa/model/layer.h
----------------------------------------------------------------------
diff --git a/include/singa/model/layer.h b/include/singa/model/layer.h
index 050236a..084c42e 100644
--- a/include/singa/model/layer.h
+++ b/include/singa/model/layer.h
@@ -36,7 +36,7 @@ class Layer {
Layer() = default;
/// Set meta data fields from a string representing a proto message.
- void Setup(const string& proto_str) {
+ void Setup(const string& proto_str) {
LayerConf conf;
conf.ParseFromString(proto_str);
this->Setup(conf);
@@ -55,6 +55,13 @@ class Layer {
virtual const std::string layer_type() const { return "Unknown"; }
/// Set meta data fields configured in 'conf' (a proto message).
+ /// For some layers, which use input tensor shapes for setting its parameter
+ /// shapes (e.g, desen layer and convolution layer), users or wrapper
+ /// functions need to configure ncessary fields inside LayerConf.
+ /// After calling Setup, the shape info of parameters should be accssed
+ /// correctly. All other info that depends on input tensors (e.g., batchsize)
+ /// should be set inside Forward(). Internal buffer/fields are set assuming
+ /// batchsize is 1.
virtual void Setup(const LayerConf& conf) {
name_ = conf.name();
for (const auto& spec : conf.param()) param_specs_.push_back(spec);
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/9d1bcb42/include/singa/utils/cuda.h
----------------------------------------------------------------------
diff --git a/include/singa/utils/cuda.h b/include/singa/utils/cuda.h
deleted file mode 100644
index b2bb5c5..0000000
--- a/include/singa/utils/cuda.h
+++ /dev/null
@@ -1,94 +0,0 @@
-// from caffe include/caffe/util/device_alternative.hpp
-
-#include <cublas_v2.h>
-#include <cuda.h>
-#include <cuda_runtime.h>
-
-//
-// CUDA macros
-//
-
-// CUDA: various checks for different function calls.
-#define CUDA_CHECK(condition) \
- /* Code block avoids redefinition of cudaError_t error */ \
- do { \
- cudaError_t error = condition; \
- CHECK_EQ(error, cudaSuccess) << " " << cudaGetErrorString(error); \
- } while (0)
-
-#define CUBLAS_CHECK(condition) \
- do { \
- cublasStatus_t status = condition; \
- CHECK_EQ(status, CUBLAS_STATUS_SUCCESS) << " " \
- << cublasGetErrorString(status); \
- } while (0)
-
-#define CURAND_CHECK(condition) \
- do { \
- curandStatus_t status = condition; \
- CHECK_EQ(status, CURAND_STATUS_SUCCESS) << " " \
- << curandGetErrorString(status); \
- } while (0)
-
-const char* cublasGetErrorString(cublasStatus_t error) {
- switch (error) {
- case CUBLAS_STATUS_SUCCESS:
- return "CUBLAS_STATUS_SUCCESS";
- case CUBLAS_STATUS_NOT_INITIALIZED:
- return "CUBLAS_STATUS_NOT_INITIALIZED";
- case CUBLAS_STATUS_ALLOC_FAILED:
- return "CUBLAS_STATUS_ALLOC_FAILED";
- case CUBLAS_STATUS_INVALID_VALUE:
- return "CUBLAS_STATUS_INVALID_VALUE";
- case CUBLAS_STATUS_ARCH_MISMATCH:
- return "CUBLAS_STATUS_ARCH_MISMATCH";
- case CUBLAS_STATUS_MAPPING_ERROR:
- return "CUBLAS_STATUS_MAPPING_ERROR";
- case CUBLAS_STATUS_EXECUTION_FAILED:
- return "CUBLAS_STATUS_EXECUTION_FAILED";
- case CUBLAS_STATUS_INTERNAL_ERROR:
- return "CUBLAS_STATUS_INTERNAL_ERROR";
-#if CUDA_VERSION >= 6000
- case CUBLAS_STATUS_NOT_SUPPORTED:
- return "CUBLAS_STATUS_NOT_SUPPORTED";
-#endif
-#if CUDA_VERSION >= 6050
- case CUBLAS_STATUS_LICENSE_ERROR:
- return "CUBLAS_STATUS_LICENSE_ERROR";
-#endif
- }
- return "Unknown cublas status";
-}
-
-const char* curandGetErrorString(curandStatus_t error) {
- switch (error) {
- case CURAND_STATUS_SUCCESS:
- return "CURAND_STATUS_SUCCESS";
- case CURAND_STATUS_VERSION_MISMATCH:
- return "CURAND_STATUS_VERSION_MISMATCH";
- case CURAND_STATUS_NOT_INITIALIZED:
- return "CURAND_STATUS_NOT_INITIALIZED";
- case CURAND_STATUS_ALLOCATION_FAILED:
- return "CURAND_STATUS_ALLOCATION_FAILED";
- case CURAND_STATUS_TYPE_ERROR:
- return "CURAND_STATUS_TYPE_ERROR";
- case CURAND_STATUS_OUT_OF_RANGE:
- return "CURAND_STATUS_OUT_OF_RANGE";
- case CURAND_STATUS_LENGTH_NOT_MULTIPLE:
- return "CURAND_STATUS_LENGTH_NOT_MULTIPLE";
- case CURAND_STATUS_DOUBLE_PRECISION_REQUIRED:
- return "CURAND_STATUS_DOUBLE_PRECISION_REQUIRED";
- case CURAND_STATUS_LAUNCH_FAILURE:
- return "CURAND_STATUS_LAUNCH_FAILURE";
- case CURAND_STATUS_PREEXISTING_FAILURE:
- return "CURAND_STATUS_PREEXISTING_FAILURE";
- case CURAND_STATUS_INITIALIZATION_FAILED:
- return "CURAND_STATUS_INITIALIZATION_FAILED";
- case CURAND_STATUS_ARCH_MISMATCH:
- return "CURAND_STATUS_ARCH_MISMATCH";
- case CURAND_STATUS_INTERNAL_ERROR:
- return "CURAND_STATUS_INTERNAL_ERROR";
- }
- return "Unknown curand status";
-}
-
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/9d1bcb42/include/singa/utils/cuda_utils.h
----------------------------------------------------------------------
diff --git a/include/singa/utils/cuda_utils.h b/include/singa/utils/cuda_utils.h
new file mode 100644
index 0000000..b2bb5c5
--- /dev/null
+++ b/include/singa/utils/cuda_utils.h
@@ -0,0 +1,94 @@
+// from caffe include/caffe/util/device_alternative.hpp
+
+#include <cublas_v2.h>
+#include <cuda.h>
+#include <cuda_runtime.h>
+
+//
+// CUDA macros
+//
+
+// CUDA: various checks for different function calls.
+#define CUDA_CHECK(condition) \
+ /* Code block avoids redefinition of cudaError_t error */ \
+ do { \
+ cudaError_t error = condition; \
+ CHECK_EQ(error, cudaSuccess) << " " << cudaGetErrorString(error); \
+ } while (0)
+
+#define CUBLAS_CHECK(condition) \
+ do { \
+ cublasStatus_t status = condition; \
+ CHECK_EQ(status, CUBLAS_STATUS_SUCCESS) << " " \
+ << cublasGetErrorString(status); \
+ } while (0)
+
+#define CURAND_CHECK(condition) \
+ do { \
+ curandStatus_t status = condition; \
+ CHECK_EQ(status, CURAND_STATUS_SUCCESS) << " " \
+ << curandGetErrorString(status); \
+ } while (0)
+
+const char* cublasGetErrorString(cublasStatus_t error) {
+ switch (error) {
+ case CUBLAS_STATUS_SUCCESS:
+ return "CUBLAS_STATUS_SUCCESS";
+ case CUBLAS_STATUS_NOT_INITIALIZED:
+ return "CUBLAS_STATUS_NOT_INITIALIZED";
+ case CUBLAS_STATUS_ALLOC_FAILED:
+ return "CUBLAS_STATUS_ALLOC_FAILED";
+ case CUBLAS_STATUS_INVALID_VALUE:
+ return "CUBLAS_STATUS_INVALID_VALUE";
+ case CUBLAS_STATUS_ARCH_MISMATCH:
+ return "CUBLAS_STATUS_ARCH_MISMATCH";
+ case CUBLAS_STATUS_MAPPING_ERROR:
+ return "CUBLAS_STATUS_MAPPING_ERROR";
+ case CUBLAS_STATUS_EXECUTION_FAILED:
+ return "CUBLAS_STATUS_EXECUTION_FAILED";
+ case CUBLAS_STATUS_INTERNAL_ERROR:
+ return "CUBLAS_STATUS_INTERNAL_ERROR";
+#if CUDA_VERSION >= 6000
+ case CUBLAS_STATUS_NOT_SUPPORTED:
+ return "CUBLAS_STATUS_NOT_SUPPORTED";
+#endif
+#if CUDA_VERSION >= 6050
+ case CUBLAS_STATUS_LICENSE_ERROR:
+ return "CUBLAS_STATUS_LICENSE_ERROR";
+#endif
+ }
+ return "Unknown cublas status";
+}
+
+const char* curandGetErrorString(curandStatus_t error) {
+ switch (error) {
+ case CURAND_STATUS_SUCCESS:
+ return "CURAND_STATUS_SUCCESS";
+ case CURAND_STATUS_VERSION_MISMATCH:
+ return "CURAND_STATUS_VERSION_MISMATCH";
+ case CURAND_STATUS_NOT_INITIALIZED:
+ return "CURAND_STATUS_NOT_INITIALIZED";
+ case CURAND_STATUS_ALLOCATION_FAILED:
+ return "CURAND_STATUS_ALLOCATION_FAILED";
+ case CURAND_STATUS_TYPE_ERROR:
+ return "CURAND_STATUS_TYPE_ERROR";
+ case CURAND_STATUS_OUT_OF_RANGE:
+ return "CURAND_STATUS_OUT_OF_RANGE";
+ case CURAND_STATUS_LENGTH_NOT_MULTIPLE:
+ return "CURAND_STATUS_LENGTH_NOT_MULTIPLE";
+ case CURAND_STATUS_DOUBLE_PRECISION_REQUIRED:
+ return "CURAND_STATUS_DOUBLE_PRECISION_REQUIRED";
+ case CURAND_STATUS_LAUNCH_FAILURE:
+ return "CURAND_STATUS_LAUNCH_FAILURE";
+ case CURAND_STATUS_PREEXISTING_FAILURE:
+ return "CURAND_STATUS_PREEXISTING_FAILURE";
+ case CURAND_STATUS_INITIALIZATION_FAILED:
+ return "CURAND_STATUS_INITIALIZATION_FAILED";
+ case CURAND_STATUS_ARCH_MISMATCH:
+ return "CURAND_STATUS_ARCH_MISMATCH";
+ case CURAND_STATUS_INTERNAL_ERROR:
+ return "CURAND_STATUS_INTERNAL_ERROR";
+ }
+ return "Unknown curand status";
+}
+
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/9d1bcb42/src/core/device/cpp_cpu.cc
----------------------------------------------------------------------
diff --git a/src/core/device/cpp_cpu.cc b/src/core/device/cpp_cpu.cc
new file mode 100644
index 0000000..3287911
--- /dev/null
+++ b/src/core/device/cpp_cpu.cc
@@ -0,0 +1,47 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "singa/core/device.h"
+namespace singa {
+CppCPU defaultDevice(-1, 1);
+CppCPU::CppCPU(int id, int num_executors, string scheduler,
+ string vm) : Device(id, num_executors, scheduler, vm) {
+ lang_ = kCpp;
+ host_ = nullptr;
+}
+
+void CppCPU::SetRandSeed(unsigned seed) {
+ ctx_.random_generator.seed(seed);
+}
+void CppCPU::DoExec(function<void(Context*)>&& fn, int executor) {
+ CHECK_EQ(executor, 0);
+ fn(&ctx_);
+}
+
+void* CppCPU::Malloc(int size) {
+ return malloc(size);
+}
+
+void CppCPU::Free(void* ptr) {
+ free(ptr);
+}
+
+void CppCPU::CopyToFrom(void* dst, const void* src, size_t nBytes,
+ CopyDirection direction, Context* ctx) {
+ memcpy(dst, src, nBytes);
+}
+} // namespace singa
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/9d1bcb42/src/core/device/cpp_device.cc
----------------------------------------------------------------------
diff --git a/src/core/device/cpp_device.cc b/src/core/device/cpp_device.cc
deleted file mode 100644
index 763156c..0000000
--- a/src/core/device/cpp_device.cc
+++ /dev/null
@@ -1,47 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-#include "singa/core/device.h"
-namespace singa {
-CppDevice hostDeviceSingleton(-1, 1);
-CppDevice::CppDevice(int id, int num_executors, string scheduler,
- string vm) : Device(id, num_executors, scheduler, vm) {
- device_type_ = kCpp;
- host_ = nullptr;
-}
-
-void CppDevice::SetRandSeed(unsigned seed) {
- ctx_.random_generator.seed(seed);
-}
-void CppDevice::DoExec(function<void(Context*)>&& fn, int executor) {
- CHECK_EQ(executor, 0);
- fn(&ctx_);
-}
-
-void* CppDevice::Malloc(int size) {
- return malloc(size);
-}
-
-void CppDevice::Free(void* ptr) {
- free(ptr);
-}
-
-void CppDevice::CopyToFrom(void* dst, const void* src, size_t nBytes,
- CopyDirection direction, Context* ctx) {
- memcpy(dst, src, nBytes);
-}
-} // namespace singa
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/9d1bcb42/src/core/device/cuda_device.cc
----------------------------------------------------------------------
diff --git a/src/core/device/cuda_device.cc b/src/core/device/cuda_device.cc
deleted file mode 100644
index 9be1a6e..0000000
--- a/src/core/device/cuda_device.cc
+++ /dev/null
@@ -1,157 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-#ifdef USE_CUDA
-#include <cublas_v2.h>
-#include <cuda.h>
-#include <cuda_runtime.h>
-#include <curand.h>
-#include <chrono>
-
-#include "singa/core/device.h"
-#include "singa/utils/cuda.h"
-namespace singa {
-
-const cudaMemcpyKind copyKind[] = {cudaMemcpyHostToHost, cudaMemcpyHostToDevice,
- cudaMemcpyDeviceToHost,
- cudaMemcpyDeviceToDevice};
-
-CudaDevice::~CudaDevice() {
- if (ctx_.cublas_handle)
- CUBLAS_CHECK(cublasDestroy(ctx_.cublas_handle));
- if (ctx_.curand_generator)
- CURAND_CHECK(curandDestroyGenerator(ctx_.curand_generator));
-#ifdef USE_CUDNN
- if (ctx_.cudnn_handle) {
- auto status = cudnnDestroy(ctx_.cudnn_handle);
- CHECK_EQ(status, CUDNN_STATUS_SUCCESS) << cudnnGetErrorString(status);
- }
-#endif
-}
-
-CudaDevice::CudaDevice(int id, int num_executors,
- string scheduler, string vm)
- : Device(id, num_executors, scheduler, vm) {
- device_type_ = kCuda;
- host_ = nullptr; // TODO(wangwei) add host device
- ctx_.stream = NULL; // use the default sync stream
- // TODO(wangwei) create one handle for each steam?
- CUDA_CHECK(cudaSetDevice(FindDevice(0)));
- // use curandCreateGeneratorHost for CudaHost device
- CURAND_CHECK(
- curandCreateGenerator(&ctx_.curand_generator, CURAND_RNG_PSEUDO_DEFAULT));
- auto seed = std::chrono::system_clock::now().time_since_epoch().count();
- SetRandSeed(seed);
- // TODO(wangwei) if one generator per stream, then need diff offset per gen?
- CURAND_CHECK(curandSetGeneratorOffset(ctx_.curand_generator, 0));
- CUBLAS_CHECK(cublasCreate(&(ctx_.cublas_handle)));
-
-#ifdef USE_CUDNN
- // TODO(wangwei) create one handle for each stream?
- auto status = cudnnCreate(&ctx_.cudnn_handle);
- CHECK_EQ(status, CUDNN_STATUS_SUCCESS) << cudnnGetErrorString(status);
-#endif // USE_CUDNN
-}
-
-void CudaDevice::SetRandSeed(unsigned seed) {
- CHECK(ctx_.curand_generator);
- CURAND_CHECK(
- curandSetPseudoRandomGeneratorSeed(ctx_.curand_generator, seed));
-}
-
-void CudaDevice::DoExec(function<void(Context*)>&& fn, int executor) {
- fn(&ctx_);
-}
-
-void CudaDevice::CopyToFrom(void* dst, const void* src, size_t nBytes,
- CopyDirection direction, Context* ctx) {
- cudaMemcpy(dst, src, nBytes, copyKind[direction]);
- // TODO(wangwei) use async copy
- // cudaMemcpyAsync(dst, src, nBytes,cudaMemcpyDefault, ctx_.stream);
-}
-
-/// Allocate cpu memory.
-void* CudaDevice::Malloc(int size) {
- void* ptr = nullptr;
- CUDA_CHECK(cudaMalloc(&ptr, size));
- return ptr;
-}
-
- /// Free cpu memory.
-void CudaDevice::Free(void* ptr) {
- CHECK_NE(ptr, nullptr);
- CUDA_CHECK(cudaFree(ptr));
-}
-
-
-// ==========Following code is from Caffe src/caffe/common.cpp=================
-
-void CudaDevice::DeviceQuery() {
- cudaDeviceProp prop;
- int device;
- if (cudaSuccess != cudaGetDevice(&device)) {
- printf("No cuda device present.\n");
- return;
- }
- CUDA_CHECK(cudaGetDeviceProperties(&prop, device));
- LOG(INFO) << "Device id: " << device;
- LOG(INFO) << "Major revision number: " << prop.major;
- LOG(INFO) << "Minor revision number: " << prop.minor;
- LOG(INFO) << "Name: " << prop.name;
- LOG(INFO) << "Total global memory: " << prop.totalGlobalMem;
- LOG(INFO) << "Total shared memory per block: " << prop.sharedMemPerBlock;
- LOG(INFO) << "Total registers per block: " << prop.regsPerBlock;
- LOG(INFO) << "Warp size: " << prop.warpSize;
- LOG(INFO) << "Maximum memory pitch: " << prop.memPitch;
- LOG(INFO) << "Maximum threads per block: " << prop.maxThreadsPerBlock;
- LOG(INFO) << "Maximum dimension of block: "
- << prop.maxThreadsDim[0] << ", " << prop.maxThreadsDim[1] << ", "
- << prop.maxThreadsDim[2];
- LOG(INFO) << "Maximum dimension of grid: "
- << prop.maxGridSize[0] << ", " << prop.maxGridSize[1] << ", "
- << prop.maxGridSize[2];
- LOG(INFO) << "Clock rate: " << prop.clockRate;
- LOG(INFO) << "Total constant memory: " << prop.totalConstMem;
- LOG(INFO) << "Texture alignment: " << prop.textureAlignment;
- LOG(INFO) << "Concurrent copy and execution: "
- << (prop.deviceOverlap ? "Yes" : "No");
- LOG(INFO) << "Number of multiprocessors: " << prop.multiProcessorCount;
- LOG(INFO) << "Kernel execution timeout: "
- << (prop.kernelExecTimeoutEnabled ? "Yes" : "No");
- return;
-}
-
-bool CudaDevice::CheckDevice(const int device_id) {
- bool r = ((cudaSuccess == cudaSetDevice(device_id)) &&
- (cudaSuccess == cudaFree(0)));
- // reset any error that may have occurred.
- cudaGetLastError();
- return r;
-}
-
-int CudaDevice::FindDevice(const int start_id) {
- int count = 0;
- CUDA_CHECK(cudaGetDeviceCount(&count));
- for (int i = start_id; i < count; i++) {
- if (CheckDevice(i)) return i;
- }
- return -1;
-}
-
-
-} // namespace singa
-#endif // USE_CUDA
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/9d1bcb42/src/core/device/cuda_gpu.cc
----------------------------------------------------------------------
diff --git a/src/core/device/cuda_gpu.cc b/src/core/device/cuda_gpu.cc
new file mode 100644
index 0000000..8eafc4c
--- /dev/null
+++ b/src/core/device/cuda_gpu.cc
@@ -0,0 +1,159 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifdef USE_CUDA
+#include <cublas_v2.h>
+#include <cuda.h>
+#include <cuda_runtime.h>
+#include <curand.h>
+#include <chrono>
+
+#include "singa/core/device.h"
+#include "singa/utils/cuda_utils.h"
+namespace singa {
+
+const cudaMemcpyKind copyKind[] = {cudaMemcpyHostToHost, cudaMemcpyHostToDevice,
+ cudaMemcpyDeviceToHost,
+ cudaMemcpyDeviceToDevice};
+
+CudaGPU::~CudaGPU() {
+ if (ctx_.cublas_handle)
+ CUBLAS_CHECK(cublasDestroy(ctx_.cublas_handle));
+ if (ctx_.curand_generator)
+ CURAND_CHECK(curandDestroyGenerator(ctx_.curand_generator));
+#ifdef USE_CUDNN
+ if (ctx_.cudnn_handle) {
+ auto status = cudnnDestroy(ctx_.cudnn_handle);
+ CHECK_EQ(status, CUDNN_STATUS_SUCCESS) << cudnnGetErrorString(status);
+ }
+#endif
+}
+
+CudaGPU::CudaGPU(int id, int num_executors,
+ string scheduler, string vm)
+ : Device(id, num_executors, scheduler, vm) {
+ if (id == -1)
+ id = FindDevice(0);
+ lang_ = kCuda;
+ host_ = nullptr; // TODO(wangwei) add host device
+ ctx_.stream = NULL; // use the default sync stream
+ // TODO(wangwei) create one handle for each steam?
+ CUDA_CHECK(cudaSetDevice(FindDevice(0)));
+ // use curandCreateGeneratorHost for CudaHost device
+ CURAND_CHECK(
+ curandCreateGenerator(&ctx_.curand_generator, CURAND_RNG_PSEUDO_DEFAULT));
+ auto seed = std::chrono::system_clock::now().time_since_epoch().count();
+ SetRandSeed(seed);
+ // TODO(wangwei) if one generator per stream, then need diff offset per gen?
+ CURAND_CHECK(curandSetGeneratorOffset(ctx_.curand_generator, 0));
+ CUBLAS_CHECK(cublasCreate(&(ctx_.cublas_handle)));
+
+#ifdef USE_CUDNN
+ // TODO(wangwei) create one handle for each stream?
+ auto status = cudnnCreate(&ctx_.cudnn_handle);
+ CHECK_EQ(status, CUDNN_STATUS_SUCCESS) << cudnnGetErrorString(status);
+#endif // USE_CUDNN
+}
+
+void CudaGPU::SetRandSeed(unsigned seed) {
+ CHECK(ctx_.curand_generator);
+ CURAND_CHECK(
+ curandSetPseudoRandomGeneratorSeed(ctx_.curand_generator, seed));
+}
+
+void CudaGPU::DoExec(function<void(Context*)>&& fn, int executor) {
+ fn(&ctx_);
+}
+
+void CudaGPU::CopyToFrom(void* dst, const void* src, size_t nBytes,
+ CopyDirection direction, Context* ctx) {
+ cudaMemcpy(dst, src, nBytes, copyKind[direction]);
+ // TODO(wangwei) use async copy
+ // cudaMemcpyAsync(dst, src, nBytes,cudaMemcpyDefault, ctx_.stream);
+}
+
+/// Allocate cpu memory.
+void* CudaGPU::Malloc(int size) {
+ void* ptr = nullptr;
+ CUDA_CHECK(cudaMalloc(&ptr, size));
+ return ptr;
+}
+
+ /// Free cpu memory.
+void CudaGPU::Free(void* ptr) {
+ CHECK_NE(ptr, nullptr);
+ CUDA_CHECK(cudaFree(ptr));
+}
+
+
+// ==========Following code is from Caffe src/caffe/common.cpp=================
+
+void CudaGPU::DeviceQuery() {
+ cudaDeviceProp prop;
+ int device;
+ if (cudaSuccess != cudaGetDevice(&device)) {
+ printf("No cuda device present.\n");
+ return;
+ }
+ CUDA_CHECK(cudaGetDeviceProperties(&prop, device));
+ LOG(INFO) << "Device id: " << device;
+ LOG(INFO) << "Major revision number: " << prop.major;
+ LOG(INFO) << "Minor revision number: " << prop.minor;
+ LOG(INFO) << "Name: " << prop.name;
+ LOG(INFO) << "Total global memory: " << prop.totalGlobalMem;
+ LOG(INFO) << "Total shared memory per block: " << prop.sharedMemPerBlock;
+ LOG(INFO) << "Total registers per block: " << prop.regsPerBlock;
+ LOG(INFO) << "Warp size: " << prop.warpSize;
+ LOG(INFO) << "Maximum memory pitch: " << prop.memPitch;
+ LOG(INFO) << "Maximum threads per block: " << prop.maxThreadsPerBlock;
+ LOG(INFO) << "Maximum dimension of block: "
+ << prop.maxThreadsDim[0] << ", " << prop.maxThreadsDim[1] << ", "
+ << prop.maxThreadsDim[2];
+ LOG(INFO) << "Maximum dimension of grid: "
+ << prop.maxGridSize[0] << ", " << prop.maxGridSize[1] << ", "
+ << prop.maxGridSize[2];
+ LOG(INFO) << "Clock rate: " << prop.clockRate;
+ LOG(INFO) << "Total constant memory: " << prop.totalConstMem;
+ LOG(INFO) << "Texture alignment: " << prop.textureAlignment;
+ LOG(INFO) << "Concurrent copy and execution: "
+ << (prop.deviceOverlap ? "Yes" : "No");
+ LOG(INFO) << "Number of multiprocessors: " << prop.multiProcessorCount;
+ LOG(INFO) << "Kernel execution timeout: "
+ << (prop.kernelExecTimeoutEnabled ? "Yes" : "No");
+ return;
+}
+
+bool CudaGPU::CheckDevice(const int device_id) {
+ bool r = ((cudaSuccess == cudaSetDevice(device_id)) &&
+ (cudaSuccess == cudaFree(0)));
+ // reset any error that may have occurred.
+ cudaGetLastError();
+ return r;
+}
+
+int CudaGPU::FindDevice(const int start_id) {
+ int count = 0;
+ CUDA_CHECK(cudaGetDeviceCount(&count));
+ for (int i = start_id; i < count; i++) {
+ if (CheckDevice(i)) return i;
+ }
+ return -1;
+}
+
+
+} // namespace singa
+#endif // USE_CUDA
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/9d1bcb42/src/core/device/device.cc
----------------------------------------------------------------------
diff --git a/src/core/device/device.cc b/src/core/device/device.cc
index 205601b..cd860db 100644
--- a/src/core/device/device.cc
+++ b/src/core/device/device.cc
@@ -64,7 +64,7 @@ void Device::CopyDataToFrom(Blob* dst, Blob* src, size_t nBytes,
void Device::CopyDataFromHostPtr(Blob* dst, const void* src, size_t nBytes,
size_t dst_offset) {
- auto direct = device_type_ == kCpp ? kHostToHost : kHostToDevice;
+ auto direct = lang_ == kCpp ? kHostToHost : kHostToDevice;
void* dstptr = reinterpret_cast<char*>(dst->mutable_data()) + dst_offset;
Exec([this, dstptr, src, nBytes,
direct](Context* ctx) { CopyToFrom(dstptr, src, nBytes, direct, ctx); },
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/9d1bcb42/src/core/tensor/tensor.cc
----------------------------------------------------------------------
diff --git a/src/core/tensor/tensor.cc b/src/core/tensor/tensor.cc
index fac846c..185b1f9 100644
--- a/src/core/tensor/tensor.cc
+++ b/src/core/tensor/tensor.cc
@@ -25,23 +25,20 @@
namespace singa {
Tensor::~Tensor() {
- if (blob_ != nullptr && blob_->DecRefCount() == 0)
- device_->FreeBlob(blob_);
+ if (blob_ != nullptr && blob_->DecRefCount() == 0) device_->FreeBlob(blob_);
blob_ = nullptr;
}
-Tensor::Tensor() {
- device_ = &hostDeviceSingleton;
-}
+Tensor::Tensor() { device_ = &defaultDevice; }
Tensor::Tensor(const Shape& shape, DataType dtype)
- : data_type_(dtype), device_(&hostDeviceSingleton), shape_(shape) {
- device_ = &hostDeviceSingleton;
+ : data_type_(dtype), device_(&defaultDevice), shape_(shape) {
+ device_ = &defaultDevice;
blob_ = device_->NewBlob(Product(shape_) * SizeOf(data_type_));
}
Tensor::Tensor(Shape&& shape, DataType dtype)
- : data_type_(dtype), device_(&hostDeviceSingleton), shape_(shape) {
- device_ = &hostDeviceSingleton;
+ : data_type_(dtype), device_(&defaultDevice), shape_(shape) {
+ device_ = &defaultDevice;
blob_ = device_->NewBlob(Product(shape_) * SizeOf(data_type_));
}
Tensor::Tensor(const Shape& shape, Device* device, DataType dtype)
@@ -82,8 +79,7 @@ void Tensor::ResetLike(const Tensor& t) {
void Tensor::ReShape(const Shape& shape) {
if (shape_ != shape) {
- if (blob_ != nullptr && blob_->DecRefCount() == 0)
- device_->FreeBlob(blob_);
+ if (blob_ != nullptr && blob_->DecRefCount() == 0) device_->FreeBlob(blob_);
blob_ = device_->NewBlob(Product(shape) * SizeOf(data_type_));
shape_ = shape;
}
@@ -91,8 +87,7 @@ void Tensor::ReShape(const Shape& shape) {
void Tensor::AsType(DataType type) {
if (data_type_ != type) {
- if (blob_ != nullptr && blob_->DecRefCount() == 0)
- device_->FreeBlob(blob_);
+ if (blob_ != nullptr && blob_->DecRefCount() == 0) device_->FreeBlob(blob_);
blob_ = device_->NewBlob(Product(shape_) * SizeOf(type));
data_type_ = type;
}
@@ -103,17 +98,14 @@ void Tensor::ToDevice(Device* dst) {
if (device_ != dst) {
Tensor tmp(shape_, dst, data_type_);
tmp.CopyData(*this);
- if (blob_ != nullptr && blob_->DecRefCount() == 0)
- device_->FreeBlob(blob_);
+ if (blob_ != nullptr && blob_->DecRefCount() == 0) device_->FreeBlob(blob_);
blob_ = tmp.blob_;
tmp.blob_ = nullptr;
device_ = dst;
}
}
-void Tensor::ToHost() {
- ToDevice(device_->host());
-}
+void Tensor::ToHost() { ToDevice(device_->host()); }
template <typename DType>
void Tensor::CopyDataFromHostPtr(const DType* src, size_t num) {
@@ -153,8 +145,7 @@ Tensor Tensor::T() const {
}
Tensor& Tensor::operator=(const Tensor& t) {
- if (blob_ != nullptr && blob_->DecRefCount() == 0)
- device_->FreeBlob(blob_);
+ if (blob_ != nullptr && blob_->DecRefCount() == 0) device_->FreeBlob(blob_);
transpose_ = t.transpose_;
data_type_ = t.data_type_;
shape_ = t.shape_;
@@ -165,8 +156,7 @@ Tensor& Tensor::operator=(const Tensor& t) {
}
Tensor& Tensor::operator=(Tensor&& t) {
- if (blob_ != nullptr && blob_->DecRefCount() == 0)
- device_->FreeBlob(blob_);
+ if (blob_ != nullptr && blob_->DecRefCount() == 0) device_->FreeBlob(blob_);
transpose_ = t.transpose_;
data_type_ = t.data_type_;
shape_ = std::move(t.shape_);
@@ -177,7 +167,10 @@ Tensor& Tensor::operator=(Tensor&& t) {
}
#define GenUnaryTensorArgMemberFunction(op, fn) \
- Tensor& Tensor::op(const Tensor& t) { fn(*this, t, this); return *this; }
+ Tensor& Tensor::op(const Tensor& t) { \
+ fn(*this, t, this); \
+ return *this; \
+ }
GenUnaryTensorArgMemberFunction(operator+=, Add);
GenUnaryTensorArgMemberFunction(operator-=, Sub);
@@ -210,19 +203,19 @@ void CopyDataToFrom(Tensor* dst, const Tensor& src, size_t num,
Device *src_dev = src.device(), *dst_dev = dst->device();
Blob *from = src.blob(), *to = dst->blob();
- if (dst_dev->type() != src_dev->type()) {
+ if (dst_dev->lang() != src_dev->lang()) {
// let the none cpp device conduct copy op
- if (dst_dev->type() == kCpp) {
+ if (dst_dev->lang() == kCpp) {
src_dev->CopyDataToFrom(to, from, nBytes, kDeviceToHost, dst_offset,
src_offset);
- } else if (src_dev->type() == kCpp) {
+ } else if (src_dev->lang() == kCpp) {
dst_dev->CopyDataToFrom(to, from, nBytes, kHostToDevice, dst_offset,
src_offset);
} else {
LOG(FATAL) << "Not support mem copy betwee Cuda and OpenCL device";
}
} else {
- auto direct = src_dev->type() == kCpp ? kHostToHost : kDeviceToDevice;
+ auto direct = src_dev->lang() == kCpp ? kHostToHost : kDeviceToDevice;
src_dev->CopyDataToFrom(to, from, nBytes, direct, dst_offset, src_offset);
}
}
@@ -252,49 +245,49 @@ void CopyDataToFrom(Tensor* dst, const Tensor& src, size_t num,
} \
} while (0)
-/// typedef DType and Dev according to values of type and lib respectively.
-/// type is from DataType, and lib is from DevType.
-/// DType and Dev would be used in __VA_ARGS__.
-#define TYPE_LIB_SWITCH(dtype, DType, dev, Dev, ...) \
- do { \
- const int _SwitchShift = 3; \
- int _SwitchHash = ((dtype) << _SwitchShift) + (dev); \
- switch (_SwitchHash) { \
- case ((kFloat32 << _SwitchShift) + kCuda): { \
- typedef float DType; \
- typedef lib::Cuda Dev; \
- { __VA_ARGS__ } \
- break; \
- } \
- case ((kFloat32 << _SwitchShift) + kCpp): { \
- typedef float DType; \
- typedef lib::Cpp Dev; \
- { __VA_ARGS__ } \
- break; \
- } \
- case ((kFloat32 << _SwitchShift) + kOpencl): { \
- typedef float DType; \
- typedef lib::Opencl Dev; \
- { __VA_ARGS__ } \
- break; \
- } \
- default: \
- LOG(FATAL) << "Unknown combination of data type " \
- << DataType_Name(dtype) << " and library " \
- << DeviceType_Name(dev); \
- } \
+/// typedef DType and Lang according to data type and device programming
+/// language respectively.
+/// type is from DataType, and lang is from LangType.
+/// DType and Lang would be used in __VA_ARGS__.
+#define TYPE_LANG_SWITCH(dtype, DType, ltype, Lang, ...) \
+ do { \
+ const int _SwitchShift = 3; \
+ int _SwitchHash = ((dtype) << _SwitchShift) + (ltype); \
+ switch (_SwitchHash) { \
+ case ((kFloat32 << _SwitchShift) + kCuda): { \
+ typedef float DType; \
+ typedef lang::Cuda Lang; \
+ { __VA_ARGS__ } \
+ break; \
+ } \
+ case ((kFloat32 << _SwitchShift) + kCpp): { \
+ typedef float DType; \
+ typedef lang::Cpp Lang; \
+ { __VA_ARGS__ } \
+ break; \
+ } \
+ case ((kFloat32 << _SwitchShift) + kOpencl): { \
+ typedef float DType; \
+ typedef lang::Opencl Lang; \
+ { __VA_ARGS__ } \
+ break; \
+ } \
+ default: \
+ LOG(FATAL) << "Unknown combination of data type " \
+ << DataType_Name(dtype) << " and language " \
+ << LangType_Name(ltype); \
+ } \
} while (0)
-
-#define EltwiseUnaryTensorFn(fn, t, ret) \
- do { \
- TYPE_LIB_SWITCH(t.data_type(), DType, t.device()->type(), Dev, { \
- ret->device()->Exec( \
- [t, ret](Context* ctx) { \
- fn<DType, Dev>(t.Size(), t.blob(), ret->blob(), ctx); \
- }, \
- {t.blob()}, {ret->blob()}); \
- }); \
+#define EltwiseUnaryTensorFn(fn, t, ret) \
+ do { \
+ TYPE_LANG_SWITCH(t.data_type(), DType, t.device()->lang(), Lang, { \
+ ret->device()->Exec( \
+ [t, ret](Context* ctx) { \
+ fn<DType, Lang>(t.Size(), t.blob(), ret->blob(), ctx); \
+ }, \
+ {t.blob()}, {ret->blob()}); \
+ }); \
} while (0)
#define GenUnaryTensorFunction(fn) \
@@ -329,26 +322,26 @@ void Softmax(const Tensor& t, Tensor* ret, int axis) {
CHECK_EQ(size % nrow, 0) << "Size = " << size << " nrow = " << nrow;
ncol = size / nrow;
}
- TYPE_LIB_SWITCH(t.data_type(), DType, t.device()->type(), Dev, {
+ TYPE_LANG_SWITCH(t.data_type(), DType, t.device()->lang(), Lang, {
ret->device()->Exec(
[nrow, ncol, t, ret](Context* ctx) {
- Softmax<DType, Dev>(nrow, ncol, t.blob(), ret->blob(), ctx);
+ Softmax<DType, Lang>(nrow, ncol, t.blob(), ret->blob(), ctx);
},
{t.blob()}, {ret->blob()});
- });
+ });
}
-#define EltwiseBinaryTensorFn(fn, lhs, rhs, ret) \
- do { \
- TYPE_LIB_SWITCH(lhs.data_type(), DType, lhs.device()->type(), Dev, { \
- CHECK_EQ(sizeof(DType), SizeOf(rhs.data_type())); \
- ret->device()->Exec( \
- [lhs, rhs, ret](Context* ctx) { \
- fn<DType, Dev>(lhs.Size(), lhs.blob(), rhs.blob(), ret->blob(), \
- ctx); \
- }, \
- {lhs.blob(), rhs.blob()}, {ret->blob()}); \
- }); \
+#define EltwiseBinaryTensorFn(fn, lhs, rhs, ret) \
+ do { \
+ TYPE_LANG_SWITCH(lhs.data_type(), DType, lhs.device()->lang(), Lang, { \
+ CHECK_EQ(sizeof(DType), SizeOf(rhs.data_type())); \
+ ret->device()->Exec( \
+ [lhs, rhs, ret](Context* ctx) { \
+ fn<DType, Lang>(lhs.Size(), lhs.blob(), rhs.blob(), ret->blob(), \
+ ctx); \
+ }, \
+ {lhs.blob(), rhs.blob()}, {ret->blob()}); \
+ }); \
} while (0)
#define GenBinaryTensorFunction(op, fn) \
@@ -369,12 +362,12 @@ GenBinaryTensorFunction(Pow, Pow);
#define EltwiseTensorScalarFn(fn, t, x, ret) \
do { \
- TYPE_LIB_SWITCH(t.data_type(), DType, t.device()->type(), Dev, { \
+ TYPE_LANG_SWITCH(t.data_type(), DType, t.device()->lang(), Lang, { \
static_assert(std::is_same<SType, DType>::value, \
"The Scalar type must match the Tensor data type"); \
ret->device()->Exec( \
[t, x, ret](Context* ctx) { \
- fn<DType, Dev>(t.Size(), t.blob(), x, ret->blob(), ctx); \
+ fn<DType, Lang>(t.Size(), t.blob(), x, ret->blob(), ctx); \
}, \
{t.blob()}, {ret->blob()}); \
}); \
@@ -424,11 +417,11 @@ void Mult(float alpha, const Tensor& A, float beta, const Tensor& B,
size_t m = transA ? A.shape()[1] : A.shape()[0], n = 0;
if (B.shape().size() == 1u) {
n = C->Size();
- TYPE_LIB_SWITCH(A.data_type(), DType, A.device()->type(), Dev, {
+ TYPE_LANG_SWITCH(A.data_type(), DType, A.device()->lang(), Lang, {
C->device()->Exec(
[transA, m, n, alpha, A, beta, B, C](Context* ctx) {
- GEMV<DType, Dev>(transA, m, n, alpha, A.blob(), B.blob(), beta,
- C->blob(), ctx);
+ GEMV<DType, Lang>(transA, m, n, alpha, A.blob(), B.blob(), beta,
+ C->blob(), ctx);
},
{A.blob(), B.blob()}, {C->blob()});
});
@@ -440,11 +433,11 @@ void Mult(float alpha, const Tensor& A, float beta, const Tensor& B,
CHECK_EQ(C->shape()[0], m);
CHECK_EQ(A.Size(), m * k);
CHECK_EQ(B.Size(), n * k);
- TYPE_LIB_SWITCH(A.data_type(), DType, A.device()->type(), Dev, {
+ TYPE_LANG_SWITCH(A.data_type(), DType, A.device()->lang(), Lang, {
C->device()->Exec(
[transA, transB, m, n, k, alpha, A, beta, B, C](Context* ctx) {
- GEMM<DType, Dev>(transA, transB, m, n, k, alpha, A.blob(), B.blob(),
- beta, C->blob(), ctx);
+ GEMM<DType, Lang>(transA, transB, m, n, k, alpha, A.blob(),
+ B.blob(), beta, C->blob(), ctx);
},
{A.blob(), B.blob()}, {C->blob()});
});
@@ -452,30 +445,30 @@ void Mult(float alpha, const Tensor& A, float beta, const Tensor& B,
}
void Bernoulli(float p, Tensor* t) {
- TYPE_LIB_SWITCH(t->data_type(), DType, t->device()->type(), Dev, {
+ TYPE_LANG_SWITCH(t->data_type(), DType, t->device()->lang(), Lang, {
t->device()->Exec(
[p, t](Context* ctx) {
- Bernoulli<DType, Dev>(t->Size(), p, t->blob(), ctx);
+ Bernoulli<DType, Lang>(t->Size(), p, t->blob(), ctx);
},
{}, {t->blob()}, true);
});
}
void Uniform(float low, float high, Tensor* t) {
- TYPE_LIB_SWITCH(t->data_type(), DType, t->device()->type(), Dev, {
+ TYPE_LANG_SWITCH(t->data_type(), DType, t->device()->lang(), Lang, {
t->device()->Exec(
[low, high, t](Context* ctx) {
- Uniform<DType, Dev>(t->Size(), low, high, t->blob(), ctx);
+ Uniform<DType, Lang>(t->Size(), low, high, t->blob(), ctx);
},
{}, {t->blob()}, true);
});
}
void Gaussian(float mean, float std, Tensor* t) {
- TYPE_LIB_SWITCH(t->data_type(), DType, t->device()->type(), Dev, {
+ TYPE_LANG_SWITCH(t->data_type(), DType, t->device()->lang(), Lang, {
t->device()->Exec(
[mean, std, t](Context* ctx) {
- Gaussian<DType, Dev>(t->Size(), mean, std, t->blob(), ctx);
+ Gaussian<DType, Lang>(t->Size(), mean, std, t->blob(), ctx);
},
{}, {t->blob()}, true);
});
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/9d1bcb42/src/core/tensor/tensor_math.h
----------------------------------------------------------------------
diff --git a/src/core/tensor/tensor_math.h b/src/core/tensor/tensor_math.h
index aa520c9..53e979b 100644
--- a/src/core/tensor/tensor_math.h
+++ b/src/core/tensor/tensor_math.h
@@ -25,8 +25,8 @@ namespace singa {
/// \file math.h Math functions for linear algebra, neural net and random
/// operations.
-/// All functions have a template argument, DType for DataType, Lib for the
-/// backend library, e.g., lib::Cublas, lib::Cudnn, etc.
+/// All functions have a template argument, DType for DataType, Lang for the
+/// device programming language, e.g., Langice::kCpp, Langice::kCuda
/// Some operations would have many config/hyper-parameters, e.g., Conv, and
/// these config vary among diff implementations, e.g., cuda/cudnn/opencl.
@@ -45,133 +45,133 @@ class OpConf {
// ================Linear algebra functions====================================
/// ret[i] = |input[i]|
-template <typename DType, typename Lib>
+template <typename DType, typename Lang>
void Abs(int count, const Blob* input, Blob* ret, Context* ctx) {
LOG(FATAL) << "Not Implemented";
}
/// sum all elements of input into ret
-template <typename DType, typename Lib>
+template <typename DType, typename Lang>
void Sum(int count, const Blob* input, DType* ret, Context* ctx) {
LOG(FATAL) << "Not Implemented";
}
/// ret[i] = sign(input[i])
-template <typename DType, typename Lib>
+template <typename DType, typename Lang>
void Sign(int count, const Blob* input, Blob* ret, Context* ctx) {
LOG(FATAL) << "Not Implemented";
}
/// Base is e, Neper number. ret[i]=exp(input[i])
-template <typename DType, typename Lib>
+template <typename DType, typename Lang>
void Exp(int count, const Blob* input, Blob* ret, Context* ctx) {
LOG(FATAL) << "Not Implemented";
}
/// Natual logarithm, the base is e, Neper number ret[i]=log(input[i]).
-template <typename DType, typename Lib>
+template <typename DType, typename Lang>
void Log(int count, const Blob* input, Blob* ret, Context* ctx) {
LOG(FATAL) << "Not Implemented";
}
/// Element-wise operation, ret[i]=sqrt([input[i])
-template <typename DType, typename Lib>
+template <typename DType, typename Lang>
void Sqrt(int count, const Blob* input, Blob* ret, Context* ctx) {
LOG(FATAL) << "Not Implemented";
}
/// Element-wise operation, ret[i]=tanh([input[i])
-template <typename DType, typename Lib>
+template <typename DType, typename Lang>
void Tanh(int count, const Blob* input, Blob* ret, Context* ctx) {
LOG(FATAL) << "Not Implemented";
}
/// Element-wise operation, ret[i]=max(0, input[i])
-template <typename DType, typename Lib>
+template <typename DType, typename Lang>
void ReLU(int count, const Blob* input, Blob* ret, Context* ctx) {
LOG(FATAL) << "Not Implemented";
}
/// Element-wise operation, ret[i]=sigmoid([input[i])
-template <typename DType, typename Lib>
+template <typename DType, typename Lang>
void Sigmoid(int count, const Blob* input, Blob* ret, Context* ctx) {
LOG(FATAL) << "Not Implemented";
}
/// Do softmax for each row invidually
-template <typename DType, typename Lib>
+template <typename DType, typename Lang>
void Softmax(int nrow, int ncol, const Blob* input, Blob* ret, Context* ctx) {
LOG(FATAL) << "Not Implemented";
}
/// Element-wise operation, do v^x for every v from the input tensor
-template <typename DType, typename Lib>
+template <typename DType, typename Lang>
void Pow(int count, const Blob* input, DType x, Blob* ret, Context* ctx) {
LOG(FATAL) << "Not Implemented";
}
/// Element-wise operation, do v^x for every v from the lhs and every x from rhs
-template <typename DType, typename Lib>
+template <typename DType, typename Lang>
void Pow(int count, const Blob* lhs, const Blob* rhs, Blob* ret, Context* ctx) {
LOG(FATAL) << "Not Implemented";
}
/// Element-wise operation, clamp every element into [low, high]
/// if x>high, then x=high; if x<low, then x=low.
-template <typename DType, typename Lib>
+template <typename DType, typename Lang>
void Clamp(int count, DType low, DType high, const Blob* input, Blob* ret,
Context* ctx) {
LOG(FATAL) << "Not Implemented";
}
/// ret = input + x
-template <typename DType, typename Lib>
+template <typename DType, typename Lang>
void Add(int count, const Blob* input, DType x, Blob* ret, Context* ctx) {
LOG(FATAL) << "Not Implemented";
}
/// ret = input - x
-template <typename DType, typename Lib>
+template <typename DType, typename Lang>
void Sub(int count, const Blob* input, DType x, Blob* ret, Context* ctx) {
- Add<DType, Lib>(count, input, -x, ret, ctx);
+ Add<DType, Lang>(count, input, -x, ret, ctx);
}
/// ret = input * x
-template <typename DType, typename Lib>
+template <typename DType, typename Lang>
void EltwiseMult(int count, const Blob* input, DType x, Blob* ret, Context* ctx)
{
LOG(FATAL) << "Not Implemented";
}
/// ret = input / x
-template <typename DType, typename Lib>
+template <typename DType, typename Lang>
void Div(int count, const Blob* input, DType x, Blob* ret, Context* ctx) {
- EltwiseMult<DType, Lib>(count, input, DType(1) / x, ret, ctx);
+ EltwiseMult<DType, Lang>(count, input, DType(1) / x, ret, ctx);
}
/// ret = lhs + rhs
-template <typename DType, typename Lib>
+template <typename DType, typename Lang>
void Add(int count, const Blob* lhs, const Blob* rhs, Blob* ret, Context* ctx) {
LOG(FATAL) << "Not Implemented";
}
/// ret = lhs - rhs
-template <typename DType, typename Lib>
+template <typename DType, typename Lang>
void Sub(int count, const Blob* lhs, const Blob* rhs, Blob* ret, Context* ctx) {
LOG(FATAL) << "Not Implemented";
}
/// ret = lhs * rhs
-template <typename DType, typename Lib>
+template <typename DType, typename Lang>
void EltwiseMult(int count, const Blob* lhs, const Blob* rhs, Blob* ret,
Context* ctx) {
LOG(FATAL) << "Not Implemented";
}
/// ret = lhs / rhs
-template <typename DType, typename Lib>
+template <typename DType, typename Lang>
void Div(int count, const Blob* lhs, const Blob* rhs, Blob* ret, Context* ctx) {
LOG(FATAL) << "Not Implemented";
}
/// outer-product.
/// lhs and rhs are vectors of len m and n. ret is matrix of shape m * n
-template <typename DType, typename Lib>
+template <typename DType, typename Lang>
void Outer(int m, int n, const Blob* lhs, const Blob* rhs, Blob* ret,
Context* ctx) {
LOG(FATAL) << "Not Implemented";
@@ -179,26 +179,26 @@ void Outer(int m, int n, const Blob* lhs, const Blob* rhs, Blob* ret,
// TODO(wangwei) unify SumRow and SumCol.
/// Sum the rows of the input matrix into a vector
-template <typename DType, typename Lib>
+template <typename DType, typename Lang>
void SumRow(int nrow, int ncol, const Blob* input, Blob* ret, Context* ctx) {
LOG(FATAL) << "Not Implemented";
}
/// Sum the rows of the input matrix into a vector
-template <typename DType, typename Lib>
+template <typename DType, typename Lang>
void SumCol(int nrow, int ncol, const Blob* input, Blob* ret, Context* ctx) {
LOG(FATAL) << "Not Implemented";
}
// TODO(wangwei) unify AddRow and AddCol.
/// Add the vector v to every row of A as the row of ret
-template <typename DType, typename Lib>
+template <typename DType, typename Lang>
void AddRow(int nrow, int ncol, const Blob* A, const Blob* v, Blob* ret,
Context* ctx) {
LOG(FATAL) << "Not Implemented";
}
/// Add the vector v to every column of A as the column of ret
-template <typename DType, typename Lib>
+template <typename DType, typename Lang>
void AddCol(int nrow, int ncol, const Blob* A, const Blob* v, Blob* ret,
Context* ctx) {
LOG(FATAL) << "Not Implemented";
@@ -207,35 +207,35 @@ void AddCol(int nrow, int ncol, const Blob* A, const Blob* v, Blob* ret,
// ===== BLAS functions, ref to http://docs.nvidia.com/cuda/cublas
// ===== Level 1
/// return the index of the element with the max value.
-template <typename DType, typename Lib>
+template <typename DType, typename Lang>
void Amax(int count, const Blob* input, int* ret, Context* ctx) {
LOG(FATAL) << "Not Implemented";
}
/// return the index of the element with the min value.
-template <typename DType, typename Lib>
+template <typename DType, typename Lang>
void Amin(int count, const Blob* input, int* ret, Context* ctx) {
LOG(FATAL) << "Not Implemented";
}
/// ret = sum |x| for all x in input
-template <typename DType, typename Lib>
+template <typename DType, typename Lang>
void Asum(int count, const Blob* input, DType* ret, Context* ctx) {
LOG(FATAL) << "Not Implemented";
}
/// ret = alpha * input + ret
-template <typename DType, typename Lib>
+template <typename DType, typename Lang>
void Axpy(int count, DType alpha, const Blob* input, Blob* ret, Context* ctx) {
LOG(FATAL) << "Not Implemented";
}
/// ret *= x
-template <typename DType, typename Lib>
+template <typename DType, typename Lang>
void Scale(int count, DType x, Blob* ret, Context* ctx) {
LOG(FATAL) << "Not Implemented";
}
-template <typename DType, typename Lib>
+template <typename DType, typename Lang>
void Dot(int count, const Blob* lhs, const Blob* rhs, DType* ret,
Context* ctx) {
LOG(FATAL) << "Not Implemented";
@@ -244,7 +244,7 @@ void Dot(int count, const Blob* lhs, const Blob* rhs, DType* ret,
// ===== Level 2
/// ret = alpha * op(A) * v + beta * ret.
/// op(A) = A if trans = false; A^T otherwise; rows(op(A)) = m, cols(op(A)) = n.
-template <typename DType, typename Lib>
+template <typename DType, typename Lang>
void GEMV(bool trans, int m, int n, DType alpha, const Blob* A, const Blob* v,
DType beta, Blob* ret, Context* ctx) {
LOG(FATAL) << "Not Implemented";
@@ -253,7 +253,7 @@ void GEMV(bool trans, int m, int n, DType alpha, const Blob* A, const Blob* v,
// ===== Level 3
/// ret = alpha * op(A) * op(B) + beta * ret.
/// op(A) = A if trans = false; A^T otherwise; rows(ret) = m, cols(ret) = n.
-template <typename DType, typename Lib>
+template <typename DType, typename Lang>
void GEMM(bool transA, bool transB, int m, int n, int k, DType alpha,
const Blob* A, const Blob* B, DType beta, Blob* ret, Context* ctx) {
LOG(FATAL) << "Not Implemented";
@@ -263,47 +263,23 @@ void GEMM(bool transA, bool transB, int m, int n, int k, DType alpha,
/// Each element of ret would be 1 with prob p and 0 with 1-p. 0<= p <= 1
// Get the random generator from 'ctx'
// If DType is not float, then convert the threshold to DType
-template <typename DType, typename Lib>
+template <typename DType, typename Lang>
void Bernoulli(int count, float p, Blob* ret, Context* ctx) {
LOG(FATAL) << "Not Implemented";
}
// The random generator should be extracted from ctx.
// If DType is not float, then convert the low and high to DType
-template <typename DType, typename Lib>
+template <typename DType, typename Lang>
void Uniform(int count, float low, float high, Blob* ret, Context* ctx) {
LOG(FATAL) << "Not Implemented";
}
// The random generator should be extracted from ctx.
// If DType is not float, then convert the mean and std to DType
-template <typename DType, typename Lib>
+template <typename DType, typename Lang>
void Gaussian(int count, float mean, float std, Blob* ret, Context* ctx) {
LOG(FATAL) << "Not Implemented";
}
-/* ================Neural net functions=======================================
-template <typename DType, typename Lib>
-void ConvFwd(ConvConf* conf, const Blob* x, const Blob* w, Blob* y,
- Context* ctx) {
- LOG(FATAL) << "Not Implemented";
-}
-
-template <typename DType, typename Lib>
-void ConvBwdBias(const ConvConf* conf, const Blob* dy, Blob* db, Context* ctx) {
- LOG(FATAL) << "Not Implemented";
-}
-
-template <typename DType, typename Lib>
-void PoolFwd(const PoolConf* conf, const Blob* x, Blob* y, Context* ctx) {
- LOG(FATAL) << "Not Implemented";
-}
-
-template <typename DType, typename Lib>
-void PoolBwd(const PoolConf* conf, const Blob* y, const Blob* dy, const Blob* x,
- Blob* dx, Context* ctx) {
- LOG(FATAL) << "Not Implemented";
-}
-*/
-
} // namespace singa
#endif // SINGA_CORE_MATH_H_
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/9d1bcb42/src/core/tensor/tensor_math_cpp.h
----------------------------------------------------------------------
diff --git a/src/core/tensor/tensor_math_cpp.h b/src/core/tensor/tensor_math_cpp.h
index 2cbc225..b58e3bd 100644
--- a/src/core/tensor/tensor_math_cpp.h
+++ b/src/core/tensor/tensor_math_cpp.h
@@ -25,64 +25,60 @@
#endif
namespace singa {
-template<>
-void Add<float, lib::Cpp>(int count,
- const Blob* lhs,
- const Blob* rhs,
- Blob* ret,
- Context* ctx) {
+template <>
+void Add<float, lang::Cpp>(int count, const Blob* lhs, const Blob* rhs,
+ Blob* ret, Context* ctx) {
// CHECK_EQ(ctx->stream, nullptr);
- float *dptr = static_cast<float*>(ret->mutable_data());
- const float *lptr = static_cast<const float*>(lhs->data());
- const float *rptr = static_cast<const float*>(rhs->data());
+ float* dptr = static_cast<float*>(ret->mutable_data());
+ const float* lptr = static_cast<const float*>(lhs->data());
+ const float* rptr = static_cast<const float*>(rhs->data());
for (int i = 0; i < count; i++) {
dptr[i] = lptr[i] + rptr[i];
}
}
template <>
-void EltwiseMult<float, lib::Cpp>(int count, const Blob* input, float x, Blob* ret, Context* ctx)
-{
- float *dptr = static_cast<float*>(ret->mutable_data());
- const float *lptr = static_cast<const float*>(input->data());
+void EltwiseMult<float, lang::Cpp>(int count, const Blob* input, float x,
+ Blob* ret, Context* ctx) {
+ float* dptr = static_cast<float*>(ret->mutable_data());
+ const float* lptr = static_cast<const float*>(input->data());
for (int i = 0; i < count; i++) {
dptr[i] = lptr[i] * x;
}
}
template <>
-void EltwiseMult<float, lib::Cpp>(int count, const Blob* lhs, const Blob* rhs, Blob* ret, Context* ctx)
-{
- float *dptr = static_cast<float*>(ret->mutable_data());
- const float *lptr = static_cast<const float*>(lhs->data());
- const float *rptr = static_cast<const float*>(rhs->data());
+void EltwiseMult<float, lang::Cpp>(int count, const Blob* lhs, const Blob* rhs,
+ Blob* ret, Context* ctx) {
+ float* dptr = static_cast<float*>(ret->mutable_data());
+ const float* lptr = static_cast<const float*>(lhs->data());
+ const float* rptr = static_cast<const float*>(rhs->data());
for (int i = 0; i < count; i++) {
dptr[i] = lptr[i] * rptr[i];
}
}
template <>
-void Bernoulli<float, lib::Cpp>(int count, float p, Blob* ret,
- Context* ctx) {
+void Bernoulli<float, lang::Cpp>(int count, float p, Blob* ret, Context* ctx) {
std::bernoulli_distribution distribution(p);
float* ptr = static_cast<float*>(ret->mutable_data());
- for (int i = 0; i < count; i ++) {
+ for (int i = 0; i < count; i++) {
ptr[i] = distribution(ctx->random_generator) ? 1.0f : 0.0f;
}
}
template <>
-void Uniform<float, lib::Cpp>(int count, float low, float high, Blob* ret,
+void Uniform<float, lang::Cpp>(int count, float low, float high, Blob* ret,
Context* ctx) {
std::uniform_real_distribution<float> distribution(low, high);
float* ptr = static_cast<float*>(ret->mutable_data());
- for (int i = 0; i < count; i ++) {
+ for (int i = 0; i < count; i++) {
ptr[i] = static_cast<float>(distribution(ctx->random_generator));
}
}
template <>
-void Gaussian<float, lib::Cpp>(int count, float mean, float std, Blob* ret,
- Context* ctx) {
+void Gaussian<float, lang::Cpp>(int count, float mean, float std, Blob* ret,
+ Context* ctx) {
std::normal_distribution<float> distribution(mean, std);
float* ptr = static_cast<float*>(ret->mutable_data());
for (int i = 0; i < count; i++) {
@@ -90,14 +86,10 @@ void Gaussian<float, lib::Cpp>(int count, float mean, float std, Blob* ret,
}
}
-
#ifdef USE_CBLAS
-template<>
-void Dot<float, lib::Cpp>(int count,
- const Blob* lhs,
- const Blob* rhs,
- float* ret,
- Context* ctx) {
+template <>
+void Dot<float, lang::Cpp>(int count, const Blob* lhs, const Blob* rhs,
+ float* ret, Context* ctx) {
float dptr = ret->mutable_data(), lptr = lhs->data(), rptr = rhs->data();
*ret = cblas_sdot(count, lptr, 1, rptr, 1);
}
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/9d1bcb42/src/core/tensor/tensor_math_cuda.h
----------------------------------------------------------------------
diff --git a/src/core/tensor/tensor_math_cuda.h b/src/core/tensor/tensor_math_cuda.h
index c5ea3c4..991e8bb 100644
--- a/src/core/tensor/tensor_math_cuda.h
+++ b/src/core/tensor/tensor_math_cuda.h
@@ -26,7 +26,7 @@ namespace singa {
#ifdef USE_CUDA
template<>
-void Add<float, lib::Cuda>(int count, const Blob* lhs, const Blob* rhs,
+void Add<float, lang::Cuda>(int count, const Blob* lhs, const Blob* rhs,
Blob* ret, Context* ctx) {
/*
cublasSetStream(ctx->cublas_handle, ctx->stream);
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/9d1bcb42/src/proto/core.proto
----------------------------------------------------------------------
diff --git a/src/proto/core.proto b/src/proto/core.proto
index f99aba4..88d7f12 100644
--- a/src/proto/core.proto
+++ b/src/proto/core.proto
@@ -30,7 +30,7 @@ enum DataType {
kNumDataType = 5;
}
-enum DeviceType {
+enum LangType {
kCpp = 0;
kCuda = 1;
kOpencl = 2;
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/9d1bcb42/test/singa/test_cpp_cpu.cc
----------------------------------------------------------------------
diff --git a/test/singa/test_cpp_cpu.cc b/test/singa/test_cpp_cpu.cc
new file mode 100644
index 0000000..86654e1
--- /dev/null
+++ b/test/singa/test_cpp_cpu.cc
@@ -0,0 +1,71 @@
+/************************************************************
+*
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements. See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership. The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an
+* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+* KIND, either express or implied. See the License for the
+* specific language governing permissions and limitations
+* under the License.
+*
+*************************************************************/
+
+#include "gtest/gtest.h"
+#include "singa/core/device.h"
+#include "singa/proto/core.pb.h"
+
+using singa::CppCPU;
+using singa::Blob;
+TEST(CppCPU, Constructor) {
+ CppCPU dev(0, 1);
+ EXPECT_EQ(0, dev.id());
+}
+
+TEST(CppCPU, MemoryMallocFree) {
+ CppCPU dev(0, 1);
+ Blob* b = dev.NewBlob(4);
+ EXPECT_NE(nullptr, b);
+ EXPECT_EQ(4u, b->size());
+ dev.FreeBlob(b);
+}
+
+TEST(CppCPU, Exec) {
+ CppCPU dev(0, 1);
+ Blob* b = dev.NewBlob(4);
+ int x = 1, y =3, z = 0;
+ dev.Exec([x, y, &z](singa::Context *ctx) {
+ z = x + y;
+ }, {b}, {b}, false);
+ EXPECT_EQ(x + y, z);
+}
+
+TEST(CppCPU, CopyData) {
+ CppCPU dev(0, 1);
+ Blob* b = dev.NewBlob(4);
+ char s[] = {'a', 'b', 'c', 'x'};
+ dev.CopyDataFromHostPtr(b, s, 4);
+ const char* bstr = static_cast<const char*>(b->data());
+ EXPECT_EQ('a', bstr[0]);
+ EXPECT_EQ('b', bstr[1]);
+ EXPECT_EQ('x', bstr[3]);
+
+ Blob* c = dev.NewBlob(4);
+ dev.CopyDataToFrom(c, b, 4, singa::kHostToHost, 0, 0);
+ const char* cstr = static_cast<const char*>(c->data());
+
+ EXPECT_EQ('a', cstr[0]);
+ EXPECT_EQ('b', cstr[1]);
+ EXPECT_EQ('x', cstr[3]);
+ dev.FreeBlob(b);
+ dev.FreeBlob(c);
+}
+
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/9d1bcb42/test/singa/test_cpp_device.cc
----------------------------------------------------------------------
diff --git a/test/singa/test_cpp_device.cc b/test/singa/test_cpp_device.cc
deleted file mode 100644
index c302206..0000000
--- a/test/singa/test_cpp_device.cc
+++ /dev/null
@@ -1,71 +0,0 @@
-/************************************************************
-*
-* Licensed to the Apache Software Foundation (ASF) under one
-* or more contributor license agreements. See the NOTICE file
-* distributed with this work for additional information
-* regarding copyright ownership. The ASF licenses this file
-* to you under the Apache License, Version 2.0 (the
-* "License"); you may not use this file except in compliance
-* with the License. You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing,
-* software distributed under the License is distributed on an
-* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-* KIND, either express or implied. See the License for the
-* specific language governing permissions and limitations
-* under the License.
-*
-*************************************************************/
-
-#include "gtest/gtest.h"
-#include "singa/core/device.h"
-#include "singa/proto/core.pb.h"
-
-using singa::CppDevice;
-using singa::Blob;
-TEST(CppDevice, Constructor) {
- CppDevice dev(0, 1);
- EXPECT_EQ(0, dev.id());
-}
-
-TEST(CppDevice, MemoryMallocFree) {
- CppDevice dev(0, 1);
- Blob* b = dev.NewBlob(4);
- EXPECT_NE(nullptr, b);
- EXPECT_EQ(4u, b->size());
- dev.FreeBlob(b);
-}
-
-TEST(CppDevice, Exec) {
- CppDevice dev(0, 1);
- Blob* b = dev.NewBlob(4);
- int x = 1, y =3, z = 0;
- dev.Exec([x, y, &z](singa::Context *ctx) {
- z = x + y;
- }, {b}, {b}, false);
- EXPECT_EQ(x + y, z);
-}
-
-TEST(CppDevice, CopyData) {
- CppDevice dev(0, 1);
- Blob* b = dev.NewBlob(4);
- char s[] = {'a', 'b', 'c', 'x'};
- dev.CopyDataFromHostPtr(b, s, 4);
- const char* bstr = static_cast<const char*>(b->data());
- EXPECT_EQ('a', bstr[0]);
- EXPECT_EQ('b', bstr[1]);
- EXPECT_EQ('x', bstr[3]);
-
- Blob* c = dev.NewBlob(4);
- dev.CopyDataToFrom(c, b, 4, singa::kHostToHost, 0, 0);
- const char* cstr = static_cast<const char*>(c->data());
-
- EXPECT_EQ('a', cstr[0]);
- EXPECT_EQ('b', cstr[1]);
- EXPECT_EQ('x', cstr[3]);
- dev.FreeBlob(b);
- dev.FreeBlob(c);
-}
-
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/9d1bcb42/test/singa/test_cudnn_dropout.cc
----------------------------------------------------------------------
diff --git a/test/singa/test_cudnn_dropout.cc b/test/singa/test_cudnn_dropout.cc
index 9913074..5fdc554 100644
--- a/test/singa/test_cudnn_dropout.cc
+++ b/test/singa/test_cudnn_dropout.cc
@@ -48,7 +48,7 @@ TEST(CudnnDropout, Setup) {
TEST(CudnnDropout, Forward) {
const float x[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f};
size_t n = sizeof(x) / sizeof(float);
- singa::CudaDevice cuda(0, 1);
+ singa::CudaGPU cuda(0, 1);
singa::Tensor in(singa::Shape{n}, &cuda);
in.CopyDataFromHostPtr(x, n);
@@ -67,7 +67,7 @@ TEST(CudnnDropout, Forward) {
for (size_t i = 0; i < n; i++)
EXPECT_FLOAT_EQ(0, GetBitValue(mptr, i) * (GetBitValue(mptr, i) - 1));
- singa::CppDevice host(0, 1);
+ singa::CppCPU host(0, 1);
out1.ToDevice(&host);
const float* outptr1 = out1.data<const float*>();
EXPECT_EQ(n, out1.Size());
@@ -90,7 +90,7 @@ TEST(CudnnDropout, Forward) {
TEST(CudnnDropout, Backward) {
const float x[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f};
size_t n = sizeof(x) / sizeof(float);
- singa::CudaDevice cuda(0, 1);
+ singa::CudaGPU cuda(0, 1);
singa::Tensor in(singa::Shape{n}, &cuda);
in.CopyDataFromHostPtr(x, n);
@@ -109,7 +109,7 @@ TEST(CudnnDropout, Backward) {
grad.CopyDataFromHostPtr(dy, n);
const auto ret = drop.Backward(singa::kTrain, grad);
- singa::CppDevice host(0, 1);
+ singa::CppCPU host(0, 1);
singa::Tensor in_grad = ret.first;
in_grad.ToDevice(&host);
const float* dx = in_grad.data<const float*>();
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/9d1bcb42/test/singa/test_tensor.cc
----------------------------------------------------------------------
diff --git a/test/singa/test_tensor.cc b/test/singa/test_tensor.cc
index 8c3c901..b3f0c6b 100644
--- a/test/singa/test_tensor.cc
+++ b/test/singa/test_tensor.cc
@@ -59,10 +59,10 @@ TEST(TensorClass, AsType) {
TEST(TensorClass, ToDevice) {
Tensor t(Shape{2,3});
- EXPECT_EQ(static_cast<Device*>(&singa::hostDeviceSingleton), t.device());
- singa::CppDevice *dev = new singa::CppDevice(0, 1);
+ EXPECT_EQ(static_cast<Device*>(&singa::defaultDevice), t.device());
+ singa::CppCPU *dev = new singa::CppCPU(0, 1);
t.ToDevice(dev);
- EXPECT_NE(static_cast<Device*>(&singa::hostDeviceSingleton), t.device());
+ EXPECT_NE(static_cast<Device*>(&singa::defaultDevice), t.device());
}
TEST(TensorClass, CopyDataFromHostPtr) {