You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by sr...@apache.org on 2021/08/09 13:48:19 UTC

[spark] branch master updated: [SPARK-20384][SQL] Support value class in nested schema for Dataset

This is an automated email from the ASF dual-hosted git repository.

srowen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 33c6d11  [SPARK-20384][SQL] Support value class in nested schema for Dataset
33c6d11 is described below

commit 33c6d1168c077630a8f81c1a4e153f862162f257
Author: Mick Jermsurawong <mi...@stripe.com>
AuthorDate: Mon Aug 9 08:47:35 2021 -0500

    [SPARK-20384][SQL] Support value class in nested schema for Dataset
    
    ### What changes were proposed in this pull request?
    
    - This PR revisits https://github.com/apache/spark/pull/22309, and [SPARK-20384](https://issues.apache.org/jira/browse/SPARK-20384) solving the original problem, but additionally will prevent backward-compat break on schema of top-level `AnyVal` value class.
    - Why previous break? We currently support top-level value classes just as any other case class; field of the underlying type is present in schema. This means any dataframe SQL filtering on this expects the field name to be present. The previous PR changes this schema and would result in breaking current usage. See test `"schema for case class that is a value class"`. This PR keeps the schema.
    - We actually currently support collection of value classes prior to this change, but not case class of nested value class. This means the schema of these classes shouldn't change to prevent breaking too.
    - However, what we can change, without breaking, is schema of nested value class, which will fails due to the compile problem, and thus its schema now isn't actually valid. After the change, the schema of this nested value class is now flattened
    - With this PR, there's flattening only for nested value class (new), but not for top-level and collection classes (existing behavior)
    - This PR revisits https://github.com/apache/spark/pull/27153 by handling tuple `Tuple2[AnyVal, AnyVal]` which is a constructor ("nested class") but is a generic type, so it should not be flattened behaving similarly to `Seq[AnyVal]`
    
    ### Why are the changes needed?
    
    - Currently, nested value class isn't supported. This is because when the generated code treats `anyVal` class in its unwrapped form, but we encode the type to be the wrapped case class. This results in compile of generated code
    For example,
    For a given `AnyVal` wrapper and its root-level class container
    ```
    case class IntWrapper(i: Int) extends AnyVal
    case class ComplexValueClassContainer(c: IntWrapper)
    ```
    The problematic part of generated code:
    ```
        private InternalRow If_1(InternalRow i) {
            boolean isNull_42 = i.isNullAt(0);
            // 1) ******** The root-level case class we care
            org.apache.spark.sql.catalyst.encoders.ComplexValueClassContainer value_46 = isNull_42 ?
                null : ((org.apache.spark.sql.catalyst.encoders.ComplexValueClassContainer) i.get(0, null));
            if (isNull_42) {
                throw new NullPointerException(((java.lang.String) references[5] /* errMsg */ ));
            }
            boolean isNull_39 = true;
            // 2) ******** We specify its member to be unwrapped case class extending `AnyVal`
            org.apache.spark.sql.catalyst.encoders.IntWrapper value_43 = null;
            if (!false) {
    
                isNull_39 = false;
                if (!isNull_39) {
                    // 3) ******** ERROR: `c()` compiled however is of type `int` and thus we see error
                    value_43 = value_46.c();
                }
            }
    ```
    We get this errror: Assignment conversion not possible from type "int" to type "org.apache.spark.sql.catalyst.encoders.IntWrapper"
    ```
    java.util.concurrent.ExecutionException: org.codehaus.commons.compiler.CompileException:
    File 'generated.java', Line 159, Column 1: failed to compile: org.codehaus.commons.compiler.CompileException: File 'generated.java', Line 159, Column 1: Assignment conversion not possible from type "int" to type "org.apache.spark.sql.catalyst.encoders.IntWrapper"
    ```
    
    From [doc](https://docs.scala-lang.org/overviews/core/value-classes.html) on value class: , Given: `class Wrapper(val underlying: Int) extends AnyVal`,
    1) "The type at compile time is `Wrapper`, but at runtime, the representation is an `Int`". This implies that when our struct has a field of value class, the generated code should support the underlying type during runtime execution.
    2) `Wrapper` "must be instantiated... when a value class is used as a type argument". This implies that `scala.Tuple[Wrapper, ...], Seq[Wrapper], Map[String, Wrapper], Option[Wrapper]` will still contain Wrapper as-is in during runtime instead of `Int`.
    
    ### Does this PR introduce _any_ user-facing change?
    
    - Yes, this will allow support for the nested value class.
    
    ### How was this patch tested?
    
    - Added unit tests to illustrate
      - raw schema
      - projection
      - round-trip encode/decode
    
    Closes #33205 from mickjermsurawong-stripe/SPARK-20384-2.
    
    Lead-authored-by: Mick Jermsurawong <mi...@stripe.com>
    Co-authored-by: Emil Ejbyfeldt <ee...@liveintent.com>
    Signed-off-by: Sean Owen <sr...@gmail.com>
