You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2017/03/28 02:09:57 UTC

spark git commit: [SPARK-19088][SQL] Optimize sequence type deserialization codegen

Repository: spark
Updated Branches:
  refs/heads/master ea361165e -> 6c70a38c2


[SPARK-19088][SQL] Optimize sequence type deserialization codegen

## What changes were proposed in this pull request?

Optimization of arbitrary Scala sequence deserialization introduced by #16240.

The previous implementation constructed an array which was then converted by `to`. This required two passes in most cases.

This implementation attempts to remedy that by using `Builder`s provided by the `newBuilder` method on every Scala collection's companion object to build the resulting collection directly.

Example codegen for simple `List` (obtained using `Seq(List(1)).toDS().map(identity).queryExecution.debug.codegen`):

Before:

```
/* 001 */ public Object generate(Object[] references) {
/* 002 */   return new GeneratedIterator(references);
/* 003 */ }
/* 004 */
/* 005 */ final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator {
/* 006 */   private Object[] references;
/* 007 */   private scala.collection.Iterator[] inputs;
/* 008 */   private scala.collection.Iterator inputadapter_input;
/* 009 */   private boolean deserializetoobject_resultIsNull;
/* 010 */   private java.lang.Object[] deserializetoobject_argValue;
/* 011 */   private boolean MapObjects_loopIsNull1;
/* 012 */   private int MapObjects_loopValue0;
/* 013 */   private boolean deserializetoobject_resultIsNull1;
/* 014 */   private scala.collection.generic.CanBuildFrom deserializetoobject_argValue1;
/* 015 */   private UnsafeRow deserializetoobject_result;
/* 016 */   private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder deserializetoobject_holder;
/* 017 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter deserializetoobject_rowWriter;
/* 018 */   private scala.collection.immutable.List mapelements_argValue;
/* 019 */   private UnsafeRow mapelements_result;
/* 020 */   private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder mapelements_holder;
/* 021 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter mapelements_rowWriter;
/* 022 */   private scala.collection.immutable.List serializefromobject_argValue;
/* 023 */   private UnsafeRow serializefromobject_result;
/* 024 */   private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder serializefromobject_holder;
/* 025 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter serializefromobject_rowWriter;
/* 026 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter serializefromobject_arrayWriter;
/* 027 */
/* 028 */   public GeneratedIterator(Object[] references) {
/* 029 */     this.references = references;
/* 030 */   }
/* 031 */
/* 032 */   public void init(int index, scala.collection.Iterator[] inputs) {
/* 033 */     partitionIndex = index;
/* 034 */     this.inputs = inputs;
/* 035 */     inputadapter_input = inputs[0];
/* 036 */
/* 037 */     deserializetoobject_result = new UnsafeRow(1);
/* 038 */     this.deserializetoobject_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(deserializetoobject_result, 32);
/* 039 */     this.deserializetoobject_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(deserializetoobject_holder, 1);
/* 040 */
/* 041 */     mapelements_result = new UnsafeRow(1);
/* 042 */     this.mapelements_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(mapelements_result, 32);
/* 043 */     this.mapelements_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(mapelements_holder, 1);
/* 044 */
/* 045 */     serializefromobject_result = new UnsafeRow(1);
/* 046 */     this.serializefromobject_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(serializefromobject_result, 32);
/* 047 */     this.serializefromobject_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(serializefromobject_holder, 1);
/* 048 */     this.serializefromobject_arrayWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter();
/* 049 */
/* 050 */   }
/* 051 */
/* 052 */   protected void processNext() throws java.io.IOException {
/* 053 */     while (inputadapter_input.hasNext() && !stopEarly()) {
/* 054 */       InternalRow inputadapter_row = (InternalRow) inputadapter_input.next();
/* 055 */       ArrayData inputadapter_value = inputadapter_row.getArray(0);
/* 056 */
/* 057 */       deserializetoobject_resultIsNull = false;
/* 058 */
/* 059 */       if (!deserializetoobject_resultIsNull) {
/* 060 */         ArrayData deserializetoobject_value3 = null;
/* 061 */
/* 062 */         if (!false) {
/* 063 */           Integer[] deserializetoobject_convertedArray = null;
/* 064 */           int deserializetoobject_dataLength = inputadapter_value.numElements();
/* 065 */           deserializetoobject_convertedArray = new Integer[deserializetoobject_dataLength];
/* 066 */
/* 067 */           int deserializetoobject_loopIndex = 0;
/* 068 */           while (deserializetoobject_loopIndex < deserializetoobject_dataLength) {
/* 069 */             MapObjects_loopValue0 = (int) (inputadapter_value.getInt(deserializetoobject_loopIndex));
/* 070 */             MapObjects_loopIsNull1 = inputadapter_value.isNullAt(deserializetoobject_loopIndex);
/* 071 */
/* 072 */             if (MapObjects_loopIsNull1) {
/* 073 */               throw new RuntimeException(((java.lang.String) references[0]));
/* 074 */             }
/* 075 */             if (false) {
/* 076 */               deserializetoobject_convertedArray[deserializetoobject_loopIndex] = null;
/* 077 */             } else {
/* 078 */               deserializetoobject_convertedArray[deserializetoobject_loopIndex] = MapObjects_loopValue0;
/* 079 */             }
/* 080 */
/* 081 */             deserializetoobject_loopIndex += 1;
/* 082 */           }
/* 083 */
/* 084 */           deserializetoobject_value3 = new org.apache.spark.sql.catalyst.util.GenericArrayData(deserializetoobject_convertedArray);
/* 085 */         }
/* 086 */         boolean deserializetoobject_isNull2 = true;
/* 087 */         java.lang.Object[] deserializetoobject_value2 = null;
/* 088 */         if (!false) {
/* 089 */           deserializetoobject_isNull2 = false;
/* 090 */           if (!deserializetoobject_isNull2) {
/* 091 */             Object deserializetoobject_funcResult = null;
/* 092 */             deserializetoobject_funcResult = deserializetoobject_value3.array();
/* 093 */             if (deserializetoobject_funcResult == null) {
/* 094 */               deserializetoobject_isNull2 = true;
/* 095 */             } else {
/* 096 */               deserializetoobject_value2 = (java.lang.Object[]) deserializetoobject_funcResult;
/* 097 */             }
/* 098 */
/* 099 */           }
/* 100 */           deserializetoobject_isNull2 = deserializetoobject_value2 == null;
/* 101 */         }
/* 102 */         deserializetoobject_resultIsNull = deserializetoobject_isNull2;
/* 103 */         deserializetoobject_argValue = deserializetoobject_value2;
/* 104 */       }
/* 105 */
/* 106 */       boolean deserializetoobject_isNull1 = deserializetoobject_resultIsNull;
/* 107 */       final scala.collection.Seq deserializetoobject_value1 = deserializetoobject_resultIsNull ? null : scala.collection.mutable.WrappedArray.make(deserializetoobject_argValue);
/* 108 */       deserializetoobject_isNull1 = deserializetoobject_value1 == null;
/* 109 */       boolean deserializetoobject_isNull = true;
/* 110 */       scala.collection.immutable.List deserializetoobject_value = null;
/* 111 */       if (!deserializetoobject_isNull1) {
/* 112 */         deserializetoobject_resultIsNull1 = false;
/* 113 */
/* 114 */         if (!deserializetoobject_resultIsNull1) {
/* 115 */           boolean deserializetoobject_isNull6 = false;
/* 116 */           final scala.collection.generic.CanBuildFrom deserializetoobject_value6 = false ? null : scala.collection.immutable.List.canBuildFrom();
/* 117 */           deserializetoobject_isNull6 = deserializetoobject_value6 == null;
/* 118 */           deserializetoobject_resultIsNull1 = deserializetoobject_isNull6;
/* 119 */           deserializetoobject_argValue1 = deserializetoobject_value6;
/* 120 */         }
/* 121 */
/* 122 */         deserializetoobject_isNull = deserializetoobject_resultIsNull1;
/* 123 */         if (!deserializetoobject_isNull) {
/* 124 */           Object deserializetoobject_funcResult1 = null;
/* 125 */           deserializetoobject_funcResult1 = deserializetoobject_value1.to(deserializetoobject_argValue1);
/* 126 */           if (deserializetoobject_funcResult1 == null) {
/* 127 */             deserializetoobject_isNull = true;
/* 128 */           } else {
/* 129 */             deserializetoobject_value = (scala.collection.immutable.List) deserializetoobject_funcResult1;
/* 130 */           }
/* 131 */
/* 132 */         }
/* 133 */         deserializetoobject_isNull = deserializetoobject_value == null;
/* 134 */       }
/* 135 */
/* 136 */       boolean mapelements_isNull = true;
/* 137 */       scala.collection.immutable.List mapelements_value = null;
/* 138 */       if (!false) {
/* 139 */         mapelements_argValue = deserializetoobject_value;
/* 140 */
/* 141 */         mapelements_isNull = false;
/* 142 */         if (!mapelements_isNull) {
/* 143 */           Object mapelements_funcResult = null;
/* 144 */           mapelements_funcResult = ((scala.Function1) references[1]).apply(mapelements_argValue);
/* 145 */           if (mapelements_funcResult == null) {
/* 146 */             mapelements_isNull = true;
/* 147 */           } else {
/* 148 */             mapelements_value = (scala.collection.immutable.List) mapelements_funcResult;
/* 149 */           }
/* 150 */
/* 151 */         }
/* 152 */         mapelements_isNull = mapelements_value == null;
/* 153 */       }
/* 154 */
/* 155 */       if (mapelements_isNull) {
/* 156 */         throw new RuntimeException(((java.lang.String) references[2]));
/* 157 */       }
/* 158 */       serializefromobject_argValue = mapelements_value;
/* 159 */
/* 160 */       final ArrayData serializefromobject_value = false ? null : new org.apache.spark.sql.catalyst.util.GenericArrayData(serializefromobject_argValue);
/* 161 */       serializefromobject_holder.reset();
/* 162 */
/* 163 */       // Remember the current cursor so that we can calculate how many bytes are
/* 164 */       // written later.
/* 165 */       final int serializefromobject_tmpCursor = serializefromobject_holder.cursor;
/* 166 */
/* 167 */       if (serializefromobject_value instanceof UnsafeArrayData) {
/* 168 */         final int serializefromobject_sizeInBytes = ((UnsafeArrayData) serializefromobject_value).getSizeInBytes();
/* 169 */         // grow the global buffer before writing data.
/* 170 */         serializefromobject_holder.grow(serializefromobject_sizeInBytes);
/* 171 */         ((UnsafeArrayData) serializefromobject_value).writeToMemory(serializefromobject_holder.buffer, serializefromobject_holder.cursor);
/* 172 */         serializefromobject_holder.cursor += serializefromobject_sizeInBytes;
/* 173 */
/* 174 */       } else {
/* 175 */         final int serializefromobject_numElements = serializefromobject_value.numElements();
/* 176 */         serializefromobject_arrayWriter.initialize(serializefromobject_holder, serializefromobject_numElements, 4);
/* 177 */
/* 178 */         for (int serializefromobject_index = 0; serializefromobject_index < serializefromobject_numElements; serializefromobject_index++) {
/* 179 */           if (serializefromobject_value.isNullAt(serializefromobject_index)) {
/* 180 */             serializefromobject_arrayWriter.setNullInt(serializefromobject_index);
/* 181 */           } else {
/* 182 */             final int serializefromobject_element = serializefromobject_value.getInt(serializefromobject_index);
/* 183 */             serializefromobject_arrayWriter.write(serializefromobject_index, serializefromobject_element);
/* 184 */           }
/* 185 */         }
/* 186 */       }
/* 187 */
/* 188 */       serializefromobject_rowWriter.setOffsetAndSize(0, serializefromobject_tmpCursor, serializefromobject_holder.cursor - serializefromobject_tmpCursor);
/* 189 */       serializefromobject_result.setTotalSize(serializefromobject_holder.totalSize());
/* 190 */       append(serializefromobject_result);
/* 191 */       if (shouldStop()) return;
/* 192 */     }
/* 193 */   }
/* 194 */ }
```

