You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/03/31 05:43:36 UTC
[GitHub] hetong007 closed pull request #9803: R Metrics
hetong007 closed pull request #9803: R Metrics
URL: https://github.com/apache/incubator-mxnet/pull/9803
This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:
As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):
diff --git a/R-package/R/metric.R b/R-package/R/metric.R
index f8d9c33a726..9e1913a0877 100644
--- a/R-package/R/metric.R
+++ b/R-package/R/metric.R
@@ -6,7 +6,7 @@ mx.metric.custom <- function(name, feval) {
c(0, 0)
}
update <- function(label, pred, state) {
- m <- feval(as.array(label), as.array(pred))
+ m <- feval(label, pred)
state <- c(state[[1]] + 1, state[[2]] + m)
return(state)
}
@@ -22,69 +22,90 @@ mx.metric.custom <- function(name, feval) {
#'
#' @export
mx.metric.accuracy <- mx.metric.custom("accuracy", function(label, pred) {
- ypred = max.col(t(as.array(pred)), tie="first")
- return(sum((as.array(label) + 1) == ypred) / length(label))
+ pred <- mx.nd.argmax(data = pred, axis = 1, keepdims = F)
+ res <- mx.nd.mean(label == pred)
+ return(as.array(res))
})
-#' Helper function for top-k accuracy
-is.num.in.vect <- function(vect, num){
- resp <- any(is.element(vect, num))
- return(resp)
-}
-
#' Top-k accuracy metric for classification
#'
#' @export
mx.metric.top_k_accuracy <- mx.metric.custom("top_k_accuracy", function(label, pred, top_k = 5) {
- if(top_k == 1){
- return(mx.metric.accuracy(label,pred))
- } else{
- ypred <- apply(pred,2,function(x) order(x, decreasing=TRUE)[seq_len(top_k)])
- ans <- apply(ypred, 2, is.num.in.vect, num = as.array(label + 1))
- acc <- sum(ans)/length(label)
- return(acc)
- }
+ label <- mx.nd.reshape(data = label, shape = c(1,0))
+ pred <- mx.nd.topk(data = pred, axis = 1, k = top_k, ret_typ = "indices")
+ pred <- mx.nd.broadcast.equal(lhs = pred, rhs = label)
+ res <- mx.nd.mean(mx.nd.sum(data = pred, axis = 1, keepdims = F))
+ return(as.array(res))
})
#' MSE (Mean Squared Error) metric for regression
#'
#' @export
mx.metric.mse <- mx.metric.custom("mse", function(label, pred) {
- res <- mean((label-pred)^2)
- return(res)
+ pred <- mx.nd.reshape(pred, shape = 0)
+ res <- mx.nd.mean(mx.nd.square(label-pred))
+ return(as.array(res))
})
-
+
#' RMSE (Root Mean Squared Error) metric for regression
#'
#' @export
mx.metric.rmse <- mx.metric.custom("rmse", function(label, pred) {
- res <- sqrt(mean((label-pred)^2))
- return(res)
+ pred <- mx.nd.reshape(pred, shape = 0)
+ res <- mx.nd.sqrt(mx.nd.mean(mx.nd.square(label-pred)))
+ return(as.array(res))
})
#' MAE (Mean Absolute Error) metric for regression
#'
#' @export
mx.metric.mae <- mx.metric.custom("mae", function(label, pred) {
- res <- mean(abs(label-pred))
- return(res)
+ pred <- mx.nd.reshape(pred, shape = 0)
+ res <- mx.nd.mean(mx.nd.abs(label-pred))
+ return(as.array(res))
})
#' RMSLE (Root Mean Squared Logarithmic Error) metric for regression
#'
#' @export
mx.metric.rmsle <- mx.metric.custom("rmsle", function(label, pred) {
- res <- sqrt(mean((log(pred + 1) - log(label + 1))^2))
- return(res)
+ pred <- mx.nd.reshape(pred, shape = 0)
+ res <- mx.nd.sqrt(mx.nd.mean(mx.nd.square(mx.nd.log1p(pred) - mx.nd.log1p(label))))
+ return(as.array(res))
})
#' Perplexity metric for language model
#'
#' @export
-mx.metric.Perplexity <- mx.metric.custom("Perplexity", function(label, pred) {
- label_probs <- as.array(mx.nd.choose.element.0index(pred, label))
- batch <- length(label_probs)
- NLL <- -sum(log(pmax(1e-15, as.array(label_probs)))) / batch
- Perplexity <- exp(NLL)
- return(Perplexity)
+mx.metric.Perplexity <- mx.metric.custom("Perplexity", function(label, pred, mask_element = -1) {
+
+ label <- mx.nd.reshape(label, shape = -1)
+ pred_probs <- mx.nd.pick(data = pred, index = label, axis = 1)
+
+ mask <- label != mask_element
+ mask_length <- mx.nd.sum(mask)
+
+ NLL <- -mx.nd.sum(mx.nd.log(pred_probs) * mask) / mask_length
+ res <- mx.nd.exp(NLL)
+ return(as.array(res))
})
+
+#' LogLoss metric for logistic regression
+#'
+#' @export
+mx.metric.logloss <- mx.metric.custom("logloss", function(label, pred) {
+ pred <- mx.nd.reshape(pred, shape = 0)
+ pred <- mx.nd.clip(pred, a_min = 1e-15, a_max = 1-1e-15)
+ res <- -mx.nd.mean(label * mx.nd.log(pred) + (1-label) * mx.nd.log(1-pred))
+ return(as.array(res))
+})
+
+#' Accuracy metric for logistic regression
+#'
+#' @export
+mx.metric.logistic_acc <- mx.metric.custom("accuracy", function(label, pred) {
+ pred <- mx.nd.reshape(pred, shape = 0) > 0.5
+ res <- mx.nd.mean(label == pred)
+ return(as.array(res))
+})
+
diff --git a/R-package/R/model.R b/R-package/R/model.R
index 01b5ed72835..b461f7973f6 100644
--- a/R-package/R/model.R
+++ b/R-package/R/model.R
@@ -112,13 +112,14 @@ mx.model.train <- function(symbol, ctx, input.shape, output.shape,
begin.round, end.round, optimizer,
train.data, eval.data, metric,
epoch.end.callback, batch.end.callback,
- kvstore, fixed.param = NULL, verbose = TRUE) {
+ kvstore, fixed.param, verbose,
+ metric_cpu) {
ndevice <- length(ctx)
if(verbose) message("Start training with ", ndevice, " devices")
# create the executors
input_slice <- mx.model.slice.shape(input.shape, ndevice)
output_slice <- mx.model.slice.shape(output.shape, ndevice)
-
+
arg_names <- arguments(symbol)
output.names <- names(output.shape)
#label_name <- arg_names[endsWith(arg_names, "label")]
@@ -126,7 +127,7 @@ mx.model.train <- function(symbol, ctx, input.shape, output.shape,
arg_lst <- list(symbol = symbol, ctx = ctx[[i]], grad.req = "write")
arg_lst <- append(arg_lst, input_slice[[i]]$shape)
arg_lst <- append(arg_lst, output_slice[[i]]$shape)
- arg_lst[["fixed.param"]] = fixed.param
+ arg_lst[["fixed.param"]] = unique(c(fixed.param, names(input.shape), names(output.shape)))
do.call(mx.simple.bind, arg_lst)
})
# set the parameters into executors
@@ -153,8 +154,10 @@ mx.model.train <- function(symbol, ctx, input.shape, output.shape,
kvstore$init(params.index, train.execs[[1]]$ref.arg.arrays[params.index])
}
# Get the input names
-
+
for (iteration in begin.round:end.round) {
+ # reset training data
+ train.data$reset()
nbatch <- 0
if (!is.null(metric)) {
train.metric <- metric$init()
@@ -175,13 +178,22 @@ mx.model.train <- function(symbol, ctx, input.shape, output.shape,
}
mx.exec.update.arg.arrays(train.execs[[i]], s, match.name=TRUE)
}
+
+ # forward pass
for (texec in train.execs) {
mx.exec.forward(texec, is.train=TRUE)
}
- # copy outputs to CPU
- out.preds <- lapply(train.execs, function(texec) {
- mx.nd.copyto(texec$ref.outputs[[1]], mx.cpu())
- })
+
+ # copy of preds and labels for metric
+ if (!is.null(metric)) {
+ preds <- lapply(train.execs, function(texec) {texec$ref.outputs[[1]]})
+ labels <- lapply(train.execs, function(texec) {texec$ref.arg.arrays[[output.names[length(output.names)]]]})
+ if (metric_cpu) {
+ preds <- lapply(seq_along(train.execs), function(i) {mx.nd.copyto(preds[[i]], mx.cpu())})
+ labels <- lapply(seq_along(train.execs), function(i) {mx.nd.copyto(labels[[i]], mx.cpu())})
+ }
+ }
+
# backward pass
for (texec in train.execs) {
mx.exec.backward(texec)
@@ -215,7 +227,9 @@ mx.model.train <- function(symbol, ctx, input.shape, output.shape,
# Update the evaluation metrics
if (!is.null(metric)) {
for (i in seq_len(ndevice)) {
- train.metric <- metric$update(slices[[i]][[length(slices[[i]])]], out.preds[[i]], train.metric)
+ train.metric <- metric$update(label = labels[[i]],
+ pred = preds[[i]],
+ state = train.metric)
}
}
nbatch <- nbatch + 1
@@ -223,13 +237,14 @@ mx.model.train <- function(symbol, ctx, input.shape, output.shape,
batch.end.callback(iteration, nbatch, environment())
}
}
- # reset training data
- train.data$reset()
+
if (!is.null(metric)) {
result <- metric$get(train.metric)
if(verbose) message("[", iteration, "] Train-", result$name, "=", result$value)
}
if (!is.null(eval.data)) {
+ # reset eval data
+ eval.data$reset()
if (!is.null(metric)) {
eval.metric <- metric$init()
}
@@ -250,16 +265,22 @@ mx.model.train <- function(symbol, ctx, input.shape, output.shape,
for (texec in train.execs) {
mx.exec.forward(texec, is.train=FALSE)
}
- out.preds <- lapply(train.execs, function(texec) {
- mx.nd.copyto(texec$ref.outputs[[1]], mx.cpu())
- })
+
+ # copy of preds and labels for metric and update metric
if (!is.null(metric)) {
+ preds <- lapply(train.execs, function(texec) {texec$ref.outputs[[1]]})
+ labels <- lapply(train.execs, function(texec) {texec$ref.arg.arrays[[output.names[length(output.names)]]]})
+ if (metric_cpu) {
+ preds <- lapply(seq_along(train.execs), function(i) {mx.nd.copyto(preds[[i]], mx.cpu())})
+ labels <- lapply(seq_along(train.execs), function(i) {mx.nd.copyto(labels[[i]], mx.cpu())})
+ }
for (i in seq_len(ndevice)) {
- eval.metric <- metric$update(slices[[i]][[length(slices[[i]])]] , out.preds[[i]], eval.metric)
+ eval.metric <- metric$update(label = labels[[i]],
+ pred = preds[[i]],
+ state = eval.metric)
}
}
}
- eval.data$reset()
if (!is.null(metric)) {
result <- metric$get(eval.metric)
if(verbose) message("[", iteration, "] Validation-", result$name, "=", result$value)
@@ -269,12 +290,12 @@ mx.model.train <- function(symbol, ctx, input.shape, output.shape,
}
# get the model out
model <- mx.model.extract.model(symbol, train.execs)
-
+
epoch_continue <- TRUE
if (!is.null(epoch.end.callback)) {
epoch_continue <- epoch.end.callback(iteration, 0, environment(), verbose = verbose)
}
-
+
if (!epoch_continue) {
break
}
@@ -291,11 +312,11 @@ mx.model.train <- function(symbol, ctx, input.shape, output.shape,
#' @export
mx.model.init.params <- function(symbol, input.shape, output.shape, initializer, ctx) {
if (!is.MXSymbol(symbol)) stop("symbol needs to be MXSymbol")
-
+
arg_lst <- list(symbol = symbol)
arg_lst <- append(arg_lst, input.shape)
arg_lst <- append(arg_lst, output.shape)
-
+
slist <- do.call(mx.symbol.infer.shape, arg_lst)
if (is.null(slist)) stop("Not enough information to get shapes")
arg.params <- mx.init.create(initializer, slist$arg.shapes, ctx, skip.unknown=TRUE)
@@ -430,93 +451,95 @@ mx.model.select.layout.predict <- function(X, model) {
#' @export
mx.model.FeedForward.create <-
-function(symbol, X, y=NULL, ctx=NULL, begin.round=1,
- num.round=10, optimizer="sgd",
- initializer=mx.init.uniform(0.01),
- eval.data=NULL, eval.metric=NULL,
- epoch.end.callback=NULL, batch.end.callback=NULL,
- array.batch.size=128, array.layout="auto",
- kvstore = "local", verbose = TRUE,
- arg.params = NULL, aux.params = NULL,
- input.names=NULL, output.names = NULL,
- fixed.param = NULL, allow.extra.params = FALSE,
- ...) {
- if (is.array(X) || is.matrix(X)) {
- if (array.layout == "auto") {
- array.layout <- mx.model.select.layout.train(X, y)
+ function(symbol, X, y=NULL, ctx=NULL, begin.round=1,
+ num.round=10, optimizer="sgd",
+ initializer=mx.init.uniform(0.01),
+ eval.data=NULL, eval.metric=NULL,
+ epoch.end.callback=NULL, batch.end.callback=NULL,
+ array.batch.size=128, array.layout="auto",
+ kvstore = "local", verbose = TRUE,
+ arg.params = NULL, aux.params = NULL,
+ input.names=NULL, output.names = NULL,
+ fixed.param = NULL, allow.extra.params = FALSE,
+ metric_cpu = TRUE,
+ ...) {
+ if (is.array(X) || is.matrix(X)) {
+ if (array.layout == "auto") {
+ array.layout <- mx.model.select.layout.train(X, y)
+ }
+ if (array.layout == "rowmajor") {
+ X <- t(X)
+ }
}
- if (array.layout == "rowmajor") {
- X <- t(X)
+ X <- mx.model.init.iter(X, y, batch.size=array.batch.size, is.train=TRUE)
+ if (!X$iter.next()) {
+ X$reset()
+ if (!X$iter.next()) stop("Empty input")
}
- }
- X <- mx.model.init.iter(X, y, batch.size=array.batch.size, is.train=TRUE)
- if (!X$iter.next()) {
- X$reset()
- if (!X$iter.next()) stop("Empty input")
- }
- if (is.null(input.names)) {
- input.names <- "data"
- }
- input.shape <- sapply(input.names, function(n){dim(X$value()[[n]])}, simplify = FALSE)
- if (is.null(output.names)) {
- arg_names <- arguments(symbol)
- output.names <- arg_names[endsWith(arg_names, "label")]
- output.shape <- list()
- output.shape[[output.names]] <- dim((X$value())$label)
- } else {
- output.shape <- sapply(output.names, function(n){dim(X$value()[[n]])}, simplify = FALSE)
- }
- params <- mx.model.init.params(symbol, input.shape, output.shape, initializer, mx.cpu())
- if (!is.null(arg.params)) params$arg.params <- arg.params
- if (!is.null(aux.params)) params$aux.params <- aux.params
- if (allow.extra.params) {
- params$arg.params[!names(params$arg.params) %in% arguments(symbol)] <- NULL
- }
- if (is.null(ctx)) ctx <- mx.ctx.default()
- if (is.mx.context(ctx)) {
- ctx <- list(ctx)
- }
- if (!is.list(ctx)) stop("ctx must be mx.context or list of mx.context")
- if (is.character(optimizer)) {
- if (is.numeric(input.shape)) {
- ndim <- length(input.shape)
- batchsize = input.shape[[ndim]]
+ if (is.null(input.names)) {
+ input.names <- "data"
+ }
+ input.shape <- sapply(input.names, function(n){dim(X$value()[[n]])}, simplify = FALSE)
+ if (is.null(output.names)) {
+ arg_names <- arguments(symbol)
+ output.names <- arg_names[endsWith(arg_names, "label")]
+ output.shape <- list()
+ output.shape[[output.names]] <- dim((X$value())$label)
} else {
- ndim <- length(input.shape[[1]])
- batchsize = input.shape[[1]][[ndim]]
+ output.shape <- sapply(output.names, function(n){dim(X$value()[[n]])}, simplify = FALSE)
}
- optimizer <- mx.opt.create(optimizer, rescale.grad=(1/batchsize), ...)
- }
- if (!is.null(eval.data) && !is.list(eval.data) && !is.mx.dataiter(eval.data)) {
- stop("The validation set should be either a mx.io.DataIter or a R list")
- }
- if (is.list(eval.data)) {
- if (is.null(eval.data$data) || is.null(eval.data$label)){
- stop("Please provide the validation set as list(data=R.array, label=R.array)")
+ params <- mx.model.init.params(symbol, input.shape, output.shape, initializer, mx.cpu())
+ if (!is.null(arg.params)) params$arg.params <- arg.params
+ if (!is.null(aux.params)) params$aux.params <- aux.params
+ if (allow.extra.params) {
+ params$arg.params[!names(params$arg.params) %in% arguments(symbol)] <- NULL
}
- if (is.array(eval.data$data) || is.matrix(eval.data$data)) {
- if (array.layout == "auto") {
- array.layout <- mx.model.select.layout.train(eval.data$data, eval.data$label)
+ if (is.null(ctx)) ctx <- mx.ctx.default()
+ if (is.mx.context(ctx)) {
+ ctx <- list(ctx)
+ }
+ if (!is.list(ctx)) stop("ctx must be mx.context or list of mx.context")
+ if (is.character(optimizer)) {
+ if (is.numeric(input.shape)) {
+ ndim <- length(input.shape)
+ batchsize = input.shape[[ndim]]
+ } else {
+ ndim <- length(input.shape[[1]])
+ batchsize = input.shape[[1]][[ndim]]
}
- if (array.layout == "rowmajor") {
- eval.data$data <- t(eval.data$data)
+ optimizer <- mx.opt.create(optimizer, rescale.grad=(1/batchsize), ...)
+ }
+ if (!is.null(eval.data) && !is.list(eval.data) && !is.mx.dataiter(eval.data)) {
+ stop("The validation set should be either a mx.io.DataIter or a R list")
+ }
+ if (is.list(eval.data)) {
+ if (is.null(eval.data$data) || is.null(eval.data$label)){
+ stop("Please provide the validation set as list(data=R.array, label=R.array)")
+ }
+ if (is.array(eval.data$data) || is.matrix(eval.data$data)) {
+ if (array.layout == "auto") {
+ array.layout <- mx.model.select.layout.train(eval.data$data, eval.data$label)
+ }
+ if (array.layout == "rowmajor") {
+ eval.data$data <- t(eval.data$data)
+ }
}
+ eval.data <- mx.model.init.iter(eval.data$data, eval.data$label, batch.size=array.batch.size, is.train = TRUE)
}
- eval.data <- mx.model.init.iter(eval.data$data, eval.data$label, batch.size=array.batch.size, is.train = TRUE)
+ kvstore <- mx.model.create.kvstore(kvstore, params$arg.params, length(ctx), verbose=verbose)
+ model <- mx.model.train(symbol, ctx, input.shape, output.shape,
+ params$arg.params, params$aux.params,
+ begin.round, num.round, optimizer=optimizer,
+ train.data=X, eval.data=eval.data,
+ metric=eval.metric,
+ epoch.end.callback=epoch.end.callback,
+ batch.end.callback=batch.end.callback,
+ kvstore=kvstore,
+ fixed.param = fixed.param,
+ verbose=verbose,
+ metric_cpu = metric_cpu)
+ return (model)
}
- kvstore <- mx.model.create.kvstore(kvstore, params$arg.params, length(ctx), verbose=verbose)
- model <- mx.model.train(symbol, ctx, input.shape, output.shape,
- params$arg.params, params$aux.params,
- begin.round, num.round, optimizer=optimizer,
- train.data=X, eval.data=eval.data,
- metric=eval.metric,
- epoch.end.callback=epoch.end.callback,
- batch.end.callback=batch.end.callback,
- kvstore=kvstore,
- fixed.param = fixed.param,
- verbose=verbose)
- return (model)
-}
#' Predict the outputs given a model and dataset.
#'
@@ -552,7 +575,7 @@ predict.MXFeedForwardModel <- function(model, X, ctx = NULL, array.batch.size =
if (!X$iter.next()) stop("Cannot predict on empty iterator")
dlist = X$value()
arg_lst <- list(symbol = model$symbol, ctx = ctx, data = dim(dlist$data), grad.req="null")
-
+
pexec <- do.call(mx.simple.bind, arg_lst)
if (allow.extra.params) {
model$arg.params[!names(model$arg.params) %in% arguments(model$symbol)] <- NULL
@@ -588,7 +611,7 @@ mx.model.load <- function(prefix, iteration) {
arg.index <- startsWith(nms, "arg:")
aux.index <- startsWith(nms, "aux:")
-
+
if (any(arg.index)) {
arg.params <- save.dict[arg.index]
names(arg.params) <- substr(nms[arg.index], 5, nchar(nms[arg.index]))
diff --git a/R-package/R/model.rnn.R b/R-package/R/model.rnn.R
index 3bf1d96dceb..f328d1ba6b7 100644
--- a/R-package/R/model.rnn.R
+++ b/R-package/R/model.rnn.R
@@ -2,8 +2,9 @@
mx.model.train.buckets <- function(symbol, ctx, train.data, eval.data,
dlist, arg.params, aux.params,
grad.req, arg.update.idx,
- begin.round, end.round, optimizer, metric,
- epoch.end.callback, batch.end.callback, kvstore, verbose = TRUE) {
+ begin.round, end.round, optimizer, metric, metric_cpu,
+ epoch.end.callback, batch.end.callback, kvstore, verbose,
+ gc_freq) {
ndevice <- length(ctx)
if (verbose)
@@ -21,7 +22,7 @@ mx.model.train.buckets <- function(symbol, ctx, train.data, eval.data,
train.execs <- lapply(seq_len(ndevice), function(i) {
s <- slices[[i]]
mx.symbol.bind(symbol = sym_ini, arg.arrays = c(s, arg.params)[arg.update.idx],
- aux.arrays = aux.params, ctx = ctx[[i]], grad.req = grad.req)
+ aux.arrays = aux.params, ctx = ctx[[i]], grad.req = grad.req)
})
# KVStore related stuffs
@@ -48,6 +49,7 @@ mx.model.train.buckets <- function(symbol, ctx, train.data, eval.data,
# train over specified number of epochs
for (iteration in begin.round:end.round) {
nbatch <- 0
+ gc()
if (!is.null(metric)) {
train.metric <- metric$init()
}
@@ -77,14 +79,22 @@ mx.model.train.buckets <- function(symbol, ctx, train.data, eval.data,
}
}
+ # forward pass
for (texec in train.execs) {
mx.exec.forward(texec, is.train = TRUE)
}
- out.preds <- lapply(train.execs, function(texec) {
- mx.nd.copyto(texec$ref.outputs[[1]], mx.cpu())
- })
+ # copy of preds and labels for metric
+ if (!is.null(metric)) {
+ preds <- lapply(train.execs, function(texec) {texec$ref.outputs[[1]]})
+ labels <- lapply(train.execs, function(texec) {texec$ref.arg.arrays[[input.names[length(input.names)]]]})
+ if (metric_cpu) {
+ preds <- lapply(seq_along(train.execs), function(i) {mx.nd.copyto(preds[[i]], mx.cpu())})
+ labels <- lapply(seq_along(train.execs), function(i) {mx.nd.copyto(labels[[i]], mx.cpu())})
+ }
+ }
+ # backward pass
for (texec in train.execs) {
mx.exec.backward(texec)
}
@@ -118,12 +128,16 @@ mx.model.train.buckets <- function(symbol, ctx, train.data, eval.data,
# Update the evaluation metrics
if (!is.null(metric)) {
for (i in seq_len(ndevice)) {
- train.metric <- metric$update(label = slices[[i]][[length(slices[[i]])]],
- pred = out.preds[[i]], state = train.metric)
+ train.metric <- metric$update(label = labels[[i]],
+ pred = preds[[i]],
+ state = train.metric)
}
}
nbatch <- nbatch + 1
+ if (!is.null(gc_freq)) {
+ if (nbatch %% gc_freq == 0) gc()
+ }
if (!is.null(batch.end.callback)) {
batch.end.callback(iteration, nbatch, environment())
@@ -156,8 +170,8 @@ mx.model.train.buckets <- function(symbol, ctx, train.data, eval.data,
train.execs <- lapply(seq_len(ndevice), function(i) {
s <- slices[[i]]
mx.symbol.bind(symbol = symbol[[names(eval.data$bucketID)]],
- arg.arrays = c(s, train.execs[[i]]$arg.arrays[arg.params.names])[arg.update.idx],
- aux.arrays = train.execs[[i]]$aux.arrays, ctx = ctx[[i]], grad.req = grad.req)
+ arg.arrays = c(s, train.execs[[i]]$arg.arrays[arg.params.names])[arg.update.idx],
+ aux.arrays = train.execs[[i]]$aux.arrays, ctx = ctx[[i]], grad.req = grad.req)
})
} else {
for (i in seq_len(ndevice)) {
@@ -166,19 +180,23 @@ mx.model.train.buckets <- function(symbol, ctx, train.data, eval.data,
}
}
+ # forward pass
for (texec in train.execs) {
mx.exec.forward(texec, is.train = FALSE)
}
- # copy outputs to CPU
- out.preds <- lapply(train.execs, function(texec) {
- mx.nd.copyto(texec$ref.outputs[[1]], mx.cpu())
- })
-
+ # copy of preds and labels for metric and update metric
if (!is.null(metric)) {
+ preds <- lapply(train.execs, function(texec) {texec$ref.outputs[[1]]})
+ labels <- lapply(train.execs, function(texec) {texec$ref.arg.arrays[[input.names[length(input.names)]]]})
+ if (metric_cpu) {
+ preds <- lapply(seq_along(train.execs), function(i) {mx.nd.copyto(preds[[i]], mx.cpu())})
+ labels <- lapply(seq_along(train.execs), function(i) {mx.nd.copyto(labels[[i]], mx.cpu())})
+ }
for (i in seq_len(ndevice)) {
- eval.metric <- metric$update(slices[[i]][[length(slices[[i]])]],
- out.preds[[i]], eval.metric)
+ eval.metric <- metric$update(label = labels[[i]],
+ pred = preds[[i]],
+ state = eval.metric)
}
}
}
@@ -187,7 +205,7 @@ mx.model.train.buckets <- function(symbol, ctx, train.data, eval.data,
result <- metric$get(eval.metric)
if (verbose) {
message("[", iteration, "] Validation-", result$name, "=",
- result$value)
+ result$value)
}
}
} else {
@@ -232,7 +250,7 @@ mx.model.buckets <- function(symbol, train.data, eval.data = NULL, metric = NULL
num.round = 1, begin.round = 1,
initializer = mx.init.uniform(0.01), optimizer = "sgd", ctx = NULL,
batch.end.callback = NULL, epoch.end.callback = NULL,
- kvstore = "local", verbose = TRUE) {
+ kvstore = "local", verbose = TRUE, metric_cpu = TRUE, gc_freq = NULL) {
if (!train.data$iter.next()) {
train.data$reset()
@@ -324,7 +342,7 @@ mx.model.buckets <- function(symbol, train.data, eval.data = NULL, metric = NULL
# kvstore initialization
kvstore <- mx.model.create.kvstore(kvstore, params$arg.params, length(ctx),
- verbose = verbose)
+ verbose = verbose)
### Execute training
model <- mx.model.train.buckets(symbol = symbol, ctx = ctx, train.data = train.data, eval.data = eval.data,
@@ -333,7 +351,7 @@ mx.model.buckets <- function(symbol, train.data, eval.data = NULL, metric = NULL
optimizer = optimizer, metric = metric,
begin.round = begin.round, end.round = num.round,
batch.end.callback = batch.end.callback, epoch.end.callback = epoch.end.callback,
- kvstore = kvstore, verbose = verbose)
+ kvstore = kvstore, verbose = verbose, metric_cpu = metric_cpu, gc_freq = gc_freq)
return(model)
}
diff --git a/R-package/R/rnn.infer.R b/R-package/R/rnn.infer.R
index 588056f2eb8..57201e721c0 100644
--- a/R-package/R/rnn.infer.R
+++ b/R-package/R/rnn.infer.R
@@ -1,7 +1,7 @@
#' Inference of RNN model
#'
-#' @param infer.data Data iterator created by mx.io.bucket.iter
+#' @param infer.data DataIter
#' @param model Model used for inference
#' @param ctx
#'
@@ -37,7 +37,7 @@ mx.infer.rnn <- function(infer.data, model, ctx = mx.cpu()) {
arguments.ini <- lapply(shapes$arg.shapes, function(shape) {
mx.nd.zeros(shape = shape, ctx = mx.cpu())
})
-
+
arg.params <- model$arg.params
arg.params.names <- names(arg.params)
aux.params <- model$aux.params
@@ -59,7 +59,7 @@ mx.infer.rnn <- function(infer.data, model, ctx = mx.cpu()) {
arg_update_idx <- match(arguments, update_names)
execs <- mx.symbol.bind(symbol = symbol, arg.arrays = c(dlist, arg.params.fix, arg.params)[arg_update_idx],
- aux.arrays = aux.params, ctx = ctx[[1]], grad.req = grad.req)
+ aux.arrays = aux.params, ctx = ctx[[1]], grad.req = grad.req)
# Initial input shapes - need to be adapted for multi-devices - divide highest
# dimension by device nb
@@ -69,10 +69,10 @@ mx.infer.rnn <- function(infer.data, model, ctx = mx.cpu()) {
while (infer.data$iter.next()) {
# Get input data slice
- dlist <- infer.data$value() #[input.names]
+ dlist <- infer.data$value()[input.names]
execs <- mx.symbol.bind(symbol = symbol, arg.arrays = c(dlist, execs$arg.arrays[arg.params.fix.names], execs$arg.arrays[arg.params.names])[arg_update_idx],
- aux.arrays = execs$aux.arrays, ctx = ctx[[1]], grad.req = grad.req)
+ aux.arrays = execs$aux.arrays, ctx = ctx[[1]], grad.req = grad.req)
mx.exec.forward(execs, is.train = FALSE)
@@ -225,7 +225,7 @@ mx.infer.rnn.one.unroll <- function(infer.data,
# init_state_shapes
init_states_names <- arguments[startsWith(arguments, "init_")]
- init_states_shapes = lapply(init_states_names, function(x) c(num_hidden, tail(input.shape[[1]], 1)))
+ init_states_shapes <- lapply(init_states_names, function(x) c(num_hidden, tail(input.shape[[1]], 1)))
names(init_states_shapes) <- init_states_names
shapes <- symbol$infer.shape(c(input.shape, init_states_shapes))
diff --git a/R-package/tests/testthat/test_model.R b/R-package/tests/testthat/test_model.R
index 13e54ff9ce4..6167ed66c41 100644
--- a/R-package/tests/testthat/test_model.R
+++ b/R-package/tests/testthat/test_model.R
@@ -80,8 +80,9 @@ test_that("Regression", {
lro <- mx.symbol.LinearRegressionOutput(fc1)
demo.metric.mae <- mx.metric.custom("mae", function(label, pred) {
- res <- mean(abs(label - pred))
- return(res)
+ pred <- mx.nd.reshape(pred, shape = 0)
+ res <- mx.nd.mean(mx.nd.abs(label-pred))
+ return(as.array(res))
})
mx.set.seed(0)
model <- mx.model.FeedForward.create(lro, X = train.x, y = train.y,
@@ -291,6 +292,8 @@ test_that("Captcha", {
captcha_net <- mx.symbol.SoftmaxOutput(data = fc2, label = label, name = "softmax")
mx.metric.acc2 <- mx.metric.custom("accuracy", function(label, pred) {
+ label = as.array(label)
+ pred = as.array(pred)
ypred <- max.col(t(pred)) - 1
ypred <- matrix(ypred, nrow = nrow(label), ncol = ncol(label), byrow = TRUE)
return(sum(colSums(label == ypred) == 4)/ncol(label))
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
users@infra.apache.org
With regards,
Apache Git Services