You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hivemall.apache.org by ya...@apache.org on 2017/01/26 07:14:40 UTC
incubator-hivemall git commit: Close #25: [HIVEMALL-34] Fix a bug to
wrongly use mllib vectors in some functions
Repository: incubator-hivemall
Updated Branches:
refs/heads/master 468849441 -> b90999664
Close #25: [HIVEMALL-34] Fix a bug to wrongly use mllib vectors in some functions
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/b9099966
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/b9099966
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/b9099966
Branch: refs/heads/master
Commit: b90999664b9c39edb291cabcb5baa12a89069fca
Parents: 4688494
Author: Takeshi YAMAMURO <li...@gmail.com>
Authored: Thu Jan 26 16:13:42 2017 +0900
Committer: Takeshi YAMAMURO <li...@gmail.com>
Committed: Thu Jan 26 16:13:42 2017 +0900
----------------------------------------------------------------------
.../org/apache/spark/sql/hive/HivemallOps.scala | 6 ++--
.../apache/spark/sql/hive/HivemallUtils.scala | 37 ++++++++++----------
.../apache/spark/sql/hive/HiveUdfSuite.scala | 22 +++++-------
3 files changed, 29 insertions(+), 36 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/b9099966/spark/spark-2.0/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala
----------------------------------------------------------------------
diff --git a/spark/spark-2.0/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala b/spark/spark-2.0/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala
index 9bde84f..f233a2a 100644
--- a/spark/spark-2.0/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala
+++ b/spark/spark-2.0/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala
@@ -25,7 +25,7 @@ import scala.collection.JavaConverters._
import org.apache.spark.annotation.Experimental
import org.apache.spark.internal.Logging
import org.apache.spark.ml.feature.HivemallFeature
-import org.apache.spark.ml.linalg.{DenseVector => SDV, SparseVector => SSV, VectorUDT}
+import org.apache.spark.ml.linalg.{DenseVector, SparseVector, VectorUDT}
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
@@ -764,12 +764,12 @@ final class HivemallOps(df: DataFrame) extends Logging {
StructField("feature", StringType) :: StructField("weight", DoubleType) :: Nil)
val explodeFunc: Row => TraversableOnce[InternalRow] = (row: Row) => {
row.get(0) match {
- case dv: SDV =>
+ case dv: DenseVector =>
dv.values.zipWithIndex.map {
case (value, index) =>
InternalRow(UTF8String.fromString(s"$index"), value)
}
- case sv: SSV =>
+ case sv: SparseVector =>
sv.values.zip(sv.indices).map {
case (value, index) =>
InternalRow(UTF8String.fromString(s"$index"), value)
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/b9099966/spark/spark-2.0/src/main/scala/org/apache/spark/sql/hive/HivemallUtils.scala
----------------------------------------------------------------------
diff --git a/spark/spark-2.0/src/main/scala/org/apache/spark/sql/hive/HivemallUtils.scala b/spark/spark-2.0/src/main/scala/org/apache/spark/sql/hive/HivemallUtils.scala
index 6924347..b7b7071 100644
--- a/spark/spark-2.0/src/main/scala/org/apache/spark/sql/hive/HivemallUtils.scala
+++ b/spark/spark-2.0/src/main/scala/org/apache/spark/sql/hive/HivemallUtils.scala
@@ -18,8 +18,7 @@
*/
package org.apache.spark.sql.hive
-import org.apache.spark.ml.linalg.{DenseVector => SDV, SparseVector => SSV, Vector => SV}
-import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors}
+import org.apache.spark.ml.linalg.{BLAS, DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.sql.{Column, DataFrame, Row}
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.expressions.UserDefinedFunction
@@ -46,17 +45,12 @@ object HivemallUtils {
@inline implicit def toStringArrayLiteral(i: Seq[String]): Column =
Column(Literal.create(i, ArrayType(StringType)))
- /**
- * Transforms `org.apache.spark.ml.linalg.Vector` into Hivemall features.
- */
- def to_hivemall_features: UserDefinedFunction = udf(_to_hivemall_features)
-
- private[hive] def _to_hivemall_features = (v: SV) => v match {
- case dv: SDV =>
+ def to_hivemall_features_func(): Vector => Array[String] = {
+ case dv: DenseVector =>
dv.values.zipWithIndex.map {
case (value, index) => s"$index:$value"
}
- case sv: SSV =>
+ case sv: SparseVector =>
sv.values.zip(sv.indices).map {
case (value, index) => s"$index:$value"
}
@@ -64,21 +58,15 @@ object HivemallUtils {
throw new IllegalArgumentException(s"Do not support vector type ${v.getClass}")
}
- /**
- * Returns a new vector with `1.0` (bias) appended to the input vector.
- * @group ftvec
- */
- def append_bias: UserDefinedFunction = udf(_append_bias)
-
- private[hive] def _append_bias = (v: SV) => v match {
- case dv: SDV =>
+ def append_bias_func(): Vector => Vector = {
+ case dv: DenseVector =>
val inputValues = dv.values
val inputLength = inputValues.length
val outputValues = Array.ofDim[Double](inputLength + 1)
System.arraycopy(inputValues, 0, outputValues, 0, inputLength)
outputValues(inputLength) = 1.0
Vectors.dense(outputValues)
- case sv: SSV =>
+ case sv: SparseVector =>
val inputValues = sv.values
val inputIndices = sv.indices
val inputValuesLength = inputValues.length
@@ -95,6 +83,17 @@ object HivemallUtils {
}
/**
+ * Transforms `org.apache.spark.ml.linalg.Vector` into Hivemall features.
+ */
+ def to_hivemall_features: UserDefinedFunction = udf(to_hivemall_features_func)
+
+ /**
+ * Returns a new vector with `1.0` (bias) appended to the input vector.
+ * @group ftvec
+ */
+ def append_bias: UserDefinedFunction = udf(append_bias_func)
+
+ /**
* Make up a function object from a Hivemall model.
*/
def funcModel(df: DataFrame, dense: Boolean = false, dims: Int = maxDims): UserDefinedFunction = {
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/b9099966/spark/spark-2.0/src/test/scala/org/apache/spark/sql/hive/HiveUdfSuite.scala
----------------------------------------------------------------------
diff --git a/spark/spark-2.0/src/test/scala/org/apache/spark/sql/hive/HiveUdfSuite.scala b/spark/spark-2.0/src/test/scala/org/apache/spark/sql/hive/HiveUdfSuite.scala
index f8622c6..d53ef73 100644
--- a/spark/spark-2.0/src/test/scala/org/apache/spark/sql/hive/HiveUdfSuite.scala
+++ b/spark/spark-2.0/src/test/scala/org/apache/spark/sql/hive/HiveUdfSuite.scala
@@ -118,7 +118,7 @@ final class HiveUdfWithVectorSuite extends VectorQueryTest {
test("to_hivemall_features") {
mllibTrainDf.createOrReplaceTempView("mllibTrainDf")
- hiveContext.udf.register("to_hivemall_features", _to_hivemall_features)
+ hiveContext.udf.register("to_hivemall_features", to_hivemall_features_func)
checkAnswer(
sql(
s"""
@@ -134,16 +134,10 @@ final class HiveUdfWithVectorSuite extends VectorQueryTest {
)
}
- ignore("append_bias") {
+ test("append_bias") {
mllibTrainDf.createOrReplaceTempView("mllibTrainDf")
- hiveContext.udf.register("append_bias", _append_bias)
- hiveContext.udf.register("to_hivemall_features", _to_hivemall_features)
- /**
- * TODO: This test throws an exception:
- * Failed to analyze query: org.apache.spark.sql.AnalysisException: cannot resolve
- * 'UDF(UDF(features))' due to data type mismatch: argument 1 requires vector type,
- * however, 'UDF(features)' is of vector type.; line 2 pos 8
- */
+ hiveContext.udf.register("append_bias", append_bias_func)
+ hiveContext.udf.register("to_hivemall_features", to_hivemall_features_func)
checkAnswer(
sql(
s"""
@@ -151,10 +145,10 @@ final class HiveUdfWithVectorSuite extends VectorQueryTest {
| FROM mllibTrainDF
""".stripMargin),
Seq(
- Row(Seq("0:1.0", "0:1.0", "2:2.0", "4:3.0")),
- Row(Seq("0:1.0", "0:1.0", "3:1.5", "4:2.1", "6:1.2")),
- Row(Seq("0:1.0", "0:1.1", "3:1.0", "4:2.3", "6:1.0")),
- Row(Seq("0:1.0", "1:4.0", "3:5.0", "5:6.0"))
+ Row(Seq("0:1.0", "2:2.0", "4:3.0", "7:1.0")),
+ Row(Seq("0:1.0", "3:1.5", "4:2.1", "6:1.2", "7:1.0")),
+ Row(Seq("0:1.1", "3:1.0", "4:2.3", "6:1.0", "7:1.0")),
+ Row(Seq("1:4.0", "3:5.0", "5:6.0", "7:1.0"))
)
)
}