You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ue...@apache.org on 2018/05/17 04:31:20 UTC
spark git commit: [SPARK-23925][SQL] Add array_repeat collection
function
Repository: spark
Updated Branches:
refs/heads/master 9a641e7f7 -> 3e66350c2
[SPARK-23925][SQL] Add array_repeat collection function
## What changes were proposed in this pull request?
The PR adds a new collection function, array_repeat. As there already was a function repeat with the same signature, with the only difference being the expected return type (String instead of Array), the new function is called array_repeat to distinguish.
The behaviour of the function is based on Presto's one.
The function creates an array containing a given element repeated the requested number of times.
## How was this patch tested?
New unit tests added into:
- CollectionExpressionsSuite
- DataFrameFunctionsSuite
Author: Florent Pépin <fl...@gmail.com>
Author: Florent Pépin <fl...@imperial.ac.uk>
Closes #21208 from pepinoflo/SPARK-23925.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/3e66350c
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/3e66350c
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/3e66350c
Branch: refs/heads/master
Commit: 3e66350c2477a456560302b7738c9d122d5d9c43
Parents: 9a641e7
Author: Florent Pépin <fl...@gmail.com>
Authored: Thu May 17 13:31:14 2018 +0900
Committer: Takuya UESHIN <ue...@databricks.com>
Committed: Thu May 17 13:31:14 2018 +0900
----------------------------------------------------------------------
python/pyspark/sql/functions.py | 14 ++
.../catalyst/analysis/FunctionRegistry.scala | 1 +
.../expressions/collectionOperations.scala | 149 +++++++++++++++++++
.../CollectionExpressionsSuite.scala | 18 +++
.../scala/org/apache/spark/sql/functions.scala | 20 +++
.../spark/sql/DataFrameFunctionsSuite.scala | 76 ++++++++++
6 files changed, 278 insertions(+)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/3e66350c/python/pyspark/sql/functions.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 6866c1c..925ac34 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -2329,6 +2329,20 @@ def map_values(col):
return Column(sc._jvm.functions.map_values(_to_java_column(col)))
+@ignore_unicode_prefix
+@since(2.4)
+def array_repeat(col, count):
+ """
+ Collection function: creates an array containing a column repeated count times.
+
+ >>> df = spark.createDataFrame([('ab',)], ['data'])
+ >>> df.select(array_repeat(df.data, 3).alias('r')).collect()
+ [Row(r=[u'ab', u'ab', u'ab'])]
+ """
+ sc = SparkContext._active_spark_context
+ return Column(sc._jvm.functions.array_repeat(_to_java_column(col), count))
+
+
# ---------------------------- User Defined Function ----------------------------------
class PandasUDFType(object):
http://git-wip-us.apache.org/repos/asf/spark/blob/3e66350c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 087d000..9c37059 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -427,6 +427,7 @@ object FunctionRegistry {
expression[Reverse]("reverse"),
expression[Concat]("concat"),
expression[Flatten]("flatten"),
+ expression[ArrayRepeat]("array_repeat"),
CreateStruct.registryEntry,
// misc functions
http://git-wip-us.apache.org/repos/asf/spark/blob/3e66350c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index 12b9ab2..2a4e42d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -1468,3 +1468,152 @@ case class Flatten(child: Expression) extends UnaryExpression {
override def prettyName: String = "flatten"
}
+
+/**
+ * Returns the array containing the given input value (left) count (right) times.
+ */
+@ExpressionDescription(
+ usage = "_FUNC_(element, count) - Returns the array containing element count times.",
+ examples = """
+ Examples:
+ > SELECT _FUNC_('123', 2);
+ ['123', '123']
+ """,
+ since = "2.4.0")
+case class ArrayRepeat(left: Expression, right: Expression)
+ extends BinaryExpression with ExpectsInputTypes {
+
+ private val MAX_ARRAY_LENGTH = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH
+
+ override def dataType: ArrayType = ArrayType(left.dataType, left.nullable)
+
+ override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, IntegerType)
+
+ override def nullable: Boolean = right.nullable
+
+ override def eval(input: InternalRow): Any = {
+ val count = right.eval(input)
+ if (count == null) {
+ null
+ } else {
+ if (count.asInstanceOf[Int] > MAX_ARRAY_LENGTH) {
+ throw new RuntimeException(s"Unsuccessful try to create array with $count elements " +
+ s"due to exceeding the array size limit $MAX_ARRAY_LENGTH.");
+ }
+ val element = left.eval(input)
+ new GenericArrayData(Array.fill(count.asInstanceOf[Int])(element))
+ }
+ }
+
+ override def prettyName: String = "array_repeat"
+
+ override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ val leftGen = left.genCode(ctx)
+ val rightGen = right.genCode(ctx)
+ val element = leftGen.value
+ val count = rightGen.value
+ val et = dataType.elementType
+
+ val coreLogic = if (CodeGenerator.isPrimitiveType(et)) {
+ genCodeForPrimitiveElement(ctx, et, element, count, leftGen.isNull, ev.value)
+ } else {
+ genCodeForNonPrimitiveElement(ctx, element, count, leftGen.isNull, ev.value)
+ }
+ val resultCode = nullElementsProtection(ev, rightGen.isNull, coreLogic)
+
+ ev.copy(code =
+ s"""
+ |boolean ${ev.isNull} = false;
+ |${leftGen.code}
+ |${rightGen.code}
+ |${CodeGenerator.javaType(dataType)} ${ev.value} =
+ | ${CodeGenerator.defaultValue(dataType)};
+ |$resultCode
+ """.stripMargin)
+ }
+
+ private def nullElementsProtection(
+ ev: ExprCode,
+ rightIsNull: String,
+ coreLogic: String): String = {
+ if (nullable) {
+ s"""
+ |if ($rightIsNull) {
+ | ${ev.isNull} = true;
+ |} else {
+ | ${coreLogic}
+ |}
+ """.stripMargin
+ } else {
+ coreLogic
+ }
+ }
+
+ private def genCodeForNumberOfElements(ctx: CodegenContext, count: String): (String, String) = {
+ val numElements = ctx.freshName("numElements")
+ val numElementsCode =
+ s"""
+ |int $numElements = 0;
+ |if ($count > 0) {
+ | $numElements = $count;
+ |}
+ |if ($numElements > $MAX_ARRAY_LENGTH) {
+ | throw new RuntimeException("Unsuccessful try to create array with " + $numElements +
+ | " elements due to exceeding the array size limit $MAX_ARRAY_LENGTH.");
+ |}
+ """.stripMargin
+
+ (numElements, numElementsCode)
+ }
+
+ private def genCodeForPrimitiveElement(
+ ctx: CodegenContext,
+ elementType: DataType,
+ element: String,
+ count: String,
+ leftIsNull: String,
+ arrayDataName: String): String = {
+ val tempArrayDataName = ctx.freshName("tempArrayData")
+ val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType)
+ val errorMessage = s" $prettyName failed."
+ val (numElemName, numElemCode) = genCodeForNumberOfElements(ctx, count)
+
+ s"""
+ |$numElemCode
+ |${ctx.createUnsafeArray(tempArrayDataName, numElemName, elementType, errorMessage)}
+ |if (!$leftIsNull) {
+ | for (int k = 0; k < $tempArrayDataName.numElements(); k++) {
+ | $tempArrayDataName.set$primitiveValueTypeName(k, $element);
+ | }
+ |} else {
+ | for (int k = 0; k < $tempArrayDataName.numElements(); k++) {
+ | $tempArrayDataName.setNullAt(k);
+ | }
+ |}
+ |$arrayDataName = $tempArrayDataName;
+ """.stripMargin
+ }
+
+ private def genCodeForNonPrimitiveElement(
+ ctx: CodegenContext,
+ element: String,
+ count: String,
+ leftIsNull: String,
+ arrayDataName: String): String = {
+ val genericArrayClass = classOf[GenericArrayData].getName
+ val arrayName = ctx.freshName("arrayObject")
+ val (numElemName, numElemCode) = genCodeForNumberOfElements(ctx, count)
+
+ s"""
+ |$numElemCode
+ |Object[] $arrayName = new Object[(int)$numElemName];
+ |if (!$leftIsNull) {
+ | for (int k = 0; k < $numElemName; k++) {
+ | $arrayName[k] = $element;
+ | }
+ |}
+ |$arrayDataName = new $genericArrayClass($arrayName);
+ """.stripMargin
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/3e66350c/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
index a2851d0..57fc5f7 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
@@ -468,4 +468,22 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(Flatten(asa3), null)
checkEvaluation(Flatten(asa4), null)
}
+
+ test("ArrayRepeat") {
+ val intArray = Literal.create(Seq(1, 2), ArrayType(IntegerType))
+ val strArray = Literal.create(Seq("hi", "hola"), ArrayType(StringType))
+
+ checkEvaluation(ArrayRepeat(Literal("hi"), Literal(0)), Seq())
+ checkEvaluation(ArrayRepeat(Literal("hi"), Literal(-1)), Seq())
+ checkEvaluation(ArrayRepeat(Literal("hi"), Literal(1)), Seq("hi"))
+ checkEvaluation(ArrayRepeat(Literal("hi"), Literal(2)), Seq("hi", "hi"))
+ checkEvaluation(ArrayRepeat(Literal(true), Literal(2)), Seq(true, true))
+ checkEvaluation(ArrayRepeat(Literal(1), Literal(2)), Seq(1, 1))
+ checkEvaluation(ArrayRepeat(Literal(3.2), Literal(2)), Seq(3.2, 3.2))
+ checkEvaluation(ArrayRepeat(Literal(null), Literal(2)), Seq[String](null, null))
+ checkEvaluation(ArrayRepeat(Literal(null, IntegerType), Literal(2)), Seq[Integer](null, null))
+ checkEvaluation(ArrayRepeat(intArray, Literal(2)), Seq(Seq(1, 2), Seq(1, 2)))
+ checkEvaluation(ArrayRepeat(strArray, Literal(2)), Seq(Seq("hi", "hola"), Seq("hi", "hola")))
+ checkEvaluation(ArrayRepeat(Literal("hi"), Literal(null, IntegerType)), null)
+ }
}
http://git-wip-us.apache.org/repos/asf/spark/blob/3e66350c/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index b71dfda..550571a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -3448,6 +3448,26 @@ object functions {
def flatten(e: Column): Column = withExpr { Flatten(e.expr) }
/**
+ * Creates an array containing the left argument repeated the number of times given by the
+ * right argument.
+ *
+ * @group collection_funcs
+ * @since 2.4.0
+ */
+ def array_repeat(left: Column, right: Column): Column = withExpr {
+ ArrayRepeat(left.expr, right.expr)
+ }
+
+ /**
+ * Creates an array containing the left argument repeated the number of times given by the
+ * right argument.
+ *
+ * @group collection_funcs
+ * @since 2.4.0
+ */
+ def array_repeat(e: Column, count: Int): Column = array_repeat(e, lit(count))
+
+ /**
* Returns an unordered array containing the keys of the map.
* @group collection_funcs
* @since 2.3.0
http://git-wip-us.apache.org/repos/asf/spark/blob/3e66350c/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index ecce06f..e26565c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -843,6 +843,82 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
}
}
+ test("array_repeat function") {
+ val dummyFilter = (c: Column) => c.isNull || c.isNotNull // to switch codeGen on
+ val strDF = Seq(
+ ("hi", 2),
+ (null, 2)
+ ).toDF("a", "b")
+
+ val strDFTwiceResult = Seq(
+ Row(Seq("hi", "hi")),
+ Row(Seq(null, null))
+ )
+
+ checkAnswer(strDF.select(array_repeat($"a", 2)), strDFTwiceResult)
+ checkAnswer(strDF.filter(dummyFilter($"a")).select(array_repeat($"a", 2)), strDFTwiceResult)
+ checkAnswer(strDF.select(array_repeat($"a", $"b")), strDFTwiceResult)
+ checkAnswer(strDF.filter(dummyFilter($"a")).select(array_repeat($"a", $"b")), strDFTwiceResult)
+ checkAnswer(strDF.selectExpr("array_repeat(a, 2)"), strDFTwiceResult)
+ checkAnswer(strDF.selectExpr("array_repeat(a, b)"), strDFTwiceResult)
+
+ val intDF = {
+ val schema = StructType(Seq(
+ StructField("a", IntegerType),
+ StructField("b", IntegerType)))
+ val data = Seq(
+ Row(3, 2),
+ Row(null, 2)
+ )
+ spark.createDataFrame(spark.sparkContext.parallelize(data), schema)
+ }
+
+ val intDFTwiceResult = Seq(
+ Row(Seq(3, 3)),
+ Row(Seq(null, null))
+ )
+
+ checkAnswer(intDF.select(array_repeat($"a", 2)), intDFTwiceResult)
+ checkAnswer(intDF.filter(dummyFilter($"a")).select(array_repeat($"a", 2)), intDFTwiceResult)
+ checkAnswer(intDF.select(array_repeat($"a", $"b")), intDFTwiceResult)
+ checkAnswer(intDF.filter(dummyFilter($"a")).select(array_repeat($"a", $"b")), intDFTwiceResult)
+ checkAnswer(intDF.selectExpr("array_repeat(a, 2)"), intDFTwiceResult)
+ checkAnswer(intDF.selectExpr("array_repeat(a, b)"), intDFTwiceResult)
+
+ val nullCountDF = {
+ val schema = StructType(Seq(
+ StructField("a", StringType),
+ StructField("b", IntegerType)))
+ val data = Seq(
+ Row("hi", null),
+ Row(null, null)
+ )
+ spark.createDataFrame(spark.sparkContext.parallelize(data), schema)
+ }
+
+ checkAnswer(
+ nullCountDF.select(array_repeat($"a", $"b")),
+ Seq(
+ Row(null),
+ Row(null)
+ )
+ )
+
+ // Error test cases
+ val invalidTypeDF = Seq(("hi", "1")).toDF("a", "b")
+
+ intercept[AnalysisException] {
+ invalidTypeDF.select(array_repeat($"a", $"b"))
+ }
+ intercept[AnalysisException] {
+ invalidTypeDF.select(array_repeat($"a", lit("1")))
+ }
+ intercept[AnalysisException] {
+ invalidTypeDF.selectExpr("array_repeat(a, 1.0)")
+ }
+
+ }
+
private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = {
import DataFrameFunctionsSuite.CodegenFallbackExpr
for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) {
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org