You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by ba...@apache.org on 2021/07/17 11:09:37 UTC

[systemds] branch master updated: [SYSTEMDS-3066] CLA Spark Decompress

This is an automated email from the ASF dual-hosted git repository.

baunsgaard pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new 14bb8a5  [SYSTEMDS-3066] CLA Spark Decompress
14bb8a5 is described below

commit 14bb8a5ba37fa2a6b81028797f24223a54297fc7
Author: baunsgaard <ba...@tugraz.at>
AuthorDate: Sat Jul 17 13:05:33 2021 +0200

    [SYSTEMDS-3066] CLA Spark Decompress
    
    This commit adds/fixes spark decompression.
    also contained in this commit is the ability to see the compression
    size if logging level is Trace while compressing with spark instructions.
---
 .../runtime/instructions/SPInstructionParser.java  |  4 +++
 .../instructions/cp/CompressionCPInstruction.java  | 13 +++++++--
 .../spark/CompressionSPInstruction.java            | 34 ++++++++++++++++++++++
 .../spark/DeCompressionSPInstruction.java          | 12 ++++----
 .../sysds/utils/DMLCompressionStatistics.java      | 23 +++++++++++----
 .../compress/CompressInstructionRewrite.java       |  4 +--
 .../compress/configuration/CompressBase.java       |  4 +--
 .../compress/workload/WorkloadAlgorithmTest.java   | 20 ++++++++++---
 .../workload/SystemDS-config-compress-workload.xml |  3 +-
 .../compress/workload/WorkloadAnalysisMLogReg.dml  |  2 +-
 10 files changed, 93 insertions(+), 26 deletions(-)

diff --git a/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java b/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java
index 08a6998..d1eb4a7 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java
@@ -59,6 +59,7 @@ import org.apache.sysds.runtime.instructions.spark.CpmmSPInstruction;
 import org.apache.sysds.runtime.instructions.spark.CtableSPInstruction;
 import org.apache.sysds.runtime.instructions.spark.CumulativeAggregateSPInstruction;
 import org.apache.sysds.runtime.instructions.spark.CumulativeOffsetSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.DeCompressionSPInstruction;
 import org.apache.sysds.runtime.instructions.spark.DnnSPInstruction;
 import org.apache.sysds.runtime.instructions.spark.IndexingSPInstruction;
 import org.apache.sysds.runtime.instructions.spark.MapmmChainSPInstruction;
@@ -500,6 +501,9 @@ public class SPInstructionParser extends InstructionParser
 			case Compression:
 				return CompressionSPInstruction.parseInstruction(str);
 
+			case DeCompression:
+				return DeCompressionSPInstruction.parseInstruction(str);
+
 			case SpoofFused:
 				return SpoofSPInstruction.parseInstruction(str);
 
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/CompressionCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/CompressionCPInstruction.java
index 5ccbc41..b3acc26 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/CompressionCPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/CompressionCPInstruction.java
@@ -19,8 +19,12 @@
 
 package org.apache.sysds.runtime.instructions.cp;
 
+import org.apache.commons.lang3.tuple.Pair;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
 import org.apache.sysds.hops.OptimizerUtils;
 import org.apache.sysds.runtime.compress.CompressedMatrixBlockFactory;
+import org.apache.sysds.runtime.compress.CompressionStatistics;
 import org.apache.sysds.runtime.compress.SingletonLookupHashMap;
 import org.apache.sysds.runtime.compress.workload.WTreeRoot;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
