You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/03/31 05:43:36 UTC

[GitHub] hetong007 closed pull request #9803: R Metrics

hetong007 closed pull request #9803: R Metrics
URL: https://github.com/apache/incubator-mxnet/pull/9803
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/R-package/R/metric.R b/R-package/R/metric.R
index f8d9c33a726..9e1913a0877 100644
--- a/R-package/R/metric.R
+++ b/R-package/R/metric.R
@@ -6,7 +6,7 @@ mx.metric.custom <- function(name, feval) {
     c(0, 0)
   }
   update <- function(label, pred, state) {
-    m <- feval(as.array(label), as.array(pred))
+    m <- feval(label, pred)
     state <- c(state[[1]] + 1, state[[2]] + m)
     return(state)
   }
@@ -22,69 +22,90 @@ mx.metric.custom <- function(name, feval) {
 #'
 #' @export
 mx.metric.accuracy <- mx.metric.custom("accuracy", function(label, pred) {
-  ypred = max.col(t(as.array(pred)), tie="first")
-  return(sum((as.array(label) + 1) == ypred) / length(label))
+  pred <- mx.nd.argmax(data = pred, axis = 1, keepdims = F)
+  res <- mx.nd.mean(label == pred)
+  return(as.array(res))
 })
 
-#' Helper function for top-k accuracy
-is.num.in.vect <- function(vect, num){
-  resp <- any(is.element(vect, num))
-  return(resp)
-}
-
 #' Top-k accuracy metric for classification
 #'
 #' @export
 mx.metric.top_k_accuracy <- mx.metric.custom("top_k_accuracy", function(label, pred, top_k = 5) {
-  if(top_k == 1){
-    return(mx.metric.accuracy(label,pred))
-  } else{
-    ypred <- apply(pred,2,function(x) order(x, decreasing=TRUE)[seq_len(top_k)])
-    ans <- apply(ypred, 2, is.num.in.vect, num = as.array(label + 1))
-    acc <- sum(ans)/length(label)  
-    return(acc)
-  }
+  label <- mx.nd.reshape(data = label, shape = c(1,0))
+  pred <- mx.nd.topk(data = pred, axis = 1, k = top_k, ret_typ = "indices")
+  pred <- mx.nd.broadcast.equal(lhs = pred, rhs = label)
+  res <- mx.nd.mean(mx.nd.sum(data = pred, axis = 1, keepdims = F))
+  return(as.array(res))
 })
 
 #' MSE (Mean Squared Error) metric for regression
 #'
 #' @export
 mx.metric.mse <- mx.metric.custom("mse", function(label, pred) {
-  res <- mean((label-pred)^2)
-  return(res)
+  pred <- mx.nd.reshape(pred, shape = 0)
+  res <- mx.nd.mean(mx.nd.square(label-pred))
+  return(as.array(res))
 })
-    
+
 #' RMSE (Root Mean Squared Error) metric for regression
 #'
 #' @export
 mx.metric.rmse <- mx.metric.custom("rmse", function(label, pred) {
-  res <- sqrt(mean((label-pred)^2))
-  return(res)
+  pred <- mx.nd.reshape(pred, shape = 0)
+  res <- mx.nd.sqrt(mx.nd.mean(mx.nd.square(label-pred)))
+  return(as.array(res))
 })
 
 #' MAE (Mean Absolute Error) metric for regression
 #'
 #' @export
 mx.metric.mae <- mx.metric.custom("mae", function(label, pred) {
-  res <- mean(abs(label-pred))
-  return(res)
+  pred <- mx.nd.reshape(pred, shape = 0)
+  res <- mx.nd.mean(mx.nd.abs(label-pred))
+  return(as.array(res))
 })
 
 #' RMSLE (Root Mean Squared Logarithmic Error) metric for regression
 #'
 #' @export
 mx.metric.rmsle <- mx.metric.custom("rmsle", function(label, pred) {
-  res <- sqrt(mean((log(pred + 1) - log(label + 1))^2))
-  return(res)
+  pred <- mx.nd.reshape(pred, shape = 0)
+  res <- mx.nd.sqrt(mx.nd.mean(mx.nd.square(mx.nd.log1p(pred) - mx.nd.log1p(label))))
+  return(as.array(res))
 })
 
 #' Perplexity metric for language model
 #'
 #' @export
-mx.metric.Perplexity <- mx.metric.custom("Perplexity", function(label, pred) {
-  label_probs <- as.array(mx.nd.choose.element.0index(pred, label))
-  batch <- length(label_probs)
-  NLL <- -sum(log(pmax(1e-15, as.array(label_probs)))) / batch
-  Perplexity <- exp(NLL)
-  return(Perplexity)
+mx.metric.Perplexity <- mx.metric.custom("Perplexity", function(label, pred, mask_element = -1) {
+  
+  label <- mx.nd.reshape(label, shape = -1)
+  pred_probs <- mx.nd.pick(data = pred, index = label, axis = 1)
+  
+  mask <- label != mask_element
+  mask_length <- mx.nd.sum(mask)
+  
+  NLL <- -mx.nd.sum(mx.nd.log(pred_probs) * mask) / mask_length
+  res <- mx.nd.exp(NLL)
+  return(as.array(res))
 })
+
+#' LogLoss metric for logistic regression
+#'
+#' @export
+mx.metric.logloss <- mx.metric.custom("logloss", function(label, pred) {
+  pred <- mx.nd.reshape(pred, shape = 0)
+  pred <- mx.nd.clip(pred, a_min = 1e-15, a_max = 1-1e-15)
+  res <- -mx.nd.mean(label * mx.nd.log(pred) + (1-label) * mx.nd.log(1-pred))
+  return(as.array(res))
+})
+
+#' Accuracy metric for logistic regression
+#'
+#' @export
+mx.metric.logistic_acc <- mx.metric.custom("accuracy", function(label, pred) {
+  pred <- mx.nd.reshape(pred, shape = 0) > 0.5
+  res <- mx.nd.mean(label == pred)
+  return(as.array(res))
+})
+
diff --git a/R-package/R/model.R b/R-package/R/model.R
index 01b5ed72835..b461f7973f6 100644
--- a/R-package/R/model.R
+++ b/R-package/R/model.R
@@ -112,13 +112,14 @@ mx.model.train <- function(symbol, ctx, input.shape, output.shape,
                            begin.round, end.round, optimizer,
                            train.data, eval.data, metric,
                            epoch.end.callback, batch.end.callback,
-                           kvstore, fixed.param = NULL, verbose = TRUE) {
+                           kvstore, fixed.param, verbose,
+                           metric_cpu) {
   ndevice <- length(ctx)
   if(verbose) message("Start training with ", ndevice, " devices")
   # create the executors
   input_slice <- mx.model.slice.shape(input.shape, ndevice)
   output_slice <- mx.model.slice.shape(output.shape, ndevice)
-
+  
   arg_names <- arguments(symbol)
   output.names <- names(output.shape)
   #label_name <- arg_names[endsWith(arg_names, "label")]
@@ -126,7 +127,7 @@ mx.model.train <- function(symbol, ctx, input.shape, output.shape,
     arg_lst <- list(symbol = symbol, ctx = ctx[[i]], grad.req = "write")
     arg_lst <- append(arg_lst, input_slice[[i]]$shape)
     arg_lst <- append(arg_lst, output_slice[[i]]$shape)
-    arg_lst[["fixed.param"]] = fixed.param
+    arg_lst[["fixed.param"]] = unique(c(fixed.param, names(input.shape), names(output.shape)))
     do.call(mx.simple.bind, arg_lst)
   })
   # set the parameters into executors
@@ -153,8 +154,10 @@ mx.model.train <- function(symbol, ctx, input.shape, output.shape,
     kvstore$init(params.index, train.execs[[1]]$ref.arg.arrays[params.index])
   }
   # Get the input names
-
+  
   for (iteration in begin.round:end.round) {
+    # reset training data
+    train.data$reset()
     nbatch <- 0
     if (!is.null(metric)) {
       train.metric <- metric$init()
@@ -175,13 +178,22 @@ mx.model.train <- function(symbol, ctx, input.shape, output.shape,
         }
         mx.exec.update.arg.arrays(train.execs[[i]], s, match.name=TRUE)
       }
+      
+      # forward pass
       for (texec in train.execs) {
         mx.exec.forward(texec, is.train=TRUE)
       }
-      # copy outputs to CPU
-      out.preds <- lapply(train.execs, function(texec) {
-        mx.nd.copyto(texec$ref.outputs[[1]], mx.cpu())
-      })
+      
+      # copy of preds and labels for metric
+      if (!is.null(metric)) {
+        preds <- lapply(train.execs, function(texec) {texec$ref.outputs[[1]]})
+        labels <- lapply(train.execs, function(texec) {texec$ref.arg.arrays[[output.names[length(output.names)]]]})
+        if (metric_cpu) {
+          preds <- lapply(seq_along(train.execs), function(i) {mx.nd.copyto(preds[[i]], mx.cpu())})
+          labels <- lapply(seq_along(train.execs), function(i) {mx.nd.copyto(labels[[i]], mx.cpu())})
+        }
+      }
+      
       # backward pass
       for (texec in train.execs) {
         mx.exec.backward(texec)
@@ -215,7 +227,9 @@ mx.model.train <- function(symbol, ctx, input.shape, output.shape,
       # Update the evaluation metrics
       if (!is.null(metric)) {
         for (i in seq_len(ndevice)) {
-          train.metric <- metric$update(slices[[i]][[length(slices[[i]])]], out.preds[[i]], train.metric)
+          train.metric <- metric$update(label = labels[[i]],
+                                        pred = preds[[i]],
+                                        state = train.metric)
         }
       }
       nbatch <- nbatch + 1
@@ -223,13 +237,14 @@ mx.model.train <- function(symbol, ctx, input.shape, output.shape,
         batch.end.callback(iteration, nbatch, environment())
       }
     }
-    # reset training data
-    train.data$reset()
+    
     if (!is.null(metric)) {
       result <- metric$get(train.metric)
       if(verbose) message("[", iteration, "] Train-", result$name, "=", result$value)
     }
     if (!is.null(eval.data)) {
+      # reset eval data
+      eval.data$reset()
       if (!is.null(metric)) {
         eval.metric <- metric$init()
       }
@@ -250,16 +265,22 @@ mx.model.train <- function(symbol, ctx, input.shape, output.shape,
         for (texec in train.execs) {
           mx.exec.forward(texec, is.train=FALSE)
         }
-        out.preds <- lapply(train.execs, function(texec) {
-          mx.nd.copyto(texec$ref.outputs[[1]], mx.cpu())
-        })
+        
+        # copy of preds and labels for metric and update metric
         if (!is.null(metric)) {
+          preds <- lapply(train.execs, function(texec) {texec$ref.outputs[[1]]})
+          labels <- lapply(train.execs, function(texec) {texec$ref.arg.arrays[[output.names[length(output.names)]]]})
+          if (metric_cpu) {
+            preds <- lapply(seq_along(train.execs), function(i) {mx.nd.copyto(preds[[i]], mx.cpu())})
+            labels <- lapply(seq_along(train.execs), function(i) {mx.nd.copyto(labels[[i]], mx.cpu())})
+          }
           for (i in seq_len(ndevice)) {
-            eval.metric <- metric$update(slices[[i]][[length(slices[[i]])]] , out.preds[[i]], eval.metric)
+            eval.metric <- metric$update(label = labels[[i]], 
+                                         pred = preds[[i]], 
+                                         state = eval.metric)
           }
         }
       }
-      eval.data$reset()
       if (!is.null(metric)) {
         result <- metric$get(eval.metric)
         if(verbose) message("[", iteration, "] Validation-", result$name, "=", result$value)
@@ -269,12 +290,12 @@ mx.model.train <- function(symbol, ctx, input.shape, output.shape,
     }
     # get the model out
     model <- mx.model.extract.model(symbol, train.execs)
-
+    
     epoch_continue <- TRUE
     if (!is.null(epoch.end.callback)) {
       epoch_continue <- epoch.end.callback(iteration, 0, environment(), verbose = verbose)
     }
-
+    
     if (!epoch_continue) {
       break
     }
@@ -291,11 +312,11 @@ mx.model.train <- function(symbol, ctx, input.shape, output.shape,
 #' @export
 mx.model.init.params <- function(symbol, input.shape, output.shape, initializer, ctx) {
   if (!is.MXSymbol(symbol)) stop("symbol needs to be MXSymbol")
-
+  
   arg_lst <- list(symbol = symbol)
   arg_lst <- append(arg_lst, input.shape)
   arg_lst <- append(arg_lst, output.shape)
-
+  
   slist <- do.call(mx.symbol.infer.shape, arg_lst)
   if (is.null(slist)) stop("Not enough information to get shapes")
   arg.params <- mx.init.create(initializer, slist$arg.shapes, ctx, skip.unknown=TRUE)
@@ -430,93 +451,95 @@ mx.model.select.layout.predict <- function(X, model) {
 #' @export
 
 mx.model.FeedForward.create <-
-function(symbol, X, y=NULL, ctx=NULL, begin.round=1,
-         num.round=10, optimizer="sgd",
-         initializer=mx.init.uniform(0.01),
-         eval.data=NULL, eval.metric=NULL,
-         epoch.end.callback=NULL, batch.end.callback=NULL,
-         array.batch.size=128, array.layout="auto",
-         kvstore = "local", verbose = TRUE,
-         arg.params = NULL, aux.params = NULL,
-         input.names=NULL, output.names = NULL,
-         fixed.param = NULL, allow.extra.params = FALSE,
-         ...) {
-  if (is.array(X) || is.matrix(X)) {
-    if (array.layout == "auto") {
-      array.layout <- mx.model.select.layout.train(X, y)
+  function(symbol, X, y=NULL, ctx=NULL, begin.round=1,
+           num.round=10, optimizer="sgd",
+           initializer=mx.init.uniform(0.01),
+           eval.data=NULL, eval.metric=NULL,
+           epoch.end.callback=NULL, batch.end.callback=NULL,
+           array.batch.size=128, array.layout="auto",
+           kvstore = "local", verbose = TRUE,
+           arg.params = NULL, aux.params = NULL,
+           input.names=NULL, output.names = NULL,
+           fixed.param = NULL, allow.extra.params = FALSE,
+           metric_cpu = TRUE,
+           ...) {
+    if (is.array(X) || is.matrix(X)) {
+      if (array.layout == "auto") {
+        array.layout <- mx.model.select.layout.train(X, y)
+      }
+      if (array.layout == "rowmajor") {
+        X <- t(X)
+      }
     }
-    if (array.layout == "rowmajor") {
-      X <- t(X)
+    X <- mx.model.init.iter(X, y, batch.size=array.batch.size, is.train=TRUE)
+    if (!X$iter.next()) {
+      X$reset()
+      if (!X$iter.next()) stop("Empty input")
     }
-  }
-  X <- mx.model.init.iter(X, y, batch.size=array.batch.size, is.train=TRUE)
-  if (!X$iter.next()) {
-    X$reset()
-    if (!X$iter.next()) stop("Empty input")
-  }
-  if (is.null(input.names)) {
-    input.names <- "data"
-  }
-  input.shape <- sapply(input.names, function(n){dim(X$value()[[n]])}, simplify = FALSE)
-  if (is.null(output.names)) {
-    arg_names <- arguments(symbol)
-    output.names <- arg_names[endsWith(arg_names, "label")]
-    output.shape <- list()
-    output.shape[[output.names]] <- dim((X$value())$label)
-  } else {
-    output.shape <- sapply(output.names, function(n){dim(X$value()[[n]])}, simplify = FALSE)  
-  }
-  params <- mx.model.init.params(symbol, input.shape, output.shape, initializer, mx.cpu())
-  if (!is.null(arg.params)) params$arg.params <- arg.params
-  if (!is.null(aux.params)) params$aux.params <- aux.params
-  if (allow.extra.params) {
-    params$arg.params[!names(params$arg.params) %in% arguments(symbol)] <- NULL
-  }
-  if (is.null(ctx)) ctx <- mx.ctx.default()
-  if (is.mx.context(ctx)) {
-    ctx <- list(ctx)
-  }
-  if (!is.list(ctx)) stop("ctx must be mx.context or list of mx.context")
-  if (is.character(optimizer)) {
-    if (is.numeric(input.shape)) {
-      ndim <- length(input.shape)
-      batchsize = input.shape[[ndim]]      
+    if (is.null(input.names)) {
+      input.names <- "data"
+    }
+    input.shape <- sapply(input.names, function(n){dim(X$value()[[n]])}, simplify = FALSE)
+    if (is.null(output.names)) {
+      arg_names <- arguments(symbol)
+      output.names <- arg_names[endsWith(arg_names, "label")]
+      output.shape <- list()
+      output.shape[[output.names]] <- dim((X$value())$label)
     } else {
-      ndim <- length(input.shape[[1]])
-      batchsize = input.shape[[1]][[ndim]]
+      output.shape <- sapply(output.names, function(n){dim(X$value()[[n]])}, simplify = FALSE)  
     }
-    optimizer <- mx.opt.create(optimizer, rescale.grad=(1/batchsize), ...)
-  }
-  if (!is.null(eval.data) && !is.list(eval.data) && !is.mx.dataiter(eval.data)) {
-    stop("The validation set should be either a mx.io.DataIter or a R list")
-  }
-  if (is.list(eval.data)) {
-    if (is.null(eval.data$data) || is.null(eval.data$label)){
-      stop("Please provide the validation set as list(data=R.array, label=R.array)")
+    params <- mx.model.init.params(symbol, input.shape, output.shape, initializer, mx.cpu())
+    if (!is.null(arg.params)) params$arg.params <- arg.params
+    if (!is.null(aux.params)) params$aux.params <- aux.params
+    if (allow.extra.params) {
+      params$arg.params[!names(params$arg.params) %in% arguments(symbol)] <- NULL
     }
-    if (is.array(eval.data$data) || is.matrix(eval.data$data)) {
-      if (array.layout == "auto") {
-        array.layout <- mx.model.select.layout.train(eval.data$data, eval.data$label)
+    if (is.null(ctx)) ctx <- mx.ctx.default()
+    if (is.mx.context(ctx)) {
+      ctx <- list(ctx)
+    }
+    if (!is.list(ctx)) stop("ctx must be mx.context or list of mx.context")
+    if (is.character(optimizer)) {
+      if (is.numeric(input.shape)) {
+        ndim <- length(input.shape)
+        batchsize = input.shape[[ndim]]      
+      } else {
+        ndim <- length(input.shape[[1]])
+        batchsize = input.shape[[1]][[ndim]]
       }
-      if (array.layout == "rowmajor") {
-        eval.data$data <- t(eval.data$data)
+      optimizer <- mx.opt.create(optimizer, rescale.grad=(1/batchsize), ...)
+    }
+    if (!is.null(eval.data) && !is.list(eval.data) && !is.mx.dataiter(eval.data)) {
+      stop("The validation set should be either a mx.io.DataIter or a R list")
+    }
+    if (is.list(eval.data)) {
+      if (is.null(eval.data$data) || is.null(eval.data$label)){
+        stop("Please provide the validation set as list(data=R.array, label=R.array)")
+      }
+      if (is.array(eval.data$data) || is.matrix(eval.data$data)) {
+        if (array.layout == "auto") {
+          array.layout <- mx.model.select.layout.train(eval.data$data, eval.data$label)
+        }
+        if (array.layout == "rowmajor") {
+          eval.data$data <- t(eval.data$data)
+        }
       }
+      eval.data <- mx.model.init.iter(eval.data$data, eval.data$label, batch.size=array.batch.size, is.train = TRUE)
     }
-    eval.data <- mx.model.init.iter(eval.data$data, eval.data$label, batch.size=array.batch.size, is.train = TRUE)
+    kvstore <- mx.model.create.kvstore(kvstore, params$arg.params, length(ctx), verbose=verbose)
+    model <- mx.model.train(symbol, ctx, input.shape, output.shape,
+                            params$arg.params, params$aux.params,
+                            begin.round, num.round, optimizer=optimizer,
+                            train.data=X, eval.data=eval.data,
+                            metric=eval.metric,
+                            epoch.end.callback=epoch.end.callback,
+                            batch.end.callback=batch.end.callback,
+                            kvstore=kvstore,
+                            fixed.param = fixed.param,
+                            verbose=verbose,
+                            metric_cpu = metric_cpu)
+    return (model)
   }
-  kvstore <- mx.model.create.kvstore(kvstore, params$arg.params, length(ctx), verbose=verbose)
-  model <- mx.model.train(symbol, ctx, input.shape, output.shape,
-                          params$arg.params, params$aux.params,
-                          begin.round, num.round, optimizer=optimizer,
-                          train.data=X, eval.data=eval.data,
-                          metric=eval.metric,
-                          epoch.end.callback=epoch.end.callback,
-                          batch.end.callback=batch.end.callback,
-                          kvstore=kvstore,
-                          fixed.param = fixed.param,
-                          verbose=verbose)
-  return (model)
-}
 
 #' Predict the outputs given a model and dataset.
 #'
@@ -552,7 +575,7 @@ predict.MXFeedForwardModel <- function(model, X, ctx = NULL, array.batch.size =
   if (!X$iter.next()) stop("Cannot predict on empty iterator")
   dlist = X$value()
   arg_lst <- list(symbol = model$symbol, ctx = ctx, data = dim(dlist$data), grad.req="null")
-
+  
   pexec <- do.call(mx.simple.bind, arg_lst)
   if (allow.extra.params) {
     model$arg.params[!names(model$arg.params) %in% arguments(model$symbol)] <- NULL
@@ -588,7 +611,7 @@ mx.model.load <- function(prefix, iteration) {
   
   arg.index <- startsWith(nms, "arg:")
   aux.index <- startsWith(nms, "aux:")
-
+  
   if (any(arg.index)) {
     arg.params <- save.dict[arg.index]
     names(arg.params) <- substr(nms[arg.index], 5, nchar(nms[arg.index]))
diff --git a/R-package/R/model.rnn.R b/R-package/R/model.rnn.R
index 3bf1d96dceb..f328d1ba6b7 100644
--- a/R-package/R/model.rnn.R
+++ b/R-package/R/model.rnn.R
@@ -2,8 +2,9 @@
 mx.model.train.buckets <- function(symbol, ctx, train.data, eval.data, 
                                    dlist, arg.params, aux.params, 
                                    grad.req, arg.update.idx, 
-                                   begin.round, end.round, optimizer, metric, 
-                                   epoch.end.callback, batch.end.callback, kvstore, verbose = TRUE) {
+                                   begin.round, end.round, optimizer, metric, metric_cpu,
+                                   epoch.end.callback, batch.end.callback, kvstore, verbose,
+                                   gc_freq) {
   
   ndevice <- length(ctx)
   if (verbose) 
@@ -21,7 +22,7 @@ mx.model.train.buckets <- function(symbol, ctx, train.data, eval.data,
   train.execs <- lapply(seq_len(ndevice), function(i) {
     s <- slices[[i]]
     mx.symbol.bind(symbol = sym_ini, arg.arrays = c(s, arg.params)[arg.update.idx], 
-                           aux.arrays = aux.params, ctx = ctx[[i]], grad.req = grad.req)
+                   aux.arrays = aux.params, ctx = ctx[[i]], grad.req = grad.req)
   })
   
   # KVStore related stuffs
@@ -48,6 +49,7 @@ mx.model.train.buckets <- function(symbol, ctx, train.data, eval.data,
   # train over specified number of epochs
   for (iteration in begin.round:end.round) {
     nbatch <- 0
+    gc()
     if (!is.null(metric)) {
       train.metric <- metric$init()
     }
@@ -77,14 +79,22 @@ mx.model.train.buckets <- function(symbol, ctx, train.data, eval.data,
         }
       }
       
+      # forward pass
       for (texec in train.execs) {
         mx.exec.forward(texec, is.train = TRUE)
       }
       
-      out.preds <- lapply(train.execs, function(texec) {
-        mx.nd.copyto(texec$ref.outputs[[1]], mx.cpu())
-      })
+      # copy of preds and labels for metric
+      if (!is.null(metric)) {
+        preds <- lapply(train.execs, function(texec) {texec$ref.outputs[[1]]})
+        labels <- lapply(train.execs, function(texec) {texec$ref.arg.arrays[[input.names[length(input.names)]]]})
+        if (metric_cpu) {
+          preds <- lapply(seq_along(train.execs), function(i) {mx.nd.copyto(preds[[i]], mx.cpu())})
+          labels <- lapply(seq_along(train.execs), function(i) {mx.nd.copyto(labels[[i]], mx.cpu())})
+        }
+      }
       
+      # backward pass
       for (texec in train.execs) {
         mx.exec.backward(texec)
       }
@@ -118,12 +128,16 @@ mx.model.train.buckets <- function(symbol, ctx, train.data, eval.data,
       # Update the evaluation metrics
       if (!is.null(metric)) {
         for (i in seq_len(ndevice)) {
-          train.metric <- metric$update(label = slices[[i]][[length(slices[[i]])]], 
-                                        pred = out.preds[[i]], state = train.metric)
+          train.metric <- metric$update(label = labels[[i]],
+                                        pred = preds[[i]],
+                                        state = train.metric)
         }
       }
       
       nbatch <- nbatch + 1
+      if (!is.null(gc_freq)) {
+        if (nbatch %% gc_freq == 0) gc()
+      }
       
       if (!is.null(batch.end.callback)) {
         batch.end.callback(iteration, nbatch, environment())
@@ -156,8 +170,8 @@ mx.model.train.buckets <- function(symbol, ctx, train.data, eval.data,
           train.execs <- lapply(seq_len(ndevice), function(i) {
             s <- slices[[i]]
             mx.symbol.bind(symbol = symbol[[names(eval.data$bucketID)]], 
-                                   arg.arrays = c(s, train.execs[[i]]$arg.arrays[arg.params.names])[arg.update.idx],
-                                   aux.arrays = train.execs[[i]]$aux.arrays, ctx = ctx[[i]], grad.req = grad.req)
+                           arg.arrays = c(s, train.execs[[i]]$arg.arrays[arg.params.names])[arg.update.idx],
+                           aux.arrays = train.execs[[i]]$aux.arrays, ctx = ctx[[i]], grad.req = grad.req)
           })
         } else {
           for (i in seq_len(ndevice)) {
@@ -166,19 +180,23 @@ mx.model.train.buckets <- function(symbol, ctx, train.data, eval.data,
           }
         }
         
+        # forward pass
         for (texec in train.execs) {
           mx.exec.forward(texec, is.train = FALSE)
         }
         
-        # copy outputs to CPU
-        out.preds <- lapply(train.execs, function(texec) {
-          mx.nd.copyto(texec$ref.outputs[[1]], mx.cpu())
-        })
-        
+        # copy of preds and labels for metric and update metric
         if (!is.null(metric)) {
+          preds <- lapply(train.execs, function(texec) {texec$ref.outputs[[1]]})
+          labels <- lapply(train.execs, function(texec) {texec$ref.arg.arrays[[input.names[length(input.names)]]]})
+          if (metric_cpu) {
+            preds <- lapply(seq_along(train.execs), function(i) {mx.nd.copyto(preds[[i]], mx.cpu())})
+            labels <- lapply(seq_along(train.execs), function(i) {mx.nd.copyto(labels[[i]], mx.cpu())})
+          }
           for (i in seq_len(ndevice)) {
-            eval.metric <- metric$update(slices[[i]][[length(slices[[i]])]], 
-                                         out.preds[[i]], eval.metric)
+            eval.metric <- metric$update(label = labels[[i]], 
+                                         pred = preds[[i]], 
+                                         state = eval.metric)
           }
         }
       }
@@ -187,7 +205,7 @@ mx.model.train.buckets <- function(symbol, ctx, train.data, eval.data,
         result <- metric$get(eval.metric)
         if (verbose) {
           message("[", iteration, "] Validation-", result$name, "=", 
-                         result$value)
+                  result$value)
         }
       }
     } else {
@@ -232,7 +250,7 @@ mx.model.buckets <- function(symbol, train.data, eval.data = NULL, metric = NULL
                              num.round = 1, begin.round = 1, 
                              initializer = mx.init.uniform(0.01), optimizer = "sgd", ctx = NULL, 
                              batch.end.callback = NULL, epoch.end.callback = NULL, 
-                             kvstore = "local", verbose = TRUE) {
+                             kvstore = "local", verbose = TRUE, metric_cpu = TRUE, gc_freq = NULL) {
   
   if (!train.data$iter.next()) {
     train.data$reset()
@@ -324,7 +342,7 @@ mx.model.buckets <- function(symbol, train.data, eval.data = NULL, metric = NULL
   
   # kvstore initialization
   kvstore <- mx.model.create.kvstore(kvstore, params$arg.params, length(ctx), 
-                                             verbose = verbose)
+                                     verbose = verbose)
   
   ### Execute training
   model <- mx.model.train.buckets(symbol = symbol, ctx = ctx,  train.data = train.data, eval.data = eval.data, 
@@ -333,7 +351,7 @@ mx.model.buckets <- function(symbol, train.data, eval.data = NULL, metric = NULL
                                   optimizer = optimizer, metric = metric, 
                                   begin.round = begin.round, end.round = num.round, 
                                   batch.end.callback = batch.end.callback, epoch.end.callback = epoch.end.callback, 
-                                  kvstore = kvstore, verbose = verbose)
+                                  kvstore = kvstore, verbose = verbose, metric_cpu = metric_cpu, gc_freq = gc_freq)
   
   return(model)
 }
diff --git a/R-package/R/rnn.infer.R b/R-package/R/rnn.infer.R
index 588056f2eb8..57201e721c0 100644
--- a/R-package/R/rnn.infer.R
+++ b/R-package/R/rnn.infer.R
@@ -1,7 +1,7 @@
 
 #' Inference of RNN model
 #'
-#' @param infer.data Data iterator created by mx.io.bucket.iter
+#' @param infer.data DataIter
 #' @param model Model used for inference
 #' @param ctx
 #'
@@ -37,7 +37,7 @@ mx.infer.rnn <- function(infer.data, model, ctx = mx.cpu()) {
   arguments.ini <- lapply(shapes$arg.shapes, function(shape) {
     mx.nd.zeros(shape = shape, ctx = mx.cpu())
   })
-
+  
   arg.params <- model$arg.params
   arg.params.names <- names(arg.params)
   aux.params <- model$aux.params
@@ -59,7 +59,7 @@ mx.infer.rnn <- function(infer.data, model, ctx = mx.cpu()) {
   arg_update_idx <- match(arguments, update_names)
   
   execs <- mx.symbol.bind(symbol = symbol, arg.arrays = c(dlist, arg.params.fix, arg.params)[arg_update_idx], 
-                                  aux.arrays = aux.params, ctx = ctx[[1]], grad.req = grad.req)
+                          aux.arrays = aux.params, ctx = ctx[[1]], grad.req = grad.req)
   
   # Initial input shapes - need to be adapted for multi-devices - divide highest
   # dimension by device nb
@@ -69,10 +69,10 @@ mx.infer.rnn <- function(infer.data, model, ctx = mx.cpu()) {
   while (infer.data$iter.next()) {
     
     # Get input data slice
-    dlist <- infer.data$value()  #[input.names]
+    dlist <- infer.data$value()[input.names]
     
     execs <- mx.symbol.bind(symbol = symbol, arg.arrays = c(dlist, execs$arg.arrays[arg.params.fix.names], execs$arg.arrays[arg.params.names])[arg_update_idx], 
-                                    aux.arrays = execs$aux.arrays, ctx = ctx[[1]], grad.req = grad.req)
+                            aux.arrays = execs$aux.arrays, ctx = ctx[[1]], grad.req = grad.req)
     
     mx.exec.forward(execs, is.train = FALSE)
     
@@ -225,7 +225,7 @@ mx.infer.rnn.one.unroll <- function(infer.data,
   
   # init_state_shapes
   init_states_names <- arguments[startsWith(arguments, "init_")]
-  init_states_shapes = lapply(init_states_names, function(x) c(num_hidden, tail(input.shape[[1]], 1)))
+  init_states_shapes <- lapply(init_states_names, function(x) c(num_hidden, tail(input.shape[[1]], 1)))
   names(init_states_shapes) <- init_states_names
   
   shapes <- symbol$infer.shape(c(input.shape, init_states_shapes))
diff --git a/R-package/tests/testthat/test_model.R b/R-package/tests/testthat/test_model.R
index 13e54ff9ce4..6167ed66c41 100644
--- a/R-package/tests/testthat/test_model.R
+++ b/R-package/tests/testthat/test_model.R
@@ -80,8 +80,9 @@ test_that("Regression", {
   lro <- mx.symbol.LinearRegressionOutput(fc1)
   
   demo.metric.mae <- mx.metric.custom("mae", function(label, pred) {
-    res <- mean(abs(label - pred))
-    return(res)
+    pred <- mx.nd.reshape(pred, shape = 0)
+    res <- mx.nd.mean(mx.nd.abs(label-pred))
+    return(as.array(res))
   })
   mx.set.seed(0)
   model <- mx.model.FeedForward.create(lro, X = train.x, y = train.y,
@@ -291,6 +292,8 @@ test_that("Captcha", {
   captcha_net <- mx.symbol.SoftmaxOutput(data = fc2, label = label, name = "softmax")
   
   mx.metric.acc2 <- mx.metric.custom("accuracy", function(label, pred) {
+    label = as.array(label)
+    pred = as.array(pred)
     ypred <- max.col(t(pred)) - 1
     ypred <- matrix(ypred, nrow = nrow(label), ncol = ncol(label), byrow = TRUE)
     return(sum(colSums(label == ypred) == 4)/ncol(label))


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services