You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2017/12/09 00:55:40 UTC

[GitHub] piiswrong closed pull request #8935: change workspace allocation in random samplers

piiswrong closed pull request #8935: change workspace allocation in random samplers
URL: https://github.com/apache/incubator-mxnet/pull/8935
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/src/operator/random/sample_op.h b/src/operator/random/sample_op.h
index 84bb85145b..9fdff03886 100644
--- a/src/operator/random/sample_op.h
+++ b/src/operator/random/sample_op.h
@@ -242,44 +242,30 @@ using FSampleCompute = std::function<void (const nnvm::NodeAttrs& attrs,
 
 using mxnet::TBlob;
 
-// Convenience class that transfers a host based scalar into an
-// array on either the host or the device. Needed as
-// the core samplers expect parameters to be tensors located on the
-// appropriate device.
-template<typename xpu>
-Context AllocContext();
-template<>
-MSHADOW_FORCE_INLINE Context AllocContext<cpu>() { return Context::CPU(); }
-template<>
-MSHADOW_FORCE_INLINE Context AllocContext<gpu>() { return Context::GPU(); }
-
+// Allocates a single chunk of workspace memory and partitions it into three
+// workspace tensors that hold the seeds as well as the distribution parameters.
 template<typename xpu, typename DType>
-struct Scalar2Array {
-  Storage::Handle array;
-  Scalar2Array(DType scalar, const OpContext& ctx) {
-    Stream<xpu> *s = ctx.get_stream<xpu>();
-    array = Storage::Get()->Alloc(sizeof(DType), AllocContext<xpu>());
-    Tensor<xpu, 1, DType> src(Ref(), Shape1(1), s);
-    Copy(src, Tensor<cpu, 1, DType>(&scalar, Shape1(1)), s);
-  }
-  ~Scalar2Array() {
-    Storage::Get()->Free(array);
-  }
-  DType *Ref() { return static_cast<DType*>(array.dptr); }
-  Tensor<xpu, 1, DType> GetTensor() { return Tensor<xpu, 1, DType>(Ref(), Shape1(1)); }
-};
-
-// Convienience function to generate the required number of seeds for sampling
-template<typename xpu>
-MSHADOW_FORCE_INLINE Tensor<xpu, 1, unsigned int> GetSeeds(index_t N, const OpContext& ctx) {
+MSHADOW_FORCE_INLINE void GetSamplingTempData(index_t N, DType p1, DType p2, const OpContext& ctx,
+                                              Tensor<xpu, 1, unsigned int>* seeds,
+                                              Tensor<xpu, 1, DType>* parm1,
+                                              Tensor<xpu, 1, DType>* parm2) {
   Stream<xpu> *s = ctx.get_stream<xpu>();
   const index_t nSeeds(OptSampleSeedNum<xpu>(N));
-  Tensor<xpu, 1, unsigned int> seeds
-    = ctx.requested[1].get_space_typed<xpu, 1, unsigned int>(Shape1(nSeeds), ctx.get_stream<xpu>());
-  ctx.requested[0].get_random<xpu, float>(s)->GetRandInt(seeds);
-  return seeds;
+  // Combined memory requirement for the workspace data.
+  const index_t nInt(nSeeds + (2 * sizeof(DType) + sizeof(unsigned) - 1) / sizeof(unsigned));
+  Tensor<xpu, 1, unsigned> wspace
+    = ctx.requested[1].get_space_typed<xpu, 1, unsigned>(Shape1(nInt), s);
+  // Partition workspace into three chunks and initialize them.
+  *seeds = Tensor<xpu, 1, unsigned>(wspace.dptr_, Shape1(nSeeds), s);
+  ctx.requested[0].get_random<xpu, float>(s)->GetRandInt(*seeds);
+  DType *pspace = static_cast<DType*>(static_cast<void*>(wspace.dptr_+nSeeds));
+  *parm1 = Tensor<xpu, 1, DType>(pspace, Shape1(1), s);
+  Copy(*parm1, Tensor<cpu, 1, DType>(&p1, Shape1(1)), s);
+  *parm2 = Tensor<xpu, 1, DType>(pspace+1, Shape1(1), s);
+  Copy(*parm2, Tensor<cpu, 1, DType>(&p2, Shape1(1)), s);
 }
 
+
 template<typename xpu, typename Sampler>
 struct SampleMaster;
 
@@ -292,12 +278,14 @@ struct SampleMaster<xpu, UniformSampler<xpu>> {
     Stream<xpu> *s = ctx.get_stream<xpu>();
     const SampleUniformParam& param = nnvm::get<SampleUniformParam>(attrs.parsed);
     CHECK_GE(param.high, param.low) << "low must be less or equal to high in uniform distribution";
-    Scalar2Array<xpu, float> low(param.low, ctx), high(param.high, ctx);
-    Tensor<xpu, 1, unsigned int> seeds(GetSeeds<xpu>(outputs->Size(), ctx));
+    Tensor<xpu, 1, unsigned int> seeds;
+    Tensor<xpu, 1, float> low, high;
+    GetSamplingTempData<xpu, float>(outputs->Size(), param.low, param.high, ctx,
+                                    &seeds, &low, &high);
     UniformSampler<xpu> sampler;
     MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, OType, {
       Tensor<xpu, 1, OType> out = outputs->FlatTo1D<xpu, OType>(s);
-      sampler.Sample(low.GetTensor(), high.GetTensor(), out, seeds, s);
+      sampler.Sample(low, high, out, seeds, s);
     });
   }
 };
