You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by sk...@apache.org on 2018/10/10 22:36:19 UTC
[incubator-mxnet] branch master updated: R fix metric shape (#12776)
This is an automated email from the ASF dual-hosted git repository.
skm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 822e59f R fix metric shape (#12776)
822e59f is described below
commit 822e59f819f098e1745cf92b5e9351d7455fe5fc
Author: jeremiedb <je...@users.noreply.github.com>
AuthorDate: Wed Oct 10 18:36:05 2018 -0400
R fix metric shape (#12776)
---
R-package/R/metric.R | 24 +++++++++++++++---------
1 file changed, 15 insertions(+), 9 deletions(-)
diff --git a/R-package/R/metric.R b/R-package/R/metric.R
index b29d0ea..8715ccf 100644
--- a/R-package/R/metric.R
+++ b/R-package/R/metric.R
@@ -59,7 +59,8 @@ mx.metric.top_k_accuracy <- mx.metric.custom("top_k_accuracy", function(label, p
#'
#' @export
mx.metric.mse <- mx.metric.custom("mse", function(label, pred) {
- pred <- mx.nd.reshape(pred, shape = 0)
+ label <- mx.nd.reshape(label, shape = -1)
+ pred <- mx.nd.reshape(pred, shape = -1)
res <- mx.nd.mean(mx.nd.square(label-pred))
return(as.array(res))
})
@@ -68,7 +69,8 @@ mx.metric.mse <- mx.metric.custom("mse", function(label, pred) {
#'
#' @export
mx.metric.rmse <- mx.metric.custom("rmse", function(label, pred) {
- pred <- mx.nd.reshape(pred, shape = 0)
+ label <- mx.nd.reshape(label, shape = -1)
+ pred <- mx.nd.reshape(pred, shape = -1)
res <- mx.nd.sqrt(mx.nd.mean(mx.nd.square(label-pred)))
return(as.array(res))
})
@@ -77,7 +79,8 @@ mx.metric.rmse <- mx.metric.custom("rmse", function(label, pred) {
#'
#' @export
mx.metric.mae <- mx.metric.custom("mae", function(label, pred) {
- pred <- mx.nd.reshape(pred, shape = 0)
+ label <- mx.nd.reshape(label, shape = -1)
+ pred <- mx.nd.reshape(pred, shape = -1)
res <- mx.nd.mean(mx.nd.abs(label-pred))
return(as.array(res))
})
@@ -86,7 +89,8 @@ mx.metric.mae <- mx.metric.custom("mae", function(label, pred) {
#'
#' @export
mx.metric.rmsle <- mx.metric.custom("rmsle", function(label, pred) {
- pred <- mx.nd.reshape(pred, shape = 0)
+ label <- mx.nd.reshape(label, shape = -1)
+ pred <- mx.nd.reshape(pred, shape = -1)
res <- mx.nd.sqrt(mx.nd.mean(mx.nd.square(mx.nd.log1p(pred) - mx.nd.log1p(label))))
return(as.array(res))
})
@@ -95,13 +99,13 @@ mx.metric.rmsle <- mx.metric.custom("rmsle", function(label, pred) {
#'
#' @export
mx.metric.Perplexity <- mx.metric.custom("Perplexity", function(label, pred, mask_element = -1) {
-
+
label <- mx.nd.reshape(label, shape = -1)
pred_probs <- mx.nd.pick(data = pred, index = label, axis = 1)
-
+
mask <- label != mask_element
mask_length <- mx.nd.sum(mask)
-
+
NLL <- -mx.nd.sum(mx.nd.log(pred_probs) * mask) / mask_length
res <- mx.nd.exp(NLL)
return(as.array(res))
@@ -111,7 +115,8 @@ mx.metric.Perplexity <- mx.metric.custom("Perplexity", function(label, pred, mas
#'
#' @export
mx.metric.logloss <- mx.metric.custom("logloss", function(label, pred) {
- pred <- mx.nd.reshape(pred, shape = 0)
+ label <- mx.nd.reshape(label, shape = -1)
+ pred <- mx.nd.reshape(pred, shape = -1)
pred <- mx.nd.clip(pred, a_min = 1e-15, a_max = 1-1e-15)
res <- -mx.nd.mean(label * mx.nd.log(pred) + (1-label) * mx.nd.log(1-pred))
return(as.array(res))
@@ -121,7 +126,8 @@ mx.metric.logloss <- mx.metric.custom("logloss", function(label, pred) {
#'
#' @export
mx.metric.logistic_acc <- mx.metric.custom("accuracy", function(label, pred) {
- pred <- mx.nd.reshape(pred, shape = 0) > 0.5
+ label <- mx.nd.reshape(label, shape = -1)
+ pred <- mx.nd.reshape(pred, shape = -1) > 0.5
res <- mx.nd.mean(label == pred)
return(as.array(res))
})