You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2019/10/17 22:20:37 UTC
[GitHub] [incubator-mxnet] anirudhacharya commented on issue #16303: LAMB
optimizer
anirudhacharya commented on issue #16303: LAMB optimizer
URL: https://github.com/apache/incubator-mxnet/pull/16303#issuecomment-543385490
sorry for the delay. i was away.
these are the changes i will do to the existing PR -
create multi-tensor kernel that performs the following part of the optimizer.
```python
grad *= self.rescale_grad
if self.clip_gradient is not None:
grad = clip(grad, -self.clip_gradient, self.clip_gradient)
mean, var = state
mean[:] = self.beta1 * mean + (1. - self.beta1) * grad
var[:] = self.beta2 * var + (1. - self.beta2) * mx.nd.square(grad)
if not self.bias_correction:
g = mean / (mx.nd.sqrt(var) + self.epsilon) + wd * weight
else:
mean_hat = mean / (1. - mx.nd.power(self.beta1, t))
var_hat = var / (1. - mx.nd.power(self.beta2, t))
g = mean_hat / mx.nd.sqrt(var_hat + self.epsilon) + wd * weight
```
then use the existing `multi_sum_sq` operator to calculate the norms
```python
r1 = weight.norm()
r2 = g.norm()
```
create another multi-tensor kernel to perform the clipping operation on r1 tensor and calculating the trust ratios.
```python
if not self.bias_correction:
r1 = mx.nd.minimum(mx.nd.maximum(r1, self.lower_bound), self.upper_bound)
# calculate lamb_trust_ratio
r = 1. if r1 == 0. or r2 == 0. else r1 / r2
lr *= r
```
and finally call the existing `multi_sgd_update` rule to update the weight matrices.
```ptyhon
# update weight
weight[:] -= lr * g
```
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
users@infra.apache.org
With regards,
Apache Git Services