You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@singa.apache.org by ch...@apache.org on 2020/07/13 09:11:23 UTC

[singa] branch dev updated: fix kint issue, cast kint to kfloat for computation and cast back

This is an automated email from the ASF dual-hosted git repository.

chrishkchris pushed a commit to branch dev
in repository https://gitbox.apache.org/repos/asf/singa.git


The following commit(s) were added to refs/heads/dev by this push:
     new 363cdd5  fix kint issue, cast kint to kfloat for computation and cast back
     new a05ef9c  Merge pull request #763 from dcslin/kint2
363cdd5 is described below

commit 363cdd5aa828200d992f2e960c0cb24d022596a0
Author: dcslin <13...@users.noreply.github.com>
AuthorDate: Sat Jul 11 05:16:45 2020 +0000

    fix kint issue, cast kint to kfloat for computation and cast back
---
 src/core/tensor/tensor.cc         | 133 +++++++++++++++++++++++++-------------
 src/core/tensor/tensor_math_cpp.h |   6 ++
 test/python/test_tensor.py        |  88 +++++++++++++++++++++++++
 3 files changed, 183 insertions(+), 44 deletions(-)

diff --git a/src/core/tensor/tensor.cc b/src/core/tensor/tensor.cc
index 39d190a..475aab5 100644
--- a/src/core/tensor/tensor.cc
+++ b/src/core/tensor/tensor.cc
@@ -681,6 +681,12 @@ void RepeatDataToFrom(bool broadcast_flag, const vector<size_t> &repeats,
         { __VA_ARGS__ }                                        \
         break;                                                 \
       }                                                        \
