You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2017/06/12 00:47:10 UTC
spark git commit: [SPARK-18891][SQL] Support for Scala Map collection
types
Repository: spark
Updated Branches:
refs/heads/master a7c61c100 -> 0538f3b0a
[SPARK-18891][SQL] Support for Scala Map collection types
## What changes were proposed in this pull request?
Add support for arbitrary Scala `Map` types in deserialization as well as a generic implicit encoder.
Used the builder approach as in #16541 to construct any provided `Map` type upon deserialization.
Please note that this PR also adds (ignored) tests for issue [SPARK-19104 CompileException with Map and Case Class in Spark 2.1.0](https://issues.apache.org/jira/browse/SPARK-19104) but doesn't solve it.
Added support for Java Maps in codegen code (encoders will be added in a different PR) with the following default implementations for interfaces/abstract classes:
* `java.util.Map`, `java.util.AbstractMap` => `java.util.HashMap`
* `java.util.SortedMap`, `java.util.NavigableMap` => `java.util.TreeMap`
* `java.util.concurrent.ConcurrentMap` => `java.util.concurrent.ConcurrentHashMap`
* `java.util.concurrent.ConcurrentNavigableMap` => `java.util.concurrent.ConcurrentSkipListMap`
Resulting codegen for `Seq(Map(1 -> 2)).toDS().map(identity).queryExecution.debug.codegen`:
```
/* 001 */ public Object generate(Object[] references) {
/* 002 */ return new GeneratedIterator(references);
/* 003 */ }
/* 004 */
/* 005 */ final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator {
/* 006 */ private Object[] references;
/* 007 */ private scala.collection.Iterator[] inputs;
/* 008 */ private scala.collection.Iterator inputadapter_input;
/* 009 */ private boolean CollectObjectsToMap_loopIsNull1;
/* 010 */ private int CollectObjectsToMap_loopValue0;
/* 011 */ private boolean CollectObjectsToMap_loopIsNull3;
/* 012 */ private int CollectObjectsToMap_loopValue2;
/* 013 */ private UnsafeRow deserializetoobject_result;
/* 014 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder deserializetoobject_holder;
/* 015 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter deserializetoobject_rowWriter;
/* 016 */ private scala.collection.immutable.Map mapelements_argValue;
/* 017 */ private UnsafeRow mapelements_result;
/* 018 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder mapelements_holder;
/* 019 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter mapelements_rowWriter;
/* 020 */ private UnsafeRow serializefromobject_result;
/* 021 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder serializefromobject_holder;
/* 022 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter serializefromobject_rowWriter;
/* 023 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter serializefromobject_arrayWriter;
/* 024 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter serializefromobject_arrayWriter1;
/* 025 */
/* 026 */ public GeneratedIterator(Object[] references) {
/* 027 */ this.references = references;
/* 028 */ }
/* 029 */
/* 030 */ public void init(int index, scala.collection.Iterator[] inputs) {
/* 031 */ partitionIndex = index;
/* 032 */ this.inputs = inputs;
/* 033 */ wholestagecodegen_init_0();
/* 034 */ wholestagecodegen_init_1();
/* 035 */
/* 036 */ }
/* 037 */
/* 038 */ private void wholestagecodegen_init_0() {
/* 039 */ inputadapter_input = inputs[0];
/* 040 */
/* 041 */ deserializetoobject_result = new UnsafeRow(1);
/* 042 */ this.deserializetoobject_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(deserializetoobject_result, 32);
/* 043 */ this.deserializetoobject_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(deserializetoobject_holder, 1);
/* 044 */
/* 045 */ mapelements_result = new UnsafeRow(1);
/* 046 */ this.mapelements_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(mapelements_result, 32);
/* 047 */ this.mapelements_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(mapelements_holder, 1);
/* 048 */ serializefromobject_result = new UnsafeRow(1);
/* 049 */ this.serializefromobject_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(serializefromobject_result, 32);
/* 050 */ this.serializefromobject_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(serializefromobject_holder, 1);
/* 051 */ this.serializefromobject_arrayWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter();
/* 052 */
/* 053 */ }
/* 054 */
/* 055 */ private void wholestagecodegen_init_1() {
/* 056 */ this.serializefromobject_arrayWriter1 = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter();
/* 057 */
/* 058 */ }
/* 059 */
/* 060 */ protected void processNext() throws java.io.IOException {
/* 061 */ while (inputadapter_input.hasNext() && !stopEarly()) {
/* 062 */ InternalRow inputadapter_row = (InternalRow) inputadapter_input.next();
/* 063 */ boolean inputadapter_isNull = inputadapter_row.isNullAt(0);
/* 064 */ MapData inputadapter_value = inputadapter_isNull ? null : (inputadapter_row.getMap(0));
/* 065 */
/* 066 */ boolean deserializetoobject_isNull1 = true;
/* 067 */ ArrayData deserializetoobject_value1 = null;
/* 068 */ if (!inputadapter_isNull) {
/* 069 */ deserializetoobject_isNull1 = false;
/* 070 */ if (!deserializetoobject_isNull1) {
/* 071 */ Object deserializetoobject_funcResult = null;
/* 072 */ deserializetoobject_funcResult = inputadapter_value.keyArray();
/* 073 */ if (deserializetoobject_funcResult == null) {
/* 074 */ deserializetoobject_isNull1 = true;
/* 075 */ } else {
/* 076 */ deserializetoobject_value1 = (ArrayData) deserializetoobject_funcResult;
/* 077 */ }
/* 078 */
/* 079 */ }
/* 080 */ deserializetoobject_isNull1 = deserializetoobject_value1 == null;
/* 081 */ }
/* 082 */
/* 083 */ boolean deserializetoobject_isNull3 = true;
/* 084 */ ArrayData deserializetoobject_value3 = null;
/* 085 */ if (!inputadapter_isNull) {
/* 086 */ deserializetoobject_isNull3 = false;
/* 087 */ if (!deserializetoobject_isNull3) {
/* 088 */ Object deserializetoobject_funcResult1 = null;
/* 089 */ deserializetoobject_funcResult1 = inputadapter_value.valueArray();
/* 090 */ if (deserializetoobject_funcResult1 == null) {
/* 091 */ deserializetoobject_isNull3 = true;
/* 092 */ } else {
/* 093 */ deserializetoobject_value3 = (ArrayData) deserializetoobject_funcResult1;
/* 094 */ }
/* 095 */
/* 096 */ }
/* 097 */ deserializetoobject_isNull3 = deserializetoobject_value3 == null;
/* 098 */ }
/* 099 */ scala.collection.immutable.Map deserializetoobject_value = null;
/* 100 */
/* 101 */ if ((deserializetoobject_isNull1 && !deserializetoobject_isNull3) ||
/* 102 */ (!deserializetoobject_isNull1 && deserializetoobject_isNull3)) {
/* 103 */ throw new RuntimeException("Invalid state: Inconsistent nullability of key-value");
/* 104 */ }
/* 105 */
/* 106 */ if (!deserializetoobject_isNull1) {
/* 107 */ if (deserializetoobject_value1.numElements() != deserializetoobject_value3.numElements()) {
/* 108 */ throw new RuntimeException("Invalid state: Inconsistent lengths of key-value arrays");
/* 109 */ }
/* 110 */ int deserializetoobject_dataLength = deserializetoobject_value1.numElements();
/* 111 */
/* 112 */ scala.collection.mutable.Builder CollectObjectsToMap_builderValue5 = scala.collection.immutable.Map$.MODULE$.newBuilder();
/* 113 */ CollectObjectsToMap_builderValue5.sizeHint(deserializetoobject_dataLength);
/* 114 */
/* 115 */ int deserializetoobject_loopIndex = 0;
/* 116 */ while (deserializetoobject_loopIndex < deserializetoobject_dataLength) {
/* 117 */ CollectObjectsToMap_loopValue0 = (int) (deserializetoobject_value1.getInt(deserializetoobject_loopIndex));
/* 118 */ CollectObjectsToMap_loopValue2 = (int) (deserializetoobject_value3.getInt(deserializetoobject_loopIndex));
/* 119 */ CollectObjectsToMap_loopIsNull1 = deserializetoobject_value1.isNullAt(deserializetoobject_loopIndex);
/* 120 */ CollectObjectsToMap_loopIsNull3 = deserializetoobject_value3.isNullAt(deserializetoobject_loopIndex);
/* 121 */
/* 122 */ if (CollectObjectsToMap_loopIsNull1) {
/* 123 */ throw new RuntimeException("Found null in map key!");
/* 124 */ }
/* 125 */
/* 126 */ scala.Tuple2 CollectObjectsToMap_loopValue4;
/* 127 */
/* 128 */ if (CollectObjectsToMap_loopIsNull3) {
/* 129 */ CollectObjectsToMap_loopValue4 = new scala.Tuple2(CollectObjectsToMap_loopValue0, null);
/* 130 */ } else {
/* 131 */ CollectObjectsToMap_loopValue4 = new scala.Tuple2(CollectObjectsToMap_loopValue0, CollectObjectsToMap_loopValue2);
/* 132 */ }
/* 133 */
/* 134 */ CollectObjectsToMap_builderValue5.$plus$eq(CollectObjectsToMap_loopValue4);
/* 135 */
/* 136 */ deserializetoobject_loopIndex += 1;
/* 137 */ }
/* 138 */
/* 139 */ deserializetoobject_value = (scala.collection.immutable.Map) CollectObjectsToMap_builderValue5.result();
/* 140 */ }
/* 141 */
/* 142 */ boolean mapelements_isNull = true;
/* 143 */ scala.collection.immutable.Map mapelements_value = null;
/* 144 */ if (!false) {
/* 145 */ mapelements_argValue = deserializetoobject_value;
/* 146 */
/* 147 */ mapelements_isNull = false;
/* 148 */ if (!mapelements_isNull) {
/* 149 */ Object mapelements_funcResult = null;
/* 150 */ mapelements_funcResult = ((scala.Function1) references[0]).apply(mapelements_argValue);
/* 151 */ if (mapelements_funcResult == null) {
/* 152 */ mapelements_isNull = true;
/* 153 */ } else {
/* 154 */ mapelements_value = (scala.collection.immutable.Map) mapelements_funcResult;
/* 155 */ }
/* 156 */
/* 157 */ }
/* 158 */ mapelements_isNull = mapelements_value == null;
/* 159 */ }
/* 160 */
/* 161 */ MapData serializefromobject_value = null;
/* 162 */ if (!mapelements_isNull) {
/* 163 */ final int serializefromobject_length = mapelements_value.size();
/* 164 */ final Object[] serializefromobject_convertedKeys = new Object[serializefromobject_length];
/* 165 */ final Object[] serializefromobject_convertedValues = new Object[serializefromobject_length];
/* 166 */ int serializefromobject_index = 0;
/* 167 */ final scala.collection.Iterator serializefromobject_entries = mapelements_value.iterator();
/* 168 */ while(serializefromobject_entries.hasNext()) {
/* 169 */ final scala.Tuple2 serializefromobject_entry = (scala.Tuple2) serializefromobject_entries.next();
/* 170 */ int ExternalMapToCatalyst_key1 = (Integer) serializefromobject_entry._1();
/* 171 */ int ExternalMapToCatalyst_value1 = (Integer) serializefromobject_entry._2();
/* 172 */
/* 173 */ boolean ExternalMapToCatalyst_value_isNull1 = false;
/* 174 */
/* 175 */ if (false) {
/* 176 */ throw new RuntimeException("Cannot use null as map key!");
/* 177 */ } else {
/* 178 */ serializefromobject_convertedKeys[serializefromobject_index] = (Integer) ExternalMapToCatalyst_key1;
/* 179 */ }
/* 180 */
/* 181 */ if (false) {
/* 182 */ serializefromobject_convertedValues[serializefromobject_index] = null;
/* 183 */ } else {
/* 184 */ serializefromobject_convertedValues[serializefromobject_index] = (Integer) ExternalMapToCatalyst_value1;
/* 185 */ }
/* 186 */
/* 187 */ serializefromobject_index++;
/* 188 */ }
/* 189 */
/* 190 */ serializefromobject_value = new org.apache.spark.sql.catalyst.util.ArrayBasedMapData(new org.apache.spark.sql.catalyst.util.GenericArrayData(serializefromobject_convertedKeys), new org.apache.spark.sql.catalyst.util.GenericArrayData(serializefromobject_convertedValues));
/* 191 */ }
/* 192 */ serializefromobject_holder.reset();
/* 193 */
/* 194 */ serializefromobject_rowWriter.zeroOutNullBytes();
/* 195 */
/* 196 */ if (mapelements_isNull) {
/* 197 */ serializefromobject_rowWriter.setNullAt(0);
/* 198 */ } else {
/* 199 */ // Remember the current cursor so that we can calculate how many bytes are
/* 200 */ // written later.
/* 201 */ final int serializefromobject_tmpCursor = serializefromobject_holder.cursor;
/* 202 */
/* 203 */ if (serializefromobject_value instanceof UnsafeMapData) {
/* 204 */ final int serializefromobject_sizeInBytes = ((UnsafeMapData) serializefromobject_value).getSizeInBytes();
/* 205 */ // grow the global buffer before writing data.
/* 206 */ serializefromobject_holder.grow(serializefromobject_sizeInBytes);
/* 207 */ ((UnsafeMapData) serializefromobject_value).writeToMemory(serializefromobject_holder.buffer, serializefromobject_holder.cursor);
/* 208 */ serializefromobject_holder.cursor += serializefromobject_sizeInBytes;
/* 209 */
/* 210 */ } else {
/* 211 */ final ArrayData serializefromobject_keys = serializefromobject_value.keyArray();
/* 212 */ final ArrayData serializefromobject_values = serializefromobject_value.valueArray();
/* 213 */
/* 214 */ // preserve 8 bytes to write the key array numBytes later.
/* 215 */ serializefromobject_holder.grow(8);
/* 216 */ serializefromobject_holder.cursor += 8;
/* 217 */
/* 218 */ // Remember the current cursor so that we can write numBytes of key array later.
/* 219 */ final int serializefromobject_tmpCursor1 = serializefromobject_holder.cursor;
/* 220 */
/* 221 */ if (serializefromobject_keys instanceof UnsafeArrayData) {
/* 222 */ final int serializefromobject_sizeInBytes1 = ((UnsafeArrayData) serializefromobject_keys).getSizeInBytes();
/* 223 */ // grow the global buffer before writing data.
/* 224 */ serializefromobject_holder.grow(serializefromobject_sizeInBytes1);
/* 225 */ ((UnsafeArrayData) serializefromobject_keys).writeToMemory(serializefromobject_holder.buffer, serializefromobject_holder.cursor);
/* 226 */ serializefromobject_holder.cursor += serializefromobject_sizeInBytes1;
/* 227 */
/* 228 */ } else {
/* 229 */ final int serializefromobject_numElements = serializefromobject_keys.numElements();
/* 230 */ serializefromobject_arrayWriter.initialize(serializefromobject_holder, serializefromobject_numElements, 4);
/* 231 */
/* 232 */ for (int serializefromobject_index1 = 0; serializefromobject_index1 < serializefromobject_numElements; serializefromobject_index1++) {
/* 233 */ if (serializefromobject_keys.isNullAt(serializefromobject_index1)) {
/* 234 */ serializefromobject_arrayWriter.setNullInt(serializefromobject_index1);
/* 235 */ } else {
/* 236 */ final int serializefromobject_element = serializefromobject_keys.getInt(serializefromobject_index1);
/* 237 */ serializefromobject_arrayWriter.write(serializefromobject_index1, serializefromobject_element);
/* 238 */ }
/* 239 */ }
/* 240 */ }
/* 241 */
/* 242 */ // Write the numBytes of key array into the first 8 bytes.
/* 243 */ Platform.putLong(serializefromobject_holder.buffer, serializefromobject_tmpCursor1 - 8, serializefromobject_holder.cursor - serializefromobject_tmpCursor1);
/* 244 */
/* 245 */ if (serializefromobject_values instanceof UnsafeArrayData) {
/* 246 */ final int serializefromobject_sizeInBytes2 = ((UnsafeArrayData) serializefromobject_values).getSizeInBytes();
/* 247 */ // grow the global buffer before writing data.
/* 248 */ serializefromobject_holder.grow(serializefromobject_sizeInBytes2);
/* 249 */ ((UnsafeArrayData) serializefromobject_values).writeToMemory(serializefromobject_holder.buffer, serializefromobject_holder.cursor);
/* 250 */ serializefromobject_holder.cursor += serializefromobject_sizeInBytes2;
/* 251 */
/* 252 */ } else {
/* 253 */ final int serializefromobject_numElements1 = serializefromobject_values.numElements();
/* 254 */ serializefromobject_arrayWriter1.initialize(serializefromobject_holder, serializefromobject_numElements1, 4);
/* 255 */
/* 256 */ for (int serializefromobject_index2 = 0; serializefromobject_index2 < serializefromobject_numElements1; serializefromobject_index2++) {
/* 257 */ if (serializefromobject_values.isNullAt(serializefromobject_index2)) {
/* 258 */ serializefromobject_arrayWriter1.setNullInt(serializefromobject_index2);
/* 259 */ } else {
/* 260 */ final int serializefromobject_element1 = serializefromobject_values.getInt(serializefromobject_index2);
/* 261 */ serializefromobject_arrayWriter1.write(serializefromobject_index2, serializefromobject_element1);
/* 262 */ }
/* 263 */ }
/* 264 */ }
/* 265 */
/* 266 */ }
/* 267 */
/* 268 */ serializefromobject_rowWriter.setOffsetAndSize(0, serializefromobject_tmpCursor, serializefromobject_holder.cursor - serializefromobject_tmpCursor);
/* 269 */ }
/* 270 */ serializefromobject_result.setTotalSize(serializefromobject_holder.totalSize());
/* 271 */ append(serializefromobject_result);
/* 272 */ if (shouldStop()) return;
/* 273 */ }
/* 274 */ }
/* 275 */ }
```
Codegen for `java.util.Map`:
```
/* 001 */ public Object generate(Object[] references) {
/* 002 */ return new GeneratedIterator(references);
/* 003 */ }
/* 004 */
/* 005 */ final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator {
/* 006 */ private Object[] references;
/* 007 */ private scala.collection.Iterator[] inputs;
/* 008 */ private scala.collection.Iterator inputadapter_input;
/* 009 */ private boolean CollectObjectsToMap_loopIsNull1;
/* 010 */ private int CollectObjectsToMap_loopValue0;
/* 011 */ private boolean CollectObjectsToMap_loopIsNull3;
/* 012 */ private int CollectObjectsToMap_loopValue2;
/* 013 */ private UnsafeRow deserializetoobject_result;
/* 014 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder deserializetoobject_holder;
/* 015 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter deserializetoobject_rowWriter;
/* 016 */ private java.util.HashMap mapelements_argValue;
/* 017 */ private UnsafeRow mapelements_result;
/* 018 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder mapelements_holder;
/* 019 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter mapelements_rowWriter;
/* 020 */ private UnsafeRow serializefromobject_result;
/* 021 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder serializefromobject_holder;
/* 022 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter serializefromobject_rowWriter;
/* 023 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter serializefromobject_arrayWriter;
/* 024 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter serializefromobject_arrayWriter1;
/* 025 */
/* 026 */ public GeneratedIterator(Object[] references) {
/* 027 */ this.references = references;
/* 028 */ }
/* 029 */
/* 030 */ public void init(int index, scala.collection.Iterator[] inputs) {
/* 031 */ partitionIndex = index;
/* 032 */ this.inputs = inputs;
/* 033 */ wholestagecodegen_init_0();
/* 034 */ wholestagecodegen_init_1();
/* 035 */
/* 036 */ }
/* 037 */
/* 038 */ private void wholestagecodegen_init_0() {
/* 039 */ inputadapter_input = inputs[0];
/* 040 */
/* 041 */ deserializetoobject_result = new UnsafeRow(1);
/* 042 */ this.deserializetoobject_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(deserializetoobject_result, 32);
/* 043 */ this.deserializetoobject_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(deserializetoobject_holder, 1);
/* 044 */
/* 045 */ mapelements_result = new UnsafeRow(1);
/* 046 */ this.mapelements_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(mapelements_result, 32);
/* 047 */ this.mapelements_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(mapelements_holder, 1);
/* 048 */ serializefromobject_result = new UnsafeRow(1);
/* 049 */ this.serializefromobject_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(serializefromobject_result, 32);
/* 050 */ this.serializefromobject_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(serializefromobject_holder, 1);
/* 051 */ this.serializefromobject_arrayWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter();
/* 052 */
/* 053 */ }
/* 054 */
/* 055 */ private void wholestagecodegen_init_1() {
/* 056 */ this.serializefromobject_arrayWriter1 = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter();
/* 057 */
/* 058 */ }
/* 059 */
/* 060 */ protected void processNext() throws java.io.IOException {
/* 061 */ while (inputadapter_input.hasNext() && !stopEarly()) {
/* 062 */ InternalRow inputadapter_row = (InternalRow) inputadapter_input.next();
/* 063 */ boolean inputadapter_isNull = inputadapter_row.isNullAt(0);
/* 064 */ MapData inputadapter_value = inputadapter_isNull ? null : (inputadapter_row.getMap(0));
/* 065 */
/* 066 */ boolean deserializetoobject_isNull1 = true;
/* 067 */ ArrayData deserializetoobject_value1 = null;
/* 068 */ if (!inputadapter_isNull) {
/* 069 */ deserializetoobject_isNull1 = false;
/* 070 */ if (!deserializetoobject_isNull1) {
/* 071 */ Object deserializetoobject_funcResult = null;
/* 072 */ deserializetoobject_funcResult = inputadapter_value.keyArray();
/* 073 */ if (deserializetoobject_funcResult == null) {
/* 074 */ deserializetoobject_isNull1 = true;
/* 075 */ } else {
/* 076 */ deserializetoobject_value1 = (ArrayData) deserializetoobject_funcResult;
/* 077 */ }
/* 078 */
/* 079 */ }
/* 080 */ deserializetoobject_isNull1 = deserializetoobject_value1 == null;
/* 081 */ }
/* 082 */
/* 083 */ boolean deserializetoobject_isNull3 = true;
/* 084 */ ArrayData deserializetoobject_value3 = null;
/* 085 */ if (!inputadapter_isNull) {
/* 086 */ deserializetoobject_isNull3 = false;
/* 087 */ if (!deserializetoobject_isNull3) {
/* 088 */ Object deserializetoobject_funcResult1 = null;
/* 089 */ deserializetoobject_funcResult1 = inputadapter_value.valueArray();
/* 090 */ if (deserializetoobject_funcResult1 == null) {
/* 091 */ deserializetoobject_isNull3 = true;
/* 092 */ } else {
/* 093 */ deserializetoobject_value3 = (ArrayData) deserializetoobject_funcResult1;
/* 094 */ }
/* 095 */
/* 096 */ }
/* 097 */ deserializetoobject_isNull3 = deserializetoobject_value3 == null;
/* 098 */ }
/* 099 */ java.util.HashMap deserializetoobject_value = null;
/* 100 */
/* 101 */ if ((deserializetoobject_isNull1 && !deserializetoobject_isNull3) ||
/* 102 */ (!deserializetoobject_isNull1 && deserializetoobject_isNull3)) {
/* 103 */ throw new RuntimeException("Invalid state: Inconsistent nullability of key-value");
/* 104 */ }
/* 105 */
/* 106 */ if (!deserializetoobject_isNull1) {
/* 107 */ if (deserializetoobject_value1.numElements() != deserializetoobject_value3.numElements()) {
/* 108 */ throw new RuntimeException("Invalid state: Inconsistent lengths of key-value arrays");
/* 109 */ }
/* 110 */ int deserializetoobject_dataLength = deserializetoobject_value1.numElements();
/* 111 */ java.util.Map CollectObjectsToMap_builderValue5 = new java.util.HashMap(deserializetoobject_dataLength);
/* 112 */
/* 113 */ int deserializetoobject_loopIndex = 0;
/* 114 */ while (deserializetoobject_loopIndex < deserializetoobject_dataLength) {
/* 115 */ CollectObjectsToMap_loopValue0 = (int) (deserializetoobject_value1.getInt(deserializetoobject_loopIndex));
/* 116 */ CollectObjectsToMap_loopValue2 = (int) (deserializetoobject_value3.getInt(deserializetoobject_loopIndex));
/* 117 */ CollectObjectsToMap_loopIsNull1 = deserializetoobject_value1.isNullAt(deserializetoobject_loopIndex);
/* 118 */ CollectObjectsToMap_loopIsNull3 = deserializetoobject_value3.isNullAt(deserializetoobject_loopIndex);
/* 119 */
/* 120 */ if (CollectObjectsToMap_loopIsNull1) {
/* 121 */ throw new RuntimeException("Found null in map key!");
/* 122 */ }
/* 123 */
/* 124 */ CollectObjectsToMap_builderValue5.put(CollectObjectsToMap_loopValue0, CollectObjectsToMap_loopValue2);
/* 125 */
/* 126 */ deserializetoobject_loopIndex += 1;
/* 127 */ }
/* 128 */
/* 129 */ deserializetoobject_value = (java.util.HashMap) CollectObjectsToMap_builderValue5;
/* 130 */ }
/* 131 */
/* 132 */ boolean mapelements_isNull = true;
/* 133 */ java.util.HashMap mapelements_value = null;
/* 134 */ if (!false) {
/* 135 */ mapelements_argValue = deserializetoobject_value;
/* 136 */
/* 137 */ mapelements_isNull = false;
/* 138 */ if (!mapelements_isNull) {
/* 139 */ Object mapelements_funcResult = null;
/* 140 */ mapelements_funcResult = ((scala.Function1) references[0]).apply(mapelements_argValue);
/* 141 */ if (mapelements_funcResult == null) {
/* 142 */ mapelements_isNull = true;
/* 143 */ } else {
/* 144 */ mapelements_value = (java.util.HashMap) mapelements_funcResult;
/* 145 */ }
/* 146 */
/* 147 */ }
/* 148 */ mapelements_isNull = mapelements_value == null;
/* 149 */ }
/* 150 */
/* 151 */ MapData serializefromobject_value = null;
/* 152 */ if (!mapelements_isNull) {
/* 153 */ final int serializefromobject_length = mapelements_value.size();
/* 154 */ final Object[] serializefromobject_convertedKeys = new Object[serializefromobject_length];
/* 155 */ final Object[] serializefromobject_convertedValues = new Object[serializefromobject_length];
/* 156 */ int serializefromobject_index = 0;
/* 157 */ final java.util.Iterator serializefromobject_entries = mapelements_value.entrySet().iterator();
/* 158 */ while(serializefromobject_entries.hasNext()) {
/* 159 */ final java.util.Map$Entry serializefromobject_entry = (java.util.Map$Entry) serializefromobject_entries.next();
/* 160 */ int ExternalMapToCatalyst_key1 = (Integer) serializefromobject_entry.getKey();
/* 161 */ int ExternalMapToCatalyst_value1 = (Integer) serializefromobject_entry.getValue();
/* 162 */
/* 163 */ boolean ExternalMapToCatalyst_value_isNull1 = false;
/* 164 */
/* 165 */ if (false) {
/* 166 */ throw new RuntimeException("Cannot use null as map key!");
/* 167 */ } else {
/* 168 */ serializefromobject_convertedKeys[serializefromobject_index] = (Integer) ExternalMapToCatalyst_key1;
/* 169 */ }
/* 170 */
/* 171 */ if (false) {
/* 172 */ serializefromobject_convertedValues[serializefromobject_index] = null;
/* 173 */ } else {
/* 174 */ serializefromobject_convertedValues[serializefromobject_index] = (Integer) ExternalMapToCatalyst_value1;
/* 175 */ }
/* 176 */
/* 177 */ serializefromobject_index++;
/* 178 */ }
/* 179 */
/* 180 */ serializefromobject_value = new org.apache.spark.sql.catalyst.util.ArrayBasedMapData(new org.apache.spark.sql.catalyst.util.GenericArrayData(serializefromobject_convertedKeys), new org.apache.spark.sql.catalyst.util.GenericArrayData(serializefromobject_convertedValues));
/* 181 */ }
/* 182 */ serializefromobject_holder.reset();
/* 183 */
/* 184 */ serializefromobject_rowWriter.zeroOutNullBytes();
/* 185 */
/* 186 */ if (mapelements_isNull) {
/* 187 */ serializefromobject_rowWriter.setNullAt(0);
/* 188 */ } else {
/* 189 */ // Remember the current cursor so that we can calculate how many bytes are
/* 190 */ // written later.
/* 191 */ final int serializefromobject_tmpCursor = serializefromobject_holder.cursor;
/* 192 */
/* 193 */ if (serializefromobject_value instanceof UnsafeMapData) {
/* 194 */ final int serializefromobject_sizeInBytes = ((UnsafeMapData) serializefromobject_value).getSizeInBytes();
/* 195 */ // grow the global buffer before writing data.
/* 196 */ serializefromobject_holder.grow(serializefromobject_sizeInBytes);
/* 197 */ ((UnsafeMapData) serializefromobject_value).writeToMemory(serializefromobject_holder.buffer, serializefromobject_holder.cursor);
/* 198 */ serializefromobject_holder.cursor += serializefromobject_sizeInBytes;
/* 199 */
/* 200 */ } else {
/* 201 */ final ArrayData serializefromobject_keys = serializefromobject_value.keyArray();
/* 202 */ final ArrayData serializefromobject_values = serializefromobject_value.valueArray();
/* 203 */
/* 204 */ // preserve 8 bytes to write the key array numBytes later.
/* 205 */ serializefromobject_holder.grow(8);
/* 206 */ serializefromobject_holder.cursor += 8;
/* 207 */
/* 208 */ // Remember the current cursor so that we can write numBytes of key array later.
/* 209 */ final int serializefromobject_tmpCursor1 = serializefromobject_holder.cursor;
/* 210 */
/* 211 */ if (serializefromobject_keys instanceof UnsafeArrayData) {
/* 212 */ final int serializefromobject_sizeInBytes1 = ((UnsafeArrayData) serializefromobject_keys).getSizeInBytes();
/* 213 */ // grow the global buffer before writing data.
/* 214 */ serializefromobject_holder.grow(serializefromobject_sizeInBytes1);
/* 215 */ ((UnsafeArrayData) serializefromobject_keys).writeToMemory(serializefromobject_holder.buffer, serializefromobject_holder.cursor);
/* 216 */ serializefromobject_holder.cursor += serializefromobject_sizeInBytes1;
/* 217 */
/* 218 */ } else {
/* 219 */ final int serializefromobject_numElements = serializefromobject_keys.numElements();
/* 220 */ serializefromobject_arrayWriter.initialize(serializefromobject_holder, serializefromobject_numElements, 4);
/* 221 */
/* 222 */ for (int serializefromobject_index1 = 0; serializefromobject_index1 < serializefromobject_numElements; serializefromobject_index1++) {
/* 223 */ if (serializefromobject_keys.isNullAt(serializefromobject_index1)) {
/* 224 */ serializefromobject_arrayWriter.setNullInt(serializefromobject_index1);
/* 225 */ } else {
/* 226 */ final int serializefromobject_element = serializefromobject_keys.getInt(serializefromobject_index1);
/* 227 */ serializefromobject_arrayWriter.write(serializefromobject_index1, serializefromobject_element);
/* 228 */ }
/* 229 */ }
/* 230 */ }
/* 231 */
/* 232 */ // Write the numBytes of key array into the first 8 bytes.
/* 233 */ Platform.putLong(serializefromobject_holder.buffer, serializefromobject_tmpCursor1 - 8, serializefromobject_holder.cursor - serializefromobject_tmpCursor1);
/* 234 */
/* 235 */ if (serializefromobject_values instanceof UnsafeArrayData) {
/* 236 */ final int serializefromobject_sizeInBytes2 = ((UnsafeArrayData) serializefromobject_values).getSizeInBytes();
/* 237 */ // grow the global buffer before writing data.
/* 238 */ serializefromobject_holder.grow(serializefromobject_sizeInBytes2);
/* 239 */ ((UnsafeArrayData) serializefromobject_values).writeToMemory(serializefromobject_holder.buffer, serializefromobject_holder.cursor);
/* 240 */ serializefromobject_holder.cursor += serializefromobject_sizeInBytes2;
/* 241 */
/* 242 */ } else {
/* 243 */ final int serializefromobject_numElements1 = serializefromobject_values.numElements();
/* 244 */ serializefromobject_arrayWriter1.initialize(serializefromobject_holder, serializefromobject_numElements1, 4);
/* 245 */
/* 246 */ for (int serializefromobject_index2 = 0; serializefromobject_index2 < serializefromobject_numElements1; serializefromobject_index2++) {
/* 247 */ if (serializefromobject_values.isNullAt(serializefromobject_index2)) {
/* 248 */ serializefromobject_arrayWriter1.setNullInt(serializefromobject_index2);
/* 249 */ } else {
/* 250 */ final int serializefromobject_element1 = serializefromobject_values.getInt(serializefromobject_index2);
/* 251 */ serializefromobject_arrayWriter1.write(serializefromobject_index2, serializefromobject_element1);
/* 252 */ }
/* 253 */ }
/* 254 */ }
/* 255 */
/* 256 */ }
/* 257 */
/* 258 */ serializefromobject_rowWriter.setOffsetAndSize(0, serializefromobject_tmpCursor, serializefromobject_holder.cursor - serializefromobject_tmpCursor);
/* 259 */ }
/* 260 */ serializefromobject_result.setTotalSize(serializefromobject_holder.totalSize());
/* 261 */ append(serializefromobject_result);
/* 262 */ if (shouldStop()) return;
/* 263 */ }
/* 264 */ }
/* 265 */ }
```
## How was this patch tested?
```
build/mvn -DskipTests clean package && dev/run-tests
```
Additionally in Spark shell:
```
scala> Seq(collection.mutable.HashMap(1 -> 2, 2 -> 3)).toDS().map(_ += (3 -> 4)).collect()
res0: Array[scala.collection.mutable.HashMap[Int,Int]] = Array(Map(2 -> 3, 1 -> 2, 3 -> 4))
```
Author: Michal Senkyr <mi...@gmail.com>
Author: Michal Šenkýř <mi...@gmail.com>
Closes #16986 from michalsenkyr/dataset-map-builder.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/0538f3b0
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/0538f3b0
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/0538f3b0
Branch: refs/heads/master
Commit: 0538f3b0ae4b80750ab81b210ad6fe77178337bf
Parents: a7c61c1
Author: Michal Senkyr <mi...@gmail.com>
Authored: Mon Jun 12 08:47:01 2017 +0800
Committer: Wenchen Fan <we...@databricks.com>
Committed: Mon Jun 12 08:47:01 2017 +0800
----------------------------------------------------------------------
.../spark/sql/catalyst/ScalaReflection.scala | 33 +---
.../catalyst/expressions/objects/objects.scala | 169 ++++++++++++++++++-
.../sql/catalyst/ScalaReflectionSuite.scala | 25 +++
.../org/apache/spark/sql/SQLImplicits.scala | 5 +
.../spark/sql/DatasetPrimitiveSuite.scala | 86 ++++++++++
5 files changed, 291 insertions(+), 27 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/0538f3b0/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 8713053..d580cf4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst
import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedAttribute, UnresolvedExtractValue}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.objects._
-import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData}
+import org.apache.spark.sql.catalyst.util.{DateTimeUtils, GenericArrayData}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
@@ -335,31 +335,12 @@ object ScalaReflection extends ScalaReflection {
// TODO: add walked type path for map
val TypeRef(_, _, Seq(keyType, valueType)) = t
- val keyData =
- Invoke(
- MapObjects(
- p => deserializerFor(keyType, Some(p), walkedTypePath),
- Invoke(getPath, "keyArray", ArrayType(schemaFor(keyType).dataType),
- returnNullable = false),
- schemaFor(keyType).dataType),
- "array",
- ObjectType(classOf[Array[Any]]), returnNullable = false)
-
- val valueData =
- Invoke(
- MapObjects(
- p => deserializerFor(valueType, Some(p), walkedTypePath),
- Invoke(getPath, "valueArray", ArrayType(schemaFor(valueType).dataType),
- returnNullable = false),
- schemaFor(valueType).dataType),
- "array",
- ObjectType(classOf[Array[Any]]), returnNullable = false)
-
- StaticInvoke(
- ArrayBasedMapData.getClass,
- ObjectType(classOf[scala.collection.immutable.Map[_, _]]),
- "toScalaMap",
- keyData :: valueData :: Nil)
+ CollectObjectsToMap(
+ p => deserializerFor(keyType, Some(p), walkedTypePath),
+ p => deserializerFor(valueType, Some(p), walkedTypePath),
+ getPath,
+ mirror.runtimeClass(t.typeSymbol.asClass)
+ )
case t if t.typeSymbol.annotations.exists(_.tpe =:= typeOf[SQLUserDefinedType]) =>
val udt = getClassFromType(t).getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance()
http://git-wip-us.apache.org/repos/asf/spark/blob/0538f3b0/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index 1a202ec..79b7b9f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
-import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData}
+import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData}
import org.apache.spark.sql.types._
/**
@@ -652,6 +652,173 @@ case class MapObjects private(
}
}
+object CollectObjectsToMap {
+ private val curId = new java.util.concurrent.atomic.AtomicInteger()
+
+ /**
+ * Construct an instance of CollectObjectsToMap case class.
+ *
+ * @param keyFunction The function applied on the key collection elements.
+ * @param valueFunction The function applied on the value collection elements.
+ * @param inputData An expression that when evaluated returns a map object.
+ * @param collClass The type of the resulting collection.
+ */
+ def apply(
+ keyFunction: Expression => Expression,
+ valueFunction: Expression => Expression,
+ inputData: Expression,
+ collClass: Class[_]): CollectObjectsToMap = {
+ val id = curId.getAndIncrement()
+ val keyLoopValue = s"CollectObjectsToMap_keyLoopValue$id"
+ val mapType = inputData.dataType.asInstanceOf[MapType]
+ val keyLoopVar = LambdaVariable(keyLoopValue, "", mapType.keyType, nullable = false)
+ val valueLoopValue = s"CollectObjectsToMap_valueLoopValue$id"
+ val valueLoopIsNull = s"CollectObjectsToMap_valueLoopIsNull$id"
+ val valueLoopVar = LambdaVariable(valueLoopValue, valueLoopIsNull, mapType.valueType)
+ CollectObjectsToMap(
+ keyLoopValue, keyFunction(keyLoopVar),
+ valueLoopValue, valueLoopIsNull, valueFunction(valueLoopVar),
+ inputData, collClass)
+ }
+}
+
+/**
+ * Expression used to convert a Catalyst Map to an external Scala Map.
+ * The collection is constructed using the associated builder, obtained by calling `newBuilder`
+ * on the collection's companion object.
+ *
+ * @param keyLoopValue the name of the loop variable that is used when iterating over the key
+ * collection, and which is used as input for the `keyLambdaFunction`
+ * @param keyLambdaFunction A function that takes the `keyLoopVar` as input, and is used as
+ * a lambda function to handle collection elements.
+ * @param valueLoopValue the name of the loop variable that is used when iterating over the value
+ * collection, and which is used as input for the `valueLambdaFunction`
+ * @param valueLoopIsNull the nullability of the loop variable that is used when iterating over
+ * the value collection, and which is used as input for the
+ * `valueLambdaFunction`
+ * @param valueLambdaFunction A function that takes the `valueLoopVar` as input, and is used as
+ * a lambda function to handle collection elements.
+ * @param inputData An expression that when evaluated returns a map object.
+ * @param collClass The type of the resulting collection.
+ */
+case class CollectObjectsToMap private(
+ keyLoopValue: String,
+ keyLambdaFunction: Expression,
+ valueLoopValue: String,
+ valueLoopIsNull: String,
+ valueLambdaFunction: Expression,
+ inputData: Expression,
+ collClass: Class[_]) extends Expression with NonSQLExpression {
+
+ override def nullable: Boolean = inputData.nullable
+
+ override def children: Seq[Expression] =
+ keyLambdaFunction :: valueLambdaFunction :: inputData :: Nil
+
+ override def eval(input: InternalRow): Any =
+ throw new UnsupportedOperationException("Only code-generated evaluation is supported")
+
+ override def dataType: DataType = ObjectType(collClass)
+
+ override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ // The data with PythonUserDefinedType are actually stored with the data type of its sqlType.
+ // When we want to apply MapObjects on it, we have to use it.
+ def inputDataType(dataType: DataType) = dataType match {
+ case p: PythonUserDefinedType => p.sqlType
+ case _ => dataType
+ }
+
+ val mapType = inputDataType(inputData.dataType).asInstanceOf[MapType]
+ val keyElementJavaType = ctx.javaType(mapType.keyType)
+ ctx.addMutableState(keyElementJavaType, keyLoopValue, "")
+ val genKeyFunction = keyLambdaFunction.genCode(ctx)
+ val valueElementJavaType = ctx.javaType(mapType.valueType)
+ ctx.addMutableState("boolean", valueLoopIsNull, "")
+ ctx.addMutableState(valueElementJavaType, valueLoopValue, "")
+ val genValueFunction = valueLambdaFunction.genCode(ctx)
+ val genInputData = inputData.genCode(ctx)
+ val dataLength = ctx.freshName("dataLength")
+ val loopIndex = ctx.freshName("loopIndex")
+ val tupleLoopValue = ctx.freshName("tupleLoopValue")
+ val builderValue = ctx.freshName("builderValue")
+
+ val getLength = s"${genInputData.value}.numElements()"
+
+ val keyArray = ctx.freshName("keyArray")
+ val valueArray = ctx.freshName("valueArray")
+ val getKeyArray =
+ s"${classOf[ArrayData].getName} $keyArray = ${genInputData.value}.keyArray();"
+ val getKeyLoopVar = ctx.getValue(keyArray, inputDataType(mapType.keyType), loopIndex)
+ val getValueArray =
+ s"${classOf[ArrayData].getName} $valueArray = ${genInputData.value}.valueArray();"
+ val getValueLoopVar = ctx.getValue(valueArray, inputDataType(mapType.valueType), loopIndex)
+
+ // Make a copy of the data if it's unsafe-backed
+ def makeCopyIfInstanceOf(clazz: Class[_ <: Any], value: String) =
+ s"$value instanceof ${clazz.getSimpleName}? $value.copy() : $value"
+ def genFunctionValue(lambdaFunction: Expression, genFunction: ExprCode) =
+ lambdaFunction.dataType match {
+ case StructType(_) => makeCopyIfInstanceOf(classOf[UnsafeRow], genFunction.value)
+ case ArrayType(_, _) => makeCopyIfInstanceOf(classOf[UnsafeArrayData], genFunction.value)
+ case MapType(_, _, _) => makeCopyIfInstanceOf(classOf[UnsafeMapData], genFunction.value)
+ case _ => genFunction.value
+ }
+ val genKeyFunctionValue = genFunctionValue(keyLambdaFunction, genKeyFunction)
+ val genValueFunctionValue = genFunctionValue(valueLambdaFunction, genValueFunction)
+
+ val valueLoopNullCheck = s"$valueLoopIsNull = $valueArray.isNullAt($loopIndex);"
+
+ val builderClass = classOf[Builder[_, _]].getName
+ val constructBuilder = s"""
+ $builderClass $builderValue = ${collClass.getName}$$.MODULE$$.newBuilder();
+ $builderValue.sizeHint($dataLength);
+ """
+
+ val tupleClass = classOf[(_, _)].getName
+ val appendToBuilder = s"""
+ $tupleClass $tupleLoopValue;
+
+ if (${genValueFunction.isNull}) {
+ $tupleLoopValue = new $tupleClass($genKeyFunctionValue, null);
+ } else {
+ $tupleLoopValue = new $tupleClass($genKeyFunctionValue, $genValueFunctionValue);
+ }
+
+ $builderValue.$$plus$$eq($tupleLoopValue);
+ """
+ val getBuilderResult = s"${ev.value} = (${collClass.getName}) $builderValue.result();"
+
+ val code = s"""
+ ${genInputData.code}
+ ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
+
+ if (!${genInputData.isNull}) {
+ int $dataLength = $getLength;
+ $constructBuilder
+ $getKeyArray
+ $getValueArray
+
+ int $loopIndex = 0;
+ while ($loopIndex < $dataLength) {
+ $keyLoopValue = ($keyElementJavaType) ($getKeyLoopVar);
+ $valueLoopValue = ($valueElementJavaType) ($getValueLoopVar);
+ $valueLoopNullCheck
+
+ ${genKeyFunction.code}
+ ${genValueFunction.code}
+
+ $appendToBuilder
+
+ $loopIndex += 1;
+ }
+
+ $getBuilderResult
+ }
+ """
+ ev.copy(code = code, isNull = genInputData.isNull)
+ }
+}
+
object ExternalMapToCatalyst {
private val curId = new java.util.concurrent.atomic.AtomicInteger()
http://git-wip-us.apache.org/repos/asf/spark/blob/0538f3b0/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
index 70ad064..ff2414b 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
@@ -314,6 +314,31 @@ class ScalaReflectionSuite extends SparkFunSuite {
assert(arrayBufferDeserializer.dataType == ObjectType(classOf[ArrayBuffer[_]]))
}
+ test("serialize and deserialize arbitrary map types") {
+ val mapSerializer = serializerFor[Map[Int, Int]](BoundReference(
+ 0, ObjectType(classOf[Map[Int, Int]]), nullable = false))
+ assert(mapSerializer.dataType.head.dataType ==
+ MapType(IntegerType, IntegerType, valueContainsNull = false))
+ val mapDeserializer = deserializerFor[Map[Int, Int]]
+ assert(mapDeserializer.dataType == ObjectType(classOf[Map[_, _]]))
+
+ import scala.collection.immutable.HashMap
+ val hashMapSerializer = serializerFor[HashMap[Int, Int]](BoundReference(
+ 0, ObjectType(classOf[HashMap[Int, Int]]), nullable = false))
+ assert(hashMapSerializer.dataType.head.dataType ==
+ MapType(IntegerType, IntegerType, valueContainsNull = false))
+ val hashMapDeserializer = deserializerFor[HashMap[Int, Int]]
+ assert(hashMapDeserializer.dataType == ObjectType(classOf[HashMap[_, _]]))
+
+ import scala.collection.mutable.{LinkedHashMap => LHMap}
+ val linkedHashMapSerializer = serializerFor[LHMap[Long, String]](BoundReference(
+ 0, ObjectType(classOf[LHMap[Long, String]]), nullable = false))
+ assert(linkedHashMapSerializer.dataType.head.dataType ==
+ MapType(LongType, StringType, valueContainsNull = true))
+ val linkedHashMapDeserializer = deserializerFor[LHMap[Long, String]]
+ assert(linkedHashMapDeserializer.dataType == ObjectType(classOf[LHMap[_, _]]))
+ }
+
private val dataTypeForComplexData = dataTypeFor[ComplexData]
private val typeOfComplexData = typeOf[ComplexData]
http://git-wip-us.apache.org/repos/asf/spark/blob/0538f3b0/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
index 17671ea..86574e2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql
+import scala.collection.Map
import scala.language.implicitConversions
import scala.reflect.runtime.universe.TypeTag
@@ -166,6 +167,10 @@ abstract class SQLImplicits extends LowPrioritySQLImplicits {
/** @since 2.2.0 */
implicit def newSequenceEncoder[T <: Seq[_] : TypeTag]: Encoder[T] = ExpressionEncoder()
+ // Maps
+ /** @since 2.3.0 */
+ implicit def newMapEncoder[T <: Map[_, _] : TypeTag]: Encoder[T] = ExpressionEncoder()
+
// Arrays
/** @since 1.6.1 */
http://git-wip-us.apache.org/repos/asf/spark/blob/0538f3b0/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
index 7e2949a..4126660 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql
import scala.collection.immutable.Queue
+import scala.collection.mutable.{LinkedHashMap => LHMap}
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.sql.test.SharedSQLContext
@@ -30,8 +31,14 @@ case class ListClass(l: List[Int])
case class QueueClass(q: Queue[Int])
+case class MapClass(m: Map[Int, Int])
+
+case class LHMapClass(m: LHMap[Int, Int])
+
case class ComplexClass(seq: SeqClass, list: ListClass, queue: QueueClass)
+case class ComplexMapClass(map: MapClass, lhmap: LHMapClass)
+
package object packageobject {
case class PackageClass(value: Int)
}
@@ -258,11 +265,90 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext {
ListClass(List(1)) -> Queue("test" -> SeqClass(Seq(2))))
}
+ test("arbitrary maps") {
+ checkDataset(Seq(Map(1 -> 2)).toDS(), Map(1 -> 2))
+ checkDataset(Seq(Map(1.toLong -> 2.toLong)).toDS(), Map(1.toLong -> 2.toLong))
+ checkDataset(Seq(Map(1.toDouble -> 2.toDouble)).toDS(), Map(1.toDouble -> 2.toDouble))
+ checkDataset(Seq(Map(1.toFloat -> 2.toFloat)).toDS(), Map(1.toFloat -> 2.toFloat))
+ checkDataset(Seq(Map(1.toByte -> 2.toByte)).toDS(), Map(1.toByte -> 2.toByte))
+ checkDataset(Seq(Map(1.toShort -> 2.toShort)).toDS(), Map(1.toShort -> 2.toShort))
+ checkDataset(Seq(Map(true -> false)).toDS(), Map(true -> false))
+ checkDataset(Seq(Map("test1" -> "test2")).toDS(), Map("test1" -> "test2"))
+ checkDataset(Seq(Map(Tuple1(1) -> Tuple1(2))).toDS(), Map(Tuple1(1) -> Tuple1(2)))
+ checkDataset(Seq(Map(1 -> Tuple1(2))).toDS(), Map(1 -> Tuple1(2)))
+ checkDataset(Seq(Map("test" -> 2.toLong)).toDS(), Map("test" -> 2.toLong))
+
+ checkDataset(Seq(LHMap(1 -> 2)).toDS(), LHMap(1 -> 2))
+ checkDataset(Seq(LHMap(1.toLong -> 2.toLong)).toDS(), LHMap(1.toLong -> 2.toLong))
+ checkDataset(Seq(LHMap(1.toDouble -> 2.toDouble)).toDS(), LHMap(1.toDouble -> 2.toDouble))
+ checkDataset(Seq(LHMap(1.toFloat -> 2.toFloat)).toDS(), LHMap(1.toFloat -> 2.toFloat))
+ checkDataset(Seq(LHMap(1.toByte -> 2.toByte)).toDS(), LHMap(1.toByte -> 2.toByte))
+ checkDataset(Seq(LHMap(1.toShort -> 2.toShort)).toDS(), LHMap(1.toShort -> 2.toShort))
+ checkDataset(Seq(LHMap(true -> false)).toDS(), LHMap(true -> false))
+ checkDataset(Seq(LHMap("test1" -> "test2")).toDS(), LHMap("test1" -> "test2"))
+ checkDataset(Seq(LHMap(Tuple1(1) -> Tuple1(2))).toDS(), LHMap(Tuple1(1) -> Tuple1(2)))
+ checkDataset(Seq(LHMap(1 -> Tuple1(2))).toDS(), LHMap(1 -> Tuple1(2)))
+ checkDataset(Seq(LHMap("test" -> 2.toLong)).toDS(), LHMap("test" -> 2.toLong))
+ }
+
+ ignore("SPARK-19104: map and product combinations") {
+ // Case classes
+ checkDataset(Seq(MapClass(Map(1 -> 2))).toDS(), MapClass(Map(1 -> 2)))
+ checkDataset(Seq(Map(1 -> MapClass(Map(2 -> 3)))).toDS(), Map(1 -> MapClass(Map(2 -> 3))))
+ checkDataset(Seq(Map(MapClass(Map(1 -> 2)) -> 3)).toDS(), Map(MapClass(Map(1 -> 2)) -> 3))
+ checkDataset(Seq(Map(MapClass(Map(1 -> 2)) -> MapClass(Map(3 -> 4)))).toDS(),
+ Map(MapClass(Map(1 -> 2)) -> MapClass(Map(3 -> 4))))
+ checkDataset(Seq(LHMap(1 -> MapClass(Map(2 -> 3)))).toDS(), LHMap(1 -> MapClass(Map(2 -> 3))))
+ checkDataset(Seq(LHMap(MapClass(Map(1 -> 2)) -> 3)).toDS(), LHMap(MapClass(Map(1 -> 2)) -> 3))
+ checkDataset(Seq(LHMap(MapClass(Map(1 -> 2)) -> MapClass(Map(3 -> 4)))).toDS(),
+ LHMap(MapClass(Map(1 -> 2)) -> MapClass(Map(3 -> 4))))
+
+ checkDataset(Seq(LHMapClass(LHMap(1 -> 2))).toDS(), LHMapClass(LHMap(1 -> 2)))
+ checkDataset(Seq(Map(1 -> LHMapClass(LHMap(2 -> 3)))).toDS(),
+ Map(1 -> LHMapClass(LHMap(2 -> 3))))
+ checkDataset(Seq(Map(LHMapClass(LHMap(1 -> 2)) -> 3)).toDS(),
+ Map(LHMapClass(LHMap(1 -> 2)) -> 3))
+ checkDataset(Seq(Map(LHMapClass(LHMap(1 -> 2)) -> LHMapClass(LHMap(3 -> 4)))).toDS(),
+ Map(LHMapClass(LHMap(1 -> 2)) -> LHMapClass(LHMap(3 -> 4))))
+ checkDataset(Seq(LHMap(1 -> LHMapClass(LHMap(2 -> 3)))).toDS(),
+ LHMap(1 -> LHMapClass(LHMap(2 -> 3))))
+ checkDataset(Seq(LHMap(LHMapClass(LHMap(1 -> 2)) -> 3)).toDS(),
+ LHMap(LHMapClass(LHMap(1 -> 2)) -> 3))
+ checkDataset(Seq(LHMap(LHMapClass(LHMap(1 -> 2)) -> LHMapClass(LHMap(3 -> 4)))).toDS(),
+ LHMap(LHMapClass(LHMap(1 -> 2)) -> LHMapClass(LHMap(3 -> 4))))
+
+ val complex = ComplexMapClass(MapClass(Map(1 -> 2)), LHMapClass(LHMap(3 -> 4)))
+ checkDataset(Seq(complex).toDS(), complex)
+ checkDataset(Seq(Map(1 -> complex)).toDS(), Map(1 -> complex))
+ checkDataset(Seq(Map(complex -> 5)).toDS(), Map(complex -> 5))
+ checkDataset(Seq(Map(complex -> complex)).toDS(), Map(complex -> complex))
+ checkDataset(Seq(LHMap(1 -> complex)).toDS(), LHMap(1 -> complex))
+ checkDataset(Seq(LHMap(complex -> 5)).toDS(), LHMap(complex -> 5))
+ checkDataset(Seq(LHMap(complex -> complex)).toDS(), LHMap(complex -> complex))
+
+ // Tuples
+ checkDataset(Seq(Map(1 -> 2) -> Map(3 -> 4)).toDS(), Map(1 -> 2) -> Map(3 -> 4))
+ checkDataset(Seq(LHMap(1 -> 2) -> Map(3 -> 4)).toDS(), LHMap(1 -> 2) -> Map(3 -> 4))
+ checkDataset(Seq(Map(1 -> 2) -> LHMap(3 -> 4)).toDS(), Map(1 -> 2) -> LHMap(3 -> 4))
+ checkDataset(Seq(LHMap(1 -> 2) -> LHMap(3 -> 4)).toDS(), LHMap(1 -> 2) -> LHMap(3 -> 4))
+ checkDataset(Seq(LHMap((Map("test1" -> 1) -> 2) -> (3 -> LHMap(4 -> "test2")))).toDS(),
+ LHMap((Map("test1" -> 1) -> 2) -> (3 -> LHMap(4 -> "test2"))))
+
+ // Complex
+ checkDataset(Seq(LHMapClass(LHMap(1 -> 2)) -> LHMap("test" -> MapClass(Map(3 -> 4)))).toDS(),
+ LHMapClass(LHMap(1 -> 2)) -> LHMap("test" -> MapClass(Map(3 -> 4))))
+ }
+
test("nested sequences") {
checkDataset(Seq(Seq(Seq(1))).toDS(), Seq(Seq(1)))
checkDataset(Seq(List(Queue(1))).toDS(), List(Queue(1)))
}
+ test("nested maps") {
+ checkDataset(Seq(Map(1 -> LHMap(2 -> 3))).toDS(), Map(1 -> LHMap(2 -> 3)))
+ checkDataset(Seq(LHMap(Map(1 -> 2) -> 3)).toDS(), LHMap(Map(1 -> 2) -> 3))
+ }
+
test("package objects") {
import packageobject._
checkDataset(Seq(PackageClass(1)).toDS(), PackageClass(1))
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org