You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ma...@apache.org on 2015/12/18 23:05:09 UTC
spark git commit: [SPARK-12404][SQL] Ensure objects passed to
StaticInvoke is Serializable
Repository: spark
Updated Branches:
refs/heads/master 41ee7c57a -> 6eba65525
[SPARK-12404][SQL] Ensure objects passed to StaticInvoke is Serializable
Now `StaticInvoke` receives `Any` as a object and `StaticInvoke` can be serialized but sometimes the object passed is not serializable.
For example, following code raises Exception because `RowEncoder#extractorsFor` invoked indirectly makes `StaticInvoke`.
```
case class TimestampContainer(timestamp: java.sql.Timestamp)
val rdd = sc.parallelize(1 to 2).map(_ => TimestampContainer(System.currentTimeMillis))
val df = rdd.toDF
val ds = df.as[TimestampContainer]
val rdd2 = ds.rdd <----------------- invokes extractorsFor indirectory
```
I'll add test cases.
Author: Kousuke Saruta <sa...@oss.nttdata.co.jp>
Author: Michael Armbrust <mi...@databricks.com>
Closes #10357 from sarutak/SPARK-12404.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/6eba6552
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/6eba6552
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/6eba6552
Branch: refs/heads/master
Commit: 6eba655259d2bcea27d0147b37d5d1e476e85422
Parents: 41ee7c5
Author: Kousuke Saruta <sa...@oss.nttdata.co.jp>
Authored: Fri Dec 18 14:05:06 2015 -0800
Committer: Michael Armbrust <mi...@databricks.com>
Committed: Fri Dec 18 14:05:06 2015 -0800
----------------------------------------------------------------------
.../spark/sql/catalyst/JavaTypeInference.scala | 12 ++---
.../spark/sql/catalyst/ScalaReflection.scala | 16 +++---
.../sql/catalyst/encoders/RowEncoder.scala | 14 +++---
.../sql/catalyst/expressions/objects.scala | 8 ++-
.../org/apache/spark/sql/JavaDatasetSuite.java | 52 ++++++++++++++++++++
.../org/apache/spark/sql/DatasetSuite.scala | 12 +++++
6 files changed, 88 insertions(+), 26 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/6eba6552/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
index c8ee87e..f566d1b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
@@ -194,7 +194,7 @@ object JavaTypeInference {
case c if c == classOf[java.sql.Date] =>
StaticInvoke(
- DateTimeUtils,
+ DateTimeUtils.getClass,
ObjectType(c),
"toJavaDate",
getPath :: Nil,
@@ -202,7 +202,7 @@ object JavaTypeInference {
case c if c == classOf[java.sql.Timestamp] =>
StaticInvoke(
- DateTimeUtils,
+ DateTimeUtils.getClass,
ObjectType(c),
"toJavaTimestamp",
getPath :: Nil,
@@ -276,7 +276,7 @@ object JavaTypeInference {
ObjectType(classOf[Array[Any]]))
StaticInvoke(
- ArrayBasedMapData,
+ ArrayBasedMapData.getClass,
ObjectType(classOf[JMap[_, _]]),
"toJavaMap",
keyData :: valueData :: Nil)
@@ -341,21 +341,21 @@ object JavaTypeInference {
case c if c == classOf[java.sql.Timestamp] =>
StaticInvoke(
- DateTimeUtils,
+ DateTimeUtils.getClass,
TimestampType,
"fromJavaTimestamp",
inputObject :: Nil)
case c if c == classOf[java.sql.Date] =>
StaticInvoke(
- DateTimeUtils,
+ DateTimeUtils.getClass,
DateType,
"fromJavaDate",
inputObject :: Nil)
case c if c == classOf[java.math.BigDecimal] =>
StaticInvoke(
- Decimal,
+ Decimal.getClass,
DecimalType.SYSTEM_DEFAULT,
"apply",
inputObject :: Nil)
http://git-wip-us.apache.org/repos/asf/spark/blob/6eba6552/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index ecff860..c1b1d5c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -223,7 +223,7 @@ object ScalaReflection extends ScalaReflection {
case t if t <:< localTypeOf[java.sql.Date] =>
StaticInvoke(
- DateTimeUtils,
+ DateTimeUtils.getClass,
ObjectType(classOf[java.sql.Date]),
"toJavaDate",
getPath :: Nil,
@@ -231,7 +231,7 @@ object ScalaReflection extends ScalaReflection {
case t if t <:< localTypeOf[java.sql.Timestamp] =>
StaticInvoke(
- DateTimeUtils,
+ DateTimeUtils.getClass,
ObjectType(classOf[java.sql.Timestamp]),
"toJavaTimestamp",
getPath :: Nil,
@@ -287,7 +287,7 @@ object ScalaReflection extends ScalaReflection {
ObjectType(classOf[Array[Any]]))
StaticInvoke(
- scala.collection.mutable.WrappedArray,
+ scala.collection.mutable.WrappedArray.getClass,
ObjectType(classOf[Seq[_]]),
"make",
arrayData :: Nil)
@@ -315,7 +315,7 @@ object ScalaReflection extends ScalaReflection {
ObjectType(classOf[Array[Any]]))
StaticInvoke(
- ArrayBasedMapData,
+ ArrayBasedMapData.getClass,
ObjectType(classOf[Map[_, _]]),
"toScalaMap",
keyData :: valueData :: Nil)
@@ -548,28 +548,28 @@ object ScalaReflection extends ScalaReflection {
case t if t <:< localTypeOf[java.sql.Timestamp] =>
StaticInvoke(
- DateTimeUtils,
+ DateTimeUtils.getClass,
TimestampType,
"fromJavaTimestamp",
inputObject :: Nil)
case t if t <:< localTypeOf[java.sql.Date] =>
StaticInvoke(
- DateTimeUtils,
+ DateTimeUtils.getClass,
DateType,
"fromJavaDate",
inputObject :: Nil)
case t if t <:< localTypeOf[BigDecimal] =>
StaticInvoke(
- Decimal,
+ Decimal.getClass,
DecimalType.SYSTEM_DEFAULT,
"apply",
inputObject :: Nil)
case t if t <:< localTypeOf[java.math.BigDecimal] =>
StaticInvoke(
- Decimal,
+ Decimal.getClass,
DecimalType.SYSTEM_DEFAULT,
"apply",
inputObject :: Nil)
http://git-wip-us.apache.org/repos/asf/spark/blob/6eba6552/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
index d34ec94..63bdf05 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
@@ -61,21 +61,21 @@ object RowEncoder {
case TimestampType =>
StaticInvoke(
- DateTimeUtils,
+ DateTimeUtils.getClass,
TimestampType,
"fromJavaTimestamp",
inputObject :: Nil)
case DateType =>
StaticInvoke(
- DateTimeUtils,
+ DateTimeUtils.getClass,
DateType,
"fromJavaDate",
inputObject :: Nil)
case _: DecimalType =>
StaticInvoke(
- Decimal,
+ Decimal.getClass,
DecimalType.SYSTEM_DEFAULT,
"apply",
inputObject :: Nil)
@@ -172,14 +172,14 @@ object RowEncoder {
case TimestampType =>
StaticInvoke(
- DateTimeUtils,
+ DateTimeUtils.getClass,
ObjectType(classOf[java.sql.Timestamp]),
"toJavaTimestamp",
input :: Nil)
case DateType =>
StaticInvoke(
- DateTimeUtils,
+ DateTimeUtils.getClass,
ObjectType(classOf[java.sql.Date]),
"toJavaDate",
input :: Nil)
@@ -197,7 +197,7 @@ object RowEncoder {
"array",
ObjectType(classOf[Array[_]]))
StaticInvoke(
- scala.collection.mutable.WrappedArray,
+ scala.collection.mutable.WrappedArray.getClass,
ObjectType(classOf[Seq[_]]),
"make",
arrayData :: Nil)
@@ -210,7 +210,7 @@ object RowEncoder {
val valueData = constructorFor(Invoke(input, "valueArray", valueArrayType))
StaticInvoke(
- ArrayBasedMapData,
+ ArrayBasedMapData.getClass,
ObjectType(classOf[Map[_, _]]),
"toScalaMap",
keyData :: valueData :: Nil)
http://git-wip-us.apache.org/repos/asf/spark/blob/6eba6552/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
index 10ec75e..492cc9b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
@@ -42,16 +42,14 @@ import org.apache.spark.sql.types._
* of calling the function.
*/
case class StaticInvoke(
- staticObject: Any,
+ staticObject: Class[_],
dataType: DataType,
functionName: String,
arguments: Seq[Expression] = Nil,
propagateNull: Boolean = true) extends Expression {
- val objectName = staticObject match {
- case c: Class[_] => c.getName
- case other => other.getClass.getName.stripSuffix("$")
- }
+ val objectName = staticObject.getName.stripSuffix("$")
+
override def nullable: Boolean = true
override def children: Seq[Expression] = arguments
http://git-wip-us.apache.org/repos/asf/spark/blob/6eba6552/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
----------------------------------------------------------------------
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
index 383a2d0..0dbaeb8 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
@@ -39,6 +39,7 @@ import org.apache.spark.sql.expressions.Aggregator;
import org.apache.spark.sql.test.TestSQLContext;
import org.apache.spark.sql.catalyst.encoders.OuterScopes;
import org.apache.spark.sql.catalyst.expressions.GenericRow;
+import org.apache.spark.sql.types.DecimalType;
import org.apache.spark.sql.types.StructType;
import static org.apache.spark.sql.functions.*;
@@ -608,6 +609,44 @@ public class JavaDatasetSuite implements Serializable {
}
}
+ public class SimpleJavaBean2 implements Serializable {
+ private Timestamp a;
+ private Date b;
+ private java.math.BigDecimal c;
+
+ public Timestamp getA() { return a; }
+
+ public void setA(Timestamp a) { this.a = a; }
+
+ public Date getB() { return b; }
+
+ public void setB(Date b) { this.b = b; }
+
+ public java.math.BigDecimal getC() { return c; }
+
+ public void setC(java.math.BigDecimal c) { this.c = c; }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+
+ SimpleJavaBean that = (SimpleJavaBean) o;
+
+ if (!a.equals(that.a)) return false;
+ if (!b.equals(that.b)) return false;
+ return c.equals(that.c);
+ }
+
+ @Override
+ public int hashCode() {
+ int result = a.hashCode();
+ result = 31 * result + b.hashCode();
+ result = 31 * result + c.hashCode();
+ return result;
+ }
+ }
+
public class NestedJavaBean implements Serializable {
private SimpleJavaBean a;
@@ -689,4 +728,17 @@ public class JavaDatasetSuite implements Serializable {
.as(Encoders.bean(SimpleJavaBean.class));
Assert.assertEquals(data, ds3.collectAsList());
}
+
+ @Test
+ public void testJavaBeanEncoder2() {
+ // This is a regression test of SPARK-12404
+ OuterScopes.addOuterScope(this);
+ SimpleJavaBean2 obj = new SimpleJavaBean2();
+ obj.setA(new Timestamp(0));
+ obj.setB(new Date(0));
+ obj.setC(java.math.BigDecimal.valueOf(1));
+ Dataset<SimpleJavaBean2> ds =
+ context.createDataset(Arrays.asList(obj), Encoders.bean(SimpleJavaBean2.class));
+ ds.collect();
+ }
}
http://git-wip-us.apache.org/repos/asf/spark/blob/6eba6552/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index f1b6b98..de012a9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql
import java.io.{ObjectInput, ObjectOutput, Externalizable}
+import java.sql.{Date, Timestamp}
import scala.language.postfixOps
@@ -42,6 +43,17 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
1, 1, 1)
}
+
+ test("SPARK-12404: Datatype Helper Serializablity") {
+ val ds = sparkContext.parallelize((
+ new Timestamp(0),
+ new Date(0),
+ java.math.BigDecimal.valueOf(1),
+ scala.math.BigDecimal(1)) :: Nil).toDS()
+
+ ds.collect()
+ }
+
test("collect, first, and take should use encoders for serialization") {
val item = NonSerializableCaseClass("abcd")
val ds = Seq(item).toDS()
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org