You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by ba...@apache.org on 2022/12/09 11:24:37 UTC

[systemds] branch main updated: [SYSTEMDS-3463] Add unique() built-in function

This is an automated email from the ASF dual-hosted git repository.

baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 8095f4167f [SYSTEMDS-3463] Add unique() built-in function
8095f4167f is described below

commit 8095f4167f21983bedc024f1aab54bfa837e5992
Author: Badrul Chowdhury <ba...@gmail.com>
AuthorDate: Sun Nov 27 19:20:52 2022 -0800

    [SYSTEMDS-3463] Add unique() built-in function
    
    This patch converts the existing unique() function from a script to a
    built-in. The script-based approach is based on sorting, which is very
    expensive computationally, especially for large multiblock inputs.
    The new approach, on the other hand, is based on a new data sketch for
    the unique() function. This first patch creates the framework for the
    new unique sketch and implements the CP RowCol case; other cases -
    CP Row/Col and Spark RowCol/Row/Col - will be implemented in subsequent
    patches.
    
    Closes #1740
---
 scripts/builtin/unique.dml                         |  45 -------
 .../java/org/apache/sysds/common/Builtins.java     |   2 +-
 src/main/java/org/apache/sysds/common/Types.java   |   3 +-
 .../org/apache/sysds/lops/PartialAggregate.java    |  11 ++
 .../org/apache/sysds/parser/DMLTranslator.java     |  23 +++-
 .../ParameterizedBuiltinFunctionExpression.java    |  71 +++++++++-
 .../sysds/runtime/functionobjects/Builtin.java     |   2 +-
 .../runtime/instructions/CPInstructionParser.java  |   5 +-
 .../runtime/instructions/InstructionUtils.java     |  13 ++
 .../cp/AggregateUnaryCPInstruction.java            |  60 ++++++---
 .../matrix/data/LibMatrixCountDistinct.java        |   6 +-
 .../sysds/runtime/matrix/data/LibMatrixSketch.java | 117 +++++++++++++++++
 .../matrix/operators/CountDistinctOperator.java    |  13 +-
 .../matrix/operators/UnarySketchOperator.java      |  44 +++++++
 .../systemds/operator/algorithm/builtin/unique.py  |  45 -------
 .../test/functions/builtin/BuiltinUniqueTest.java  | 114 ----------------
 .../sysds/test/functions/unique/UniqueBase.java    |  64 +++++++++
 .../sysds/test/functions/unique/UniqueRowCol.java  | 145 +++++++++++++++++++++
 src/test/scripts/functions/builtin/unique.R        |  27 ----
 .../unique.dml => unique/uniqueRowCol.dml}         |   6 +-
 20 files changed, 543 insertions(+), 273 deletions(-)

