You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by th...@apache.org on 2019/02/01 00:51:07 UTC

[incubator-mxnet] branch master updated: add NAG optimizer to r api (#14023)

This is an automated email from the ASF dual-hosted git repository.

the pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 439377d  add NAG optimizer to r api (#14023)
439377d is described below

commit 439377d18e57e54cd576ac462ba35d92e19f0fb7
Author: Anirudh <an...@gmail.com>
AuthorDate: Thu Jan 31 16:50:49 2019 -0800

    add NAG optimizer to r api (#14023)
---
 R-package/R/optimizer.R                   | 105 ++++++++++++++++++++++++++++++
 R-package/tests/testthat/test_optimizer.R |  88 +++++++++++++++++++++----
 2 files changed, 182 insertions(+), 11 deletions(-)

diff --git a/R-package/R/optimizer.R b/R-package/R/optimizer.R
index 9a858d5..6f13f7b 100644
--- a/R-package/R/optimizer.R
+++ b/R-package/R/optimizer.R
@@ -453,6 +453,110 @@ mx.opt.adadelta <- function(rho = 0.90,
 }
 
 
+#' Create a Nesterov Accelerated SGD( NAG) optimizer.
+#'
+#' NAG optimizer is described in Aleksandar Botev. et al (2016).
+#' *NAG: A Nesterov accelerated SGD.*
+#' https://arxiv.org/pdf/1607.01981.pdf
+#'
+#' @param learning.rate float, default=0.01
+#'      The initial learning rate.
+#' @param momentum float, default=0
+#'      The momentum value
+#' @param wd float, default=0.0
+#'      L2 regularization coefficient added to all the weights.
+#' @param rescale.grad float, default=1.0
+#'      rescaling factor of gradient.
+#' @param clip_gradient float, optional, default=-1 (no clipping if < 0)
+#'      clip gradient in range [-clip_gradient, clip_gradient].
+#' @param lr_scheduler function, optional
+#'      The learning rate scheduler.
+#'
+mx.opt.nag <- function(learning.rate = 0.01,
+                       momentum = 0,
+                       wd = 0,
+                       rescale.grad = 1,
+                       clip_gradient = -1,
+                       lr_scheduler = NULL) {
+
+  lr <- learning.rate
+  count <- 0
+  num_update <- 0
+
+  nag <- new.env()
+  nag$lr <- learning.rate
+  nag$count <- 0
+  nag$num_update <- 0
+
+  create_exec <- function(index, weight_dim, ctx) {
+
+    weight <- mx.symbol.Variable("weight")
+    grad <- mx.symbol.Variable("grad")
+    mom <- mx.symbol.Variable("mom")
+    grad <- grad * rescale.grad
+
+    if (!is.null(clip_gradient)) {
+      if (clip_gradient >= 0) {
+        grad <- mx.symbol.clip(data = grad, a.min = -clip_gradient, a.max = clip_gradient)
+      }
+    }
+
+    if (momentum == 0) {
+
+      weight <- weight - lr * (grad + (wd * weight))
+      w <- mx.symbol.identity(weight, name = "w")
+      sym <- mx.symbol.Group(c(w))
+
+    } else {
+
+      mom <- momentum * mom + grad + wd * weight
+      grad <- momentum * mom + grad
+      weight <- weight - lr * grad
+
+      w <- mx.symbol.identity(weight, name = "w")
+      m <- mx.symbol.identity(mom, name = "m")
+      sym <- mx.symbol.Group(c(w, m))
+
+    }
+
+    exec <- mx.simple.bind(symbol = sym, weight = weight_dim, ctx = ctx, grad.req = "null")
+    return(exec)
+  }
+
+  update <- function(index, exec_w, weight, grad) {
+
+    if (!is.null(lr_scheduler)){
+      lr_scheduler(nag) ## changing lr
+      lr <- nag$lr
+      ## update count
+      indexKey <- paste0('ik', index)
+      if (!exists(envir = nag, x = indexKey, inherits = FALSE)){
+        nag[[indexKey]] <- 0
+      } else {
+        indexValue <- nag[[indexKey]]
+        nag[[indexKey]] <- indexValue + 1
+        nag$num_update <- max(nag$num_update, nag[[indexKey]])
+      }
+    }
+
+    mx.exec.update.arg.arrays(exec_w,
+                              arg.arrays = list(weight = weight,grad = grad),
+                              match.name = T)
+    mx.exec.forward(exec_w, is.train = F)
+
+    # update state
+    if (!is.null(exec_w$ref.outputs$m_output)){
+      mx.exec.update.arg.arrays(exec_w,
+                                arg.arrays = list(mom = exec_w$ref.outputs$m_output),
+                                match.name = T) 
+    }
+
+    return(exec_w$ref.outputs$w_output)
+  }
+  return(list(create_exec = create_exec, update = update))
+}
+
+
 #' Create an optimizer by name and parameters
 #'
 #' @param name The name of the optimizer
@@ -466,6 +570,7 @@ mx.opt.create <- function(name, ...) {
          "adam" = mx.opt.adam(...),
          "adagrad" = mx.opt.adagrad(...),
          "adadelta" = mx.opt.adadelta(...),
+         "nag" = mx.opt.nag(...),
          stop("Unknown optimizer ", name))
 }
 
diff --git a/R-package/tests/testthat/test_optimizer.R b/R-package/tests/testthat/test_optimizer.R
index 1ae7bc2..1eec83f 100644
--- a/R-package/tests/testthat/test_optimizer.R
+++ b/R-package/tests/testthat/test_optimizer.R
@@ -17,6 +17,12 @@
 
 context("optimizer")
 
+if (Sys.getenv("R_GPU_ENABLE") != "" & as.integer(Sys.getenv("R_GPU_ENABLE")) == 
+		1) {
+	mx.ctx.default(new = mx.gpu())
+	message("Using GPU for testing.")
+}
+
 test_that("sgd", {
   
   data <- mx.symbol.Variable("data")
@@ -30,14 +36,14 @@ test_that("sgd", {
   y <- mx.nd.array(c(5, 11, 16))
   w1 <- mx.nd.array(array(c(1.1, 1.8), dim = c(2, 1)))
   
-  exec <- mxnet:::mx.symbol.bind(symbol = loss, ctx = mx.cpu(), arg.arrays = list(data = x, 
+  exec <- mxnet:::mx.symbol.bind(symbol = loss, ctx = mx.ctx.default(), arg.arrays = list(data = x, 
     fc1_weight = w1, label = y), aux.arrays = NULL, grad.reqs = c("null", "write", 
     "null"))
   
   optimizer <- mx.opt.create("sgd", learning.rate = 1, momentum = 0, wd = 0, rescale.grad = 1, 
     clip_gradient = -1)
   
-  updaters <- mx.opt.get.updater(optimizer, exec$ref.arg.arrays, ctx = mx.cpu())
+  updaters <- mx.opt.get.updater(optimizer, exec$ref.arg.arrays, ctx = mx.ctx.default())
   
   mx.exec.forward(exec, is.train = T)
   mx.exec.backward(exec)
@@ -63,14 +69,14 @@ test_that("rmsprop", {
   y <- mx.nd.array(c(5, 11, 16))
   w1 <- mx.nd.array(array(c(1.1, 1.8), dim = c(2, 1)))
   
-  exec <- mxnet:::mx.symbol.bind(symbol = loss, ctx = mx.cpu(), arg.arrays = list(data = x, 
+  exec <- mxnet:::mx.symbol.bind(symbol = loss, ctx = mx.ctx.default(), arg.arrays = list(data = x, 
     fc1_weight = w1, label = y), aux.arrays = NULL, grad.reqs = c("null", "write", 
     "null"))
   
   optimizer <- mx.opt.create("rmsprop", learning.rate = 1, centered = TRUE, gamma1 = 0.95, 
     gamma2 = 0.9, epsilon = 1e-04, wd = 0, rescale.grad = 1, clip_gradient = -1)
   
-  updaters <- mx.opt.get.updater(optimizer, exec$ref.arg.arrays, ctx = mx.cpu())
+  updaters <- mx.opt.get.updater(optimizer, exec$ref.arg.arrays, ctx = mx.ctx.default())
   
   mx.exec.forward(exec, is.train = T)
   mx.exec.backward(exec)
@@ -97,14 +103,14 @@ test_that("adam", {
   y <- mx.nd.array(c(5, 11, 16))
   w1 <- mx.nd.array(array(c(1.1, 1.8), dim = c(2, 1)))
   
-  exec <- mxnet:::mx.symbol.bind(symbol = loss, ctx = mx.cpu(), arg.arrays = list(data = x, 
+  exec <- mxnet:::mx.symbol.bind(symbol = loss, ctx = mx.ctx.default(), arg.arrays = list(data = x, 
     fc1_weight = w1, label = y), aux.arrays = NULL, grad.reqs = c("null", "write", 
     "null"))
   
   optimizer <- mx.opt.create("adam", learning.rate = 1, beta1 = 0.9, beta2 = 0.999, 
     epsilon = 1e-08, wd = 0, rescale.grad = 1, clip_gradient = -1)
   
-  updaters <- mx.opt.get.updater(optimizer, exec$ref.arg.arrays, ctx = mx.cpu())
+  updaters <- mx.opt.get.updater(optimizer, exec$ref.arg.arrays, ctx = mx.ctx.default())
   
   mx.exec.forward(exec, is.train = T)
   mx.exec.backward(exec)
@@ -131,14 +137,14 @@ test_that("adagrad", {
   y <- mx.nd.array(c(5, 11, 16))
   w1 <- mx.nd.array(array(c(1.1, 1.8), dim = c(2, 1)))
   
-  exec <- mxnet:::mx.symbol.bind(symbol = loss, ctx = mx.cpu(), arg.arrays = list(data = x, 
+  exec <- mxnet:::mx.symbol.bind(symbol = loss, ctx = mx.ctx.default(), arg.arrays = list(data = x, 
     fc1_weight = w1, label = y), aux.arrays = NULL, grad.reqs = c("null", "write", 
     "null"))
   
   optimizer <- mx.opt.create("adagrad", learning.rate = 1, epsilon = 1e-08, wd = 0, 
     rescale.grad = 1, clip_gradient = -1)
   
-  updaters <- mx.opt.get.updater(optimizer, exec$ref.arg.arrays, ctx = mx.cpu())
+  updaters <- mx.opt.get.updater(optimizer, exec$ref.arg.arrays, ctx = mx.ctx.default())
   
   mx.exec.forward(exec, is.train = T)
   mx.exec.backward(exec)
@@ -164,22 +170,82 @@ test_that("adadelta", {
   y <- mx.nd.array(c(5, 11, 16))
   w1 <- mx.nd.array(array(c(1.1, 1.8), dim = c(2, 1)))
   
-  exec <- mxnet:::mx.symbol.bind(symbol = loss, ctx = mx.cpu(), arg.arrays = list(data = x, 
+  exec <- mxnet:::mx.symbol.bind(symbol = loss, ctx = mx.ctx.default(), arg.arrays = list(data = x, 
     fc1_weight = w1, label = y), aux.arrays = NULL, grad.reqs = c("null", "write", 
     "null"))
   
   optimizer <- mx.opt.create("adadelta", rho = 0.9, epsilon = 1e-05, wd = 0, rescale.grad = 1, 
     clip_gradient = -1)
   
-  updaters <- mx.opt.get.updater(optimizer, exec$ref.arg.arrays, ctx = mx.cpu())
+  updaters <- mx.opt.get.updater(optimizer, exec$ref.arg.arrays, ctx = mx.ctx.default())
   
   mx.exec.forward(exec, is.train = T)
   mx.exec.backward(exec)
   
   arg.blocks <- updaters(exec$ref.arg.arrays, exec$ref.grad.arrays)
   mx.exec.update.arg.arrays(exec, arg.blocks, skip.null = TRUE)
-  
+
   expect_equal(as.array(arg.blocks[[2]]), array(c(1.11, 1.81), dim = c(2, 1)), 
     tolerance = 0.1)
   
 })
+
+
+test_that("nag_no_momentum", {
+  data <- mx.symbol.Variable("data")
+  label <- mx.symbol.Variable("label")
+  fc_weight <- mx.symbol.Variable("fc_weight")
+  fc <- mx.symbol.FullyConnected(data = data, weight = fc_weight, no.bias = T,
+	name = "fc1", num_hidden = 1)
+  loss <- mx.symbol.LinearRegressionOutput(data = fc, label = label, name = "loss")
+
+  x <- mx.nd.array(array(1:6, dim = 2:3))
+	y <- mx.nd.array(c(5, 11, 16))
+	w1 <- mx.nd.array(array(c(1.1, 1.8), dim = c(2, 1)))
+
+	exec <- mxnet:::mx.symbol.bind(symbol = loss, ctx = mx.ctx.default(), arg.arrays = list(data = x,
+    fc1_weight = w1, label = y), aux.arrays = NULL, grad.reqs = c("null", "write", "null"))
+
+  optimizer <- mx.opt.create("nag", learning.rate = 1, momentum = 0, wd = 0, rescale.grad = 1,
+	  clip_gradient = -1)
+
+	updaters <- mx.opt.get.updater(optimizer, exec$ref.arg.arrays, ctx = mx.ctx.default())
+	
+  mx.exec.forward(exec, is.train = T)
+	mx.exec.backward(exec)
+		
+  arg.blocks <- updaters(exec$ref.arg.arrays, exec$ref.grad.arrays)
+	mx.exec.update.arg.arrays(exec, arg.blocks, skip.null = TRUE)
+		
+  expect_equal(as.array(arg.blocks[[2]]), array(c(1.4, 2.6), dim = c(2, 1)), tolerance = 0.05)
+})
+
+
+test_that("nag_momentum", {
+  data <- mx.symbol.Variable("data")
+  label <- mx.symbol.Variable("label")
+  fc_weight <- mx.symbol.Variable("fc_weight")
+  fc <- mx.symbol.FullyConnected(data = data, weight = fc_weight, no.bias = T,
+                                 name = "fc1", num_hidden = 1)
+  loss <- mx.symbol.LinearRegressionOutput(data = fc, label = label, name = "loss")
+  
+  x <- mx.nd.array(array(1:6, dim = 2:3))
+  y <- mx.nd.array(c(5, 11, 16))
+  w1 <- mx.nd.array(array(c(1.1, 1.8), dim = c(2, 1)))
+  
+  exec <- mxnet:::mx.symbol.bind(symbol = loss, ctx = mx.ctx.default(), arg.arrays = list(data = x,
+                                                                                          fc1_weight = w1, label = y), aux.arrays = NULL, grad.reqs = c("null", "write", "null"))
+  
+  optimizer <- mx.opt.create("nag", learning.rate = 1, momentum = 0.1, wd = 0, rescale.grad = 1,
+                             clip_gradient = 5)
+  
+  updaters <- mx.opt.get.updater(optimizer, exec$ref.arg.arrays, ctx = mx.ctx.default())
+  
+  mx.exec.forward(exec, is.train = T)
+  mx.exec.backward(exec)
+  
+  arg.blocks <- updaters(exec$ref.arg.arrays, exec$ref.grad.arrays)
+  mx.exec.update.arg.arrays(exec, arg.blocks, skip.null = TRUE)
+  
+  expect_equal(as.array(arg.blocks[[2]]), array(c(1.45, 2.65), dim = c(2, 1)), tolerance = 0.1)
+})