You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2018/01/05 06:02:49 UTC

spark git commit: [SPARK-22825][SQL] Fix incorrect results of Casting Array to String

Repository: spark
Updated Branches:
  refs/heads/master df7fc3ef3 -> 52fc5c17d


[SPARK-22825][SQL] Fix incorrect results of Casting Array to String

## What changes were proposed in this pull request?
This pr fixed the issue when casting arrays into strings;
```
scala> val df = spark.range(10).select('id.cast("integer")).agg(collect_list('id).as('ids))
scala> df.write.saveAsTable("t")
scala> sql("SELECT cast(ids as String) FROM t").show(false)
+------------------------------------------------------------------+
|ids                                                               |
+------------------------------------------------------------------+
|org.apache.spark.sql.catalyst.expressions.UnsafeArrayData8bc285df|
+------------------------------------------------------------------+
```

This pr modified the result into;
```
+------------------------------+
|ids                           |
+------------------------------+
|[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]|
+------------------------------+
```

## How was this patch tested?
Added tests in `CastSuite` and `SQLQuerySuite`.

Author: Takeshi Yamamuro <ya...@apache.org>

Closes #20024 from maropu/SPARK-22825.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/52fc5c17
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/52fc5c17
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/52fc5c17

Branch: refs/heads/master
Commit: 52fc5c17d9d784b846149771b398e741621c0b5c
Parents: df7fc3e
Author: Takeshi Yamamuro <ya...@apache.org>
Authored: Fri Jan 5 14:02:21 2018 +0800
Committer: Wenchen Fan <we...@databricks.com>
Committed: Fri Jan 5 14:02:21 2018 +0800