diff --git a/scripts/builtin/unique.dml b/scripts/builtin/unique.dml
deleted file mode 100644
index 57e01949b6..0000000000
--- a/scripts/builtin/unique.dml
+++ /dev/null
@@ -1,45 +0,0 @@
-#-------------------------------------------------------------
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-#-------------------------------------------------------------
-
-# Builtin function that implements unique operation on vectors
-#
-# INPUT:
-# -------------------------------------------------------
-# X     input vector
-# -------------------------------------------------------
-#
-# OUTPUT:
-# -------------------------------------------------------------------
-# R     matrix with only unique rows
-# -------------------------------------------------------------------
-
-m_unique = function(matrix[double] X)
-  return (matrix[double] R)
-{
-  R = X
-  if(nrow(X) > 1) {
-    # sort-based approach (a generic alternative would be transformencode)
-    X_sorted = order(target=X, by=1, decreasing=FALSE, index.return=FALSE);
-    temp = X_sorted[1:nrow(X_sorted)-1,] != X_sorted[2:nrow(X_sorted),];
-    mask = rbind(matrix(1, 1, 1), temp);
-    R = removeEmpty(target = X_sorted, margin = "rows", select = mask);
-  }
-}
diff --git a/src/main/java/org/apache/sysds/common/Builtins.java b/src/main/java/org/apache/sysds/common/Builtins.java
index 5afef9c308..4a0d045367 100644
--- a/src/main/java/org/apache/sysds/common/Builtins.java
+++ b/src/main/java/org/apache/sysds/common/Builtins.java
@@ -290,7 +290,6 @@ public enum Builtins {
 	TRANS("t", false),
 	TSNE("tSNE", true),
 	TYPEOF("typeof", false),
-	UNIQUE("unique", true),
 	UNIVAR("univar", true),
 	UNION("union", true),
 	VAR("var", false),
@@ -344,6 +343,7 @@ public enum Builtins {
 	TRANSFORMENCODE("transformencode", false, true),
 	TRANSFORMMETA("transformmeta", false, true),
 	UNDER_SAMPLING("underSampling", true),
+	UNIQUE("unique", false, true),
 	UPPER_TRI("upper.tri", false, true),
 	XDUMMY1("xdummy1", true), //error handling test
 	XDUMMY2("xdummy2", true); //error handling test
diff --git a/src/main/java/org/apache/sysds/common/Types.java b/src/main/java/org/apache/sysds/common/Types.java
index 7c3a3f1e53..b441f7263b 100644
--- a/src/main/java/org/apache/sysds/common/Types.java
+++ b/src/main/java/org/apache/sysds/common/Types.java
@@ -199,7 +199,8 @@ public class Types
 		TRACE(6), MEAN(7), VAR(8),
 		MAXINDEX(9), MININDEX(10),
 		COUNT_DISTINCT(11), ROW_COUNT_DISTINCT(12), COL_COUNT_DISTINCT(13),
-		COUNT_DISTINCT_APPROX(14), COUNT_DISTINCT_APPROX_ROW(15), COUNT_DISTINCT_APPROX_COL(16);
+		COUNT_DISTINCT_APPROX(14), COUNT_DISTINCT_APPROX_ROW(15), COUNT_DISTINCT_APPROX_COL(16),
+		UNIQUE(17);
 
 		@Override
 		public String toString() {
diff --git a/src/main/java/org/apache/sysds/lops/PartialAggregate.java b/src/main/java/org/apache/sysds/lops/PartialAggregate.java
index 1a7d22b989..467c7c69b0 100644
--- a/src/main/java/org/apache/sysds/lops/PartialAggregate.java
+++ b/src/main/java/org/apache/sysds/lops/PartialAggregate.java
@@ -374,6 +374,17 @@ public class PartialAggregate extends Lop
 
 			case COUNT_DISTINCT_APPROX_COL:
 				return "uacdapc";
+
+			case UNIQUE: {
+				switch (dir) {
+					case RowCol: return "unique";
+					case Row: return "uniquer";
+					case Col: return "uniquec";
+					default:
+						throw new LopsException("PartialAggregate.getOpcode() - "
+								+ "Unknown aggregate direction: " + dir);
+				}
+			}
 		}
 		
 		//should never come here for normal compilation
diff --git a/src/main/java/org/apache/sysds/parser/DMLTranslator.java b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
index 0c3a6dfd8f..06deb8ad7b 100644
--- a/src/main/java/org/apache/sysds/parser/DMLTranslator.java
+++ b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
@@ -2041,7 +2041,7 @@ public class DMLTranslator
 				break;
 
 			case COUNT_DISTINCT:
-			case COUNT_DISTINCT_APPROX:
+			case COUNT_DISTINCT_APPROX: {
 				Direction dir = Direction.RowCol;  // Default direction
 				DataType dataType = DataType.SCALAR;  // Default output data type
 
@@ -2063,6 +2063,7 @@ public class DMLTranslator
 				currBuiltinOp = new AggUnaryOp(target.getName(), dataType, target.getValueType(),
 						AggOp.valueOf(source.getOpCode().name()), dir, paramHops.get("data"));
 				break;
+			}
 
 			case COUNT_DISTINCT_APPROX_ROW:
 				currBuiltinOp = new AggUnaryOp(target.getName(), DataType.MATRIX, target.getValueType(),
@@ -2074,6 +2075,26 @@ public class DMLTranslator
 						AggOp.valueOf(source.getOpCode().name()), Direction.Col, paramHops.get("data"));
 				break;
 
+			case UNIQUE:
+				Direction dir = Direction.RowCol;
+				DataType dataType = DataType.MATRIX;
+
+				LiteralOp dirOp = (LiteralOp) paramHops.get("dir");
+				if (dirOp != null) {
+					String dirString = dirOp.getStringValue().toUpperCase();
+					if (dirString.equals(Direction.RowCol.toString())) {
+						dir = Direction.RowCol;
+					} else if (dirString.equals(Direction.Row.toString())) {
+						dir = Direction.Row;
+					} else if (dirString.equals(Direction.Col.toString())) {
+						dir = Direction.Col;
+					}
+				}
+
+				currBuiltinOp = new AggUnaryOp(target.getName(), dataType, target.getValueType(),
+						AggOp.valueOf(source.getOpCode().name()), dir, paramHops.get("data"));
+				break;
+
 			default:
 				throw new ParseException(source.printErrorLocation() + 
 					"processParameterizedBuiltinFunctionExpression() -- Unknown operation: " + source.getOpCode());
diff --git a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
index 7ef19badde..293ca7312e 100644
--- a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
+++ b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
@@ -259,6 +259,10 @@ public class ParameterizedBuiltinFunctionExpression extends DataIdentifier
 			validateCountDistinctApprox(output, conditional, true);
 			break;
 
+		case UNIQUE:
+			validateUnique(output, conditional);
+			break;
+
 		default: //always unconditional (because unsupported operation)
 			//handle common issue of transformencode
 			if( getOpCode()==Builtins.TRANSFORMENCODE )
@@ -398,7 +402,7 @@ public class ParameterizedBuiltinFunctionExpression extends DataIdentifier
 
 		checkStringParam(true, fname, "dir", conditional);
 		// Check data value of "dir" parameter
-		validateAggregationDirection(dataId, output);
+		validateCountDistinctAggregationDirection(dataId, output);
 	}
 
 	private void validateCountDistinctApprox(DataIdentifier output, boolean conditional, boolean isDirectionAlias) {
@@ -464,11 +468,11 @@ public class ParameterizedBuiltinFunctionExpression extends DataIdentifier
 		if (!isDirectionAlias) {
 			checkStringParam(true, fname, "dir", conditional);
 			// Check data value of "dir" parameter
-			validateAggregationDirection(dataId, output);
+			validateCountDistinctAggregationDirection(dataId, output);
 		}
 	}
 
-	private void validateAggregationDirection(Identifier dataId, DataIdentifier output) {
+	private void validateCountDistinctAggregationDirection(Identifier dataId, DataIdentifier output) {
 		HashMap<String, Expression> varParams = getVarParams();
 		if (varParams.containsKey("dir")) {
 			String inputDirectionString = varParams.get("dir").toString().toUpperCase();
@@ -512,6 +516,67 @@ public class ParameterizedBuiltinFunctionExpression extends DataIdentifier
 		}
 	}
 
+	private void validateUnique(DataIdentifier output, boolean conditional) {
+		HashMap<String, Expression> varParams = getVarParams();
+
+		// "data" is the only parameter that is allowed to be unnamed
+		if (varParams.containsKey(null)) {
+			varParams.put("data", varParams.remove(null));
+		}
+
+		// Validate the number of parameters
+		String fname = getOpCode().getName();
+		String usageMessage = "function " + fname + " takes at least 1 and at most 2 parameters";
+		if (varParams.size() < 1) {
+			raiseValidateError("Too few parameters: " + usageMessage, conditional);
+		}
+
+		if (varParams.size() > 2) {
+			raiseValidateError("Too many parameters: " + usageMessage, conditional);
+		}
+
+		// Check parameter names are valid
+		Set<String> validParameterNames = CollectionUtils.asSet("data", "dir");
+		checkInvalidParameters(getOpCode(), varParams, validParameterNames);
+
+		// Check parameter expression data types match expected
+		checkDataType(false, fname, "data", DataType.MATRIX, conditional);
+		checkDataValueType(false, fname, "data", DataType.MATRIX, ValueType.FP64, conditional);
+
+		// We need the dimensions of the input matrix to determine the output matrix characteristics
+		// Validate data parameter, lookup previously defined var or resolve expression
+		Identifier dataId = varParams.get("data").getOutput();
+		if (dataId == null) {
+			raiseValidateError("Cannot parse input parameter \"data\" to function " + fname, conditional);
+		}
+
+		checkStringParam(true, fname, "dir", conditional);
+		// Check data value of "dir" parameter
+		validateUniqueAggregationDirection(dataId, output);
+	}
+
+	private void validateUniqueAggregationDirection(Identifier dataId, DataIdentifier output) {
+		HashMap<String, Expression> varParams = getVarParams();
+		if (varParams.containsKey("dir")) {
+			String inputDirectionString = varParams.get("dir").toString().toUpperCase();
+
+			// unrecognized value for "dir" parameter
+			if (!inputDirectionString.equals(Types.Direction.Row.toString())
+					&& !inputDirectionString.equals(Types.Direction.Col.toString())
+					&& !inputDirectionString.equals(Types.Direction.RowCol.toString())) {
+				raiseValidateError("Invalid argument: " + inputDirectionString + " is not recognized");
+			}
+		}
+
+		// rc/r/c -> unique return value is the same as the input in the worst case
+		// default to dir="rc"
+		output.setDataType(DataType.MATRIX);
+		output.setDimensions(dataId.getDim1(), dataId.getDim2());
+		output.setBlocksize(dataId.getBlocksize());
+		output.setValueType(ValueType.FP64);
+		output.setNnz(dataId.getNnz());
+	}
+
 	private void checkStringParam(boolean optional, String fname, String pname, boolean conditional) {
 		Expression param = getVarParam(pname);
 		if (param == null) {
diff --git a/src/main/java/org/apache/sysds/runtime/functionobjects/Builtin.java b/src/main/java/org/apache/sysds/runtime/functionobjects/Builtin.java
index b904e60d9c..30114eff18 100644
--- a/src/main/java/org/apache/sysds/runtime/functionobjects/Builtin.java
+++ b/src/main/java/org/apache/sysds/runtime/functionobjects/Builtin.java
@@ -51,7 +51,7 @@ public class Builtin extends ValueFunction
 		MAX, ABS, SIGN, SQRT, EXP, PLOGP, PRINT, PRINTF, NROW, NCOL, LENGTH, LINEAGE, ROUND, MAXINDEX, MININDEX,
 		STOP, CEIL, FLOOR, CUMSUM, CUMPROD, CUMMIN, CUMMAX, CUMSUMPROD, INVERSE, SPROP, SIGMOID, EVAL, LIST,
 		TYPEOF, DETECTSCHEMA, ISNA, ISNAN, ISINF, DROP_INVALID_TYPE, DROP_INVALID_LENGTH, VALUE_SWAP, FRAME_ROW_REPLICATE,
-		MAP, COUNT_DISTINCT, COUNT_DISTINCT_APPROX}
+		MAP, COUNT_DISTINCT, COUNT_DISTINCT_APPROX, UNIQUE}
 
 
 	public BuiltinCode bFunc;
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java b/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
index 78dec00c24..eeeed7c5a1 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
@@ -121,6 +121,9 @@ public class CPInstructionParser extends InstructionParser
 		String2CPInstructionType.put( "uacdap"  , CPType.AggregateUnary);
 		String2CPInstructionType.put( "uacdapr" , CPType.AggregateUnary);
 		String2CPInstructionType.put( "uacdapc" , CPType.AggregateUnary);
+		String2CPInstructionType.put( "unique"  , CPType.AggregateUnary);
+		String2CPInstructionType.put( "uniquer" , CPType.AggregateUnary);
+		String2CPInstructionType.put( "uniquec" , CPType.AggregateUnary);
 
 		String2CPInstructionType.put( "uaggouterchain", CPType.UaggOuterChain);
 		
@@ -215,7 +218,7 @@ public class CPInstructionParser extends InstructionParser
 		String2CPInstructionType.put( "list",   CPType.BuiltinNary);
 		
 		// Parameterized Builtin Functions
-		String2CPInstructionType.put( "autoDiff" , CPType.ParameterizedBuiltin);
+		String2CPInstructionType.put( "autoDiff" ,      CPType.ParameterizedBuiltin);
 		String2CPInstructionType.put("paramserv",       CPType.ParameterizedBuiltin);
 		String2CPInstructionType.put( "nvlist",         CPType.ParameterizedBuiltin);
 		String2CPInstructionType.put( "cdf",            CPType.ParameterizedBuiltin);
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java b/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
index 4f4fac1e38..5026175b59 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
@@ -104,6 +104,7 @@ import org.apache.sysds.runtime.matrix.operators.RightScalarOperator;
 import org.apache.sysds.runtime.matrix.operators.ScalarOperator;
 import org.apache.sysds.runtime.matrix.operators.TernaryOperator;
 import org.apache.sysds.runtime.matrix.operators.UnaryOperator;
+import org.apache.sysds.runtime.matrix.operators.UnarySketchOperator;
 
 
 public class InstructionUtils 
@@ -453,6 +454,18 @@ public class InstructionUtils
 			aggun = new CountDistinctOperator(AggregateUnaryCPInstruction.AUType.COUNT_DISTINCT_APPROX,
 					Direction.Col, ReduceRow.getReduceRowFnObject());
 		}
+		else if ( opcode.equalsIgnoreCase("unique") ) {
+			AggregateOperator agg = new AggregateOperator(0, Builtin.getBuiltinFnObject("unique"));
+			aggun = new UnarySketchOperator(agg, ReduceAll.getReduceAllFnObject(), Direction.RowCol, numThreads);
+		}
+		else if ( opcode.equalsIgnoreCase("uniquer") ) {
+			AggregateOperator agg = new AggregateOperator(0, Builtin.getBuiltinFnObject("unique"));
+			aggun = new UnarySketchOperator(agg, ReduceCol.getReduceColFnObject(), Direction.Row, numThreads);
+		}
+		else if ( opcode.equalsIgnoreCase("uniquec") ) {
+			AggregateOperator agg = new AggregateOperator(0, Builtin.getBuiltinFnObject("unique"));
+			aggun = new UnarySketchOperator(agg, ReduceRow.getReduceRowFnObject(), Direction.Col, numThreads);
+		}
 
 		return aggun;
 	}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/AggregateUnaryCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/AggregateUnaryCPInstruction.java