After:

```
/* 001 */ public Object generate(Object[] references) {
/* 002 */   return new GeneratedIterator(references);
/* 003 */ }
/* 004 */
/* 005 */ final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator {
/* 006 */   private Object[] references;
/* 007 */   private scala.collection.Iterator[] inputs;
/* 008 */   private scala.collection.Iterator inputadapter_input;
/* 009 */   private boolean CollectObjects_loopIsNull1;
/* 010 */   private int CollectObjects_loopValue0;
/* 011 */   private UnsafeRow deserializetoobject_result;
/* 012 */   private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder deserializetoobject_holder;
/* 013 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter deserializetoobject_rowWriter;
/* 014 */   private scala.collection.immutable.List mapelements_argValue;
/* 015 */   private UnsafeRow mapelements_result;
/* 016 */   private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder mapelements_holder;
/* 017 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter mapelements_rowWriter;
/* 018 */   private scala.collection.immutable.List serializefromobject_argValue;
/* 019 */   private UnsafeRow serializefromobject_result;
/* 020 */   private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder serializefromobject_holder;
/* 021 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter serializefromobject_rowWriter;
/* 022 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter serializefromobject_arrayWriter;
/* 023 */
/* 024 */   public GeneratedIterator(Object[] references) {
/* 025 */     this.references = references;
/* 026 */   }
/* 027 */
/* 028 */   public void init(int index, scala.collection.Iterator[] inputs) {
/* 029 */     partitionIndex = index;
/* 030 */     this.inputs = inputs;
/* 031 */     inputadapter_input = inputs[0];
/* 032 */
/* 033 */     deserializetoobject_result = new UnsafeRow(1);
/* 034 */     this.deserializetoobject_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(deserializetoobject_result, 32);
/* 035 */     this.deserializetoobject_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(deserializetoobject_holder, 1);
/* 036 */
/* 037 */     mapelements_result = new UnsafeRow(1);
/* 038 */     this.mapelements_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(mapelements_result, 32);
/* 039 */     this.mapelements_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(mapelements_holder, 1);
/* 040 */
/* 041 */     serializefromobject_result = new UnsafeRow(1);
/* 042 */     this.serializefromobject_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(serializefromobject_result, 32);
/* 043 */     this.serializefromobject_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(serializefromobject_holder, 1);
/* 044 */     this.serializefromobject_arrayWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter();
/* 045 */
/* 046 */   }
/* 047 */
/* 048 */   protected void processNext() throws java.io.IOException {
/* 049 */     while (inputadapter_input.hasNext() && !stopEarly()) {
/* 050 */       InternalRow inputadapter_row = (InternalRow) inputadapter_input.next();
/* 051 */       ArrayData inputadapter_value = inputadapter_row.getArray(0);
/* 052 */
/* 053 */       scala.collection.immutable.List deserializetoobject_value = null;
/* 054 */
/* 055 */       if (!false) {
/* 056 */         int deserializetoobject_dataLength = inputadapter_value.numElements();
/* 057 */         scala.collection.mutable.Builder CollectObjects_builderValue2 = scala.collection.immutable.List$.MODULE$.newBuilder();
/* 058 */         CollectObjects_builderValue2.sizeHint(deserializetoobject_dataLength);
/* 059 */
/* 060 */         int deserializetoobject_loopIndex = 0;
/* 061 */         while (deserializetoobject_loopIndex < deserializetoobject_dataLength) {
/* 062 */           CollectObjects_loopValue0 = (int) (inputadapter_value.getInt(deserializetoobject_loopIndex));
/* 063 */           CollectObjects_loopIsNull1 = inputadapter_value.isNullAt(deserializetoobject_loopIndex);
/* 064 */
/* 065 */           if (CollectObjects_loopIsNull1) {
/* 066 */             throw new RuntimeException(((java.lang.String) references[0]));
/* 067 */           }
/* 068 */           if (false) {
/* 069 */             CollectObjects_builderValue2.$plus$eq(null);
/* 070 */           } else {
/* 071 */             CollectObjects_builderValue2.$plus$eq(CollectObjects_loopValue0);
/* 072 */           }
/* 073 */
/* 074 */           deserializetoobject_loopIndex += 1;
/* 075 */         }
/* 076 */
/* 077 */         deserializetoobject_value = (scala.collection.immutable.List) CollectObjects_builderValue2.result();
/* 078 */       }
/* 079 */
/* 080 */       boolean mapelements_isNull = true;
/* 081 */       scala.collection.immutable.List mapelements_value = null;
/* 082 */       if (!false) {
/* 083 */         mapelements_argValue = deserializetoobject_value;
/* 084 */
/* 085 */         mapelements_isNull = false;
/* 086 */         if (!mapelements_isNull) {
/* 087 */           Object mapelements_funcResult = null;
/* 088 */           mapelements_funcResult = ((scala.Function1) references[1]).apply(mapelements_argValue);
/* 089 */           if (mapelements_funcResult == null) {
/* 090 */             mapelements_isNull = true;
/* 091 */           } else {
/* 092 */             mapelements_value = (scala.collection.immutable.List) mapelements_funcResult;
/* 093 */           }
/* 094 */
/* 095 */         }
/* 096 */         mapelements_isNull = mapelements_value == null;
/* 097 */       }
/* 098 */
/* 099 */       if (mapelements_isNull) {
/* 100 */         throw new RuntimeException(((java.lang.String) references[2]));
/* 101 */       }
/* 102 */       serializefromobject_argValue = mapelements_value;
/* 103 */
/* 104 */       final ArrayData serializefromobject_value = false ? null : new org.apache.spark.sql.catalyst.util.GenericArrayData(serializefromobject_argValue);
/* 105 */       serializefromobject_holder.reset();
/* 106 */
/* 107 */       // Remember the current cursor so that we can calculate how many bytes are
/* 108 */       // written later.
/* 109 */       final int serializefromobject_tmpCursor = serializefromobject_holder.cursor;
/* 110 */
/* 111 */       if (serializefromobject_value instanceof UnsafeArrayData) {
/* 112 */         final int serializefromobject_sizeInBytes = ((UnsafeArrayData) serializefromobject_value).getSizeInBytes();
/* 113 */         // grow the global buffer before writing data.
/* 114 */         serializefromobject_holder.grow(serializefromobject_sizeInBytes);
/* 115 */         ((UnsafeArrayData) serializefromobject_value).writeToMemory(serializefromobject_holder.buffer, serializefromobject_holder.cursor);
/* 116 */         serializefromobject_holder.cursor += serializefromobject_sizeInBytes;
/* 117 */
/* 118 */       } else {
/* 119 */         final int serializefromobject_numElements = serializefromobject_value.numElements();
/* 120 */         serializefromobject_arrayWriter.initialize(serializefromobject_holder, serializefromobject_numElements, 4);
/* 121 */
/* 122 */         for (int serializefromobject_index = 0; serializefromobject_index < serializefromobject_numElements; serializefromobject_index++) {
/* 123 */           if (serializefromobject_value.isNullAt(serializefromobject_index)) {
/* 124 */             serializefromobject_arrayWriter.setNullInt(serializefromobject_index);
/* 125 */           } else {
/* 126 */             final int serializefromobject_element = serializefromobject_value.getInt(serializefromobject_index);
/* 127 */             serializefromobject_arrayWriter.write(serializefromobject_index, serializefromobject_element);
/* 128 */           }
/* 129 */         }
/* 130 */       }
/* 131 */
/* 132 */       serializefromobject_rowWriter.setOffsetAndSize(0, serializefromobject_tmpCursor, serializefromobject_holder.cursor - serializefromobject_tmpCursor);
/* 133 */       serializefromobject_result.setTotalSize(serializefromobject_holder.totalSize());
/* 134 */       append(serializefromobject_result);
/* 135 */       if (shouldStop()) return;
/* 136 */     }
/* 137 */   }
/* 138 */ }
```

