You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2018/01/05 06:02:49 UTC
spark git commit: [SPARK-22825][SQL] Fix incorrect results of Casting
Array to String
Repository: spark
Updated Branches:
refs/heads/master df7fc3ef3 -> 52fc5c17d
[SPARK-22825][SQL] Fix incorrect results of Casting Array to String
## What changes were proposed in this pull request?
This pr fixed the issue when casting arrays into strings;
```
scala> val df = spark.range(10).select('id.cast("integer")).agg(collect_list('id).as('ids))
scala> df.write.saveAsTable("t")
scala> sql("SELECT cast(ids as String) FROM t").show(false)
+------------------------------------------------------------------+
|ids |
+------------------------------------------------------------------+
|org.apache.spark.sql.catalyst.expressions.UnsafeArrayData8bc285df|
+------------------------------------------------------------------+
```
This pr modified the result into;
```
+------------------------------+
|ids |
+------------------------------+
|[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]|
+------------------------------+
```
## How was this patch tested?
Added tests in `CastSuite` and `SQLQuerySuite`.
Author: Takeshi Yamamuro <ya...@apache.org>
Closes #20024 from maropu/SPARK-22825.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/52fc5c17
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/52fc5c17
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/52fc5c17
Branch: refs/heads/master
Commit: 52fc5c17d9d784b846149771b398e741621c0b5c
Parents: df7fc3e
Author: Takeshi Yamamuro <ya...@apache.org>
Authored: Fri Jan 5 14:02:21 2018 +0800
Committer: Wenchen Fan <we...@databricks.com>
Committed: Fri Jan 5 14:02:21 2018 +0800
----------------------------------------------------------------------
.../expressions/codegen/UTF8StringBuilder.java | 78 ++++++++++++++++++++
.../spark/sql/catalyst/expressions/Cast.scala | 68 +++++++++++++++++
.../sql/catalyst/expressions/CastSuite.scala | 25 +++++++
.../org/apache/spark/sql/SQLQuerySuite.scala | 2 -
4 files changed, 171 insertions(+), 2 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/52fc5c17/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilder.java
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilder.java
new file mode 100644
index 0000000..f0f66ba
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilder.java
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.codegen;
+
+import org.apache.spark.unsafe.Platform;
+import org.apache.spark.unsafe.array.ByteArrayMethods;
+import org.apache.spark.unsafe.types.UTF8String;
+
+/**
+ * A helper class to write {@link UTF8String}s to an internal buffer and build the concatenated
+ * {@link UTF8String} at the end.
+ */
+public class UTF8StringBuilder {
+
+ private static final int ARRAY_MAX = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH;
+
+ private byte[] buffer;
+ private int cursor = Platform.BYTE_ARRAY_OFFSET;
+
+ public UTF8StringBuilder() {
+ // Since initial buffer size is 16 in `StringBuilder`, we set the same size here
+ this.buffer = new byte[16];
+ }
+
+ // Grows the buffer by at least `neededSize`
+ private void grow(int neededSize) {
+ if (neededSize > ARRAY_MAX - totalSize()) {
+ throw new UnsupportedOperationException(
+ "Cannot grow internal buffer by size " + neededSize + " because the size after growing " +
+ "exceeds size limitation " + ARRAY_MAX);
+ }
+ final int length = totalSize() + neededSize;
+ if (buffer.length < length) {
+ int newLength = length < ARRAY_MAX / 2 ? length * 2 : ARRAY_MAX;
+ final byte[] tmp = new byte[newLength];
+ Platform.copyMemory(
+ buffer,
+ Platform.BYTE_ARRAY_OFFSET,
+ tmp,
+ Platform.BYTE_ARRAY_OFFSET,
+ totalSize());
+ buffer = tmp;
+ }
+ }
+
+ private int totalSize() {
+ return cursor - Platform.BYTE_ARRAY_OFFSET;
+ }
+
+ public void append(UTF8String value) {
+ grow(value.numBytes());
+ value.writeToMemory(buffer, cursor);
+ cursor += value.numBytes();
+ }
+
+ public void append(String value) {
+ append(UTF8String.fromString(value));
+ }
+
+ public UTF8String build() {
+ return UTF8String.fromBytes(buffer, 0, totalSize());
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/52fc5c17/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index 274d881..d4fc5e0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -206,6 +206,28 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
case DateType => buildCast[Int](_, d => UTF8String.fromString(DateTimeUtils.dateToString(d)))
case TimestampType => buildCast[Long](_,
t => UTF8String.fromString(DateTimeUtils.timestampToString(t, timeZone)))
+ case ArrayType(et, _) =>
+ buildCast[ArrayData](_, array => {
+ val builder = new UTF8StringBuilder
+ builder.append("[")
+ if (array.numElements > 0) {
+ val toUTF8String = castToString(et)
+ if (!array.isNullAt(0)) {
+ builder.append(toUTF8String(array.get(0, et)).asInstanceOf[UTF8String])
+ }
+ var i = 1
+ while (i < array.numElements) {
+ builder.append(",")
+ if (!array.isNullAt(i)) {
+ builder.append(" ")
+ builder.append(toUTF8String(array.get(i, et)).asInstanceOf[UTF8String])
+ }
+ i += 1
+ }
+ }
+ builder.append("]")
+ builder.build()
+ })
case _ => buildCast[Any](_, o => UTF8String.fromString(o.toString))
}
@@ -597,6 +619,41 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
"""
}
+ private def writeArrayToStringBuilder(
+ et: DataType,
+ array: String,
+ buffer: String,
+ ctx: CodegenContext): String = {
+ val elementToStringCode = castToStringCode(et, ctx)
+ val funcName = ctx.freshName("elementToString")
+ val elementToStringFunc = ctx.addNewFunction(funcName,
+ s"""
+ |private UTF8String $funcName(${ctx.javaType(et)} element) {
+ | UTF8String elementStr = null;
+ | ${elementToStringCode("element", "elementStr", null /* resultIsNull won't be used */)}
+ | return elementStr;
+ |}
+ """.stripMargin)
+
+ val loopIndex = ctx.freshName("loopIndex")
+ s"""
+ |$buffer.append("[");
+ |if ($array.numElements() > 0) {
+ | if (!$array.isNullAt(0)) {
+ | $buffer.append($elementToStringFunc(${ctx.getValue(array, et, "0")}));
+ | }
+ | for (int $loopIndex = 1; $loopIndex < $array.numElements(); $loopIndex++) {
+ | $buffer.append(",");
+ | if (!$array.isNullAt($loopIndex)) {
+ | $buffer.append(" ");
+ | $buffer.append($elementToStringFunc(${ctx.getValue(array, et, loopIndex)}));
+ | }
+ | }
+ |}
+ |$buffer.append("]");
+ """.stripMargin
+ }
+
private[this] def castToStringCode(from: DataType, ctx: CodegenContext): CastFunction = {
from match {
case BinaryType =>
@@ -608,6 +665,17 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
val tz = ctx.addReferenceObj("timeZone", timeZone)
(c, evPrim, evNull) => s"""$evPrim = UTF8String.fromString(
org.apache.spark.sql.catalyst.util.DateTimeUtils.timestampToString($c, $tz));"""
+ case ArrayType(et, _) =>
+ (c, evPrim, evNull) => {
+ val buffer = ctx.freshName("buffer")
+ val bufferClass = classOf[UTF8StringBuilder].getName
+ val writeArrayElemCode = writeArrayToStringBuilder(et, c, buffer, ctx)
+ s"""
+ |$bufferClass $buffer = new $bufferClass();
+ |$writeArrayElemCode;
+ |$evPrim = $buffer.build();
+ """.stripMargin
+ }
case _ =>
(c, evPrim, evNull) => s"$evPrim = UTF8String.fromString(String.valueOf($c));"
}
http://git-wip-us.apache.org/repos/asf/spark/blob/52fc5c17/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
index 1dd040e..e3ed717 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
@@ -853,4 +853,29 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
cast("2", LongType).genCode(ctx)
assert(ctx.inlinedMutableStates.length == 0)
}
+
+ test("SPARK-22825 Cast array to string") {
+ val ret1 = cast(Literal.create(Array(1, 2, 3, 4, 5)), StringType)
+ checkEvaluation(ret1, "[1, 2, 3, 4, 5]")
+ val ret2 = cast(Literal.create(Array("ab", "cde", "f")), StringType)
+ checkEvaluation(ret2, "[ab, cde, f]")
+ val ret3 = cast(Literal.create(Array("ab", null, "c")), StringType)
+ checkEvaluation(ret3, "[ab,, c]")
+ val ret4 = cast(Literal.create(Array("ab".getBytes, "cde".getBytes, "f".getBytes)), StringType)
+ checkEvaluation(ret4, "[ab, cde, f]")
+ val ret5 = cast(
+ Literal.create(Array("2014-12-03", "2014-12-04", "2014-12-06").map(Date.valueOf)),
+ StringType)
+ checkEvaluation(ret5, "[2014-12-03, 2014-12-04, 2014-12-06]")
+ val ret6 = cast(
+ Literal.create(Array("2014-12-03 13:01:00", "2014-12-04 15:05:00").map(Timestamp.valueOf)),
+ StringType)
+ checkEvaluation(ret6, "[2014-12-03 13:01:00, 2014-12-04 15:05:00]")
+ val ret7 = cast(Literal.create(Array(Array(1, 2, 3), Array(4, 5))), StringType)
+ checkEvaluation(ret7, "[[1, 2, 3], [4, 5]]")
+ val ret8 = cast(
+ Literal.create(Array(Array(Array("a"), Array("b", "c")), Array(Array("d")))),
+ StringType)
+ checkEvaluation(ret8, "[[[a], [b, c]], [[d]]]")
+ }
}
http://git-wip-us.apache.org/repos/asf/spark/blob/52fc5c17/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 5e07728..96bf65f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -28,8 +28,6 @@ import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart}
import org.apache.spark.sql.catalyst.util.StringUtils
import org.apache.spark.sql.execution.aggregate
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, SortAggregateExec}
-import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation}
-import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, CartesianProductExec, SortMergeJoinExec}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org