You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2020/07/07 16:35:27 UTC
[spark] branch master updated: [SPARK-31317][SQL] Add withField
method to Column
This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 4bbc343 [SPARK-31317][SQL] Add withField method to Column
4bbc343 is described below
commit 4bbc343a4c0ff3d4b5443bd65adf2df55b9245ee
Author: fqaiser94@gmail.com <fq...@gmail.com>
AuthorDate: Tue Jul 7 16:34:03 2020 +0000
[SPARK-31317][SQL] Add withField method to Column
### What changes were proposed in this pull request?
Added a new `withField` method to the `Column` class. This method should allow users to add or replace a `StructField` in a `StructType` column (with very similar semantics to the `withColumn` method on `Dataset`).
### Why are the changes needed?
Often Spark users have to work with deeply nested data e.g. to fix a data quality issue with an existing `StructField`. To do this with the existing Spark APIs, users have to rebuild the entire struct column.
For example, let's say you have the following deeply nested data structure which has a data quality issue (`5` is missing):
```
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
val data = spark.createDataFrame(sc.parallelize(
Seq(Row(Row(Row(1, 2, 3), Row(Row(4, null, 6), Row(7, 8, 9), Row(10, 11, 12)), Row(13, 14, 15))))),
StructType(Seq(
StructField("a", StructType(Seq(
StructField("a", StructType(Seq(
StructField("a", IntegerType),
StructField("b", IntegerType),
StructField("c", IntegerType)))),
StructField("b", StructType(Seq(
StructField("a", StructType(Seq(
StructField("a", IntegerType),
StructField("b", IntegerType),
StructField("c", IntegerType)))),
StructField("b", StructType(Seq(
StructField("a", IntegerType),
StructField("b", IntegerType),
StructField("c", IntegerType)))),
StructField("c", StructType(Seq(
StructField("a", IntegerType),
StructField("b", IntegerType),
StructField("c", IntegerType))))
))),
StructField("c", StructType(Seq(
StructField("a", IntegerType),
StructField("b", IntegerType),
StructField("c", IntegerType))))
)))))).cache
data.show(false)
+---------------------------------+
|a |
+---------------------------------+
|[[1, 2, 3], [[4,, 6], [7, 8, 9]]]|
+---------------------------------+
```
Currently, to replace the missing value users would have to do something like this:
```
val result = data.withColumn("a",
struct(
$"a.a",
struct(
struct(
$"a.b.a.a",
lit(5).as("b"),
$"a.b.a.c"
).as("a"),
$"a.b.b",
$"a.b.c"
).as("b"),
$"a.c"
))
result.show(false)
+---------------------------------------------------------------+
|a |
+---------------------------------------------------------------+
|[[1, 2, 3], [[4, 5, 6], [7, 8, 9], [10, 11, 12]], [13, 14, 15]]|
+---------------------------------------------------------------+
```
As you can see above, with the existing methods users must call the `struct` function and list all fields, including fields they don't want to change. This is not ideal as:
>this leads to complex, fragile code that cannot survive schema evolution.
[SPARK-16483](https://issues.apache.org/jira/browse/SPARK-16483)
In contrast, with the method added in this PR, a user could simply do something like this:
```
val result = data.withColumn("a", 'a.withField("b.a.b", lit(5)))
result.show(false)
+---------------------------------------------------------------+
|a |
+---------------------------------------------------------------+
|[[1, 2, 3], [[4, 5, 6], [7, 8, 9], [10, 11, 12]], [13, 14, 15]]|
+---------------------------------------------------------------+
```
This is the first of maybe a few methods that could be added to the `Column` class to make it easier to manipulate nested data. Other methods under discussion in [SPARK-22231](https://issues.apache.org/jira/browse/SPARK-22231) include `drop` and `renameField`. However, these should be added in a separate PR.
### Does this PR introduce any user-facing change?
No.
### How was this patch tested?
New unit tests were added. Jenkins must pass them.
### Related JIRAs:
- https://issues.apache.org/jira/browse/SPARK-22231
- https://issues.apache.org/jira/browse/SPARK-16483
Closes #27066 from fqaiser94/SPARK-22231-withField.
Lead-authored-by: fqaiser94@gmail.com <fq...@gmail.com>
Co-authored-by: fqaiser94 <fq...@gmail.com>
Co-authored-by: Wenchen Fan <we...@databricks.com>
Signed-off-by: Wenchen Fan <we...@databricks.com>
---
.../catalyst/expressions/complexTypeCreator.scala | 58 +++
.../sql/catalyst/optimizer/ComplexTypes.scala | 13 +-
.../spark/sql/catalyst/optimizer/Optimizer.scala | 7 +-
.../spark/sql/catalyst/optimizer/WithFields.scala | 42 ++
.../optimizer/CombineWithFieldsSuite.scala | 76 ++++
.../sql/catalyst/optimizer/complexTypesSuite.scala | 57 +++
.../main/scala/org/apache/spark/sql/Column.scala | 66 +++
.../apache/spark/sql/ColumnExpressionSuite.scala | 499 +++++++++++++++++++++
8 files changed, 815 insertions(+), 3 deletions(-)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
index 1b4a705..cf7cc3a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
@@ -539,3 +539,61 @@ case class StringToMap(text: Expression, pairDelim: Expression, keyValueDelim: E
override def prettyName: String = "str_to_map"
}
+
+/**
+ * Adds/replaces field in struct by name.
+ */
+case class WithFields(
+ structExpr: Expression,
+ names: Seq[String],
+ valExprs: Seq[Expression]) extends Unevaluable {
+
+ assert(names.length == valExprs.length)
+
+ override def checkInputDataTypes(): TypeCheckResult = {
+ if (!structExpr.dataType.isInstanceOf[StructType]) {
+ TypeCheckResult.TypeCheckFailure(
+ "struct argument should be struct type, got: " + structExpr.dataType.catalogString)
+ } else {
+ TypeCheckResult.TypeCheckSuccess
+ }
+ }
+
+ override def children: Seq[Expression] = structExpr +: valExprs
+
+ override def dataType: StructType = evalExpr.dataType.asInstanceOf[StructType]
+
+ override def foldable: Boolean = structExpr.foldable && valExprs.forall(_.foldable)
+
+ override def nullable: Boolean = structExpr.nullable
+
+ override def prettyName: String = "with_fields"
+
+ lazy val evalExpr: Expression = {
+ val existingExprs = structExpr.dataType.asInstanceOf[StructType].fieldNames.zipWithIndex.map {
+ case (name, i) => (name, GetStructField(KnownNotNull(structExpr), i).asInstanceOf[Expression])
+ }
+
+ val addOrReplaceExprs = names.zip(valExprs)
+
+ val resolver = SQLConf.get.resolver
+ val newExprs = addOrReplaceExprs.foldLeft(existingExprs) {
+ case (resultExprs, newExpr @ (newExprName, _)) =>
+ if (resultExprs.exists(x => resolver(x._1, newExprName))) {
+ resultExprs.map {
+ case (name, _) if resolver(name, newExprName) => newExpr
+ case x => x
+ }
+ } else {
+ resultExprs :+ newExpr
+ }
+ }.flatMap { case (name, expr) => Seq(Literal(name), expr) }
+
+ val expr = CreateNamedStruct(newExprs)
+ if (structExpr.nullable) {
+ If(IsNull(structExpr), Literal(null, expr.dataType), expr)
+ } else {
+ expr
+ }
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala
index f79dabf..1c33a2c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala
@@ -39,7 +39,18 @@ object SimplifyExtractValueOps extends Rule[LogicalPlan] {
// Remove redundant field extraction.
case GetStructField(createNamedStruct: CreateNamedStruct, ordinal, _) =>
createNamedStruct.valExprs(ordinal)
-
+ case GetStructField(w @ WithFields(struct, names, valExprs), ordinal, maybeName) =>
+ val name = w.dataType(ordinal).name
+ val matches = names.zip(valExprs).filter(_._1 == name)
+ if (matches.nonEmpty) {
+ // return last matching element as that is the final value for the field being extracted.
+ // For example, if a user submits a query like this:
+ // `$"struct_col".withField("b", lit(1)).withField("b", lit(2)).getField("b")`
+ // we want to return `lit(2)` (and not `lit(1)`).
+ matches.last._2
+ } else {
+ GetStructField(struct, ordinal, maybeName)
+ }
// Remove redundant array indexing.
case GetArrayStructFields(CreateArray(elems, useStringTypeWhenEmpty), field, ordinal, _, _) =>
// Instead of selecting the field on the entire array, select it from each member
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index e800ee3..1b14157 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -107,6 +107,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
EliminateSerialization,
RemoveRedundantAliases,
RemoveNoopOperators,
+ CombineWithFields,
SimplifyExtractValueOps,
CombineConcats) ++
extendedOperatorOptimizationRules
@@ -207,7 +208,8 @@ abstract class Optimizer(catalogManager: CatalogManager)
CollapseProject,
RemoveNoopOperators) :+
// This batch must be executed after the `RewriteSubquery` batch, which creates joins.
- Batch("NormalizeFloatingNumbers", Once, NormalizeFloatingNumbers)
+ Batch("NormalizeFloatingNumbers", Once, NormalizeFloatingNumbers) :+
+ Batch("ReplaceWithFieldsExpression", Once, ReplaceWithFieldsExpression)
// remove any batches with no rules. this may happen when subclasses do not add optional rules.
batches.filter(_.rules.nonEmpty)
@@ -240,7 +242,8 @@ abstract class Optimizer(catalogManager: CatalogManager)
PullupCorrelatedPredicates.ruleName ::
RewriteCorrelatedScalarSubquery.ruleName ::
RewritePredicateSubquery.ruleName ::
- NormalizeFloatingNumbers.ruleName :: Nil
+ NormalizeFloatingNumbers.ruleName ::
+ ReplaceWithFieldsExpression.ruleName :: Nil
/**
* Optimize all the subqueries inside expression.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/WithFields.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/WithFields.scala
new file mode 100644
index 0000000..05c9086
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/WithFields.scala
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.expressions.WithFields
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.rules.Rule
+
+
+/**
+ * Combines all adjacent [[WithFields]] expression into a single [[WithFields]] expression.
+ */
+object CombineWithFields extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
+ case WithFields(WithFields(struct, names1, valExprs1), names2, valExprs2) =>
+ WithFields(struct, names1 ++ names2, valExprs1 ++ valExprs2)
+ }
+}
+
+/**
+ * Replaces [[WithFields]] expression with an evaluable expression.
+ */
+object ReplaceWithFieldsExpression extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
+ case w: WithFields => w.evalExpr
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombineWithFieldsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombineWithFieldsSuite.scala
new file mode 100644
index 0000000..a3e0bbc
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombineWithFieldsSuite.scala
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.expressions.{Alias, Literal, WithFields}
+import org.apache.spark.sql.catalyst.plans.PlanTest
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.rules._
+
+
+class CombineWithFieldsSuite extends PlanTest {
+
+ object Optimize extends RuleExecutor[LogicalPlan] {
+ val batches = Batch("CombineWithFields", FixedPoint(10), CombineWithFields) :: Nil
+ }
+
+ private val testRelation = LocalRelation('a.struct('a1.int))
+
+ test("combines two WithFields") {
+ val originalQuery = testRelation
+ .select(Alias(
+ WithFields(
+ WithFields(
+ 'a,
+ Seq("b1"),
+ Seq(Literal(4))),
+ Seq("c1"),
+ Seq(Literal(5))), "out")())
+
+ val optimized = Optimize.execute(originalQuery.analyze)
+ val correctAnswer = testRelation
+ .select(Alias(WithFields('a, Seq("b1", "c1"), Seq(Literal(4), Literal(5))), "out")())
+ .analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("combines three WithFields") {
+ val originalQuery = testRelation
+ .select(Alias(
+ WithFields(
+ WithFields(
+ WithFields(
+ 'a,
+ Seq("b1"),
+ Seq(Literal(4))),
+ Seq("c1"),
+ Seq(Literal(5))),
+ Seq("d1"),
+ Seq(Literal(6))), "out")())
+
+ val optimized = Optimize.execute(originalQuery.analyze)
+ val correctAnswer = testRelation
+ .select(Alias(WithFields('a, Seq("b1", "c1", "d1"), Seq(4, 5, 6).map(Literal(_))), "out")())
+ .analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala
index d557460..c71e7db 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala
@@ -452,4 +452,61 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
checkEvaluation(GetMapValue(mb0, Literal(Array[Byte](2, 1), BinaryType)), "2")
checkEvaluation(GetMapValue(mb0, Literal(Array[Byte](3, 4))), null)
}
+
+ private val structAttr = 'struct1.struct('a.int)
+ private val testStructRelation = LocalRelation(structAttr)
+
+ test("simplify GetStructField on WithFields that is not changing the attribute being extracted") {
+ val query = testStructRelation.select(
+ GetStructField(WithFields('struct1, Seq("b"), Seq(Literal(1))), 0, Some("a")) as "outerAtt")
+ val expected = testStructRelation.select(GetStructField('struct1, 0, Some("a")) as "outerAtt")
+ checkRule(query, expected)
+ }
+
+ test("simplify GetStructField on WithFields that is changing the attribute being extracted") {
+ val query = testStructRelation.select(
+ GetStructField(WithFields('struct1, Seq("b"), Seq(Literal(1))), 1, Some("b")) as "outerAtt")
+ val expected = testStructRelation.select(Literal(1) as "outerAtt")
+ checkRule(query, expected)
+ }
+
+ test(
+ "simplify GetStructField on WithFields that is changing the attribute being extracted twice") {
+ val query = testStructRelation
+ .select(GetStructField(WithFields('struct1, Seq("b", "b"), Seq(Literal(1), Literal(2))), 1,
+ Some("b")) as "outerAtt")
+ val expected = testStructRelation.select(Literal(2) as "outerAtt")
+ checkRule(query, expected)
+ }
+
+ test("collapse multiple GetStructField on the same WithFields") {
+ val query = testStructRelation
+ .select(WithFields('struct1, Seq("b"), Seq(Literal(2))) as "struct2")
+ .select(
+ GetStructField('struct2, 0, Some("a")) as "struct1A",
+ GetStructField('struct2, 1, Some("b")) as "struct1B")
+ val expected = testStructRelation.select(
+ GetStructField('struct1, 0, Some("a")) as "struct1A",
+ Literal(2) as "struct1B")
+ checkRule(query, expected)
+ }
+
+ test("collapse multiple GetStructField on different WithFields") {
+ val query = testStructRelation
+ .select(
+ WithFields('struct1, Seq("b"), Seq(Literal(2))) as "struct2",
+ WithFields('struct1, Seq("b"), Seq(Literal(3))) as "struct3")
+ .select(
+ GetStructField('struct2, 0, Some("a")) as "struct2A",
+ GetStructField('struct2, 1, Some("b")) as "struct2B",
+ GetStructField('struct3, 0, Some("a")) as "struct3A",
+ GetStructField('struct3, 1, Some("b")) as "struct3B")
+ val expected = testStructRelation
+ .select(
+ GetStructField('struct1, 0, Some("a")) as "struct2A",
+ Literal(2) as "struct2B",
+ GetStructField('struct1, 0, Some("a")) as "struct3A",
+ Literal(3) as "struct3B")
+ checkRule(query, expected)
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index e6f7b1d..da542c6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -871,6 +871,72 @@ class Column(val expr: Expression) extends Logging {
*/
def getItem(key: Any): Column = withExpr { UnresolvedExtractValue(expr, Literal(key)) }
+ // scalastyle:off line.size.limit
+ /**
+ * An expression that adds/replaces field in `StructType` by name.
+ *
+ * {{{
+ * val df = sql("SELECT named_struct('a', 1, 'b', 2) struct_col")
+ * df.select($"struct_col".withField("c", lit(3)))
+ * // result: {"a":1,"b":2,"c":3}
+ *
+ * val df = sql("SELECT named_struct('a', 1, 'b', 2) struct_col")
+ * df.select($"struct_col".withField("b", lit(3)))
+ * // result: {"a":1,"b":3}
+ *
+ * val df = sql("SELECT CAST(NULL AS struct<a:int,b:int>) struct_col")
+ * df.select($"struct_col".withField("c", lit(3)))
+ * // result: null of type struct<a:int,b:int,c:int>
+ *
+ * val df = sql("SELECT named_struct('a', 1, 'b', 2, 'b', 3) struct_col")
+ * df.select($"struct_col".withField("b", lit(100)))
+ * // result: {"a":1,"b":100,"b":100}
+ *
+ * val df = sql("SELECT named_struct('a', named_struct('a', 1, 'b', 2)) struct_col")
+ * df.select($"struct_col".withField("a.c", lit(3)))
+ * // result: {"a":{"a":1,"b":2,"c":3}}
+ *
+ * val df = sql("SELECT named_struct('a', named_struct('b', 1), 'a', named_struct('c', 2)) struct_col")
+ * df.select($"struct_col".withField("a.c", lit(3)))
+ * // result: org.apache.spark.sql.AnalysisException: Ambiguous reference to fields
+ * }}}
+ *
+ * @group expr_ops
+ * @since 3.1.0
+ */
+ // scalastyle:on line.size.limit
+ def withField(fieldName: String, col: Column): Column = withExpr {
+ require(fieldName != null, "fieldName cannot be null")
+ require(col != null, "col cannot be null")
+
+ val nameParts = if (fieldName.isEmpty) {
+ fieldName :: Nil
+ } else {
+ CatalystSqlParser.parseMultipartIdentifier(fieldName)
+ }
+ withFieldHelper(expr, nameParts, Nil, col.expr)
+ }
+
+ private def withFieldHelper(
+ struct: Expression,
+ namePartsRemaining: Seq[String],
+ namePartsDone: Seq[String],
+ value: Expression) : WithFields = {
+ val name = namePartsRemaining.head
+ if (namePartsRemaining.length == 1) {
+ WithFields(struct, name :: Nil, value :: Nil)
+ } else {
+ val newNamesRemaining = namePartsRemaining.tail
+ val newNamesDone = namePartsDone :+ name
+ val newValue = withFieldHelper(
+ struct = UnresolvedExtractValue(struct, Literal(name)),
+ namePartsRemaining = newNamesRemaining,
+ namePartsDone = newNamesDone,
+ value = value)
+ WithFields(struct, name :: Nil, newValue :: Nil)
+ }
+ }
+
/**
* An expression that gets a field by name in a `StructType`.
*
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
index fa06484..131ab1b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
@@ -923,4 +923,503 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession {
val inSet = InSet(Literal("a"), Set("a", "b").map(UTF8String.fromString))
assert(inSet.sql === "('a' IN ('a', 'b'))")
}
+
+ def checkAnswerAndSchema(
+ df: => DataFrame,
+ expectedAnswer: Seq[Row],
+ expectedSchema: StructType): Unit = {
+
+ checkAnswer(df, expectedAnswer)
+ assert(df.schema == expectedSchema)
+ }
+
+ private lazy val structType = StructType(Seq(
+ StructField("a", IntegerType, nullable = false),
+ StructField("b", IntegerType, nullable = true),
+ StructField("c", IntegerType, nullable = false)))
+
+ private lazy val structLevel1: DataFrame = spark.createDataFrame(
+ sparkContext.parallelize(Row(Row(1, null, 3)) :: Nil),
+ StructType(Seq(StructField("a", structType, nullable = false))))
+
+ private lazy val nullStructLevel1: DataFrame = spark.createDataFrame(
+ sparkContext.parallelize(Row(null) :: Nil),
+ StructType(Seq(StructField("a", structType, nullable = true))))
+
+ private lazy val structLevel2: DataFrame = spark.createDataFrame(
+ sparkContext.parallelize(Row(Row(Row(1, null, 3))) :: Nil),
+ StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", structType, nullable = false))),
+ nullable = false))))
+
+ private lazy val nullStructLevel2: DataFrame = spark.createDataFrame(
+ sparkContext.parallelize(Row(Row(null)) :: Nil),
+ StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", structType, nullable = true))),
+ nullable = false))))
+
+ private lazy val structLevel3: DataFrame = spark.createDataFrame(
+ sparkContext.parallelize(Row(Row(Row(Row(1, null, 3)))) :: Nil),
+ StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", structType, nullable = false))),
+ nullable = false))),
+ nullable = false))))
+
+ test("withField should throw an exception if called on a non-StructType column") {
+ intercept[AnalysisException] {
+ testData.withColumn("key", $"key".withField("a", lit(2)))
+ }.getMessage should include("struct argument should be struct type, got: int")
+ }
+
+ test("withField should throw an exception if either fieldName or col argument are null") {
+ intercept[IllegalArgumentException] {
+ structLevel1.withColumn("a", $"a".withField(null, lit(2)))
+ }.getMessage should include("fieldName cannot be null")
+
+ intercept[IllegalArgumentException] {
+ structLevel1.withColumn("a", $"a".withField("b", null))
+ }.getMessage should include("col cannot be null")
+
+ intercept[IllegalArgumentException] {
+ structLevel1.withColumn("a", $"a".withField(null, null))
+ }.getMessage should include("fieldName cannot be null")
+ }
+
+ test("withField should throw an exception if any intermediate structs don't exist") {
+ intercept[AnalysisException] {
+ structLevel2.withColumn("a", 'a.withField("x.b", lit(2)))
+ }.getMessage should include("No such struct field x in a")
+
+ intercept[AnalysisException] {
+ structLevel3.withColumn("a", 'a.withField("a.x.b", lit(2)))
+ }.getMessage should include("No such struct field x in a")
+ }
+
+ test("withField should throw an exception if intermediate field is not a struct") {
+ intercept[AnalysisException] {
+ structLevel1.withColumn("a", 'a.withField("b.a", lit(2)))
+ }.getMessage should include("struct argument should be struct type, got: int")
+ }
+
+ test("withField should throw an exception if intermediate field reference is ambiguous") {
+ intercept[AnalysisException] {
+ val structLevel2: DataFrame = spark.createDataFrame(
+ sparkContext.parallelize(Row(Row(Row(1, null, 3), 4)) :: Nil),
+ StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", structType, nullable = false),
+ StructField("a", structType, nullable = false))),
+ nullable = false))))
+
+ structLevel2.withColumn("a", 'a.withField("a.b", lit(2)))
+ }.getMessage should include("Ambiguous reference to fields")
+ }
+
+ test("withField should add field with no name") {
+ checkAnswerAndSchema(
+ structLevel1.withColumn("a", $"a".withField("", lit(4))),
+ Row(Row(1, null, 3, 4)) :: Nil,
+ StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", IntegerType, nullable = false),
+ StructField("b", IntegerType, nullable = true),
+ StructField("c", IntegerType, nullable = false),
+ StructField("", IntegerType, nullable = false))),
+ nullable = false))))
+ }
+
+ test("withField should add field to struct") {
+ checkAnswerAndSchema(
+ structLevel1.withColumn("a", 'a.withField("d", lit(4))),
+ Row(Row(1, null, 3, 4)) :: Nil,
+ StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", IntegerType, nullable = false),
+ StructField("b", IntegerType, nullable = true),
+ StructField("c", IntegerType, nullable = false),
+ StructField("d", IntegerType, nullable = false))),
+ nullable = false))))
+ }
+
+ test("withField should add field to null struct") {
+ checkAnswerAndSchema(
+ nullStructLevel1.withColumn("a", $"a".withField("d", lit(4))),
+ Row(null) :: Nil,
+ StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", IntegerType, nullable = false),
+ StructField("b", IntegerType, nullable = true),
+ StructField("c", IntegerType, nullable = false),
+ StructField("d", IntegerType, nullable = false))),
+ nullable = true))))
+ }
+
+ test("withField should add field to nested null struct") {
+ checkAnswerAndSchema(
+ nullStructLevel2.withColumn("a", $"a".withField("a.d", lit(4))),
+ Row(Row(null)) :: Nil,
+ StructType(
+ Seq(StructField("a", StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", IntegerType, nullable = false),
+ StructField("b", IntegerType, nullable = true),
+ StructField("c", IntegerType, nullable = false),
+ StructField("d", IntegerType, nullable = false))),
+ nullable = true))),
+ nullable = false))))
+ }
+
+ test("withField should add null field to struct") {
+ checkAnswerAndSchema(
+ structLevel1.withColumn("a", 'a.withField("d", lit(null).cast(IntegerType))),
+ Row(Row(1, null, 3, null)) :: Nil,
+ StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", IntegerType, nullable = false),
+ StructField("b", IntegerType, nullable = true),
+ StructField("c", IntegerType, nullable = false),
+ StructField("d", IntegerType, nullable = true))),
+ nullable = false))))
+ }
+
+ test("withField should add multiple fields to struct") {
+ checkAnswerAndSchema(
+ structLevel1.withColumn("a", 'a.withField("d", lit(4)).withField("e", lit(5))),
+ Row(Row(1, null, 3, 4, 5)) :: Nil,
+ StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", IntegerType, nullable = false),
+ StructField("b", IntegerType, nullable = true),
+ StructField("c", IntegerType, nullable = false),
+ StructField("d", IntegerType, nullable = false),
+ StructField("e", IntegerType, nullable = false))),
+ nullable = false))))
+ }
+
+ test("withField should add field to nested struct") {
+ Seq(
+ structLevel2.withColumn("a", 'a.withField("a.d", lit(4))),
+ structLevel2.withColumn("a", 'a.withField("a", $"a.a".withField("d", lit(4))))
+ ).foreach { df =>
+ checkAnswerAndSchema(
+ df,
+ Row(Row(Row(1, null, 3, 4))) :: Nil,
+ StructType(
+ Seq(StructField("a", StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", IntegerType, nullable = false),
+ StructField("b", IntegerType, nullable = true),
+ StructField("c", IntegerType, nullable = false),
+ StructField("d", IntegerType, nullable = false))),
+ nullable = false))),
+ nullable = false))))
+ }
+ }
+
+ test("withField should add field to deeply nested struct") {
+ checkAnswerAndSchema(
+ structLevel3.withColumn("a", 'a.withField("a.a.d", lit(4))),
+ Row(Row(Row(Row(1, null, 3, 4)))) :: Nil,
+ StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", IntegerType, nullable = false),
+ StructField("b", IntegerType, nullable = true),
+ StructField("c", IntegerType, nullable = false),
+ StructField("d", IntegerType, nullable = false))),
+ nullable = false))),
+ nullable = false))),
+ nullable = false))))
+ }
+
+ test("withField should replace field in struct") {
+ checkAnswerAndSchema(
+ structLevel1.withColumn("a", 'a.withField("b", lit(2))),
+ Row(Row(1, 2, 3)) :: Nil,
+ StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", IntegerType, nullable = false),
+ StructField("b", IntegerType, nullable = false),
+ StructField("c", IntegerType, nullable = false))),
+ nullable = false))))
+ }
+
+ test("withField should replace field in null struct") {
+ checkAnswerAndSchema(
+ nullStructLevel1.withColumn("a", 'a.withField("b", lit("foo"))),
+ Row(null) :: Nil,
+ StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", IntegerType, nullable = false),
+ StructField("b", StringType, nullable = false),
+ StructField("c", IntegerType, nullable = false))),
+ nullable = true))))
+ }
+
+ test("withField should replace field in nested null struct") {
+ checkAnswerAndSchema(
+ nullStructLevel2.withColumn("a", $"a".withField("a.b", lit("foo"))),
+ Row(Row(null)) :: Nil,
+ StructType(
+ Seq(StructField("a", StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", IntegerType, nullable = false),
+ StructField("b", StringType, nullable = false),
+ StructField("c", IntegerType, nullable = false))),
+ nullable = true))),
+ nullable = false))))
+ }
+
+ test("withField should replace field with null value in struct") {
+ checkAnswerAndSchema(
+ structLevel1.withColumn("a", 'a.withField("c", lit(null).cast(IntegerType))),
+ Row(Row(1, null, null)) :: Nil,
+ StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", IntegerType, nullable = false),
+ StructField("b", IntegerType, nullable = true),
+ StructField("c", IntegerType, nullable = true))),
+ nullable = false))))
+ }
+
+ test("withField should replace multiple fields in struct") {
+ checkAnswerAndSchema(
+ structLevel1.withColumn("a", 'a.withField("a", lit(10)).withField("b", lit(20))),
+ Row(Row(10, 20, 3)) :: Nil,
+ StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", IntegerType, nullable = false),
+ StructField("b", IntegerType, nullable = false),
+ StructField("c", IntegerType, nullable = false))),
+ nullable = false))))
+ }
+
+ test("withField should replace field in nested struct") {
+ Seq(
+ structLevel2.withColumn("a", $"a".withField("a.b", lit(2))),
+ structLevel2.withColumn("a", 'a.withField("a", $"a.a".withField("b", lit(2))))
+ ).foreach { df =>
+ checkAnswerAndSchema(
+ df,
+ Row(Row(Row(1, 2, 3))) :: Nil,
+ StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", IntegerType, nullable = false),
+ StructField("b", IntegerType, nullable = false),
+ StructField("c", IntegerType, nullable = false))),
+ nullable = false))),
+ nullable = false))))
+ }
+ }
+
+ test("withField should replace field in deeply nested struct") {
+ checkAnswerAndSchema(
+ structLevel3.withColumn("a", $"a".withField("a.a.b", lit(2))),
+ Row(Row(Row(Row(1, 2, 3)))) :: Nil,
+ StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", IntegerType, nullable = false),
+ StructField("b", IntegerType, nullable = false),
+ StructField("c", IntegerType, nullable = false))),
+ nullable = false))),
+ nullable = false))),
+ nullable = false))))
+ }
+
+ test("withField should replace all fields with given name in struct") {
+ val structLevel1 = spark.createDataFrame(
+ sparkContext.parallelize(Row(Row(1, 2, 3)) :: Nil),
+ StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", IntegerType, nullable = false),
+ StructField("b", IntegerType, nullable = false),
+ StructField("b", IntegerType, nullable = false))),
+ nullable = false))))
+
+ checkAnswerAndSchema(
+ structLevel1.withColumn("a", 'a.withField("b", lit(100))),
+ Row(Row(1, 100, 100)) :: Nil,
+ StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", IntegerType, nullable = false),
+ StructField("b", IntegerType, nullable = false),
+ StructField("b", IntegerType, nullable = false))),
+ nullable = false))))
+ }
+
+ test("withField should replace fields in struct in given order") {
+ checkAnswerAndSchema(
+ structLevel1.withColumn("a", 'a.withField("b", lit(2)).withField("b", lit(20))),
+ Row(Row(1, 20, 3)) :: Nil,
+ StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", IntegerType, nullable = false),
+ StructField("b", IntegerType, nullable = false),
+ StructField("c", IntegerType, nullable = false))),
+ nullable = false))))
+ }
+
+ test("withField should add field and then replace same field in struct") {
+ checkAnswerAndSchema(
+ structLevel1.withColumn("a", 'a.withField("d", lit(4)).withField("d", lit(5))),
+ Row(Row(1, null, 3, 5)) :: Nil,
+ StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", IntegerType, nullable = false),
+ StructField("b", IntegerType, nullable = true),
+ StructField("c", IntegerType, nullable = false),
+ StructField("d", IntegerType, nullable = false))),
+ nullable = false))))
+ }
+
+ test("withField should handle fields with dots in their name if correctly quoted") {
+ val df: DataFrame = spark.createDataFrame(
+ sparkContext.parallelize(Row(Row(Row(1, null, 3))) :: Nil),
+ StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a.b", StructType(Seq(
+ StructField("c.d", IntegerType, nullable = false),
+ StructField("e.f", IntegerType, nullable = true),
+ StructField("g.h", IntegerType, nullable = false))),
+ nullable = false))),
+ nullable = false))))
+
+ checkAnswerAndSchema(
+ df.withColumn("a", 'a.withField("`a.b`.`e.f`", lit(2))),
+ Row(Row(Row(1, 2, 3))) :: Nil,
+ StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a.b", StructType(Seq(
+ StructField("c.d", IntegerType, nullable = false),
+ StructField("e.f", IntegerType, nullable = false),
+ StructField("g.h", IntegerType, nullable = false))),
+ nullable = false))),
+ nullable = false))))
+
+ intercept[AnalysisException] {
+ df.withColumn("a", 'a.withField("a.b.e.f", lit(2)))
+ }.getMessage should include("No such struct field a in a.b")
+ }
+
+ private lazy val mixedCaseStructLevel1: DataFrame = spark.createDataFrame(
+ sparkContext.parallelize(Row(Row(1, 1)) :: Nil),
+ StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", IntegerType, nullable = false),
+ StructField("B", IntegerType, nullable = false))),
+ nullable = false))))
+
+ test("withField should replace field in struct even if casing is different") {
+ withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") {
+ checkAnswerAndSchema(
+ mixedCaseStructLevel1.withColumn("a", 'a.withField("A", lit(2))),
+ Row(Row(2, 1)) :: Nil,
+ StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("A", IntegerType, nullable = false),
+ StructField("B", IntegerType, nullable = false))),
+ nullable = false))))
+
+ checkAnswerAndSchema(
+ mixedCaseStructLevel1.withColumn("a", 'a.withField("b", lit(2))),
+ Row(Row(1, 2)) :: Nil,
+ StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", IntegerType, nullable = false),
+ StructField("b", IntegerType, nullable = false))),
+ nullable = false))))
+ }
+ }
+
+ test("withField should add field to struct because casing is different") {
+ withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") {
+ checkAnswerAndSchema(
+ mixedCaseStructLevel1.withColumn("a", 'a.withField("A", lit(2))),
+ Row(Row(1, 1, 2)) :: Nil,
+ StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", IntegerType, nullable = false),
+ StructField("B", IntegerType, nullable = false),
+ StructField("A", IntegerType, nullable = false))),
+ nullable = false))))
+
+ checkAnswerAndSchema(
+ mixedCaseStructLevel1.withColumn("a", 'a.withField("b", lit(2))),
+ Row(Row(1, 1, 2)) :: Nil,
+ StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", IntegerType, nullable = false),
+ StructField("B", IntegerType, nullable = false),
+ StructField("b", IntegerType, nullable = false))),
+ nullable = false))))
+ }
+ }
+
+ private lazy val mixedCaseStructLevel2: DataFrame = spark.createDataFrame(
+ sparkContext.parallelize(Row(Row(Row(1, 1), Row(1, 1))) :: Nil),
+ StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", IntegerType, nullable = false),
+ StructField("b", IntegerType, nullable = false))),
+ nullable = false),
+ StructField("B", StructType(Seq(
+ StructField("a", IntegerType, nullable = false),
+ StructField("b", IntegerType, nullable = false))),
+ nullable = false))),
+ nullable = false))))
+
+ test("withField should replace nested field in struct even if casing is different") {
+ withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") {
+ checkAnswerAndSchema(
+ mixedCaseStructLevel2.withColumn("a", 'a.withField("A.a", lit(2))),
+ Row(Row(Row(2, 1), Row(1, 1))) :: Nil,
+ StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("A", StructType(Seq(
+ StructField("a", IntegerType, nullable = false),
+ StructField("b", IntegerType, nullable = false))),
+ nullable = false),
+ StructField("B", StructType(Seq(
+ StructField("a", IntegerType, nullable = false),
+ StructField("b", IntegerType, nullable = false))),
+ nullable = false))),
+ nullable = false))))
+
+ checkAnswerAndSchema(
+ mixedCaseStructLevel2.withColumn("a", 'a.withField("b.a", lit(2))),
+ Row(Row(Row(1, 1), Row(2, 1))) :: Nil,
+ StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", IntegerType, nullable = false),
+ StructField("b", IntegerType, nullable = false))),
+ nullable = false),
+ StructField("b", StructType(Seq(
+ StructField("a", IntegerType, nullable = false),
+ StructField("b", IntegerType, nullable = false))),
+ nullable = false))),
+ nullable = false))))
+ }
+ }
+
+ test("withField should throw an exception because casing is different") {
+ withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") {
+ intercept[AnalysisException] {
+ mixedCaseStructLevel2.withColumn("a", 'a.withField("A.a", lit(2)))
+ }.getMessage should include("No such struct field A in a, B")
+
+ intercept[AnalysisException] {
+ mixedCaseStructLevel2.withColumn("a", 'a.withField("b.a", lit(2)))
+ }.getMessage should include("No such struct field b in a, B")
+ }
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org