You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2020/08/03 16:45:13 UTC

[GitHub] [incubator-mxnet] xidulu opened a new pull request #18852: [WIP] Gamma reparameterization gradient

xidulu opened a new pull request #18852:
URL: https://github.com/apache/incubator-mxnet/pull/18852


   ## Description ##
   
   Correctness test WIP.
   https://github.com/apache/incubator-mxnet/issues/18140
   
   ## Checklist ##
   ### Essentials ###
   Please feel free to remove inapplicable items for your PR.
   - [ ] The PR title starts with [MXNET-$JIRA_ID], where $JIRA_ID refers to the relevant [JIRA issue](https://issues.apache.org/jira/projects/MXNET/issues) created (except PRs with tiny changes)
   - [ ] Changes are complete (i.e. I finished coding on this PR)
   - [ ] All changes have test coverage:
   - Unit tests are added for small changes to verify correctness (e.g. adding a new operator)
   - Nightly tests are added for complicated/long-running ones (e.g. changing distributed kvstore)
   - Build tests will be added for build configuration changes (e.g. adding a new build option with NCCL)
   - [ ] Code is well-documented: 
   - For user-facing API changes, API doc string has been updated. 
   - For new C++ functions in header files, their functionalities and arguments are documented. 
   - For new examples, README.md is added to explain the what the example does, the source of the dataset, expected performance on test set and reference to the original paper if applicable
   - Check the API doc at https://mxnet-ci-doc.s3-accelerate.dualstack.amazonaws.com/PR-$PR_ID/$BUILD_ID/index.html
   - [ ] To the best of my knowledge, examples are either not affected by this change, or have been fixed to be compatible with this change
   
   ### Changes ###
   - [ ] Feature1, tests, (and when applicable, API doc)
   - [ ] Feature2, tests, (and when applicable, API doc)
   
   ## Comments ##
   - If this change is a backward incompatible change, why must this change be made.
   - Interesting edge cases to note here
   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] szhengac commented on a change in pull request #18852: [WIP] Gamma reparameterization gradient

Posted by GitBox <gi...@apache.org>.
szhengac commented on a change in pull request #18852:
URL: https://github.com/apache/incubator-mxnet/pull/18852#discussion_r464700663



##########
File path: src/operator/numpy/random/np_gamma_op.h
##########
@@ -394,6 +401,76 @@ void NumpyGammaForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
   }
 }
 
+template<typename xpu, int ndim, typename DType>
+inline void GammaReparamBackwardImpl(const OpContext& ctx,
+                                            const std::vector<TBlob>& inputs,
+                                            const std::vector<OpReqType>& req,
+                                            const std::vector<TBlob>& outputs,
+                                            const mxnet::TShape& new_ishape,
+                                            const mxnet::TShape& new_oshape,
+                                            const float scale) {
+  using namespace mshadow;
+  using namespace mshadow::expr;
+  using namespace broadcast;
+  using namespace mxnet_op;
+  Stream<xpu> *s = ctx.get_stream<xpu>();
+  const TBlob igrad = outputs[0].reshape(new_ishape);
+  // inputs: [grad_from_samples, alpha_tensor, samples]
+  const TBlob ograd = inputs[0].reshape(new_oshape);
+  const TBlob alpha = inputs[1].reshape(new_ishape);
+  const TBlob samples = inputs[2].reshape(new_oshape);
+  size_t workspace_size =
+      ReduceWorkspaceSize<ndim, DType>(s, igrad.shape_, req[0], ograd.shape_);
+  // Convert samples to standard gamma
+  Tensor<xpu, 1, char> workspace =
+      ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(workspace_size), s);
+  Reduce<red::sum, ndim, DType, op::mshadow_op::mul, op::mshadow_op::gamma_implicit_grad>(

Review comment:
       How does this line convert samples to standard gamma. I think we should multiply the sample by scale first?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] mxnet-bot commented on pull request #18852: [WIP] Gamma reparameterization gradient

