You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jk...@apache.org on 2015/06/04 00:16:30 UTC

spark git commit: [SPARK-8051] [MLLIB] make StringIndexerModel silent if input column does not exist

Repository: spark
Updated Branches:
  refs/heads/master d3e026f87 -> 26c9d7a0f


[SPARK-8051] [MLLIB] make StringIndexerModel silent if input column does not exist

This is just a workaround to a bigger problem. Some pipeline stages may not be effective during prediction, and they should not complain about missing required columns, e.g. `StringIndexerModel`. jkbradley

Author: Xiangrui Meng <me...@databricks.com>

Closes #6595 from mengxr/SPARK-8051 and squashes the following commits:

b6a36b9 [Xiangrui Meng] add doc
f143fd4 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-8051
8ee7c7e [Xiangrui Meng] use SparkFunSuite
e112394 [Xiangrui Meng] make StringIndexerModel silent if input column does not exist


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/26c9d7a0
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/26c9d7a0
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/26c9d7a0

Branch: refs/heads/master
Commit: 26c9d7a0f975009e22ec91e5c0b5cfcada79b35e
Parents: d3e026f
Author: Xiangrui Meng <me...@databricks.com>
Authored: Wed Jun 3 15:16:24 2015 -0700
Committer: Joseph K. Bradley <jo...@databricks.com>
Committed: Wed Jun 3 15:16:24 2015 -0700

----------------------------------------------------------------------
 .../org/apache/spark/ml/feature/StringIndexer.scala | 16 +++++++++++++++-
 .../spark/ml/feature/StringIndexerSuite.scala       |  8 ++++++++
 2 files changed, 23 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/26c9d7a0/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
index a2dc8a8..f4e2507 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
@@ -88,6 +88,9 @@ class StringIndexer(override val uid: String) extends Estimator[StringIndexerMod
 /**
  * :: Experimental ::
  * Model fitted by [[StringIndexer]].
+ * NOTE: During transformation, if the input column does not exist,
+ * [[StringIndexerModel.transform]] would return the input dataset unmodified.
+ * This is a temporary fix for the case when target labels do not exist during prediction.
  */
 @Experimental
 class StringIndexerModel private[ml] (
@@ -112,6 +115,12 @@ class StringIndexerModel private[ml] (
   def setOutputCol(value: String): this.type = set(outputCol, value)
 
   override def transform(dataset: DataFrame): DataFrame = {
+    if (!dataset.schema.fieldNames.contains($(inputCol))) {
+      logInfo(s"Input column ${$(inputCol)} does not exist during transformation. " +
+        "Skip StringIndexerModel.")
+      return dataset
+    }
+
     val indexer = udf { label: String =>
       if (labelToIndex.contains(label)) {
         labelToIndex(label)
@@ -128,6 +137,11 @@ class StringIndexerModel private[ml] (
   }
 
   override def transformSchema(schema: StructType): StructType = {
-    validateAndTransformSchema(schema)
+    if (schema.fieldNames.contains($(inputCol))) {
+      validateAndTransformSchema(schema)
+    } else {
+      // If the input column does not exist during transformation, we skip StringIndexerModel.
+      schema
+    }
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/26c9d7a0/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
index cbf1e8d..5f557e1 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
@@ -60,4 +60,12 @@ class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext {
     val expected = Set((0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0))
     assert(output === expected)
   }
+
+  test("StringIndexerModel should keep silent if the input column does not exist.") {
+    val indexerModel = new StringIndexerModel("indexer", Array("a", "b", "c"))
+      .setInputCol("label")
+      .setOutputCol("labelIndex")
+    val df = sqlContext.range(0L, 10L)
+    assert(indexerModel.transform(df).eq(df))
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org