You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ma...@apache.org on 2015/04/21 23:50:11 UTC

spark git commit: [SPARK-6996][SQL] Support map types in java beans

Repository: spark
Updated Branches:
  refs/heads/master 6265cba00 -> 2a24bf92e


[SPARK-6996][SQL] Support map types in java beans

liancheng mengxr this is similar to #5146.

Author: Punya Biswal <pb...@palantir.com>

Closes #5578 from punya/feature/SPARK-6996 and squashes the following commits:

d56c3e0 [Punya Biswal] Fix imports
c7e308b [Punya Biswal] Support java iterable types in POJOs
5e00685 [Punya Biswal] Support map types in java beans


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/2a24bf92
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/2a24bf92
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/2a24bf92

Branch: refs/heads/master
Commit: 2a24bf92e6d36e876bad6a8b4e0ff12c407ebb8a
Parents: 6265cba
Author: Punya Biswal <pb...@palantir.com>
Authored: Tue Apr 21 14:50:02 2015 -0700
Committer: Michael Armbrust <mi...@databricks.com>
Committed: Tue Apr 21 14:50:02 2015 -0700

----------------------------------------------------------------------
 .../sql/catalyst/CatalystTypeConverters.scala   |  20 ++++
 .../apache/spark/sql/JavaTypeInference.scala    | 110 +++++++++++++++++++
 .../scala/org/apache/spark/sql/SQLContext.scala |  52 +--------
 .../apache/spark/sql/JavaDataFrameSuite.java    |  57 ++++++++--
 4 files changed, 180 insertions(+), 59 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/2a24bf92/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
index d4f9fda..a13e2f3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
@@ -17,6 +17,7 @@
 
 package org.apache.spark.sql.catalyst
 
+import java.lang.{Iterable => JavaIterable}
 import java.util.{Map => JavaMap}
 
 import scala.collection.mutable.HashMap
