You are viewing a plain text version of this content. The canonical link for it is here.
Posted to reviews@spark.apache.org by GitBox <gi...@apache.org> on 2021/09/17 02:00:01 UTC

[GitHub] [spark] HyukjinKwon commented on a change in pull request #34025: [SPARK-36673][SQL] Fix incorrect schema of nested types of union

HyukjinKwon commented on a change in pull request #34025:
URL: https://github.com/apache/spark/pull/34025#discussion_r710665992



##########
File path: sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
##########
@@ -559,6 +559,42 @@ object StructType extends AbstractDataType {
       case _ => dt
     }
 
+  /**
+   * This works a little similarly to `merge`, but it does not actually merge two DataTypes.
+   * This method just merges nullability.
+   */
+  private[sql] def mergeNullability(left: DataType, right: DataType): DataType =
+    (left, right) match {
+      case (ArrayType(leftElementType, leftContainsNull),
+          ArrayType(rightElementType, rightContainsNull)) =>
+        ArrayType(
+          mergeNullability(leftElementType, rightElementType),
+          leftContainsNull || rightContainsNull)
+
+      case (MapType(leftKeyType, leftValueType, leftContainsNull),
+          MapType(rightKeyType, rightValueType, rightContainsNull)) =>
+        MapType(
+          mergeNullability(leftKeyType, rightKeyType),
+          mergeNullability(leftValueType, rightValueType),
+          leftContainsNull || rightContainsNull)
+
+      case (StructType(leftFields), StructType(rightFields)) =>
+        require(leftFields.size == rightFields.size, "To merge nullability, " +
+          "two structs must have same number of fields.")
+
+        val newFields = leftFields.zip(rightFields).map {
+          case (leftField @ StructField(_, leftType, leftNullable, _),
+              _ @ StructField(_, rightType, rightNullable, _)) =>
+            leftField.copy(
+              dataType = mergeNullability(leftType, rightType),
+              nullable = leftNullable || rightNullable)
+        }.toSeq
+        StructType(newFields)
+
+      case (leftType, _) =>
+        leftType
+    }
+
   private[sql] def merge(left: DataType, right: DataType): DataType =