---
 .../spark/sql/catalyst/ScalaReflection.scala       | 33 +++++---
 .../spark/sql/catalyst/ScalaReflectionSuite.scala  | 94 ++++++++++++++++++++++
 .../catalyst/encoders/ExpressionEncoderSuite.scala | 57 +++++++++++++
 .../org/apache/spark/sql/DataFrameSuite.scala      | 52 +++++++++++-
 .../org/apache/spark/sql/test/SQLTestData.scala    |  3 +
 5 files changed, 229 insertions(+), 10 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 4de7e5c..bab407b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -956,6 +956,19 @@ trait ScalaReflection extends Logging {
     tag.in(mirror).tpe.dealias
   }
 
+  private def isValueClass(tpe: Type): Boolean = {
+    tpe.typeSymbol.asClass.isDerivedValueClass
+  }
+
+  private def isTypeParameter(tpe: Type): Boolean = {
+    tpe.typeSymbol.isParameter
+  }
+
+  /** Returns the name and type of the underlying parameter of value class `tpe`. */
+  private def getUnderlyingTypeOfValueClass(tpe: `Type`): Type = {
+    getConstructorParameters(tpe).head._2
+  }
+
   /**
    * Returns the parameter names and types for the primary constructor of this type.
    *
@@ -967,15 +980,17 @@ trait ScalaReflection extends Logging {
     val formalTypeArgs = dealiasedTpe.typeSymbol.asClass.typeParams
     val TypeRef(_, _, actualTypeArgs) = dealiasedTpe
     val params = constructParams(dealiasedTpe)
-    // if there are type variables to fill in, do the substitution (SomeClass[T] -> SomeClass[Int])
-    if (actualTypeArgs.nonEmpty) {
-      params.map { p =>
-        p.name.decodedName.toString ->
-          p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs)
-      }
-    } else {
-      params.map { p =>
-        p.name.decodedName.toString -> p.typeSignature
+    params.map { p =>
+      val paramTpe = p.typeSignature
+      if (isTypeParameter(paramTpe)) {
+        // if there are type variables to fill in, do the substitution
+        // (SomeClass[T] -> SomeClass[Int])
+        p.name.decodedName.toString -> paramTpe.substituteTypes(formalTypeArgs, actualTypeArgs)
+      } else if (isValueClass(paramTpe)) {
+        // Replace value class with underlying type
+        p.name.decodedName.toString -> getUnderlyingTypeOfValueClass(paramTpe)
+      } else {
+        p.name.decodedName.toString -> paramTpe
       }
     }
   }
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
index 164bbd7..d86f986 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
@@ -156,8 +156,19 @@ object TraitProductWithNoConstructorCompanion {}
 
 trait TraitProductWithNoConstructorCompanion extends Product1[Int] {}
 
+object TestingValueClass {
+  case class IntWrapper(val i: Int) extends AnyVal
+  case class StrWrapper(s: String) extends AnyVal
+
+  case class ValueClassData(intField: Int,
+                            wrappedInt: IntWrapper, // an int column
+                            strField: String,
+                            wrappedStr: StrWrapper) // a string column
+}
+
 class ScalaReflectionSuite extends SparkFunSuite {
   import org.apache.spark.sql.catalyst.ScalaReflection._
+  import TestingValueClass._
 
   // A helper method used to test `ScalaReflection.serializerForType`.
   private def serializerFor[T: TypeTag]: Expression =
@@ -451,4 +462,87 @@ class ScalaReflectionSuite extends SparkFunSuite {
       StructField("e", StringType, true))))
     assert(deserializerFor[FooClassWithEnum].dataType == ObjectType(classOf[FooClassWithEnum]))
   }
+
+  test("schema for case class that is a value class") {
+    val schema = schemaFor[IntWrapper]
+    assert(
+      schema === Schema(StructType(Seq(StructField("i", IntegerType, false))), nullable = true))
+  }
+
+  test("SPARK-20384: schema for case class that contains value class fields") {
+    val schema = schemaFor[ValueClassData]
+    assert(
+      schema === Schema(
+        StructType(Seq(
+          StructField("intField", IntegerType, nullable = false),
+          StructField("wrappedInt", IntegerType, nullable = false),
+          StructField("strField", StringType),
+          StructField("wrappedStr", StringType)
+        )),
+        nullable = true))
+  }
+
+  test("SPARK-20384: schema for array of value class") {
+    val schema = schemaFor[Array[IntWrapper]]
+    assert(
+      schema === Schema(
+        ArrayType(StructType(Seq(StructField("i", IntegerType, false))), containsNull = true),
+        nullable = true))
+  }
+
+  test("SPARK-20384: schema for map of value class") {
+    val schema = schemaFor[Map[IntWrapper, StrWrapper]]
+    assert(
+      schema === Schema(
+        MapType(
+          StructType(Seq(StructField("i", IntegerType, false))),
+          StructType(Seq(StructField("s", StringType))),
+          valueContainsNull = true),
+        nullable = true))
+  }
+
+  test("SPARK-20384: schema for tuple_2 of value class") {
+    val schema = schemaFor[(IntWrapper, StrWrapper)]
+    assert(
+      schema === Schema(
+        StructType(
+          Seq(
+            StructField("_1", StructType(Seq(StructField("i", IntegerType, false)))),
+            StructField("_2", StructType(Seq(StructField("s", StringType))))
+          )
+        ),
+        nullable = true))
+  }
+
+  test("SPARK-20384: schema for tuple_3 of value class") {
+    val schema = schemaFor[(IntWrapper, StrWrapper, StrWrapper)]
+    assert(
+      schema === Schema(
+        StructType(
+          Seq(
+            StructField("_1", StructType(Seq(StructField("i", IntegerType, false)))),
+            StructField("_2", StructType(Seq(StructField("s", StringType)))),
+            StructField("_3", StructType(Seq(StructField("s", StringType))))
+          )
+        ),
+        nullable = true))
+  }
+
+  test("SPARK-20384: schema for nested tuple of value class") {
+    val schema = schemaFor[(IntWrapper, (StrWrapper, StrWrapper))]
+    assert(
+      schema === Schema(
+        StructType(
+          Seq(
+            StructField("_1", StructType(Seq(StructField("i", IntegerType, false)))),
+            StructField("_2", StructType(
+              Seq(
+                StructField("_1", StructType(Seq(StructField("s", StringType)))),
+                StructField("_2", StructType(Seq(StructField("s", StringType)))))
+              )
+            )
+          )
+        ),
+        nullable = true))
+  }
 }
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
index bf4afac..ae5ce60 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
@@ -116,6 +116,21 @@ object ReferenceValueClass {
 }
 case class IntAndString(i: Int, s: String)
 
+case class StringWrapper(s: String) extends AnyVal
+case class ValueContainer(
+                           a: Int,
+                           b: StringWrapper) // a string column
+case class IntWrapper(i: Int) extends AnyVal
+case class ComplexValueClassContainer(
+                                       a: Int,
+                                       b: ValueContainer,
+                                       c: IntWrapper)
+case class SeqOfValueClass(s: Seq[StringWrapper])
+case class MapOfValueClassKey(m: Map[IntWrapper, String])
+case class MapOfValueClassValue(m: Map[String, StringWrapper])
+case class OptionOfValueClassValue(o: Option[StringWrapper])
+case class CaseClassWithGeneric[T](generic: T, value: IntWrapper)
+
 class ExpressionEncoderSuite extends CodegenInterpretedPlanTest with AnalysisTest {
   OuterScopes.addOuterScope(this)
 
@@ -391,12 +406,54 @@ class ExpressionEncoderSuite extends CodegenInterpretedPlanTest with AnalysisTes
     ExpressionEncoder.tuple(intEnc, ExpressionEncoder.tuple(intEnc, longEnc))
   }
 
+  // test for value classes
   encodeDecodeTest(
     PrimitiveValueClass(42), "primitive value class")
 
   encodeDecodeTest(
     ReferenceValueClass(ReferenceValueClass.Container(1)), "reference value class")
 
+  encodeDecodeTest(StringWrapper("a"), "string value class")
+  encodeDecodeTest(ValueContainer(1, StringWrapper("b")), "nested value class")
+  encodeDecodeTest(ValueContainer(1, StringWrapper(null)), "nested value class with null")
+  encodeDecodeTest(ComplexValueClassContainer(1, ValueContainer(2, StringWrapper("b")),
+    IntWrapper(3)), "complex value class")
+  encodeDecodeTest(
+    Array(IntWrapper(1), IntWrapper(2), IntWrapper(3)),
+    "array of value class")
+  encodeDecodeTest(Array.empty[IntWrapper], "empty array of value class")
+  encodeDecodeTest(
+    Seq(IntWrapper(1), IntWrapper(2), IntWrapper(3)),
+    "seq of value class")
+  encodeDecodeTest(Seq.empty[IntWrapper], "empty seq of value class")
+  encodeDecodeTest(
+    Map(IntWrapper(1) -> StringWrapper("a"), IntWrapper(2) -> StringWrapper("b")),
+    "map with value class")
+
+  // test for nested value class collections
+  encodeDecodeTest(
+    MapOfValueClassKey(Map(IntWrapper(1)-> "a")),
+    "case class with map of value class key")
+  encodeDecodeTest(
+    MapOfValueClassValue(Map("a"-> StringWrapper("b"))),
+    "case class with map of value class value")
+  encodeDecodeTest(
+    SeqOfValueClass(Seq(StringWrapper("a"))),
+    "case class with seq of class value")
+  encodeDecodeTest(
+    OptionOfValueClassValue(Some(StringWrapper("a"))),
+    "case class with option of class value")
+  encodeDecodeTest((StringWrapper("a_1"), StringWrapper("a_2")),
+    "tuple2 of class value")
+  encodeDecodeTest((StringWrapper("a_1"), StringWrapper("a_2"), StringWrapper("a_3")),
+    "tuple3 of class value")
+  encodeDecodeTest(((StringWrapper("a_1"), StringWrapper("a_2")), StringWrapper("b_2")),
+    "nested tuple._1 of class value")
+  encodeDecodeTest((StringWrapper("a_1"), (StringWrapper("b_1"), StringWrapper("b_2"))),
+    "nested tuple._2 of class value")
+  encodeDecodeTest(CaseClassWithGeneric(IntWrapper(1), IntWrapper(2)),
+    "case class with value class in generic parameter")
+
   encodeDecodeTest(Option(31), "option of int")
   encodeDecodeTest(Option.empty[Int], "empty option of int")
   encodeDecodeTest(Option("abc"), "option of string")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 2fd5993..f2d0b60 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -46,7 +46,7 @@ import org.apache.spark.sql.expressions.{Aggregator, Window}
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT, SharedSparkSession}
-import org.apache.spark.sql.test.SQLTestData.{DecimalData, TestData2}
+import org.apache.spark.sql.test.SQLTestData.{ArrayStringWrapper, ContainerStringWrapper, DecimalData, StringWrapper, TestData2}
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.CalendarInterval
 import org.apache.spark.util.Utils
@@ -834,6 +834,56 @@ class DataFrameSuite extends QueryTest
     assert(df.schema.map(_.name) === Seq("key", "valueRenamed", "newCol"))
   }
 
+  test("SPARK-20384: Value class filter") {
+    val df = spark.sparkContext
+      .parallelize(Seq(StringWrapper("a"), StringWrapper("b"), StringWrapper("c")))
+      .toDF()
+    val filtered = df.where("s = \"a\"")
+    checkAnswer(filtered, spark.sparkContext.parallelize(Seq(StringWrapper("a"))).toDF)
+  }
+
+  test("SPARK-20384: Tuple2 of value class filter") {
+    val df = spark.sparkContext
+      .parallelize(Seq(
+        (StringWrapper("a1"), StringWrapper("a2")),
+        (StringWrapper("b1"), StringWrapper("b2"))))
+      .toDF()
+    val filtered = df.where("_2.s = \"a2\"")
+    checkAnswer(filtered,
+      spark.sparkContext.parallelize(Seq((StringWrapper("a1"), StringWrapper("a2")))).toDF)
+  }
+
+  test("SPARK-20384: Tuple3 of value class filter") {
+    val df = spark.sparkContext
+      .parallelize(Seq(
+        (StringWrapper("a1"), StringWrapper("a2"), StringWrapper("a3")),
+        (StringWrapper("b1"), StringWrapper("b2"), StringWrapper("b3"))))
+      .toDF()
+    val filtered = df.where("_3.s = \"a3\"")
+    checkAnswer(filtered,
+      spark.sparkContext.parallelize(
+        Seq((StringWrapper("a1"), StringWrapper("a2"), StringWrapper("a3")))).toDF)
+  }
+
+  test("SPARK-20384: Array value class filter") {
+    val ab = ArrayStringWrapper(Seq(StringWrapper("a"), StringWrapper("b")))
+    val cd = ArrayStringWrapper(Seq(StringWrapper("c"), StringWrapper("d")))
+
+    val df = spark.sparkContext.parallelize(Seq(ab, cd)).toDF
+    val filtered = df.where(array_contains(col("wrappers.s"), "b"))
+    checkAnswer(filtered, spark.sparkContext.parallelize(Seq(ab)).toDF)
+  }
+
+  test("SPARK-20384: Nested value class filter") {
+    val a = ContainerStringWrapper(StringWrapper("a"))
+    val b = ContainerStringWrapper(StringWrapper("b"))
+
+    val df = spark.sparkContext.parallelize(Seq(a, b)).toDF
+    // flat value class, `s` field is not in schema
+    val filtered = df.where("wrapper = \"a\"")
+    checkAnswer(filtered, spark.sparkContext.parallelize(Seq(a)).toDF)
+  }
+
   private lazy val person2: DataFrame = Seq(
     ("Bob", 16, 176),
     ("Alice", 32, 164),
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala
index 307c4f3..21064b5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala
@@ -450,4 +450,7 @@ private[sql] object SQLTestData {
   case class CourseSales(course: String, year: Int, earnings: Double)
   case class TrainingSales(training: String, sales: CourseSales)
   case class IntervalData(data: CalendarInterval)
+  case class StringWrapper(s: String) extends AnyVal
+  case class ArrayStringWrapper(wrappers: Seq[StringWrapper])
+  case class ContainerStringWrapper(wrapper: StringWrapper)
 }

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org