You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jk...@apache.org on 2015/04/30 01:35:20 UTC

spark git commit: [SPARK-7259] [ML] VectorIndexer: do not copy non-ML metadata to output column

Repository: spark
Updated Branches:
  refs/heads/master f8cbb0a4b -> b1ef6a60f


[SPARK-7259] [ML] VectorIndexer: do not copy non-ML metadata to output column

Changed VectorIndexer so it does not carry non-ML metadata from the input to the output column.  Removed ml.util.TestingUtils since VectorIndexer was the only use.

CC: mengxr

Author: Joseph K. Bradley <jo...@databricks.com>

Closes #5789 from jkbradley/vector-indexer-metadata and squashes the following commits:

b28e159 [Joseph K. Bradley] Changed VectorIndexer so it does not carry non-ML metadata from the input to the output column.  Removed ml.util.TestingUtils since VectorIndexer was the only use.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/b1ef6a60
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/b1ef6a60
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/b1ef6a60

Branch: refs/heads/master
Commit: b1ef6a60ff6ea2adb43c6544e5311c11f4364f64
Parents: f8cbb0a
Author: Joseph K. Bradley <jo...@databricks.com>
Authored: Wed Apr 29 16:35:17 2015 -0700
Committer: Joseph K. Bradley <jo...@databricks.com>
Committed: Wed Apr 29 16:35:17 2015 -0700

----------------------------------------------------------------------
 .../apache/spark/ml/feature/VectorIndexer.scala | 69 ++++++++++----------
 .../spark/ml/feature/VectorIndexerSuite.scala   |  7 +-
 .../org/apache/spark/ml/util/TestingUtils.scala | 60 -----------------
 3 files changed, 37 insertions(+), 99 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/b1ef6a60/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
