You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by ss...@apache.org on 2022/02/15 15:59:53 UTC

[systemds] branch main updated: [MINOR] Handling null (NaN) values in input matrix

This is an automated email from the ASF dual-hosted git repository.

ssiddiqi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new cb6f87b  [MINOR] Handling null (NaN) values in input matrix
cb6f87b is described below

commit cb6f87b54a6f10e24c7098675b802573f1087ea3
Author: Shafaq Siddiqi <sh...@tugraz.at>
AuthorDate: Tue Feb 15 16:49:34 2022 +0100

    [MINOR] Handling null (NaN) values in input matrix
---
 scripts/builtin/multiLogRegPredict.dml | 9 ++++++++-
 1 file changed, 8 insertions(+), 1 deletion(-)

diff --git a/scripts/builtin/multiLogRegPredict.dml b/scripts/builtin/multiLogRegPredict.dml
index 2b1f80c..b17d8ca 100644
--- a/scripts/builtin/multiLogRegPredict.dml
+++ b/scripts/builtin/multiLogRegPredict.dml
@@ -51,12 +51,19 @@ m_multiLogRegPredict = function(Matrix[Double] X, Matrix[Double] B, Matrix[Doubl
   }
   if(ncol(X) < nrow(B)-1)
     stop("multiLogRegPredict: mismatching ncol(X) and nrow(B): "+ncol(X)+" "+nrow(B));
+  
+  # Robustness for datasets with missing values (causing NaN probabilities)
+  numNaNs = sum(isNaN(X))
+  if( numNaNs > 0 ) {
+    print("multiLogRegPredict: matrix X contains "+numNaNs+" missing values, replacing with 0.")
+    X = replace(target=X, pattern=NaN, replacement=0);
+  }
   accuracy = 0.0 # initialize variable 
   beta = B[1:ncol(X), ];
   intercept = ifelse(ncol(X)==nrow(B), matrix(0,1,ncol(B)), B[nrow(B),]);
   linear_terms = X %*% beta + matrix(1,nrow(X),1) %*% intercept;
 
-  M = probabilities(linear_terms); # compute the probablitites on unknown data
+  M = probabilities(linear_terms); # compute the probabilities on unknown data
   predicted_Y = rowIndexMax(M); # extract the class labels
 
   if(nrow(Y) != 0)