You are viewing a plain text version of this content. The canonical link for it is here.
Posted to reviews@spark.apache.org by DylanGuedes <gi...@git.apache.org> on 2018/06/11 20:05:55 UTC
[GitHub] spark pull request #21045: [SPARK-23931][SQL] Adds arrays_zip function to sp...
Github user DylanGuedes commented on a diff in the pull request:
https://github.com/apache/spark/pull/21045#discussion_r194530372
--- Diff: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala ---
@@ -128,6 +128,170 @@ case class MapKeys(child: Expression)
override def prettyName: String = "map_keys"
}
+@ExpressionDescription(
+ usage = """
+ _FUNC_(a1, a2, ...) - Returns a merged array of structs in which the N-th struct contains all
+ N-th values of input arrays.
+ """,
+ examples = """
+ Examples:
+ > SELECT _FUNC_(array(1, 2, 3), array(2, 3, 4));
+ [[1, 2], [2, 3], [3, 4]]
+ > SELECT _FUNC_(array(1, 2), array(2, 3), array(3, 4));
+ [[1, 2, 3], [2, 3, 4]]
+ """,
+ since = "2.4.0")
+case class ArraysZip(children: Seq[Expression]) extends Expression with ExpectsInputTypes {
+
+ override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.length)(ArrayType)
+
+ override def dataType: DataType = ArrayType(mountSchema)
+
+ override def nullable: Boolean = children.exists(_.nullable)
+
+ private lazy val arrayTypes = children.map(_.dataType.asInstanceOf[ArrayType])
+
+ private lazy val arrayElementTypes = arrayTypes.map(_.elementType)
+
+ @transient private lazy val mountSchema: StructType = {
+ val fields = children.zip(arrayElementTypes).zipWithIndex.map {
+ case ((expr: NamedExpression, elementType), _) =>
+ StructField(expr.name, elementType, nullable = true)
+ case ((_, elementType), idx) =>
+ StructField(idx.toString, elementType, nullable = true)
+ }
+ StructType(fields)
+ }
+
+ @transient lazy val numberOfArrays: Int = children.length
+
+ @transient lazy val genericArrayData = classOf[GenericArrayData].getName
+
+ def emptyInputGenCode(ev: ExprCode): ExprCode = {
+ ev.copy(code"""
+ |${CodeGenerator.javaType(dataType)} ${ev.value} = new $genericArrayData(new Object[0]);
+ |boolean ${ev.isNull} = false;
+ """.stripMargin)
+ }
+
+ def nonEmptyInputGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ val genericInternalRow = classOf[GenericInternalRow].getName
+ val arrVals = ctx.freshName("arrVals")
+ val biggestCardinality = ctx.freshName("biggestCardinality")
+
+ val currentRow = ctx.freshName("currentRow")
+ val j = ctx.freshName("j")
+ val i = ctx.freshName("i")
+ val args = ctx.freshName("args")
+
+ val evals = children.map(_.genCode(ctx))
+ val getValuesAndCardinalities = evals.zipWithIndex.map { case (eval, index) =>
+ s"""
+ |if ($biggestCardinality != -1) {
+ | ${eval.code}
+ | if (!${eval.isNull}) {
+ | $arrVals[$index] = ${eval.value};
+ | $biggestCardinality = Math.max($biggestCardinality, ${eval.value}.numElements());
+ | } else {
+ | $biggestCardinality = -1;
+ | }
+ |}
+ """.stripMargin
+ }
+
+ val splittedGetValuesAndCardinalities = ctx.splitExpressions(
+ expressions = getValuesAndCardinalities,
+ funcName = "getValuesAndCardinalities",
+ returnType = "int",
+ makeSplitFunction = body =>
+ s"""
+ |$body
+ |return $biggestCardinality;
+ """.stripMargin,
+ foldFunctions = _.map(funcCall => s"$biggestCardinality = $funcCall;").mkString("\n"),
+ arguments =
+ ("ArrayData[]", arrVals) ::
+ ("int", biggestCardinality) :: Nil)
+
+ val getValueForType = arrayElementTypes.zipWithIndex.map { case (eleType, idx) =>
+ val g = CodeGenerator.getValue(s"$arrVals[$idx]", eleType, i)
+ s"""
+ |if ($i < $arrVals[$idx].numElements() && !$arrVals[$idx].isNullAt($i)) {
+ | $currentRow[$idx] = $g;
+ |} else {
+ | $currentRow[$idx] = null;
+ |}
+ """.stripMargin
+ }
+
+ val getValueForTypeSplitted = ctx.splitExpressions(
+ expressions = getValueForType,
+ funcName = "extractValue",
+ arguments =
+ ("int", i) ::
+ ("Object[]", currentRow) ::
+ ("ArrayData[]", arrVals) :: Nil)
+
+ val initVariables = s"""
+ |ArrayData[] $arrVals = new ArrayData[$numberOfArrays];
+ |int $biggestCardinality = 0;
+ |${CodeGenerator.javaType(dataType)} ${ev.value} = null;
+ """.stripMargin
+
+ ev.copy(code"""
+ |$initVariables
+ |$splittedGetValuesAndCardinalities
+ |boolean ${ev.isNull} = $biggestCardinality == -1;
+ |if (!${ev.isNull}) {
+ | Object[] $args = new Object[$biggestCardinality];
+ | for (int $i = 0; $i < $biggestCardinality; $i ++) {
+ | Object[] $currentRow = new Object[$numberOfArrays];
+ | $getValueForTypeSplitted
+ | $args[$i] = new $genericInternalRow($currentRow);
+ | }
+ | ${ev.value} = new $genericArrayData($args);
+ |}
+ """.stripMargin)
+ }
+
+ override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ if (numberOfArrays == 0) {
+ emptyInputGenCode(ev)
+ } else {
+ nonEmptyInputGenCode(ctx, ev)
+ }
+ }
+
+ override def eval(input: InternalRow): Any = {
+ val inputArrays = children.map(_.eval(input).asInstanceOf[ArrayData])
+ if (inputArrays.contains(null)) {
+ null
+ } else {
+ val biggestCardinality = if (inputArrays.isEmpty) {
+ 0
+ } else {
+ inputArrays.map(_.numElements()).max
+ }
+
+ val result = new Array[InternalRow](biggestCardinality)
+ val zippedArrs: Seq[(ArrayData, Int)] = inputArrays.zipWithIndex
+
+ for (i <- 0 until biggestCardinality) {
+ val currentLayer: Seq[Object] = zippedArrs.map { case (arr, index) =>
+ if (i < arr.numElements() && !arr.isNullAt(i)) {
+ arr.get(i, arrayElementTypes(index))
+ } else {
+ null
+ }
+ }
+
+ result(i) = InternalRow.apply(currentLayer: _*)
+ }
+ new GenericArrayData(result)
+ }
+ }
--- End diff --
Done!
---
---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org