You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2018/07/17 15:07:24 UTC
spark git commit: [SPARK-24305][SQL][FOLLOWUP] Avoid serialization of
private fields in collection expressions.
Repository: spark
Updated Branches:
refs/heads/master 0ca16f6e1 -> 4cf1bec4d
[SPARK-24305][SQL][FOLLOWUP] Avoid serialization of private fields in collection expressions.
## What changes were proposed in this pull request?
The PR tries to avoid serialization of private fields of already added collection functions and follows up on comments in [SPARK-23922](https://github.com/apache/spark/pull/21028) and [SPARK-23935](https://github.com/apache/spark/pull/21236)
## How was this patch tested?
Run tests from:
- CollectionExpressionSuite.scala
- DataFrameFunctionsSuite.scala
Author: Marek Novotny <mn...@gmail.com>
Closes #21352 from mn-mikke/SPARK-24305.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/4cf1bec4
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/4cf1bec4
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/4cf1bec4
Branch: refs/heads/master
Commit: 4cf1bec4dc574c541d03ea2f49db4de8b76ef6d2
Parents: 0ca16f6
Author: Marek Novotny <mn...@gmail.com>
Authored: Tue Jul 17 23:07:18 2018 +0800
Committer: Wenchen Fan <we...@databricks.com>
Committed: Tue Jul 17 23:07:18 2018 +0800
----------------------------------------------------------------------
.../expressions/collectionOperations.scala | 132 +++++++++----------
1 file changed, 64 insertions(+), 68 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/4cf1bec4/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index 972bc6e..d60f4c3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -168,27 +168,22 @@ case class ArraysZip(children: Seq[Expression]) extends Expression with ExpectsI
override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.length)(ArrayType)
- override def dataType: DataType = ArrayType(mountSchema)
-
- override def nullable: Boolean = children.exists(_.nullable)
-
- private lazy val arrayTypes = children.map(_.dataType.asInstanceOf[ArrayType])
-
- private lazy val arrayElementTypes = arrayTypes.map(_.elementType)
-
- @transient private lazy val mountSchema: StructType = {
+ @transient override lazy val dataType: DataType = {
val fields = children.zip(arrayElementTypes).zipWithIndex.map {
case ((expr: NamedExpression, elementType), _) =>
StructField(expr.name, elementType, nullable = true)
case ((_, elementType), idx) =>
StructField(idx.toString, elementType, nullable = true)
}
- StructType(fields)
+ ArrayType(StructType(fields), containsNull = false)
}
- @transient lazy val numberOfArrays: Int = children.length
+ override def nullable: Boolean = children.exists(_.nullable)
+
+ @transient private lazy val arrayElementTypes =
+ children.map(_.dataType.asInstanceOf[ArrayType].elementType)
- @transient lazy val genericArrayData = classOf[GenericArrayData].getName
+ private def genericArrayData = classOf[GenericArrayData].getName
def emptyInputGenCode(ev: ExprCode): ExprCode = {
ev.copy(code"""
@@ -256,7 +251,7 @@ case class ArraysZip(children: Seq[Expression]) extends Expression with ExpectsI
("ArrayData[]", arrVals) :: Nil)
val initVariables = s"""
- |ArrayData[] $arrVals = new ArrayData[$numberOfArrays];
+ |ArrayData[] $arrVals = new ArrayData[${children.length}];
|int $biggestCardinality = 0;
|${CodeGenerator.javaType(dataType)} ${ev.value} = null;
""".stripMargin
@@ -268,7 +263,7 @@ case class ArraysZip(children: Seq[Expression]) extends Expression with ExpectsI
|if (!${ev.isNull}) {
| Object[] $args = new Object[$biggestCardinality];
| for (int $i = 0; $i < $biggestCardinality; $i ++) {
- | Object[] $currentRow = new Object[$numberOfArrays];
+ | Object[] $currentRow = new Object[${children.length}];
| $getValueForTypeSplitted
| $args[$i] = new $genericInternalRow($currentRow);
| }
@@ -278,7 +273,7 @@ case class ArraysZip(children: Seq[Expression]) extends Expression with ExpectsI
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
- if (numberOfArrays == 0) {
+ if (children.length == 0) {
emptyInputGenCode(ev)
} else {
nonEmptyInputGenCode(ctx, ev)
@@ -360,7 +355,7 @@ case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInp
override def inputTypes: Seq[AbstractDataType] = Seq(MapType)
- lazy val childDataType: MapType = child.dataType.asInstanceOf[MapType]
+ @transient private lazy val childDataType: MapType = child.dataType.asInstanceOf[MapType]
override def dataType: DataType = {
ArrayType(
@@ -520,7 +515,7 @@ case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpres
}
}
- override def dataType: MapType = {
+ @transient override lazy val dataType: MapType = {
if (children.isEmpty) {
MapType(StringType, StringType)
} else {
@@ -747,11 +742,11 @@ case class MapFromEntries(child: Expression) extends UnaryExpression {
case _ => None
}
- private def nullEntries: Boolean = dataTypeDetails.get._3
+ @transient private lazy val nullEntries: Boolean = dataTypeDetails.get._3
override def nullable: Boolean = child.nullable || nullEntries
- override def dataType: MapType = dataTypeDetails.get._1
+ @transient override lazy val dataType: MapType = dataTypeDetails.get._1
override def checkInputDataTypes(): TypeCheckResult = dataTypeDetails match {
case Some(_) => TypeCheckResult.TypeCheckSuccess
@@ -949,8 +944,7 @@ trait ArraySortLike extends ExpectsInputTypes {
protected def nullOrder: NullOrder
- @transient
- private lazy val lt: Comparator[Any] = {
+ @transient private lazy val lt: Comparator[Any] = {
val ordering = arrayExpression.dataType match {
case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]]
case _ @ ArrayType(a: ArrayType, _) => a.interpretedOrdering.asInstanceOf[Ordering[Any]]
@@ -972,8 +966,7 @@ trait ArraySortLike extends ExpectsInputTypes {
}
}
- @transient
- private lazy val gt: Comparator[Any] = {
+ @transient private lazy val gt: Comparator[Any] = {
val ordering = arrayExpression.dataType match {
case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]]
case _ @ ArrayType(a: ArrayType, _) => a.interpretedOrdering.asInstanceOf[Ordering[Any]]
@@ -995,7 +988,9 @@ trait ArraySortLike extends ExpectsInputTypes {
}
}
- def elementType: DataType = arrayExpression.dataType.asInstanceOf[ArrayType].elementType
+ @transient lazy val elementType: DataType =
+ arrayExpression.dataType.asInstanceOf[ArrayType].elementType
+
def containsNull: Boolean = arrayExpression.dataType.asInstanceOf[ArrayType].containsNull
def sortEval(array: Any, ascending: Boolean): Any = {
@@ -1211,7 +1206,7 @@ case class Reverse(child: Expression) extends UnaryExpression with ImplicitCastI
override def dataType: DataType = child.dataType
- lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType
+ @transient private lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType
override def nullSafeEval(input: Any): Any = input match {
case a: ArrayData => new GenericArrayData(a.toObjectArray(elementType).reverse)
@@ -1601,9 +1596,9 @@ case class Slice(x: Expression, start: Expression, length: Expression)
override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, IntegerType, IntegerType)
- override def children: Seq[Expression] = Seq(x, start, length)
+ @transient override lazy val children: Seq[Expression] = Seq(x, start, length) // called from eval
- lazy val elementType: DataType = x.dataType.asInstanceOf[ArrayType].elementType
+ @transient private lazy val elementType: DataType = x.dataType.asInstanceOf[ArrayType].elementType
override def nullSafeEval(xVal: Any, startVal: Any, lengthVal: Any): Any = {
val startInt = startVal.asInstanceOf[Int]
@@ -1889,7 +1884,7 @@ case class ArrayMin(child: Expression) extends UnaryExpression with ImplicitCast
override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType)
- private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType)
+ @transient private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType)
override def checkInputDataTypes(): TypeCheckResult = {
val typeCheckResult = super.checkInputDataTypes()
@@ -1930,7 +1925,7 @@ case class ArrayMin(child: Expression) extends UnaryExpression with ImplicitCast
min
}
- override def dataType: DataType = child.dataType match {
+ @transient override lazy val dataType: DataType = child.dataType match {
case ArrayType(dt, _) => dt
case _ => throw new IllegalStateException(s"$prettyName accepts only arrays.")
}
@@ -1954,7 +1949,7 @@ case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCast
override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType)
- private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType)
+ @transient private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType)
override def checkInputDataTypes(): TypeCheckResult = {
val typeCheckResult = super.checkInputDataTypes()
@@ -1995,7 +1990,7 @@ case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCast
max
}
- override def dataType: DataType = child.dataType match {
+ @transient override lazy val dataType: DataType = child.dataType match {
case ArrayType(dt, _) => dt
case _ => throw new IllegalStateException(s"$prettyName accepts only arrays.")
}
@@ -2097,10 +2092,13 @@ case class ArrayPosition(left: Expression, right: Expression)
since = "2.4.0")
case class ElementAt(left: Expression, right: Expression) extends GetMapValueUtil {
- @transient private lazy val ordering: Ordering[Any] =
- TypeUtils.getInterpretedOrdering(left.dataType.asInstanceOf[MapType].keyType)
+ @transient private lazy val mapKeyType = left.dataType.asInstanceOf[MapType].keyType
+
+ @transient private lazy val arrayContainsNull = left.dataType.asInstanceOf[ArrayType].containsNull
+
+ @transient private lazy val ordering: Ordering[Any] = TypeUtils.getInterpretedOrdering(mapKeyType)
- override def dataType: DataType = left.dataType match {
+ @transient override lazy val dataType: DataType = left.dataType match {
case ArrayType(elementType, _) => elementType
case MapType(_, valueType, _) => valueType
}
@@ -2109,7 +2107,7 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti
Seq(TypeCollection(ArrayType, MapType),
left.dataType match {
case _: ArrayType => IntegerType
- case _: MapType => left.dataType.asInstanceOf[MapType].keyType
+ case _: MapType => mapKeyType
case _ => AnyDataType // no match for a wrong 'left' expression type
}
)
@@ -2119,8 +2117,7 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti
super.checkInputDataTypes() match {
case f: TypeCheckResult.TypeCheckFailure => f
case TypeCheckResult.TypeCheckSuccess if left.dataType.isInstanceOf[MapType] =>
- TypeUtils.checkForOrderingExpr(
- left.dataType.asInstanceOf[MapType].keyType, s"function $prettyName")
+ TypeUtils.checkForOrderingExpr(mapKeyType, s"function $prettyName")
case TypeCheckResult.TypeCheckSuccess => TypeCheckResult.TypeCheckSuccess
}
}
@@ -2142,14 +2139,14 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti
} else {
array.numElements() + index
}
- if (left.dataType.asInstanceOf[ArrayType].containsNull && array.isNullAt(idx)) {
+ if (arrayContainsNull && array.isNullAt(idx)) {
null
} else {
array.get(idx, dataType)
}
}
case _: MapType =>
- getValueEval(value, ordinal, left.dataType.asInstanceOf[MapType].keyType, ordering)
+ getValueEval(value, ordinal, mapKeyType, ordering)
}
}
@@ -2158,7 +2155,7 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti
case _: ArrayType =>
nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
val index = ctx.freshName("elementAtIndex")
- val nullCheck = if (left.dataType.asInstanceOf[ArrayType].containsNull) {
+ val nullCheck = if (arrayContainsNull) {
s"""
|if ($eval1.isNullAt($index)) {
| ${ev.isNull} = true;
@@ -2209,9 +2206,7 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti
""")
case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpression {
- private val MAX_ARRAY_LENGTH: Int = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH
-
- val allowedTypes = Seq(StringType, BinaryType, ArrayType)
+ private def allowedTypes: Seq[AbstractDataType] = Seq(StringType, BinaryType, ArrayType)
override def checkInputDataTypes(): TypeCheckResult = {
if (children.isEmpty) {
@@ -2228,7 +2223,7 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio
}
}
- override def dataType: DataType = {
+ @transient override lazy val dataType: DataType = {
if (children.isEmpty) {
StringType
} else {
@@ -2236,7 +2231,7 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio
}
}
- lazy val javaType: String = CodeGenerator.javaType(dataType)
+ private def javaType: String = CodeGenerator.javaType(dataType)
override def nullable: Boolean = children.exists(_.nullable)
@@ -2256,9 +2251,10 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio
} else {
val arrayData = inputs.map(_.asInstanceOf[ArrayData])
val numberOfElements = arrayData.foldLeft(0L)((sum, ad) => sum + ad.numElements())
- if (numberOfElements > MAX_ARRAY_LENGTH) {
+ if (numberOfElements > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
throw new RuntimeException(s"Unsuccessful try to concat arrays with $numberOfElements" +
- s" elements due to exceeding the array size limit $MAX_ARRAY_LENGTH.")
+ " elements due to exceeding the array size limit " +
+ ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH + ".")
}
val finalData = new Array[AnyRef](numberOfElements.toInt)
var position = 0
@@ -2316,9 +2312,10 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio
|for (int z = 0; z < ${children.length}; z++) {
| $numElements += args[z].numElements();
|}
- |if ($numElements > $MAX_ARRAY_LENGTH) {
+ |if ($numElements > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
| throw new RuntimeException("Unsuccessful try to concat arrays with " + $numElements +
- | " elements due to exceeding the array size limit $MAX_ARRAY_LENGTH.");
+ | " elements due to exceeding the array size limit" +
+ | " ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.");
|}
""".stripMargin
@@ -2413,15 +2410,13 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio
since = "2.4.0")
case class Flatten(child: Expression) extends UnaryExpression {
- private val MAX_ARRAY_LENGTH = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH
-
- private lazy val childDataType: ArrayType = child.dataType.asInstanceOf[ArrayType]
+ private def childDataType: ArrayType = child.dataType.asInstanceOf[ArrayType]
override def nullable: Boolean = child.nullable || childDataType.containsNull
- override def dataType: DataType = childDataType.elementType
+ @transient override lazy val dataType: DataType = childDataType.elementType
- lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType
+ @transient private lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType
override def checkInputDataTypes(): TypeCheckResult = child.dataType match {
case ArrayType(_: ArrayType, _) =>
@@ -2441,9 +2436,10 @@ case class Flatten(child: Expression) extends UnaryExpression {
} else {
val arrayData = elements.map(_.asInstanceOf[ArrayData])
val numberOfElements = arrayData.foldLeft(0L)((sum, e) => sum + e.numElements())
- if (numberOfElements > MAX_ARRAY_LENGTH) {
+ if (numberOfElements > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
throw new RuntimeException("Unsuccessful try to flatten an array of arrays with " +
- s"$numberOfElements elements due to exceeding the array size limit $MAX_ARRAY_LENGTH.")
+ s"$numberOfElements elements due to exceeding the array size limit " +
+ ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH + ".")
}
val flattenedData = new Array(numberOfElements.toInt)
var position = 0
@@ -2476,9 +2472,10 @@ case class Flatten(child: Expression) extends UnaryExpression {
|for (int z = 0; z < $childVariableName.numElements(); z++) {
| $variableName += $childVariableName.getArray(z).numElements();
|}
- |if ($variableName > $MAX_ARRAY_LENGTH) {
+ |if ($variableName > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
| throw new RuntimeException("Unsuccessful try to flatten an array of arrays with " +
- | $variableName + " elements due to exceeding the array size limit $MAX_ARRAY_LENGTH.");
+ | $variableName + " elements due to exceeding the array size limit" +
+ | " ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.");
|}
""".stripMargin
(code, variableName)
@@ -2602,7 +2599,7 @@ case class Sequence(
override def nullable: Boolean = children.exists(_.nullable)
- override lazy val dataType: ArrayType = ArrayType(start.dataType, containsNull = false)
+ override def dataType: ArrayType = ArrayType(start.dataType, containsNull = false)
override def checkInputDataTypes(): TypeCheckResult = {
val startType = start.dataType
@@ -2633,7 +2630,7 @@ case class Sequence(
stepOpt.map(step => if (step.dataType != CalendarIntervalType) Cast(step, widerType) else step),
timeZoneId)
- private lazy val impl: SequenceImpl = dataType.elementType match {
+ @transient private lazy val impl: SequenceImpl = dataType.elementType match {
case iType: IntegralType =>
type T = iType.InternalType
val ct = ClassTag[T](iType.tag.mirror.runtimeClass(iType.tag.tpe))
@@ -2953,8 +2950,6 @@ object Sequence {
case class ArrayRepeat(left: Expression, right: Expression)
extends BinaryExpression with ExpectsInputTypes {
- private val MAX_ARRAY_LENGTH = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH
-
override def dataType: ArrayType = ArrayType(left.dataType, left.nullable)
override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, IntegerType)
@@ -2966,9 +2961,9 @@ case class ArrayRepeat(left: Expression, right: Expression)
if (count == null) {
null
} else {
- if (count.asInstanceOf[Int] > MAX_ARRAY_LENGTH) {
+ if (count.asInstanceOf[Int] > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
throw new RuntimeException(s"Unsuccessful try to create array with $count elements " +
- s"due to exceeding the array size limit $MAX_ARRAY_LENGTH.");
+ s"due to exceeding the array size limit ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.");
}
val element = left.eval(input)
new GenericArrayData(Array.fill(count.asInstanceOf[Int])(element))
@@ -3027,9 +3022,10 @@ case class ArrayRepeat(left: Expression, right: Expression)
|if ($count > 0) {
| $numElements = $count;
|}
- |if ($numElements > $MAX_ARRAY_LENGTH) {
+ |if ($numElements > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
| throw new RuntimeException("Unsuccessful try to create array with " + $numElements +
- | " elements due to exceeding the array size limit $MAX_ARRAY_LENGTH.");
+ | " elements due to exceeding the array size limit" +
+ | " ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.");
|}
""".stripMargin
@@ -3111,7 +3107,7 @@ case class ArrayRemove(left: Expression, right: Expression)
Seq(ArrayType, elementType)
}
- lazy val elementType: DataType = left.dataType.asInstanceOf[ArrayType].elementType
+ private def elementType: DataType = left.dataType.asInstanceOf[ArrayType].elementType
@transient private lazy val ordering: Ordering[Any] =
TypeUtils.getInterpretedOrdering(right.dataType)
@@ -3228,7 +3224,7 @@ case class ArrayDistinct(child: Expression)
override def dataType: DataType = child.dataType
- @transient lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType
+ @transient private lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType
@transient private lazy val ordering: Ordering[Any] =
TypeUtils.getInterpretedOrdering(elementType)
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org