You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by li...@apache.org on 2017/10/01 17:49:29 UTC

spark git commit: [SPARK-22001][ML][SQL] ImputerModel can do withColumn for all input columns at one pass

Repository: spark
Updated Branches:
  refs/heads/master 02c91e03f -> 3ca367083


[SPARK-22001][ML][SQL] ImputerModel can do withColumn for all input columns at one pass

## What changes were proposed in this pull request?

SPARK-21690 makes one-pass `Imputer` by parallelizing the computation of all input columns. When we transform dataset with `ImputerModel`, we do `withColumn` on all input columns sequentially. We can also do this on all input columns at once by adding a `withColumns` API to `Dataset`.

The new `withColumns` API is for internal use only now.

## How was this patch tested?

Existing tests for `ImputerModel`'s change. Added tests for `withColumns` API.

Author: Liang-Chi Hsieh <vi...@gmail.com>

Closes #19229 from viirya/SPARK-22001.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/3ca36708
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/3ca36708
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/3ca36708

Branch: refs/heads/master
Commit: 3ca367083e196e6487207211e6c49d4bbfe31288
Parents: 02c91e0
Author: Liang-Chi Hsieh <vi...@gmail.com>
Authored: Sun Oct 1 10:49:22 2017 -0700
Committer: gatorsmile <ga...@gmail.com>
Committed: Sun Oct 1 10:49:22 2017 -0700

----------------------------------------------------------------------
 .../org/apache/spark/ml/feature/Imputer.scala   | 10 ++--
 .../scala/org/apache/spark/sql/Dataset.scala    | 42 +++++++++++-----
 .../org/apache/spark/sql/DataFrameSuite.scala   | 52 ++++++++++++++++++++
 3 files changed, 86 insertions(+), 18 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/3ca36708/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala
