You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by mb...@apache.org on 2021/10/10 16:27:31 UTC

[systemds] branch master updated: [SYSTEMDS-3156] Fix performance naiveBayes on distributed data

This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new 72188a7  [SYSTEMDS-3156] Fix performance naiveBayes on distributed data
72188a7 is described below

commit 72188a7a52990150066a1d57fa415dd7551d4f85
Author: Matthias Boehm <mb...@gmail.com>
AuthorDate: Sun Oct 10 18:27:00 2021 +0200

    [SYSTEMDS-3156] Fix performance naiveBayes on distributed data
    
    This patch fixes the performance of the new naiveBayes builtin function
    which showed bad performance on the perftest mutinomial 5 classes, 80GB
    (10Mx1K, dense). Compared to the original algorithm, the removed stop
    conditions led to a scenario where during recompilation the number of
    classes could not be inferred rendering the aggregate output distributed
    and being used in fallback cpmm matrix multiplications.
    
    We now properly split DAGs after aggregate with unknown number of
    classes (which already solved the issue), and restored the additional
    sanity checks. On the perftest scenario performance improved from 876s
    to 43s (20x).
---
 scripts/builtin/naiveBayes.dml                     | 38 +++++++++++++++++-----
 .../apache/sysds/hops/ParameterizedBuiltinOp.java  | 12 +++++++
 .../RewriteSplitDagDataDependentOperators.java     |  2 ++
 3 files changed, 44 insertions(+), 8 deletions(-)

diff --git a/scripts/builtin/naiveBayes.dml b/scripts/builtin/naiveBayes.dml
index 313ec09..7911cd0 100644
--- a/scripts/builtin/naiveBayes.dml
+++ b/scripts/builtin/naiveBayes.dml
@@ -19,19 +19,41 @@
 #
 #-------------------------------------------------------------
 
-m_naiveBayes = function(Matrix[Double] D, Matrix[Double] C, Double laplace = 1, Boolean verbose = TRUE)
-  return (Matrix[Double] prior, Matrix[Double] classConditionals) 
+m_naiveBayes = function(Matrix[Double] D,
+  Matrix[Double] C, Double laplace = 1, Boolean verbose = TRUE)
+  return (Matrix[Double] prior, Matrix[Double] classConditionals)
 {
   laplaceCorrection = laplace;
-  numRows = nrow(D);
-  numFeatures = ncol(D);
-  numClasses = max(C);
+  numRows = nrow(D)
+  numFeatures = ncol(D)
+  minFeatureVal = min(D)
+  numClasses = as.integer(max(C))
+  minLabelVal = min(C)
+
+  # sanity checks of data and arguments
+  if(minFeatureVal < 0)
+    stop("naiveBayes: Stopping due to invalid argument: Multinomial naive Bayes "
+       + " is meant for count-based feature values, minimum value in X is negative")
+  if(numRows < 2)
+    stop("naiveBayes: Stopping due to invalid inputs: "
+       + "Not possible to learn a classifier without at least 2 rows")
+  if(minLabelVal < 1)
+    stop("naiveBayes: Stopping due to invalid argument: Label vector (Y) must be recoded")
+  if(numClasses == 1)
+    stop("naiveBayes: Stopping due to invalid argument: "
+       + "Maximum label value is 1, need more than one class to learn a multi-class classifier")	
+  if(sum(abs(C%%1 == 0)) != numRows)
+    stop("naiveBayes: Stopping due to invalid argument: " 
+       + "Please ensure that Y contains (positive) integral labels")
+  if(laplaceCorrection < 0)
+    stop("naiveBayes: Stopping due to invalid argument: "
+       + "Laplacian correction (laplace) must be non-negative")
 
   # Compute conditionals
   # Compute the feature counts for each class
-  classFeatureCounts = aggregate(target=D, groups=C, fn="sum", ngroups=as.integer(numClasses));
+  classFeatureCounts = aggregate(target=D, groups=C, fn="sum", ngroups=numClasses);
 
-  # Compute the total feature count for each class 
+  # Compute the total feature count for each class
   # and add the number of features to this sum
   # for subsequent regularization (Laplace's rule)
   classSums = rowSums(classFeatureCounts) + numFeatures*laplaceCorrection;
@@ -40,7 +62,7 @@ m_naiveBayes = function(Matrix[Double] D, Matrix[Double] C, Double laplace = 1,
   classConditionals = (classFeatureCounts + laplaceCorrection) / classSums;
 
   # Compute class priors
-  classCounts = aggregate(target=C, groups=C, fn="count", ngroups=as.integer(numClasses));
+  classCounts = aggregate(target=C, groups=C, fn="count", ngroups=numClasses);
   prior = classCounts / numRows;
 
   # Compute accuracy on training set
diff --git a/src/main/java/org/apache/sysds/hops/ParameterizedBuiltinOp.java b/src/main/java/org/apache/sysds/hops/ParameterizedBuiltinOp.java
index 8c9666b..7f9d71e 100644
--- a/src/main/java/org/apache/sysds/hops/ParameterizedBuiltinOp.java
+++ b/src/main/java/org/apache/sysds/hops/ParameterizedBuiltinOp.java
@@ -967,6 +967,18 @@ public class ParameterizedBuiltinOp extends MultiThreadedHop {
 		return ret;
 	}
 
+	public boolean isKnownNGroups() {
+		try {
+			Hop ngroups = getParameterHop(Statement.GAGG_NUM_GROUPS);
+			return (ngroups != null 
+				&& (ngroups instanceof LiteralOp | ngroups instanceof DataOp));
+		}
+		catch(Exception ex) {
+			LOG.warn("Known groups check exception: " + ex.getMessage());
+		}
+		return false;
+	}
+	
 	public boolean isTargetDiagInput() {
 		Hop targetHop = getTargetHop();
 		//input vector (guarantees diagV2M), implies remove rows
diff --git a/src/main/java/org/apache/sysds/hops/rewrite/RewriteSplitDagDataDependentOperators.java b/src/main/java/org/apache/sysds/hops/rewrite/RewriteSplitDagDataDependentOperators.java
index d0a464a..25dbd78 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/RewriteSplitDagDataDependentOperators.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteSplitDagDataDependentOperators.java
@@ -318,6 +318,8 @@ public class RewriteSplitDagDataDependentOperators extends StatementBlockRewrite
 	private static boolean isBasicDataDependentOperator(Hop hop, boolean noSplitRequired) {
 		return (HopRewriteUtils.isNary(hop, OpOpN.EVAL) & !noSplitRequired)
 			|| (HopRewriteUtils.isData(hop, OpOpData.SQLREAD) & !noSplitRequired)
+			|| (HopRewriteUtils.isParameterBuiltinOp(hop, ParamBuiltinOp.GROUPEDAGG) 
+				&& !((ParameterizedBuiltinOp)hop).isKnownNGroups() && !noSplitRequired)
 			|| ((HopRewriteUtils.isUnary(hop, OpOp1.COMPRESS) || hop.requiresCompression()) &&
 				(!HopRewriteUtils.hasOnlyWriteParents(hop, true, true)));
 		//note: for compression we probe for write parents (part of noSplitRequired) directly