Posted by GitBox <gi...@apache.org>.
mxnet-bot commented on pull request #18852:
URL: https://github.com/apache/incubator-mxnet/pull/18852#issuecomment-668125220


   Hey @xidulu , Thanks for submitting the PR 
   All tests are already queued to run once. If tests fail, you can trigger one or more tests again with the following commands: 
   - To trigger all jobs: @mxnet-bot run ci [all] 
   - To trigger specific jobs: @mxnet-bot run ci [job1, job2] 
   *** 
   **CI supported jobs**: [website, unix-cpu, edge, windows-cpu, centos-cpu, clang, windows-gpu, miscellaneous, sanity, unix-gpu, centos-gpu]
   *** 
   _Note_: 
    Only following 3 categories can trigger CI :PR Author, MXNet Committer, Jenkins Admin. 
   All CI tests must pass before the PR can be merged. 
   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] szhengac commented on a change in pull request #18852: Gamma reparameterization gradient

Posted by GitBox <gi...@apache.org>.
szhengac commented on a change in pull request #18852:
URL: https://github.com/apache/incubator-mxnet/pull/18852#discussion_r468195504



##########
File path: src/operator/numpy/random/np_gamma_op.h
##########
@@ -394,6 +401,83 @@ void NumpyGammaForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
   }
 }
 
+template<typename xpu, int ndim, typename DType>
+inline void GammaReparamBackwardImpl(const OpContext& ctx,
+                                            const std::vector<TBlob>& inputs,
+                                            const std::vector<OpReqType>& req,
+                                            const std::vector<TBlob>& outputs,
+                                            const mxnet::TShape& new_ishape,
+                                            const mxnet::TShape& new_oshape,
+                                            const float scale) {
+  using namespace mshadow;
+  using namespace mshadow::expr;
+  using namespace broadcast;
+  using namespace mxnet_op;
+  Stream<xpu> *s = ctx.get_stream<xpu>();
+  const TBlob igrad = outputs[0].reshape(new_ishape);
+  // inputs: [grad_from_samples, alpha_tensor, samples]
+  const TBlob ograd = inputs[0].reshape(new_oshape);
+  const TBlob alpha = inputs[1].reshape(new_ishape);
+  TBlob samples = inputs[2].reshape(new_oshape);
+  size_t workspace_size =
+      ReduceWorkspaceSize<ndim, DType>(s, igrad.shape_, req[0], ograd.shape_);
+  // Convert samples to standard gamma
+  // Kernel<StandarizeKernel<DType>, xpu>::Launch(
+  //       s, samples.Size(), samples.dptr<DType>(), scale);
+  Kernel<op_with_req<mshadow_op::div, kWriteTo>, xpu>::Launch(
+    s, samples.Size(), samples.dptr<DType>(), samples.dptr<DType>(), DType(scale));
+  Tensor<xpu, 1, char> workspace =
+      ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(workspace_size), s);
+  Reduce<red::sum, ndim, DType, op::mshadow_op::mul, op::mshadow_op::gamma_implicit_grad>(

Review comment:
       why do we need op::mshadow_op::mul here?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] xidulu commented on a change in pull request #18852: Gamma reparameterization gradient

Posted by GitBox <gi...@apache.org>.
xidulu commented on a change in pull request #18852:
URL: https://github.com/apache/incubator-mxnet/pull/18852#discussion_r469003544



##########
File path: src/operator/operator_tune.cc
##########
@@ -417,6 +417,7 @@ IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::logical_xor);  // NOLINT()
 IMPLEMENT_BINARY_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mshadow_op::bitwise_and);  // NOLINT()
 IMPLEMENT_BINARY_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mshadow_op::bitwise_xor);  // NOLINT()
 IMPLEMENT_BINARY_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mshadow_op::bitwise_or);  // NOLINT()
+// IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::gamma_implicit_grad);  // NOLINT()

