You are viewing a plain text version of this content. The canonical link for it is here.
Posted to reviews@spark.apache.org by GitBox <gi...@apache.org> on 2020/10/05 02:50:36 UTC

[GitHub] [spark] maropu commented on a change in pull request #29587: [SPARK-32376][SQL] Make unionByName null-filling behavior work with struct columns

maropu commented on a change in pull request #29587:
URL: https://github.com/apache/spark/pull/29587#discussion_r499321425



##########
File path: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala
##########
@@ -17,29 +17,202 @@
 
 package org.apache.spark.sql.catalyst.analysis
 
+import scala.collection.mutable
+
 import org.apache.spark.sql.AnalysisException
-import org.apache.spark.sql.catalyst.expressions.{Alias, Literal}
+import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.optimizer.CombineUnions
 import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Union}
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types._
 import org.apache.spark.sql.util.SchemaUtils
+import org.apache.spark.unsafe.types.UTF8String
 
 /**
  * Resolves different children of Union to a common set of columns.
  */
 object ResolveUnion extends Rule[LogicalPlan] {
-  private def unionTwoSides(
+  /**
+   * This method sorts recursively columns in a struct expression based on column names.
+   */
+  private def sortStructFields(expr: Expression): Expression = {
+    val existingExprs = expr.dataType.asInstanceOf[StructType].fieldNames.zipWithIndex.map {
+      case (name, i) =>
+        val fieldExpr = GetStructField(KnownNotNull(expr), i)
+        if (fieldExpr.dataType.isInstanceOf[StructType]) {
+          (name, sortStructFields(fieldExpr))
+        } else {
+          (name, fieldExpr)
+        }
+    }.sortBy(_._1).flatMap(pair => Seq(Literal(pair._1), pair._2))
+
+    val newExpr = CreateNamedStruct(existingExprs)
+    if (expr.nullable) {
+      If(IsNull(expr), Literal(null, newExpr.dataType), newExpr)
+    } else {
+      newExpr
+    }
+  }
+
+  /**
+   * Assumes input expressions are field expression of `CreateNamedStruct`. This method
+   * sorts the expressions based on field names.
+   */
+  private def sortFieldExprs(fieldExprs: Seq[Expression]): Seq[Expression] = {
+    fieldExprs.grouped(2).map { e =>
+      Seq(e.head, e.last)
+    }.toSeq.sortBy { pair =>
+      assert(pair.head.isInstanceOf[Literal])
+      pair.head.eval().asInstanceOf[UTF8String].toString
+    }.flatten
+  }
+
+  /**
+   * This helper method sorts fields in a `WithFields` expression by field name.
+   */
+  private def sortStructFieldsInWithFields(expr: Expression): Expression = expr transformUp {
+    case w: WithFields if w.resolved =>
+      w.evalExpr match {
+        case i @ If(IsNull(_), _, CreateNamedStruct(fieldExprs)) =>
+          val sorted = sortFieldExprs(fieldExprs)
+          val newStruct = CreateNamedStruct(sorted)
+          i.copy(trueValue = Literal(null, newStruct.dataType), falseValue = newStruct)
+        case CreateNamedStruct(fieldExprs) =>
+          val sorted = sortFieldExprs(fieldExprs)
+          val newStruct = CreateNamedStruct(sorted)
+          newStruct
+        case other =>

Review comment:
       If this case means a program bug, the message should include `Please file a bug report ...` like the others?
   https://github.com/apache/spark/blob/fab53212cb110a81696cee8546c35095332f6e09/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala#L2747-L2748

##########
File path: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala
##########
@@ -17,29 +17,202 @@
 
 package org.apache.spark.sql.catalyst.analysis
 
+import scala.collection.mutable
+
 import org.apache.spark.sql.AnalysisException
-import org.apache.spark.sql.catalyst.expressions.{Alias, Literal}
+import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.optimizer.CombineUnions
 import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Union}
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types._
 import org.apache.spark.sql.util.SchemaUtils
+import org.apache.spark.unsafe.types.UTF8String
 
 /**
  * Resolves different children of Union to a common set of columns.
  */
 object ResolveUnion extends Rule[LogicalPlan] {
-  private def unionTwoSides(
+  /**
+   * This method sorts recursively columns in a struct expression based on column names.
+   */
+  private def sortStructFields(expr: Expression): Expression = {
+    val existingExprs = expr.dataType.asInstanceOf[StructType].fieldNames.zipWithIndex.map {
+      case (name, i) =>
+        val fieldExpr = GetStructField(KnownNotNull(expr), i)
+        if (fieldExpr.dataType.isInstanceOf[StructType]) {
+          (name, sortStructFields(fieldExpr))
+        } else {
+          (name, fieldExpr)
+        }
+    }.sortBy(_._1).flatMap(pair => Seq(Literal(pair._1), pair._2))
+
+    val newExpr = CreateNamedStruct(existingExprs)
+    if (expr.nullable) {
+      If(IsNull(expr), Literal(null, newExpr.dataType), newExpr)
+    } else {
+      newExpr
+    }
+  }
+
+  /**
+   * Assumes input expressions are field expression of `CreateNamedStruct`. This method
+   * sorts the expressions based on field names.
+   */
+  private def sortFieldExprs(fieldExprs: Seq[Expression]): Seq[Expression] = {
+    fieldExprs.grouped(2).map { e =>
+      Seq(e.head, e.last)
+    }.toSeq.sortBy { pair =>
+      assert(pair.head.isInstanceOf[Literal])
+      pair.head.eval().asInstanceOf[UTF8String].toString
+    }.flatten
+  }
+
+  /**
+   * This helper method sorts fields in a `WithFields` expression by field name.
+   */
+  private def sortStructFieldsInWithFields(expr: Expression): Expression = expr transformUp {
+    case w: WithFields if w.resolved =>
+      w.evalExpr match {
+        case i @ If(IsNull(_), _, CreateNamedStruct(fieldExprs)) =>
+          val sorted = sortFieldExprs(fieldExprs)
+          val newStruct = CreateNamedStruct(sorted)
+          i.copy(trueValue = Literal(null, newStruct.dataType), falseValue = newStruct)
+        case CreateNamedStruct(fieldExprs) =>
+          val sorted = sortFieldExprs(fieldExprs)
+          val newStruct = CreateNamedStruct(sorted)
+          newStruct
+        case other =>
+          throw new AnalysisException(s"`WithFields` has incorrect eval expression: $other")
+      }
+  }
+
+  def simplifyWithFields(expr: Expression): Expression = {

Review comment:
       nit: `private`. Btw, all the transformations in this method will be moved into an optimizer rule in followup? We normally add tests when adding a new rule, but this PR does not have any test for them.

##########
File path: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala
##########
@@ -17,29 +17,202 @@
 
 package org.apache.spark.sql.catalyst.analysis
 
+import scala.collection.mutable
+
 import org.apache.spark.sql.AnalysisException
-import org.apache.spark.sql.catalyst.expressions.{Alias, Literal}
+import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.optimizer.CombineUnions
 import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Union}
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types._
 import org.apache.spark.sql.util.SchemaUtils
+import org.apache.spark.unsafe.types.UTF8String
 
 /**
  * Resolves different children of Union to a common set of columns.
  */
 object ResolveUnion extends Rule[LogicalPlan] {
-  private def unionTwoSides(
+  /**
+   * This method sorts recursively columns in a struct expression based on column names.
+   */
+  private def sortStructFields(expr: Expression): Expression = {
+    val existingExprs = expr.dataType.asInstanceOf[StructType].fieldNames.zipWithIndex.map {
+      case (name, i) =>
+        val fieldExpr = GetStructField(KnownNotNull(expr), i)
+        if (fieldExpr.dataType.isInstanceOf[StructType]) {
+          (name, sortStructFields(fieldExpr))
+        } else {
+          (name, fieldExpr)
+        }
+    }.sortBy(_._1).flatMap(pair => Seq(Literal(pair._1), pair._2))
+
+    val newExpr = CreateNamedStruct(existingExprs)
+    if (expr.nullable) {
+      If(IsNull(expr), Literal(null, newExpr.dataType), newExpr)
+    } else {
+      newExpr
+    }
+  }
+
+  /**
+   * Assumes input expressions are field expression of `CreateNamedStruct`. This method
+   * sorts the expressions based on field names.
+   */
+  private def sortFieldExprs(fieldExprs: Seq[Expression]): Seq[Expression] = {
+    fieldExprs.grouped(2).map { e =>
+      Seq(e.head, e.last)
+    }.toSeq.sortBy { pair =>
+      assert(pair.head.isInstanceOf[Literal])
+      pair.head.eval().asInstanceOf[UTF8String].toString
+    }.flatten
+  }
+
+  /**
+   * This helper method sorts fields in a `WithFields` expression by field name.
+   */
+  private def sortStructFieldsInWithFields(expr: Expression): Expression = expr transformUp {
+    case w: WithFields if w.resolved =>
+      w.evalExpr match {
+        case i @ If(IsNull(_), _, CreateNamedStruct(fieldExprs)) =>
+          val sorted = sortFieldExprs(fieldExprs)
+          val newStruct = CreateNamedStruct(sorted)
+          i.copy(trueValue = Literal(null, newStruct.dataType), falseValue = newStruct)
+        case CreateNamedStruct(fieldExprs) =>
+          val sorted = sortFieldExprs(fieldExprs)
+          val newStruct = CreateNamedStruct(sorted)
+          newStruct
+        case other =>
+          throw new AnalysisException(s"`WithFields` has incorrect eval expression: $other")
+      }
+  }
+
+  def simplifyWithFields(expr: Expression): Expression = {
+    expr.transformUp {
+      case WithFields(structExpr, names, values) if names.distinct.length != names.length =>
+        val newNames = mutable.ArrayBuffer.empty[String]
+        val newValues = mutable.ArrayBuffer.empty[Expression]
+        names.zip(values).reverse.foreach { case (name, value) =>
+          if (!newNames.contains(name)) {
+            newNames += name
+            newValues += value
+          }
+        }
+        WithFields(structExpr, names = newNames.reverse, valExprs = newValues.reverse)
+      case WithFields(WithFields(struct, names1, valExprs1), names2, valExprs2) =>
+        WithFields(struct, names1 ++ names2, valExprs1 ++ valExprs2)

Review comment:
       duplicated? https://github.com/apache/spark/blob/fab53212cb110a81696cee8546c35095332f6e09/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/WithFields.scala#L30-L31

##########
File path: sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
##########
@@ -2721,6 +2721,16 @@ object SQLConf {
       .booleanConf
       .createWithDefault(false)
 
+  val UNION_BYNAME_STRUCT_SUPPORT_ENABLED =
+    buildConf("spark.sql.unionByName.structSupport.enabled")
+      .doc("When true, the `allowMissingColumns` feature of `Dataset.unionByName` supports " +
+        "nested column in struct types. Missing nested columns of struct columns with same " +
+        "name will also be filled with null values. This currently does not support nested " +
+        "columns in array and map types.")

Review comment:
       How about explaining the behavior of the fields being sorted when merging them? I'm a bit worried that users might be surprised about the behavior.
   ```
   scala> val df1 = spark.range(1).selectExpr("id c0", "named_struct('c', id + 1, 'b', id + 2, 'a', id + 3) c1")
   scala> val df2 = spark.range(1).selectExpr("id c0", "named_struct('c', id + 1, 'b', id + 2) c1")
   
   scala> df1.unionByName(df1, true).printSchema()
   root
    |-- c0: long (nullable = false)
    |-- c1: struct (nullable = false)
    |    |-- c: long (nullable = false)
    |    |-- b: long (nullable = false)
    |    |-- a: long (nullable = false)
   
   scala> df1.unionByName(df2, true).printSchema()
   root
    |-- c0: long (nullable = false)
    |-- c1: struct (nullable = false)
    |    |-- a: long (nullable = true)
    |    |-- b: long (nullable = false)
    |    |-- c: long (nullable = false)
   ```
   




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org