@@ -311,12 +299,14 @@ struct SampleMaster<xpu, NormalSampler<xpu>> {
     Stream<xpu> *s = ctx.get_stream<xpu>();
     const SampleNormalParam& param = nnvm::get<SampleNormalParam>(attrs.parsed);
     CHECK_GT(param.scale, 0) << "scale parameter in gaussian has to be positive";
-    Scalar2Array<xpu, float> loc(param.loc, ctx), scale(param.scale, ctx);
-    Tensor<xpu, 1, unsigned int> seeds(GetSeeds<xpu>(outputs->Size(), ctx));
+    Tensor<xpu, 1, unsigned int> seeds;
+    Tensor<xpu, 1, float> loc, scale;
+    GetSamplingTempData<xpu, float>(outputs->Size(), param.loc, param.scale, ctx,
+                                    &seeds, &loc, &scale);
     NormalSampler<xpu> sampler;
     MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, OType, {
       Tensor<xpu, 1, OType> out = outputs->FlatTo1D<xpu, OType>(s);
-      sampler.Sample(loc.GetTensor(), scale.GetTensor(), out, seeds, s);
+      sampler.Sample(loc, scale, out, seeds, s);
     });
   }
 };
@@ -331,12 +321,14 @@ struct SampleMaster<xpu, GammaSampler<xpu>> {
     const SampleGammaParam& param = nnvm::get<SampleGammaParam>(attrs.parsed);
     CHECK_GT(param.alpha, 0) << "alpha parameter in gamma distribution has to be positive";
     CHECK_GT(param.beta, 0) << "beta parameter in gamma distribution has to be positive";
-    Scalar2Array<xpu, float> alpha(param.alpha, ctx), beta(param.beta, ctx);
-    Tensor<xpu, 1, unsigned int> seeds(GetSeeds<xpu>(outputs->Size(), ctx));
+    Tensor<xpu, 1, unsigned int> seeds;
+    Tensor<xpu, 1, float> alpha, beta;
+    GetSamplingTempData<xpu, float>(outputs->Size(), param.alpha, param.beta, ctx,
+                                    &seeds, &alpha, &beta);
     GammaSampler<xpu> sampler;
     MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, OType, {
       Tensor<xpu, 1, OType> out = outputs->FlatTo1D<xpu, OType>(s);
-      sampler.Sample(alpha.GetTensor(), beta.GetTensor(), out, seeds, s);
+      sampler.Sample(alpha, beta, out, seeds, s);
     });
   }
 };
