You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2018/07/17 15:07:24 UTC

spark git commit: [SPARK-24305][SQL][FOLLOWUP] Avoid serialization of private fields in collection expressions.

Repository: spark
Updated Branches:
  refs/heads/master 0ca16f6e1 -> 4cf1bec4d


[SPARK-24305][SQL][FOLLOWUP] Avoid serialization of private fields in collection expressions.

## What changes were proposed in this pull request?

The PR tries to avoid serialization of private fields of already added collection functions and follows up on comments in [SPARK-23922](https://github.com/apache/spark/pull/21028) and [SPARK-23935](https://github.com/apache/spark/pull/21236)

## How was this patch tested?

Run tests from:
- CollectionExpressionSuite.scala
- DataFrameFunctionsSuite.scala

Author: Marek Novotny <mn...@gmail.com>

Closes #21352 from mn-mikke/SPARK-24305.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/4cf1bec4
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/4cf1bec4
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/4cf1bec4

Branch: refs/heads/master
Commit: 4cf1bec4dc574c541d03ea2f49db4de8b76ef6d2
Parents: 0ca16f6
Author: Marek Novotny <mn...@gmail.com>
Authored: Tue Jul 17 23:07:18 2018 +0800
Committer: Wenchen Fan <we...@databricks.com>
Committed: Tue Jul 17 23:07:18 2018 +0800

----------------------------------------------------------------------
 .../expressions/collectionOperations.scala      | 132 +++++++++----------
 1 file changed, 64 insertions(+), 68 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/4cf1bec4/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index 972bc6e..d60f4c3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -168,27 +168,22 @@ case class ArraysZip(children: Seq[Expression]) extends Expression with ExpectsI
 
   override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.length)(ArrayType)
 