Review comment:
       Operator tune is turned off for this OP, otherwise a floating point exception would occurred when importing MXNet. This line is now removed.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] szha commented on pull request #18852: Gamma reparameterization gradient

Posted by GitBox <gi...@apache.org>.
szha commented on pull request #18852:
URL: https://github.com/apache/incubator-mxnet/pull/18852#issuecomment-671757830


   @leandrolcampos you might be interested.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] xidulu commented on pull request #18852: Gamma reparameterization gradient

Posted by GitBox <gi...@apache.org>.
xidulu commented on pull request #18852:
URL: https://github.com/apache/incubator-mxnet/pull/18852#issuecomment-672412389


   @szha 
   Any more comments?


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] szha merged pull request #18852: Gamma reparameterization gradient

Posted by GitBox <gi...@apache.org>.
szha merged pull request #18852:
URL: https://github.com/apache/incubator-mxnet/pull/18852


   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] szha commented on a change in pull request #18852: Gamma reparameterization gradient

Posted by GitBox <gi...@apache.org>.
szha commented on a change in pull request #18852:
URL: https://github.com/apache/incubator-mxnet/pull/18852#discussion_r468945939



##########
File path: tests/python/unittest/test_numpy_op.py
##########
@@ -4777,6 +4777,47 @@ def _test_gamma_exception(shape, scale):
         assertRaises(ValueError, _test_gamma_exception, shape, scale)
 
 
+@with_seed()
+@use_np
+def test_gamma_grad():
+    class TestRandomGamma(HybridBlock):
+        def __init__(self, size, beta):
+            super(TestRandomGamma, self).__init__()
+            self._size = size
+            self._beta = beta
+
+        def hybrid_forward(self, F, a):
+            return F.np.random.gamma(a, self._beta, size=self._size)
+
+    shapes = [(1,), (2, 2), (4, 2, 2)]
+    alpha = [2.0, 5.0, 10.0]
+    beta = [0.5, 1.0, 1.5]
+    for (shape, a, b) in itertools.product(shapes, alpha, beta):
+        for hybridize in [True, False]:

Review comment:
       use pytest.mark.parametrize for simplicity and parallelism.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] szhengac commented on a change in pull request #18852: [WIP] Gamma reparameterization gradient

Posted by GitBox <gi...@apache.org>.
szhengac commented on a change in pull request #18852:
URL: https://github.com/apache/incubator-mxnet/pull/18852#discussion_r464699654



##########
File path: tests/python/unittest/test_numpy_op.py
##########
@@ -4777,6 +4777,57 @@ def _test_gamma_exception(shape, scale):
         assertRaises(ValueError, _test_gamma_exception, shape, scale)
 
 
+@with_seed()
+@use_np
+def test_gamma_grad():
+    class TestRandomGamma(HybridBlock):
+        def __init__(self, size, beta):
+            super(TestRandomGamma, self).__init__()
+            self._size = size
+            self._beta = beta
+
+        def hybrid_forward(self, F, a):
+            return F.np.random.gamma(a, 1.0, self._size) * self._beta

Review comment:
       see the section 5.1




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] szha commented on pull request #18852: Gamma reparameterization gradient

Posted by GitBox <gi...@apache.org>.
szha commented on pull request #18852:
URL: https://github.com/apache/incubator-mxnet/pull/18852#issuecomment-672686272


   @xidulu nice work!


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] szha commented on pull request #18852: [WIP] Gamma reparameterization gradient

Posted by GitBox <gi...@apache.org>.
szha commented on pull request #18852:
URL: https://github.com/apache/incubator-mxnet/pull/18852#issuecomment-668199540


   cc @szhengac @sxjscience 


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] xidulu commented on a change in pull request #18852: Gamma reparameterization gradient

Posted by GitBox <gi...@apache.org>.
xidulu commented on a change in pull request #18852:
URL: https://github.com/apache/incubator-mxnet/pull/18852#discussion_r468292891