----------------------------------------------------------------------
 .../expressions/codegen/UTF8StringBuilder.java  | 78 ++++++++++++++++++++
 .../spark/sql/catalyst/expressions/Cast.scala   | 68 +++++++++++++++++
 .../sql/catalyst/expressions/CastSuite.scala    | 25 +++++++
 .../org/apache/spark/sql/SQLQuerySuite.scala    |  2 -
 4 files changed, 171 insertions(+), 2 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/52fc5c17/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilder.java
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilder.java
new file mode 100644
index 0000000..f0f66ba
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilder.java
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.codegen;
+
+import org.apache.spark.unsafe.Platform;
+import org.apache.spark.unsafe.array.ByteArrayMethods;
+import org.apache.spark.unsafe.types.UTF8String;
+
+/**
+ * A helper class to write {@link UTF8String}s to an internal buffer and build the concatenated
+ * {@link UTF8String} at the end.
+ */
+public class UTF8StringBuilder {
+
+  private static final int ARRAY_MAX = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH;
+
+  private byte[] buffer;
+  private int cursor = Platform.BYTE_ARRAY_OFFSET;
+
+  public UTF8StringBuilder() {
+    // Since initial buffer size is 16 in `StringBuilder`, we set the same size here
+    this.buffer = new byte[16];
+  }
+
+  // Grows the buffer by at least `neededSize`
+  private void grow(int neededSize) {
+    if (neededSize > ARRAY_MAX - totalSize()) {
+      throw new UnsupportedOperationException(
+        "Cannot grow internal buffer by size " + neededSize + " because the size after growing " +
+          "exceeds size limitation " + ARRAY_MAX);
+    }
+    final int length = totalSize() + neededSize;
+    if (buffer.length < length) {
+      int newLength = length < ARRAY_MAX / 2 ? length * 2 : ARRAY_MAX;
+      final byte[] tmp = new byte[newLength];
+      Platform.copyMemory(
+        buffer,
+        Platform.BYTE_ARRAY_OFFSET,
+        tmp,
+        Platform.BYTE_ARRAY_OFFSET,
+        totalSize());
+      buffer = tmp;
+    }
+  }
+
+  private int totalSize() {
+    return cursor - Platform.BYTE_ARRAY_OFFSET;
+  }
+
+  public void append(UTF8String value) {
+    grow(value.numBytes());
+    value.writeToMemory(buffer, cursor);
+    cursor += value.numBytes();
+  }
+
+  public void append(String value) {
+    append(UTF8String.fromString(value));
+  }
+
+  public UTF8String build() {
+    return UTF8String.fromBytes(buffer, 0, totalSize());
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/52fc5c17/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index 274d881..d4fc5e0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -206,6 +206,28 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
     case DateType => buildCast[Int](_, d => UTF8String.fromString(DateTimeUtils.dateToString(d)))
     case TimestampType => buildCast[Long](_,
       t => UTF8String.fromString(DateTimeUtils.timestampToString(t, timeZone)))
+    case ArrayType(et, _) =>
+      buildCast[ArrayData](_, array => {
+        val builder = new UTF8StringBuilder
+        builder.append("[")
+        if (array.numElements > 0) {
+          val toUTF8String = castToString(et)
+          if (!array.isNullAt(0)) {
+            builder.append(toUTF8String(array.get(0, et)).asInstanceOf[UTF8String])
+          }
+          var i = 1
+          while (i < array.numElements) {
+            builder.append(",")
+            if (!array.isNullAt(i)) {
+              builder.append(" ")
+              builder.append(toUTF8String(array.get(i, et)).asInstanceOf[UTF8String])
+            }
+            i += 1
+          }
+        }
+        builder.append("]")
+        builder.build()
+      })
     case _ => buildCast[Any](_, o => UTF8String.fromString(o.toString))
   }
 
@@ -597,6 +619,41 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
     """
   }
 
+  private def writeArrayToStringBuilder(
+      et: DataType,
+      array: String,
+      buffer: String,
+      ctx: CodegenContext): String = {
+    val elementToStringCode = castToStringCode(et, ctx)
+    val funcName = ctx.freshName("elementToString")
+    val elementToStringFunc = ctx.addNewFunction(funcName,
+      s"""
+         |private UTF8String $funcName(${ctx.javaType(et)} element) {
+         |  UTF8String elementStr = null;
+         |  ${elementToStringCode("element", "elementStr", null /* resultIsNull won't be used */)}
+         |  return elementStr;
+         |}
+       """.stripMargin)
+
+    val loopIndex = ctx.freshName("loopIndex")
+    s"""
+       |$buffer.append("[");
+       |if ($array.numElements() > 0) {
+       |  if (!$array.isNullAt(0)) {
+       |    $buffer.append($elementToStringFunc(${ctx.getValue(array, et, "0")}));
+       |  }
+       |  for (int $loopIndex = 1; $loopIndex < $array.numElements(); $loopIndex++) {
+       |    $buffer.append(",");
+       |    if (!$array.isNullAt($loopIndex)) {
+       |      $buffer.append(" ");
+       |      $buffer.append($elementToStringFunc(${ctx.getValue(array, et, loopIndex)}));
+       |    }
+       |  }
+       |}
+       |$buffer.append("]");
+     """.stripMargin
+  }
+
   private[this] def castToStringCode(from: DataType, ctx: CodegenContext): CastFunction = {
     from match {
       case BinaryType =>
@@ -608,6 +665,17 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
         val tz = ctx.addReferenceObj("timeZone", timeZone)
         (c, evPrim, evNull) => s"""$evPrim = UTF8String.fromString(
           org.apache.spark.sql.catalyst.util.DateTimeUtils.timestampToString($c, $tz));"""
+      case ArrayType(et, _) =>
+        (c, evPrim, evNull) => {
+          val buffer = ctx.freshName("buffer")
+          val bufferClass = classOf[UTF8StringBuilder].getName
+          val writeArrayElemCode = writeArrayToStringBuilder(et, c, buffer, ctx)
+          s"""
+             |$bufferClass $buffer = new $bufferClass();
+             |$writeArrayElemCode;
+             |$evPrim = $buffer.build();
+           """.stripMargin
+        }
       case _ =>
         (c, evPrim, evNull) => s"$evPrim = UTF8String.fromString(String.valueOf($c));"
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/52fc5c17/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
index 1dd040e..e3ed717 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
@@ -853,4 +853,29 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
     cast("2", LongType).genCode(ctx)
     assert(ctx.inlinedMutableStates.length == 0)
   }
+
+  test("SPARK-22825 Cast array to string") {
+    val ret1 = cast(Literal.create(Array(1, 2, 3, 4, 5)), StringType)
+    checkEvaluation(ret1, "[1, 2, 3, 4, 5]")
+    val ret2 = cast(Literal.create(Array("ab", "cde", "f")), StringType)
+    checkEvaluation(ret2, "[ab, cde, f]")
+    val ret3 = cast(Literal.create(Array("ab", null, "c")), StringType)
+    checkEvaluation(ret3, "[ab,, c]")
+    val ret4 = cast(Literal.create(Array("ab".getBytes, "cde".getBytes, "f".getBytes)), StringType)
+    checkEvaluation(ret4, "[ab, cde, f]")
+    val ret5 = cast(
+      Literal.create(Array("2014-12-03", "2014-12-04", "2014-12-06").map(Date.valueOf)),
+      StringType)
+    checkEvaluation(ret5, "[2014-12-03, 2014-12-04, 2014-12-06]")
+    val ret6 = cast(
+      Literal.create(Array("2014-12-03 13:01:00", "2014-12-04 15:05:00").map(Timestamp.valueOf)),
+      StringType)
+    checkEvaluation(ret6, "[2014-12-03 13:01:00, 2014-12-04 15:05:00]")
+    val ret7 = cast(Literal.create(Array(Array(1, 2, 3), Array(4, 5))), StringType)
+    checkEvaluation(ret7, "[[1, 2, 3], [4, 5]]")
+    val ret8 = cast(
+      Literal.create(Array(Array(Array("a"), Array("b", "c")), Array(Array("d")))),
+      StringType)
+    checkEvaluation(ret8, "[[[a], [b, c]], [[d]]]")
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/52fc5c17/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 5e07728..96bf65f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -28,8 +28,6 @@ import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart}
 import org.apache.spark.sql.catalyst.util.StringUtils
 import org.apache.spark.sql.execution.aggregate
 import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, SortAggregateExec}
-import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation}
-import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat
 import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, CartesianProductExec, SortMergeJoinExec}
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.internal.SQLConf


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org