You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2017/12/05 18:22:06 UTC

[incubator-mxnet] branch master updated: rsp[:] = const_value, used in constant initializer for rsp weight (#8891)

This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 6c51629  rsp[:] = const_value, used in constant initializer for rsp weight (#8891)
6c51629 is described below

commit 6c5162915655d7df93c8598b345a6bdfd053e600
Author: Ziyue Huang <zy...@gmail.com>
AuthorDate: Wed Dec 6 02:21:59 2017 +0800

    rsp[:] = const_value, used in constant initializer for rsp weight (#8891)
    
    * rsp set to scalar
    
    * add test for const init of rsp
    
    * move set_rsp to ndarray_func
    
    * remove PopulateFullIdxRspImpl in init_op.h
---
 python/mxnet/ndarray/sparse.py               |  3 +-
 src/kvstore/kvstore_dist_server.h            |  7 ++++-
 src/ndarray/ndarray.cc                       | 44 +++++++++++++++-------------
 src/ndarray/ndarray_function.cc              | 13 ++++++++
 src/ndarray/ndarray_function.cu              | 12 ++++++++
 src/ndarray/ndarray_function.h               | 25 ++++++++++++++++
 src/operator/random/sample_op.h              |  6 +++-
 src/operator/tensor/init_op.cu               |  1 +
 src/operator/tensor/init_op.h                | 13 --------
 tests/python/gpu/test_operator_gpu.py        |  1 +
 tests/python/unittest/test_init.py           | 16 ++++++++++
 tests/python/unittest/test_sparse_ndarray.py |  5 ++--
 12 files changed, 107 insertions(+), 39 deletions(-)

diff --git a/python/mxnet/ndarray/sparse.py b/python/mxnet/ndarray/sparse.py
index 229044e..700dee0 100644
--- a/python/mxnet/ndarray/sparse.py
+++ b/python/mxnet/ndarray/sparse.py
@@ -658,8 +658,7 @@ class RowSparseNDArray(BaseSparseNDArray):
                 if value.handle is not self.handle:
                     value.copyto(self)
             elif isinstance(value, numeric_types):
-                raise ValueError("Assigning numeric types to RowSparseNDArray " \
-                                 "is not implemented yet.")
+                _internal._set_value(float(value), out=self)
             elif isinstance(value, (np.ndarray, np.generic)):
                 warnings.warn('Assigning non-NDArray object to RowSparseNDArray is not efficient',
                               RuntimeWarning)
diff --git a/src/kvstore/kvstore_dist_server.h b/src/kvstore/kvstore_dist_server.h
index 49aa001..f1637c4 100644
--- a/src/kvstore/kvstore_dist_server.h
+++ b/src/kvstore/kvstore_dist_server.h
@@ -247,7 +247,12 @@ class KVStoreDistServer {
             NDArray rsp = stored;
             stored.CheckAndAlloc({mshadow::Shape1(recved.shape()[0])});
             mshadow::Stream<cpu> *s = ctx.get_stream<cpu>();
-            op::PopulateFullIdxRspImpl(s, &rsp);
+            using namespace mxnet::op;
+            nnvm::dim_t nnr = rsp.shape()[0];
+            MSHADOW_IDX_TYPE_SWITCH(rsp.aux_type(rowsparse::kIdx), IType, {
+              IType* idx = rsp.aux_data(rowsparse::kIdx).dptr<IType>();
+              mxnet_op::Kernel<PopulateFullIdxRspKernel, cpu>::Launch(s, nnr, idx);
+            });
             mshadow::Copy(rsp.data().FlatTo1D<cpu, float>(),
                           recved.data().FlatTo1D<cpu, float>(), s);
             on_complete();
diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc
index 4a1963a..f09f168 100644
--- a/src/ndarray/ndarray.cc
+++ b/src/ndarray/ndarray.cc
@@ -309,30 +309,34 @@ void SetValueOp(const real_t &rhs, NDArray *out) {
   CHECK_NE(out->is_none(), true) << "Set value target must not be empty";
   // important: callback must always capture by value
   NDArray ret = *out;
-  switch (ret.ctx().dev_mask()) {
-    case cpu::kDevMask: {
-      Engine::Get()->PushSync([rhs, ret](RunContext ctx) {
-          CHECK(ret.storage_type() == kDefaultStorage);
-          TBlob tmp = ret.data();
-          ndarray::Eval<cpu>(rhs, &tmp, ctx);
-        }, ret.ctx(), {}, {ret.var()},
-        FnProperty::kNormal, 0, PROFILER_MESSAGE_FUNCNAME);
-      break;
-    }
+  const NDArrayStorageType stype = ret.storage_type();
+  Engine::Get()->PushSync([rhs, ret, stype](RunContext ctx) {
+      TBlob tmp = ret.data();
+      switch (ret.ctx().dev_mask()) {
+        case cpu::kDevMask: {
+          if (stype == kDefaultStorage) {
+            ndarray::Eval<cpu>(rhs, &tmp, ctx);
+          } else {
+            ndarray::Eval(ctx.get_stream<cpu>(), rhs, ret);
+          }
+          break;
+        }
 #if MXNET_USE_CUDA
-    case gpu::kDevMask: {
-      Engine::Get()->PushSync([rhs, ret](RunContext ctx) {
-          TBlob tmp = ret.data();
-          ndarray::Eval<gpu>(rhs, &tmp, ctx);
+        case gpu::kDevMask: {
+          if (stype == kDefaultStorage) {
+            ndarray::Eval<gpu>(rhs, &tmp, ctx);
+          } else {
+            ndarray::Eval(ctx.get_stream<gpu>(), rhs, ret);
+          }
           // Wait GPU kernel to complete
           ctx.get_stream<gpu>()->Wait();
-        }, ret.ctx(), {}, {ret.var()},
-        FnProperty::kNormal, 0, PROFILER_MESSAGE_FUNCNAME);
-      break;
-    }
+          break;
+        }
 #endif
-    default: LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR;
-  }
+        default: LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR;
+      }
+    }, ret.ctx(), {}, {ret.var()},
+  FnProperty::kNormal, 0, PROFILER_MESSAGE_FUNCNAME);
 }
 
 /*!
diff --git a/src/ndarray/ndarray_function.cc b/src/ndarray/ndarray_function.cc
index ef0adbe..552555a 100644
--- a/src/ndarray/ndarray_function.cc
+++ b/src/ndarray/ndarray_function.cc
@@ -183,5 +183,18 @@ void ElementwiseSum<cpu>(mshadow::Stream<cpu>* s,
   }
 }
 
+
+template<>
+void Eval<cpu>(mshadow::Stream<cpu> *s,
+               const real_t val, const NDArray& dst) {
+  NDArray temp = dst;
+  const NDArrayStorageType stype = temp.storage_type();
+  if (stype == kRowSparseStorage) {
+    SetValueRspImpl(s, val, &temp);
+  } else {
+    LOG(FATAL) << "Not implemented for storage type" << stype;
+  }
+}
+
 }  // namespace ndarray
 }  // namespace mxnet
diff --git a/src/ndarray/ndarray_function.cu b/src/ndarray/ndarray_function.cu
index 445f845..06b5ad4 100644
--- a/src/ndarray/ndarray_function.cu
+++ b/src/ndarray/ndarray_function.cu
@@ -203,5 +203,17 @@ void ElementwiseSum<gpu>(mshadow::Stream<gpu>* s,
   }
 }
 
+template<>
+void Eval<gpu>(mshadow::Stream<gpu> *s,
+               const real_t val, const NDArray& dst) {
+  NDArray temp = dst;
+  const NDArrayStorageType stype = temp.storage_type();
+  if (stype == kRowSparseStorage) {
+    SetValueRspImpl(s, val, &temp);
+  } else {
+    LOG(FATAL) << "Not implemented for storage type" << stype;
+  }
+}
+
 }  // namespace ndarray
 }  // namespace mxnet
diff --git a/src/ndarray/ndarray_function.h b/src/ndarray/ndarray_function.h
index 6e6df39..518bb77 100644
--- a/src/ndarray/ndarray_function.h
+++ b/src/ndarray/ndarray_function.h
@@ -32,6 +32,7 @@
 #include <mxnet/ndarray.h>
 #include <vector>
 #include "../operator/mshadow_op.h"
+#include "../operator/tensor/init_op.h"
 
 namespace mxnet {
 /*! \brief namespace to support all possible Ndarray operator */
@@ -179,6 +180,30 @@ void ElementwiseSum(mshadow::Stream<xpu>* s,
                     const std::vector<NDArray>& nds,
                     NDArray* out);
 
+/*!
+ * \brief Set a row_sparse NDArray with val
+ * \param s - The device stream
+ * \param val - The value to be set
+ * \param dst - NDArray which is to be set to val
+ */
+template<typename xpu>
+void SetValueRspImpl(mshadow::Stream<xpu> *s,
+                     const real_t val, NDArray *dst) {
+  CHECK_EQ(dst->storage_type(), kRowSparseStorage);
+  using namespace mxnet::op;
+  nnvm::dim_t nnr = dst->shape()[0];
+  dst->CheckAndAlloc({mshadow::Shape1(nnr)});
+  MSHADOW_IDX_TYPE_SWITCH(dst->aux_type(rowsparse::kIdx), IType, {
+    IType* idx = dst->aux_data(rowsparse::kIdx).dptr<IType>();
+    mxnet_op::Kernel<PopulateFullIdxRspKernel, xpu>::Launch(s, nnr, idx);
+  });
+  Fill<false>(s, dst->data(), kWriteTo, val);
+}
+
+template<typename xpu>
+void Eval(mshadow::Stream<xpu> *s,
+          const real_t val, const NDArray& dst);
+
 // broadcasting
 template <typename Device>
 void EvalBroadcast(TBlob const& src, TBlob* ret, int size, RunContext ctx);
diff --git a/src/operator/random/sample_op.h b/src/operator/random/sample_op.h
index 240825b..84bb851 100644
--- a/src/operator/random/sample_op.h
+++ b/src/operator/random/sample_op.h
@@ -428,13 +428,17 @@ void SampleComputeEx_(const nnvm::NodeAttrs& attrs,
                       const std::vector<OpReqType>& req,
                       const std::vector<NDArray>& outputs,
                       SampleMaster<xpu, Sampler> sample_master) {
+  using namespace mxnet::op;
   NDArray output = outputs[0];
   mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
   if (output.storage_type() == kRowSparseStorage) {
     // indices
     nnvm::dim_t nnr = output.shape()[0];
     output.CheckAndAlloc({mshadow::Shape1(nnr)});
-    PopulateFullIdxRspImpl(s, &output);
+    MSHADOW_IDX_TYPE_SWITCH(output.aux_type(rowsparse::kIdx), IType, {
+      IType* idx = output.aux_data(rowsparse::kIdx).dptr<IType>();
+      mxnet_op::Kernel<PopulateFullIdxRspKernel, xpu>::Launch(s, nnr, idx);
+    });
     // data
     TBlob out_blob = output.data();
     sample_master.op(attrs, ctx, req[0], &out_blob);
diff --git a/src/operator/tensor/init_op.cu b/src/operator/tensor/init_op.cu
index aeea289..37660e1 100644
--- a/src/operator/tensor/init_op.cu
+++ b/src/operator/tensor/init_op.cu
@@ -43,6 +43,7 @@ void FillZerosCsrImpl(mshadow::Stream<mshadow::gpu> *s, const NDArray& dst) {
   });
 }
 
+
 NNVM_REGISTER_OP(_zeros)
 .set_attr<FCompute>("FCompute<gpu>", FillCompute<gpu, 0>)
 .set_attr<FComputeEx>("FComputeEx<gpu>", FillComputeZerosEx<gpu>);
diff --git a/src/operator/tensor/init_op.h b/src/operator/tensor/init_op.h
index 4d89970..3f5014d 100644
--- a/src/operator/tensor/init_op.h
+++ b/src/operator/tensor/init_op.h
@@ -291,19 +291,6 @@ inline void FillDnsZerosRspImpl(mshadow::Stream<xpu> *s, NDArray *dst) {
   });
 }
 
-// Fill full indices NDArray with zeros by updating the aux shape.
-template<typename xpu>
-void PopulateFullIdxRspImpl(mshadow::Stream<xpu> *s, NDArray *dst) {
-  using namespace rowsparse;
-  CHECK_EQ(dst->storage_type(), kRowSparseStorage);
-  nnvm::dim_t nnr = dst->shape()[0];
-  dst->CheckAndAllocAuxData(kIdx, mshadow::Shape1(nnr));
-  MSHADOW_IDX_TYPE_SWITCH(dst->aux_type(kIdx), IType, {
-    IType* idx = dst->aux_data(kIdx).dptr<IType>();
-    mxnet_op::Kernel<PopulateFullIdxRspKernel, xpu>::Launch(s, nnr, idx);
-  });
-}
-
 /*!
  * \brief Fill a rsp NDArray with zeros by updating the aux shape.
  * \tparam xpu - cpu or gpu
diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py
index cecda21..ad18744 100644
--- a/tests/python/gpu/test_operator_gpu.py
+++ b/tests/python/gpu/test_operator_gpu.py
@@ -38,6 +38,7 @@ from test_sparse_ndarray import test_create_csr, test_create_row_sparse, test_sp
 from test_sparse_ndarray import test_create_sparse_nd_empty, test_create_sparse_nd_from_sparse
 from test_sparse_ndarray import test_create_sparse_nd_from_dense, test_create_sparse_nd_infer_shape
 from test_sparse_ndarray import test_sparse_nd_check_format, test_sparse_nd_copy
+from test_sparse_ndarray import test_sparse_nd_setitem
 from test_sparse_operator import *
 from test_ndarray import *
 
diff --git a/tests/python/unittest/test_init.py b/tests/python/unittest/test_init.py
index e642e65..efd6ef3 100644
--- a/tests/python/unittest/test_init.py
+++ b/tests/python/unittest/test_init.py
@@ -44,8 +44,24 @@ def test_aux_init():
     assert (mod.get_params()[1]['bn_moving_var'].asnumpy() == 1).all()
     assert (mod.get_params()[1]['bn_moving_mean'].asnumpy() == 0).all()
 
+def test_rsp_const_init():
+    def check_rsp_const_init(init, val):
+        shape = (10, 10)
+        x = mx.symbol.Variable("data", stype='csr')
+        weight = mx.symbol.Variable("weight", shape=(shape[1], 2),
+                                    init=init, stype='row_sparse')
+        dot = mx.symbol.sparse.dot(x, weight)
+        mod = mx.mod.Module(dot, label_names=None)
+        mod.bind(data_shapes=[('data', shape)])
+        mod.init_params()
+        assert (list(mod.get_params()[0].values())[0].asnumpy() == val).all()
+
+    check_rsp_const_init(mx.initializer.Constant(value=2.), 2.)
+    check_rsp_const_init(mx.initializer.Zero(), 0.)
+    check_rsp_const_init(mx.initializer.One(), 1.)
 
 if __name__ == '__main__':
     test_variable_init()
     test_default_init()
     test_aux_init()
+    test_rsp_const_init()
diff --git a/tests/python/unittest/test_sparse_ndarray.py b/tests/python/unittest/test_sparse_ndarray.py
index e59e476..e404997 100644
--- a/tests/python/unittest/test_sparse_ndarray.py
+++ b/tests/python/unittest/test_sparse_ndarray.py
@@ -93,7 +93,7 @@ def test_sparse_nd_setitem():
         x = mx.nd.zeros(shape=shape, stype=stype)
         x[:] = dst
         dst_nd = mx.nd.array(dst) if isinstance(dst, (np.ndarray, np.generic)) else dst
-        assert same(x.asnumpy(), dst_nd.asnumpy())
+        assert np.all(x.asnumpy() == dst_nd.asnumpy() if isinstance(dst_nd, NDArray) else dst)
 
     shape = rand_shape_2d()
     for stype in ['row_sparse', 'csr']:
@@ -102,7 +102,8 @@ def test_sparse_nd_setitem():
         check_sparse_nd_setitem(stype, shape, rand_ndarray(shape, stype))
         # numpy assignment
         check_sparse_nd_setitem(stype, shape, np.ones(shape))
-
+    # scalar assigned to row_sparse NDArray
+    check_sparse_nd_setitem('row_sparse', shape, 2)
 
 def test_sparse_nd_slice():
     shape = (rnd.randint(2, 10), rnd.randint(2, 10))    

-- 
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <co...@mxnet.apache.org>'].