@@ -29,9 +33,11 @@ import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 import org.apache.sysds.runtime.matrix.operators.Operator;
 
 public class CompressionCPInstruction extends ComputationCPInstruction {
+	private static final Log LOG = LogFactory.getLog(CompressionCPInstruction.class.getName());
 
 	private final int _singletonLookupID;
 
+
 	private CompressionCPInstruction(Operator op, CPOperand in, CPOperand out, String opcode, String istr,
 		int singletonLookupID) {
 		super(CPType.Compression, op, in, null, null, out, opcode, istr);
@@ -61,9 +67,12 @@ public class CompressionCPInstruction extends ComputationCPInstruction {
 
 		WTreeRoot root = (_singletonLookupID != 0) ? (WTreeRoot) m.get(_singletonLookupID) : null;
 		// Compress the matrix block
-		MatrixBlock out = CompressedMatrixBlockFactory.compress(in, OptimizerUtils.getConstrainedNumThreads(-1), root)
-			.getLeft();
+		Pair<MatrixBlock, CompressionStatistics> compResult = CompressedMatrixBlockFactory.compress(in, OptimizerUtils.getConstrainedNumThreads(-1), root);
 
+		if(LOG.isTraceEnabled())
+			LOG.trace(compResult.getRight());
+		MatrixBlock out = compResult.getLeft();
+		
 		m.removeKey(_singletonLookupID);
 		// Set output and release input
 		ec.releaseMatrixInput(input1.getName());
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/spark/CompressionSPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/spark/CompressionSPInstruction.java
index 64809cf..e6b62ee 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/spark/CompressionSPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/spark/CompressionSPInstruction.java
@@ -19,6 +19,10 @@
 
 package org.apache.sysds.runtime.instructions.spark;
 
+import java.util.List;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
 import org.apache.spark.api.java.JavaPairRDD;
 import org.apache.spark.api.java.function.Function;
 import org.apache.sysds.runtime.compress.CompressedMatrixBlockFactory;
@@ -35,7 +39,10 @@ import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
 import org.apache.sysds.runtime.matrix.operators.Operator;
 
+import scala.Tuple2;
+
 public class CompressionSPInstruction extends UnarySPInstruction {
+	private static final Log LOG = LogFactory.getLog(CompressionSPInstruction.class.getName());
 
 	private final int _singletonLookupID;
 
@@ -79,6 +86,11 @@ public class CompressionSPInstruction extends UnarySPInstruction {
 
 		// execute compression
 		JavaPairRDD<MatrixIndexes, MatrixBlock> out = in.mapValues(mappingFunction);
+		if(LOG.isTraceEnabled()) {
+			out.checkpoint();
+			LOG.trace("\nSpark compressed    : " + reduceSizes(out.mapValues(new SizeFunction()).collect())
+				+ "\nSpark uncompressed  : " + reduceSizes(in.mapValues(new SizeFunction()).collect()));
+		}
 
 		// set outputs
 		sec.setRDDHandleForVariable(output.getName(), out);
@@ -110,4 +122,26 @@ public class CompressionSPInstruction extends UnarySPInstruction {
 				.getLeft();
 		}
 	}
+
+	public static class SizeFunction implements Function<MatrixBlock, Double> {
+		private static final long serialVersionUID = 1L;
+
+		public SizeFunction() {
+
+		}
+
+		@Override
+		public Double call(MatrixBlock arg0) throws Exception {
+			return (double) arg0.getInMemorySize();
+		}
+	}
+
+	public static String reduceSizes(List<Tuple2<MatrixIndexes, Double>> in) {
+		double sum = 0;
+		for(Tuple2<MatrixIndexes, Double> e : in) {
+			sum += e._2();
+		}
+
+		return "sum: " + sum + " mean: " + (sum / in.size());
+	}
 }
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/spark/DeCompressionSPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/spark/DeCompressionSPInstruction.java
index bd64775..d002d55 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/spark/DeCompressionSPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/spark/DeCompressionSPInstruction.java
@@ -27,10 +27,10 @@ import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
 import org.apache.sysds.runtime.instructions.InstructionUtils;
 import org.apache.sysds.runtime.instructions.cp.CPOperand;
-import org.apache.sysds.runtime.instructions.spark.CompressionSPInstruction.CompressionFunction;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
 import org.apache.sysds.runtime.matrix.operators.Operator;
+import org.apache.sysds.utils.DMLCompressionStatistics;
 
 public class DeCompressionSPInstruction extends UnarySPInstruction {
 
@@ -51,9 +51,10 @@ public class DeCompressionSPInstruction extends UnarySPInstruction {
 		// get input rdd handle
 		JavaPairRDD<MatrixIndexes, MatrixBlock> in = sec.getBinaryMatrixBlockRDDHandleForVariable(input1.getName());
 
-		// execute compression
-		JavaPairRDD<MatrixIndexes, MatrixBlock> out = in.mapValues(new CompressionFunction());
+		// execute decompression
+		JavaPairRDD<MatrixIndexes, MatrixBlock> out = in.mapValues(new DeCompressionFunction());
 
+		DMLCompressionStatistics.addDecompressSparkCount();
 		// set outputs
 		sec.setRDDHandleForVariable(output.getName(), out);
 		sec.addLineageRDD(input1.getName(), output.getName());
@@ -64,11 +65,10 @@ public class DeCompressionSPInstruction extends UnarySPInstruction {
 
 		@Override
 		public MatrixBlock call(MatrixBlock arg0) throws Exception {
-			if(arg0 instanceof CompressedMatrixBlock){
+			if(arg0 instanceof CompressedMatrixBlock) 
 				return ((CompressedMatrixBlock) arg0).decompress(OptimizerUtils.getConstrainedNumThreads(-1));
-			}else{
+			else 
 				return arg0;
-			}
 		}
 	}
 }
diff --git a/src/main/java/org/apache/sysds/utils/DMLCompressionStatistics.java b/src/main/java/org/apache/sysds/utils/DMLCompressionStatistics.java
index 0f7fda5..92130e8 100644
--- a/src/main/java/org/apache/sysds/utils/DMLCompressionStatistics.java
+++ b/src/main/java/org/apache/sysds/utils/DMLCompressionStatistics.java
@@ -33,6 +33,9 @@ public class DMLCompressionStatistics {
 	private static int DecompressMTCount = 0;
 	private static double DecompressMT = 0.0;
 
+	private static int DecompressSparkCount = 0;
+	private static int DecompressCacheCount = 0;
+
 	public static void reset() {
 		Phase0 = 0.0;
 		Phase1 = 0.0;
@@ -44,6 +47,8 @@ public class DMLCompressionStatistics {
 		DecompressST = 0.0;
 		DecompressMTCount = 0;
 		DecompressMT = 0.0;
+		DecompressSparkCount = 0;
+		DecompressCacheCount = 0;
 	}
 
 	public static boolean haveCompressed(){
@@ -85,12 +90,16 @@ public class DMLCompressionStatistics {
 		}
 	}
 
-	public static int getDecompressionCount() {
-		return DecompressMTCount;
+	public static void addDecompressSparkCount(){
+		DecompressSTCount++;
 	}
 
-	public static int getDecompressionSTCount() {
-		return DecompressSTCount;
+	public static void addDecompressCacheCount(){
+		DecompressCacheCount++;
+	}
+
+	public static int getDecompressionCount() {
+		return DecompressMTCount + DecompressSTCount + DecompressSparkCount + DecompressCacheCount;
 	}
 
 	public static void display(StringBuilder sb) {
@@ -102,9 +111,11 @@ public class DMLCompressionStatistics {
 				Phase3 / 1000,
 				Phase4 / 1000,
 				Phase5 / 1000));
-			sb.append(String.format("Decompression Counts (Single , Multi) thread                     :\t%d/%d\n",
+			sb.append(String.format("Decompression Counts (Single , Multi, Spark, Cache) thread       :\t%d/%d/%d/%d\n",
 				DecompressSTCount,
-				DecompressMTCount));
+				DecompressMTCount,
+				DecompressSparkCount,
+				DecompressCacheCount));
 			sb.append(String.format("Dedicated Decompression Time (Single , Multi) thread             :\t%.3f/%.3f\n",
 				DecompressST / 1000,
 				DecompressMT / 1000));
diff --git a/src/test/java/org/apache/sysds/test/functions/compress/CompressInstructionRewrite.java b/src/test/java/org/apache/sysds/test/functions/compress/CompressInstructionRewrite.java
index ccc6d79..c005115 100644
--- a/src/test/java/org/apache/sysds/test/functions/compress/CompressInstructionRewrite.java
+++ b/src/test/java/org/apache/sysds/test/functions/compress/CompressInstructionRewrite.java
@@ -127,9 +127,7 @@ public class CompressInstructionRewrite extends AutomatedTestBase {
 			if(LOG.isDebugEnabled())
 				LOG.debug(stdout);
 
-			int decompressCount = 0;
-			decompressCount += DMLCompressionStatistics.getDecompressionCount();
-			decompressCount += DMLCompressionStatistics.getDecompressionSTCount();
+			int decompressCount = DMLCompressionStatistics.getDecompressionCount();
 			long compressionCount = Statistics.getCPHeavyHitterCount("compress");
 
 			Assert.assertEquals(compressionCountsExpected, compressionCount);
diff --git a/src/test/java/org/apache/sysds/test/functions/compress/configuration/CompressBase.java b/src/test/java/org/apache/sysds/test/functions/compress/configuration/CompressBase.java
index 07b0441..5e1f3f5 100644
--- a/src/test/java/org/apache/sysds/test/functions/compress/configuration/CompressBase.java
+++ b/src/test/java/org/apache/sysds/test/functions/compress/configuration/CompressBase.java
@@ -70,9 +70,7 @@ public abstract class CompressBase extends AutomatedTestBase {
 
 			LOG.debug(runTest(null));
 
-			int decompressCount = 0;
-			decompressCount += DMLCompressionStatistics.getDecompressionCount();
-			decompressCount += DMLCompressionStatistics.getDecompressionSTCount();
+			int decompressCount = DMLCompressionStatistics.getDecompressionCount();
 			long compressionCount = (instType == ExecType.SPARK) ? Statistics
 				.getCPHeavyHitterCount("sp_compress") : Statistics.getCPHeavyHitterCount("compress");
 			DMLCompressionStatistics.reset();
diff --git a/src/test/java/org/apache/sysds/test/functions/compress/workload/WorkloadAlgorithmTest.java b/src/test/java/org/apache/sysds/test/functions/compress/workload/WorkloadAlgorithmTest.java
index e94a4ab..c257a57 100644
--- a/src/test/java/org/apache/sysds/test/functions/compress/workload/WorkloadAlgorithmTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/compress/workload/WorkloadAlgorithmTest.java
@@ -55,12 +55,23 @@ public class WorkloadAlgorithmTest extends AutomatedTestBase {
 		runWorkloadAnalysisTest(TEST_NAME1, ExecMode.HYBRID, 2);
 	}
 
+
+	@Test
+	public void testLmSP() {
+		runWorkloadAnalysisTest(TEST_NAME2, ExecMode.SPARK, 2);
+	}
+
 	@Test
 	public void testLmCP() {
 		runWorkloadAnalysisTest(TEST_NAME2, ExecMode.HYBRID, 2);
 	}
 
 	@Test
+	public void testPCASP() {
+		runWorkloadAnalysisTest(TEST_NAME3, ExecMode.SPARK, 1);
+	}
+
+	@Test
 	public void testPCACP() {
 		runWorkloadAnalysisTest(TEST_NAME3, ExecMode.HYBRID, 1);
 	}
@@ -85,18 +96,19 @@ public class WorkloadAlgorithmTest extends AutomatedTestBase {
 			writeInputMatrixWithMTD("y", y, false);
 
 			String ret = runTest(null).toString();
-
 			if(ret.contains("ERROR:"))
 				fail(ret);
 
 			// check various additional expectations
-			long actualCompressionCount = Statistics.getCPHeavyHitterCount("compress");
+			long actualCompressionCount = mode == ExecMode.HYBRID ? Statistics
+				.getCPHeavyHitterCount("compress") : Statistics.getCPHeavyHitterCount("sp_compress");
+
 			Assert.assertEquals(compressionCount, actualCompressionCount);
-			Assert.assertTrue(heavyHittersContainsString("compress"));
+			Assert.assertTrue( mode == ExecMode.HYBRID ? heavyHittersContainsString("compress") : heavyHittersContainsString("sp_compress"));
 			Assert.assertFalse(heavyHittersContainsString("m_scale"));
 
 		}
-		catch(Exception e){
+		catch(Exception e) {
 			resetExecMode(oldPlatform);
 			fail("Failed workload test");
 		}
diff --git a/src/test/scripts/functions/compress/workload/SystemDS-config-compress-workload.xml b/src/test/scripts/functions/compress/workload/SystemDS-config-compress-workload.xml
index 4e735c6..ed2ab68 100644
--- a/src/test/scripts/functions/compress/workload/SystemDS-config-compress-workload.xml
+++ b/src/test/scripts/functions/compress/workload/SystemDS-config-compress-workload.xml
@@ -19,6 +19,7 @@
 
 <root>
 	<sysds.compressed.linalg>workload</sysds.compressed.linalg>
+	<sysds.defaultblocksize>8000</sysds.defaultblocksize>
 	<sysds.cp.parallel.ops>true</sysds.cp.parallel.ops>
 	<sysds.scratch>target/force_comp_scratch_space</sysds.scratch>
-</root>
+</root>
\ No newline at end of file
diff --git a/src/test/scripts/functions/compress/workload/WorkloadAnalysisMLogReg.dml b/src/test/scripts/functions/compress/workload/WorkloadAnalysisMLogReg.dml
index 78e62c1..77b2959 100644
--- a/src/test/scripts/functions/compress/workload/WorkloadAnalysisMLogReg.dml
+++ b/src/test/scripts/functions/compress/workload/WorkloadAnalysisMLogReg.dml
@@ -27,7 +27,7 @@ print("")
 print("MLogReg")
 
 X = scale(X=X, scale=TRUE, center=TRUE);
-B = multiLogReg(X=X, Y=Y, verbose=TRUE, maxi = 10, maxii=10);
+B = multiLogReg(X=X, Y=Y, verbose=FALSE, maxi=3, maxii=2);
 
 [nn, P, acc] = multiLogRegPredict(X=X, B=B, Y=Y)