You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jk...@apache.org on 2016/02/02 20:16:37 UTC
spark git commit: [SPARK-12711][ML] ML StopWordsRemover does not
protect itself from column name duplication
Repository: spark
Updated Branches:
refs/heads/master 358300c79 -> b1835d727
[SPARK-12711][ML] ML StopWordsRemover does not protect itself from column name duplication
Fixes problem and verifies fix by test suite.
Also - adds optional parameter: nullable (Boolean) to: SchemaUtils.appendColumn
and deduplicates SchemaUtils.appendColumn functions.
Author: Grzegorz Chilkiewicz <gr...@codilime.com>
Closes #10741 from grzegorz-chilkiewicz/master.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/b1835d72
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/b1835d72
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/b1835d72
Branch: refs/heads/master
Commit: b1835d727234fdff42aa8cadd17ddcf43b0bed15
Parents: 358300c
Author: Grzegorz Chilkiewicz <gr...@codilime.com>
Authored: Tue Feb 2 11:16:24 2016 -0800
Committer: Joseph K. Bradley <jo...@databricks.com>
Committed: Tue Feb 2 11:16:24 2016 -0800
----------------------------------------------------------------------
.../apache/spark/ml/feature/StopWordsRemover.scala | 4 +---
.../scala/org/apache/spark/ml/util/SchemaUtils.scala | 8 +++-----
.../spark/ml/feature/StopWordsRemoverSuite.scala | 15 +++++++++++++++
3 files changed, 19 insertions(+), 8 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/b1835d72/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala
index b93c9ed..e53ef30 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala
@@ -149,9 +149,7 @@ class StopWordsRemover(override val uid: String)
val inputType = schema($(inputCol)).dataType
require(inputType.sameType(ArrayType(StringType)),
s"Input type must be ArrayType(StringType) but got $inputType.")
- val outputFields = schema.fields :+
- StructField($(outputCol), inputType, schema($(inputCol)).nullable)
- StructType(outputFields)
+ SchemaUtils.appendColumn(schema, $(outputCol), inputType, schema($(inputCol)).nullable)
}
override def copy(extra: ParamMap): StopWordsRemover = defaultCopy(extra)
http://git-wip-us.apache.org/repos/asf/spark/blob/b1835d72/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala
index e71dd9e..76021ad 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala
@@ -71,12 +71,10 @@ private[spark] object SchemaUtils {
def appendColumn(
schema: StructType,
colName: String,
- dataType: DataType): StructType = {
+ dataType: DataType,
+ nullable: Boolean = false): StructType = {
if (colName.isEmpty) return schema
- val fieldNames = schema.fieldNames
- require(!fieldNames.contains(colName), s"Column $colName already exists.")
- val outputFields = schema.fields :+ StructField(colName, dataType, nullable = false)
- StructType(outputFields)
+ appendColumn(schema, StructField(colName, dataType, nullable))
}
/**
http://git-wip-us.apache.org/repos/asf/spark/blob/b1835d72/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala
index fb217e0..a5b24c1 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala
@@ -89,4 +89,19 @@ class StopWordsRemoverSuite
.setCaseSensitive(true)
testDefaultReadWrite(t)
}
+
+ test("StopWordsRemover output column already exists") {
+ val outputCol = "expected"
+ val remover = new StopWordsRemover()
+ .setInputCol("raw")
+ .setOutputCol(outputCol)
+ val dataSet = sqlContext.createDataFrame(Seq(
+ (Seq("The", "the", "swift"), Seq("swift"))
+ )).toDF("raw", outputCol)
+
+ val thrown = intercept[IllegalArgumentException] {
+ testStopWordsRemover(remover, dataSet)
+ }
+ assert(thrown.getMessage == s"requirement failed: Column $outputCol already exists.")
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org