index 1f36ece..4663f16 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala
@@ -223,20 +223,18 @@ class ImputerModel private[ml] (
 
   override def transform(dataset: Dataset[_]): DataFrame = {
     transformSchema(dataset.schema, logging = true)
-    var outputDF = dataset
     val surrogates = surrogateDF.select($(inputCols).map(col): _*).head().toSeq
 
-    $(inputCols).zip($(outputCols)).zip(surrogates).foreach {
+    val newCols = $(inputCols).zip($(outputCols)).zip(surrogates).map {
       case ((inputCol, outputCol), surrogate) =>
         val inputType = dataset.schema(inputCol).dataType
         val ic = col(inputCol)
-        outputDF = outputDF.withColumn(outputCol,
-          when(ic.isNull, surrogate)
+        when(ic.isNull, surrogate)
           .when(ic === $(missingValue), surrogate)
           .otherwise(ic)
-          .cast(inputType))
+          .cast(inputType)
     }
-    outputDF.toDF()
+    dataset.withColumns($(outputCols), newCols).toDF()
   }
 
   override def transformSchema(schema: StructType): StructType = {

http://git-wip-us.apache.org/repos/asf/spark/blob/3ca36708/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index ab0c412..f2a76a5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -2083,22 +2083,40 @@ class Dataset[T] private[sql](
    * @group untypedrel
    * @since 2.0.0
    */
-  def withColumn(colName: String, col: Column): DataFrame = {
+  def withColumn(colName: String, col: Column): DataFrame = withColumns(Seq(colName), Seq(col))
+
+  /**
+   * Returns a new Dataset by adding columns or replacing the existing columns that has
+   * the same names.
+   */
+  private[spark] def withColumns(colNames: Seq[String], cols: Seq[Column]): DataFrame = {
+    require(colNames.size == cols.size,
+      s"The size of column names: ${colNames.size} isn't equal to " +
+        s"the size of columns: ${cols.size}")
+    SchemaUtils.checkColumnNameDuplication(
+      colNames,
+      "in given column names",
+      sparkSession.sessionState.conf.caseSensitiveAnalysis)
+
     val resolver = sparkSession.sessionState.analyzer.resolver
     val output = queryExecution.analyzed.output
-    val shouldReplace = output.exists(f => resolver(f.name, colName))
-    if (shouldReplace) {
-      val columns = output.map { field =>
-        if (resolver(field.name, colName)) {
-          col.as(colName)
-        } else {
-          Column(field)
-        }
+
+    val columnMap = colNames.zip(cols).toMap
+
+    val replacedAndExistingColumns = output.map { field =>
+      columnMap.find { case (colName, _) =>
+        resolver(field.name, colName)
+      } match {
+        case Some((colName: String, col: Column)) => col.as(colName)
+        case _ => Column(field)
       }
-      select(columns : _*)
-    } else {
-      select(Column("*"), col.as(colName))
     }
+
+    val newColumns = columnMap.filter { case (colName, col) =>
+      !output.exists(f => resolver(f.name, colName))
+    }.map { case (colName, col) => col.as(colName) }
+
+    select(replacedAndExistingColumns ++ newColumns : _*)
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/3ca36708/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 0e2f2e5..672deea 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -641,6 +641,49 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
     assert(df.schema.map(_.name) === Seq("key", "value", "newCol"))
   }
 
+  test("withColumns") {
+    val df = testData.toDF().withColumns(Seq("newCol1", "newCol2"),
+      Seq(col("key") + 1, col("key") + 2))
+    checkAnswer(
+      df,
+      testData.collect().map { case Row(key: Int, value: String) =>
+        Row(key, value, key + 1, key + 2)
+      }.toSeq)
+    assert(df.schema.map(_.name) === Seq("key", "value", "newCol1", "newCol2"))
+
+    val err = intercept[IllegalArgumentException] {
+      testData.toDF().withColumns(Seq("newCol1"),
+        Seq(col("key") + 1, col("key") + 2))
+    }
+    assert(
+      err.getMessage.contains("The size of column names: 1 isn't equal to the size of columns: 2"))
+
+    val err2 = intercept[AnalysisException] {
+      testData.toDF().withColumns(Seq("newCol1", "newCOL1"),
+        Seq(col("key") + 1, col("key") + 2))
+    }
+    assert(err2.getMessage.contains("Found duplicate column(s)"))
+  }
+
+  test("withColumns: case sensitive") {
+    withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") {
+      val df = testData.toDF().withColumns(Seq("newCol1", "newCOL1"),
+        Seq(col("key") + 1, col("key") + 2))
+      checkAnswer(
+        df,
+        testData.collect().map { case Row(key: Int, value: String) =>
+          Row(key, value, key + 1, key + 2)
+        }.toSeq)
+      assert(df.schema.map(_.name) === Seq("key", "value", "newCol1", "newCOL1"))
+
+      val err = intercept[AnalysisException] {
+        testData.toDF().withColumns(Seq("newCol1", "newCol1"),
+          Seq(col("key") + 1, col("key") + 2))
+      }
+      assert(err.getMessage.contains("Found duplicate column(s)"))
+    }
+  }
+
   test("replace column using withColumn") {
     val df2 = sparkContext.parallelize(Array(1, 2, 3)).toDF("x")
     val df3 = df2.withColumn("x", df2("x") + 1)
@@ -649,6 +692,15 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
       Row(2) :: Row(3) :: Row(4) :: Nil)
   }
 
+  test("replace column using withColumns") {
+    val df2 = sparkContext.parallelize(Array((1, 2), (2, 3), (3, 4))).toDF("x", "y")
+    val df3 = df2.withColumns(Seq("x", "newCol1", "newCol2"),
+      Seq(df2("x") + 1, df2("y"), df2("y") + 1))
+    checkAnswer(
+      df3.select("x", "newCol1", "newCol2"),
+      Row(2, 2, 3) :: Row(3, 3, 4) :: Row(4, 4, 5) :: Nil)
+  }
+
   test("drop column using drop") {
     val df = testData.drop("key")
     checkAnswer(


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org