-  override def dataType: DataType = ArrayType(mountSchema)
-
-  override def nullable: Boolean = children.exists(_.nullable)
-
-  private lazy val arrayTypes = children.map(_.dataType.asInstanceOf[ArrayType])
-
-  private lazy val arrayElementTypes = arrayTypes.map(_.elementType)
-
-  @transient private lazy val mountSchema: StructType = {
+  @transient override lazy val dataType: DataType = {
     val fields = children.zip(arrayElementTypes).zipWithIndex.map {
       case ((expr: NamedExpression, elementType), _) =>
         StructField(expr.name, elementType, nullable = true)
       case ((_, elementType), idx) =>
         StructField(idx.toString, elementType, nullable = true)
     }
-    StructType(fields)
+    ArrayType(StructType(fields), containsNull = false)
   }
 
-  @transient lazy val numberOfArrays: Int = children.length
+  override def nullable: Boolean = children.exists(_.nullable)
+
+  @transient private lazy val arrayElementTypes =
+    children.map(_.dataType.asInstanceOf[ArrayType].elementType)
 
-  @transient lazy val genericArrayData = classOf[GenericArrayData].getName
+  private def genericArrayData = classOf[GenericArrayData].getName
 
   def emptyInputGenCode(ev: ExprCode): ExprCode = {
     ev.copy(code"""
@@ -256,7 +251,7 @@ case class ArraysZip(children: Seq[Expression]) extends Expression with ExpectsI
         ("ArrayData[]", arrVals) :: Nil)
 
     val initVariables = s"""
-      |ArrayData[] $arrVals = new ArrayData[$numberOfArrays];
+      |ArrayData[] $arrVals = new ArrayData[${children.length}];
       |int $biggestCardinality = 0;
       |${CodeGenerator.javaType(dataType)} ${ev.value} = null;
     """.stripMargin
@@ -268,7 +263,7 @@ case class ArraysZip(children: Seq[Expression]) extends Expression with ExpectsI
       |if (!${ev.isNull}) {
       |  Object[] $args = new Object[$biggestCardinality];
       |  for (int $i = 0; $i < $biggestCardinality; $i ++) {
-      |    Object[] $currentRow = new Object[$numberOfArrays];
+      |    Object[] $currentRow = new Object[${children.length}];
       |    $getValueForTypeSplitted
       |    $args[$i] = new $genericInternalRow($currentRow);
       |  }
@@ -278,7 +273,7 @@ case class ArraysZip(children: Seq[Expression]) extends Expression with ExpectsI
   }
 
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
-    if (numberOfArrays == 0) {
+    if (children.length == 0) {
       emptyInputGenCode(ev)
     } else {
       nonEmptyInputGenCode(ctx, ev)
@@ -360,7 +355,7 @@ case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInp
 
   override def inputTypes: Seq[AbstractDataType] = Seq(MapType)
 
-  lazy val childDataType: MapType = child.dataType.asInstanceOf[MapType]
+  @transient private lazy val childDataType: MapType = child.dataType.asInstanceOf[MapType]
 
   override def dataType: DataType = {
     ArrayType(
@@ -520,7 +515,7 @@ case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpres
     }
   }
 
-  override def dataType: MapType = {
+  @transient override lazy val dataType: MapType = {
     if (children.isEmpty) {
       MapType(StringType, StringType)
     } else {
@@ -747,11 +742,11 @@ case class MapFromEntries(child: Expression) extends UnaryExpression {
     case _ => None
   }
 
-  private def nullEntries: Boolean = dataTypeDetails.get._3
+  @transient private lazy val nullEntries: Boolean = dataTypeDetails.get._3
 
   override def nullable: Boolean = child.nullable || nullEntries
 
-  override def dataType: MapType = dataTypeDetails.get._1
+  @transient override lazy val dataType: MapType = dataTypeDetails.get._1
 
   override def checkInputDataTypes(): TypeCheckResult = dataTypeDetails match {
     case Some(_) => TypeCheckResult.TypeCheckSuccess
@@ -949,8 +944,7 @@ trait ArraySortLike extends ExpectsInputTypes {
 
   protected def nullOrder: NullOrder
 
-  @transient
-  private lazy val lt: Comparator[Any] = {
+  @transient private lazy val lt: Comparator[Any] = {
     val ordering = arrayExpression.dataType match {
       case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]]
       case _ @ ArrayType(a: ArrayType, _) => a.interpretedOrdering.asInstanceOf[Ordering[Any]]
@@ -972,8 +966,7 @@ trait ArraySortLike extends ExpectsInputTypes {
     }
   }
 
-  @transient
-  private lazy val gt: Comparator[Any] = {
+  @transient private lazy val gt: Comparator[Any] = {
     val ordering = arrayExpression.dataType match {
       case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]]
       case _ @ ArrayType(a: ArrayType, _) => a.interpretedOrdering.asInstanceOf[Ordering[Any]]
@@ -995,7 +988,9 @@ trait ArraySortLike extends ExpectsInputTypes {
     }
   }
 
-  def elementType: DataType = arrayExpression.dataType.asInstanceOf[ArrayType].elementType
+  @transient lazy val elementType: DataType =
+    arrayExpression.dataType.asInstanceOf[ArrayType].elementType
+
   def containsNull: Boolean = arrayExpression.dataType.asInstanceOf[ArrayType].containsNull
 
   def sortEval(array: Any, ascending: Boolean): Any = {
@@ -1211,7 +1206,7 @@ case class Reverse(child: Expression) extends UnaryExpression with ImplicitCastI
 
   override def dataType: DataType = child.dataType
 
-  lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType
+  @transient private lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType
 
   override def nullSafeEval(input: Any): Any = input match {
     case a: ArrayData => new GenericArrayData(a.toObjectArray(elementType).reverse)
@@ -1601,9 +1596,9 @@ case class Slice(x: Expression, start: Expression, length: Expression)
 
   override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, IntegerType, IntegerType)
 
-  override def children: Seq[Expression] = Seq(x, start, length)
+  @transient override lazy val children: Seq[Expression] = Seq(x, start, length) // called from eval
 
-  lazy val elementType: DataType = x.dataType.asInstanceOf[ArrayType].elementType
+  @transient private lazy val elementType: DataType = x.dataType.asInstanceOf[ArrayType].elementType
 
   override def nullSafeEval(xVal: Any, startVal: Any, lengthVal: Any): Any = {
     val startInt = startVal.asInstanceOf[Int]
@@ -1889,7 +1884,7 @@ case class ArrayMin(child: Expression) extends UnaryExpression with ImplicitCast
 
   override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType)
 
-  private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType)
+  @transient private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType)
 
   override def checkInputDataTypes(): TypeCheckResult = {
     val typeCheckResult = super.checkInputDataTypes()
@@ -1930,7 +1925,7 @@ case class ArrayMin(child: Expression) extends UnaryExpression with ImplicitCast
     min
   }
 
-  override def dataType: DataType = child.dataType match {
+  @transient override lazy val dataType: DataType = child.dataType match {
     case ArrayType(dt, _) => dt
     case _ => throw new IllegalStateException(s"$prettyName accepts only arrays.")
   }
@@ -1954,7 +1949,7 @@ case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCast
 
   override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType)
 
-  private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType)
+  @transient private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType)
 
   override def checkInputDataTypes(): TypeCheckResult = {
     val typeCheckResult = super.checkInputDataTypes()
@@ -1995,7 +1990,7 @@ case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCast
     max
   }
 
-  override def dataType: DataType = child.dataType match {
+  @transient override lazy val dataType: DataType = child.dataType match {
     case ArrayType(dt, _) => dt
     case _ => throw new IllegalStateException(s"$prettyName accepts only arrays.")
   }
@@ -2097,10 +2092,13 @@ case class ArrayPosition(left: Expression, right: Expression)
   since = "2.4.0")
 case class ElementAt(left: Expression, right: Expression) extends GetMapValueUtil {
 
-  @transient private lazy val ordering: Ordering[Any] =
-    TypeUtils.getInterpretedOrdering(left.dataType.asInstanceOf[MapType].keyType)
+  @transient private lazy val mapKeyType = left.dataType.asInstanceOf[MapType].keyType
+
+  @transient private lazy val arrayContainsNull = left.dataType.asInstanceOf[ArrayType].containsNull
+
+  @transient private lazy val ordering: Ordering[Any] = TypeUtils.getInterpretedOrdering(mapKeyType)
 
-  override def dataType: DataType = left.dataType match {
+  @transient override lazy val dataType: DataType = left.dataType match {
     case ArrayType(elementType, _) => elementType
     case MapType(_, valueType, _) => valueType
   }
@@ -2109,7 +2107,7 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti
     Seq(TypeCollection(ArrayType, MapType),
       left.dataType match {
         case _: ArrayType => IntegerType
-        case _: MapType => left.dataType.asInstanceOf[MapType].keyType
+        case _: MapType => mapKeyType
         case _ => AnyDataType // no match for a wrong 'left' expression type
       }
     )
@@ -2119,8 +2117,7 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti
     super.checkInputDataTypes() match {
       case f: TypeCheckResult.TypeCheckFailure => f
       case TypeCheckResult.TypeCheckSuccess if left.dataType.isInstanceOf[MapType] =>
-        TypeUtils.checkForOrderingExpr(
-          left.dataType.asInstanceOf[MapType].keyType, s"function $prettyName")
+        TypeUtils.checkForOrderingExpr(mapKeyType, s"function $prettyName")
       case TypeCheckResult.TypeCheckSuccess => TypeCheckResult.TypeCheckSuccess
     }
   }
@@ -2142,14 +2139,14 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti
           } else {
             array.numElements() + index
           }
-          if (left.dataType.asInstanceOf[ArrayType].containsNull && array.isNullAt(idx)) {
+          if (arrayContainsNull && array.isNullAt(idx)) {
             null
           } else {
             array.get(idx, dataType)
           }
         }
       case _: MapType =>
-        getValueEval(value, ordinal, left.dataType.asInstanceOf[MapType].keyType, ordering)
+        getValueEval(value, ordinal, mapKeyType, ordering)
     }
   }
 
@@ -2158,7 +2155,7 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti
       case _: ArrayType =>
         nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
           val index = ctx.freshName("elementAtIndex")
-          val nullCheck = if (left.dataType.asInstanceOf[ArrayType].containsNull) {
+          val nullCheck = if (arrayContainsNull) {
             s"""
                |if ($eval1.isNullAt($index)) {
                |  ${ev.isNull} = true;
@@ -2209,9 +2206,7 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti
   """)
 case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpression {
 
-  private val MAX_ARRAY_LENGTH: Int = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH
-
-  val allowedTypes = Seq(StringType, BinaryType, ArrayType)
+  private def allowedTypes: Seq[AbstractDataType] = Seq(StringType, BinaryType, ArrayType)
 
   override def checkInputDataTypes(): TypeCheckResult = {
     if (children.isEmpty) {
@@ -2228,7 +2223,7 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio
     }
   }
 
-  override def dataType: DataType = {
+  @transient override lazy val dataType: DataType = {
     if (children.isEmpty) {
       StringType
     } else {
@@ -2236,7 +2231,7 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio
     }
   }
 
-  lazy val javaType: String = CodeGenerator.javaType(dataType)
+  private def javaType: String = CodeGenerator.javaType(dataType)
 
   override def nullable: Boolean = children.exists(_.nullable)
 
@@ -2256,9 +2251,10 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio
       } else {
         val arrayData = inputs.map(_.asInstanceOf[ArrayData])
         val numberOfElements = arrayData.foldLeft(0L)((sum, ad) => sum + ad.numElements())
-        if (numberOfElements > MAX_ARRAY_LENGTH) {
+        if (numberOfElements > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
           throw new RuntimeException(s"Unsuccessful try to concat arrays with $numberOfElements" +
-            s" elements due to exceeding the array size limit $MAX_ARRAY_LENGTH.")
+            " elements due to exceeding the array size limit " +
+            ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH + ".")
         }
         val finalData = new Array[AnyRef](numberOfElements.toInt)
         var position = 0
@@ -2316,9 +2312,10 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio
         |for (int z = 0; z < ${children.length}; z++) {
         |  $numElements += args[z].numElements();
         |}
-        |if ($numElements > $MAX_ARRAY_LENGTH) {
+        |if ($numElements > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
         |  throw new RuntimeException("Unsuccessful try to concat arrays with " + $numElements +
-        |    " elements due to exceeding the array size limit $MAX_ARRAY_LENGTH.");
+        |    " elements due to exceeding the array size limit" +
+        |    " ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.");
         |}
       """.stripMargin
 
@@ -2413,15 +2410,13 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio
   since = "2.4.0")
 case class Flatten(child: Expression) extends UnaryExpression {
 
-  private val MAX_ARRAY_LENGTH = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH
-
-  private lazy val childDataType: ArrayType = child.dataType.asInstanceOf[ArrayType]
+  private def childDataType: ArrayType = child.dataType.asInstanceOf[ArrayType]
 
   override def nullable: Boolean = child.nullable || childDataType.containsNull
 
-  override def dataType: DataType = childDataType.elementType
+  @transient override lazy val dataType: DataType = childDataType.elementType
 
-  lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType
+  @transient private lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType
 
   override def checkInputDataTypes(): TypeCheckResult = child.dataType match {
     case ArrayType(_: ArrayType, _) =>
@@ -2441,9 +2436,10 @@ case class Flatten(child: Expression) extends UnaryExpression {
     } else {
       val arrayData = elements.map(_.asInstanceOf[ArrayData])
       val numberOfElements = arrayData.foldLeft(0L)((sum, e) => sum + e.numElements())
-      if (numberOfElements > MAX_ARRAY_LENGTH) {
+      if (numberOfElements > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
         throw new RuntimeException("Unsuccessful try to flatten an array of arrays with " +
-          s"$numberOfElements elements due to exceeding the array size limit $MAX_ARRAY_LENGTH.")
+          s"$numberOfElements elements due to exceeding the array size limit " +
+          ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH + ".")
       }
       val flattenedData = new Array(numberOfElements.toInt)
       var position = 0
@@ -2476,9 +2472,10 @@ case class Flatten(child: Expression) extends UnaryExpression {
       |for (int z = 0; z < $childVariableName.numElements(); z++) {
       |  $variableName += $childVariableName.getArray(z).numElements();
       |}
-      |if ($variableName > $MAX_ARRAY_LENGTH) {
+      |if ($variableName > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
       |  throw new RuntimeException("Unsuccessful try to flatten an array of arrays with " +
-      |    $variableName + " elements due to exceeding the array size limit $MAX_ARRAY_LENGTH.");
+      |    $variableName + " elements due to exceeding the array size limit" +
+      |    " ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.");
       |}
       """.stripMargin
     (code, variableName)
@@ -2602,7 +2599,7 @@ case class Sequence(
 
   override def nullable: Boolean = children.exists(_.nullable)
 
-  override lazy val dataType: ArrayType = ArrayType(start.dataType, containsNull = false)
+  override def dataType: ArrayType = ArrayType(start.dataType, containsNull = false)
 
   override def checkInputDataTypes(): TypeCheckResult = {
     val startType = start.dataType
@@ -2633,7 +2630,7 @@ case class Sequence(
     stepOpt.map(step => if (step.dataType != CalendarIntervalType) Cast(step, widerType) else step),
     timeZoneId)
 
-  private lazy val impl: SequenceImpl = dataType.elementType match {
+  @transient private lazy val impl: SequenceImpl = dataType.elementType match {
     case iType: IntegralType =>
       type T = iType.InternalType
       val ct = ClassTag[T](iType.tag.mirror.runtimeClass(iType.tag.tpe))
@@ -2953,8 +2950,6 @@ object Sequence {
 case class ArrayRepeat(left: Expression, right: Expression)
   extends BinaryExpression with ExpectsInputTypes {
 
-  private val MAX_ARRAY_LENGTH = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH
-
   override def dataType: ArrayType = ArrayType(left.dataType, left.nullable)
 
   override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, IntegerType)
@@ -2966,9 +2961,9 @@ case class ArrayRepeat(left: Expression, right: Expression)
     if (count == null) {
       null
     } else {
-      if (count.asInstanceOf[Int] > MAX_ARRAY_LENGTH) {
+      if (count.asInstanceOf[Int] > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
         throw new RuntimeException(s"Unsuccessful try to create array with $count elements " +
-          s"due to exceeding the array size limit $MAX_ARRAY_LENGTH.");
+          s"due to exceeding the array size limit ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.");
       }
       val element = left.eval(input)
       new GenericArrayData(Array.fill(count.asInstanceOf[Int])(element))
@@ -3027,9 +3022,10 @@ case class ArrayRepeat(left: Expression, right: Expression)
          |if ($count > 0) {
          |  $numElements = $count;
          |}
-         |if ($numElements > $MAX_ARRAY_LENGTH) {
+         |if ($numElements > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
          |  throw new RuntimeException("Unsuccessful try to create array with " + $numElements +
-         |    " elements due to exceeding the array size limit $MAX_ARRAY_LENGTH.");
+         |    " elements due to exceeding the array size limit" +
+         |    " ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.");
          |}
        """.stripMargin
 
@@ -3111,7 +3107,7 @@ case class ArrayRemove(left: Expression, right: Expression)
     Seq(ArrayType, elementType)
   }
 
-  lazy val elementType: DataType = left.dataType.asInstanceOf[ArrayType].elementType
+  private def elementType: DataType = left.dataType.asInstanceOf[ArrayType].elementType
 
   @transient private lazy val ordering: Ordering[Any] =
     TypeUtils.getInterpretedOrdering(right.dataType)
@@ -3228,7 +3224,7 @@ case class ArrayDistinct(child: Expression)
 
   override def dataType: DataType = child.dataType
 
-  @transient lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType
+  @transient private lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType
 
   @transient private lazy val ordering: Ordering[Any] =
     TypeUtils.getInterpretedOrdering(elementType)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org