@@ -49,6 +50,16 @@ object CatalystTypeConverters {
     case (s: Seq[_], arrayType: ArrayType) =>
       s.map(convertToCatalyst(_, arrayType.elementType))
 
+    case (jit: JavaIterable[_], arrayType: ArrayType) => {
+      val iter = jit.iterator
+      var listOfItems: List[Any] = List()
+      while (iter.hasNext) {
+        val item = iter.next()
+        listOfItems :+= convertToCatalyst(item, arrayType.elementType)
+      }
+      listOfItems
+    }
+
     case (s: Array[_], arrayType: ArrayType) =>
       s.toSeq.map(convertToCatalyst(_, arrayType.elementType))
 
@@ -124,6 +135,15 @@ object CatalystTypeConverters {
           extractOption(item) match {
             case a: Array[_] => a.toSeq.map(elementConverter)
             case s: Seq[_] => s.map(elementConverter)
+            case i: JavaIterable[_] => {
+              val iter = i.iterator
+              var convertedIterable: List[Any] = List()
+              while (iter.hasNext) {
+                val item = iter.next()
+                convertedIterable :+= elementConverter(item)
+              }
+              convertedIterable
+            }
             case null => null
           }
         }

http://git-wip-us.apache.org/repos/asf/spark/blob/2a24bf92/sql/core/src/main/scala/org/apache/spark/sql/JavaTypeInference.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/JavaTypeInference.scala b/sql/core/src/main/scala/org/apache/spark/sql/JavaTypeInference.scala
new file mode 100644
index 0000000..db484c5
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/JavaTypeInference.scala
@@ -0,0 +1,110 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import java.beans.Introspector
+import java.lang.{Iterable => JIterable}
+import java.util.{Iterator => JIterator, Map => JMap}
+
+import com.google.common.reflect.TypeToken
+
+import org.apache.spark.sql.types._
+
+import scala.language.existentials
+
+/**
+ * Type-inference utilities for POJOs and Java collections.
+ */
+private [sql] object JavaTypeInference {
+
+  private val iterableType = TypeToken.of(classOf[JIterable[_]])
+  private val mapType = TypeToken.of(classOf[JMap[_, _]])
+  private val iteratorReturnType = classOf[JIterable[_]].getMethod("iterator").getGenericReturnType
+  private val nextReturnType = classOf[JIterator[_]].getMethod("next").getGenericReturnType
+  private val keySetReturnType = classOf[JMap[_, _]].getMethod("keySet").getGenericReturnType
+  private val valuesReturnType = classOf[JMap[_, _]].getMethod("values").getGenericReturnType
+
+  /**
+   * Infers the corresponding SQL data type of a Java type.
+   * @param typeToken Java type
+   * @return (SQL data type, nullable)
+   */
+  private [sql] def inferDataType(typeToken: TypeToken[_]): (DataType, Boolean) = {
+    // TODO: All of this could probably be moved to Catalyst as it is mostly not Spark specific.
+    typeToken.getRawType match {
+      case c: Class[_] if c.isAnnotationPresent(classOf[SQLUserDefinedType]) =>
+        (c.getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance(), true)
+
+      case c: Class[_] if c == classOf[java.lang.String] => (StringType, true)
+      case c: Class[_] if c == java.lang.Short.TYPE => (ShortType, false)
+      case c: Class[_] if c == java.lang.Integer.TYPE => (IntegerType, false)
+      case c: Class[_] if c == java.lang.Long.TYPE => (LongType, false)
+      case c: Class[_] if c == java.lang.Double.TYPE => (DoubleType, false)
+      case c: Class[_] if c == java.lang.Byte.TYPE => (ByteType, false)
+      case c: Class[_] if c == java.lang.Float.TYPE => (FloatType, false)
+      case c: Class[_] if c == java.lang.Boolean.TYPE => (BooleanType, false)
+
+      case c: Class[_] if c == classOf[java.lang.Short] => (ShortType, true)
+      case c: Class[_] if c == classOf[java.lang.Integer] => (IntegerType, true)
+      case c: Class[_] if c == classOf[java.lang.Long] => (LongType, true)
+      case c: Class[_] if c == classOf[java.lang.Double] => (DoubleType, true)
+      case c: Class[_] if c == classOf[java.lang.Byte] => (ByteType, true)
+      case c: Class[_] if c == classOf[java.lang.Float] => (FloatType, true)
+      case c: Class[_] if c == classOf[java.lang.Boolean] => (BooleanType, true)
+
+      case c: Class[_] if c == classOf[java.math.BigDecimal] => (DecimalType(), true)
+      case c: Class[_] if c == classOf[java.sql.Date] => (DateType, true)
+      case c: Class[_] if c == classOf[java.sql.Timestamp] => (TimestampType, true)
+
+      case _ if typeToken.isArray =>
+        val (dataType, nullable) = inferDataType(typeToken.getComponentType)
+        (ArrayType(dataType, nullable), true)
+
+      case _ if iterableType.isAssignableFrom(typeToken) =>
+        val (dataType, nullable) = inferDataType(elementType(typeToken))
+        (ArrayType(dataType, nullable), true)
+
+      case _ if mapType.isAssignableFrom(typeToken) =>
+        val typeToken2 = typeToken.asInstanceOf[TypeToken[_ <: JMap[_, _]]]
+        val mapSupertype = typeToken2.getSupertype(classOf[JMap[_, _]])
+        val keyType = elementType(mapSupertype.resolveType(keySetReturnType))
+        val valueType = elementType(mapSupertype.resolveType(valuesReturnType))
+        val (keyDataType, _) = inferDataType(keyType)
+        val (valueDataType, nullable) = inferDataType(valueType)
+        (MapType(keyDataType, valueDataType, nullable), true)
+
+      case _ =>
+        val beanInfo = Introspector.getBeanInfo(typeToken.getRawType)
+        val properties = beanInfo.getPropertyDescriptors.filterNot(_.getName == "class")
+        val fields = properties.map { property =>
+          val returnType = typeToken.method(property.getReadMethod).getReturnType
+          val (dataType, nullable) = inferDataType(returnType)
+          new StructField(property.getName, dataType, nullable)
+        }
+        (new StructType(fields), true)
+    }
+  }
+
+  private def elementType(typeToken: TypeToken[_]): TypeToken[_] = {
+    val typeToken2 = typeToken.asInstanceOf[TypeToken[_ <: JIterable[_]]]
+    val iterableSupertype = typeToken2.getSupertype(classOf[JIterable[_]])
+    val iteratorType = iterableSupertype.resolveType(iteratorReturnType)
+    val itemType = iteratorType.resolveType(nextReturnType)
+    itemType
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/2a24bf92/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index f9f3eb2..bcd20c0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -25,6 +25,8 @@ import scala.collection.immutable
 import scala.language.implicitConversions
 import scala.reflect.runtime.universe.TypeTag
 
+import com.google.common.reflect.TypeToken
+
 import org.apache.spark.annotation.{DeveloperApi, Experimental}
 import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
 import org.apache.spark.rdd.RDD
@@ -1222,56 +1224,12 @@ class SQLContext(@transient val sparkContext: SparkContext)
    * Returns a Catalyst Schema for the given java bean class.
    */
   protected def getSchema(beanClass: Class[_]): Seq[AttributeReference] = {
-    val (dataType, _) = inferDataType(beanClass)
+    val (dataType, _) = JavaTypeInference.inferDataType(TypeToken.of(beanClass))
     dataType.asInstanceOf[StructType].fields.map { f =>
       AttributeReference(f.name, f.dataType, f.nullable)()
     }
   }
 
-  /**
-   * Infers the corresponding SQL data type of a Java class.
-   * @param clazz Java class
-   * @return (SQL data type, nullable)
-   */
-  private def inferDataType(clazz: Class[_]): (DataType, Boolean) = {
-    // TODO: All of this could probably be moved to Catalyst as it is mostly not Spark specific.
-    clazz match {
-      case c: Class[_] if c.isAnnotationPresent(classOf[SQLUserDefinedType]) =>
-        (c.getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance(), true)
-
-      case c: Class[_] if c == classOf[java.lang.String] => (StringType, true)
-      case c: Class[_] if c == java.lang.Short.TYPE => (ShortType, false)
-      case c: Class[_] if c == java.lang.Integer.TYPE => (IntegerType, false)
-      case c: Class[_] if c == java.lang.Long.TYPE => (LongType, false)
-      case c: Class[_] if c == java.lang.Double.TYPE => (DoubleType, false)
-      case c: Class[_] if c == java.lang.Byte.TYPE => (ByteType, false)
-      case c: Class[_] if c == java.lang.Float.TYPE => (FloatType, false)
-      case c: Class[_] if c == java.lang.Boolean.TYPE => (BooleanType, false)
-
-      case c: Class[_] if c == classOf[java.lang.Short] => (ShortType, true)
-      case c: Class[_] if c == classOf[java.lang.Integer] => (IntegerType, true)
-      case c: Class[_] if c == classOf[java.lang.Long] => (LongType, true)
-      case c: Class[_] if c == classOf[java.lang.Double] => (DoubleType, true)
-      case c: Class[_] if c == classOf[java.lang.Byte] => (ByteType, true)
-      case c: Class[_] if c == classOf[java.lang.Float] => (FloatType, true)
-      case c: Class[_] if c == classOf[java.lang.Boolean] => (BooleanType, true)
-
-      case c: Class[_] if c == classOf[java.math.BigDecimal] => (DecimalType(), true)
-      case c: Class[_] if c == classOf[java.sql.Date] => (DateType, true)
-      case c: Class[_] if c == classOf[java.sql.Timestamp] => (TimestampType, true)
-
-      case c: Class[_] if c.isArray =>
-        val (dataType, nullable) = inferDataType(c.getComponentType)
-        (ArrayType(dataType, nullable), true)
-
-      case _ =>
-        val beanInfo = Introspector.getBeanInfo(clazz)
-        val properties = beanInfo.getPropertyDescriptors.filterNot(_.getName == "class")
-        val fields = properties.map { property =>
-          val (dataType, nullable) = inferDataType(property.getPropertyType)
-          new StructField(property.getName, dataType, nullable)
-        }
-        (new StructType(fields), true)
-    }
-  }
 }
+
+

http://git-wip-us.apache.org/repos/asf/spark/blob/2a24bf92/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
----------------------------------------------------------------------
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
index 6d0fbe8..fc3ed4a 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
@@ -17,23 +17,28 @@
 
 package test.org.apache.spark.sql;
 
-import java.io.Serializable;
-import java.util.Arrays;
-
-import scala.collection.Seq;
-
-import org.junit.After;
-import org.junit.Assert;
-import org.junit.Before;
-import org.junit.Ignore;
-import org.junit.Test;
+import com.google.common.collect.ImmutableMap;
+import com.google.common.primitives.Ints;
 
 import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.sql.*;
+import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.TestData$;
 import org.apache.spark.sql.test.TestSQLContext;
 import org.apache.spark.sql.test.TestSQLContext$;
 import org.apache.spark.sql.types.*;
+import org.junit.*;
+
+import scala.collection.JavaConversions;
+import scala.collection.Seq;
+import scala.collection.mutable.Buffer;
+
+import java.io.Serializable;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
 
 import static org.apache.spark.sql.functions.*;
 
@@ -106,6 +111,8 @@ public class JavaDataFrameSuite {
   public static class Bean implements Serializable {
     private double a = 0.0;
     private Integer[] b = new Integer[]{0, 1};
+    private Map<String, int[]> c = ImmutableMap.of("hello", new int[] { 1, 2 });
+    private List<String> d = Arrays.asList("floppy", "disk");
 
     public double getA() {
       return a;
@@ -114,6 +121,14 @@ public class JavaDataFrameSuite {
     public Integer[] getB() {
       return b;
     }
+
+    public Map<String, int[]> getC() {
+      return c;
+    }
+
+    public List<String> getD() {
+      return d;
+    }
   }
 
   @Test
@@ -127,7 +142,15 @@ public class JavaDataFrameSuite {
     Assert.assertEquals(
       new StructField("b", new ArrayType(IntegerType$.MODULE$, true), true, Metadata.empty()),
       schema.apply("b"));
-    Row first = df.select("a", "b").first();
+    ArrayType valueType = new ArrayType(DataTypes.IntegerType, false);
+    MapType mapType = new MapType(DataTypes.StringType, valueType, true);
+    Assert.assertEquals(
+      new StructField("c", mapType, true, Metadata.empty()),
+      schema.apply("c"));
+    Assert.assertEquals(
+      new StructField("d", new ArrayType(DataTypes.StringType, true), true, Metadata.empty()),
+      schema.apply("d"));
+    Row first = df.select("a", "b", "c", "d").first();
     Assert.assertEquals(bean.getA(), first.getDouble(0), 0.0);
     // Now Java lists and maps are converetd to Scala Seq's and Map's. Once we get a Seq below,
     // verify that it has the expected length, and contains expected elements.
@@ -136,5 +159,15 @@ public class JavaDataFrameSuite {
     for (int i = 0; i < result.length(); i++) {
       Assert.assertEquals(bean.getB()[i], result.apply(i));
     }
+    Buffer<Integer> outputBuffer = (Buffer<Integer>) first.getJavaMap(2).get("hello");
+    Assert.assertArrayEquals(
+      bean.getC().get("hello"),
+      Ints.toArray(JavaConversions.asJavaList(outputBuffer)));
+    Seq<String> d = first.getAs(3);
+    Assert.assertEquals(bean.getD().size(), d.length());
+    for (int i = 0; i < d.length(); i++) {
+      Assert.assertEquals(bean.getD().get(i), d.apply(i));
+    }
   }
+
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org