index 6fc0107520..030fe5f5cf 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/AggregateUnaryCPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/AggregateUnaryCPInstruction.java
@@ -20,7 +20,6 @@
 package org.apache.sysds.runtime.instructions.cp;
 
 import org.apache.sysds.api.DMLScript;
-import org.apache.sysds.common.Types;
 import org.apache.sysds.common.Types.DataType;
 import org.apache.sysds.common.Types.ExecMode;
 import org.apache.sysds.runtime.DMLRuntimeException;
@@ -29,33 +28,27 @@ import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysds.runtime.data.BasicTensorBlock;
 import org.apache.sysds.runtime.data.TensorBlock;
 import org.apache.sysds.runtime.functionobjects.Builtin;
-import org.apache.sysds.runtime.functionobjects.ReduceAll;
-import org.apache.sysds.runtime.functionobjects.ReduceCol;
-import org.apache.sysds.runtime.functionobjects.ReduceRow;
 import org.apache.sysds.runtime.instructions.InstructionUtils;
-import org.apache.sysds.runtime.instructions.spark.data.CorrMatrixBlock;
 import org.apache.sysds.runtime.lineage.LineageDedupUtils;
 import org.apache.sysds.runtime.lineage.LineageItem;
 import org.apache.sysds.runtime.matrix.data.LibMatrixCountDistinct;
