You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by qk...@apache.org on 2017/07/28 23:57:57 UTC
[incubator-mxnet] branch master updated: [R] allow users to use
other names than "label". close #7126 (#7232)
This is an automated email from the ASF dual-hosted git repository.
qkou pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 424143a [R] allow users to use other names than "label". close #7126 (#7232)
424143a is described below
commit 424143ac47ab3a38ae8aedaeb3319379887de0bc
Author: Qiang Kou (KK) <qk...@qkou.info>
AuthorDate: Fri Jul 28 23:57:52 2017 +0000
[R] allow users to use other names than "label". close #7126 (#7232)
---
R-package/R/model.R | 31 +++++++++++++++++--------------
R-package/tests/testthat/test_model.R | 11 +++++------
2 files changed, 22 insertions(+), 20 deletions(-)
diff --git a/R-package/R/model.R b/R-package/R/model.R
index 64cc816..2ee6624 100644
--- a/R-package/R/model.R
+++ b/R-package/R/model.R
@@ -116,15 +116,16 @@ mx.model.train <- function(symbol, ctx, input.shape, output.shape,
ndevice <- length(ctx)
if(verbose) message(paste0("Start training with ", ndevice, " devices"))
# create the executors
- sliceinfo <- mx.model.slice.shape(input.shape, ndevice)
- sliceinfo2 <- mx.model.slice.shape(output.shape, ndevice)
+ input_slice <- mx.model.slice.shape(input.shape, ndevice)
+ output_slice <- mx.model.slice.shape(output.shape, ndevice)
arg_names <- arguments(symbol)
- label_name <- arg_names[endsWith(arg_names, "label")]
+ output.names <- names(output.shape)
+ #label_name <- arg_names[endsWith(arg_names, "label")]
train.execs <- lapply(1:ndevice, function(i) {
arg_lst <- list(symbol = symbol, ctx = ctx[[i]], grad.req = "write")
- arg_lst <- append(arg_lst, sliceinfo[[i]]$shape)
- arg_lst <- append(arg_lst, sliceinfo2[[i]]$shape)
+ arg_lst <- append(arg_lst, input_slice[[i]]$shape)
+ arg_lst <- append(arg_lst, output_slice[[i]]$shape)
arg_lst[["fixed.param"]] = fixed.param
do.call(mx.simple.bind, arg_lst)
})
@@ -152,9 +153,6 @@ mx.model.train <- function(symbol, ctx, input.shape, output.shape,
kvstore$init(params.index, train.execs[[1]]$ref.arg.arrays[params.index])
}
# Get the input names
- # input.names <- mx.model.check.arguments(symbol)
- arg_names <- arguments(symbol)
- label_name <- arg_names[endsWith(arg_names, "label")]
for (iteration in begin.round:end.round) {
nbatch <- 0
@@ -165,14 +163,16 @@ mx.model.train <- function(symbol, ctx, input.shape, output.shape,
# Get input data slice
dlist <- train.data$value()
slices <- lapply(1:ndevice, function(i) {
- s <- sliceinfo[[i]]
+ s <- input_slice[[i]]
ret <- sapply(names(dlist), function(n) {mx.nd.slice(dlist[[n]], s$begin, s$end)})
return(ret)
})
# copy data to executor
for (i in 1:ndevice) {
s <- slices[[i]]
- names(s)[endsWith(names(s), "label")] = label_name
+ if (endsWith(output.names, "label")) {
+ names(s)[endsWith(names(s), "label")] = output.names
+ }
mx.exec.update.arg.arrays(train.execs[[i]], s, match.name=TRUE)
}
for (texec in train.execs) {
@@ -186,6 +186,7 @@ mx.model.train <- function(symbol, ctx, input.shape, output.shape,
for (texec in train.execs) {
mx.exec.backward(texec)
}
+
if (!is.null(kvstore)) {
# push the gradient
kvstore$push(params.index, lapply(train.execs, function(texec) {
@@ -214,7 +215,7 @@ mx.model.train <- function(symbol, ctx, input.shape, output.shape,
# Update the evaluation metrics
if (!is.null(metric)) {
for (i in 1 : ndevice) {
- train.metric <- metric$update(slices[[i]]$label, out.preds[[i]], train.metric)
+ train.metric <- metric$update(slices[[i]][[length(slices[[i]])]], out.preds[[i]], train.metric)
}
}
nbatch <- nbatch + 1
@@ -235,13 +236,15 @@ mx.model.train <- function(symbol, ctx, input.shape, output.shape,
while (eval.data$iter.next()) {
dlist <- eval.data$value()
slices <- lapply(1:ndevice, function(i) {
- s <- sliceinfo[[i]]
+ s <- input_slice[[i]]
ret <- sapply(names(dlist), function(n) {mx.nd.slice(dlist[[n]], s$begin, s$end)})
return(ret)
})
for (i in 1:ndevice) {
s <- slices[[i]]
- names(s)[endsWith(names(s), "label")] = label_name
+ if (endsWith(output.names, "label")) {
+ names(s)[endsWith(names(s), "label")] = output.names
+ }
mx.exec.update.arg.arrays(train.execs[[i]], s, match.name=TRUE)
}
for (texec in train.execs) {
@@ -252,7 +255,7 @@ mx.model.train <- function(symbol, ctx, input.shape, output.shape,
})
if (!is.null(metric)) {
for (i in 1 : ndevice) {
- eval.metric <- metric$update(slices[[i]]$label, out.preds[[i]], eval.metric)
+ eval.metric <- metric$update(slices[[i]][[length(slices[[i]])]] , out.preds[[i]], eval.metric)
}
}
}
diff --git a/R-package/tests/testthat/test_model.R b/R-package/tests/testthat/test_model.R
index 4cf2a8c..73a2127 100644
--- a/R-package/tests/testthat/test_model.R
+++ b/R-package/tests/testthat/test_model.R
@@ -162,12 +162,11 @@ test_that("Matrix Factorization", {
k <- 64
user <- mx.symbol.Variable("user")
item <- mx.symbol.Variable("item")
- score <- mx.symbol.Variable("label")
+ score <- mx.symbol.Variable("score")
user1 <- mx.symbol.Embedding(data = mx.symbol.BlockGrad(user), input_dim = max_user,
output_dim = k, name = "user1")
item1 <- mx.symbol.Embedding(data = mx.symbol.BlockGrad(item), input_dim = max_item,
- output_dim = k, name = "item1"
- )
+ output_dim = k, name = "item1")
pred <- user1 * item1
pred1 <- mx.symbol.sum_axis(pred, axis = 1, name = "pred1")
pred2 <- mx.symbol.Flatten(pred1, name = "pred2")
@@ -188,10 +187,10 @@ test_that("Matrix Factorization", {
value = function() {
user <- .self$iter1$value()$data
item <- .self$iter2$value()$data
- label <- .self$iter1$value()$label
+ score <- .self$iter1$value()$label
list(user = user,
item = item,
- label = label)
+ score = score)
},
iter.next = function() {
.self$iter1$iter.next()
@@ -224,5 +223,5 @@ test_that("Matrix Factorization", {
momentum = 0.9,
epoch.end.callback = mx.callback.log.train.metric(1),
input.names = c("user", "item"),
- output.names = "label")
+ output.names = "score")
})
--
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <co...@mxnet.apache.org>'].