You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by ba...@apache.org on 2023/05/17 10:39:50 UTC
[systemds] branch main updated: [SYSTEMDS-3490] Compressed Transform Tests
This is an automated email from the ASF dual-hosted git repository.
baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new d0597c0f62 [SYSTEMDS-3490] Compressed Transform Tests
d0597c0f62 is described below
commit d0597c0f62ca35dc6f99235bd7cfffa2421c6ab4
Author: baunsgaard <ba...@tu-berlin.de>
AuthorDate: Wed May 17 10:24:38 2023 +0200
[SYSTEMDS-3490] Compressed Transform Tests
This commit update the compressed tests to 100% coverage and fixes
some edge cases in binning and hashing.
Closes #1826
---
.../sysds/runtime/compress/lib/CLALibUtils.java | 11 +---
.../runtime/transform/encode/ColumnEncoderBin.java | 15 ++---
.../runtime/transform/encode/CompressedEncode.java | 66 +++++++++-------------
.../runtime/transform/encode/EncoderFactory.java | 4 +-
src/test/java/org/apache/sysds/test/TestUtils.java | 33 +----------
.../transform/TransformCompressedTestMultiCol.java | 51 ++++++++++-------
.../TransformCompressedTestSingleCol.java | 30 +++++++---
.../frame/transform/TransformCustomTest.java | 22 ++++++++
8 files changed, 112 insertions(+), 120 deletions(-)
diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibUtils.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibUtils.java
index 3e6837d490..1dfe2a0575 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibUtils.java
@@ -20,7 +20,6 @@
package org.apache.sysds.runtime.compress.lib;
import java.util.ArrayList;
-import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
@@ -112,7 +111,7 @@ public final class CLALibUtils {
else
filteredGroups.add(g);
}
- return returnGroupIfFiniteNumbers(groups, filteredGroups, constV);
+ return filteredGroups;
}
protected static void filterGroupsAndSplitPreAgg(List<AColGroup> groups, double[] constV,
@@ -150,14 +149,6 @@ public final class CLALibUtils {
}
}
- private static List<AColGroup> returnGroupIfFiniteNumbers(List<AColGroup> groups, List<AColGroup> filteredGroups,
- double[] constV) {
- for(double v : constV)
- if(!Double.isFinite(v))
- throw new NotImplementedException("Not handling if the values are not finite: " + Arrays.toString(constV));
- return filteredGroups;
- }
-
private static AColGroup combineEmpty(List<AColGroup> e) {
return new ColGroupEmpty(combineColIndexes(e));
}
diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderBin.java b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderBin.java
index 895141db07..b2c530a3e4 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderBin.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderBin.java
@@ -43,10 +43,6 @@ public class ColumnEncoderBin extends ColumnEncoder {
public static final String NBINS_PREFIX = "nbins";
private static final long serialVersionUID = 1917445005206076078L;
- public int getNumBin() {
- return _numBin;
- }
-
protected int _numBin = -1;
private BinMethod _binMethod = BinMethod.EQUI_WIDTH;
@@ -75,6 +71,10 @@ public class ColumnEncoderBin extends ColumnEncoder {
_binMaxs = binMaxs;
}
+ public int getNumBin() {
+ return _numBin;
+ }
+
public double getColMins() {
return _colMins;
}
@@ -404,15 +404,8 @@ public class ColumnEncoderBin extends ColumnEncoder {
sb.append(": ");
sb.append(_colID);
sb.append(" --- Method: " + _binMethod + " num Bin: " + _numBin);
- // if(_binMethod == BinMethod.EQUI_WIDTH) {
sb.append("\n---- BinMin: " + Arrays.toString(_binMins));
sb.append("\n---- BinMax: " + Arrays.toString(_binMaxs));
- // }
- // else {
- // // sb.append(" --- MinMax: "+ _colMins + " " + _colMaxs);
- // sb.append("\n---- BinMin: " + Arrays.toString(_binMins));
- // sb.append("\n---- BinMax: " + Arrays.toString(_binMaxs));
- // }
return sb.toString();
}
diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java b/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java
index 63eb81e008..150133c469 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java
@@ -33,11 +33,11 @@ import org.apache.commons.logging.LogFactory;
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.conf.DMLConfig;
-import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
import org.apache.sysds.runtime.compress.colgroup.ColGroupConst;
import org.apache.sysds.runtime.compress.colgroup.ColGroupDDC;
+import org.apache.sysds.runtime.compress.colgroup.ColGroupEmpty;
import org.apache.sysds.runtime.compress.colgroup.ColGroupUncompressed;
import org.apache.sysds.runtime.compress.colgroup.dictionary.ADictionary;
import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary;
@@ -69,11 +69,12 @@ public class CompressedEncode {
this.k = k;
}
- public static MatrixBlock encode(MultiColumnEncoder enc, FrameBlock in, int k) {
+ public static MatrixBlock encode(MultiColumnEncoder enc, FrameBlock in, int k)
+ throws InterruptedException, ExecutionException {
return new CompressedEncode(enc, in, k).apply();
}
- private MatrixBlock apply() {
+ private MatrixBlock apply() throws InterruptedException, ExecutionException {
final List<ColumnEncoderComposite> encoders = enc.getColumnEncoders();
final List<AColGroup> groups = isParallel() ? multiThread(encoders) : singleThread(encoders);
final int cols = shiftGroups(groups);
@@ -94,7 +95,8 @@ public class CompressedEncode {
return groups;
}
- private List<AColGroup> multiThread(List<ColumnEncoderComposite> encoders) {
+ private List<AColGroup> multiThread(List<ColumnEncoderComposite> encoders)
+ throws InterruptedException, ExecutionException {
final ExecutorService pool = CommonThreadPool.get(k);
try {
@@ -106,13 +108,10 @@ public class CompressedEncode {
List<AColGroup> groups = new ArrayList<>(encoders.size());
for(Future<AColGroup> t : pool.invokeAll(tasks))
groups.add(t.get());
-
- pool.shutdown();
return groups;
}
- catch(InterruptedException | ExecutionException ex) {
+ finally {
pool.shutdown();
- throw new DMLRuntimeException("Failed parallel compressed transform encode", ex);
}
}
@@ -157,8 +156,10 @@ public class CompressedEncode {
boolean containsNull = a.containsNull();
HashMap<?, Long> map = a.getRecodeMap();
int domain = map.size();
+ if(containsNull && domain == 0)
+ return new ColGroupEmpty(ColIndexFactory.create(1));
IColIndex colIndexes = ColIndexFactory.create(0, domain);
- if(domain == 1)
+ if(domain == 1 && !containsNull)
return ColGroupConst.create(colIndexes, new double[] {1});
ADictionary d = new IdentityDictionary(colIndexes.size(), containsNull);
AMapToData m = createMappingAMapToData(a, map, containsNull);
@@ -180,12 +181,6 @@ public class CompressedEncode {
AMapToData m = binEncode(a, b, containsNull);
AColGroup ret = ColGroupDDC.create(colIndexes, d, m, null);
- try {
- ret.getNumberNonZeros(a.size());
- }
- catch(Exception e) {
- throw new DMLRuntimeException("Failed binning \n\n" + a + "\n" + b + "\n" + d + "\n" + m, e);
- }
return ret;
}
@@ -230,7 +225,6 @@ public class CompressedEncode {
ADictionary d = new IdentityDictionary(colIndexes.size(), containsNull);
AMapToData m = binEncode(a, b, containsNull);
AColGroup ret = ColGroupDDC.create(colIndexes, d, m, null);
- ret.getNumberNonZeros(a.size());
return ret;
}
@@ -246,11 +240,11 @@ public class CompressedEncode {
IColIndex colIndexes = ColIndexFactory.create(1);
if(domain == 1)
return ColGroupConst.create(colIndexes, new double[] {1});
- MatrixBlock incrementing = new MatrixBlock(domain + (containsNull ? 1 : 0) , 1, false);
+ MatrixBlock incrementing = new MatrixBlock(domain + (containsNull ? 1 : 0), 1, false);
for(int i = 0; i < domain; i++)
incrementing.quickSetValue(i, 0, i + 1);
if(containsNull)
- incrementing.quickSetValue(domain, 0 , Double.NaN);
+ incrementing.quickSetValue(domain, 0, Double.NaN);
ADictionary d = MatrixBlockDictionary.create(incrementing);
@@ -258,7 +252,6 @@ public class CompressedEncode {
List<ColumnEncoder> r = c.getEncoders();
r.set(0, new ColumnEncoderRecode(colId, (HashMap<Object, Long>) map));
-
return ColGroupDDC.create(colIndexes, d, m, null);
}
@@ -283,7 +276,7 @@ public class CompressedEncode {
if(containsNull)
vals[map.size()] = Double.NaN;
ValueType t = a.getValueType();
- map.forEach((k,v) -> vals[v.intValue()] = UtilFunctions.objectToDouble(t,k));
+ map.forEach((k, v) -> vals[v.intValue()] = UtilFunctions.objectToDouble(t, k));
ADictionary d = Dictionary.create(vals);
AMapToData m = createMappingAMapToData(a, map, containsNull);
return ColGroupDDC.create(colIndexes, d, m, null);
@@ -295,17 +288,17 @@ public class CompressedEncode {
final int si = map.size();
AMapToData m = MapToFactory.create(in.getNumRows(), si + (containsNull ? 1 : 0));
Array<?>.ArrayIterator it = a.getIterator();
- if(containsNull){
+ if(containsNull) {
while(it.hasNext()) {
Object v = it.next();
if(v != null)
m.set(it.getIndex(), map.get(v).intValue());
else
- m.set(it.getIndex(),si);
+ m.set(it.getIndex(), si);
}
}
- else{
+ else {
while(it.hasNext()) {
Object v = it.next();
m.set(it.getIndex(), map.get(v).intValue());
@@ -340,25 +333,22 @@ public class CompressedEncode {
int colId = c._colID;
Array<?> a = in.getColumn(colId - 1);
ColumnEncoderFeatureHash CEHash = (ColumnEncoderFeatureHash) c.getEncoders().get(0);
-
- // HashMap<?, Long> map = a.getRecodeMap();
int domain = (int) CEHash.getK();
boolean nulls = a.containsNull();
IColIndex colIndexes = ColIndexFactory.create(0, 1);
- if(domain == 1)
+ if(domain == 1 && ! nulls)
return ColGroupConst.create(colIndexes, new double[] {1});
- MatrixBlock incrementing = new MatrixBlock(domain + (nulls ? 1 : 0), 1, false);
- for(int i = 0; i < domain; i++)
- incrementing.quickSetValue(i, 0, i + 1);
- if(nulls)
- incrementing.quickSetValue(domain, 0, Double.NaN);
-
- ADictionary d = MatrixBlockDictionary.create(incrementing);
-
- AMapToData m = createHashMappingAMapToData(a, domain , nulls);
- AColGroup ret = ColGroupDDC.create(colIndexes, d, m, null);
- ret.getNumberNonZeros(a.size());
+ MatrixBlock incrementing = new MatrixBlock(domain + (nulls ? 1 : 0), 1, false);
+ for(int i = 0; i < domain; i++)
+ incrementing.quickSetValue(i, 0, i + 1);
+ if(nulls)
+ incrementing.quickSetValue(domain, 0, Double.NaN);
+
+ ADictionary d = MatrixBlockDictionary.create(incrementing);
+
+ AMapToData m = createHashMappingAMapToData(a, domain, nulls);
+ AColGroup ret = ColGroupDDC.create(colIndexes, d, m, null);
return ret;
}
@@ -369,7 +359,7 @@ public class CompressedEncode {
int domain = (int) CEHash.getK();
boolean nulls = a.containsNull();
IColIndex colIndexes = ColIndexFactory.create(0, domain);
- if(domain == 1)
+ if(domain == 1 && !nulls)
return ColGroupConst.create(colIndexes, new double[] {1});
ADictionary d = new IdentityDictionary(colIndexes.size(), nulls);
AMapToData m = createHashMappingAMapToData(a, domain, nulls);
diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java
index 41e16d6e6e..075b6fbdd4 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java
@@ -44,8 +44,8 @@ import org.apache.sysds.utils.stats.TransformStatistics;
import org.apache.wink.json4j.JSONArray;
import org.apache.wink.json4j.JSONObject;
-public class EncoderFactory {
- protected static final Log LOG = LogFactory.getLog(EncoderFactory.class.getName());
+public interface EncoderFactory {
+ final static Log LOG = LogFactory.getLog(EncoderFactory.class.getName());
public static MultiColumnEncoder createEncoder(String spec, String[] colnames, int clen, FrameBlock meta) {
return createEncoder(spec, colnames, UtilFunctions.nCopies(clen, ValueType.STRING), meta);
diff --git a/src/test/java/org/apache/sysds/test/TestUtils.java b/src/test/java/org/apache/sysds/test/TestUtils.java
index bc5d48b8a2..23da1a8fc3 100644
--- a/src/test/java/org/apache/sysds/test/TestUtils.java
+++ b/src/test/java/org/apache/sysds/test/TestUtils.java
@@ -91,9 +91,6 @@ import org.apache.sysds.runtime.util.DataConverter;
import org.apache.sysds.runtime.util.UtilFunctions;
import org.junit.Assert;
-//import jcuda.runtime.JCuda;
-
-
/**
* <p>
* Provides methods to easily create tests. Implemented methods can be used for
@@ -106,8 +103,7 @@ import org.junit.Assert;
* <li>clean up</li>
* </ul>
*/
-public class TestUtils
-{
+public class TestUtils {
private static final Log LOG = LogFactory.getLog(TestUtils.class.getName());
@@ -1604,16 +1600,6 @@ public class TestUtils
return false;
}
-
- /**
- *
- * @param vt
- * @param in1
- * @param in2
- * @param tolerance
- *
- * @return
- */
public static int compareTo(ValueType vt, Object in1, Object in2, double tolerance) {
if(in1 == null && in2 == null) return 0;
else if(in1 == null) return -1;
@@ -1659,12 +1645,6 @@ public class TestUtils
}
}
- /**
- * Converts a 2D array into a sparse hashmap matrix.
- *
- * @param matrix
- * @return
- */
public static HashMap<CellIndex, Double> convert2DDoubleArrayToHashMap(double[][] matrix) {
HashMap<CellIndex, Double> hmMatrix = new HashMap<>();
for (int i = 0; i < matrix.length; i++) {
@@ -1677,11 +1657,6 @@ public class TestUtils
return hmMatrix;
}
- /**
- * Method to convert a hashmap of matrix entries into a double array
- * @param matrix
- * @return
- */
public static double[][] convertHashMapToDoubleArray(HashMap <CellIndex, Double> matrix) {
int max_rows = -1, max_cols= -1;
for(CellIndex ix : matrix.keySet()) {
@@ -1701,12 +1676,6 @@ public class TestUtils
return ret_arr;
}
- /**
- * Converts a 2D double array into a 1D double array.
- *
- * @param array
- * @return
- */
public static double[] convert2Dto1DDoubleArray(double[][] array) {
double[] ret = new double[array.length * array[0].length];
int c = 0;
diff --git a/src/test/java/org/apache/sysds/test/component/frame/transform/TransformCompressedTestMultiCol.java b/src/test/java/org/apache/sysds/test/component/frame/transform/TransformCompressedTestMultiCol.java
index 2cea03489d..a592fc7477 100644
--- a/src/test/java/org/apache/sysds/test/component/frame/transform/TransformCompressedTestMultiCol.java
+++ b/src/test/java/org/apache/sysds/test/component/frame/transform/TransformCompressedTestMultiCol.java
@@ -22,6 +22,7 @@ package org.apache.sysds.test.component.frame.transform;
import static org.junit.Assert.fail;
import java.util.ArrayList;
+import java.util.Arrays;
import java.util.Collection;
import org.apache.commons.logging.Log;
@@ -55,24 +56,37 @@ public class TransformCompressedTestMultiCol {
final int[] threads = new int[] {1, 4};
try {
- FrameBlock data = TestUtils.generateRandomFrameBlock(100, new ValueType[] {ValueType.UINT4, ValueType.UINT8, ValueType.UINT4}, 231);
- data.setSchema(new ValueType[] {ValueType.INT32, ValueType.INT32, ValueType.INT32});
- for(int k : threads) {
- tests.add(new Object[] {data, k});
+ ValueType[] kPlusCols = new ValueType[1002];
+
+ Arrays.fill(kPlusCols, ValueType.BOOLEAN);
+
+ FrameBlock[] blocks = new FrameBlock[] {//
+ TestUtils.generateRandomFrameBlock(100, //
+ new ValueType[] {ValueType.UINT4, ValueType.UINT8, ValueType.UINT4}, 231), //
+ TestUtils.generateRandomFrameBlock(100, //
+ new ValueType[] {ValueType.BOOLEAN, ValueType.UINT8, ValueType.UINT4}, 231), //
+ new FrameBlock(new ValueType[] {ValueType.BOOLEAN, ValueType.INT32, ValueType.INT32}, 100), //
+ TestUtils.generateRandomFrameBlock(100, //
+ new ValueType[] {ValueType.UINT4, ValueType.BOOLEAN, ValueType.FP32}, 231, 0.2),
+ TestUtils.generateRandomFrameBlock(432, //
+ new ValueType[] {ValueType.UINT4, ValueType.BOOLEAN, ValueType.FP32}, 231, 0.2),
+ TestUtils.generateRandomFrameBlock(100, //
+ new ValueType[] {ValueType.UINT4, ValueType.BOOLEAN, ValueType.FP32}, 231, 0.9),
+ TestUtils.generateRandomFrameBlock(100, //
+ new ValueType[] {ValueType.UINT4, ValueType.BOOLEAN, ValueType.FP32}, 231, 0.99),
+
+ TestUtils.generateRandomFrameBlock(5, kPlusCols, 322),
+ TestUtils.generateRandomFrameBlock(1020, kPlusCols, 322),
+
+ };
+ blocks[2].ensureAllocatedColumns(100);
+
+ for(FrameBlock block : blocks) {
+ for(int k : threads) {
+ tests.add(new Object[] {block, k});
+ }
}
- FrameBlock data2 = TestUtils.generateRandomFrameBlock(100, new ValueType[] {ValueType.BOOLEAN, ValueType.UINT8, ValueType.UINT4}, 231);
- data2.setSchema(new ValueType[] {ValueType.BOOLEAN, ValueType.INT32, ValueType.INT32});
- for(int k : threads) {
- tests.add(new Object[] {data2, k});
- }
-
- FrameBlock data3 = new FrameBlock(
- new ValueType[] {ValueType.BOOLEAN, ValueType.INT32, ValueType.INT32}, 100) ;
- data3.ensureAllocatedColumns(100);
- for(int k : threads)
- tests.add(new Object[] {data3, k});
-
}
catch(Exception e) {
e.printStackTrace();
@@ -114,12 +128,12 @@ public class TransformCompressedTestMultiCol {
}
@Test
- public void testHash(){
+ public void testHash() {
test("{ids:true, hash:[1,2,3], K:10}");
}
@Test
- public void testHashToDummy(){
+ public void testHashToDummy() {
test("{ids:true, hash:[1,2,3], K:10, dummycode:[1,2]}");
}
@@ -137,7 +151,6 @@ public class TransformCompressedTestMultiCol {
MatrixBlock outNormal = encoderNormal.encode(data, k);
FrameBlock outNormalMD = encoderNormal.getMetaData(null);
-
TestUtils.compareMatrices(outNormal, outCompressed, 0, "Not Equal after apply");
TestUtils.compareFrames(outNormalMD, outCompressedMD, true);
}
diff --git a/src/test/java/org/apache/sysds/test/component/frame/transform/TransformCompressedTestSingleCol.java b/src/test/java/org/apache/sysds/test/component/frame/transform/TransformCompressedTestSingleCol.java
index a573783f6e..7b37ba1413 100644
--- a/src/test/java/org/apache/sysds/test/component/frame/transform/TransformCompressedTestSingleCol.java
+++ b/src/test/java/org/apache/sysds/test/component/frame/transform/TransformCompressedTestSingleCol.java
@@ -54,14 +54,18 @@ public class TransformCompressedTestSingleCol {
final ArrayList<Object[]> tests = new ArrayList<>();
final int[] threads = new int[] {1, 4};
try {
-
- FrameBlock data = TestUtils.generateRandomFrameBlock(100, new ValueType[] {ValueType.UINT4}, 231);
- for(int k : threads)
- tests.add(new Object[] {data, k});
-
- data = TestUtils.generateRandomFrameBlock(100, new ValueType[] {ValueType.UINT4}, 231, 0.2);
- for(int k : threads)
- tests.add(new Object[] {data, k});
+ FrameBlock[] blocks = new FrameBlock[] {
+ TestUtils.generateRandomFrameBlock(100, new ValueType[] {ValueType.UINT4}, 231),
+ TestUtils.generateRandomFrameBlock(100, new ValueType[] {ValueType.UINT4}, 231, 0.2),
+ TestUtils.generateRandomFrameBlock(100, new ValueType[] {ValueType.UINT4}, 231, 1.0),
+ TestUtils.generateRandomFrameBlock(100, new ValueType[] {ValueType.UINT4}, 231, 1.0),
+ // Above block size of number of unique elements
+ TestUtils.generateRandomFrameBlock(1200, new ValueType[] {ValueType.FP32}, 231, 0.1),};
+
+ blocks[3].set(40, 0, "14");
+ for(FrameBlock block : blocks)
+ for(int k : threads)
+ tests.add(new Object[] {block, k});
}
catch(Exception e) {
e.printStackTrace();
@@ -120,11 +124,21 @@ public class TransformCompressedTestSingleCol {
test("{ids:true, hash:[1], K:10}");
}
+ @Test
+ public void testHashDomain1() {
+ test("{ids:true, hash:[1], K:1}");
+ }
+
@Test
public void testHashToDummy() {
test("{ids:true, hash:[1], K:10, dummycode:[1]}");
}
+ @Test
+ public void testHashToDummyDomain1() {
+ test("{ids:true, hash:[1], K:1, dummycode:[1]}");
+ }
+
public void test(String spec) {
try {
diff --git a/src/test/java/org/apache/sysds/test/component/frame/transform/TransformCustomTest.java b/src/test/java/org/apache/sysds/test/component/frame/transform/TransformCustomTest.java
index ce7c5d17d9..d1b2479375 100644
--- a/src/test/java/org/apache/sysds/test/component/frame/transform/TransformCustomTest.java
+++ b/src/test/java/org/apache/sysds/test/component/frame/transform/TransformCustomTest.java
@@ -21,11 +21,20 @@ package org.apache.sysds.test.component.frame.transform;
import static org.junit.Assert.fail;
+import java.util.ArrayList;
+import java.util.List;
+
+import org.apache.commons.lang.NotImplementedException;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.runtime.frame.data.FrameBlock;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.transform.encode.ColumnEncoder;
+import org.apache.sysds.runtime.transform.encode.ColumnEncoderComposite;
+import org.apache.sysds.runtime.transform.encode.ColumnEncoderDummycode;
+import org.apache.sysds.runtime.transform.encode.ColumnEncoderPassThrough;
+import org.apache.sysds.runtime.transform.encode.CompressedEncode;
import org.apache.sysds.runtime.transform.encode.EncoderFactory;
import org.apache.sysds.runtime.transform.encode.MultiColumnEncoder;
import org.apache.sysds.test.TestUtils;
@@ -71,6 +80,19 @@ public class TransformCustomTest {
test("{ids:true, bin:[{id:1, method:equi-height, numbins:10}]}");
}
+ @Test(expected = NotImplementedException.class)
+ public void testInvalidEncodeCompressed() throws Exception {
+ List<ColumnEncoderComposite> columnEncoders = new ArrayList<>();
+ List<ColumnEncoder> encoders = new ArrayList<>();
+ // create a nonsense sequence of encoders.
+ encoders.add(new ColumnEncoderDummycode());
+ encoders.add(new ColumnEncoderPassThrough());
+ encoders.add(new ColumnEncoderDummycode());
+ columnEncoders.add(new ColumnEncoderComposite(encoders));
+ MultiColumnEncoder enc = new MultiColumnEncoder(columnEncoders);
+ CompressedEncode.encode(enc, data, 1);
+ }
+
public void test(String spec) {
try {