You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jk...@apache.org on 2015/05/29 01:32:55 UTC

spark git commit: [SPARK-7198] [MLLIB] VectorAssembler should output ML attributes

Repository: spark
Updated Branches:
  refs/heads/master 3e312a5ed -> 7859ab659


[SPARK-7198] [MLLIB] VectorAssembler should output ML attributes

`VectorAssembler` should carry over ML attributes. For unknown attributes, we assume numeric values. This PR handles the following cases:

1. DoubleType with ML attribute: carry over
2. DoubleType without ML attribute: numeric value
3. Scalar type: numeric value
4. VectorType with all ML attributes: carry over and update names
5. VectorType with number of ML attributes: assume all numeric
6. VectorType without ML attributes: check the first row and get the number of attributes

jkbradley

Author: Xiangrui Meng <me...@databricks.com>

Closes #6452 from mengxr/SPARK-7198 and squashes the following commits:

a9d2469 [Xiangrui Meng] add space
facdb1f [Xiangrui Meng] VectorAssembler should output ML attributes


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/7859ab65
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/7859ab65
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/7859ab65

Branch: refs/heads/master
Commit: 7859ab659eecbcf2d8b9a274a4e9e4f5186a528c
Parents: 3e312a5
Author: Xiangrui Meng <me...@databricks.com>
Authored: Thu May 28 16:32:51 2015 -0700
Committer: Joseph K. Bradley <jo...@databricks.com>
Committed: Thu May 28 16:32:51 2015 -0700

----------------------------------------------------------------------
 .../spark/ml/feature/VectorAssembler.scala      | 51 ++++++++++++++++++--
 .../spark/ml/feature/VectorAssemblerSuite.scala | 37 ++++++++++++++
 2 files changed, 83 insertions(+), 5 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/7859ab65/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
index 514ffb0..229ee27 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
@@ -22,6 +22,7 @@ import scala.collection.mutable.ArrayBuilder
 import org.apache.spark.SparkException
 import org.apache.spark.annotation.Experimental
 import org.apache.spark.ml.Transformer
+import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute, UnresolvedAttribute}
 import org.apache.spark.ml.param.shared._
 import org.apache.spark.ml.util.Identifiable
 import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors}
