You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by ss...@apache.org on 2022/02/15 15:59:53 UTC
[systemds] branch main updated: [MINOR] Handling null (NaN) values in input matrix
This is an automated email from the ASF dual-hosted git repository.
ssiddiqi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new cb6f87b [MINOR] Handling null (NaN) values in input matrix
cb6f87b is described below
commit cb6f87b54a6f10e24c7098675b802573f1087ea3
Author: Shafaq Siddiqi <sh...@tugraz.at>
AuthorDate: Tue Feb 15 16:49:34 2022 +0100
[MINOR] Handling null (NaN) values in input matrix
---
scripts/builtin/multiLogRegPredict.dml | 9 ++++++++-
1 file changed, 8 insertions(+), 1 deletion(-)
diff --git a/scripts/builtin/multiLogRegPredict.dml b/scripts/builtin/multiLogRegPredict.dml
index 2b1f80c..b17d8ca 100644
--- a/scripts/builtin/multiLogRegPredict.dml
+++ b/scripts/builtin/multiLogRegPredict.dml
@@ -51,12 +51,19 @@ m_multiLogRegPredict = function(Matrix[Double] X, Matrix[Double] B, Matrix[Doubl
}
if(ncol(X) < nrow(B)-1)
stop("multiLogRegPredict: mismatching ncol(X) and nrow(B): "+ncol(X)+" "+nrow(B));
+
+ # Robustness for datasets with missing values (causing NaN probabilities)
+ numNaNs = sum(isNaN(X))
+ if( numNaNs > 0 ) {
+ print("multiLogRegPredict: matrix X contains "+numNaNs+" missing values, replacing with 0.")
+ X = replace(target=X, pattern=NaN, replacement=0);
+ }
accuracy = 0.0 # initialize variable
beta = B[1:ncol(X), ];
intercept = ifelse(ncol(X)==nrow(B), matrix(0,1,ncol(B)), B[nrow(B),]);
linear_terms = X %*% beta + matrix(1,nrow(X),1) %*% intercept;
- M = probabilities(linear_terms); # compute the probablitites on unknown data
+ M = probabilities(linear_terms); # compute the probabilities on unknown data
predicted_Y = rowIndexMax(M); # extract the class labels
if(nrow(Y) != 0)