You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2018/10/17 16:30:21 UTC

systemml git commit: [SYSTEMML-2329] Extended sampling-based sparsity estimator

Repository: systemml
Updated Branches:
  refs/heads/master f1b9d1c08 -> ca24ec564


[SYSTEMML-2329] Extended sampling-based sparsity estimator

This patch fixes the existing sampling-based estimator by optionally
removing its bias via an approach similar to element-wise addition used
in other estimators.

Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/ca24ec56
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/ca24ec56
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/ca24ec56

Branch: refs/heads/master
Commit: ca24ec5647dedbf6eb50bbc630ccee673b1b3320
Parents: f1b9d1c
Author: Matthias Boehm <mb...@gmail.com>
Authored: Wed Oct 17 18:29:56 2018 +0200
Committer: Matthias Boehm <mb...@gmail.com>
Committed: Wed Oct 17 18:29:56 2018 +0200

----------------------------------------------------------------------
 .../sysml/hops/estim/EstimatorSample.java       | 36 ++++++++++++++++----
 1 file changed, 29 insertions(+), 7 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/ca24ec56/src/main/java/org/apache/sysml/hops/estim/EstimatorSample.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/estim/EstimatorSample.java b/src/main/java/org/apache/sysml/hops/estim/EstimatorSample.java
index faf7d0e..ec624f0 100644
--- a/src/main/java/org/apache/sysml/hops/estim/EstimatorSample.java
+++ b/src/main/java/org/apache/sysml/hops/estim/EstimatorSample.java
@@ -37,22 +37,29 @@ import org.apache.sysml.runtime.util.UtilFunctions;
  * The basic idea is to draw random samples of aligned columns SA and rows SB,
  * and compute the output nnz as max(nnz(SA_i)*nnz(SB_i)). However, this estimator is
  * biased toward underestimation as the maximum is unlikely sampled and collisions are
- * not accounted for.
+ * not accounted for. Accordingly, we also support an extended estimator that relies
+ * on similar ideas for element-wise addition as the other estimators.
  */
 public class EstimatorSample extends SparsityEstimator
 {
 	private static final double SAMPLE_FRACTION = 0.1; //10%
 	
 	private final double _frac;
+	private final boolean _extended;
 	
 	public EstimatorSample() {
-		this(SAMPLE_FRACTION);
+		this(SAMPLE_FRACTION, false);
 	}
 	
 	public EstimatorSample(double sampleFrac) {
+		this(sampleFrac, false);
+	}
+	
+	public EstimatorSample(double sampleFrac, boolean extended) {
 		if( sampleFrac < 0 || sampleFrac > 1.0 )
 			throw new DMLRuntimeException("Invalid sample fraction: "+sampleFrac);
 		_frac = sampleFrac;
+		_extended = extended;
 	}
 	
 	@Override
@@ -73,12 +80,27 @@ public class EstimatorSample extends SparsityEstimator
 				int k =  m1.getNumColumns();
 				int[] ix = UtilFunctions.getSortedSampleIndexes(
 					k, (int)Math.max(k*_frac, 1));
+				int p = ix.length;
 				int[] cnnz = computeColumnNnz(m1, ix);
-				long nnzOut = 0;
-				for(int i=0; i<ix.length; i++)
-					nnzOut = Math.max(nnzOut, cnnz[i] * m2.recomputeNonZeros(ix[i], ix[i]));
-				return OptimizerUtils.getSparsity( 
-					m1.getNumRows(), m2.getNumColumns(), nnzOut);
+				if( _extended ) {
+					double ml = (long)m1.getNumRows()*m2.getNumColumns();
+					double sumS = 0, prodS = 1;
+					for(int i=0; i<ix.length; i++) {
+						long rnnz = m2.recomputeNonZeros(ix[i], ix[i]);
+						double v = (double)cnnz[i] * rnnz /ml;
+						sumS += v;
+						prodS *= 1-v;
+					}
+					return 1-Math.pow(1-1d/p * sumS, k - p) * prodS;
+				}
+				else {
+					//biased sampling-based estimator
+					long nnzOut = 0;
+					for(int i=0; i<p; i++)
+						nnzOut = Math.max(nnzOut, cnnz[i] * m2.recomputeNonZeros(ix[i], ix[i]));
+					return OptimizerUtils.getSparsity( 
+						m1.getNumRows(), m2.getNumColumns(), nnzOut);
+				}
 			}
 			case MULT: {
 				int k = Math.max(m1.getNumColumns(), m1.getNumRows());