@@ -37,7 +38,7 @@ import org.apache.spark.sql.types._
 class VectorAssembler(override val uid: String)
   extends Transformer with HasInputCols with HasOutputCol {
 
-  def this() = this(Identifiable.randomUID("va"))
+  def this() = this(Identifiable.randomUID("vecAssembler"))
 
   /** @group setParam */
   def setInputCols(value: Array[String]): this.type = set(inputCols, value)
@@ -46,19 +47,59 @@ class VectorAssembler(override val uid: String)
   def setOutputCol(value: String): this.type = set(outputCol, value)
 
   override def transform(dataset: DataFrame): DataFrame = {
+    // Schema transformation.
+    val schema = dataset.schema
+    lazy val first = dataset.first()
+    val attrs = $(inputCols).flatMap { c =>
+      val field = schema(c)
+      val index = schema.fieldIndex(c)
+      field.dataType match {
+        case DoubleType =>
+          val attr = Attribute.fromStructField(field)
+          // If the input column doesn't have ML attribute, assume numeric.
+          if (attr == UnresolvedAttribute) {
+            Some(NumericAttribute.defaultAttr.withName(c))
+          } else {
+            Some(attr.withName(c))
+          }
+        case _: NumericType | BooleanType =>
+          // If the input column type is a compatible scalar type, assume numeric.
+          Some(NumericAttribute.defaultAttr.withName(c))
+        case _: VectorUDT =>
+          val group = AttributeGroup.fromStructField(field)
+          if (group.attributes.isDefined) {
+            // If attributes are defined, copy them with updated names.
+            group.attributes.get.map { attr =>
+              if (attr.name.isDefined) {
+                // TODO: Define a rigorous naming scheme.
+                attr.withName(c + "_" + attr.name.get)
+              } else {
+                attr
+              }
+            }
+          } else {
+            // Otherwise, treat all attributes as numeric. If we cannot get the number of attributes
+            // from metadata, check the first row.
+            val numAttrs = group.numAttributes.getOrElse(first.getAs[Vector](index).size)
+            Array.fill(numAttrs)(NumericAttribute.defaultAttr)
+          }
+      }
+    }
+    val metadata = new AttributeGroup($(outputCol), attrs).toMetadata()
+
+    // Data transformation.
     val assembleFunc = udf { r: Row =>
       VectorAssembler.assemble(r.toSeq: _*)
     }
-    val schema = dataset.schema
-    val inputColNames = $(inputCols)
-    val args = inputColNames.map { c =>
+    val args = $(inputCols).map { c =>
       schema(c).dataType match {
         case DoubleType => dataset(c)
         case _: VectorUDT => dataset(c)
         case _: NumericType | BooleanType => dataset(c).cast(DoubleType).as(s"${c}_double_$uid")
       }
     }
-    dataset.select(col("*"), assembleFunc(struct(args : _*)).as($(outputCol)))
+
+    dataset.select(col("*"), assembleFunc(struct(args : _*)).as($(outputCol), metadata))
   }
 
   override def transformSchema(schema: StructType): StructType = {

http://git-wip-us.apache.org/repos/asf/spark/blob/7859ab65/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala
index d0cd62c..43534e8 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala
@@ -20,9 +20,11 @@ package org.apache.spark.ml.feature
 import org.scalatest.FunSuite
 
 import org.apache.spark.SparkException
+import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute, NumericAttribute}
 import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
 import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.sql.Row
+import org.apache.spark.sql.functions.col
 
 class VectorAssemblerSuite extends FunSuite with MLlibTestSparkContext {
 
@@ -61,4 +63,39 @@ class VectorAssemblerSuite extends FunSuite with MLlibTestSparkContext {
         assert(v === Vectors.sparse(6, Array(1, 2, 4, 5), Array(1.0, 2.0, 3.0, 10.0)))
     }
   }
+
+  test("ML attributes") {
+    val browser = NominalAttribute.defaultAttr.withValues("chrome", "firefox", "safari")
+    val hour = NumericAttribute.defaultAttr.withMin(0.0).withMax(24.0)
+    val user = new AttributeGroup("user", Array(
+      NominalAttribute.defaultAttr.withName("gender").withValues("male", "female"),
+      NumericAttribute.defaultAttr.withName("salary")))
+    val row = (1.0, 0.5, 1, Vectors.dense(1.0, 1000.0), Vectors.sparse(2, Array(1), Array(2.0)))
+    val df = sqlContext.createDataFrame(Seq(row)).toDF("browser", "hour", "count", "user", "ad")
+      .select(
+        col("browser").as("browser", browser.toMetadata()),
+        col("hour").as("hour", hour.toMetadata()),
+        col("count"), // "count" is an integer column without ML attribute
+        col("user").as("user", user.toMetadata()),
+        col("ad")) // "ad" is a vector column without ML attribute
+    val assembler = new VectorAssembler()
+      .setInputCols(Array("browser", "hour", "count", "user", "ad"))
+      .setOutputCol("features")
+    val output = assembler.transform(df)
+    val schema = output.schema
+    val features = AttributeGroup.fromStructField(schema("features"))
+    assert(features.size === 7)
+    val browserOut = features.getAttr(0)
+    assert(browserOut === browser.withIndex(0).withName("browser"))
+    val hourOut = features.getAttr(1)
+    assert(hourOut === hour.withIndex(1).withName("hour"))
+    val countOut = features.getAttr(2)
+    assert(countOut === NumericAttribute.defaultAttr.withName("count").withIndex(2))
+    val userGenderOut = features.getAttr(3)
+    assert(userGenderOut === user.getAttr("gender").withName("user_gender").withIndex(3))
+    val userSalaryOut = features.getAttr(4)
+    assert(userSalaryOut === user.getAttr("salary").withName("user_salary").withIndex(4))
+    assert(features.getAttr(5) === NumericAttribute.defaultAttr.withIndex(5))
+    assert(features.getAttr(6) === NumericAttribute.defaultAttr.withIndex(6))
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org