Benchmark results before:

```
OpenJDK 64-Bit Server VM 1.8.0_112-b15 on Linux 4.8.13-1-ARCH
AMD A10-4600M APU with Radeon(tm) HD Graphics
collect:                                 Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
------------------------------------------------------------------------------------------------
Seq                                            269 /  370          0.0      269125.8       1.0X
List                                           154 /  176          0.0      154453.5       1.7X
mutable.Queue                                  210 /  233          0.0      209691.6       1.3X
```

Benchmark results after:

```
OpenJDK 64-Bit Server VM 1.8.0_112-b15 on Linux 4.8.13-1-ARCH
AMD A10-4600M APU with Radeon(tm) HD Graphics
collect:                                 Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
------------------------------------------------------------------------------------------------
Seq                                            255 /  316          0.0      254697.3       1.0X
List                                           152 /  177          0.0      152410.0       1.7X
mutable.Queue                                  213 /  235          0.0      213470.0       1.2X
```

## How was this patch tested?

```bash
./build/mvn -DskipTests clean package && ./dev/run-tests
```

Additionally in Spark Shell:

```scala
case class QueueClass(q: scala.collection.immutable.Queue[Int])

spark.createDataset(Seq(List(1,2,3))).map(x => QueueClass(scala.collection.immutable.Queue(x: _*))).map(_.q.dequeue).collect
```

