You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by al...@apache.org on 2019/05/13 14:28:35 UTC

[flink] branch master updated: [FLINK-12301] Fix ScalaCaseClassSerializer to support value types

This is an automated email from the ASF dual-hosted git repository.

aljoscha pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/master by this push:
     new 9caf2c4  [FLINK-12301] Fix ScalaCaseClassSerializer to support value types
9caf2c4 is described below

commit 9caf2c4355f851c7a8ca2b1fe9a1c6dab7bd95e3
Author: Igal Shilman <ig...@data-artisans.com>
AuthorDate: Mon May 13 11:00:37 2019 +0200

    [FLINK-12301] Fix ScalaCaseClassSerializer to support value types
    
    We now use Scala reflection because it correctly deals with Scala
    language features.
---
 .../scala/typeutils/ScalaCaseClassSerializer.scala | 65 +++++++---------------
 .../ScalaCaseClassSerializerReflectionTest.scala   | 41 ++++++++++----
 2 files changed, 50 insertions(+), 56 deletions(-)

diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/typeutils/ScalaCaseClassSerializer.scala b/flink-scala/src/main/scala/org/apache/flink/api/scala/typeutils/ScalaCaseClassSerializer.scala
index 7ff1427..fbaa2ac 100644
--- a/flink-scala/src/main/scala/org/apache/flink/api/scala/typeutils/ScalaCaseClassSerializer.scala
+++ b/flink-scala/src/main/scala/org/apache/flink/api/scala/typeutils/ScalaCaseClassSerializer.scala
@@ -18,15 +18,14 @@
 
 package org.apache.flink.api.scala.typeutils
 
+import java.io.ObjectInputStream
+
 import org.apache.flink.api.common.typeutils.CompositeTypeSerializerUtil.delegateCompatibilityCheckToNewSnapshot
 import org.apache.flink.api.common.typeutils.TypeSerializerConfigSnapshot.SelfResolvingTypeSerializer
 import org.apache.flink.api.common.typeutils._
 import org.apache.flink.api.java.typeutils.runtime.TupleSerializerConfigSnapshot
 import org.apache.flink.api.scala.typeutils.ScalaCaseClassSerializer.lookupConstructor
 
-import java.io.ObjectInputStream
-import java.lang.invoke.{MethodHandle, MethodHandles}
-
 import scala.collection.JavaConverters._
 import scala.reflect.runtime.universe
 
