You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by ar...@apache.org on 2022/05/25 06:40:18 UTC

[systemds] branch main updated: [SYSTEMDS-3367] Integrate UDF encoders in task graph

This is an automated email from the ASF dual-hosted git repository.

arnabp20 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 038864e1ba [SYSTEMDS-3367] Integrate UDF encoders in task graph
038864e1ba is described below

commit 038864e1ba3bf0f288b426a25e2b0da9f7bcf928
Author: arnabp <ar...@tugraz.at>
AuthorDate: Wed May 25 10:26:35 2022 +0530

    [SYSTEMDS-3367] Integrate UDF encoders in task graph
    
    This patch integrates UDF-based transformencoders into the
    task graph to allow concurrent execution. As we cannot estimate
    the sparsity of an arbitrary UDF output, we always allocate
    dense output matrix if at least one UDF is present. If an UDF
    comes after a dummycode, we slice the expanded columns and
    apply the UDF. Moreover, we disable row partitioning for UDFs.
    
    Closes #1623
---
 .../transform/encode/ColumnEncoderComposite.java   |  8 +++-
 .../transform/encode/ColumnEncoderFeatureHash.java |  7 +++-
 .../runtime/transform/encode/ColumnEncoderUDF.java | 47 ++++++++++++++++++----
 .../transform/encode/MultiColumnEncoder.java       | 17 ++++----
 .../sysds/utils/stats/TransformStatistics.java     | 13 +++++-
 5 files changed, 71 insertions(+), 21 deletions(-)

diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderComposite.java b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderComposite.java
index 243bfe7caa..608b1de3f1 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderComposite.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderComposite.java
@@ -269,6 +269,9 @@ public class ColumnEncoderComposite extends ColumnEncoder {
 		ColumnEncoderDummycode dc = getEncoder(ColumnEncoderDummycode.class);
 		if(dc != null)
 			dc.updateDomainSizes(_columnEncoders);
+		ColumnEncoderUDF udf = getEncoder(ColumnEncoderUDF.class);
+		if (udf != null && dc != null)
+			udf.updateDomainSizes(_columnEncoders);
 	}
 
 	public void addEncoder(ColumnEncoder other) {
@@ -385,7 +388,10 @@ public class ColumnEncoderComposite extends ColumnEncoder {
 	public void setNumPartitions(int nBuild, int nApply) {
 			_columnEncoders.forEach(e -> {
 				e.setBuildRowBlocksPerColumn(nBuild);
-				e.setApplyRowBlocksPerColumn(nApply);
+				if (e.getClass().equals(ColumnEncoderUDF.class))
+					e.setApplyRowBlocksPerColumn(1);
+				else
+					e.setApplyRowBlocksPerColumn(nApply);
 			});
 	}
 
diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderFeatureHash.java b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderFeatureHash.java
index 161959441b..fec65dd93d 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderFeatureHash.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderFeatureHash.java
@@ -83,8 +83,11 @@ public class ColumnEncoderFeatureHash extends ColumnEncoder {
 				codes[i-startInd] = Double.NaN;
 			else {
 				// Calculate non-negative modulo
-				double mod = key.hashCode() % _K > 0 ? key.hashCode() % _K : _K + key.hashCode() % _K;
-				codes[i - startInd] = mod + 1;
+				//double mod = key.hashCode() % _K > 0 ? key.hashCode() % _K : _K + key.hashCode() % _K;
+				double mod = (key.hashCode() % _K) + 1;
+				if (mod < 0)
+					mod += _K;
+				codes[i - startInd] = mod;
 			}
 		}
 		return codes;
diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderUDF.java b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderUDF.java
index a3f76623f2..00e588ee17 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderUDF.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderUDF.java
@@ -21,6 +21,7 @@ package org.apache.sysds.runtime.transform.encode;
 
 import java.util.List;
 
+import org.apache.sysds.api.DMLScript;
 import org.apache.sysds.common.Types.DataType;
 import org.apache.sysds.common.Types.ValueType;
 import org.apache.sysds.parser.DMLProgram;
@@ -28,7 +29,6 @@ import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.controlprogram.Program;
 import org.apache.sysds.runtime.controlprogram.caching.CacheBlock;
 import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
-import org.apache.sysds.runtime.controlprogram.caching.MatrixObject.UpdateType;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContextFactory;
 import org.apache.sysds.runtime.controlprogram.paramserv.ParamservUtils;
@@ -39,15 +39,16 @@ import org.apache.sysds.runtime.instructions.cp.ListObject;
 import org.apache.sysds.runtime.matrix.data.FrameBlock;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 import org.apache.sysds.runtime.util.DependencyTask;
+import org.apache.sysds.utils.stats.TransformStatistics;
 
 public class ColumnEncoderUDF extends ColumnEncoder {
 
 	//TODO pass execution context through encoder factory for arbitrary functions not just builtin
-	//TODO handling udf after dummy coding
 	//TODO integration into IPA to ensure existence of unoptimized functions
 	
 	private final String _fName;
-	
+	public int _domainSize = 1;
+
 	protected ColumnEncoderUDF(int ptCols, String name) {
 		super(ptCols); // 1-based
 		_fName = name;
@@ -73,10 +74,12 @@ public class ColumnEncoderUDF extends ColumnEncoder {
 	}
 	
 	@Override
-	public MatrixBlock apply(CacheBlock in, MatrixBlock out, int outputCol, int rowStart, int blk) {
+	public void applyDense(CacheBlock in, MatrixBlock out, int outputCol, int rowStart, int blk) {
+		long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0;
 		//create execution context and input
 		ExecutionContext ec = ExecutionContextFactory.createContext(new Program(new DMLProgram()));
-		MatrixBlock col = out.slice(0, in.getNumRows()-1, _colID-1, _colID-1, new MatrixBlock());
+		//MatrixBlock col = out.slice(0, in.getNumRows()-1, _colID-1, _colID-1, new MatrixBlock());
+		MatrixBlock col = out.slice(0, in.getNumRows()-1, outputCol, outputCol+_domainSize-1, new MatrixBlock());
 		ec.setVariable("I", new ListObject(new Data[] {ParamservUtils.newMatrixObject(col, true)}));
 		ec.setVariable("O", ParamservUtils.newMatrixObject(col, true));
 		
@@ -87,11 +90,39 @@ public class ColumnEncoderUDF extends ColumnEncoder {
 				new CPOperand(_fName, ValueType.STRING, DataType.SCALAR, true),
 				new CPOperand("I", ValueType.UNKNOWN, DataType.LIST)});
 		fun.processInstruction(ec);
-		
+
 		//obtain result and in-place write back
 		MatrixBlock ret = ((MatrixObject)ec.getCacheableData("O")).acquireReadAndRelease();
-		out.leftIndexingOperations(ret, 0, in.getNumRows()-1, _colID-1, _colID-1, ret, UpdateType.INPLACE);
-		return out;
+		//out.leftIndexingOperations(ret, 0, in.getNumRows()-1, _colID-1, _colID-1, ret, UpdateType.INPLACE);
+		//out.leftIndexingOperations(ret, 0, in.getNumRows()-1, outputCol, outputCol+_domainSize-1, ret, UpdateType.INPLACE);
+		//out.copy(0, in.getNumRows()-1, _colID-1, _colID-1, ret, true);
+		out.copy(0, in.getNumRows()-1, outputCol, outputCol+_domainSize-1, ret, true);
+
+		if (DMLScript.STATISTICS)
+			TransformStatistics.incUDFApplyTime(System.nanoTime() - t0);
+	}
+
+	public void updateDomainSizes(List<ColumnEncoder> columnEncoders) {
+		if(_colID == -1)
+			return;
+		for(ColumnEncoder columnEncoder : columnEncoders) {
+			int distinct = -1;
+			if(columnEncoder instanceof ColumnEncoderRecode) {
+				ColumnEncoderRecode columnEncoderRecode = (ColumnEncoderRecode) columnEncoder;
+				distinct = columnEncoderRecode.getNumDistinctValues();
+			}
+			else if(columnEncoder instanceof ColumnEncoderBin) {
+				distinct = ((ColumnEncoderBin) columnEncoder)._numBin;
+			}
+			else if(columnEncoder instanceof ColumnEncoderFeatureHash){
+				distinct = (int) ((ColumnEncoderFeatureHash) columnEncoder).getK();
+			}
+
+			if(distinct != -1) {
+				_domainSize = distinct;
+				LOG.debug("DummyCoder for column: " + _colID + " has domain size: " + _domainSize);
+			}
+		}
 	}
 	
 	@Override
diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java b/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java
index d84d00e531..07d1aa7f2a 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java
@@ -314,11 +314,12 @@ public class MultiColumnEncoder implements Encoder {
 
 	public MatrixBlock apply(CacheBlock in, int k) {
 		// domain sizes are not updated if called from transformapply
+		boolean hasUDF = _columnEncoders.stream().anyMatch(e -> e.hasEncoder(ColumnEncoderUDF.class));
 		for(ColumnEncoderComposite columnEncoder : _columnEncoders)
 			columnEncoder.updateAllDCEncoders();
 		int numCols = in.getNumColumns() + getNumExtraCols();
-		long estNNz = (long) in.getNumColumns() * (long) in.getNumRows();
-		boolean sparse = MatrixBlock.evalSparseFormatInMemory(in.getNumRows(), numCols, estNNz);
+		long estNNz = (long) in.getNumRows() * (hasUDF ? numCols : (long) in.getNumColumns());
+		boolean sparse = MatrixBlock.evalSparseFormatInMemory(in.getNumRows(), numCols, estNNz) && !hasUDF;
 		MatrixBlock out = new MatrixBlock(in.getNumRows(), numCols, sparse, estNNz);
 		return apply(in, out, 0, k);
 	}
@@ -379,16 +380,15 @@ public class MultiColumnEncoder implements Encoder {
 	private void applyMT(CacheBlock in, MatrixBlock out, int outputCol, int k) {
 		DependencyThreadPool pool = new DependencyThreadPool(k);
 		try {
-			if(APPLY_ENCODER_SEPARATE_STAGES){
+			if(APPLY_ENCODER_SEPARATE_STAGES) {
 				int offset = outputCol;
 				for (ColumnEncoderComposite e : _columnEncoders) {
 					pool.submitAllAndWait(e.getApplyTasks(in, out, e._colID - 1 + offset));
 					if (e.hasEncoder(ColumnEncoderDummycode.class))
 						offset += e.getEncoder(ColumnEncoderDummycode.class)._domainSize - 1;
 				}
-			}else{
+			} else
 				pool.submitAllAndWait(getApplyTasks(in, out, outputCol));
-			}
 		}
 		catch(ExecutionException | InterruptedException e) {
 			LOG.error("MT Column apply failed");
@@ -455,7 +455,7 @@ public class MultiColumnEncoder implements Encoder {
 			long memBudget = (long) (OptimizerUtils.getLocalMemBudget() - in.getInMemorySize());
 			// Worst case scenario: all partial maps contain all distinct values (if < #rows)
 			long totMemOverhead = getTotalMemOverhead(in, rcdNumBuildBlks, recodeEncoders);
-			// Reduce recode build blocks count till they fit int the memory budget
+			// Reduce recode build blocks count till they fit in the memory budget
 			while (rcdNumBuildBlks > 1 && totMemOverhead > memBudget) {
 				rcdNumBuildBlks--;
 				totMemOverhead = getTotalMemOverhead(in, rcdNumBuildBlks, recodeEncoders);
@@ -1078,10 +1078,11 @@ public class MultiColumnEncoder implements Encoder {
 
 		@Override
 		public Object call() throws Exception {
+			boolean hasUDF = _encoder.getColumnEncoders().stream().anyMatch(e -> e.hasEncoder(ColumnEncoderUDF.class));
 			int numCols = _input.getNumColumns() + _encoder.getNumExtraCols();
 			boolean hasDC = _encoder.getColumnEncoders(ColumnEncoderDummycode.class).size() > 0;
-			long estNNz = (long) _input.getNumColumns() * (long) _input.getNumRows();
-			boolean sparse = MatrixBlock.evalSparseFormatInMemory(_input.getNumRows(), numCols, estNNz);
+			long estNNz = (long) _input.getNumRows() * (hasUDF ? numCols : (long) _input.getNumColumns());
+			boolean sparse = MatrixBlock.evalSparseFormatInMemory(_input.getNumRows(), numCols, estNNz) && !hasUDF;
 			_output.reset(_input.getNumRows(), numCols, sparse, estNNz);
 			outputMatrixPreProcessing(_output, _input, hasDC);
 			return null;
diff --git a/src/main/java/org/apache/sysds/utils/stats/TransformStatistics.java b/src/main/java/org/apache/sysds/utils/stats/TransformStatistics.java
index 05f06b065c..b7779e4ee1 100644
--- a/src/main/java/org/apache/sysds/utils/stats/TransformStatistics.java
+++ b/src/main/java/org/apache/sysds/utils/stats/TransformStatistics.java
@@ -35,6 +35,7 @@ public class TransformStatistics {
 	private static final LongAdder passThroughApplyTime = new LongAdder();
 	private static final LongAdder featureHashingApplyTime = new LongAdder();
 	private static final LongAdder binningApplyTime = new LongAdder();
+	private static final LongAdder UDFApplyTime = new LongAdder();
 	private static final LongAdder omitApplyTime = new LongAdder();
 	private static final LongAdder imputeApplyTime = new LongAdder();
 
@@ -58,6 +59,10 @@ public class TransformStatistics {
 		binningApplyTime.add(t);
 	}
 
+	public static void incUDFApplyTime(long t) {
+		UDFApplyTime.add(t);
+	}
+
 	public static void incPassThroughApplyTime(long t) {
 		passThroughApplyTime.add(t);
 	}
@@ -106,8 +111,8 @@ public class TransformStatistics {
 	public static long getEncodeApplyTime() {
 		return dummyCodeApplyTime.longValue() + binningApplyTime.longValue() +
 				featureHashingApplyTime.longValue() + passThroughApplyTime.longValue() +
-				recodeApplyTime.longValue() + omitApplyTime.longValue() +
-				imputeApplyTime.longValue();
+				recodeApplyTime.longValue() + UDFApplyTime.longValue() +
+				omitApplyTime.longValue() + imputeApplyTime.longValue();
 	}
 
 	public static void reset() {
@@ -122,6 +127,7 @@ public class TransformStatistics {
 		passThroughApplyTime.reset();
 		featureHashingApplyTime.reset();
 		binningApplyTime.reset();
+		UDFApplyTime.reset();
 		omitApplyTime.reset();
 		imputeApplyTime.reset();
 		outMatrixPreProcessingTime.reset();
@@ -163,6 +169,9 @@ public class TransformStatistics {
 			if(passThroughApplyTime.longValue() > 0)
 				sb.append("\tPassThrough apply time:\t").append(String.format("%.3f",
 					passThroughApplyTime.longValue()*1e-9)).append(" sec.\n");
+			if(UDFApplyTime.longValue() > 0)
+				sb.append("\tUDF apply time:\t").append(String.format("%.3f",
+					UDFApplyTime.longValue()*1e-9)).append(" sec.\n");
 			if(omitApplyTime.longValue() > 0)
 				sb.append("\tOmit apply time:\t").append(String.format("%.3f",
 					omitApplyTime.longValue()*1e-9)).append(" sec.\n");