+      case ((kInt << _SwitchShift) + kCuda): {             \
+        typedef int DType;                                   \
+        typedef lang::Cuda Lang;                               \
+        { __VA_ARGS__ }                                        \
+        break;                                                 \
+      }                                                        \
       case ((kFloat32 << _SwitchShift) + kCpp): {              \
         typedef float DType;                                   \
         typedef lang::Cpp Lang;                                \
@@ -688,7 +694,7 @@ void RepeatDataToFrom(bool broadcast_flag, const vector<size_t> &repeats,
         break;                                                 \
       }                                                        \
       case ((kInt << _SwitchShift) + kCpp): {                  \
-        typedef float DType;                                   \
+        typedef int DType;                                   \
         typedef lang::Cpp Lang;                                \
         { __VA_ARGS__ }                                        \
         break;                                                 \
@@ -730,9 +736,7 @@ float Tensor::l2() const {
   TYPE_LANG_SWITCH(data_type_, DType, device_->lang(), Lang, {
     device_->Exec(
         [&nrm, this](Context *ctx) {
-          DType ret = DType(0);
-          Nrm2<DType, Lang>(*this, &ret, ctx);
-          nrm = TypeCast<DType, float>(ret);
+          Nrm2<DType, Lang>(*this, &nrm, ctx);
         },
         {this->block()}, {}, "L1");
   });
@@ -929,32 +933,58 @@ Tensor SoftMaxBackward(const Tensor &in, int axis, const Tensor &fdout) {
     });                                                                    \
   } while (0)
 
-#define GenBinaryTensorFn(op, fn)                              \
-  Tensor op(const Tensor &lhs, const Tensor &rhs) {            \
-    if (lhs.shape() != rhs.shape()) {                          \
-      auto lhs_ = Broadcast(lhs, rhs.shape());                 \
-      auto rhs_ = Broadcast(rhs, lhs.shape());                 \
-      Tensor ret(lhs_.shape(), lhs.device(), lhs.data_type()); \
-      fn(lhs_, rhs_, &ret);                                    \
-      return ret;                                              \
-    } else {                                                   \
-      Tensor ret(lhs.shape(), lhs.device(), lhs.data_type());  \
-      fn(lhs, rhs, &ret);                                      \
-      return ret;                                              \
-    }                                                          \
-  }                                                            \
-  void fn(const Tensor &lhs, const Tensor &rhs, Tensor *ret) { \
-    CHECK_EQ(lhs.device(), ret->device());                     \
-    CHECK_EQ(rhs.device(), ret->device());                     \
-    if (lhs.shape() != rhs.shape()) {                          \
-      auto lhs_ = Broadcast(lhs, rhs.shape());                 \
-      auto rhs_ = Broadcast(rhs, lhs.shape());                 \
-      CHECK(lhs_.shape() == ret->shape());                     \
-      EltwiseBinaryTensorFn(fn, lhs_, rhs_, ret);              \
-    } else {                                                   \
-      CHECK(lhs.shape() == ret->shape());                      \
-      EltwiseBinaryTensorFn(fn, lhs, rhs, ret);                \
-    }                                                          \
+#define GenBinaryTensorFn(op, fn)                                           \
+  Tensor op(const Tensor &lhs, const Tensor &rhs) {                         \
+    if (lhs.shape() != rhs.shape()) {                                       \
+      if (lhs.data_type() == kFloat32 && rhs.data_type() == kFloat32) {     \
+        auto lhs_ = Broadcast(lhs, rhs.shape());                            \
+        auto rhs_ = Broadcast(rhs, lhs.shape());                            \
+        Tensor ret(lhs_.shape(), lhs.device(), lhs.data_type());            \
+        fn(lhs_, rhs_, &ret);                                               \
+        return ret;                                                         \
+      } else {                                                              \
+        /* lhs tensor and rhs tensor are not both in float, cast to float */\
+        Tensor tmp_lhs = lhs.Clone().AsType(kFloat32);                      \
+        Tensor tmp_rhs = rhs.Clone().AsType(kFloat32);                      \
+        tmp_lhs = Broadcast(tmp_lhs, tmp_rhs.shape());                      \
+        tmp_rhs = Broadcast(tmp_rhs, tmp_lhs.shape());                      \
+        Tensor ret(tmp_lhs.shape(), tmp_lhs.device(), tmp_lhs.data_type()); \
+        fn(tmp_lhs, tmp_rhs, &ret);                                         \
+        /* if lhs and rhs are both int, cast back to int */                 \
+        if (lhs.data_type() == kInt && rhs.data_type() == kInt)             \
+          return ret.Clone().AsType(kInt);                                  \
+        return ret;                                                         \
+      }                                                                     \
+    } else {                                                                \
+      if (lhs.data_type() == kFloat32 && rhs.data_type() == kFloat32) {     \
+        Tensor ret(lhs.shape(), lhs.device(), lhs.data_type());             \
+        fn(lhs, rhs, &ret);                                                 \
+        return ret;                                                         \
+      } else {                                                              \
+        /* lhs tensor and rhs tensor are not both in float, cast to float */\
+        Tensor tmp_lhs = lhs.Clone().AsType(kFloat32);                      \
+        Tensor tmp_rhs = rhs.Clone().AsType(kFloat32);                      \
+        Tensor ret(tmp_lhs.shape(), tmp_lhs.device(), tmp_lhs.data_type()); \
+        fn(tmp_lhs, tmp_rhs, &ret);                                         \
+        /* if lhs and rhs are both int, cast back to int */                 \
+        if (lhs.data_type() == kInt && rhs.data_type() == kInt)             \
+          return ret.Clone().AsType(kInt);                                  \
+        return ret;                                                         \
+      }                                                                     \
+    }                                                                       \
+  }                                                                         \
+  void fn(const Tensor &lhs, const Tensor &rhs, Tensor *ret) {              \
+    CHECK_EQ(lhs.device(), ret->device());                                  \
+    CHECK_EQ(rhs.device(), ret->device());                                  \
+    if (lhs.shape() != rhs.shape()) {                                       \
+      auto lhs_ = Broadcast(lhs, rhs.shape());                              \
+      auto rhs_ = Broadcast(rhs, lhs.shape());                              \
+      CHECK(lhs_.shape() == ret->shape());                                  \
+      EltwiseBinaryTensorFn(fn, lhs_, rhs_, ret);                           \
+    } else {                                                                \
+      CHECK(lhs.shape() == ret->shape());                                   \
+      EltwiseBinaryTensorFn(fn, lhs, rhs, ret);                             \
+    }                                                                       \
   }  // namespace singa
 
 // boradcasting operations:
@@ -974,8 +1004,6 @@ GenBinaryTensorFn(ReLUBackward, ReLUBackward);
 #define EltwiseTensorScalarFn(fn, t, x, ret)                            \
   do {                                                                  \
     TYPE_LANG_SWITCH(t.data_type(), DType, t.device()->lang(), Lang, {  \
-      static_assert(std::is_same<SType, DType>::value,                  \
-                    "The Scalar type must match the Tensor data type"); \
       Tensor &retRef = *ret;                                            \
       ret->device()->Exec(                                              \
           [t, x, retRef](Context *ctx) mutable {                        \
@@ -985,18 +1013,35 @@ GenBinaryTensorFn(ReLUBackward, ReLUBackward);
     });                                                                 \
   } while (0)
 
-#define GenTensorScalarFn(op, fn)                             \
-  template <typename SType>                                   \
-  Tensor op(const Tensor &in, const SType x) {                \
-    Tensor ret(in.shape(), in.device(), in.data_type());      \
-    fn(in, x, &ret);                                          \
-    return ret;                                               \
-  }                                                           \
-  template <typename SType>                                   \
-  void fn(const Tensor &in, const SType x, Tensor *ret) {     \
-    EltwiseTensorScalarFn(fn, in, x, ret);                    \
-  }                                                           \
-  template Tensor op<float>(const Tensor &in, const float x); \
+#define GenTensorScalarFn(op, fn)                                          \
+  template <typename SType>                                                \
+  Tensor op(const Tensor &in, const SType x) {                             \
+    if (in.data_type() == kFloat32 && std::is_same<SType, float>::value){  \
+      Tensor ret(in.shape(), in.device(), in.data_type());                 \
+      fn(in, x, &ret);                                                     \
+      return ret;                                                          \
+    } else if (in.data_type() == kFloat32) {                               \
+      Tensor ret(in.shape(), in.device(), in.data_type());                 \
+      float tmp_x = x;                                                     \
+      fn(in, tmp_x, &ret);                                                 \
+      return ret;                                                          \
+    } else {                                                               \
+      /* tensor and scalar are not both in float, cast to float */         \
+      Tensor tmp_in = in.Clone().AsType(kFloat32);                         \
+      float tmp_x = x;                                                     \
+      Tensor ret(tmp_in.shape(), tmp_in.device(), tmp_in.data_type());     \
+      fn(tmp_in, tmp_x, &ret);                                             \
+      /* if tensor and scalar are both int, cast back to int */            \
+      if (in.data_type() == kInt && std::is_same<SType, int>::value)       \
+        return ret.Clone().AsType(kInt);                                   \
+      return ret;                                                          \
+    }                                                                      \
+  }                                                                        \
+  template <typename SType>                                                \
+  void fn(const Tensor &in, const SType x, Tensor *ret) {                  \
+    EltwiseTensorScalarFn(fn, in, x, ret);                                 \
+  }                                                                        \
+  template Tensor op<float>(const Tensor &in, const float x);              \
   template void fn<float>(const Tensor &in, const float x, Tensor *ret)
 
 GenTensorScalarFn(operator+, Add);
diff --git a/src/core/tensor/tensor_math_cpp.h b/src/core/tensor/tensor_math_cpp.h
index 724237f..2de7d0f 100644
--- a/src/core/tensor/tensor_math_cpp.h
+++ b/src/core/tensor/tensor_math_cpp.h
@@ -635,6 +635,12 @@ void Transform<float, lang::Cpp>(const Tensor &in, Tensor *out, Context *ctx) {
 }
 
 template <>
+void Transform<int, lang::Cpp>(const Tensor &in, Tensor *out, Context *ctx) {
+  auto identity = [](int a) { return a; };
+  traverse_unary<int>(in, out, identity);
+}
+
+template <>
 void Bernoulli<float, lang::Cpp>(const float p, Tensor *out, Context *ctx) {
   std::bernoulli_distribution distribution(p);
   float *outPtr = static_cast<float *>(out->block()->mutable_data());
diff --git a/test/python/test_tensor.py b/test/python/test_tensor.py
index 1beae86..ec989ee 100644
--- a/test/python/test_tensor.py
+++ b/test/python/test_tensor.py
@@ -478,6 +478,94 @@ class TestTensorMethods(unittest.TestCase):
         x = tensor.Tensor((4, 5, 3, 2), device=dev)
         x.gaussian(0, 1)
 
+    def _kfloat32_int(self, dev=gpu_dev):
+        np.random.seed(0)
+        x_val = np.random.random((2, 3)).astype(np.float32) * 10
+        x = tensor.from_numpy(x_val)
+        x.to_device(dev)
+        scalar = np.random.random((1,))[0] * 100
+        y = x + scalar
+        self.assertEqual(y.dtype, core_pb2.kFloat32)
+        np.testing.assert_array_almost_equal(tensor.to_numpy(y), x_val + scalar)
+
+    @unittest.skipIf(not singa_api.USE_CUDA, 'CUDA is not enabled')
+    def test_kfloat32_int_gpu(self):
+        self._kfloat32_int(gpu_dev)
+
+    def test_kfloat32_int_cpu(self):
+        self._kfloat32_int(cpu_dev)
+
+    def _kint_float(self, dev=gpu_dev):
+        np.random.seed(0)
+        x_val = np.random.randint(0, 10, (2, 3))
+        x = tensor.from_numpy(x_val)
+        x.to_device(dev)
+        scalar = np.random.random((1,))[0] * 100
+        y = x + scalar
+        self.assertEqual(y.dtype, core_pb2.kFloat32)
+        np.testing.assert_array_almost_equal(tensor.to_numpy(y), x_val + scalar)
+
+    @unittest.skipIf(not singa_api.USE_CUDA, 'CUDA is not enabled')
+    def test_kint_float_gpu(self):
+        self._kint_float(gpu_dev)
+
+    def test_kint_float_cpu(self):
+        self._kint_float(cpu_dev)
+
+    def _kint_kint(self, dev=gpu_dev):
+        a_np = np.array([[[17, 4, 9, 22, 18], [-9, 9, -1, -1, 4],
+                          [1, 14, 7, 1, 4], [3, 14, -2, 3, -8]],
+                         [[-25, 6, 8, -7, 22], [-14, 0, -1, 15, 14],
+                          [1, 3, -8, -19, -3], [1, 12, 12, -3, -3]],
+                         [[-10, -14, -17, 19, -5], [-4, -12, 7, -16, -2],
+                          [-8, 3, -5, -11, 0], [4, 0, 3, -6, -3]]],
+                        dtype=np.int32)
+        b_np = np.array([[[-6, -3, -8, -17, 1], [-4, -16, 4, -9, 0],
+                          [7, 1, 11, -12, 4], [-6, -8, -5, -3, 0]],
+                         [[-11, 9, 4, -15, 14], [18, 11, -1, -10, 10],
+                          [-4, 12, 2, 9, 3], [7, 0, 17, 1, 4]],
+                         [[18, -13, -12, 9, -11], [19, -4, -7, 19, 14],
+                          [18, 9, -8, 19, -2], [8, 9, -1, 6, 9]]],
+                        dtype=np.int32)
+        ta = tensor.from_numpy(a_np)
+        tb = tensor.from_numpy(b_np)
+        ta.to_device(dev)
+        tb.to_device(dev)
+        y = ta - tb
+        np.testing.assert_array_almost_equal(tensor.to_numpy(y), a_np - b_np)
+
+    @unittest.skipIf(not singa_api.USE_CUDA, 'CUDA is not enabled')
+    def test_kint_kint_cpu(self, dev=cpu_dev):
+        self._kint_kint(cpu_dev)
+
+    def test_kint_kint_gpu(self, dev=gpu_dev):
+        self._kint_kint(gpu_dev)
+
+    def _kint_kint_bc(self, dev=gpu_dev):
+        a_np = np.array([[[17, 4, 9, 22, 18], [-9, 9, -1, -1, 4],
+                          [1, 14, 7, 1, 4], [3, 14, -2, 3, -8]],
+                         [[-25, 6, 8, -7, 22], [-14, 0, -1, 15, 14],
+                          [1, 3, -8, -19, -3], [1, 12, 12, -3, -3]],
+                         [[-10, -14, -17, 19, -5], [-4, -12, 7, -16, -2],
+                          [-8, 3, -5, -11, 0], [4, 0, 3, -6, -3]]],
+                        dtype=np.int32)
+        b_np = np.array([[-6, -3, -8, -17, 1], [-4, -16, 4, -9, 0],
+                          [7, 1, 11, -12, 4], [-6, -8, -5, -3, 0]],
+                        dtype=np.int32)
+        ta = tensor.from_numpy(a_np)
+        tb = tensor.from_numpy(b_np)
+        ta.to_device(dev)
+        tb.to_device(dev)
+        y = ta - tb
+        np.testing.assert_array_almost_equal(tensor.to_numpy(y), a_np - b_np)
+
+    @unittest.skipIf(not singa_api.USE_CUDA, 'CUDA is not enabled')
+    def test_kint_kint_bc_cpu(self, dev=cpu_dev):
+        self._kint_kint_bc(cpu_dev)
+
+    def test_kint_kint_bc_gpu(self, dev=gpu_dev):
+        self._kint_kint_bc(gpu_dev)
+
 
 if __name__ == '__main__':
     unittest.main()