You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by sr...@apache.org on 2020/02/25 02:20:44 UTC
[spark] branch branch-3.0 updated: [SPARK-30939][ML] Correctly set
output col when StringIndexer.setOutputCols is used
This is an automated email from the ASF dual-hosted git repository.
srowen pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.0 by this push:
new 52363d4 [SPARK-30939][ML] Correctly set output col when StringIndexer.setOutputCols is used
52363d4 is described below
commit 52363d4d975d0e49212245193234447ed2d38152
Author: Sean Owen <sr...@gmail.com>
AuthorDate: Mon Feb 24 20:18:10 2020 -0600
[SPARK-30939][ML] Correctly set output col when StringIndexer.setOutputCols is used
### What changes were proposed in this pull request?
Set the supplied output col name as intended when StringIndexer transforms an input after setOutputCols is used.
### Why are the changes needed?
The output col names are wrong otherwise and downstream pipeline components fail.
### Does this PR introduce any user-facing change?
Yes in the sense that it fixes incorrect behavior, otherwise no.
### How was this patch tested?
Existing tests plus new direct tests of the schema.
Closes #27684 from srowen/SPARK-30939.
Authored-by: Sean Owen <sr...@gmail.com>
Signed-off-by: Sean Owen <sr...@gmail.com>
(cherry picked from commit cc8d356e4faf1b581d330119c42ef73df466e828)
Signed-off-by: Sean Owen <sr...@gmail.com>
---
.../org/apache/spark/ml/feature/StringIndexer.scala | 2 +-
.../apache/spark/ml/feature/StringIndexerSuite.scala | 17 +++++++++++++++++
2 files changed, 18 insertions(+), 1 deletion(-)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
index 9f9f097..be32f44 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
@@ -107,7 +107,7 @@ private[feature] trait StringIndexerBase extends Params with HasHandleInvalid wi
s"but got $inputDataType.")
require(schema.fields.forall(_.name != outputColName),
s"Output column $outputColName already exists.")
- NominalAttribute.defaultAttr.withName($(outputCol)).toStructField()
+ NominalAttribute.defaultAttr.withName(outputColName).toStructField()
}
/** Validates and transforms the input schema. */
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
index b5ce2ba..9481408 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
@@ -96,6 +96,23 @@ class StringIndexerSuite extends MLTest with DefaultReadWriteTest {
}
}
+ test("StringIndexer.transformSchema)") {
+ val idxToStr = new StringIndexer().setInputCol("input").setOutputCol("output")
+ val inSchema = StructType(Seq(StructField("input", StringType)))
+ val outSchema = idxToStr.transformSchema(inSchema)
+ assert(outSchema("output").dataType === DoubleType)
+ }
+
+ test("StringIndexer.transformSchema multi col") {
+ val idxToStr = new StringIndexer().setInputCols(Array("input", "input2")).
+ setOutputCols(Array("output", "output2"))
+ val inSchema = StructType(Seq(StructField("input", StringType),
+ StructField("input2", StringType)))
+ val outSchema = idxToStr.transformSchema(inSchema)
+ assert(outSchema("output").dataType === DoubleType)
+ assert(outSchema("output2").dataType === DoubleType)
+ }
+
test("StringIndexerUnseen") {
val data = Seq((0, "a"), (1, "b"), (4, "b"))
val data2 = Seq((0, "a"), (1, "b"), (2, "c"), (3, "d"))
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org