##########
File path: src/operator/numpy/random/np_gamma_op.h
##########
@@ -394,6 +401,83 @@ void NumpyGammaForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
   }
 }
 
+template<typename xpu, int ndim, typename DType>
+inline void GammaReparamBackwardImpl(const OpContext& ctx,
+                                            const std::vector<TBlob>& inputs,
+                                            const std::vector<OpReqType>& req,
+                                            const std::vector<TBlob>& outputs,
+                                            const mxnet::TShape& new_ishape,
+                                            const mxnet::TShape& new_oshape,
+                                            const float scale) {
+  using namespace mshadow;
+  using namespace mshadow::expr;
+  using namespace broadcast;
+  using namespace mxnet_op;
+  Stream<xpu> *s = ctx.get_stream<xpu>();
+  const TBlob igrad = outputs[0].reshape(new_ishape);
+  // inputs: [grad_from_samples, alpha_tensor, samples]
+  const TBlob ograd = inputs[0].reshape(new_oshape);
+  const TBlob alpha = inputs[1].reshape(new_ishape);
+  TBlob samples = inputs[2].reshape(new_oshape);
+  size_t workspace_size =
+      ReduceWorkspaceSize<ndim, DType>(s, igrad.shape_, req[0], ograd.shape_);
+  // Convert samples to standard gamma
+  // Kernel<StandarizeKernel<DType>, xpu>::Launch(
+  //       s, samples.Size(), samples.dptr<DType>(), scale);
+  Kernel<op_with_req<mshadow_op::div, kWriteTo>, xpu>::Launch(
+    s, samples.Size(), samples.dptr<DType>(), samples.dptr<DType>(), DType(scale));
+  Tensor<xpu, 1, char> workspace =
+      ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(workspace_size), s);
+  Reduce<red::sum, ndim, DType, op::mshadow_op::mul, op::mshadow_op::gamma_implicit_grad>(

Review comment:
       This stands for the multiplication between d(Gamma(x;\alpha, \beta)) and gradient from downstream Ops. Similar to https://github.com/apache/incubator-mxnet/blob/743bbcbc7c8c85661a146d94ebd3196306650677/src/operator/tensor/elemwise_binary_broadcast_op.h#L745




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] szha commented on a change in pull request #18852: Gamma reparameterization gradient

Posted by GitBox <gi...@apache.org>.
szha commented on a change in pull request #18852:
URL: https://github.com/apache/incubator-mxnet/pull/18852#discussion_r468945803



##########
File path: src/operator/operator_tune.cc
##########
@@ -417,6 +417,7 @@ IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::logical_xor);  // NOLINT()
 IMPLEMENT_BINARY_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mshadow_op::bitwise_and);  // NOLINT()
 IMPLEMENT_BINARY_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mshadow_op::bitwise_xor);  // NOLINT()
 IMPLEMENT_BINARY_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mshadow_op::bitwise_or);  // NOLINT()
+// IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::gamma_implicit_grad);  // NOLINT()

Review comment:
       why the comment?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] szhengac commented on a change in pull request #18852: [WIP] Gamma reparameterization gradient

Posted by GitBox <gi...@apache.org>.
szhengac commented on a change in pull request #18852:
URL: https://github.com/apache/incubator-mxnet/pull/18852#discussion_r464699493



##########
File path: tests/python/unittest/test_numpy_op.py
##########
@@ -4777,6 +4777,57 @@ def _test_gamma_exception(shape, scale):
         assertRaises(ValueError, _test_gamma_exception, shape, scale)
 
 
+@with_seed()
+@use_np
+def test_gamma_grad():
+    class TestRandomGamma(HybridBlock):
+        def __init__(self, size, beta):
+            super(TestRandomGamma, self).__init__()
+            self._size = size
+            self._beta = beta
+
+        def hybrid_forward(self, F, a):
+            return F.np.random.gamma(a, 1.0, self._size) * self._beta

Review comment:
       From https://arxiv.org/pdf/1806.01851.pdf, it should be dividing?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org