You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by sr...@apache.org on 2020/02/25 02:20:44 UTC

[spark] branch branch-3.0 updated: [SPARK-30939][ML] Correctly set output col when StringIndexer.setOutputCols is used

This is an automated email from the ASF dual-hosted git repository.

srowen pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.0 by this push:
     new 52363d4  [SPARK-30939][ML] Correctly set output col when StringIndexer.setOutputCols is used
52363d4 is described below

commit 52363d4d975d0e49212245193234447ed2d38152
Author: Sean Owen <sr...@gmail.com>
AuthorDate: Mon Feb 24 20:18:10 2020 -0600

    [SPARK-30939][ML] Correctly set output col when StringIndexer.setOutputCols is used
    
    ### What changes were proposed in this pull request?
    
    Set the supplied output col name as intended when StringIndexer transforms an input after setOutputCols is used.
    
    ### Why are the changes needed?
    
    The output col names are wrong otherwise and downstream pipeline components fail.
    
    ### Does this PR introduce any user-facing change?
    
    Yes in the sense that it fixes incorrect behavior, otherwise no.
    
    ### How was this patch tested?
    
    Existing tests plus new direct tests of the schema.
    
    Closes #27684 from srowen/SPARK-30939.
    
    Authored-by: Sean Owen <sr...@gmail.com>
    Signed-off-by: Sean Owen <sr...@gmail.com>
    (cherry picked from commit cc8d356e4faf1b581d330119c42ef73df466e828)
    Signed-off-by: Sean Owen <sr...@gmail.com>
---
 .../org/apache/spark/ml/feature/StringIndexer.scala     |  2 +-
 .../apache/spark/ml/feature/StringIndexerSuite.scala    | 17 +++++++++++++++++
 2 files changed, 18 insertions(+), 1 deletion(-)

diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
index 9f9f097..be32f44 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
@@ -107,7 +107,7 @@ private[feature] trait StringIndexerBase extends Params with HasHandleInvalid wi
         s"but got $inputDataType.")
     require(schema.fields.forall(_.name != outputColName),
       s"Output column $outputColName already exists.")
-    NominalAttribute.defaultAttr.withName($(outputCol)).toStructField()
+    NominalAttribute.defaultAttr.withName(outputColName).toStructField()
   }
 
   /** Validates and transforms the input schema. */
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
index b5ce2ba..9481408 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
@@ -96,6 +96,23 @@ class StringIndexerSuite extends MLTest with DefaultReadWriteTest {
     }
   }
 
+  test("StringIndexer.transformSchema)") {
+    val idxToStr = new StringIndexer().setInputCol("input").setOutputCol("output")
+    val inSchema = StructType(Seq(StructField("input", StringType)))
+    val outSchema = idxToStr.transformSchema(inSchema)
+    assert(outSchema("output").dataType === DoubleType)
+  }
+
+  test("StringIndexer.transformSchema multi col") {
+    val idxToStr = new StringIndexer().setInputCols(Array("input", "input2")).
+      setOutputCols(Array("output", "output2"))
+    val inSchema = StructType(Seq(StructField("input", StringType),
+      StructField("input2", StringType)))
+    val outSchema = idxToStr.transformSchema(inSchema)
+    assert(outSchema("output").dataType === DoubleType)
+    assert(outSchema("output2").dataType === DoubleType)
+  }
+
   test("StringIndexerUnseen") {
     val data = Seq((0, "a"), (1, "b"), (4, "b"))
     val data2 = Seq((0, "a"), (1, "b"), (2, "c"), (3, "d"))


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org