You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by hv...@apache.org on 2018/04/17 13:09:41 UTC
spark git commit: [SPARK-23875][SQL] Add IndexedSeq wrapper for
ArrayData
Repository: spark
Updated Branches:
refs/heads/master 05ae74778 -> 30ffb53ca
[SPARK-23875][SQL] Add IndexedSeq wrapper for ArrayData
## What changes were proposed in this pull request?
We don't have a good way to sequentially access `UnsafeArrayData` with a common interface such as `Seq`. An example is `MapObject` where we need to access several sequence collection types together. But `UnsafeArrayData` doesn't implement `ArrayData.array`. Calling `toArray` will copy the entire array. We can provide an `IndexedSeq` wrapper for `ArrayData`, so we can avoid copying the entire array.
## How was this patch tested?
Added test.
Author: Liang-Chi Hsieh <vi...@gmail.com>
Closes #20984 from viirya/SPARK-23875.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/30ffb53c
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/30ffb53c
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/30ffb53c
Branch: refs/heads/master
Commit: 30ffb53cad84283b4f7694bfd60bdd7e1101b04e
Parents: 05ae747
Author: Liang-Chi Hsieh <vi...@gmail.com>
Authored: Tue Apr 17 15:09:36 2018 +0200
Committer: Herman van Hovell <hv...@databricks.com>
Committed: Tue Apr 17 15:09:36 2018 +0200
----------------------------------------------------------------------
.../catalyst/expressions/objects/objects.scala | 2 +-
.../spark/sql/catalyst/util/ArrayData.scala | 30 +++++-
.../util/ArrayDataIndexedSeqSuite.scala | 100 +++++++++++++++++++
3 files changed, 130 insertions(+), 2 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/30ffb53c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index 77802e8..72b202b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -708,7 +708,7 @@ case class MapObjects private(
}
}
case ArrayType(et, _) =>
- _.asInstanceOf[ArrayData].array
+ _.asInstanceOf[ArrayData].toSeq[Any](et)
}
private lazy val mapElements: Seq[_] => Any = customCollectionCls match {
http://git-wip-us.apache.org/repos/asf/spark/blob/30ffb53c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala
index 9beef41..2cf59d5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala
@@ -19,8 +19,9 @@ package org.apache.spark.sql.catalyst.util
import scala.reflect.ClassTag
+import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{SpecializedGetters, UnsafeArrayData}
-import org.apache.spark.sql.types.DataType
+import org.apache.spark.sql.types._
object ArrayData {
def toArrayData(input: Any): ArrayData = input match {
@@ -42,6 +43,9 @@ abstract class ArrayData extends SpecializedGetters with Serializable {
def array: Array[Any]
+ def toSeq[T](dataType: DataType): IndexedSeq[T] =
+ new ArrayDataIndexedSeq[T](this, dataType)
+
def setNullAt(i: Int): Unit
def update(i: Int, value: Any): Unit
@@ -164,3 +168,27 @@ abstract class ArrayData extends SpecializedGetters with Serializable {
}
}
}
+
+/**
+ * Implements an `IndexedSeq` interface for `ArrayData`. Notice that if the original `ArrayData`
+ * is a primitive array and contains null elements, it is better to ask for `IndexedSeq[Any]`,
+ * instead of `IndexedSeq[Int]`, in order to keep the null elements.
+ */
+class ArrayDataIndexedSeq[T](arrayData: ArrayData, dataType: DataType) extends IndexedSeq[T] {
+
+ private val accessor: (SpecializedGetters, Int) => Any = InternalRow.getAccessor(dataType)
+
+ override def apply(idx: Int): T =
+ if (0 <= idx && idx < arrayData.numElements()) {
+ if (arrayData.isNullAt(idx)) {
+ null.asInstanceOf[T]
+ } else {
+ accessor(arrayData, idx).asInstanceOf[T]
+ }
+ } else {
+ throw new IndexOutOfBoundsException(
+ s"Index $idx must be between 0 and the length of the ArrayData.")
+ }
+
+ override def length: Int = arrayData.numElements()
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/30ffb53c/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayDataIndexedSeqSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayDataIndexedSeqSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayDataIndexedSeqSuite.scala
new file mode 100644
index 0000000..6400898
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayDataIndexedSeqSuite.scala
@@ -0,0 +1,100 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.util
+
+import scala.util.Random
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.RandomDataGenerator
+import org.apache.spark.sql.catalyst.encoders.{ExamplePointUDT, RowEncoder}
+import org.apache.spark.sql.catalyst.expressions.{FromUnsafeProjection, UnsafeArrayData, UnsafeProjection}
+import org.apache.spark.sql.types._
+
+class ArrayDataIndexedSeqSuite extends SparkFunSuite {
+ private def compArray(arrayData: ArrayData, elementDt: DataType, array: Array[Any]): Unit = {
+ assert(arrayData.numElements == array.length)
+ array.zipWithIndex.map { case (e, i) =>
+ if (e != null) {
+ elementDt match {
+ // For NaN, etc.
+ case FloatType | DoubleType => assert(arrayData.get(i, elementDt).equals(e))
+ case _ => assert(arrayData.get(i, elementDt) === e)
+ }
+ } else {
+ assert(arrayData.isNullAt(i))
+ }
+ }
+
+ val seq = arrayData.toSeq[Any](elementDt)
+ array.zipWithIndex.map { case (e, i) =>
+ if (e != null) {
+ elementDt match {
+ // For Nan, etc.
+ case FloatType | DoubleType => assert(seq(i).equals(e))
+ case _ => assert(seq(i) === e)
+ }
+ } else {
+ assert(seq(i) == null)
+ }
+ }
+
+ intercept[IndexOutOfBoundsException] {
+ seq(-1)
+ }.getMessage().contains("must be between 0 and the length of the ArrayData.")
+
+ intercept[IndexOutOfBoundsException] {
+ seq(seq.length)
+ }.getMessage().contains("must be between 0 and the length of the ArrayData.")
+ }
+
+ private def testArrayData(): Unit = {
+ val elementTypes = Seq(BooleanType, ByteType, ShortType, IntegerType, LongType, FloatType,
+ DoubleType, DecimalType.USER_DEFAULT, StringType, BinaryType, DateType, TimestampType,
+ CalendarIntervalType, new ExamplePointUDT())
+ val arrayTypes = elementTypes.flatMap { elementType =>
+ Seq(ArrayType(elementType, containsNull = false), ArrayType(elementType, containsNull = true))
+ }
+ val random = new Random(100)
+ arrayTypes.foreach { dt =>
+ val schema = StructType(StructField("col_1", dt, nullable = false) :: Nil)
+ val row = RandomDataGenerator.randomRow(random, schema)
+ val rowConverter = RowEncoder(schema)
+ val internalRow = rowConverter.toRow(row)
+
+ val unsafeRowConverter = UnsafeProjection.create(schema)
+ val safeRowConverter = FromUnsafeProjection(schema)
+
+ val unsafeRow = unsafeRowConverter(internalRow)
+ val safeRow = safeRowConverter(unsafeRow)
+
+ val genericArrayData = safeRow.getArray(0).asInstanceOf[GenericArrayData]
+ val unsafeArrayData = unsafeRow.getArray(0).asInstanceOf[UnsafeArrayData]
+
+ val elementType = dt.elementType
+ test("ArrayDataIndexedSeq - UnsafeArrayData - " + dt.toString) {
+ compArray(unsafeArrayData, elementType, unsafeArrayData.toArray[Any](elementType))
+ }
+
+ test("ArrayDataIndexedSeq - GenericArrayData - " + dt.toString) {
+ compArray(genericArrayData, elementType, genericArrayData.toArray[Any](elementType))
+ }
+ }
+ }
+
+ testArrayData()
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org