You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by mb...@apache.org on 2022/02/26 15:41:49 UTC

[systemds] branch main updated: [SYSTEMDS-3298] Performance ultra-sparse ctableexpand via CSR outputs

This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new f541900  [SYSTEMDS-3298] Performance ultra-sparse ctableexpand via CSR outputs
f541900 is described below

commit f54190013cdd151653377a49c45c07065c3ebb80
Author: Matthias Boehm <mb...@gmail.com>
AuthorDate: Sat Feb 26 16:41:24 2022 +0100

    [SYSTEMDS-3298] Performance ultra-sparse ctableexpand via CSR outputs
    
    This patch continues with further performance improvements for the NLP
    embedding use case of encoded abstracts with concatenated word
    embeddings, padded to the maximum abstract length (transformapply,
    ctable, ba*+, reshape). For ctablexpand such as table(seq, vect), we now
    directly output the ultra-sparse result in CSR.
    
    For 10 iterations of 20K abstracts, max 1000 words/abstract,
    2.5M distinct words, and 300dim word embeddings, this patch improved
    performance as follows: 10x ctable: 13.6s->2.4s (82s->67s overall).
---
 .../runtime/compress/colgroup/ColGroupDeltaDDC.java  |  2 ++
 .../colgroup/dictionary/DeltaDictionary.java         |  3 ++-
 .../runtime/controlprogram/WhileProgramBlock.java    |  4 ----
 .../apache/sysds/runtime/data/SparseBlockCSR.java    | 18 ++++++++++++++++++
 .../apache/sysds/runtime/functionobjects/CTable.java |  5 +++--
 .../sysds/runtime/matrix/data/MatrixBlock.java       | 20 ++++++++++++++++----
 .../transform/encode/ColumnEncoderPassThrough.java   |  2 --
 .../org/apache/sysds/runtime/util/UtilFunctions.java |  4 +++-
 .../functions/pipelines/BuiltinTopkEvaluateTest.java |  1 +
 9 files changed, 45 insertions(+), 14 deletions(-)

diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDeltaDDC.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDeltaDDC.java
index 0949dae..56b945e 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDeltaDDC.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDeltaDDC.java
@@ -32,6 +32,8 @@ import org.apache.sysds.runtime.matrix.operators.ScalarOperator;
  */
 public class ColGroupDeltaDDC extends ColGroupDDC {
 
+	private static final long serialVersionUID = -1045556313148564147L;
+
 	/**
 	 * Constructor for serialization
 	 *
diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DeltaDictionary.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DeltaDictionary.java
index eb0af85..ba38772 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DeltaDictionary.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DeltaDictionary.java
@@ -31,7 +31,8 @@ import org.apache.sysds.runtime.matrix.operators.ScalarOperator;
  * encoded values are implemented.
  */
 public class DeltaDictionary extends Dictionary {
-
+	private static final long serialVersionUID = -5700139221491143705L;
+	
 	private final int _numCols;
 
 	public DeltaDictionary(double[] values, int numCols) {
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/WhileProgramBlock.java b/src/main/java/org/apache/sysds/runtime/controlprogram/WhileProgramBlock.java
index 4695b94..7dea91b 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/WhileProgramBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/WhileProgramBlock.java
@@ -20,12 +20,8 @@
 package org.apache.sysds.runtime.controlprogram;
 
 import java.util.ArrayList;
-import java.util.List;
 
 import org.apache.sysds.hops.Hop;
-import org.apache.sysds.parser.ForStatement;
-import org.apache.sysds.parser.Statement;
-import org.apache.sysds.parser.StatementBlock;
 import org.apache.sysds.parser.WhileStatementBlock;
 import org.apache.sysds.api.DMLScript;
 import org.apache.sysds.common.Types.ValueType;
diff --git a/src/main/java/org/apache/sysds/runtime/data/SparseBlockCSR.java b/src/main/java/org/apache/sysds/runtime/data/SparseBlockCSR.java
index d4921e9..4093ed9 100644
--- a/src/main/java/org/apache/sysds/runtime/data/SparseBlockCSR.java
+++ b/src/main/java/org/apache/sysds/runtime/data/SparseBlockCSR.java
@@ -334,6 +334,24 @@ public class SparseBlockCSR extends SparseBlock
 	public void compact(int r) {
 		//do nothing everything preallocated
 	}
+	
+	public void compact() {
+		int pos = 0;
+		for(int i=0; i<numRows(); i++) {
+			int apos = pos(i);
+			int alen = size(i);
+			_ptr[i] = pos;
+			for(int j=apos; j<apos+alen; j++) {
+				if( _values[j] != 0 ){
+					_values[pos] = _values[j];
+					_indexes[pos] = _indexes[j];
+					pos++;
+				}
+			}
+		}
+		_ptr[numRows()] = pos;
+		_size = pos; //adjust logical size
+	}
 
 	@Override
 	public int numRows() {
diff --git a/src/main/java/org/apache/sysds/runtime/functionobjects/CTable.java b/src/main/java/org/apache/sysds/runtime/functionobjects/CTable.java
index 291effa..0437a8c 100644
--- a/src/main/java/org/apache/sysds/runtime/functionobjects/CTable.java
+++ b/src/main/java/org/apache/sysds/runtime/functionobjects/CTable.java
@@ -115,7 +115,7 @@ public class CTable extends ValueFunction
 				ctableResult.quickGetValue((int)row-1, (int)col-1) + w);
 	}
 
-	public int execute(int row, double v2, double w, int maxCol, MatrixBlock ctableResult) 
+	public int execute(int row, double v2, double w, int maxCol, int[] retIx, double[] retVals) 
 	{
 		// If any of the values are NaN (i.e., missing) then 
 		// we skip this tuple, proceed to the next tuple
@@ -130,7 +130,8 @@ public class CTable extends ValueFunction
 		} 
 		
 		//set weight as value (expand is guaranteed to address different cells)
-		ctableResult.quickSetValue(row-1, col-1, w);
+		retIx[row - 1] = col - 1;
+		retVals[row - 1] = w;
 		
 		//maintain max seen col 
 		return Math.max(maxCol, col);
diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
index bc7f05f..863775d 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
@@ -5465,11 +5465,16 @@ public class MatrixBlock extends MatrixValue implements CacheBlock, Externalizab
 	 * @param updateClen when this matrix already has the desired number of columns updateClen can be set to false
 	 * @return resultBlock
 	 */
-	public MatrixBlock ctableSeqOperations(MatrixValue thatMatrix, double thatScalar, MatrixBlock resultBlock, boolean updateClen) {
+	public MatrixBlock ctableSeqOperations(MatrixValue thatMatrix, double thatScalar, MatrixBlock ret, boolean updateClen) {
 		MatrixBlock that = checkType(thatMatrix);
 		CTable ctable = CTable.getCTableFnObject();
 		double w = thatScalar;
 		
+		//prepare allocation of CSR sparse block
+		int[] rptr = new int[rlen+1];
+		int[] indexes = new int[rlen];
+		double[] values = new double[rlen];
+		
 		//sparse-unsafe ctable execution
 		//(because input values of 0 are invalid and have to result in errors)
 		//resultBlock guaranteed to be allocated for ctableexpand
@@ -5477,16 +5482,23 @@ public class MatrixBlock extends MatrixValue implements CacheBlock, Externalizab
 		int maxCol = 0;
 		for( int i=0; i<rlen; i++ ) {
 			double v2 = that.quickGetValue(i, 0);
-			maxCol = ctable.execute(i+1, v2, w, maxCol, resultBlock);
+			maxCol = ctable.execute(i+1, v2, w, maxCol, indexes, values);
+			rptr[i] = i;
 		}
+		rptr[rlen] = rlen;
 
+		//construct sparse CSR block from filled arrays
+		ret.sparseBlock = new SparseBlockCSR(rptr, indexes, values, rlen);
+		((SparseBlockCSR)ret.sparseBlock).compact();
+		ret.setNonZeros(ret.sparseBlock.size());
+		
 		//update meta data (initially unknown number of columns)
 		//note: nnz maintained in ctable (via quickset)
 		if(updateClen) {
-			resultBlock.clen = maxCol;
+			ret.clen = maxCol;
 		}
 
-		return resultBlock;
+		return ret;
 	}
 
 	/**
diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderPassThrough.java b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderPassThrough.java
index a134dfd..04df3d5 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderPassThrough.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderPassThrough.java
@@ -22,9 +22,7 @@ package org.apache.sysds.runtime.transform.encode;
 import static org.apache.sysds.runtime.util.UtilFunctions.getEndIndex;
 
 import java.util.ArrayList;
-import java.util.HashSet;
 import java.util.List;
-import java.util.Set;
 
 import org.apache.sysds.api.DMLScript;
 import org.apache.sysds.runtime.controlprogram.caching.CacheBlock;
diff --git a/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java b/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java
index 6ca83be..17db8f6 100644
--- a/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java
+++ b/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java
@@ -878,6 +878,7 @@ public class UtilFunctions {
 			.map(DATE_FORMATS::get).orElseThrow(() -> new NullPointerException("Unknown date format."));
 	}
 
+	@SuppressWarnings("unused")
 	private static int findDateCol (FrameBlock block) {
 		int cols = block.getNumColumns();
 		int[] match_counter = new int[cols];
@@ -946,7 +947,8 @@ public class UtilFunctions {
 				if (!currentFormat.equals(dominantFormat)){
 					curr.applyPattern(dominantFormat);
 				}
-				String newDate = curr.format(date); //convert date to dominant date format
+				//FIME: unused newDate
+				//String newDate = curr.format(date); //convert date to dominant date format
 				output[i] =  curr.format(date); //convert back to datestring
 			} catch (ParseException e) {
 				throw new DMLRuntimeException(e);
diff --git a/src/test/java/org/apache/sysds/test/functions/pipelines/BuiltinTopkEvaluateTest.java b/src/test/java/org/apache/sysds/test/functions/pipelines/BuiltinTopkEvaluateTest.java
index 71160b7..7034335 100644
--- a/src/test/java/org/apache/sysds/test/functions/pipelines/BuiltinTopkEvaluateTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/pipelines/BuiltinTopkEvaluateTest.java
@@ -46,6 +46,7 @@ public class BuiltinTopkEvaluateTest extends AutomatedTestBase {
 
 	//TODO: debug test failure in git actions
 	@Ignore
+	@Test
 	public void testEvalPipClass() {
 		evalPip(0.8, "FALSE", INPUT+"/classification/", Types.ExecMode.SINGLE_NODE);
 	}