Author: Michal Senkyr <mi...@gmail.com>

Closes #16541 from michalsenkyr/dataset-seq-builder.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/6c70a38c
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/6c70a38c
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/6c70a38c

Branch: refs/heads/master
Commit: 6c70a38c2e60e1b69a310aee1a92ee0b3815c02d
Parents: ea36116
Author: Michal Senkyr <mi...@gmail.com>
Authored: Tue Mar 28 10:09:49 2017 +0800
Committer: Wenchen Fan <we...@databricks.com>
Committed: Tue Mar 28 10:09:49 2017 +0800

----------------------------------------------------------------------
 .../spark/sql/catalyst/ScalaReflection.scala    | 51 ++--------------
 .../catalyst/expressions/objects/objects.scala  | 64 +++++++++++++++-----
 .../sql/catalyst/ScalaReflectionSuite.scala     |  8 ---
 3 files changed, 54 insertions(+), 69 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/6c70a38c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index c4af284..1c7720a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -307,54 +307,11 @@ object ScalaReflection extends ScalaReflection {
           }
         }
 
-        val array = Invoke(
-          MapObjects(mapFunction, getPath, dataType),
-          "array",
-          ObjectType(classOf[Array[Any]]))
-
-        val wrappedArray = StaticInvoke(
-          scala.collection.mutable.WrappedArray.getClass,
-          ObjectType(classOf[Seq[_]]),
-          "make",
-          array :: Nil)
-
-        if (localTypeOf[scala.collection.mutable.WrappedArray[_]] <:< t.erasure) {
-          wrappedArray
-        } else {
-          // Convert to another type using `to`
-          val cls = mirror.runtimeClass(t.typeSymbol.asClass)
-          import scala.collection.generic.CanBuildFrom
-          import scala.reflect.ClassTag
-
-          // Some canBuildFrom methods take an implicit ClassTag parameter
-          val cbfParams = try {
-            cls.getDeclaredMethod("canBuildFrom", classOf[ClassTag[_]])
-            StaticInvoke(
-              ClassTag.getClass,
-              ObjectType(classOf[ClassTag[_]]),
-              "apply",
-              StaticInvoke(
-                cls,
-                ObjectType(classOf[Class[_]]),
-                "getClass"
-              ) :: Nil
-            ) :: Nil
-          } catch {
-            case _: NoSuchMethodException => Nil
-          }
-
-          Invoke(
-            wrappedArray,
-            "to",
-            ObjectType(cls),
-            StaticInvoke(
-              cls,
-              ObjectType(classOf[CanBuildFrom[_, _, _]]),
-              "canBuildFrom",
-              cbfParams
-            ) :: Nil
-          )
+        val cls = t.dealias.companion.decl(TermName("newBuilder")) match {
+          case NoSymbol => classOf[Seq[_]]
+          case _ => mirror.runtimeClass(t.typeSymbol.asClass)
         }
+        MapObjects(mapFunction, getPath, dataType, Some(cls))
 
       case t if t <:< localTypeOf[Map[_, _]] =>
         // TODO: add walked type path for map

http://git-wip-us.apache.org/repos/asf/spark/blob/6c70a38c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index 771ac28..bb584f7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.objects
 
 import java.lang.reflect.Modifier
 
+import scala.collection.mutable.Builder
 import scala.language.existentials
 import scala.reflect.ClassTag
 
@@ -429,24 +430,34 @@ object MapObjects {
    * @param function The function applied on the collection elements.
    * @param inputData An expression that when evaluated returns a collection object.
    * @param elementType The data type of elements in the collection.
+   * @param customCollectionCls Class of the resulting collection (returning ObjectType)
+   *                            or None (returning ArrayType)
    */
   def apply(
       function: Expression => Expression,
       inputData: Expression,
-      elementType: DataType): MapObjects = {
-    val loopValue = "MapObjects_loopValue" + curId.getAndIncrement()
-    val loopIsNull = "MapObjects_loopIsNull" + curId.getAndIncrement()
+      elementType: DataType,
+      customCollectionCls: Option[Class[_]] = None): MapObjects = {
+    val id = curId.getAndIncrement()
+    val loopValue = s"MapObjects_loopValue$id"
+    val loopIsNull = s"MapObjects_loopIsNull$id"
     val loopVar = LambdaVariable(loopValue, loopIsNull, elementType)
-    MapObjects(loopValue, loopIsNull, elementType, function(loopVar), inputData)
+    val builderValue = s"MapObjects_builderValue$id"
+    MapObjects(loopValue, loopIsNull, elementType, function(loopVar), inputData,
+      customCollectionCls, builderValue)
   }
 }
 
 /**
  * Applies the given expression to every element of a collection of items, returning the result
- * as an ArrayType.  This is similar to a typical map operation, but where the lambda function
- * is expressed using catalyst expressions.
+ * as an ArrayType or ObjectType. This is similar to a typical map operation, but where the lambda
+ * function is expressed using catalyst expressions.
+ *
+ * The type of the result is determined as follows:
+ * - ArrayType - when customCollectionCls is None
+ * - ObjectType(collection) - when customCollectionCls contains a collection class
  *
- * The following collection ObjectTypes are currently supported:
+ * The following collection ObjectTypes are currently supported on input:
  *   Seq, Array, ArrayData, java.util.List
  *
  * @param loopValue the name of the loop variable that used when iterate the collection, and used
@@ -458,13 +469,19 @@ object MapObjects {
  * @param lambdaFunction A function that take the `loopVar` as input, and used as lambda function
  *                       to handle collection elements.
  * @param inputData An expression that when evaluated returns a collection object.
+ * @param customCollectionCls Class of the resulting collection (returning ObjectType)
+ *                            or None (returning ArrayType)
+ * @param builderValue The name of the builder variable used to construct the resulting collection
+ *                     (used only when returning ObjectType)
  */
 case class MapObjects private(
     loopValue: String,
     loopIsNull: String,
     loopVarDataType: DataType,
     lambdaFunction: Expression,
-    inputData: Expression) extends Expression with NonSQLExpression {
+    inputData: Expression,
+    customCollectionCls: Option[Class[_]],
+    builderValue: String) extends Expression with NonSQLExpression {
 
   override def nullable: Boolean = inputData.nullable
 
@@ -474,7 +491,8 @@ case class MapObjects private(
     throw new UnsupportedOperationException("Only code-generated evaluation is supported")
 
   override def dataType: DataType =
-    ArrayType(lambdaFunction.dataType, containsNull = lambdaFunction.nullable)
+    customCollectionCls.map(ObjectType.apply).getOrElse(
+      ArrayType(lambdaFunction.dataType, containsNull = lambdaFunction.nullable))
 
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
     val elementJavaType = ctx.javaType(loopVarDataType)
@@ -557,15 +575,33 @@ case class MapObjects private(
       case _ => s"$loopIsNull = $loopValue == null;"
     }
 
+    val (initCollection, addElement, getResult): (String, String => String, String) =
+      customCollectionCls match {
+        case Some(cls) =>
+          // collection
+          val collObjectName = s"${cls.getName}$$.MODULE$$"
+          val getBuilderVar = s"$collObjectName.newBuilder()"
+
+          (s"""${classOf[Builder[_, _]].getName} $builderValue = $getBuilderVar;
+        $builderValue.sizeHint($dataLength);""",
+            genValue => s"$builderValue.$$plus$$eq($genValue);",
+            s"(${cls.getName}) $builderValue.result();")
+        case None =>
+          // array
+          (s"""$convertedType[] $convertedArray = null;
+        $convertedArray = $arrayConstructor;""",
+            genValue => s"$convertedArray[$loopIndex] = $genValue;",
+            s"new ${classOf[GenericArrayData].getName}($convertedArray);")
+      }
+
     val code = s"""
       ${genInputData.code}
       ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
 
       if (!${genInputData.isNull}) {
         $determineCollectionType
-        $convertedType[] $convertedArray = null;
         int $dataLength = $getLength;
-        $convertedArray = $arrayConstructor;
+        $initCollection
 
         int $loopIndex = 0;
         while ($loopIndex < $dataLength) {
@@ -574,15 +610,15 @@ case class MapObjects private(
 
           ${genFunction.code}
           if (${genFunction.isNull}) {
-            $convertedArray[$loopIndex] = null;
+            ${addElement("null")}
           } else {
-            $convertedArray[$loopIndex] = $genFunctionValue;
+            ${addElement(genFunctionValue)}
           }
 
           $loopIndex += 1;
         }
 
-        ${ev.value} = new ${classOf[GenericArrayData].getName}($convertedArray);
+        ${ev.value} = $getResult
       }
     """
     ev.copy(code = code, isNull = genInputData.isNull)

http://git-wip-us.apache.org/repos/asf/spark/blob/6c70a38c/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
index 650a353..70ad064 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
@@ -312,14 +312,6 @@ class ScalaReflectionSuite extends SparkFunSuite {
       ArrayType(IntegerType, containsNull = false))
     val arrayBufferDeserializer = deserializerFor[ArrayBuffer[Int]]
     assert(arrayBufferDeserializer.dataType == ObjectType(classOf[ArrayBuffer[_]]))
-
-    // Check whether conversion is skipped when using WrappedArray[_] supertype
-    // (would otherwise needlessly add overhead)
-    import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
-    val seqDeserializer = deserializerFor[Seq[Int]]
-    assert(seqDeserializer.asInstanceOf[StaticInvoke].staticObject ==
-      scala.collection.mutable.WrappedArray.getClass)
-    assert(seqDeserializer.asInstanceOf[StaticInvoke].functionName == "make")
   }
 
   private val dataTypeForComplexData = dataTypeFor[ComplexData]


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org