index 452faa0..1e5ffd1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
@@ -233,6 +233,7 @@ private object VectorIndexer {
  *  - Continuous features (columns) are left unchanged.
  * This also appends metadata to the output column, marking features as Numeric (continuous),
  * Nominal (categorical), or Binary (either continuous or categorical).
+ * Non-ML metadata is not carried over from the input to the output column.
  *
  * This maintains vector sparsity.
  *
@@ -283,34 +284,40 @@ class VectorIndexerModel private[ml] (
 
   // TODO: Check more carefully about whether this whole class will be included in a closure.
 
+  /** Per-vector transform function */
   private val transformFunc: Vector => Vector = {
-    val sortedCategoricalFeatureIndices = categoryMaps.keys.toArray.sorted
+    val sortedCatFeatureIndices = categoryMaps.keys.toArray.sorted
     val localVectorMap = categoryMaps
-    val f: Vector => Vector = {
-      case dv: DenseVector =>
-        val tmpv = dv.copy
-        localVectorMap.foreach { case (featureIndex: Int, categoryMap: Map[Double, Int]) =>
-          tmpv.values(featureIndex) = categoryMap(tmpv(featureIndex))
-        }
-        tmpv
-      case sv: SparseVector =>
-        // We use the fact that categorical value 0 is always mapped to index 0.
-        val tmpv = sv.copy
-        var catFeatureIdx = 0 // index into sortedCategoricalFeatureIndices
-        var k = 0 // index into non-zero elements of sparse vector
-        while (catFeatureIdx < sortedCategoricalFeatureIndices.length && k < tmpv.indices.length) {
-          val featureIndex = sortedCategoricalFeatureIndices(catFeatureIdx)
-          if (featureIndex < tmpv.indices(k)) {
-            catFeatureIdx += 1
-          } else if (featureIndex > tmpv.indices(k)) {
-            k += 1
-          } else {
-            tmpv.values(k) = localVectorMap(featureIndex)(tmpv.values(k))
-            catFeatureIdx += 1
-            k += 1
+    val localNumFeatures = numFeatures
+    val f: Vector => Vector = { (v: Vector) =>
+      assert(v.size == localNumFeatures, "VectorIndexerModel expected vector of length" +
+        s" $numFeatures but found length ${v.size}")
+      v match {
+        case dv: DenseVector =>
+          val tmpv = dv.copy
+          localVectorMap.foreach { case (featureIndex: Int, categoryMap: Map[Double, Int]) =>
+            tmpv.values(featureIndex) = categoryMap(tmpv(featureIndex))
           }
-        }
-        tmpv
+          tmpv
+        case sv: SparseVector =>
+          // We use the fact that categorical value 0 is always mapped to index 0.
+          val tmpv = sv.copy
+          var catFeatureIdx = 0 // index into sortedCatFeatureIndices
+          var k = 0 // index into non-zero elements of sparse vector
+          while (catFeatureIdx < sortedCatFeatureIndices.length && k < tmpv.indices.length) {
+            val featureIndex = sortedCatFeatureIndices(catFeatureIdx)
+            if (featureIndex < tmpv.indices(k)) {
+              catFeatureIdx += 1
+            } else if (featureIndex > tmpv.indices(k)) {
+              k += 1
+            } else {
+              tmpv.values(k) = localVectorMap(featureIndex)(tmpv.values(k))
+              catFeatureIdx += 1
+              k += 1
+            }
+          }
+          tmpv
+      }
     }
     f
   }
@@ -326,13 +333,6 @@ class VectorIndexerModel private[ml] (
     val map = extractParamMap(paramMap)
     val newField = prepOutputField(dataset.schema, map)
     val newCol = callUDF(transformFunc, new VectorUDT, dataset(map(inputCol)))
-    // For now, just check the first row of inputCol for vector length.
-    val firstRow = dataset.select(map(inputCol)).take(1)
-    if (firstRow.length != 0) {
-      val actualNumFeatures = firstRow(0).getAs[Vector](0).size
-      require(numFeatures == actualNumFeatures, "VectorIndexerModel expected vector of length" +
-        s" $numFeatures but found length $actualNumFeatures")
-    }
     dataset.withColumn(map(outputCol), newCol.as(map(outputCol), newField.metadata))
   }
 
@@ -345,6 +345,7 @@ class VectorIndexerModel private[ml] (
       s"VectorIndexerModel requires output column parameter: $outputCol")
     SchemaUtils.checkColumnType(schema, map(inputCol), dataType)
 
+    // If the input metadata specifies numFeatures, compare with expected numFeatures.
     val origAttrGroup = AttributeGroup.fromStructField(schema(map(inputCol)))
     val origNumFeatures: Option[Int] = if (origAttrGroup.attributes.nonEmpty) {
       Some(origAttrGroup.attributes.get.length)
@@ -364,7 +365,7 @@ class VectorIndexerModel private[ml] (
    * Prepare the output column field, including per-feature metadata.
    * @param schema  Input schema
    * @param map  Parameter map (with this class' embedded parameter map folded in)
-   * @return  Output column field
+   * @return  Output column field.  This field does not contain non-ML metadata.
    */
   private def prepOutputField(schema: StructType, map: ParamMap): StructField = {
     val origAttrGroup = AttributeGroup.fromStructField(schema(map(inputCol)))
@@ -391,6 +392,6 @@ class VectorIndexerModel private[ml] (
       partialFeatureAttributes
     }
     val newAttributeGroup = new AttributeGroup(map(outputCol), featureAttributes)
-    newAttributeGroup.toStructField(schema(map(inputCol)).metadata)
+    newAttributeGroup.toStructField()
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/b1ef6a60/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala
index 1b261b2..38dc83b 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala
@@ -23,7 +23,6 @@ import org.scalatest.FunSuite
 
 import org.apache.spark.SparkException
 import org.apache.spark.ml.attribute._
-import org.apache.spark.ml.util.TestingUtils
 import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors}
 import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.rdd.RDD
@@ -111,8 +110,8 @@ class VectorIndexerSuite extends FunSuite with MLlibTestSparkContext {
     val model = vectorIndexer.fit(densePoints1) // vectors of length 3
     model.transform(densePoints1) // should work
     model.transform(sparsePoints1) // should work
-    intercept[IllegalArgumentException] {
-      model.transform(densePoints2)
+    intercept[SparkException] {
+      model.transform(densePoints2).collect()
       println("Did not throw error when fit, transform were called on vectors of different lengths")
     }
     intercept[SparkException] {
@@ -245,8 +244,6 @@ class VectorIndexerSuite extends FunSuite with MLlibTestSparkContext {
           // TODO: Once input features marked as categorical are handled correctly, check that here.
       }
     }
-    // Check that non-ML metadata are preserved.
-    TestingUtils.testPreserveMetadata(densePoints1WithMeta, model, "features", "indexed")
   }
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/b1ef6a60/mllib/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala b/mllib/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala
deleted file mode 100644
index c44cb61..0000000
--- a/mllib/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala
+++ /dev/null
@@ -1,60 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.ml.util
-
-import org.apache.spark.ml.Transformer
-import org.apache.spark.sql.DataFrame
-import org.apache.spark.sql.types.MetadataBuilder
-import org.scalatest.FunSuite
-
-private[ml] object TestingUtils extends FunSuite {
-
-  /**
-   * Test whether unrelated metadata are preserved for this transformer.
-   * This attaches extra metadata to a column, transforms the column, and check to ensure the
-   * extra metadata have not changed.
-   * @param data  Input dataset
-   * @param transformer  Transformer to test
-   * @param inputCol  Unique input column for Transformer.  This must be the ONLY input column.
-   * @param outputCol  Output column to test for metadata presence.
-   */
-  def testPreserveMetadata(
-      data: DataFrame,
-      transformer: Transformer,
-      inputCol: String,
-      outputCol: String): Unit = {
-    // Create some fake metadata
-    val origMetadata = data.schema(inputCol).metadata
-    val metaKey = "__testPreserveMetadata__fake_key"
-    val metaValue = 12345
-    assert(!origMetadata.contains(metaKey),
-      s"Unit test with testPreserveMetadata will fail since metadata key was present: $metaKey")
-    val newMetadata =
-      new MetadataBuilder().withMetadata(origMetadata).putLong(metaKey, metaValue).build()
-    // Add metadata to the inputCol
-    val withMetadata = data.select(data(inputCol).as(inputCol, newMetadata))
-    // Transform, and ensure extra metadata was not affected
-    val transformed = transformer.transform(withMetadata)
-    val transMetadata = transformed.schema(outputCol).metadata
-    assert(transMetadata.contains(metaKey),
-      "Unit test with testPreserveMetadata failed; extra metadata key was not present.")
-    assert(transMetadata.getLong(metaKey) === metaValue,
-      "Unit test with testPreserveMetadata failed; extra metadata value was wrong." +
-        s" Expected $metaValue but found ${transMetadata.getLong(metaKey)}")
-  }
-}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org