@@ -38,16 +37,16 @@ import scala.reflect.runtime.universe
   */
 @SerialVersionUID(1L)
 class ScalaCaseClassSerializer[T <: Product](
-  clazz: Class[T],
-  scalaFieldSerializers: Array[TypeSerializer[_]]
-) extends CaseClassSerializer[T](clazz, scalaFieldSerializers)
-    with SelfResolvingTypeSerializer[T] {
+    clazz: Class[T],
+    scalaFieldSerializers: Array[TypeSerializer[_]]
+    ) extends CaseClassSerializer[T](clazz, scalaFieldSerializers)
+  with SelfResolvingTypeSerializer[T] {
 
   @transient
   private var constructor = lookupConstructor(clazz)
 
   override def createInstance(fields: Array[AnyRef]): T = {
-    constructor.invoke(fields).asInstanceOf[T]
+    constructor(fields)
   }
 
   override def snapshotConfiguration(): TypeSerializerSnapshot[T] = {
@@ -55,8 +54,7 @@ class ScalaCaseClassSerializer[T <: Product](
   }
 
   override def resolveSchemaCompatibilityViaRedirectingToNewSnapshotClass(
-    s: TypeSerializerConfigSnapshot[T]
-  ): TypeSerializerSchemaCompatibility[T] = {
+      s: TypeSerializerConfigSnapshot[T]): TypeSerializerSchemaCompatibility[T] = {
 
     require(s.isInstanceOf[TupleSerializerConfigSnapshot[_]])
 
@@ -85,22 +83,8 @@ class ScalaCaseClassSerializer[T <: Product](
 
 object ScalaCaseClassSerializer {
 
-  def lookupConstructor[T](clazz: Class[_]): MethodHandle = {
-    val types = findPrimaryConstructorParameterTypes(clazz, clazz.getClassLoader)
-
-    val constructor = clazz.getConstructor(types: _*)
-
-    val handle = MethodHandles
-      .lookup()
-      .unreflectConstructor(constructor)
-      .asSpreader(classOf[Array[AnyRef]], types.length)
-
-    handle
-  }
-
-  private def findPrimaryConstructorParameterTypes(cls: Class[_], cl: ClassLoader):
-  List[Class[_]] = {
-    val rootMirror = universe.runtimeMirror(cl)
+  def lookupConstructor[T](cls: Class[T]): Array[AnyRef] => T = {
+    val rootMirror = universe.runtimeMirror(cls.getClassLoader)
     val classSymbol = rootMirror.classSymbol(cls)
 
     require(
@@ -113,30 +97,21 @@ object ScalaCaseClassSerializer {
          |""".stripMargin
     )
 
-    val primaryConstructorSymbol = findPrimaryConstructorMethodSymbol(classSymbol)
-    val scalaTypes = getArgumentsTypes(primaryConstructorSymbol)
-    scalaTypes.map(tpe => scalaTypeToJavaClass(rootMirror)(tpe))
-  }
-
-  private def findPrimaryConstructorMethodSymbol(classSymbol: universe.ClassSymbol):
-  universe.MethodSymbol = {
-    classSymbol.toType
+    val primaryConstructorSymbol = classSymbol.toType
       .decl(universe.termNames.CONSTRUCTOR)
       .alternatives
+      .collectFirst({
+        case constructorSymbol: universe.MethodSymbol if constructorSymbol.isPrimaryConstructor =>
+          constructorSymbol
+      })
       .head
       .asMethod
-  }
 
-  private def getArgumentsTypes(primaryConstructorSymbol: universe.MethodSymbol):
-  List[universe.Type] = {
-    primaryConstructorSymbol.typeSignature
-      .paramLists
-      .head
-      .map(symbol => symbol.typeSignature)
-  }
+    val classMirror = rootMirror.reflectClass(classSymbol)
+    val constructorMethodMirror = classMirror.reflectConstructor(primaryConstructorSymbol)
 
-  private def scalaTypeToJavaClass(mirror: universe.Mirror)(scalaType: universe.Type): Class[_] = {
-    val erasure = scalaType.erasure
-    mirror.runtimeClass(erasure)
+    arr: Array[AnyRef] => {
+      constructorMethodMirror.apply(arr: _*).asInstanceOf[T]
+    }
   }
 }
diff --git a/flink-scala/src/test/scala/org/apache/flink/api/scala/typeutils/ScalaCaseClassSerializerReflectionTest.scala b/flink-scala/src/test/scala/org/apache/flink/api/scala/typeutils/ScalaCaseClassSerializerReflectionTest.scala
index 222dc57..47eab50 100644
--- a/flink-scala/src/test/scala/org/apache/flink/api/scala/typeutils/ScalaCaseClassSerializerReflectionTest.scala
+++ b/flink-scala/src/test/scala/org/apache/flink/api/scala/typeutils/ScalaCaseClassSerializerReflectionTest.scala
@@ -18,13 +18,11 @@
 
 package org.apache.flink.api.scala.typeutils
 
-import org.apache.flink.api.scala.typeutils.ScalaCaseClassSerializerReflectionTest.{Generic, HigherKind, SimpleCaseClass}
+import org.apache.flink.api.scala.typeutils.ScalaCaseClassSerializerReflectionTest._
 import org.apache.flink.util.TestLogger
-
 import org.junit.Assert.assertEquals
 import org.junit.Test
 
-import java.lang.invoke.MethodHandle
 
 /**
   * Test obtaining the primary constructor of a case class
@@ -34,40 +32,40 @@ class ScalaCaseClassSerializerReflectionTest extends TestLogger {
 
   @Test
   def usageExample(): Unit = {
-    val constructor: MethodHandle = ScalaCaseClassSerializer
+    val constructor = ScalaCaseClassSerializer
       .lookupConstructor(classOf[SimpleCaseClass])
 
-    val actual = constructor.invoke(Array("hi", 1.asInstanceOf[Any]))
+    val actual = constructor(Array("hi", 1.asInstanceOf[AnyRef]))
 
     assertEquals(SimpleCaseClass("hi", 1), actual)
   }
 
   @Test
   def genericCaseClass(): Unit = {
-    val constructor: MethodHandle = ScalaCaseClassSerializer
+    val constructor = ScalaCaseClassSerializer
       .lookupConstructor(classOf[Generic[_]])
 
-    val actual = constructor.invoke(Array(1.asInstanceOf[AnyRef]))
+    val actual = constructor(Array(1.asInstanceOf[AnyRef]))
 
     assertEquals(Generic[Int](1), actual)
   }
 
   @Test
   def caseClassWithParameterizedList(): Unit = {
-    val constructor: MethodHandle = ScalaCaseClassSerializer
+    val constructor = ScalaCaseClassSerializer
       .lookupConstructor(classOf[HigherKind])
 
-    val actual = constructor.invoke(Array(List(1, 2, 3), "hey"))
+    val actual = constructor(Array(List(1, 2, 3), "hey"))
 
     assertEquals(HigherKind(List(1, 2, 3), "hey"), actual)
   }
 
   @Test
   def tupleType(): Unit = {
-    val constructor: MethodHandle = ScalaCaseClassSerializer
+    val constructor = ScalaCaseClassSerializer
       .lookupConstructor(classOf[(String, String, Int)])
 
-    val actual = constructor.invoke(Array("a", "b", 7))
+    val actual = constructor(Array("a", "b", 7.asInstanceOf[AnyRef]))
 
     assertEquals(("a", "b", 7), actual)
   }
@@ -80,6 +78,21 @@ class ScalaCaseClassSerializerReflectionTest extends TestLogger {
     ScalaCaseClassSerializer
       .lookupConstructor(classOf[outerInstance.InnerCaseClass])
   }
+
+  @Test
+  def valueClass(): Unit = {
+    val constructor = ScalaCaseClassSerializer
+      .lookupConstructor(classOf[Measurement])
+
+    val arguments = Array(
+      1.asInstanceOf[AnyRef],
+      new DegreeCelsius(0.5f).asInstanceOf[AnyRef]
+    )
+    
+    val actual = constructor(arguments)
+
+    assertEquals(Measurement(1, new DegreeCelsius(0.5f)), actual)
+  }
 }
 
 object ScalaCaseClassSerializerReflectionTest {
@@ -94,6 +107,12 @@ object ScalaCaseClassSerializerReflectionTest {
 
   case class Generic[T](item: T)
 
+  class DegreeCelsius(val value: Float) extends AnyVal {
+    override def toString: String = s"$value °C"
+  }
+
+  case class Measurement(i: Int, temperature: DegreeCelsius)
+
 }
 
 class OuterClass {