You are viewing a plain text version of this content. The canonical link for it is here.
Posted to reviews@spark.apache.org by DylanGuedes <gi...@git.apache.org> on 2018/06/11 20:05:55 UTC

[GitHub] spark pull request #21045: [SPARK-23931][SQL] Adds arrays_zip function to sp...

Github user DylanGuedes commented on a diff in the pull request:

    https://github.com/apache/spark/pull/21045#discussion_r194530372
  
    --- Diff: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala ---
    @@ -128,6 +128,170 @@ case class MapKeys(child: Expression)
       override def prettyName: String = "map_keys"
     }
     
    +@ExpressionDescription(
    +  usage = """
    +    _FUNC_(a1, a2, ...) - Returns a merged array of structs in which the N-th struct contains all
    +    N-th values of input arrays.
    +  """,
    +  examples = """
    +    Examples:
    +      > SELECT _FUNC_(array(1, 2, 3), array(2, 3, 4));
    +        [[1, 2], [2, 3], [3, 4]]
    +      > SELECT _FUNC_(array(1, 2), array(2, 3), array(3, 4));
    +        [[1, 2, 3], [2, 3, 4]]
    +  """,
    +  since = "2.4.0")
    +case class ArraysZip(children: Seq[Expression]) extends Expression with ExpectsInputTypes {
    +
    +  override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.length)(ArrayType)
    +
    +  override def dataType: DataType = ArrayType(mountSchema)
    +
    +  override def nullable: Boolean = children.exists(_.nullable)
    +
    +  private lazy val arrayTypes = children.map(_.dataType.asInstanceOf[ArrayType])
    +
    +  private lazy val arrayElementTypes = arrayTypes.map(_.elementType)
    +
    +  @transient private lazy val mountSchema: StructType = {
    +    val fields = children.zip(arrayElementTypes).zipWithIndex.map {
    +      case ((expr: NamedExpression, elementType), _) =>
    +        StructField(expr.name, elementType, nullable = true)
    +      case ((_, elementType), idx) =>
    +        StructField(idx.toString, elementType, nullable = true)
    +    }
    +    StructType(fields)
    +  }
    +
    +  @transient lazy val numberOfArrays: Int = children.length
    +
    +  @transient lazy val genericArrayData = classOf[GenericArrayData].getName
    +
    +  def emptyInputGenCode(ev: ExprCode): ExprCode = {
    +    ev.copy(code"""
    +      |${CodeGenerator.javaType(dataType)} ${ev.value} = new $genericArrayData(new Object[0]);
    +      |boolean ${ev.isNull} = false;
    +    """.stripMargin)
    +  }
    +
    +  def nonEmptyInputGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    +    val genericInternalRow = classOf[GenericInternalRow].getName
    +    val arrVals = ctx.freshName("arrVals")
    +    val biggestCardinality = ctx.freshName("biggestCardinality")
    +
    +    val currentRow = ctx.freshName("currentRow")
    +    val j = ctx.freshName("j")
    +    val i = ctx.freshName("i")
    +    val args = ctx.freshName("args")
    +
    +    val evals = children.map(_.genCode(ctx))
    +    val getValuesAndCardinalities = evals.zipWithIndex.map { case (eval, index) =>
    +      s"""
    +        |if ($biggestCardinality != -1) {
    +        |  ${eval.code}
    +        |  if (!${eval.isNull}) {
    +        |    $arrVals[$index] = ${eval.value};
    +        |    $biggestCardinality = Math.max($biggestCardinality, ${eval.value}.numElements());
    +        |  } else {
    +        |    $biggestCardinality = -1;
    +        |  }
    +        |}
    +      """.stripMargin
    +    }
    +
    +    val splittedGetValuesAndCardinalities = ctx.splitExpressions(
    +      expressions = getValuesAndCardinalities,
    +      funcName = "getValuesAndCardinalities",
    +      returnType = "int",
    +      makeSplitFunction = body =>
    +        s"""
    +          |$body
    +          |return $biggestCardinality;
    +        """.stripMargin,
    +      foldFunctions = _.map(funcCall => s"$biggestCardinality = $funcCall;").mkString("\n"),
    +      arguments =
    +        ("ArrayData[]", arrVals) ::
    +        ("int", biggestCardinality) :: Nil)
    +
    +    val getValueForType = arrayElementTypes.zipWithIndex.map { case (eleType, idx) =>
    +      val g = CodeGenerator.getValue(s"$arrVals[$idx]", eleType, i)
    +      s"""
    +        |if ($i < $arrVals[$idx].numElements() && !$arrVals[$idx].isNullAt($i)) {
    +        |  $currentRow[$idx] = $g;
    +        |} else {
    +        |  $currentRow[$idx] = null;
    +        |}
    +      """.stripMargin
    +    }
    +
    +    val getValueForTypeSplitted = ctx.splitExpressions(
    +      expressions = getValueForType,
    +      funcName = "extractValue",
    +      arguments =
    +        ("int", i) ::
    +        ("Object[]", currentRow) ::
    +        ("ArrayData[]", arrVals) :: Nil)
    +
    +    val initVariables = s"""
    +      |ArrayData[] $arrVals = new ArrayData[$numberOfArrays];
    +      |int $biggestCardinality = 0;
    +      |${CodeGenerator.javaType(dataType)} ${ev.value} = null;
    +    """.stripMargin
    +
    +    ev.copy(code"""
    +      |$initVariables
    +      |$splittedGetValuesAndCardinalities
    +      |boolean ${ev.isNull} = $biggestCardinality == -1;
    +      |if (!${ev.isNull}) {
    +      |  Object[] $args = new Object[$biggestCardinality];
    +      |  for (int $i = 0; $i < $biggestCardinality; $i ++) {
    +      |    Object[] $currentRow = new Object[$numberOfArrays];
    +      |    $getValueForTypeSplitted
    +      |    $args[$i] = new $genericInternalRow($currentRow);
    +      |  }
    +      |  ${ev.value} = new $genericArrayData($args);
    +      |}
    +    """.stripMargin)
    +  }
    +
    +  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    +    if (numberOfArrays == 0) {
    +      emptyInputGenCode(ev)
    +    } else {
    +      nonEmptyInputGenCode(ctx, ev)
    +    }
    +  }
    +
    +  override def eval(input: InternalRow): Any = {
    +    val inputArrays = children.map(_.eval(input).asInstanceOf[ArrayData])
    +    if (inputArrays.contains(null)) {
    +      null
    +    } else {
    +      val biggestCardinality = if (inputArrays.isEmpty) {
    +        0
    +      } else {
    +        inputArrays.map(_.numElements()).max
    +      }
    +
    +      val result = new Array[InternalRow](biggestCardinality)
    +      val zippedArrs: Seq[(ArrayData, Int)] = inputArrays.zipWithIndex
    +
    +      for (i <- 0 until biggestCardinality) {
    +        val currentLayer: Seq[Object] = zippedArrs.map { case (arr, index) =>
    +          if (i < arr.numElements() && !arr.isNullAt(i)) {
    +            arr.get(i, arrayElementTypes(index))
    +          } else {
    +            null
    +          }
    +        }
    +
    +        result(i) = InternalRow.apply(currentLayer: _*)
    +      }
    +      new GenericArrayData(result)
    +    }
    +  }
    --- End diff --
    
    Done!


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org