@@ -350,12 +342,13 @@ struct SampleMaster<xpu, ExponentialSampler<xpu>> {
     Stream<xpu> *s = ctx.get_stream<xpu>();
     const SampleExponentialParam& param = nnvm::get<SampleExponentialParam>(attrs.parsed);
     CHECK_GT(param.lam, 0) << "lambda parameter in exponential distribution has to be positive";
-    Scalar2Array<xpu, float> lam(param.lam, ctx);
-    Tensor<xpu, 1, unsigned int> seeds(GetSeeds<xpu>(outputs->Size(), ctx));
+    Tensor<xpu, 1, unsigned int> seeds;
+    Tensor<xpu, 1, float> lam, dummy;
+    GetSamplingTempData<xpu, float>(outputs->Size(), param.lam, 0, ctx, &seeds, &lam, &dummy);
     ExponentialSampler<xpu> sampler;
     MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, OType, {
       Tensor<xpu, 1, OType> out = outputs->FlatTo1D<xpu, OType>(s);
-      sampler.Sample(lam.GetTensor(), out, seeds, s);
+      sampler.Sample(lam, out, seeds, s);
     });
   }
 };
@@ -369,12 +362,13 @@ struct SampleMaster<xpu, PoissonSampler<xpu>> {
     Stream<xpu> *s = ctx.get_stream<xpu>();
     const SamplePoissonParam& param = nnvm::get<SamplePoissonParam>(attrs.parsed);
     CHECK_GE(param.lam, 0) << "lambda parameter in poisson distribution has to be non-negative";
-    Scalar2Array<xpu, float> lam(param.lam, ctx);
-    Tensor<xpu, 1, unsigned int> seeds(GetSeeds<xpu>(outputs->Size(), ctx));
+    Tensor<xpu, 1, unsigned int> seeds;
+    Tensor<xpu, 1, float> lam, dummy;
+    GetSamplingTempData<xpu, float>(outputs->Size(), param.lam, 0, ctx, &seeds, &lam, &dummy);
     PoissonSampler<xpu> sampler;
     MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, OType, {
       Tensor<xpu, 1, OType> out = outputs->FlatTo1D<xpu, OType>(s);
-      sampler.Sample(lam.GetTensor(), out, seeds, s);
+      sampler.Sample(lam, out, seeds, s);
     });
   }
 };
@@ -389,12 +383,13 @@ struct SampleMaster<xpu, NegativeBinomialSampler<xpu>> {
     const SampleNegBinomialParam& param = nnvm::get<SampleNegBinomialParam>(attrs.parsed);
     CHECK_GE(param.k, 0) << "k parameter in negative binomial distribution has to be non-negative";
     CHECK_GE(param.p, 0) << "p parameter in negative binomial distribution has to be non-negative";
-    Scalar2Array<xpu, float> k(param.k, ctx), p(param.p, ctx);
-    Tensor<xpu, 1, unsigned int> seeds(GetSeeds<xpu>(outputs->Size(), ctx));
+    Tensor<xpu, 1, unsigned int> seeds;
+    Tensor<xpu, 1, float> k, p;
+    GetSamplingTempData<xpu, float>(outputs->Size(), param.k, param.p, ctx, &seeds, &k, &p);
     NegativeBinomialSampler<xpu> sampler;
     MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, OType, {
       Tensor<xpu, 1, OType> out = outputs->FlatTo1D<xpu, OType>(s);
-      sampler.Sample(k.GetTensor(), p.GetTensor(), out, seeds, s);
+      sampler.Sample(k, p, out, seeds, s);
     });
   }
 };
@@ -411,12 +406,14 @@ struct SampleMaster<xpu, GeneralizedNegativeBinomialSampler<xpu>> {
       << "mu parameter in generalized negative binomial distribution has to be non-negative";
     CHECK_GE(param.alpha, 0)
       << "alpha parameter in generalized negative binomial distribution has to be non-negative";
-    Scalar2Array<xpu, float> mu(param.mu, ctx), alpha(param.alpha, ctx);
-    Tensor<xpu, 1, unsigned int> seeds(GetSeeds<xpu>(outputs->Size(), ctx));
+    Tensor<xpu, 1, unsigned int> seeds;
+    Tensor<xpu, 1, float> mu, alpha;
+    GetSamplingTempData<xpu, float>(outputs->Size(), param.mu, param.alpha, ctx,
+                                    &seeds, &mu, &alpha);
     GeneralizedNegativeBinomialSampler<xpu> sampler;
     MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, OType, {
       Tensor<xpu, 1, OType> out = outputs->FlatTo1D<xpu, OType>(s);
-      sampler.Sample(mu.GetTensor(), alpha.GetTensor(), out, seeds, s);
+      sampler.Sample(mu, alpha, out, seeds, s);
     });
   }
 };


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services