+import org.apache.sysds.runtime.matrix.data.LibMatrixSketch;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
-import org.apache.sysds.runtime.matrix.data.sketch.countdistinctapprox.SmallestPriorityQueue;
 import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;
 import org.apache.sysds.runtime.matrix.operators.CountDistinctOperator;
 import org.apache.sysds.runtime.matrix.operators.Operator;
 import org.apache.sysds.runtime.matrix.operators.SimpleOperator;
+import org.apache.sysds.runtime.matrix.operators.UnarySketchOperator;
 import org.apache.sysds.runtime.meta.DataCharacteristics;
 import org.apache.sysds.utils.Explain;
 
-import java.util.HashSet;
-import java.util.Set;
-
 public class AggregateUnaryCPInstruction extends UnaryCPInstruction {
 	// private static final Log LOG = LogFactory.getLog(AggregateUnaryCPInstruction.class.getName());
 
 	public enum AUType {
 		NROW, NCOL, LENGTH, EXISTS, LINEAGE, 
-		COUNT_DISTINCT, COUNT_DISTINCT_APPROX,
+		COUNT_DISTINCT, COUNT_DISTINCT_APPROX, UNIQUE,
 		DEFAULT;
 		public boolean isMeta() {
 			return this != DEFAULT;
@@ -107,6 +100,13 @@ public class AggregateUnaryCPInstruction extends UnaryCPInstruction {
 				.parseAggregateUnaryRowIndexOperator(opcode, Integer.parseInt(parts[4]), Integer.parseInt(parts[3]));
 			return new AggregateUnaryCPInstruction(aggun, in1, out, AUType.DEFAULT, opcode, str);
 		}
+		else if(opcode.equalsIgnoreCase("unique")
+				|| opcode.equalsIgnoreCase("uniquer")
+				|| opcode.equalsIgnoreCase("uniquec")){
+			AggregateUnaryOperator aggun = InstructionUtils.parseBasicAggregateUnaryOperator(opcode,
+					Integer.parseInt(parts[3]));
+			return new AggregateUnaryCPInstruction(aggun, in1, out, AUType.UNIQUE, opcode, str);
+		}
 		else { //DEFAULT BEHAVIOR
 			AggregateUnaryOperator aggun = InstructionUtils
 				.parseBasicAggregateUnaryOperator(opcode, Integer.parseInt(parts[3]));
@@ -116,7 +116,7 @@ public class AggregateUnaryCPInstruction extends UnaryCPInstruction {
 	
 	@Override
 	public void processInstruction( ExecutionContext ec ) {
-		String output_name = output.getName();
+		String outputName = output.getName();
 		String opcode = getOpcode();
 		
 		switch( _type ) {
@@ -163,7 +163,7 @@ public class AggregateUnaryCPInstruction extends UnaryCPInstruction {
 				}
 				
 				//create and set output scalar
-				ec.setScalarOutput(output_name, new IntObject(rval));
+				ec.setScalarOutput(outputName, new IntObject(rval));
 				break;
 			}
 			case EXISTS: {
@@ -172,7 +172,7 @@ public class AggregateUnaryCPInstruction extends UnaryCPInstruction {
 					ec.getScalarInput(input1).getStringValue();
 				boolean rval = ec.getVariables().keySet().contains(varName);
 				//create and set output scalar
-				ec.setScalarOutput(output_name, new BooleanObject(rval));
+				ec.setScalarOutput(outputName, new BooleanObject(rval));
 				break;
 			}
 			case LINEAGE: {
@@ -184,7 +184,7 @@ public class AggregateUnaryCPInstruction extends UnaryCPInstruction {
 				LineageItem li = ec.getLineageItem(input1);
 				String out = !DMLScript.LINEAGE_DEDUP ? Explain.explain(li) :
 					Explain.explain(li) + LineageDedupUtils.mergeExplainDedupBlocks(ec);
-				ec.setScalarOutput(output_name, new StringObject(out));
+				ec.setScalarOutput(outputName, new StringObject(out));
 				break;
 			}
 			case COUNT_DISTINCT:
@@ -203,18 +203,38 @@ public class AggregateUnaryCPInstruction extends UnaryCPInstruction {
 				if (op.getDirection().isRowCol()) {
 					long res = (long) LibMatrixCountDistinct.estimateDistinctValues(input, op).getValue(0, 0);
 					ec.releaseMatrixInput(input1.getName());
-					ec.setScalarOutput(output_name, new IntObject(res));
+					ec.setScalarOutput(outputName, new IntObject(res));
 				} else {  // Row/Col
 					// Note that for each row, the max number of distinct values < NNZ < max number of columns = 1000:
 					// Since count distinct approximate estimates are unreliable for values < 1024,
 					// we will force a naive count.
 					MatrixBlock res = LibMatrixCountDistinct.estimateDistinctValues(input, op);
 					ec.releaseMatrixInput(input1.getName());
-					ec.setMatrixOutput(output_name, res);
+					ec.setMatrixOutput(outputName, res);
+				}
+
+				break;
+			}
+
+			case UNIQUE: {
+				if(!ec.getVariables().keySet().contains(input1.getName())) {
+					throw new DMLRuntimeException("Variable '" + input1.getName() + "' does not exist.");
+				}
+				MatrixBlock input = ec.getMatrixInput(input1.getName());
+
+				// Operator type: test and cast
+				if (!(_optr instanceof UnarySketchOperator)) {
+					throw new DMLRuntimeException("Operator should be instance of "
+							+ UnarySketchOperator.class.getSimpleName());
 				}
+				UnarySketchOperator op = (UnarySketchOperator) _optr;
 
+				MatrixBlock res = LibMatrixSketch.getUniqueValues(input, op.getDirection());
+				ec.releaseMatrixInput(input1.getName());
+				ec.setMatrixOutput(outputName, res);
 				break;
 			}
+
 			default: {
 				AggregateUnaryOperator au_op = (AggregateUnaryOperator) _optr;
 				if (input1.getDataType() == DataType.MATRIX) {
@@ -226,10 +246,10 @@ public class AggregateUnaryCPInstruction extends UnaryCPInstruction {
 					ec.releaseMatrixInput(input1.getName());
 					if (output.getDataType() == DataType.SCALAR) {
 						DoubleObject ret = new DoubleObject(resultBlock.getValue(0, 0));
-						ec.setScalarOutput(output_name, ret);
+						ec.setScalarOutput(outputName, ret);
 					} else {
 						// since the computed value is a scalar, allocate a "temp" output matrix
-						ec.setMatrixOutput(output_name, resultBlock);
+						ec.setMatrixOutput(outputName, resultBlock);
 					}
 				} 
 				else if (input1.getDataType() == DataType.TENSOR) {
@@ -240,10 +260,10 @@ public class AggregateUnaryCPInstruction extends UnaryCPInstruction {
 
 					ec.releaseTensorInput(input1.getName());
 					if(output.getDataType() == DataType.SCALAR)
-						ec.setScalarOutput(output_name, ScalarObjectFactory.createScalarObject(
+						ec.setScalarOutput(outputName, ScalarObjectFactory.createScalarObject(
 							input1.getValueType(), resultBlock.get(new int[]{0, 0})));
 					else
-						ec.setTensorOutput(output_name, new TensorBlock(resultBlock));
+						ec.setTensorOutput(outputName, new TensorBlock(resultBlock));
 				}
 				else {
 					throw new DMLRuntimeException(opcode + " only supported on matrix or tensor.");
diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixCountDistinct.java b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixCountDistinct.java
index 72bcd64b43..c70f72ad3f 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixCountDistinct.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixCountDistinct.java
@@ -29,7 +29,11 @@ import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 import org.apache.sysds.api.DMLException;
 import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
-import org.apache.sysds.runtime.data.*;
+import org.apache.sysds.runtime.data.DenseBlock;
+import org.apache.sysds.runtime.data.SparseBlock;
+import org.apache.sysds.runtime.data.SparseBlockCOO;
+import org.apache.sysds.runtime.data.SparseBlockCSR;
+import org.apache.sysds.runtime.data.SparseBlockFactory;
 import org.apache.sysds.runtime.instructions.spark.data.CorrMatrixBlock;
 import org.apache.sysds.runtime.matrix.data.sketch.MatrixSketch;
 import org.apache.sysds.runtime.matrix.data.sketch.SketchFactory;
diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixSketch.java b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixSketch.java
new file mode 100644
index 0000000000..3793564dbe
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixSketch.java
@@ -0,0 +1,117 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.matrix.data;
+
+import org.apache.commons.lang.NotImplementedException;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.hops.OptimizerUtils;
+
+import java.util.ArrayList;
+import java.util.HashSet;
+import java.util.Iterator;
+import java.util.List;
+
+public class LibMatrixSketch {
+
+	private enum MatrixShape {
+		SKINNY,  // rows > cols
+		WIDE,    // rows < cols
+	}
+
+	public static MatrixBlock getUniqueValues(MatrixBlock blkIn, Types.Direction dir) {
+
+		int R = blkIn.getNumRows();
+		int C = blkIn.getNumColumns();
+		List<HashSet<Double>> hashSets = new ArrayList<>();
+
+		MatrixShape matrixShape = (R >= C)? MatrixShape.SKINNY : MatrixShape.WIDE;
+		MatrixBlock blkOut;
+		switch (dir)
+		{
+			case RowCol:
+				HashSet<Double> hashSet = new HashSet<>();
+				// TODO optimize for sparse and compressed inputs
+				for (int i=0; i<R; ++i) {
+					for (int j=0; j<C; ++j) {
+						hashSet.add(blkIn.getValue(i, j));
+					}
+				}
+				hashSets.add(hashSet);
+				blkOut = serializeRowCol(hashSets, dir, matrixShape);
+				break;
+
+			case Row:
+			case Col:
+				throw new NotImplementedException("Unique Row/Col has not been implemented yet");
+
+			default:
+				throw new IllegalArgumentException("Unrecognized direction: " + dir);
+		}
+
+		return blkOut;
+	}
+
+	private static MatrixBlock serializeRowCol(List<HashSet<Double>> hashSets, Types.Direction dir, MatrixShape matrixShape) {
+
+		if (dir != Types.Direction.RowCol) {
+			throw new IllegalArgumentException("Unrecognized direction: " + dir);
+		}
+
+		MatrixBlock blkOut;
+
+		if (hashSets.isEmpty()) {
+			throw new IllegalArgumentException("Corrupt sketch: metadata cannot be empty");
+		}
+
+		int R, C;
+		HashSet<Double> hashSet = hashSets.get(0);
+		Iterator<Double> iter = hashSet.iterator();
+
+		if (hashSet.size() <= OptimizerUtils.DEFAULT_BLOCKSIZE) {
+			if (matrixShape == MatrixShape.SKINNY) {
+				// Rx1 column vector
+				R = hashSet.size();
+				C = 1;
+			} else {  // WIDE
+				// 1xC row vector
+				R = 1;
+				C = hashSet.size();
+			}
+		} else {
+			if (matrixShape == MatrixShape.SKINNY) {
+				R = OptimizerUtils.DEFAULT_BLOCKSIZE;
+				C = (hashSet.size() / OptimizerUtils.DEFAULT_BLOCKSIZE) + 1;
+			} else {  // WIDE
+				R = (hashSet.size() / OptimizerUtils.DEFAULT_BLOCKSIZE) + 1;
+				C = OptimizerUtils.DEFAULT_BLOCKSIZE;
+			}
+		}
+
+		blkOut = new MatrixBlock(R, C, false);
+		for (int i=0; i<R; ++i) {
+			// C is guaranteed to be > 0
+			for (int j=0; j<C; ++j) {
+				blkOut.setValue(i, j, iter.next());
+			}
+		}
+
+		return blkOut;
+	}
+}
diff --git a/src/main/java/org/apache/sysds/runtime/matrix/operators/CountDistinctOperator.java b/src/main/java/org/apache/sysds/runtime/matrix/operators/CountDistinctOperator.java
index c33accf943..d3deba14de 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/operators/CountDistinctOperator.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/operators/CountDistinctOperator.java
@@ -26,15 +26,14 @@ import org.apache.sysds.runtime.functionobjects.Plus;
 import org.apache.sysds.runtime.instructions.cp.AggregateUnaryCPInstruction.AUType;
 import org.apache.sysds.utils.Hash.HashType;
 
-public class CountDistinctOperator extends AggregateUnaryOperator {
+public class CountDistinctOperator extends UnarySketchOperator {
 	private static final long serialVersionUID = 7615123453265129670L;
 
 	private final CountDistinctOperatorTypes operatorType;
-	private final Types.Direction direction;
 	private final HashType hashType;
 
 	public CountDistinctOperator(AUType opType, Types.Direction direction, IndexFunction indexFunction) {
-		super(new AggregateOperator(0, Plus.getPlusFnObject()), indexFunction, 1);
+		super(new AggregateOperator(0, Plus.getPlusFnObject()), indexFunction, direction, 1);
 
 		switch(opType) {
 			case COUNT_DISTINCT:
@@ -47,15 +46,13 @@ public class CountDistinctOperator extends AggregateUnaryOperator {
 				throw new DMLRuntimeException(opType + " not supported for CountDistinct Operator");
 		}
 		this.hashType = HashType.LinearHash;
-		this.direction = direction;
 	}
 
 	public CountDistinctOperator(CountDistinctOperatorTypes operatorType, Types.Direction direction,
 								 IndexFunction indexFunction, HashType hashType) {
-		super(new AggregateOperator(0, Plus.getPlusFnObject()), indexFunction, 1);
+		super(new AggregateOperator(0, Plus.getPlusFnObject()), indexFunction, direction, 1);
 
 		this.operatorType = operatorType;
-		this.direction = direction;
 		this.hashType = hashType;
 	}
 
@@ -66,8 +63,4 @@ public class CountDistinctOperator extends AggregateUnaryOperator {
 	public HashType getHashType() {
 		return hashType;
 	}
-
-	public Types.Direction getDirection() {
-		return direction;
-	}
 }
diff --git a/src/main/java/org/apache/sysds/runtime/matrix/operators/UnarySketchOperator.java b/src/main/java/org/apache/sysds/runtime/matrix/operators/UnarySketchOperator.java
new file mode 100644
index 0000000000..0716c45afd
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/matrix/operators/UnarySketchOperator.java
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.matrix.operators;
+
+import org.apache.sysds.common.Types;
+import org.apache.sysds.runtime.functionobjects.IndexFunction;
+
+public class UnarySketchOperator extends AggregateUnaryOperator {
+	private static final long serialVersionUID = 7615123453265129671L;
+
+	private final Types.Direction direction;
+
+	public UnarySketchOperator(AggregateOperator aop, IndexFunction indexFunction, Types.Direction direction) {
+		super(aop, indexFunction);
+		this.direction = direction;
+	}
+
+	public UnarySketchOperator(AggregateOperator aop, IndexFunction indexFunction,
+							   Types.Direction direction, int numThreads) {
+		super(aop, indexFunction, numThreads);
+		this.direction = direction;
+	}
+
+	public Types.Direction getDirection() {
+		return direction;
+	}
+}
diff --git a/src/main/python/systemds/operator/algorithm/builtin/unique.py b/src/main/python/systemds/operator/algorithm/builtin/unique.py
deleted file mode 100644
index fd77b1fd55..0000000000
--- a/src/main/python/systemds/operator/algorithm/builtin/unique.py
+++ /dev/null
@@ -1,45 +0,0 @@
-# -------------------------------------------------------------
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-# -------------------------------------------------------------
-
-# Autogenerated By   : src/main/python/generator/generator.py
-# Autogenerated From : scripts/builtin/unique.dml
-
-from typing import Dict, Iterable
-
-from systemds.operator import OperationNode, Matrix, Frame, List, MultiReturn, Scalar
-from systemds.script_building.dag import OutputType
-from systemds.utils.consts import VALID_INPUT_TYPES
-
-
-def unique(X: Matrix):
-    """
-     Builtin function that implements unique operation on vectors
-    
-    
-    
-    :param X: input vector
-    :return: matrix with only unique rows
-    """
-
-    params_dict = {'X': X}
-    return Matrix(X.sds_context,
-        'unique',
-        named_input_nodes=params_dict)
diff --git a/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinUniqueTest.java b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinUniqueTest.java
deleted file mode 100644
index 7d36d79b08..0000000000
--- a/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinUniqueTest.java
+++ /dev/null
@@ -1,114 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.sysds.test.functions.builtin;
-
-import org.apache.sysds.common.Types;
-import org.apache.sysds.common.Types.ExecType;
-import org.apache.sysds.runtime.matrix.data.MatrixValue;
-import org.apache.sysds.test.AutomatedTestBase;
-import org.apache.sysds.test.TestConfiguration;
-import org.apache.sysds.test.TestUtils;
-import org.junit.Test;
-
-import java.util.HashMap;
-
-public class BuiltinUniqueTest extends AutomatedTestBase {
-	private final static String TEST_NAME = "unique";
-	private final static String TEST_DIR = "functions/builtin/";
-	private static final String TEST_CLASS_DIR = TEST_DIR + BuiltinUniqueTest.class.getSimpleName() + "/";
-
-	@Override
-	public void setUp() {
-		TestUtils.clearAssertionInformation();
-		addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"R"}));
-	}
-
-	@Test
-	public void testUnique1CP() {
-		double[][] X = {{1},{1},{6},{9},{4},{2},{0},{9},{0},{0},{4},{4}};
-		runUniqueTest(X, ExecType.CP);
-	}
-
-	@Test
-	public void testUnique1SP() {
-		double[][] X = {{1},{1},{6},{9},{4},{2},{0},{9},{0},{0},{4},{4}};
-		runUniqueTest(X,ExecType.SPARK);
-	}
-
-	@Test
-	public void testUnique2CP() {
-		double[][] X = {{0}};
-		runUniqueTest(X, ExecType.CP);
-	}
-
-	@Test
-	public void testUnique2SP() {
-		double[][] X = {{0}};
-		runUniqueTest(X, ExecType.SPARK);
-	}
-
-	@Test
-	public void testUnique3CP() {
-		double[][] X = {{1, 2, 3}, {2, 3, 4}, {1, 2, 3}};
-		runUniqueTest(X, ExecType.CP);
-	}
-
-//	@Test
-//	public void testUnique3SP() { //This fails?
-//		double[][] X = {{1, 2, 3}, {2, 3, 4}, {1, 2, 3}};
-//		runUniqueTest(X, ExecType.SPARK);
-//	}
-
-	@Test
-	public void testUnique4CP() {
-		double[][] X = {{1.5, 2}, {7, 3}, {1, 3}, {1.5, 2}, {-1, -2.32}, {-1, 0.1}, {1, 3}, {-1, 0.1}};
-		runUniqueTest(X, ExecType.CP);
-	}
-
-//	@Test
-//	public void testUnique4SP() { //This fails?
-//		double[][] X = {{1.5, 2}, {7, 3}, {1, 3}, {1.5, 2}, {-1, -2.32}, {-1, 0.1}, {1, 3}, {-1, 0.1}};
-//		runUniqueTest(X, ExecType.SPARK);
-//	}
-
-	private void runUniqueTest(double[][] X, ExecType instType) {
-		Types.ExecMode platformOld = setExecMode(instType);
-		try {
-			loadTestConfiguration(getTestConfiguration(TEST_NAME));
-			String HOME = SCRIPT_DIR + TEST_DIR;
-			fullDMLScriptName = HOME + TEST_NAME + ".dml";
-			programArgs = new String[]{ "-args", input("X"), output("R")};
-			fullRScriptName = HOME + TEST_NAME + ".R";
-			rCmd = "Rscript" + " " + fullRScriptName + " " + inputDir() + " " + expectedDir();
-
-			writeInputMatrixWithMTD("X", X, true);
-
-			runTest(true, false, null, -1);
-			runRScript(true);
-
-			HashMap<MatrixValue.CellIndex, Double> dmlfile = readDMLMatrixFromOutputDir("R");
-			HashMap<MatrixValue.CellIndex, Double> rfile  = readRMatrixFromExpectedDir("R");
-			TestUtils.compareMatrices(dmlfile, rfile, 1e-10, "dml", "expected");
-		}
-		finally {
-			rtplatform = platformOld;
-		}
-	}
-}
diff --git a/src/test/java/org/apache/sysds/test/functions/unique/UniqueBase.java b/src/test/java/org/apache/sysds/test/functions/unique/UniqueBase.java
new file mode 100644
index 0000000000..6b78a60290
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/unique/UniqueBase.java
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.unique;
+
+import org.apache.sysds.common.Types;
+import org.apache.sysds.runtime.matrix.data.MatrixValue;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+
+import java.util.HashMap;
+
+public abstract class UniqueBase extends AutomatedTestBase {
+
+	protected abstract String getTestName();
+
+	protected abstract String getTestDir();
+
+	protected abstract String getTestClassDir();
+
+	@Override
+	public void setUp() {
+		TestUtils.clearAssertionInformation();
+		addTestConfiguration(getTestName(), new TestConfiguration(getTestClassDir(), getTestName(), new String[] {"A"}));
+	}
+
+	protected void uniqueTest(double[][] inputMatrix, double[][] expectedMatrix,
+							Types.ExecType instType, double epsilon) {
+		Types.ExecMode platformOld = setExecMode(instType);
+		try {
+			loadTestConfiguration(getTestConfiguration(getTestName()));
+			String HOME = SCRIPT_DIR + getTestDir();
+			fullDMLScriptName = HOME + getTestName() + ".dml";
+			programArgs = new String[]{ "-args", input("I"), output("A")};
+
+			writeInputMatrixWithMTD("I", inputMatrix, true);
+
+			runTest(true, false, null, -1);
+			writeExpectedMatrix("A", expectedMatrix);
+
+			compareResultsRowsOutOfOrder(epsilon);
+		}
+		finally {
+			rtplatform = platformOld;
+		}
+	}
+}
diff --git a/src/test/java/org/apache/sysds/test/functions/unique/UniqueRowCol.java b/src/test/java/org/apache/sysds/test/functions/unique/UniqueRowCol.java
new file mode 100644
index 0000000000..a8b2fc1ba7
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/unique/UniqueRowCol.java
@@ -0,0 +1,145 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.unique;
+
+import org.apache.sysds.common.Types;
+import org.junit.Test;
+
+public class UniqueRowCol extends UniqueBase {
+	private final static String TEST_NAME = "uniqueRowCol";
+	private final static String TEST_DIR = "functions/unique/";
+	private static final String TEST_CLASS_DIR = TEST_DIR + UniqueRowCol.class.getSimpleName() + "/";
+
+
+	@Override
+	protected String getTestName() {
+		return TEST_NAME;
+	}
+
+	@Override
+	protected String getTestDir() {
+		return TEST_DIR;
+	}
+
+	@Override
+	protected String getTestClassDir() {
+		return TEST_CLASS_DIR;
+	}
+
+	@Test
+	public void testBaseCase1CP() {
+		double[][] inputMatrix = {{0}};
+		double[][] expectedMatrix = {{0}};
+		uniqueTest(inputMatrix, expectedMatrix, Types.ExecType.CP, 0.0);
+	}
+
+	@Test
+	public void testBaseCase2CP() {
+		double[][] inputMatrix = {{1}};
+		double[][] expectedMatrix = {{1}};
+		uniqueTest(inputMatrix, expectedMatrix, Types.ExecType.CP, 0.0);
+	}
+
+	@Test
+	public void testSkinnySmallCP() {
+		double[][] inputMatrix = {{1},{1},{6},{9},{4},{2},{0},{9},{0},{0},{4},{4}};
+		double[][] expectedMatrix = {{1},{6},{9},{4},{2},{0}};
+		uniqueTest(inputMatrix, expectedMatrix, Types.ExecType.CP, 0.0);
+	}
+
+	@Test
+	public void testWideSmallCP() {
+		double[][] inputMatrix = {{1,1,6,9,4,2,0,9,0,0,4,4}};
+		double[][] expectedMatrix = {{1,6,9,4,2,0}};
+		uniqueTest(inputMatrix, expectedMatrix, Types.ExecType.CP, 0.0);
+	}
+
+	@Test
+	public void testSquareLargeCP() {
+		double[][] inputMatrix = new double[1000][1000];
+		// Input is a 1000 x 1000 matrix:
+		// [1, 1, ..., 1, 2, 2, .., 2]
+		// [1, 1, ..., 1, 2, 2, .., 2]
+		// ..
+		// [1, 1, ..., 1, 2, 2, .., 2]
+		// [2, 2, ..., 2, 1, 1, .., 1]
+		// [2, 2, ..., 2, 1, 1, .., 1]
+		// ..
+		// [2, 2, ..., 2, 1, 1, .., 1]
+		for (int i=0; i<500; ++i) {
+			for (int j=0; j<500; ++j) {
+				inputMatrix[i][j] = 1;
+				inputMatrix[i+500][j+500] = 1;
+			}
+		}
+		for (int i=500; i<1000; ++i) {
+			for (int j=0; j<500; ++j) {
+				inputMatrix[i][j] = 2;
+				inputMatrix[i-500][j+500] = 2;
+			}
+		}
+		// Expect the output to be a skinny matrix due to the following condition in code:
+		// (R >= C)? LibMatrixSketch.MatrixShape.SKINNY : LibMatrixSketch.MatrixShape.WIDE;
+		double[][] expectedMatrix = {{1},{2}};
+		uniqueTest(inputMatrix, expectedMatrix, Types.ExecType.CP, 0.0);
+	}
+
+	@Test
+	public void testSkinnyLargeCP() {
+		double[][] inputMatrix = new double[2000][2];
+		// Input is a 2000 x 2 matrix:
+		// [1, 2]
+		// [1, 2]
+		// ..
+		// [1, 2]
+		// [2, 1]
+		// [2, 1]
+		// ..
+		// [2, 1]
+		for (int i=0; i<1000; ++i) {
+			inputMatrix[i][0] = 1;
+			inputMatrix[i][1] = 2;
+		}
+		for (int i=1000; i<2000; ++i) {
+			inputMatrix[i][0] = 2;
+			inputMatrix[i][1] = 1;
+		}
+		double[][] expectedMatrix = {{1}, {2}};
+		uniqueTest(inputMatrix, expectedMatrix, Types.ExecType.CP, 0.0);
+	}
+
+	@Test
+	public void testWideLargeCP() {
+		double[][] inputMatrix = new double[2][2000];
+		// Input is a 2 x 2000 matrix:
+		// [1, 1, ..., 1, 2, 2, .., 2]
+		// [2, 2, ..., 2, 1, 1, .., 1]
+		for (int j=0; j<1000; ++j) {
+			inputMatrix[0][j] = 1;
+			inputMatrix[1][j+1000] = 1;
+		}
+		for (int j=1000; j<2000; ++j) {
+			inputMatrix[0][j] = 2;
+			inputMatrix[1][j-1000] = 2;
+		}
+		double[][] expectedMatrix = {{1,2}};
+		uniqueTest(inputMatrix, expectedMatrix, Types.ExecType.CP, 0.0);
+	}
+}
diff --git a/src/test/scripts/functions/builtin/unique.R b/src/test/scripts/functions/builtin/unique.R
deleted file mode 100644
index 6f4c17895e..0000000000
--- a/src/test/scripts/functions/builtin/unique.R
+++ /dev/null
@@ -1,27 +0,0 @@
-#-------------------------------------------------------------
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-#-------------------------------------------------------------
-args<-commandArgs(TRUE)
-options(digits=22)
-library("Matrix")
-
-X = as.matrix(readMM(paste(args[1], "X.mtx", sep="")));
-R = unique(X[order(X[,1]),]);
-writeMM(as(R, "CsparseMatrix"), paste(args[2], "R", sep=""));
\ No newline at end of file
diff --git a/src/test/scripts/functions/builtin/unique.dml b/src/test/scripts/functions/unique/uniqueRowCol.dml
similarity index 92%
rename from src/test/scripts/functions/builtin/unique.dml
rename to src/test/scripts/functions/unique/uniqueRowCol.dml
index 55b5aab378..2022342418 100644
--- a/src/test/scripts/functions/builtin/unique.dml
+++ b/src/test/scripts/functions/unique/uniqueRowCol.dml
@@ -19,6 +19,6 @@
 #
 #-------------------------------------------------------------
 
-X = read($1);
-R = unique(X = X);
-write(R, $2);
+input = read($1);
+res = unique(input);
+write(res, $2, format="text");