You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2020/07/07 16:35:27 UTC

[spark] branch master updated: [SPARK-31317][SQL] Add withField method to Column

This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 4bbc343  [SPARK-31317][SQL] Add withField method to Column
4bbc343 is described below

commit 4bbc343a4c0ff3d4b5443bd65adf2df55b9245ee
Author: fqaiser94@gmail.com <fq...@gmail.com>
AuthorDate: Tue Jul 7 16:34:03 2020 +0000

    [SPARK-31317][SQL] Add withField method to Column
    
    ### What changes were proposed in this pull request?
    
    Added a new `withField` method to the `Column` class. This method should allow users to add or replace a `StructField` in a `StructType` column (with very similar semantics to the `withColumn` method on `Dataset`).
    
    ### Why are the changes needed?
    
    Often Spark users have to work with deeply nested data e.g. to fix a data quality issue with an existing `StructField`. To do this with the existing Spark APIs, users have to rebuild the entire struct column.
    
    For example, let's say you have the following deeply nested data structure which has a data quality issue (`5` is missing):
    ```
    import org.apache.spark.sql._
    import org.apache.spark.sql.functions._
    import org.apache.spark.sql.types._
    
    val data = spark.createDataFrame(sc.parallelize(
          Seq(Row(Row(Row(1, 2, 3), Row(Row(4, null, 6), Row(7, 8, 9), Row(10, 11, 12)), Row(13, 14, 15))))),
          StructType(Seq(
            StructField("a", StructType(Seq(
              StructField("a", StructType(Seq(
                StructField("a", IntegerType),
                StructField("b", IntegerType),
                StructField("c", IntegerType)))),
              StructField("b", StructType(Seq(
                StructField("a", StructType(Seq(
                  StructField("a", IntegerType),
                  StructField("b", IntegerType),
                  StructField("c", IntegerType)))),
                StructField("b", StructType(Seq(
                  StructField("a", IntegerType),
                  StructField("b", IntegerType),
                  StructField("c", IntegerType)))),
                StructField("c", StructType(Seq(
                  StructField("a", IntegerType),
                  StructField("b", IntegerType),
                  StructField("c", IntegerType))))
              ))),
              StructField("c", StructType(Seq(
                StructField("a", IntegerType),
                StructField("b", IntegerType),
                StructField("c", IntegerType))))
            )))))).cache
    
    data.show(false)
    +---------------------------------+
    |a                                |
    +---------------------------------+
    |[[1, 2, 3], [[4,, 6], [7, 8, 9]]]|
    +---------------------------------+
    ```
    Currently, to replace the missing value users would have to do something like this:
    ```
    val result = data.withColumn("a",
      struct(
        $"a.a",
        struct(
          struct(
            $"a.b.a.a",
            lit(5).as("b"),
            $"a.b.a.c"
          ).as("a"),
          $"a.b.b",
          $"a.b.c"
        ).as("b"),
        $"a.c"
      ))
    
    result.show(false)
    +---------------------------------------------------------------+
    |a                                                              |
    +---------------------------------------------------------------+
    |[[1, 2, 3], [[4, 5, 6], [7, 8, 9], [10, 11, 12]], [13, 14, 15]]|
    +---------------------------------------------------------------+
    ```
    As you can see above, with the existing methods users must call the `struct` function and list all fields, including fields they don't want to change. This is not ideal as:
    >this leads to complex, fragile code that cannot survive schema evolution.
    [SPARK-16483](https://issues.apache.org/jira/browse/SPARK-16483)
    
    In contrast, with the method added in this PR, a user could simply do something like this:
    ```
    val result = data.withColumn("a", 'a.withField("b.a.b", lit(5)))
    result.show(false)
    +---------------------------------------------------------------+
    |a                                                              |
    +---------------------------------------------------------------+
    |[[1, 2, 3], [[4, 5, 6], [7, 8, 9], [10, 11, 12]], [13, 14, 15]]|
    +---------------------------------------------------------------+
    
    ```
    
    This is the first of maybe a few methods that could be added to the `Column` class to make it easier to manipulate nested data. Other methods under discussion in [SPARK-22231](https://issues.apache.org/jira/browse/SPARK-22231) include `drop` and `renameField`. However, these should be added in a separate PR.
    
    ### Does this PR introduce any user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    New unit tests were added. Jenkins must pass them.
    
    ### Related JIRAs:
    - https://issues.apache.org/jira/browse/SPARK-22231
    - https://issues.apache.org/jira/browse/SPARK-16483
    
    Closes #27066 from fqaiser94/SPARK-22231-withField.
    
    Lead-authored-by: fqaiser94@gmail.com <fq...@gmail.com>
    Co-authored-by: fqaiser94 <fq...@gmail.com>
    Co-authored-by: Wenchen Fan <we...@databricks.com>
    Signed-off-by: Wenchen Fan <we...@databricks.com>
---
 .../catalyst/expressions/complexTypeCreator.scala  |  58 +++
 .../sql/catalyst/optimizer/ComplexTypes.scala      |  13 +-
 .../spark/sql/catalyst/optimizer/Optimizer.scala   |   7 +-
 .../spark/sql/catalyst/optimizer/WithFields.scala  |  42 ++
 .../optimizer/CombineWithFieldsSuite.scala         |  76 ++++
 .../sql/catalyst/optimizer/complexTypesSuite.scala |  57 +++
 .../main/scala/org/apache/spark/sql/Column.scala   |  66 +++
 .../apache/spark/sql/ColumnExpressionSuite.scala   | 499 +++++++++++++++++++++
 8 files changed, 815 insertions(+), 3 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
index 1b4a705..cf7cc3a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
@@ -539,3 +539,61 @@ case class StringToMap(text: Expression, pairDelim: Expression, keyValueDelim: E
 
   override def prettyName: String = "str_to_map"
 }
+
+/**
+ * Adds/replaces field in struct by name.
+ */
+case class WithFields(
+    structExpr: Expression,
+    names: Seq[String],
+    valExprs: Seq[Expression]) extends Unevaluable {
+
+  assert(names.length == valExprs.length)
+
+  override def checkInputDataTypes(): TypeCheckResult = {
+    if (!structExpr.dataType.isInstanceOf[StructType]) {
+      TypeCheckResult.TypeCheckFailure(
+        "struct argument should be struct type, got: " + structExpr.dataType.catalogString)
+    } else {
+      TypeCheckResult.TypeCheckSuccess
+    }
+  }
+
+  override def children: Seq[Expression] = structExpr +: valExprs
+
+  override def dataType: StructType = evalExpr.dataType.asInstanceOf[StructType]
+
+  override def foldable: Boolean = structExpr.foldable && valExprs.forall(_.foldable)
+
+  override def nullable: Boolean = structExpr.nullable
+
+  override def prettyName: String = "with_fields"
+
+  lazy val evalExpr: Expression = {
+    val existingExprs = structExpr.dataType.asInstanceOf[StructType].fieldNames.zipWithIndex.map {
+      case (name, i) => (name, GetStructField(KnownNotNull(structExpr), i).asInstanceOf[Expression])
+    }
+
+    val addOrReplaceExprs = names.zip(valExprs)
+
+    val resolver = SQLConf.get.resolver
+    val newExprs = addOrReplaceExprs.foldLeft(existingExprs) {
+      case (resultExprs, newExpr @ (newExprName, _)) =>
+        if (resultExprs.exists(x => resolver(x._1, newExprName))) {
+          resultExprs.map {
+            case (name, _) if resolver(name, newExprName) => newExpr
+            case x => x
+          }
+        } else {
+          resultExprs :+ newExpr
+        }
+    }.flatMap { case (name, expr) => Seq(Literal(name), expr) }
+
+    val expr = CreateNamedStruct(newExprs)
+    if (structExpr.nullable) {
+      If(IsNull(structExpr), Literal(null, expr.dataType), expr)
+    } else {
+      expr
+    }
+  }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala
index f79dabf..1c33a2c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala
@@ -39,7 +39,18 @@ object SimplifyExtractValueOps extends Rule[LogicalPlan] {
       // Remove redundant field extraction.
       case GetStructField(createNamedStruct: CreateNamedStruct, ordinal, _) =>
         createNamedStruct.valExprs(ordinal)
-
+      case GetStructField(w @ WithFields(struct, names, valExprs), ordinal, maybeName) =>
+        val name = w.dataType(ordinal).name
+        val matches = names.zip(valExprs).filter(_._1 == name)
+        if (matches.nonEmpty) {
+          // return last matching element as that is the final value for the field being extracted.
+          // For example, if a user submits a query like this:
+          // `$"struct_col".withField("b", lit(1)).withField("b", lit(2)).getField("b")`
+          // we want to return `lit(2)` (and not `lit(1)`).
+          matches.last._2
+        } else {
+          GetStructField(struct, ordinal, maybeName)
+        }
       // Remove redundant array indexing.
       case GetArrayStructFields(CreateArray(elems, useStringTypeWhenEmpty), field, ordinal, _, _) =>
         // Instead of selecting the field on the entire array, select it from each member
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index e800ee3..1b14157 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -107,6 +107,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
         EliminateSerialization,
         RemoveRedundantAliases,
         RemoveNoopOperators,
+        CombineWithFields,
         SimplifyExtractValueOps,
         CombineConcats) ++
         extendedOperatorOptimizationRules
@@ -207,7 +208,8 @@ abstract class Optimizer(catalogManager: CatalogManager)
       CollapseProject,
       RemoveNoopOperators) :+
     // This batch must be executed after the `RewriteSubquery` batch, which creates joins.
-    Batch("NormalizeFloatingNumbers", Once, NormalizeFloatingNumbers)
+    Batch("NormalizeFloatingNumbers", Once, NormalizeFloatingNumbers) :+
+    Batch("ReplaceWithFieldsExpression", Once, ReplaceWithFieldsExpression)
 
     // remove any batches with no rules. this may happen when subclasses do not add optional rules.
     batches.filter(_.rules.nonEmpty)
@@ -240,7 +242,8 @@ abstract class Optimizer(catalogManager: CatalogManager)
       PullupCorrelatedPredicates.ruleName ::
       RewriteCorrelatedScalarSubquery.ruleName ::
       RewritePredicateSubquery.ruleName ::
-      NormalizeFloatingNumbers.ruleName :: Nil
+      NormalizeFloatingNumbers.ruleName ::
+      ReplaceWithFieldsExpression.ruleName :: Nil
 
   /**
    * Optimize all the subqueries inside expression.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/WithFields.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/WithFields.scala
new file mode 100644
index 0000000..05c9086
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/WithFields.scala
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.expressions.WithFields
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.rules.Rule
+
+
+/**
+ * Combines all adjacent [[WithFields]] expression into a single [[WithFields]] expression.
+ */
+object CombineWithFields extends Rule[LogicalPlan] {
+  def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
+    case WithFields(WithFields(struct, names1, valExprs1), names2, valExprs2) =>
+      WithFields(struct, names1 ++ names2, valExprs1 ++ valExprs2)
+  }
+}
+
+/**
+ * Replaces [[WithFields]] expression with an evaluable expression.
+ */
+object ReplaceWithFieldsExpression extends Rule[LogicalPlan] {
+  def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
+    case w: WithFields => w.evalExpr
+  }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombineWithFieldsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombineWithFieldsSuite.scala
new file mode 100644
index 0000000..a3e0bbc
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombineWithFieldsSuite.scala
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.expressions.{Alias, Literal, WithFields}
+import org.apache.spark.sql.catalyst.plans.PlanTest
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.rules._
+
+
+class CombineWithFieldsSuite extends PlanTest {
+
+  object Optimize extends RuleExecutor[LogicalPlan] {
+    val batches = Batch("CombineWithFields", FixedPoint(10), CombineWithFields) :: Nil
+  }
+
+  private val testRelation = LocalRelation('a.struct('a1.int))
+
+  test("combines two WithFields") {
+    val originalQuery = testRelation
+      .select(Alias(
+        WithFields(
+          WithFields(
+            'a,
+            Seq("b1"),
+            Seq(Literal(4))),
+          Seq("c1"),
+          Seq(Literal(5))), "out")())
+
+    val optimized = Optimize.execute(originalQuery.analyze)
+    val correctAnswer = testRelation
+      .select(Alias(WithFields('a, Seq("b1", "c1"), Seq(Literal(4), Literal(5))), "out")())
+      .analyze
+
+    comparePlans(optimized, correctAnswer)
+  }
+
+  test("combines three WithFields") {
+    val originalQuery = testRelation
+      .select(Alias(
+        WithFields(
+          WithFields(
+            WithFields(
+              'a,
+              Seq("b1"),
+              Seq(Literal(4))),
+            Seq("c1"),
+            Seq(Literal(5))),
+          Seq("d1"),
+          Seq(Literal(6))), "out")())
+
+    val optimized = Optimize.execute(originalQuery.analyze)
+    val correctAnswer = testRelation
+      .select(Alias(WithFields('a, Seq("b1", "c1", "d1"), Seq(4, 5, 6).map(Literal(_))), "out")())
+      .analyze
+
+    comparePlans(optimized, correctAnswer)
+  }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala
index d557460..c71e7db 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala
@@ -452,4 +452,61 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
     checkEvaluation(GetMapValue(mb0, Literal(Array[Byte](2, 1), BinaryType)), "2")
     checkEvaluation(GetMapValue(mb0, Literal(Array[Byte](3, 4))), null)
   }
+
+  private val structAttr = 'struct1.struct('a.int)
+  private val testStructRelation = LocalRelation(structAttr)
+
+  test("simplify GetStructField on WithFields that is not changing the attribute being extracted") {
+    val query = testStructRelation.select(
+      GetStructField(WithFields('struct1, Seq("b"), Seq(Literal(1))), 0, Some("a")) as "outerAtt")
+    val expected = testStructRelation.select(GetStructField('struct1, 0, Some("a")) as "outerAtt")
+    checkRule(query, expected)
+  }
+
+  test("simplify GetStructField on WithFields that is changing the attribute being extracted") {
+    val query = testStructRelation.select(
+      GetStructField(WithFields('struct1, Seq("b"), Seq(Literal(1))), 1, Some("b")) as "outerAtt")
+    val expected = testStructRelation.select(Literal(1) as "outerAtt")
+    checkRule(query, expected)
+  }
+
+  test(
+    "simplify GetStructField on WithFields that is changing the attribute being extracted twice") {
+    val query = testStructRelation
+      .select(GetStructField(WithFields('struct1, Seq("b", "b"), Seq(Literal(1), Literal(2))), 1,
+        Some("b")) as "outerAtt")
+    val expected = testStructRelation.select(Literal(2) as "outerAtt")
+    checkRule(query, expected)
+  }
+
+  test("collapse multiple GetStructField on the same WithFields") {
+    val query = testStructRelation
+      .select(WithFields('struct1, Seq("b"), Seq(Literal(2))) as "struct2")
+      .select(
+        GetStructField('struct2, 0, Some("a")) as "struct1A",
+        GetStructField('struct2, 1, Some("b")) as "struct1B")
+    val expected = testStructRelation.select(
+      GetStructField('struct1, 0, Some("a")) as "struct1A",
+      Literal(2) as "struct1B")
+    checkRule(query, expected)
+  }
+
+  test("collapse multiple GetStructField on different WithFields") {
+    val query = testStructRelation
+      .select(
+        WithFields('struct1, Seq("b"), Seq(Literal(2))) as "struct2",
+        WithFields('struct1, Seq("b"), Seq(Literal(3))) as "struct3")
+      .select(
+        GetStructField('struct2, 0, Some("a")) as "struct2A",
+        GetStructField('struct2, 1, Some("b")) as "struct2B",
+        GetStructField('struct3, 0, Some("a")) as "struct3A",
+        GetStructField('struct3, 1, Some("b")) as "struct3B")
+    val expected = testStructRelation
+      .select(
+        GetStructField('struct1, 0, Some("a")) as "struct2A",
+        Literal(2) as "struct2B",
+        GetStructField('struct1, 0, Some("a")) as "struct3A",
+        Literal(3) as "struct3B")
+    checkRule(query, expected)
+  }
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index e6f7b1d..da542c6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -871,6 +871,72 @@ class Column(val expr: Expression) extends Logging {
    */
   def getItem(key: Any): Column = withExpr { UnresolvedExtractValue(expr, Literal(key)) }
 
+  // scalastyle:off line.size.limit
+  /**
+   * An expression that adds/replaces field in `StructType` by name.
+   *
+   * {{{
+   *   val df = sql("SELECT named_struct('a', 1, 'b', 2) struct_col")
+   *   df.select($"struct_col".withField("c", lit(3)))
+   *   // result: {"a":1,"b":2,"c":3}
+   *
+   *   val df = sql("SELECT named_struct('a', 1, 'b', 2) struct_col")
+   *   df.select($"struct_col".withField("b", lit(3)))
+   *   // result: {"a":1,"b":3}
+   *
+   *   val df = sql("SELECT CAST(NULL AS struct<a:int,b:int>) struct_col")
+   *   df.select($"struct_col".withField("c", lit(3)))
+   *   // result: null of type struct<a:int,b:int,c:int>
+   *
+   *   val df = sql("SELECT named_struct('a', 1, 'b', 2, 'b', 3) struct_col")
+   *   df.select($"struct_col".withField("b", lit(100)))
+   *   // result: {"a":1,"b":100,"b":100}
+   *
+   *   val df = sql("SELECT named_struct('a', named_struct('a', 1, 'b', 2)) struct_col")
+   *   df.select($"struct_col".withField("a.c", lit(3)))
+   *   // result: {"a":{"a":1,"b":2,"c":3}}
+   *
+   *   val df = sql("SELECT named_struct('a', named_struct('b', 1), 'a', named_struct('c', 2)) struct_col")
+   *   df.select($"struct_col".withField("a.c", lit(3)))
+   *   // result: org.apache.spark.sql.AnalysisException: Ambiguous reference to fields
+   * }}}
+   *
+   * @group expr_ops
+   * @since 3.1.0
+   */
+  // scalastyle:on line.size.limit
+  def withField(fieldName: String, col: Column): Column = withExpr {
+    require(fieldName != null, "fieldName cannot be null")
+    require(col != null, "col cannot be null")
+
+    val nameParts = if (fieldName.isEmpty) {
+      fieldName :: Nil
+    } else {
+      CatalystSqlParser.parseMultipartIdentifier(fieldName)
+    }
+    withFieldHelper(expr, nameParts, Nil, col.expr)
+  }
+
+  private def withFieldHelper(
+      struct: Expression,
+      namePartsRemaining: Seq[String],
+      namePartsDone: Seq[String],
+      value: Expression) : WithFields = {
+    val name = namePartsRemaining.head
+    if (namePartsRemaining.length == 1) {
+      WithFields(struct, name :: Nil, value :: Nil)
+    } else {
+      val newNamesRemaining = namePartsRemaining.tail
+      val newNamesDone = namePartsDone :+ name
+      val newValue = withFieldHelper(
+        struct = UnresolvedExtractValue(struct, Literal(name)),
+        namePartsRemaining = newNamesRemaining,
+        namePartsDone = newNamesDone,
+        value = value)
+      WithFields(struct, name :: Nil, newValue :: Nil)
+    }
+  }
+
   /**
    * An expression that gets a field by name in a `StructType`.
    *
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
index fa06484..131ab1b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
@@ -923,4 +923,503 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession {
     val inSet = InSet(Literal("a"), Set("a", "b").map(UTF8String.fromString))
     assert(inSet.sql === "('a' IN ('a', 'b'))")
   }
+
+  def checkAnswerAndSchema(
+      df: => DataFrame,
+      expectedAnswer: Seq[Row],
+      expectedSchema: StructType): Unit = {
+
+    checkAnswer(df, expectedAnswer)
+    assert(df.schema == expectedSchema)
+  }
+
+  private lazy val structType = StructType(Seq(
+    StructField("a", IntegerType, nullable = false),
+    StructField("b", IntegerType, nullable = true),
+    StructField("c", IntegerType, nullable = false)))
+
+  private lazy val structLevel1: DataFrame = spark.createDataFrame(
+    sparkContext.parallelize(Row(Row(1, null, 3)) :: Nil),
+    StructType(Seq(StructField("a", structType, nullable = false))))
+
+  private lazy val nullStructLevel1: DataFrame = spark.createDataFrame(
+    sparkContext.parallelize(Row(null) :: Nil),
+    StructType(Seq(StructField("a", structType, nullable = true))))
+
+  private lazy val structLevel2: DataFrame = spark.createDataFrame(
+    sparkContext.parallelize(Row(Row(Row(1, null, 3))) :: Nil),
+    StructType(Seq(
+      StructField("a", StructType(Seq(
+        StructField("a", structType, nullable = false))),
+        nullable = false))))
+
+  private lazy val nullStructLevel2: DataFrame = spark.createDataFrame(
+    sparkContext.parallelize(Row(Row(null)) :: Nil),
+    StructType(Seq(
+      StructField("a", StructType(Seq(
+        StructField("a", structType, nullable = true))),
+        nullable = false))))
+
+  private lazy val structLevel3: DataFrame = spark.createDataFrame(
+    sparkContext.parallelize(Row(Row(Row(Row(1, null, 3)))) :: Nil),
+    StructType(Seq(
+      StructField("a", StructType(Seq(
+        StructField("a", StructType(Seq(
+          StructField("a", structType, nullable = false))),
+          nullable = false))),
+        nullable = false))))
+
+  test("withField should throw an exception if called on a non-StructType column") {
+    intercept[AnalysisException] {
+      testData.withColumn("key", $"key".withField("a", lit(2)))
+    }.getMessage should include("struct argument should be struct type, got: int")
+  }
+
+  test("withField should throw an exception if either fieldName or col argument are null") {
+    intercept[IllegalArgumentException] {
+      structLevel1.withColumn("a", $"a".withField(null, lit(2)))
+    }.getMessage should include("fieldName cannot be null")
+
+    intercept[IllegalArgumentException] {
+      structLevel1.withColumn("a", $"a".withField("b", null))
+    }.getMessage should include("col cannot be null")
+
+    intercept[IllegalArgumentException] {
+      structLevel1.withColumn("a", $"a".withField(null, null))
+    }.getMessage should include("fieldName cannot be null")
+  }
+
+  test("withField should throw an exception if any intermediate structs don't exist") {
+    intercept[AnalysisException] {
+      structLevel2.withColumn("a", 'a.withField("x.b", lit(2)))
+    }.getMessage should include("No such struct field x in a")
+
+    intercept[AnalysisException] {
+      structLevel3.withColumn("a", 'a.withField("a.x.b", lit(2)))
+    }.getMessage should include("No such struct field x in a")
+  }
+
+  test("withField should throw an exception if intermediate field is not a struct") {
+    intercept[AnalysisException] {
+      structLevel1.withColumn("a", 'a.withField("b.a", lit(2)))
+    }.getMessage should include("struct argument should be struct type, got: int")
+  }
+
+  test("withField should throw an exception if intermediate field reference is ambiguous") {
+    intercept[AnalysisException] {
+      val structLevel2: DataFrame = spark.createDataFrame(
+        sparkContext.parallelize(Row(Row(Row(1, null, 3), 4)) :: Nil),
+        StructType(Seq(
+          StructField("a", StructType(Seq(
+            StructField("a", structType, nullable = false),
+            StructField("a", structType, nullable = false))),
+            nullable = false))))
+
+      structLevel2.withColumn("a", 'a.withField("a.b", lit(2)))
+    }.getMessage should include("Ambiguous reference to fields")
+  }
+
+  test("withField should add field with no name") {
+    checkAnswerAndSchema(
+      structLevel1.withColumn("a", $"a".withField("", lit(4))),
+      Row(Row(1, null, 3, 4)) :: Nil,
+      StructType(Seq(
+        StructField("a", StructType(Seq(
+          StructField("a", IntegerType, nullable = false),
+          StructField("b", IntegerType, nullable = true),
+          StructField("c", IntegerType, nullable = false),
+          StructField("", IntegerType, nullable = false))),
+          nullable = false))))
+  }
+
+  test("withField should add field to struct") {
+    checkAnswerAndSchema(
+      structLevel1.withColumn("a", 'a.withField("d", lit(4))),
+      Row(Row(1, null, 3, 4)) :: Nil,
+      StructType(Seq(
+        StructField("a", StructType(Seq(
+          StructField("a", IntegerType, nullable = false),
+          StructField("b", IntegerType, nullable = true),
+          StructField("c", IntegerType, nullable = false),
+          StructField("d", IntegerType, nullable = false))),
+          nullable = false))))
+  }
+
+  test("withField should add field to null struct") {
+    checkAnswerAndSchema(
+      nullStructLevel1.withColumn("a", $"a".withField("d", lit(4))),
+      Row(null) :: Nil,
+      StructType(Seq(
+        StructField("a", StructType(Seq(
+          StructField("a", IntegerType, nullable = false),
+          StructField("b", IntegerType, nullable = true),
+          StructField("c", IntegerType, nullable = false),
+          StructField("d", IntegerType, nullable = false))),
+          nullable = true))))
+  }
+
+  test("withField should add field to nested null struct") {
+    checkAnswerAndSchema(
+      nullStructLevel2.withColumn("a", $"a".withField("a.d", lit(4))),
+      Row(Row(null)) :: Nil,
+      StructType(
+        Seq(StructField("a", StructType(Seq(
+          StructField("a", StructType(Seq(
+            StructField("a", IntegerType, nullable = false),
+            StructField("b", IntegerType, nullable = true),
+            StructField("c", IntegerType, nullable = false),
+            StructField("d", IntegerType, nullable = false))),
+            nullable = true))),
+          nullable = false))))
+  }
+
+  test("withField should add null field to struct") {
+    checkAnswerAndSchema(
+      structLevel1.withColumn("a", 'a.withField("d", lit(null).cast(IntegerType))),
+      Row(Row(1, null, 3, null)) :: Nil,
+      StructType(Seq(
+        StructField("a", StructType(Seq(
+          StructField("a", IntegerType, nullable = false),
+          StructField("b", IntegerType, nullable = true),
+          StructField("c", IntegerType, nullable = false),
+          StructField("d", IntegerType, nullable = true))),
+          nullable = false))))
+  }
+
+  test("withField should add multiple fields to struct") {
+    checkAnswerAndSchema(
+      structLevel1.withColumn("a", 'a.withField("d", lit(4)).withField("e", lit(5))),
+      Row(Row(1, null, 3, 4, 5)) :: Nil,
+      StructType(Seq(
+        StructField("a", StructType(Seq(
+          StructField("a", IntegerType, nullable = false),
+          StructField("b", IntegerType, nullable = true),
+          StructField("c", IntegerType, nullable = false),
+          StructField("d", IntegerType, nullable = false),
+          StructField("e", IntegerType, nullable = false))),
+          nullable = false))))
+  }
+
+  test("withField should add field to nested struct") {
+    Seq(
+      structLevel2.withColumn("a", 'a.withField("a.d", lit(4))),
+      structLevel2.withColumn("a", 'a.withField("a", $"a.a".withField("d", lit(4))))
+    ).foreach { df =>
+      checkAnswerAndSchema(
+        df,
+        Row(Row(Row(1, null, 3, 4))) :: Nil,
+        StructType(
+          Seq(StructField("a", StructType(Seq(
+            StructField("a", StructType(Seq(
+              StructField("a", IntegerType, nullable = false),
+              StructField("b", IntegerType, nullable = true),
+              StructField("c", IntegerType, nullable = false),
+              StructField("d", IntegerType, nullable = false))),
+              nullable = false))),
+            nullable = false))))
+    }
+  }
+
+  test("withField should add field to deeply nested struct") {
+    checkAnswerAndSchema(
+      structLevel3.withColumn("a", 'a.withField("a.a.d", lit(4))),
+      Row(Row(Row(Row(1, null, 3, 4)))) :: Nil,
+      StructType(Seq(
+        StructField("a", StructType(Seq(
+          StructField("a", StructType(Seq(
+            StructField("a", StructType(Seq(
+              StructField("a", IntegerType, nullable = false),
+              StructField("b", IntegerType, nullable = true),
+              StructField("c", IntegerType, nullable = false),
+              StructField("d", IntegerType, nullable = false))),
+              nullable = false))),
+            nullable = false))),
+          nullable = false))))
+  }
+
+  test("withField should replace field in struct") {
+    checkAnswerAndSchema(
+      structLevel1.withColumn("a", 'a.withField("b", lit(2))),
+      Row(Row(1, 2, 3)) :: Nil,
+      StructType(Seq(
+        StructField("a", StructType(Seq(
+          StructField("a", IntegerType, nullable = false),
+          StructField("b", IntegerType, nullable = false),
+          StructField("c", IntegerType, nullable = false))),
+          nullable = false))))
+  }
+
+  test("withField should replace field in null struct") {
+    checkAnswerAndSchema(
+      nullStructLevel1.withColumn("a", 'a.withField("b", lit("foo"))),
+      Row(null) :: Nil,
+      StructType(Seq(
+        StructField("a", StructType(Seq(
+          StructField("a", IntegerType, nullable = false),
+          StructField("b", StringType, nullable = false),
+          StructField("c", IntegerType, nullable = false))),
+          nullable = true))))
+  }
+
+  test("withField should replace field in nested null struct") {
+    checkAnswerAndSchema(
+      nullStructLevel2.withColumn("a", $"a".withField("a.b", lit("foo"))),
+      Row(Row(null)) :: Nil,
+      StructType(
+        Seq(StructField("a", StructType(Seq(
+          StructField("a", StructType(Seq(
+            StructField("a", IntegerType, nullable = false),
+            StructField("b", StringType, nullable = false),
+            StructField("c", IntegerType, nullable = false))),
+            nullable = true))),
+          nullable = false))))
+  }
+
+  test("withField should replace field with null value in struct") {
+    checkAnswerAndSchema(
+      structLevel1.withColumn("a", 'a.withField("c", lit(null).cast(IntegerType))),
+      Row(Row(1, null, null)) :: Nil,
+      StructType(Seq(
+        StructField("a", StructType(Seq(
+          StructField("a", IntegerType, nullable = false),
+          StructField("b", IntegerType, nullable = true),
+          StructField("c", IntegerType, nullable = true))),
+          nullable = false))))
+  }
+
+  test("withField should replace multiple fields in struct") {
+    checkAnswerAndSchema(
+      structLevel1.withColumn("a", 'a.withField("a", lit(10)).withField("b", lit(20))),
+      Row(Row(10, 20, 3)) :: Nil,
+      StructType(Seq(
+        StructField("a", StructType(Seq(
+          StructField("a", IntegerType, nullable = false),
+          StructField("b", IntegerType, nullable = false),
+          StructField("c", IntegerType, nullable = false))),
+          nullable = false))))
+  }
+
+  test("withField should replace field in nested struct") {
+    Seq(
+      structLevel2.withColumn("a", $"a".withField("a.b", lit(2))),
+      structLevel2.withColumn("a", 'a.withField("a", $"a.a".withField("b", lit(2))))
+    ).foreach { df =>
+      checkAnswerAndSchema(
+        df,
+        Row(Row(Row(1, 2, 3))) :: Nil,
+        StructType(Seq(
+          StructField("a", StructType(Seq(
+            StructField("a", StructType(Seq(
+              StructField("a", IntegerType, nullable = false),
+              StructField("b", IntegerType, nullable = false),
+              StructField("c", IntegerType, nullable = false))),
+              nullable = false))),
+            nullable = false))))
+    }
+  }
+
+  test("withField should replace field in deeply nested struct") {
+    checkAnswerAndSchema(
+      structLevel3.withColumn("a", $"a".withField("a.a.b", lit(2))),
+      Row(Row(Row(Row(1, 2, 3)))) :: Nil,
+      StructType(Seq(
+        StructField("a", StructType(Seq(
+          StructField("a", StructType(Seq(
+            StructField("a", StructType(Seq(
+              StructField("a", IntegerType, nullable = false),
+              StructField("b", IntegerType, nullable = false),
+              StructField("c", IntegerType, nullable = false))),
+              nullable = false))),
+            nullable = false))),
+          nullable = false))))
+  }
+
+  test("withField should replace all fields with given name in struct") {
+    val structLevel1 = spark.createDataFrame(
+      sparkContext.parallelize(Row(Row(1, 2, 3)) :: Nil),
+      StructType(Seq(
+        StructField("a", StructType(Seq(
+          StructField("a", IntegerType, nullable = false),
+          StructField("b", IntegerType, nullable = false),
+          StructField("b", IntegerType, nullable = false))),
+          nullable = false))))
+
+    checkAnswerAndSchema(
+      structLevel1.withColumn("a", 'a.withField("b", lit(100))),
+      Row(Row(1, 100, 100)) :: Nil,
+      StructType(Seq(
+        StructField("a", StructType(Seq(
+          StructField("a", IntegerType, nullable = false),
+          StructField("b", IntegerType, nullable = false),
+          StructField("b", IntegerType, nullable = false))),
+          nullable = false))))
+  }
+
+  test("withField should replace fields in struct in given order") {
+    checkAnswerAndSchema(
+      structLevel1.withColumn("a", 'a.withField("b", lit(2)).withField("b", lit(20))),
+      Row(Row(1, 20, 3)) :: Nil,
+      StructType(Seq(
+        StructField("a", StructType(Seq(
+          StructField("a", IntegerType, nullable = false),
+          StructField("b", IntegerType, nullable = false),
+          StructField("c", IntegerType, nullable = false))),
+          nullable = false))))
+  }
+
+  test("withField should add field and then replace same field in struct") {
+    checkAnswerAndSchema(
+      structLevel1.withColumn("a", 'a.withField("d", lit(4)).withField("d", lit(5))),
+      Row(Row(1, null, 3, 5)) :: Nil,
+      StructType(Seq(
+        StructField("a", StructType(Seq(
+          StructField("a", IntegerType, nullable = false),
+          StructField("b", IntegerType, nullable = true),
+          StructField("c", IntegerType, nullable = false),
+          StructField("d", IntegerType, nullable = false))),
+          nullable = false))))
+  }
+
+  test("withField should handle fields with dots in their name if correctly quoted") {
+    val df: DataFrame = spark.createDataFrame(
+      sparkContext.parallelize(Row(Row(Row(1, null, 3))) :: Nil),
+      StructType(Seq(
+        StructField("a", StructType(Seq(
+          StructField("a.b", StructType(Seq(
+            StructField("c.d", IntegerType, nullable = false),
+            StructField("e.f", IntegerType, nullable = true),
+            StructField("g.h", IntegerType, nullable = false))),
+            nullable = false))),
+          nullable = false))))
+
+    checkAnswerAndSchema(
+      df.withColumn("a", 'a.withField("`a.b`.`e.f`", lit(2))),
+      Row(Row(Row(1, 2, 3))) :: Nil,
+      StructType(Seq(
+        StructField("a", StructType(Seq(
+          StructField("a.b", StructType(Seq(
+            StructField("c.d", IntegerType, nullable = false),
+            StructField("e.f", IntegerType, nullable = false),
+            StructField("g.h", IntegerType, nullable = false))),
+            nullable = false))),
+          nullable = false))))
+
+    intercept[AnalysisException] {
+      df.withColumn("a", 'a.withField("a.b.e.f", lit(2)))
+    }.getMessage should include("No such struct field a in a.b")
+  }
+
+  private lazy val mixedCaseStructLevel1: DataFrame = spark.createDataFrame(
+    sparkContext.parallelize(Row(Row(1, 1)) :: Nil),
+    StructType(Seq(
+      StructField("a", StructType(Seq(
+        StructField("a", IntegerType, nullable = false),
+        StructField("B", IntegerType, nullable = false))),
+        nullable = false))))
+
+  test("withField should replace field in struct even if casing is different") {
+    withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") {
+      checkAnswerAndSchema(
+        mixedCaseStructLevel1.withColumn("a", 'a.withField("A", lit(2))),
+        Row(Row(2, 1)) :: Nil,
+        StructType(Seq(
+          StructField("a", StructType(Seq(
+            StructField("A", IntegerType, nullable = false),
+            StructField("B", IntegerType, nullable = false))),
+            nullable = false))))
+
+      checkAnswerAndSchema(
+        mixedCaseStructLevel1.withColumn("a", 'a.withField("b", lit(2))),
+        Row(Row(1, 2)) :: Nil,
+        StructType(Seq(
+          StructField("a", StructType(Seq(
+            StructField("a", IntegerType, nullable = false),
+            StructField("b", IntegerType, nullable = false))),
+            nullable = false))))
+    }
+  }
+
+  test("withField should add field to struct because casing is different") {
+    withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") {
+      checkAnswerAndSchema(
+        mixedCaseStructLevel1.withColumn("a", 'a.withField("A", lit(2))),
+        Row(Row(1, 1, 2)) :: Nil,
+        StructType(Seq(
+          StructField("a", StructType(Seq(
+            StructField("a", IntegerType, nullable = false),
+            StructField("B", IntegerType, nullable = false),
+            StructField("A", IntegerType, nullable = false))),
+            nullable = false))))
+
+      checkAnswerAndSchema(
+        mixedCaseStructLevel1.withColumn("a", 'a.withField("b", lit(2))),
+        Row(Row(1, 1, 2)) :: Nil,
+        StructType(Seq(
+          StructField("a", StructType(Seq(
+            StructField("a", IntegerType, nullable = false),
+            StructField("B", IntegerType, nullable = false),
+            StructField("b", IntegerType, nullable = false))),
+            nullable = false))))
+    }
+  }
+
+  private lazy val mixedCaseStructLevel2: DataFrame = spark.createDataFrame(
+    sparkContext.parallelize(Row(Row(Row(1, 1), Row(1, 1))) :: Nil),
+    StructType(Seq(
+      StructField("a", StructType(Seq(
+        StructField("a", StructType(Seq(
+          StructField("a", IntegerType, nullable = false),
+          StructField("b", IntegerType, nullable = false))),
+          nullable = false),
+        StructField("B", StructType(Seq(
+          StructField("a", IntegerType, nullable = false),
+          StructField("b", IntegerType, nullable = false))),
+          nullable = false))),
+        nullable = false))))
+
+  test("withField should replace nested field in struct even if casing is different") {
+    withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") {
+      checkAnswerAndSchema(
+        mixedCaseStructLevel2.withColumn("a", 'a.withField("A.a", lit(2))),
+        Row(Row(Row(2, 1), Row(1, 1))) :: Nil,
+        StructType(Seq(
+          StructField("a", StructType(Seq(
+            StructField("A", StructType(Seq(
+              StructField("a", IntegerType, nullable = false),
+              StructField("b", IntegerType, nullable = false))),
+              nullable = false),
+            StructField("B", StructType(Seq(
+              StructField("a", IntegerType, nullable = false),
+              StructField("b", IntegerType, nullable = false))),
+              nullable = false))),
+            nullable = false))))
+
+      checkAnswerAndSchema(
+        mixedCaseStructLevel2.withColumn("a", 'a.withField("b.a", lit(2))),
+        Row(Row(Row(1, 1), Row(2, 1))) :: Nil,
+        StructType(Seq(
+          StructField("a", StructType(Seq(
+            StructField("a", StructType(Seq(
+              StructField("a", IntegerType, nullable = false),
+              StructField("b", IntegerType, nullable = false))),
+              nullable = false),
+            StructField("b", StructType(Seq(
+              StructField("a", IntegerType, nullable = false),
+              StructField("b", IntegerType, nullable = false))),
+              nullable = false))),
+            nullable = false))))
+    }
+  }
+
+  test("withField should throw an exception because casing is different") {
+    withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") {
+      intercept[AnalysisException] {
+        mixedCaseStructLevel2.withColumn("a", 'a.withField("A.a", lit(2)))
+      }.getMessage should include("No such struct field A in a, B")
+
+      intercept[AnalysisException] {
+        mixedCaseStructLevel2.withColumn("a", 'a.withField("b.a", lit(2)))
+      }.getMessage should include("No such struct field b in a, B")
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org