You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ue...@apache.org on 2018/05/17 04:31:20 UTC

spark git commit: [SPARK-23925][SQL] Add array_repeat collection function

Repository: spark
Updated Branches:
  refs/heads/master 9a641e7f7 -> 3e66350c2


[SPARK-23925][SQL] Add array_repeat collection function

## What changes were proposed in this pull request?

The PR adds a new collection function, array_repeat. As there already was a function repeat with the same signature, with the only difference being the expected return type (String instead of Array), the new function is called array_repeat to distinguish.
The behaviour of the function is based on Presto's one.

The function creates an array containing a given element repeated the requested number of times.

## How was this patch tested?

New unit tests added into:
- CollectionExpressionsSuite
- DataFrameFunctionsSuite

Author: Florent Pépin <fl...@gmail.com>
Author: Florent Pépin <fl...@imperial.ac.uk>

Closes #21208 from pepinoflo/SPARK-23925.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/3e66350c
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/3e66350c
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/3e66350c

Branch: refs/heads/master
Commit: 3e66350c2477a456560302b7738c9d122d5d9c43
Parents: 9a641e7
Author: Florent Pépin <fl...@gmail.com>
Authored: Thu May 17 13:31:14 2018 +0900
Committer: Takuya UESHIN <ue...@databricks.com>
Committed: Thu May 17 13:31:14 2018 +0900

----------------------------------------------------------------------
 python/pyspark/sql/functions.py                 |  14 ++
 .../catalyst/analysis/FunctionRegistry.scala    |   1 +
 .../expressions/collectionOperations.scala      | 149 +++++++++++++++++++
 .../CollectionExpressionsSuite.scala            |  18 +++
 .../scala/org/apache/spark/sql/functions.scala  |  20 +++
 .../spark/sql/DataFrameFunctionsSuite.scala     |  76 ++++++++++
 6 files changed, 278 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/3e66350c/python/pyspark/sql/functions.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 6866c1c..925ac34 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -2329,6 +2329,20 @@ def map_values(col):
     return Column(sc._jvm.functions.map_values(_to_java_column(col)))
 
 
+@ignore_unicode_prefix
+@since(2.4)
+def array_repeat(col, count):
+    """
+    Collection function: creates an array containing a column repeated count times.
+
+    >>> df = spark.createDataFrame([('ab',)], ['data'])
+    >>> df.select(array_repeat(df.data, 3).alias('r')).collect()
+    [Row(r=[u'ab', u'ab', u'ab'])]
+    """
+    sc = SparkContext._active_spark_context
+    return Column(sc._jvm.functions.array_repeat(_to_java_column(col), count))
+
+
 # ---------------------------- User Defined Function ----------------------------------
 
 class PandasUDFType(object):

http://git-wip-us.apache.org/repos/asf/spark/blob/3e66350c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 087d000..9c37059 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -427,6 +427,7 @@ object FunctionRegistry {
     expression[Reverse]("reverse"),
     expression[Concat]("concat"),
     expression[Flatten]("flatten"),
+    expression[ArrayRepeat]("array_repeat"),
     CreateStruct.registryEntry,
 
     // misc functions

http://git-wip-us.apache.org/repos/asf/spark/blob/3e66350c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index 12b9ab2..2a4e42d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -1468,3 +1468,152 @@ case class Flatten(child: Expression) extends UnaryExpression {
 
   override def prettyName: String = "flatten"
 }
