You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by li...@apache.org on 2017/10/01 17:49:29 UTC
spark git commit: [SPARK-22001][ML][SQL] ImputerModel can do
withColumn for all input columns at one pass
Repository: spark
Updated Branches:
refs/heads/master 02c91e03f -> 3ca367083
[SPARK-22001][ML][SQL] ImputerModel can do withColumn for all input columns at one pass
## What changes were proposed in this pull request?
SPARK-21690 makes one-pass `Imputer` by parallelizing the computation of all input columns. When we transform dataset with `ImputerModel`, we do `withColumn` on all input columns sequentially. We can also do this on all input columns at once by adding a `withColumns` API to `Dataset`.
The new `withColumns` API is for internal use only now.
## How was this patch tested?
Existing tests for `ImputerModel`'s change. Added tests for `withColumns` API.
Author: Liang-Chi Hsieh <vi...@gmail.com>
Closes #19229 from viirya/SPARK-22001.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/3ca36708
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/3ca36708
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/3ca36708
Branch: refs/heads/master
Commit: 3ca367083e196e6487207211e6c49d4bbfe31288
Parents: 02c91e0
Author: Liang-Chi Hsieh <vi...@gmail.com>
Authored: Sun Oct 1 10:49:22 2017 -0700
Committer: gatorsmile <ga...@gmail.com>
Committed: Sun Oct 1 10:49:22 2017 -0700
----------------------------------------------------------------------
.../org/apache/spark/ml/feature/Imputer.scala | 10 ++--
.../scala/org/apache/spark/sql/Dataset.scala | 42 +++++++++++-----
.../org/apache/spark/sql/DataFrameSuite.scala | 52 ++++++++++++++++++++
3 files changed, 86 insertions(+), 18 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/3ca36708/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala
index 1f36ece..4663f16 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala
@@ -223,20 +223,18 @@ class ImputerModel private[ml] (
override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
- var outputDF = dataset
val surrogates = surrogateDF.select($(inputCols).map(col): _*).head().toSeq
- $(inputCols).zip($(outputCols)).zip(surrogates).foreach {
+ val newCols = $(inputCols).zip($(outputCols)).zip(surrogates).map {
case ((inputCol, outputCol), surrogate) =>
val inputType = dataset.schema(inputCol).dataType
val ic = col(inputCol)
- outputDF = outputDF.withColumn(outputCol,
- when(ic.isNull, surrogate)
+ when(ic.isNull, surrogate)
.when(ic === $(missingValue), surrogate)
.otherwise(ic)
- .cast(inputType))
+ .cast(inputType)
}
- outputDF.toDF()
+ dataset.withColumns($(outputCols), newCols).toDF()
}
override def transformSchema(schema: StructType): StructType = {
http://git-wip-us.apache.org/repos/asf/spark/blob/3ca36708/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index ab0c412..f2a76a5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -2083,22 +2083,40 @@ class Dataset[T] private[sql](
* @group untypedrel
* @since 2.0.0
*/
- def withColumn(colName: String, col: Column): DataFrame = {
+ def withColumn(colName: String, col: Column): DataFrame = withColumns(Seq(colName), Seq(col))
+
+ /**
+ * Returns a new Dataset by adding columns or replacing the existing columns that has
+ * the same names.
+ */
+ private[spark] def withColumns(colNames: Seq[String], cols: Seq[Column]): DataFrame = {
+ require(colNames.size == cols.size,
+ s"The size of column names: ${colNames.size} isn't equal to " +
+ s"the size of columns: ${cols.size}")
+ SchemaUtils.checkColumnNameDuplication(
+ colNames,
+ "in given column names",
+ sparkSession.sessionState.conf.caseSensitiveAnalysis)
+
val resolver = sparkSession.sessionState.analyzer.resolver
val output = queryExecution.analyzed.output
- val shouldReplace = output.exists(f => resolver(f.name, colName))
- if (shouldReplace) {
- val columns = output.map { field =>
- if (resolver(field.name, colName)) {
- col.as(colName)
- } else {
- Column(field)
- }
+
+ val columnMap = colNames.zip(cols).toMap
+
+ val replacedAndExistingColumns = output.map { field =>
+ columnMap.find { case (colName, _) =>
+ resolver(field.name, colName)
+ } match {
+ case Some((colName: String, col: Column)) => col.as(colName)
+ case _ => Column(field)
}
- select(columns : _*)
- } else {
- select(Column("*"), col.as(colName))
}
+
+ val newColumns = columnMap.filter { case (colName, col) =>
+ !output.exists(f => resolver(f.name, colName))
+ }.map { case (colName, col) => col.as(colName) }
+
+ select(replacedAndExistingColumns ++ newColumns : _*)
}
/**
http://git-wip-us.apache.org/repos/asf/spark/blob/3ca36708/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 0e2f2e5..672deea 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -641,6 +641,49 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
assert(df.schema.map(_.name) === Seq("key", "value", "newCol"))
}
+ test("withColumns") {
+ val df = testData.toDF().withColumns(Seq("newCol1", "newCol2"),
+ Seq(col("key") + 1, col("key") + 2))
+ checkAnswer(
+ df,
+ testData.collect().map { case Row(key: Int, value: String) =>
+ Row(key, value, key + 1, key + 2)
+ }.toSeq)
+ assert(df.schema.map(_.name) === Seq("key", "value", "newCol1", "newCol2"))
+
+ val err = intercept[IllegalArgumentException] {
+ testData.toDF().withColumns(Seq("newCol1"),
+ Seq(col("key") + 1, col("key") + 2))
+ }
+ assert(
+ err.getMessage.contains("The size of column names: 1 isn't equal to the size of columns: 2"))
+
+ val err2 = intercept[AnalysisException] {
+ testData.toDF().withColumns(Seq("newCol1", "newCOL1"),
+ Seq(col("key") + 1, col("key") + 2))
+ }
+ assert(err2.getMessage.contains("Found duplicate column(s)"))
+ }
+
+ test("withColumns: case sensitive") {
+ withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") {
+ val df = testData.toDF().withColumns(Seq("newCol1", "newCOL1"),
+ Seq(col("key") + 1, col("key") + 2))
+ checkAnswer(
+ df,
+ testData.collect().map { case Row(key: Int, value: String) =>
+ Row(key, value, key + 1, key + 2)
+ }.toSeq)
+ assert(df.schema.map(_.name) === Seq("key", "value", "newCol1", "newCOL1"))
+
+ val err = intercept[AnalysisException] {
+ testData.toDF().withColumns(Seq("newCol1", "newCol1"),
+ Seq(col("key") + 1, col("key") + 2))
+ }
+ assert(err.getMessage.contains("Found duplicate column(s)"))
+ }
+ }
+
test("replace column using withColumn") {
val df2 = sparkContext.parallelize(Array(1, 2, 3)).toDF("x")
val df3 = df2.withColumn("x", df2("x") + 1)
@@ -649,6 +692,15 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
Row(2) :: Row(3) :: Row(4) :: Nil)
}
+ test("replace column using withColumns") {
+ val df2 = sparkContext.parallelize(Array((1, 2), (2, 3), (3, 4))).toDF("x", "y")
+ val df3 = df2.withColumns(Seq("x", "newCol1", "newCol2"),
+ Seq(df2("x") + 1, df2("y"), df2("y") + 1))
+ checkAnswer(
+ df3.select("x", "newCol1", "newCol2"),
+ Row(2, 2, 3) :: Row(3, 3, 4) :: Row(4, 4, 5) :: Nil)
+ }
+
test("drop column using drop") {
val df = testData.drop("key")
checkAnswer(
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org