You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jk...@apache.org on 2015/06/22 03:25:42 UTC

spark git commit: [SPARK-7426] [MLLIB] [ML] Updated Attribute.fromStructField to allow any NumericType.

Repository: spark
Updated Branches:
  refs/heads/master a1894422a -> 47c1d5629


[SPARK-7426] [MLLIB] [ML] Updated Attribute.fromStructField to allow any NumericType.

Updated `Attribute.fromStructField` to allow any `NumericType`, rather than just `DoubleType`, and added unit tests for a few of the other NumericTypes.

Author: Mike Dusenberry <du...@gmail.com>

Closes #6540 from dusenberrymw/SPARK-7426_AttributeFactory.fromStructField_Should_Allow_NumericTypes and squashes the following commits:

87fecb3 [Mike Dusenberry] Updated Attribute.fromStructField to allow any NumericType, rather than just DoubleType, and added unit tests for a few of the other NumericTypes.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/47c1d562
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/47c1d562
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/47c1d562

Branch: refs/heads/master
Commit: 47c1d5629373566df9d12fdc4ceb22f38b869482
Parents: a189442
Author: Mike Dusenberry <du...@gmail.com>
Authored: Sun Jun 21 18:25:36 2015 -0700
Committer: Joseph K. Bradley <jo...@databricks.com>
Committed: Sun Jun 21 18:25:36 2015 -0700

----------------------------------------------------------------------
 .../main/scala/org/apache/spark/ml/attribute/attributes.scala   | 4 ++--
 .../scala/org/apache/spark/ml/attribute/AttributeSuite.scala    | 5 +++++
 2 files changed, 7 insertions(+), 2 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/47c1d562/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala b/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala
index ce43a45..e479f16 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala
@@ -20,7 +20,7 @@ package org.apache.spark.ml.attribute
 import scala.annotation.varargs
 
 import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.sql.types.{DoubleType, Metadata, MetadataBuilder, StructField}
+import org.apache.spark.sql.types.{DoubleType, NumericType, Metadata, MetadataBuilder, StructField}
 
 /**
  * :: DeveloperApi ::
@@ -127,7 +127,7 @@ private[attribute] trait AttributeFactory {
    * Creates an [[Attribute]] from a [[StructField]] instance.
    */
   def fromStructField(field: StructField): Attribute = {
-    require(field.dataType == DoubleType)
+    require(field.dataType.isInstanceOf[NumericType])
     val metadata = field.metadata
     val mlAttr = AttributeKeys.ML_ATTR
     if (metadata.contains(mlAttr)) {

http://git-wip-us.apache.org/repos/asf/spark/blob/47c1d562/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala
index 72b575d..c5fd2f9 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala
@@ -215,5 +215,10 @@ class AttributeSuite extends SparkFunSuite {
     assert(Attribute.fromStructField(fldWithoutMeta) == UnresolvedAttribute)
     val fldWithMeta = new StructField("x", DoubleType, false, metadata)
     assert(Attribute.fromStructField(fldWithMeta).isNumeric)
+    // Attribute.fromStructField should accept any NumericType, not just DoubleType
+    val longFldWithMeta = new StructField("x", LongType, false, metadata)
+    assert(Attribute.fromStructField(longFldWithMeta).isNumeric)
+    val decimalFldWithMeta = new StructField("x", DecimalType(None), false, metadata)
+    assert(Attribute.fromStructField(decimalFldWithMeta).isNumeric)
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org