Review comment:
       What about something like this?
   
   ```diff
   diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
   index 5be328e0486..50e8d64feba 100644
   --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
   +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
   @@ -307,7 +307,7 @@ case class Union(
        children.map(_.output).transpose.map { attrs =>
          val firstAttr = attrs.head
          val nullable = attrs.exists(_.nullable)
   -      val newDt = attrs.map(_.dataType).reduce(StructType.mergeNullability)
   +      val newDt = attrs.map(_.dataType).reduce(StructType.unionLikeMerge)
          if (firstAttr.dataType == newDt) {
            firstAttr.withNullability(nullable)
          } else {
   diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
   index 7c0f9cc51e9..7d571bf9b1d 100644
   --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
   +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
   @@ -560,25 +560,12 @@ object StructType extends AbstractDataType {
        }
   
      /**
   -   * This works a little similarly to `merge`, but it does not actually merge two DataTypes.
   -   * This method just merges nullability.
   +   * Merge both struct types but it follows UNION semantic blah blah
       */
   -  private[sql] def mergeNullability(left: DataType, right: DataType): DataType =
   -    (left, right) match {
   -      case (ArrayType(leftElementType, leftContainsNull),
   -          ArrayType(rightElementType, rightContainsNull)) =>
   -        ArrayType(
   -          mergeNullability(leftElementType, rightElementType),
   -          leftContainsNull || rightContainsNull)
   -
   -      case (MapType(leftKeyType, leftValueType, leftContainsNull),
   -          MapType(rightKeyType, rightValueType, rightContainsNull)) =>
   -        MapType(
   -          mergeNullability(leftKeyType, rightKeyType),
   -          mergeNullability(leftValueType, rightValueType),
   -          leftContainsNull || rightContainsNull)
   -
   -      case (StructType(leftFields), StructType(rightFields)) =>
   +  private[sql] def unionLikeMerge(left: DataType, right: DataType): DataType =
   +    mergeInternal(left, right, (s1: StructType, s2: StructType) => {
   +      val leftFields = s1.fields
   +      val rightFields = s2.fields
            require(leftFields.size == rightFields.size, "To merge nullability, " +
              "two structs must have same number of fields.")
   
   @@ -586,32 +573,17 @@ object StructType extends AbstractDataType {
              case (leftField @ StructField(_, leftType, leftNullable, _),
                  _ @ StructField(_, rightType, rightNullable, _)) =>
                leftField.copy(
   -              dataType = mergeNullability(leftType, rightType),
   +              dataType = unionLikeMerge(leftType, rightType),
                  nullable = leftNullable || rightNullable)
            }.toSeq
            StructType(newFields)
   -
   -      case (leftType, _) =>
   -        leftType
   -    }
   +    })
   
      private[sql] def merge(left: DataType, right: DataType): DataType =
   -    (left, right) match {
   -      case (ArrayType(leftElementType, leftContainsNull),
   -      ArrayType(rightElementType, rightContainsNull)) =>
   -        ArrayType(
   -          merge(leftElementType, rightElementType),
   -          leftContainsNull || rightContainsNull)
   -
   -      case (MapType(leftKeyType, leftValueType, leftContainsNull),
   -      MapType(rightKeyType, rightValueType, rightContainsNull)) =>
   -        MapType(
   -          merge(leftKeyType, rightKeyType),
   -          merge(leftValueType, rightValueType),
   -          leftContainsNull || rightContainsNull)
   -
   -      case (StructType(leftFields), StructType(rightFields)) =>
   -        val newFields = mutable.ArrayBuffer.empty[StructField]
   +    mergeInternal(left, right, (s1: StructType, s2: StructType) => {
   +      val leftFields = s1.fields
   +      val rightFields = s2.fields
   +      val newFields = mutable.ArrayBuffer.empty[StructField]
   
            val rightMapped = fieldsMap(rightFields)
            leftFields.foreach {
   @@ -641,6 +613,27 @@ object StructType extends AbstractDataType {
              }
   
            StructType(newFields.toSeq)
   +    })
   +
   +  private def mergeInternal(
   +      left: DataType,
   +      right: DataType,
   +      mergeStruct: (StructType, StructType) => StructType): DataType = {
   +    (left, right) match {
   +      case (ArrayType(leftElementType, leftContainsNull),
   +      ArrayType(rightElementType, rightContainsNull)) =>
   +        ArrayType(
   +          merge(leftElementType, rightElementType),
   +          leftContainsNull || rightContainsNull)
   +
   +      case (MapType(leftKeyType, leftValueType, leftContainsNull),
   +      MapType(rightKeyType, rightValueType, rightContainsNull)) =>
   +        MapType(
   +          merge(leftKeyType, rightKeyType),
   +          merge(leftValueType, rightValueType),
   +          leftContainsNull || rightContainsNull)
   +
   +      case (s1: StructType, s2: StructType) => mergeStruct(s1, s2)
   
          case (DecimalType.Fixed(leftPrecision, leftScale),
            DecimalType.Fixed(rightPrecision, rightScale)) =>
   @@ -666,6 +659,7 @@ object StructType extends AbstractDataType {
          case _ =>
            throw QueryExecutionErrors.cannotMergeIncompatibleDataTypesError(left, right)
        }
   +  }
   
      private[sql] def fieldsMap(fields: Array[StructField]): Map[String, StructField] = {
        // Mimics the optimization of breakOut, not present in Scala 2.13, while working in 2.12
   diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
   index b5ae4b3fac8..8e0080a2469 100644
   --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
   +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
   @@ -684,7 +684,7 @@ case class UnionExec(children: Seq[SparkPlan]) extends SparkPlan {
        children.map(_.output).transpose.map { attrs =>
          val firstAttr = attrs.head
          val nullable = attrs.exists(_.nullable)
   -      val newDt = attrs.map(_.dataType).reduce(StructType.mergeNullability)
   +      val newDt = attrs.map(_.dataType).reduce(StructType.unionLikeMerge)
          if (firstAttr.dataType == newDt) {
            firstAttr.withNullability(nullable)
          } else {
   ```




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org