You are viewing a plain text version of this content. The canonical link for it is here.
Posted to reviews@spark.apache.org by kiszk <gi...@git.apache.org> on 2018/05/04 14:19:55 UTC
[GitHub] spark pull request #21040: [SPARK-23930][SQL] Add slice function
Github user kiszk commented on a diff in the pull request:
https://github.com/apache/spark/pull/21040#discussion_r186097518
--- Diff: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala ---
@@ -378,6 +378,138 @@ case class ArrayContains(left: Expression, right: Expression)
override def prettyName: String = "array_contains"
}
+/**
+ * Slices an array according to the requested start index and length
+ */
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = "_FUNC_(x, start, length) - Subsets array x starting from index start (or starting from the end if start is negative) with the specified length.",
+ examples = """
+ Examples:
+ > SELECT _FUNC_(array(1, 2, 3, 4), 2, 2);
+ [2,3]
+ > SELECT _FUNC_(array(1, 2, 3, 4), -2, 2);
+ [3,4]
+ """, since = "2.4.0")
+// scalastyle:on line.size.limit
+case class Slice(x: Expression, start: Expression, length: Expression)
+ extends TernaryExpression with ImplicitCastInputTypes {
+
+ override def dataType: DataType = x.dataType
+
+ override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, IntegerType, IntegerType)
+
+ override def children: Seq[Expression] = Seq(x, start, length)
+
+ lazy val elementType: DataType = x.dataType.asInstanceOf[ArrayType].elementType
+
+ override def nullSafeEval(xVal: Any, startVal: Any, lengthVal: Any): Any = {
+ val startInt = startVal.asInstanceOf[Int]
+ val lengthInt = lengthVal.asInstanceOf[Int]
+ val arr = xVal.asInstanceOf[ArrayData]
+ val startIndex = if (startInt == 0) {
+ throw new RuntimeException(
+ s"Unexpected value for start in function $prettyName: SQL array indices start at 1.")
+ } else if (startInt < 0) {
+ startInt + arr.numElements()
+ } else {
+ startInt - 1
+ }
+ if (lengthInt < 0) {
+ throw new RuntimeException(s"Unexpected value for length in function $prettyName: " +
+ "length must be greater than or equal to 0.")
+ }
+ // startIndex can be negative if start is negative and its absolute value is greater than the
+ // number of elements in the array
+ if (startIndex < 0 || startIndex >= arr.numElements()) {
+ return new GenericArrayData(Array.empty[AnyRef])
+ }
+ val data = arr.toSeq[AnyRef](elementType)
+ new GenericArrayData(data.slice(startIndex, startIndex + lengthInt))
+ }
+
+ override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ nullSafeCodeGen(ctx, ev, (x, start, length) => {
+ val startIdx = ctx.freshName("startIdx")
+ val resLength = ctx.freshName("resLength")
+ val defaultIntValue = CodeGenerator.defaultValue(CodeGenerator.JAVA_INT, false)
+ s"""
+ |${CodeGenerator.JAVA_INT} $startIdx = $defaultIntValue;
+ |${CodeGenerator.JAVA_INT} $resLength = $defaultIntValue;
+ |if ($start == 0) {
+ | throw new RuntimeException("Unexpected value for start in function $prettyName: "
+ | + "SQL array indices start at 1.");
+ |} else if ($start < 0) {
+ | $startIdx = $start + $x.numElements();
+ |} else {
+ | // arrays in SQL are 1-based instead of 0-based
+ | $startIdx = $start - 1;
+ |}
+ |if ($length < 0) {
+ | throw new RuntimeException("Unexpected value for length in function $prettyName: "
+ | + "length must be greater than or equal to 0.");
+ |} else if ($length > $x.numElements() - $startIdx) {
+ | $resLength = $x.numElements() - $startIdx;
+ |} else {
+ | $resLength = $length;
+ |}
+ |${genCodeForResult(ctx, ev, x, startIdx, resLength)}
+ """.stripMargin
+ })
+ }
+
+ def genCodeForResult(
+ ctx: CodegenContext,
+ ev: ExprCode,
+ inputArray: String,
+ startIdx: String,
+ resLength: String): String = {
+ val values = ctx.freshName("values")
+ val i = ctx.freshName("i")
+ val getValue = CodeGenerator.getValue(inputArray, elementType, s"$i + $startIdx")
+ if (!CodeGenerator.isPrimitiveType(elementType)) {
+ val arrayClass = classOf[GenericArrayData].getName
+ s"""
+ |Object[] $values;
+ |if ($startIdx < 0 || $startIdx >= $inputArray.numElements()) {
+ | $values = new Object[0];
+ |} else {
+ | $values = new Object[$resLength];
+ | for (int $i = 0; $i < $resLength; $i ++) {
+ | $values[$i] = $getValue;
+ | }
+ |}
+ |${ev.value} = new $arrayClass($values);
+ """.stripMargin
+ } else {
+ val sizeInBytes = ctx.freshName("sizeInBytes")
+ val bytesArray = ctx.freshName("bytesArray")
+ val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType)
+ s"""
+ |if ($startIdx < 0 || $startIdx >= $inputArray.numElements()) {
+ | $resLength = 0;
+ |}
+ |${CodeGenerator.JAVA_INT} $sizeInBytes =
+ | UnsafeArrayData.calculateHeaderPortionInBytes($resLength) +
+ | ${classOf[ByteArrayMethods].getName}.roundNumberOfBytesToNearestWord(
+ | ${elementType.defaultSize} * $resLength);
+ |byte[] $bytesArray = new byte[$sizeInBytes];
--- End diff --
What happens if `sizeInBytes` is larger than `Integer.MAX_VALUE`? For example, `0x7000_0000` long elements. In this case, `GenericArrayData` or `long[]` can hold these elements. WDYT?
---
---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org