You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2017/06/12 00:53:28 UTC

spark git commit: [SPARK-18891][SQL] Support for specific Java List subtypes

Repository: spark
Updated Branches:
  refs/heads/master 0538f3b0a -> f48273c13


[SPARK-18891][SQL] Support for specific Java List subtypes

## What changes were proposed in this pull request?

Add support for specific Java `List` subtypes in deserialization as well as a generic implicit encoder.

All `List` subtypes are supported by using either the size-specifying constructor (one `int` parameter) or the default constructor.

Interfaces/abstract classes use the following implementations:

* `java.util.List`, `java.util.AbstractList` or `java.util.AbstractSequentialList` => `java.util.ArrayList`

## How was this patch tested?

```bash
build/mvn -DskipTests clean package && dev/run-tests
```

Additionally in Spark shell:

```
scala> val jlist = new java.util.LinkedList[Int]; jlist.add(1)
jlist: java.util.LinkedList[Int] = [1]
res0: Boolean = true

scala> Seq(jlist).toDS().map(_.element()).collect()
res1: Array[Int] = Array(1)
```

Author: Michal Senkyr <mi...@gmail.com>

Closes #18009 from michalsenkyr/dataset-java-lists.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/f48273c1
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/f48273c1
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/f48273c1

Branch: refs/heads/master
Commit: f48273c13c9e9fea2d9bb6dda10fcaaaaa50c588
Parents: 0538f3b
Author: Michal Senkyr <mi...@gmail.com>
Authored: Mon Jun 12 08:53:23 2017 +0800
Committer: Wenchen Fan <we...@databricks.com>
Committed: Mon Jun 12 08:53:23 2017 +0800

----------------------------------------------------------------------
 .../spark/sql/catalyst/JavaTypeInference.scala  | 15 ++---
 .../catalyst/expressions/objects/objects.scala  | 19 +++++-
 .../org/apache/spark/sql/JavaDatasetSuite.java  | 61 ++++++++++++++++++++
 3 files changed, 83 insertions(+), 12 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/f48273c1/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
index 86a73a3..7683ee7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
@@ -267,16 +267,11 @@ object JavaTypeInference {
 
       case c if listType.isAssignableFrom(typeToken) =>
         val et = elementType(typeToken)
-        val array =
-          Invoke(
-            MapObjects(
-              p => deserializerFor(et, Some(p)),
-              getPath,
-              inferDataType(et)._1),
-            "array",
-            ObjectType(classOf[Array[Any]]))
-
-        StaticInvoke(classOf[java.util.Arrays], ObjectType(c), "asList", array :: Nil)
+        MapObjects(
+          p => deserializerFor(et, Some(p)),
+          getPath,
+          inferDataType(et)._1,
+          customCollectionCls = Some(c))
 
       case _ if mapType.isAssignableFrom(typeToken) =>
         val (keyType, valueType) = mapKeyValueType(typeToken)

http://git-wip-us.apache.org/repos/asf/spark/blob/f48273c1/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index 79b7b9f..5bb0feb 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -22,6 +22,7 @@ import java.lang.reflect.Modifier
 import scala.collection.mutable.Builder
 import scala.language.existentials
 import scala.reflect.ClassTag
+import scala.util.Try
 
 import org.apache.spark.{SparkConf, SparkEnv}
 import org.apache.spark.serializer._
@@ -597,8 +598,8 @@ case class MapObjects private(
 
     val (initCollection, addElement, getResult): (String, String => String, String) =
       customCollectionCls match {
-        case Some(cls) =>
-          // collection
+        case Some(cls) if classOf[Seq[_]].isAssignableFrom(cls) =>
+          // Scala sequence
           val getBuilder = s"${cls.getName}$$.MODULE$$.newBuilder()"
           val builder = ctx.freshName("collectionBuilder")
           (
@@ -609,6 +610,20 @@ case class MapObjects private(
             genValue => s"$builder.$$plus$$eq($genValue);",
             s"(${cls.getName}) $builder.result();"
           )
+        case Some(cls) if classOf[java.util.List[_]].isAssignableFrom(cls) =>
+          // Java list
+          val builder = ctx.freshName("collectionBuilder")
+          (
+            if (cls == classOf[java.util.List[_]] || cls == classOf[java.util.AbstractList[_]] ||
+              cls == classOf[java.util.AbstractSequentialList[_]]) {
+              s"${cls.getName} $builder = new java.util.ArrayList($dataLength);"
+            } else {
+              val param = Try(cls.getConstructor(Integer.TYPE)).map(_ => dataLength).getOrElse("")
+              s"${cls.getName} $builder = new ${cls.getName}($param);"
+            },
+            genValue => s"$builder.add($genValue);",
+            s"$builder;"
+          )
         case None =>
           // array
           (

http://git-wip-us.apache.org/repos/asf/spark/blob/f48273c1/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
----------------------------------------------------------------------
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
index 3ba37ad..4ca3b64 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
@@ -1399,4 +1399,65 @@ public class JavaDatasetSuite implements Serializable {
       ds1.map((MapFunction<NestedSmallBean, NestedSmallBean>) b -> b, encoder);
     Assert.assertEquals(beans, ds2.collectAsList());
   }
+
+  @Test
+  public void testSpecificLists() {
+    SpecificListsBean bean = new SpecificListsBean();
+    ArrayList<Integer> arrayList = new ArrayList<>();
+    arrayList.add(1);
+    bean.setArrayList(arrayList);
+    LinkedList<Integer> linkedList = new LinkedList<>();
+    linkedList.add(1);
+    bean.setLinkedList(linkedList);
+    bean.setList(Collections.singletonList(1));
+    List<SpecificListsBean> beans = Collections.singletonList(bean);
+    Dataset<SpecificListsBean> dataset =
+      spark.createDataset(beans, Encoders.bean(SpecificListsBean.class));
+    Assert.assertEquals(beans, dataset.collectAsList());
+  }
+
+  public static class SpecificListsBean implements Serializable {
+    private ArrayList<Integer> arrayList;
+    private LinkedList<Integer> linkedList;
+    private List<Integer> list;
+
+    public ArrayList<Integer> getArrayList() {
+      return arrayList;
+    }
+
+    public void setArrayList(ArrayList<Integer> arrayList) {
+      this.arrayList = arrayList;
+    }
+
+    public LinkedList<Integer> getLinkedList() {
+      return linkedList;
+    }
+
+    public void setLinkedList(LinkedList<Integer> linkedList) {
+      this.linkedList = linkedList;
+    }
+
+    public List<Integer> getList() {
+      return list;
+    }
+
+    public void setList(List<Integer> list) {
+      this.list = list;
+    }
+
+    @Override
+    public boolean equals(Object o) {
+      if (this == o) return true;
+      if (o == null || getClass() != o.getClass()) return false;
+      SpecificListsBean that = (SpecificListsBean) o;
+      return Objects.equal(arrayList, that.arrayList) &&
+        Objects.equal(linkedList, that.linkedList) &&
+        Objects.equal(list, that.list);
+    }
+
+    @Override
+    public int hashCode() {
+      return Objects.hashCode(arrayList, linkedList, list);
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org