You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by sk...@apache.org on 2018/10/10 22:36:19 UTC

[incubator-mxnet] branch master updated: R fix metric shape (#12776)

This is an automated email from the ASF dual-hosted git repository.

skm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 822e59f  R fix metric shape (#12776)
822e59f is described below

commit 822e59f819f098e1745cf92b5e9351d7455fe5fc
Author: jeremiedb <je...@users.noreply.github.com>
AuthorDate: Wed Oct 10 18:36:05 2018 -0400

    R fix metric shape (#12776)
---
 R-package/R/metric.R | 24 +++++++++++++++---------
 1 file changed, 15 insertions(+), 9 deletions(-)

diff --git a/R-package/R/metric.R b/R-package/R/metric.R
index b29d0ea..8715ccf 100644
--- a/R-package/R/metric.R
+++ b/R-package/R/metric.R
@@ -59,7 +59,8 @@ mx.metric.top_k_accuracy <- mx.metric.custom("top_k_accuracy", function(label, p
 #'
 #' @export
 mx.metric.mse <- mx.metric.custom("mse", function(label, pred) {
-  pred <- mx.nd.reshape(pred, shape = 0)
+  label <- mx.nd.reshape(label, shape = -1)
+  pred <- mx.nd.reshape(pred, shape = -1)
   res <- mx.nd.mean(mx.nd.square(label-pred))
   return(as.array(res))
 })
@@ -68,7 +69,8 @@ mx.metric.mse <- mx.metric.custom("mse", function(label, pred) {
 #'
 #' @export
 mx.metric.rmse <- mx.metric.custom("rmse", function(label, pred) {
-  pred <- mx.nd.reshape(pred, shape = 0)
+  label <- mx.nd.reshape(label, shape = -1)
+  pred <- mx.nd.reshape(pred, shape = -1)
   res <- mx.nd.sqrt(mx.nd.mean(mx.nd.square(label-pred)))
   return(as.array(res))
 })
@@ -77,7 +79,8 @@ mx.metric.rmse <- mx.metric.custom("rmse", function(label, pred) {
 #'
 #' @export
 mx.metric.mae <- mx.metric.custom("mae", function(label, pred) {
-  pred <- mx.nd.reshape(pred, shape = 0)
+  label <- mx.nd.reshape(label, shape = -1)
+  pred <- mx.nd.reshape(pred, shape = -1)
   res <- mx.nd.mean(mx.nd.abs(label-pred))
   return(as.array(res))
 })
@@ -86,7 +89,8 @@ mx.metric.mae <- mx.metric.custom("mae", function(label, pred) {
 #'
 #' @export
 mx.metric.rmsle <- mx.metric.custom("rmsle", function(label, pred) {
-  pred <- mx.nd.reshape(pred, shape = 0)
+  label <- mx.nd.reshape(label, shape = -1)
+  pred <- mx.nd.reshape(pred, shape = -1)
   res <- mx.nd.sqrt(mx.nd.mean(mx.nd.square(mx.nd.log1p(pred) - mx.nd.log1p(label))))
   return(as.array(res))
 })
@@ -95,13 +99,13 @@ mx.metric.rmsle <- mx.metric.custom("rmsle", function(label, pred) {
 #'
 #' @export
 mx.metric.Perplexity <- mx.metric.custom("Perplexity", function(label, pred, mask_element = -1) {
-  
+
   label <- mx.nd.reshape(label, shape = -1)
   pred_probs <- mx.nd.pick(data = pred, index = label, axis = 1)
-  
+
   mask <- label != mask_element
   mask_length <- mx.nd.sum(mask)
-  
+
   NLL <- -mx.nd.sum(mx.nd.log(pred_probs) * mask) / mask_length
   res <- mx.nd.exp(NLL)
   return(as.array(res))
@@ -111,7 +115,8 @@ mx.metric.Perplexity <- mx.metric.custom("Perplexity", function(label, pred, mas
 #'
 #' @export
 mx.metric.logloss <- mx.metric.custom("logloss", function(label, pred) {
-  pred <- mx.nd.reshape(pred, shape = 0)
+  label <- mx.nd.reshape(label, shape = -1)
+  pred <- mx.nd.reshape(pred, shape = -1)
   pred <- mx.nd.clip(pred, a_min = 1e-15, a_max = 1-1e-15)
   res <- -mx.nd.mean(label * mx.nd.log(pred) + (1-label) * mx.nd.log(1-pred))
   return(as.array(res))
@@ -121,7 +126,8 @@ mx.metric.logloss <- mx.metric.custom("logloss", function(label, pred) {
 #'
 #' @export
 mx.metric.logistic_acc <- mx.metric.custom("accuracy", function(label, pred) {
-  pred <- mx.nd.reshape(pred, shape = 0) > 0.5
+  label <- mx.nd.reshape(label, shape = -1)
+  pred <- mx.nd.reshape(pred, shape = -1) > 0.5
   res <- mx.nd.mean(label == pred)
   return(as.array(res))
 })