+
+/**
+ * Returns the array containing the given input value (left) count (right) times.
+ */
+@ExpressionDescription(
+  usage = "_FUNC_(element, count) - Returns the array containing element count times.",
+  examples = """
+    Examples:
+      > SELECT _FUNC_('123', 2);
+       ['123', '123']
+  """,
+  since = "2.4.0")
+case class ArrayRepeat(left: Expression, right: Expression)
+  extends BinaryExpression with ExpectsInputTypes {
+
+  private val MAX_ARRAY_LENGTH = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH
+
+  override def dataType: ArrayType = ArrayType(left.dataType, left.nullable)
+
+  override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, IntegerType)
+
+  override def nullable: Boolean = right.nullable
+
+  override def eval(input: InternalRow): Any = {
+    val count = right.eval(input)
+    if (count == null) {
+      null
+    } else {
+      if (count.asInstanceOf[Int] > MAX_ARRAY_LENGTH) {
+        throw new RuntimeException(s"Unsuccessful try to create array with $count elements " +
+          s"due to exceeding the array size limit $MAX_ARRAY_LENGTH.");
+      }
+      val element = left.eval(input)
+      new GenericArrayData(Array.fill(count.asInstanceOf[Int])(element))
+    }
+  }
+
+  override def prettyName: String = "array_repeat"
+
+  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+    val leftGen = left.genCode(ctx)
+    val rightGen = right.genCode(ctx)
+    val element = leftGen.value
+    val count = rightGen.value
+    val et = dataType.elementType
+
+    val coreLogic = if (CodeGenerator.isPrimitiveType(et)) {
+      genCodeForPrimitiveElement(ctx, et, element, count, leftGen.isNull, ev.value)
+    } else {
+      genCodeForNonPrimitiveElement(ctx, element, count, leftGen.isNull, ev.value)
+    }
+    val resultCode = nullElementsProtection(ev, rightGen.isNull, coreLogic)
+
+    ev.copy(code =
+      s"""
+         |boolean ${ev.isNull} = false;
+         |${leftGen.code}
+         |${rightGen.code}
+         |${CodeGenerator.javaType(dataType)} ${ev.value} =
+         |  ${CodeGenerator.defaultValue(dataType)};
+         |$resultCode
+       """.stripMargin)
+  }
+
+  private def nullElementsProtection(
+      ev: ExprCode,
+      rightIsNull: String,
+      coreLogic: String): String = {
+    if (nullable) {
+      s"""
+         |if ($rightIsNull) {
+         |  ${ev.isNull} = true;
+         |} else {
+         |  ${coreLogic}
+         |}
+       """.stripMargin
+    } else {
+      coreLogic
+    }
+  }
+
+  private def genCodeForNumberOfElements(ctx: CodegenContext, count: String): (String, String) = {
+    val numElements = ctx.freshName("numElements")
+    val numElementsCode =
+      s"""
+         |int $numElements = 0;
+         |if ($count > 0) {
+         |  $numElements = $count;
+         |}
+         |if ($numElements > $MAX_ARRAY_LENGTH) {
+         |  throw new RuntimeException("Unsuccessful try to create array with " + $numElements +
+         |    " elements due to exceeding the array size limit $MAX_ARRAY_LENGTH.");
+         |}
+       """.stripMargin
+
+    (numElements, numElementsCode)
+  }
+
+  private def genCodeForPrimitiveElement(
+      ctx: CodegenContext,
+      elementType: DataType,
+      element: String,
+      count: String,
+      leftIsNull: String,
+      arrayDataName: String): String = {
+    val tempArrayDataName = ctx.freshName("tempArrayData")
+    val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType)
+    val errorMessage = s" $prettyName failed."
+    val (numElemName, numElemCode) = genCodeForNumberOfElements(ctx, count)
+
+    s"""
+       |$numElemCode
+       |${ctx.createUnsafeArray(tempArrayDataName, numElemName, elementType, errorMessage)}
+       |if (!$leftIsNull) {
+       |  for (int k = 0; k < $tempArrayDataName.numElements(); k++) {
+       |    $tempArrayDataName.set$primitiveValueTypeName(k, $element);
+       |  }
+       |} else {
+       |  for (int k = 0; k < $tempArrayDataName.numElements(); k++) {
+       |    $tempArrayDataName.setNullAt(k);
+       |  }
+       |}
+       |$arrayDataName = $tempArrayDataName;
+     """.stripMargin
+  }
+
+  private def genCodeForNonPrimitiveElement(
+      ctx: CodegenContext,
+      element: String,
+      count: String,
+      leftIsNull: String,
+      arrayDataName: String): String = {
+    val genericArrayClass = classOf[GenericArrayData].getName
+    val arrayName = ctx.freshName("arrayObject")
+    val (numElemName, numElemCode) = genCodeForNumberOfElements(ctx, count)
+
+    s"""
+       |$numElemCode
+       |Object[] $arrayName = new Object[(int)$numElemName];
+       |if (!$leftIsNull) {
+       |  for (int k = 0; k < $numElemName; k++) {
+       |    $arrayName[k] = $element;
+       |  }
+       |}
+       |$arrayDataName = new $genericArrayClass($arrayName);
+     """.stripMargin
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/3e66350c/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
index a2851d0..57fc5f7 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
@@ -468,4 +468,22 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
     checkEvaluation(Flatten(asa3), null)
     checkEvaluation(Flatten(asa4), null)
   }
