You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by mb...@apache.org on 2021/08/26 19:29:21 UTC
[systemds] branch master updated: [SYSTEMDS-3106] Fix performance
dense-sparse matrix multiplication
This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new 25f99b7 [SYSTEMDS-3106] Fix performance dense-sparse matrix multiplication
25f99b7 is described below
commit 25f99b76db6d53db82555c638144a799d6efade5
Author: Matthias Boehm <mb...@gmail.com>
AuthorDate: Thu Aug 26 21:18:44 2021 +0200
[SYSTEMDS-3106] Fix performance dense-sparse matrix multiplication
This patch improves the performance of dense-sparse matrix
multiplications for small dense left-hand-sides by making dense-sparse
amenable (other than the already working vector-sparse case) to
parallelization over rows of the right-hand-side.
Cleanup warnings compression (e.g., serial version UIDs).
---
.../sysds/runtime/compress/colgroup/ColGroupFactory.java | 2 +-
.../runtime/compress/colgroup/dictionary/ADictionary.java | 2 ++
.../runtime/compress/colgroup/dictionary/Dictionary.java | 2 ++
.../compress/colgroup/dictionary/MatrixBlockDictionary.java | 2 ++
.../runtime/compress/colgroup/dictionary/QDictionary.java | 2 ++
.../sysds/runtime/compress/colgroup/mapping/AMapToData.java | 2 ++
.../sysds/runtime/compress/colgroup/mapping/MapToBit.java | 2 ++
.../sysds/runtime/compress/colgroup/mapping/MapToByte.java | 2 ++
.../sysds/runtime/compress/colgroup/mapping/MapToChar.java | 2 ++
.../sysds/runtime/compress/colgroup/mapping/MapToInt.java | 2 ++
.../sysds/runtime/compress/colgroup/offset/AOffset.java | 3 ++-
.../sysds/runtime/compress/colgroup/offset/OffsetByte.java | 2 ++
.../sysds/runtime/compress/colgroup/offset/OffsetChar.java | 2 ++
.../org/apache/sysds/runtime/matrix/data/LibMatrixMult.java | 13 +++++++++----
14 files changed, 34 insertions(+), 6 deletions(-)
diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupFactory.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupFactory.java
index e62f8d5..1d7b97d 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupFactory.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupFactory.java
@@ -152,7 +152,7 @@ public final class ColGroupFactory {
@Override
public Collection<AColGroup> call() {
- ArrayList<AColGroup> res = new ArrayList<AColGroup>();
+ ArrayList<AColGroup> res = new ArrayList<>();
Tmp tmpMap = new Tmp();
for(CompressedSizeInfoColGroup g : _groups)
diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/ADictionary.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/ADictionary.java
index 4405ece..15c74b0 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/ADictionary.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/ADictionary.java
@@ -34,6 +34,8 @@ import org.apache.sysds.runtime.matrix.operators.ScalarOperator;
*/
public abstract class ADictionary implements Serializable {
+ private static final long serialVersionUID = 9118692576356558592L;
+
protected static final Log LOG = LogFactory.getLog(ADictionary.class.getName());
/**
diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/Dictionary.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/Dictionary.java
index bad6620..5b9d834 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/Dictionary.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/Dictionary.java
@@ -42,6 +42,8 @@ import org.apache.sysds.utils.MemoryEstimates;
*/
public class Dictionary extends ADictionary {
+ private static final long serialVersionUID = -6517136537249507753L;
+
private final double[] _values;
public Dictionary(double[] values) {
diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/MatrixBlockDictionary.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/MatrixBlockDictionary.java
index 73f1288..fb4d701 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/MatrixBlockDictionary.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/MatrixBlockDictionary.java
@@ -37,6 +37,8 @@ import org.apache.sysds.runtime.matrix.operators.ScalarOperator;
public class MatrixBlockDictionary extends ADictionary {
+ private static final long serialVersionUID = 2535887782150955098L;
+
private MatrixBlock _data;
public MatrixBlockDictionary(MatrixBlock data) {
diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/QDictionary.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/QDictionary.java
index 614c4b6..13f17cd 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/QDictionary.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/QDictionary.java
@@ -42,6 +42,8 @@ import org.apache.sysds.utils.MemoryEstimates;
*/
public class QDictionary extends ADictionary {
+ private static final long serialVersionUID = 2100501253343438897L;
+
protected double _scale;
protected byte[] _values;
diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/AMapToData.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/AMapToData.java
index 735c0a9..4421efa 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/AMapToData.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/AMapToData.java
@@ -28,6 +28,8 @@ import org.apache.commons.logging.LogFactory;
public abstract class AMapToData implements Serializable {
+ private static final long serialVersionUID = 100512759972844714L;
+
protected static final Log LOG = LogFactory.getLog(AMapToData.class.getName());
private int nUnique;
diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToBit.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToBit.java
index 48abe9e..faedc37 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToBit.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToBit.java
@@ -29,6 +29,8 @@ import org.apache.sysds.utils.MemoryEstimates;
public class MapToBit extends AMapToData {
+ private static final long serialVersionUID = -8065234231282619923L;
+
private final BitSet _data;
private final int _size;
diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToByte.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToByte.java
index 8d651bd..d8ee6cb 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToByte.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToByte.java
@@ -29,6 +29,8 @@ import org.apache.sysds.utils.MemoryEstimates;
public class MapToByte extends AMapToData {
+ private static final long serialVersionUID = -2498505439667351828L;
+
private final byte[] _data;
public MapToByte(int unique, int size) {
diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToChar.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToChar.java
index 9d6ad5f..f94232c 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToChar.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToChar.java
@@ -29,6 +29,8 @@ import org.apache.sysds.utils.MemoryEstimates;
public class MapToChar extends AMapToData {
+ private static final long serialVersionUID = 6315708056775476541L;
+
private final char[] _data;
public MapToChar(int unique, int size) {
diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToInt.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToInt.java
index 56a3c67..513c2a7 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToInt.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToInt.java
@@ -29,6 +29,8 @@ import org.apache.sysds.utils.MemoryEstimates;
public class MapToInt extends AMapToData {
+ private static final long serialVersionUID = -5557070920888782274L;
+
private final int[] _data;
public MapToInt(int unique, int size) {
diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/AOffset.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/AOffset.java
index d21f2a1..7f42240 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/AOffset.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/AOffset.java
@@ -39,6 +39,7 @@ import org.apache.commons.logging.LogFactory;
*/
public abstract class AOffset implements Serializable {
+ private static final long serialVersionUID = -4143271285905723425L;
protected static final Log LOG = LogFactory.getLog(AOffset.class.getName());
protected SoftReference<Map<Integer, AIterator>> skipIterators;
@@ -89,7 +90,7 @@ public abstract class AOffset implements Serializable {
sk.put(row, it);
}
else {
- Map<Integer, AIterator> nsk = new HashMap<Integer, AIterator>();
+ Map<Integer, AIterator> nsk = new HashMap<>();
nsk.put(row, it.clone());
skipIterators = new SoftReference<>(nsk);
}
diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/OffsetByte.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/OffsetByte.java
index aea8515..c89b673 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/OffsetByte.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/OffsetByte.java
@@ -28,6 +28,8 @@ import org.apache.sysds.utils.MemoryEstimates;
public class OffsetByte extends AOffset {
+ private static final long serialVersionUID = -4716104973912491790L;
+
private final static int maxV = 255;
private final byte[] offsets;
private final int offsetToFirst;
diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/OffsetChar.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/OffsetChar.java
index f98539c..c1c2930 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/OffsetChar.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/OffsetChar.java
@@ -28,6 +28,8 @@ import org.apache.sysds.utils.MemoryEstimates;
public class OffsetChar extends AOffset {
+ private static final long serialVersionUID = -1192266421395964882L;
+
private final static int maxV = (int) Character.MAX_VALUE;
private final char[] offsets;
diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java
index dd88ec9..a503085 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java
@@ -1154,10 +1154,15 @@ public class LibMatrixMult
final int blocksizeK = 32;
final int blocksizeI = 32;
+ int rl1 = pm2 ? 0 : rl;
+ int ru1 = pm2 ? m : ru;
+ int rl2 = pm2 ? rl : 0;
+ int ru2 = pm2 ? ru : cd;
+
//blocked execution
- for( int bi = rl; bi < ru; bi+=blocksizeI )
- for( int bk = 0, bimin = Math.min(ru, bi+blocksizeI); bk < cd; bk+=blocksizeK ) {
- int bkmin = Math.min(cd, bk+blocksizeK);
+ for( int bi = rl1; bi < ru1; bi+=blocksizeI )
+ for( int bk = rl2, bimin = Math.min(ru1, bi+blocksizeI); bk < ru2; bk+=blocksizeK ) {
+ int bkmin = Math.min(ru2, bk+blocksizeK);
//core sub block matrix multiplication
for(int i = bi; i < bimin; i++) {
double[] avals = a.values(i), cvals = c.values(i);
@@ -3883,7 +3888,7 @@ public class LibMatrixMult
double jvmMem = InfrastructureAnalyzer.getLocalMaxMemory();
return (m1.rlen==1 && LOW_LEVEL_OPTIMIZATION && m2.clen>1 && !(m1.isUltraSparse()||m2.isUltraSparse()))
|| (m1.rlen<=16 && LOW_LEVEL_OPTIMIZATION && m2.clen>1 && m2.rlen > m1.rlen
- && ( !m1.isUltraSparse() && !m2.sparse ) //dense-dense / sparse/dense
+ && ( !m1.isUltraSparse() && !(m1.sparse & m2.sparse) ) //dense-dense / sparse-dense / dense-sparse
&& (long)k * 8 * m1.rlen * m2.clen < Math.max(MEM_OVERHEAD_THRESHOLD,0.01*jvmMem) );
}