You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/03/06 04:21:57 UTC

[GitHub] tlby commented on a change in pull request #9988: [Perl] Sparse feature.

tlby commented on a change in pull request #9988: [Perl] Sparse feature.
URL: https://github.com/apache/incubator-mxnet/pull/9988#discussion_r172401798
 
 

 ##########
 File path: perl-package/AI-MXNet/lib/AI/MXNet/Optimizer.pm
 ##########
 @@ -398,6 +423,461 @@ method update(
 
 __PACKAGE__->register;
 
+=head1 NAME
+
+    AI::MXNet::Signum - The Signum optimizer that takes the sign of gradient or momentum.
+=cut
+
+=head1 DESCRIPTION
+
+    The optimizer updates the weight by:
+
+        rescaled_grad = rescale_grad * clip(grad, clip_gradient) + wd * weight
+        state = momentum * state + (1-momentum)*rescaled_grad
+        weight = (1 - lr * wd_lh) * weight - lr * sign(state)
+
+    See the original paper at: https://jeremybernste.in/projects/amazon/signum.pdf
+
+    For details of the update algorithm see
+    :class:`~mxnet.ndarray.signsgd_update` and :class:`~mxnet.ndarray.signum_update`.
+
+    This optimizer accepts the following parameters in addition to those accepted
+    by :class:`.Optimizer`.
+
+    Parameters
+    ----------
+    momentum : float, optional
+       The momentum value.
+    wd_lh : float, optional
+       The amount of decoupled weight decay regularization, see details in the original paper at:\
+       https://arxiv.org/abs/1711.05101
+=cut
+
+package AI::MXNet::Signum;
+use Mouse;
+extends 'AI::MXNet::Optimizer';
+
+has 'momentum' => (is => "rw", isa => "Num", default => 0.9);
+has 'wd_lh'    => (is => "rw", isa => "Num", default => 0);
+
+method create_state(Index $index, AI::MXNet::NDArray $weight)
+{
+
+    my $momentum;
+    if($self->momentum != 0)
+    {
+        $momentum = AI::MXNet::NDArray->zeros(
+            $weight->shape,
+            ctx => $weight->context,
+            dtype=>$weight->dtype,
+            stype=>$weight->stype
+        );
+    }
+    return $momentum;
+}
+
+method update(
+    Index                     $index,
+    AI::MXNet::NDArray        $weight,
+    AI::MXNet::NDArray        $grad,
+    Maybe[AI::MXNet::NDArray|ArrayRef[Maybe[AI::MXNet::NDArray]]] $state
+)
+{
+    $self->_update_count($index);
+    my $lr = $self->_get_lr($index);
+    my $wd = $self->_get_wd($index);
+    my %kwargs = (
+        out => $weight,
+        lr  => $lr,
+        wd  => $wd,
+        rescale_grad => $self->rescale_grad,
+    );
+    if($self->momentum > 0)
+    {
+        $kwargs{momentum} = $self->momentum;
+    }
+    if($self->clip_gradient)
+    {
+        $kwargs{clip_gradient} = $self->clip_gradient;
+    }
+    if($self->wd_lh)
+    {
+        $kwargs{wd_lh} = $self->wd_lh;
+    }
+    if(defined $state)
+    {
+        AI::MXNet::NDArray->signum_update(
+            $weight, $grad, $state, %kwargs
+        );
+    }
+    else
+    {
+        AI::MXNet::NDArray->signsgd_update(
+            $weight, $grad, %kwargs
+        );
+    }
+}
+
+__PACKAGE__->register;
+
+=head1 NAME
+
+    AI::MXNet::FTML - The FTML optimizer.
+=cut
+
+=head1 DESCRIPTION
+
+    This class implements the optimizer described in
+    *FTML - Follow the Moving Leader in Deep Learning*,
+    available at http://proceedings.mlr.press/v70/zheng17a/zheng17a.pdf.
+
+    This optimizer accepts the following parameters in addition to those accepted
+    by AI::MXNet::Optimizer
+
+    Parameters
+    ----------
+    beta1 : float, optional
+        0 < beta1 < 1. Generally close to 0.5.
+    beta2 : float, optional
+        0 < beta2 < 1. Generally close to 1.
+    epsilon : float, optional
+        Small value to avoid division by 0.
+=cut
+
+package AI::MXNet::FTML;
+use Mouse;
+extends 'AI::MXNet::Optimizer';
+
+has 'beta1'   => (is => "rw", isa => "Num", default => 0.6);
+has 'beta2'   => (is => "rw", isa => "Num", default => 0.999);
+has 'epsilon' => (is => "rw", isa => "Num", default => 1e-8);
+
+method create_state(Index $index, AI::MXNet::NDArray $weight)
+{
+    return [
+        AI::MXNet::NDArray->zeros($weight->shape, ctx => $weight->context, dtype=>$weight->dtype), # d_0
+        AI::MXNet::NDArray->zeros($weight->shape, ctx => $weight->context, dtype=>$weight->dtype), # v_0
+        AI::MXNet::NDArray->zeros($weight->shape, ctx => $weight->context, dtype=>$weight->dtype), # z_0
+    ];
+}
+
+method update(
+    Index                     $index,
+    AI::MXNet::NDArray        $weight,
+    AI::MXNet::NDArray        $grad,
+    Maybe[AI::MXNet::NDArray|ArrayRef[Maybe[AI::MXNet::NDArray]]] $state
+)
+{
+    my $lr = $self->_get_lr($index);
+    my $wd = $self->_get_wd($index);
+    my $t = $self->_update_count($index);
+    my %kwargs = (
+        out => $weight,
+        lr  => $lr,
+        wd  => $wd,
+        t   => $t,
+        beta1 => $self->beta1,
+        beta2 => $self->beta2,
+        epsilon => $self->epsilon,
+        rescale_grad => $self->rescale_grad
+    );
+use Data::Dumper;
 
 Review comment:
   debugging cruft?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services