+
+  test("ArrayRepeat") {
+    val intArray = Literal.create(Seq(1, 2), ArrayType(IntegerType))
+    val strArray = Literal.create(Seq("hi", "hola"), ArrayType(StringType))
+
+    checkEvaluation(ArrayRepeat(Literal("hi"), Literal(0)), Seq())
+    checkEvaluation(ArrayRepeat(Literal("hi"), Literal(-1)), Seq())
+    checkEvaluation(ArrayRepeat(Literal("hi"), Literal(1)), Seq("hi"))
+    checkEvaluation(ArrayRepeat(Literal("hi"), Literal(2)), Seq("hi", "hi"))
+    checkEvaluation(ArrayRepeat(Literal(true), Literal(2)), Seq(true, true))
+    checkEvaluation(ArrayRepeat(Literal(1), Literal(2)), Seq(1, 1))
+    checkEvaluation(ArrayRepeat(Literal(3.2), Literal(2)), Seq(3.2, 3.2))
+    checkEvaluation(ArrayRepeat(Literal(null), Literal(2)), Seq[String](null, null))
+    checkEvaluation(ArrayRepeat(Literal(null, IntegerType), Literal(2)), Seq[Integer](null, null))
+    checkEvaluation(ArrayRepeat(intArray, Literal(2)), Seq(Seq(1, 2), Seq(1, 2)))
+    checkEvaluation(ArrayRepeat(strArray, Literal(2)), Seq(Seq("hi", "hola"), Seq("hi", "hola")))
+    checkEvaluation(ArrayRepeat(Literal("hi"), Literal(null, IntegerType)), null)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/3e66350c/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index b71dfda..550571a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -3448,6 +3448,26 @@ object functions {
   def flatten(e: Column): Column = withExpr { Flatten(e.expr) }
 
   /**
+   * Creates an array containing the left argument repeated the number of times given by the
+   * right argument.
+   *
+   * @group collection_funcs
+   * @since 2.4.0
+   */
+  def array_repeat(left: Column, right: Column): Column = withExpr {
+    ArrayRepeat(left.expr, right.expr)
+  }
+
+  /**
+   * Creates an array containing the left argument repeated the number of times given by the
+   * right argument.
+   *
+   * @group collection_funcs
+   * @since 2.4.0
+   */
+  def array_repeat(e: Column, count: Int): Column = array_repeat(e, lit(count))
+
+  /**
    * Returns an unordered array containing the keys of the map.
    * @group collection_funcs
    * @since 2.3.0

http://git-wip-us.apache.org/repos/asf/spark/blob/3e66350c/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index ecce06f..e26565c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -843,6 +843,82 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
     }
   }
 
+  test("array_repeat function") {
+    val dummyFilter = (c: Column) => c.isNull || c.isNotNull // to switch codeGen on
+    val strDF = Seq(
+      ("hi", 2),
+      (null, 2)
+    ).toDF("a", "b")
+
+    val strDFTwiceResult = Seq(
+      Row(Seq("hi", "hi")),
+      Row(Seq(null, null))
+    )
+
+    checkAnswer(strDF.select(array_repeat($"a", 2)), strDFTwiceResult)
+    checkAnswer(strDF.filter(dummyFilter($"a")).select(array_repeat($"a", 2)), strDFTwiceResult)
+    checkAnswer(strDF.select(array_repeat($"a", $"b")), strDFTwiceResult)
+    checkAnswer(strDF.filter(dummyFilter($"a")).select(array_repeat($"a", $"b")), strDFTwiceResult)
+    checkAnswer(strDF.selectExpr("array_repeat(a, 2)"), strDFTwiceResult)
+    checkAnswer(strDF.selectExpr("array_repeat(a, b)"), strDFTwiceResult)
+
+    val intDF = {
+      val schema = StructType(Seq(
+        StructField("a", IntegerType),
+        StructField("b", IntegerType)))
+      val data = Seq(
+        Row(3, 2),
+        Row(null, 2)
+      )
+      spark.createDataFrame(spark.sparkContext.parallelize(data), schema)
+    }
+
+    val intDFTwiceResult = Seq(
+      Row(Seq(3, 3)),
+      Row(Seq(null, null))
+    )
+
+    checkAnswer(intDF.select(array_repeat($"a", 2)), intDFTwiceResult)
+    checkAnswer(intDF.filter(dummyFilter($"a")).select(array_repeat($"a", 2)), intDFTwiceResult)
+    checkAnswer(intDF.select(array_repeat($"a", $"b")), intDFTwiceResult)
+    checkAnswer(intDF.filter(dummyFilter($"a")).select(array_repeat($"a", $"b")), intDFTwiceResult)
+    checkAnswer(intDF.selectExpr("array_repeat(a, 2)"), intDFTwiceResult)
+    checkAnswer(intDF.selectExpr("array_repeat(a, b)"), intDFTwiceResult)
+
+    val nullCountDF = {
+      val schema = StructType(Seq(
+        StructField("a", StringType),
+        StructField("b", IntegerType)))
+      val data = Seq(
+        Row("hi", null),
+        Row(null, null)
+      )
+      spark.createDataFrame(spark.sparkContext.parallelize(data), schema)
+    }
+
+    checkAnswer(
+      nullCountDF.select(array_repeat($"a", $"b")),
+      Seq(
+        Row(null),
+        Row(null)
+      )
+    )
+
+    // Error test cases
+    val invalidTypeDF = Seq(("hi", "1")).toDF("a", "b")
+
+    intercept[AnalysisException] {
+      invalidTypeDF.select(array_repeat($"a", $"b"))
+    }
+    intercept[AnalysisException] {
+      invalidTypeDF.select(array_repeat($"a", lit("1")))
+    }
+    intercept[AnalysisException] {
+      invalidTypeDF.selectExpr("array_repeat(a, 1.0)")
+    }
+
+  }
+
   private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = {
     import DataFrameFunctionsSuite.CodegenFallbackExpr
     for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org