You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2017/03/28 02:09:57 UTC
spark git commit: [SPARK-19088][SQL] Optimize sequence type
deserialization codegen
Repository: spark
Updated Branches:
refs/heads/master ea361165e -> 6c70a38c2
[SPARK-19088][SQL] Optimize sequence type deserialization codegen
## What changes were proposed in this pull request?
Optimization of arbitrary Scala sequence deserialization introduced by #16240.
The previous implementation constructed an array which was then converted by `to`. This required two passes in most cases.
This implementation attempts to remedy that by using `Builder`s provided by the `newBuilder` method on every Scala collection's companion object to build the resulting collection directly.
Example codegen for simple `List` (obtained using `Seq(List(1)).toDS().map(identity).queryExecution.debug.codegen`):
Before:
```
/* 001 */ public Object generate(Object[] references) {
/* 002 */ return new GeneratedIterator(references);
/* 003 */ }
/* 004 */
/* 005 */ final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator {
/* 006 */ private Object[] references;
/* 007 */ private scala.collection.Iterator[] inputs;
/* 008 */ private scala.collection.Iterator inputadapter_input;
/* 009 */ private boolean deserializetoobject_resultIsNull;
/* 010 */ private java.lang.Object[] deserializetoobject_argValue;
/* 011 */ private boolean MapObjects_loopIsNull1;
/* 012 */ private int MapObjects_loopValue0;
/* 013 */ private boolean deserializetoobject_resultIsNull1;
/* 014 */ private scala.collection.generic.CanBuildFrom deserializetoobject_argValue1;
/* 015 */ private UnsafeRow deserializetoobject_result;
/* 016 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder deserializetoobject_holder;
/* 017 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter deserializetoobject_rowWriter;
/* 018 */ private scala.collection.immutable.List mapelements_argValue;
/* 019 */ private UnsafeRow mapelements_result;
/* 020 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder mapelements_holder;
/* 021 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter mapelements_rowWriter;
/* 022 */ private scala.collection.immutable.List serializefromobject_argValue;
/* 023 */ private UnsafeRow serializefromobject_result;
/* 024 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder serializefromobject_holder;
/* 025 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter serializefromobject_rowWriter;
/* 026 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter serializefromobject_arrayWriter;
/* 027 */
/* 028 */ public GeneratedIterator(Object[] references) {
/* 029 */ this.references = references;
/* 030 */ }
/* 031 */
/* 032 */ public void init(int index, scala.collection.Iterator[] inputs) {
/* 033 */ partitionIndex = index;
/* 034 */ this.inputs = inputs;
/* 035 */ inputadapter_input = inputs[0];
/* 036 */
/* 037 */ deserializetoobject_result = new UnsafeRow(1);
/* 038 */ this.deserializetoobject_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(deserializetoobject_result, 32);
/* 039 */ this.deserializetoobject_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(deserializetoobject_holder, 1);
/* 040 */
/* 041 */ mapelements_result = new UnsafeRow(1);
/* 042 */ this.mapelements_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(mapelements_result, 32);
/* 043 */ this.mapelements_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(mapelements_holder, 1);
/* 044 */
/* 045 */ serializefromobject_result = new UnsafeRow(1);
/* 046 */ this.serializefromobject_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(serializefromobject_result, 32);
/* 047 */ this.serializefromobject_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(serializefromobject_holder, 1);
/* 048 */ this.serializefromobject_arrayWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter();
/* 049 */
/* 050 */ }
/* 051 */
/* 052 */ protected void processNext() throws java.io.IOException {
/* 053 */ while (inputadapter_input.hasNext() && !stopEarly()) {
/* 054 */ InternalRow inputadapter_row = (InternalRow) inputadapter_input.next();
/* 055 */ ArrayData inputadapter_value = inputadapter_row.getArray(0);
/* 056 */
/* 057 */ deserializetoobject_resultIsNull = false;
/* 058 */
/* 059 */ if (!deserializetoobject_resultIsNull) {
/* 060 */ ArrayData deserializetoobject_value3 = null;
/* 061 */
/* 062 */ if (!false) {
/* 063 */ Integer[] deserializetoobject_convertedArray = null;
/* 064 */ int deserializetoobject_dataLength = inputadapter_value.numElements();
/* 065 */ deserializetoobject_convertedArray = new Integer[deserializetoobject_dataLength];
/* 066 */
/* 067 */ int deserializetoobject_loopIndex = 0;
/* 068 */ while (deserializetoobject_loopIndex < deserializetoobject_dataLength) {
/* 069 */ MapObjects_loopValue0 = (int) (inputadapter_value.getInt(deserializetoobject_loopIndex));
/* 070 */ MapObjects_loopIsNull1 = inputadapter_value.isNullAt(deserializetoobject_loopIndex);
/* 071 */
/* 072 */ if (MapObjects_loopIsNull1) {
/* 073 */ throw new RuntimeException(((java.lang.String) references[0]));
/* 074 */ }
/* 075 */ if (false) {
/* 076 */ deserializetoobject_convertedArray[deserializetoobject_loopIndex] = null;
/* 077 */ } else {
/* 078 */ deserializetoobject_convertedArray[deserializetoobject_loopIndex] = MapObjects_loopValue0;
/* 079 */ }
/* 080 */
/* 081 */ deserializetoobject_loopIndex += 1;
/* 082 */ }
/* 083 */
/* 084 */ deserializetoobject_value3 = new org.apache.spark.sql.catalyst.util.GenericArrayData(deserializetoobject_convertedArray);
/* 085 */ }
/* 086 */ boolean deserializetoobject_isNull2 = true;
/* 087 */ java.lang.Object[] deserializetoobject_value2 = null;
/* 088 */ if (!false) {
/* 089 */ deserializetoobject_isNull2 = false;
/* 090 */ if (!deserializetoobject_isNull2) {
/* 091 */ Object deserializetoobject_funcResult = null;
/* 092 */ deserializetoobject_funcResult = deserializetoobject_value3.array();
/* 093 */ if (deserializetoobject_funcResult == null) {
/* 094 */ deserializetoobject_isNull2 = true;
/* 095 */ } else {
/* 096 */ deserializetoobject_value2 = (java.lang.Object[]) deserializetoobject_funcResult;
/* 097 */ }
/* 098 */
/* 099 */ }
/* 100 */ deserializetoobject_isNull2 = deserializetoobject_value2 == null;
/* 101 */ }
/* 102 */ deserializetoobject_resultIsNull = deserializetoobject_isNull2;
/* 103 */ deserializetoobject_argValue = deserializetoobject_value2;
/* 104 */ }
/* 105 */
/* 106 */ boolean deserializetoobject_isNull1 = deserializetoobject_resultIsNull;
/* 107 */ final scala.collection.Seq deserializetoobject_value1 = deserializetoobject_resultIsNull ? null : scala.collection.mutable.WrappedArray.make(deserializetoobject_argValue);
/* 108 */ deserializetoobject_isNull1 = deserializetoobject_value1 == null;
/* 109 */ boolean deserializetoobject_isNull = true;
/* 110 */ scala.collection.immutable.List deserializetoobject_value = null;
/* 111 */ if (!deserializetoobject_isNull1) {
/* 112 */ deserializetoobject_resultIsNull1 = false;
/* 113 */
/* 114 */ if (!deserializetoobject_resultIsNull1) {
/* 115 */ boolean deserializetoobject_isNull6 = false;
/* 116 */ final scala.collection.generic.CanBuildFrom deserializetoobject_value6 = false ? null : scala.collection.immutable.List.canBuildFrom();
/* 117 */ deserializetoobject_isNull6 = deserializetoobject_value6 == null;
/* 118 */ deserializetoobject_resultIsNull1 = deserializetoobject_isNull6;
/* 119 */ deserializetoobject_argValue1 = deserializetoobject_value6;
/* 120 */ }
/* 121 */
/* 122 */ deserializetoobject_isNull = deserializetoobject_resultIsNull1;
/* 123 */ if (!deserializetoobject_isNull) {
/* 124 */ Object deserializetoobject_funcResult1 = null;
/* 125 */ deserializetoobject_funcResult1 = deserializetoobject_value1.to(deserializetoobject_argValue1);
/* 126 */ if (deserializetoobject_funcResult1 == null) {
/* 127 */ deserializetoobject_isNull = true;
/* 128 */ } else {
/* 129 */ deserializetoobject_value = (scala.collection.immutable.List) deserializetoobject_funcResult1;
/* 130 */ }
/* 131 */
/* 132 */ }
/* 133 */ deserializetoobject_isNull = deserializetoobject_value == null;
/* 134 */ }
/* 135 */
/* 136 */ boolean mapelements_isNull = true;
/* 137 */ scala.collection.immutable.List mapelements_value = null;
/* 138 */ if (!false) {
/* 139 */ mapelements_argValue = deserializetoobject_value;
/* 140 */
/* 141 */ mapelements_isNull = false;
/* 142 */ if (!mapelements_isNull) {
/* 143 */ Object mapelements_funcResult = null;
/* 144 */ mapelements_funcResult = ((scala.Function1) references[1]).apply(mapelements_argValue);
/* 145 */ if (mapelements_funcResult == null) {
/* 146 */ mapelements_isNull = true;
/* 147 */ } else {
/* 148 */ mapelements_value = (scala.collection.immutable.List) mapelements_funcResult;
/* 149 */ }
/* 150 */
/* 151 */ }
/* 152 */ mapelements_isNull = mapelements_value == null;
/* 153 */ }
/* 154 */
/* 155 */ if (mapelements_isNull) {
/* 156 */ throw new RuntimeException(((java.lang.String) references[2]));
/* 157 */ }
/* 158 */ serializefromobject_argValue = mapelements_value;
/* 159 */
/* 160 */ final ArrayData serializefromobject_value = false ? null : new org.apache.spark.sql.catalyst.util.GenericArrayData(serializefromobject_argValue);
/* 161 */ serializefromobject_holder.reset();
/* 162 */
/* 163 */ // Remember the current cursor so that we can calculate how many bytes are
/* 164 */ // written later.
/* 165 */ final int serializefromobject_tmpCursor = serializefromobject_holder.cursor;
/* 166 */
/* 167 */ if (serializefromobject_value instanceof UnsafeArrayData) {
/* 168 */ final int serializefromobject_sizeInBytes = ((UnsafeArrayData) serializefromobject_value).getSizeInBytes();
/* 169 */ // grow the global buffer before writing data.
/* 170 */ serializefromobject_holder.grow(serializefromobject_sizeInBytes);
/* 171 */ ((UnsafeArrayData) serializefromobject_value).writeToMemory(serializefromobject_holder.buffer, serializefromobject_holder.cursor);
/* 172 */ serializefromobject_holder.cursor += serializefromobject_sizeInBytes;
/* 173 */
/* 174 */ } else {
/* 175 */ final int serializefromobject_numElements = serializefromobject_value.numElements();
/* 176 */ serializefromobject_arrayWriter.initialize(serializefromobject_holder, serializefromobject_numElements, 4);
/* 177 */
/* 178 */ for (int serializefromobject_index = 0; serializefromobject_index < serializefromobject_numElements; serializefromobject_index++) {
/* 179 */ if (serializefromobject_value.isNullAt(serializefromobject_index)) {
/* 180 */ serializefromobject_arrayWriter.setNullInt(serializefromobject_index);
/* 181 */ } else {
/* 182 */ final int serializefromobject_element = serializefromobject_value.getInt(serializefromobject_index);
/* 183 */ serializefromobject_arrayWriter.write(serializefromobject_index, serializefromobject_element);
/* 184 */ }
/* 185 */ }
/* 186 */ }
/* 187 */
/* 188 */ serializefromobject_rowWriter.setOffsetAndSize(0, serializefromobject_tmpCursor, serializefromobject_holder.cursor - serializefromobject_tmpCursor);
/* 189 */ serializefromobject_result.setTotalSize(serializefromobject_holder.totalSize());
/* 190 */ append(serializefromobject_result);
/* 191 */ if (shouldStop()) return;
/* 192 */ }
/* 193 */ }
/* 194 */ }
```
After:
```
/* 001 */ public Object generate(Object[] references) {
/* 002 */ return new GeneratedIterator(references);
/* 003 */ }
/* 004 */
/* 005 */ final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator {
/* 006 */ private Object[] references;
/* 007 */ private scala.collection.Iterator[] inputs;
/* 008 */ private scala.collection.Iterator inputadapter_input;
/* 009 */ private boolean CollectObjects_loopIsNull1;
/* 010 */ private int CollectObjects_loopValue0;
/* 011 */ private UnsafeRow deserializetoobject_result;
/* 012 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder deserializetoobject_holder;
/* 013 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter deserializetoobject_rowWriter;
/* 014 */ private scala.collection.immutable.List mapelements_argValue;
/* 015 */ private UnsafeRow mapelements_result;
/* 016 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder mapelements_holder;
/* 017 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter mapelements_rowWriter;
/* 018 */ private scala.collection.immutable.List serializefromobject_argValue;
/* 019 */ private UnsafeRow serializefromobject_result;
/* 020 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder serializefromobject_holder;
/* 021 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter serializefromobject_rowWriter;
/* 022 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter serializefromobject_arrayWriter;
/* 023 */
/* 024 */ public GeneratedIterator(Object[] references) {
/* 025 */ this.references = references;
/* 026 */ }
/* 027 */
/* 028 */ public void init(int index, scala.collection.Iterator[] inputs) {
/* 029 */ partitionIndex = index;
/* 030 */ this.inputs = inputs;
/* 031 */ inputadapter_input = inputs[0];
/* 032 */
/* 033 */ deserializetoobject_result = new UnsafeRow(1);
/* 034 */ this.deserializetoobject_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(deserializetoobject_result, 32);
/* 035 */ this.deserializetoobject_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(deserializetoobject_holder, 1);
/* 036 */
/* 037 */ mapelements_result = new UnsafeRow(1);
/* 038 */ this.mapelements_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(mapelements_result, 32);
/* 039 */ this.mapelements_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(mapelements_holder, 1);
/* 040 */
/* 041 */ serializefromobject_result = new UnsafeRow(1);
/* 042 */ this.serializefromobject_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(serializefromobject_result, 32);
/* 043 */ this.serializefromobject_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(serializefromobject_holder, 1);
/* 044 */ this.serializefromobject_arrayWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter();
/* 045 */
/* 046 */ }
/* 047 */
/* 048 */ protected void processNext() throws java.io.IOException {
/* 049 */ while (inputadapter_input.hasNext() && !stopEarly()) {
/* 050 */ InternalRow inputadapter_row = (InternalRow) inputadapter_input.next();
/* 051 */ ArrayData inputadapter_value = inputadapter_row.getArray(0);
/* 052 */
/* 053 */ scala.collection.immutable.List deserializetoobject_value = null;
/* 054 */
/* 055 */ if (!false) {
/* 056 */ int deserializetoobject_dataLength = inputadapter_value.numElements();
/* 057 */ scala.collection.mutable.Builder CollectObjects_builderValue2 = scala.collection.immutable.List$.MODULE$.newBuilder();
/* 058 */ CollectObjects_builderValue2.sizeHint(deserializetoobject_dataLength);
/* 059 */
/* 060 */ int deserializetoobject_loopIndex = 0;
/* 061 */ while (deserializetoobject_loopIndex < deserializetoobject_dataLength) {
/* 062 */ CollectObjects_loopValue0 = (int) (inputadapter_value.getInt(deserializetoobject_loopIndex));
/* 063 */ CollectObjects_loopIsNull1 = inputadapter_value.isNullAt(deserializetoobject_loopIndex);
/* 064 */
/* 065 */ if (CollectObjects_loopIsNull1) {
/* 066 */ throw new RuntimeException(((java.lang.String) references[0]));
/* 067 */ }
/* 068 */ if (false) {
/* 069 */ CollectObjects_builderValue2.$plus$eq(null);
/* 070 */ } else {
/* 071 */ CollectObjects_builderValue2.$plus$eq(CollectObjects_loopValue0);
/* 072 */ }
/* 073 */
/* 074 */ deserializetoobject_loopIndex += 1;
/* 075 */ }
/* 076 */
/* 077 */ deserializetoobject_value = (scala.collection.immutable.List) CollectObjects_builderValue2.result();
/* 078 */ }
/* 079 */
/* 080 */ boolean mapelements_isNull = true;
/* 081 */ scala.collection.immutable.List mapelements_value = null;
/* 082 */ if (!false) {
/* 083 */ mapelements_argValue = deserializetoobject_value;
/* 084 */
/* 085 */ mapelements_isNull = false;
/* 086 */ if (!mapelements_isNull) {
/* 087 */ Object mapelements_funcResult = null;
/* 088 */ mapelements_funcResult = ((scala.Function1) references[1]).apply(mapelements_argValue);
/* 089 */ if (mapelements_funcResult == null) {
/* 090 */ mapelements_isNull = true;
/* 091 */ } else {
/* 092 */ mapelements_value = (scala.collection.immutable.List) mapelements_funcResult;
/* 093 */ }
/* 094 */
/* 095 */ }
/* 096 */ mapelements_isNull = mapelements_value == null;
/* 097 */ }
/* 098 */
/* 099 */ if (mapelements_isNull) {
/* 100 */ throw new RuntimeException(((java.lang.String) references[2]));
/* 101 */ }
/* 102 */ serializefromobject_argValue = mapelements_value;
/* 103 */
/* 104 */ final ArrayData serializefromobject_value = false ? null : new org.apache.spark.sql.catalyst.util.GenericArrayData(serializefromobject_argValue);
/* 105 */ serializefromobject_holder.reset();
/* 106 */
/* 107 */ // Remember the current cursor so that we can calculate how many bytes are
/* 108 */ // written later.
/* 109 */ final int serializefromobject_tmpCursor = serializefromobject_holder.cursor;
/* 110 */
/* 111 */ if (serializefromobject_value instanceof UnsafeArrayData) {
/* 112 */ final int serializefromobject_sizeInBytes = ((UnsafeArrayData) serializefromobject_value).getSizeInBytes();
/* 113 */ // grow the global buffer before writing data.
/* 114 */ serializefromobject_holder.grow(serializefromobject_sizeInBytes);
/* 115 */ ((UnsafeArrayData) serializefromobject_value).writeToMemory(serializefromobject_holder.buffer, serializefromobject_holder.cursor);
/* 116 */ serializefromobject_holder.cursor += serializefromobject_sizeInBytes;
/* 117 */
/* 118 */ } else {
/* 119 */ final int serializefromobject_numElements = serializefromobject_value.numElements();
/* 120 */ serializefromobject_arrayWriter.initialize(serializefromobject_holder, serializefromobject_numElements, 4);
/* 121 */
/* 122 */ for (int serializefromobject_index = 0; serializefromobject_index < serializefromobject_numElements; serializefromobject_index++) {
/* 123 */ if (serializefromobject_value.isNullAt(serializefromobject_index)) {
/* 124 */ serializefromobject_arrayWriter.setNullInt(serializefromobject_index);
/* 125 */ } else {
/* 126 */ final int serializefromobject_element = serializefromobject_value.getInt(serializefromobject_index);
/* 127 */ serializefromobject_arrayWriter.write(serializefromobject_index, serializefromobject_element);
/* 128 */ }
/* 129 */ }
/* 130 */ }
/* 131 */
/* 132 */ serializefromobject_rowWriter.setOffsetAndSize(0, serializefromobject_tmpCursor, serializefromobject_holder.cursor - serializefromobject_tmpCursor);
/* 133 */ serializefromobject_result.setTotalSize(serializefromobject_holder.totalSize());
/* 134 */ append(serializefromobject_result);
/* 135 */ if (shouldStop()) return;
/* 136 */ }
/* 137 */ }
/* 138 */ }
```
Benchmark results before:
```
OpenJDK 64-Bit Server VM 1.8.0_112-b15 on Linux 4.8.13-1-ARCH
AMD A10-4600M APU with Radeon(tm) HD Graphics
collect: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------
Seq 269 / 370 0.0 269125.8 1.0X
List 154 / 176 0.0 154453.5 1.7X
mutable.Queue 210 / 233 0.0 209691.6 1.3X
```
Benchmark results after:
```
OpenJDK 64-Bit Server VM 1.8.0_112-b15 on Linux 4.8.13-1-ARCH
AMD A10-4600M APU with Radeon(tm) HD Graphics
collect: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------
Seq 255 / 316 0.0 254697.3 1.0X
List 152 / 177 0.0 152410.0 1.7X
mutable.Queue 213 / 235 0.0 213470.0 1.2X
```
## How was this patch tested?
```bash
./build/mvn -DskipTests clean package && ./dev/run-tests
```
Additionally in Spark Shell:
```scala
case class QueueClass(q: scala.collection.immutable.Queue[Int])
spark.createDataset(Seq(List(1,2,3))).map(x => QueueClass(scala.collection.immutable.Queue(x: _*))).map(_.q.dequeue).collect
```
Author: Michal Senkyr <mi...@gmail.com>
Closes #16541 from michalsenkyr/dataset-seq-builder.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/6c70a38c
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/6c70a38c
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/6c70a38c
Branch: refs/heads/master
Commit: 6c70a38c2e60e1b69a310aee1a92ee0b3815c02d
Parents: ea36116
Author: Michal Senkyr <mi...@gmail.com>
Authored: Tue Mar 28 10:09:49 2017 +0800
Committer: Wenchen Fan <we...@databricks.com>
Committed: Tue Mar 28 10:09:49 2017 +0800
----------------------------------------------------------------------
.../spark/sql/catalyst/ScalaReflection.scala | 51 ++--------------
.../catalyst/expressions/objects/objects.scala | 64 +++++++++++++++-----
.../sql/catalyst/ScalaReflectionSuite.scala | 8 ---
3 files changed, 54 insertions(+), 69 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/6c70a38c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index c4af284..1c7720a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -307,54 +307,11 @@ object ScalaReflection extends ScalaReflection {
}
}
- val array = Invoke(
- MapObjects(mapFunction, getPath, dataType),
- "array",
- ObjectType(classOf[Array[Any]]))
-
- val wrappedArray = StaticInvoke(
- scala.collection.mutable.WrappedArray.getClass,
- ObjectType(classOf[Seq[_]]),
- "make",
- array :: Nil)
-
- if (localTypeOf[scala.collection.mutable.WrappedArray[_]] <:< t.erasure) {
- wrappedArray
- } else {
- // Convert to another type using `to`
- val cls = mirror.runtimeClass(t.typeSymbol.asClass)
- import scala.collection.generic.CanBuildFrom
- import scala.reflect.ClassTag
-
- // Some canBuildFrom methods take an implicit ClassTag parameter
- val cbfParams = try {
- cls.getDeclaredMethod("canBuildFrom", classOf[ClassTag[_]])
- StaticInvoke(
- ClassTag.getClass,
- ObjectType(classOf[ClassTag[_]]),
- "apply",
- StaticInvoke(
- cls,
- ObjectType(classOf[Class[_]]),
- "getClass"
- ) :: Nil
- ) :: Nil
- } catch {
- case _: NoSuchMethodException => Nil
- }
-
- Invoke(
- wrappedArray,
- "to",
- ObjectType(cls),
- StaticInvoke(
- cls,
- ObjectType(classOf[CanBuildFrom[_, _, _]]),
- "canBuildFrom",
- cbfParams
- ) :: Nil
- )
+ val cls = t.dealias.companion.decl(TermName("newBuilder")) match {
+ case NoSymbol => classOf[Seq[_]]
+ case _ => mirror.runtimeClass(t.typeSymbol.asClass)
}
+ MapObjects(mapFunction, getPath, dataType, Some(cls))
case t if t <:< localTypeOf[Map[_, _]] =>
// TODO: add walked type path for map
http://git-wip-us.apache.org/repos/asf/spark/blob/6c70a38c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index 771ac28..bb584f7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.objects
import java.lang.reflect.Modifier
+import scala.collection.mutable.Builder
import scala.language.existentials
import scala.reflect.ClassTag
@@ -429,24 +430,34 @@ object MapObjects {
* @param function The function applied on the collection elements.
* @param inputData An expression that when evaluated returns a collection object.
* @param elementType The data type of elements in the collection.
+ * @param customCollectionCls Class of the resulting collection (returning ObjectType)
+ * or None (returning ArrayType)
*/
def apply(
function: Expression => Expression,
inputData: Expression,
- elementType: DataType): MapObjects = {
- val loopValue = "MapObjects_loopValue" + curId.getAndIncrement()
- val loopIsNull = "MapObjects_loopIsNull" + curId.getAndIncrement()
+ elementType: DataType,
+ customCollectionCls: Option[Class[_]] = None): MapObjects = {
+ val id = curId.getAndIncrement()
+ val loopValue = s"MapObjects_loopValue$id"
+ val loopIsNull = s"MapObjects_loopIsNull$id"
val loopVar = LambdaVariable(loopValue, loopIsNull, elementType)
- MapObjects(loopValue, loopIsNull, elementType, function(loopVar), inputData)
+ val builderValue = s"MapObjects_builderValue$id"
+ MapObjects(loopValue, loopIsNull, elementType, function(loopVar), inputData,
+ customCollectionCls, builderValue)
}
}
/**
* Applies the given expression to every element of a collection of items, returning the result
- * as an ArrayType. This is similar to a typical map operation, but where the lambda function
- * is expressed using catalyst expressions.
+ * as an ArrayType or ObjectType. This is similar to a typical map operation, but where the lambda
+ * function is expressed using catalyst expressions.
+ *
+ * The type of the result is determined as follows:
+ * - ArrayType - when customCollectionCls is None
+ * - ObjectType(collection) - when customCollectionCls contains a collection class
*
- * The following collection ObjectTypes are currently supported:
+ * The following collection ObjectTypes are currently supported on input:
* Seq, Array, ArrayData, java.util.List
*
* @param loopValue the name of the loop variable that used when iterate the collection, and used
@@ -458,13 +469,19 @@ object MapObjects {
* @param lambdaFunction A function that take the `loopVar` as input, and used as lambda function
* to handle collection elements.
* @param inputData An expression that when evaluated returns a collection object.
+ * @param customCollectionCls Class of the resulting collection (returning ObjectType)
+ * or None (returning ArrayType)
+ * @param builderValue The name of the builder variable used to construct the resulting collection
+ * (used only when returning ObjectType)
*/
case class MapObjects private(
loopValue: String,
loopIsNull: String,
loopVarDataType: DataType,
lambdaFunction: Expression,
- inputData: Expression) extends Expression with NonSQLExpression {
+ inputData: Expression,
+ customCollectionCls: Option[Class[_]],
+ builderValue: String) extends Expression with NonSQLExpression {
override def nullable: Boolean = inputData.nullable
@@ -474,7 +491,8 @@ case class MapObjects private(
throw new UnsupportedOperationException("Only code-generated evaluation is supported")
override def dataType: DataType =
- ArrayType(lambdaFunction.dataType, containsNull = lambdaFunction.nullable)
+ customCollectionCls.map(ObjectType.apply).getOrElse(
+ ArrayType(lambdaFunction.dataType, containsNull = lambdaFunction.nullable))
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val elementJavaType = ctx.javaType(loopVarDataType)
@@ -557,15 +575,33 @@ case class MapObjects private(
case _ => s"$loopIsNull = $loopValue == null;"
}
+ val (initCollection, addElement, getResult): (String, String => String, String) =
+ customCollectionCls match {
+ case Some(cls) =>
+ // collection
+ val collObjectName = s"${cls.getName}$$.MODULE$$"
+ val getBuilderVar = s"$collObjectName.newBuilder()"
+
+ (s"""${classOf[Builder[_, _]].getName} $builderValue = $getBuilderVar;
+ $builderValue.sizeHint($dataLength);""",
+ genValue => s"$builderValue.$$plus$$eq($genValue);",
+ s"(${cls.getName}) $builderValue.result();")
+ case None =>
+ // array
+ (s"""$convertedType[] $convertedArray = null;
+ $convertedArray = $arrayConstructor;""",
+ genValue => s"$convertedArray[$loopIndex] = $genValue;",
+ s"new ${classOf[GenericArrayData].getName}($convertedArray);")
+ }
+
val code = s"""
${genInputData.code}
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
if (!${genInputData.isNull}) {
$determineCollectionType
- $convertedType[] $convertedArray = null;
int $dataLength = $getLength;
- $convertedArray = $arrayConstructor;
+ $initCollection
int $loopIndex = 0;
while ($loopIndex < $dataLength) {
@@ -574,15 +610,15 @@ case class MapObjects private(
${genFunction.code}
if (${genFunction.isNull}) {
- $convertedArray[$loopIndex] = null;
+ ${addElement("null")}
} else {
- $convertedArray[$loopIndex] = $genFunctionValue;
+ ${addElement(genFunctionValue)}
}
$loopIndex += 1;
}
- ${ev.value} = new ${classOf[GenericArrayData].getName}($convertedArray);
+ ${ev.value} = $getResult
}
"""
ev.copy(code = code, isNull = genInputData.isNull)
http://git-wip-us.apache.org/repos/asf/spark/blob/6c70a38c/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
index 650a353..70ad064 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
@@ -312,14 +312,6 @@ class ScalaReflectionSuite extends SparkFunSuite {
ArrayType(IntegerType, containsNull = false))
val arrayBufferDeserializer = deserializerFor[ArrayBuffer[Int]]
assert(arrayBufferDeserializer.dataType == ObjectType(classOf[ArrayBuffer[_]]))
-
- // Check whether conversion is skipped when using WrappedArray[_] supertype
- // (would otherwise needlessly add overhead)
- import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
- val seqDeserializer = deserializerFor[Seq[Int]]
- assert(seqDeserializer.asInstanceOf[StaticInvoke].staticObject ==
- scala.collection.mutable.WrappedArray.getClass)
- assert(seqDeserializer.asInstanceOf[StaticInvoke].functionName == "make")
}
private val dataTypeForComplexData = dataTypeFor[ComplexData]
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org