You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by qk...@apache.org on 2017/07/28 23:57:57 UTC

[incubator-mxnet] branch master updated: [R] allow users to use other names than "label". close #7126 (#7232)

This is an automated email from the ASF dual-hosted git repository.

qkou pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 424143a  [R] allow users to use other names than "label". close #7126 (#7232)
424143a is described below

commit 424143ac47ab3a38ae8aedaeb3319379887de0bc
Author: Qiang Kou (KK) <qk...@qkou.info>
AuthorDate: Fri Jul 28 23:57:52 2017 +0000

    [R] allow users to use other names than "label". close #7126 (#7232)
---
 R-package/R/model.R                   | 31 +++++++++++++++++--------------
 R-package/tests/testthat/test_model.R | 11 +++++------
 2 files changed, 22 insertions(+), 20 deletions(-)

diff --git a/R-package/R/model.R b/R-package/R/model.R
index 64cc816..2ee6624 100644
--- a/R-package/R/model.R
+++ b/R-package/R/model.R
@@ -116,15 +116,16 @@ mx.model.train <- function(symbol, ctx, input.shape, output.shape,
   ndevice <- length(ctx)
   if(verbose) message(paste0("Start training with ", ndevice, " devices"))
   # create the executors
-  sliceinfo <- mx.model.slice.shape(input.shape, ndevice)
-  sliceinfo2 <- mx.model.slice.shape(output.shape, ndevice)
+  input_slice <- mx.model.slice.shape(input.shape, ndevice)
+  output_slice <- mx.model.slice.shape(output.shape, ndevice)
 
   arg_names <- arguments(symbol)
-  label_name <- arg_names[endsWith(arg_names, "label")]
+  output.names <- names(output.shape)
+  #label_name <- arg_names[endsWith(arg_names, "label")]
   train.execs <- lapply(1:ndevice, function(i) {
     arg_lst <- list(symbol = symbol, ctx = ctx[[i]], grad.req = "write")
-    arg_lst <- append(arg_lst, sliceinfo[[i]]$shape)
-    arg_lst <- append(arg_lst, sliceinfo2[[i]]$shape)
+    arg_lst <- append(arg_lst, input_slice[[i]]$shape)
+    arg_lst <- append(arg_lst, output_slice[[i]]$shape)
     arg_lst[["fixed.param"]] = fixed.param
     do.call(mx.simple.bind, arg_lst)
   })
@@ -152,9 +153,6 @@ mx.model.train <- function(symbol, ctx, input.shape, output.shape,
     kvstore$init(params.index, train.execs[[1]]$ref.arg.arrays[params.index])
   }
   # Get the input names
-  # input.names <- mx.model.check.arguments(symbol)
-  arg_names <- arguments(symbol)
-  label_name <- arg_names[endsWith(arg_names, "label")]
 
   for (iteration in begin.round:end.round) {
     nbatch <- 0
@@ -165,14 +163,16 @@ mx.model.train <- function(symbol, ctx, input.shape, output.shape,
       # Get input data slice
       dlist <- train.data$value()
       slices <- lapply(1:ndevice, function(i) {
-        s <- sliceinfo[[i]]
+        s <- input_slice[[i]]
         ret <- sapply(names(dlist), function(n) {mx.nd.slice(dlist[[n]], s$begin, s$end)})
         return(ret)
       })
       # copy data to executor
       for (i in 1:ndevice) {
         s <- slices[[i]]
-        names(s)[endsWith(names(s), "label")] = label_name
+        if (endsWith(output.names, "label")) {
+          names(s)[endsWith(names(s), "label")] = output.names 
+        }
         mx.exec.update.arg.arrays(train.execs[[i]], s, match.name=TRUE)
       }
       for (texec in train.execs) {
@@ -186,6 +186,7 @@ mx.model.train <- function(symbol, ctx, input.shape, output.shape,
       for (texec in train.execs) {
         mx.exec.backward(texec)
       }
+      
       if (!is.null(kvstore)) {
         # push the gradient
         kvstore$push(params.index, lapply(train.execs, function(texec) {
@@ -214,7 +215,7 @@ mx.model.train <- function(symbol, ctx, input.shape, output.shape,
       # Update the evaluation metrics
       if (!is.null(metric)) {
         for (i in 1 : ndevice) {
-          train.metric <- metric$update(slices[[i]]$label, out.preds[[i]], train.metric)
+          train.metric <- metric$update(slices[[i]][[length(slices[[i]])]], out.preds[[i]], train.metric)
         }
       }
       nbatch <- nbatch + 1
@@ -235,13 +236,15 @@ mx.model.train <- function(symbol, ctx, input.shape, output.shape,
       while (eval.data$iter.next()) {
         dlist <- eval.data$value()
         slices <- lapply(1:ndevice, function(i) {
-          s <- sliceinfo[[i]]
+          s <- input_slice[[i]]
           ret <- sapply(names(dlist), function(n) {mx.nd.slice(dlist[[n]], s$begin, s$end)})
           return(ret)
         })
         for (i in 1:ndevice) {
           s <- slices[[i]]
-          names(s)[endsWith(names(s), "label")] = label_name
+          if (endsWith(output.names, "label")) {
+            names(s)[endsWith(names(s), "label")] = output.names 
+          }
           mx.exec.update.arg.arrays(train.execs[[i]], s, match.name=TRUE)
         }
         for (texec in train.execs) {
@@ -252,7 +255,7 @@ mx.model.train <- function(symbol, ctx, input.shape, output.shape,
         })
         if (!is.null(metric)) {
           for (i in 1 : ndevice) {
-            eval.metric <- metric$update(slices[[i]]$label, out.preds[[i]], eval.metric)
+            eval.metric <- metric$update(slices[[i]][[length(slices[[i]])]] , out.preds[[i]], eval.metric)
           }
         }
       }
diff --git a/R-package/tests/testthat/test_model.R b/R-package/tests/testthat/test_model.R
index 4cf2a8c..73a2127 100644
--- a/R-package/tests/testthat/test_model.R
+++ b/R-package/tests/testthat/test_model.R
@@ -162,12 +162,11 @@ test_that("Matrix Factorization", {
   k <- 64
   user <- mx.symbol.Variable("user")
   item <- mx.symbol.Variable("item")
-  score <- mx.symbol.Variable("label")
+  score <- mx.symbol.Variable("score")
   user1 <- mx.symbol.Embedding(data = mx.symbol.BlockGrad(user), input_dim = max_user,
                                output_dim = k, name = "user1")
   item1 <- mx.symbol.Embedding(data = mx.symbol.BlockGrad(item), input_dim = max_item,
-                               output_dim = k, name = "item1"
-    )
+                               output_dim = k, name = "item1")
   pred <- user1 * item1
   pred1 <- mx.symbol.sum_axis(pred, axis = 1, name = "pred1")
   pred2 <- mx.symbol.Flatten(pred1, name = "pred2")
@@ -188,10 +187,10 @@ test_that("Matrix Factorization", {
         value = function() {
           user <- .self$iter1$value()$data
           item <- .self$iter2$value()$data
-          label <- .self$iter1$value()$label
+          score <- .self$iter1$value()$label
           list(user = user,
                item = item,
-               label = label)
+               score = score)
         },
         iter.next = function() {
           .self$iter1$iter.next()
@@ -224,5 +223,5 @@ test_that("Matrix Factorization", {
                                        momentum = 0.9,
                                        epoch.end.callback = mx.callback.log.train.metric(1),
                                        input.names = c("user", "item"),
-                                       output.names = "label")
+                                       output.names = "score")
 })

-- 
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <co...@mxnet.apache.org>'].