You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hivemall.apache.org by my...@apache.org on 2017/09/13 12:18:26 UTC
[2/4] incubator-hivemall git commit: Close #122:
[HIVEMALL-133][SPARK] Support spark-v2.2 in the hivemalls-spark module
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8bf6dd9e/spark/spark-2.2/src/test/scala/org/apache/spark/sql/QueryTest.scala
----------------------------------------------------------------------
diff --git a/spark/spark-2.2/src/test/scala/org/apache/spark/sql/QueryTest.scala b/spark/spark-2.2/src/test/scala/org/apache/spark/sql/QueryTest.scala
new file mode 100644
index 0000000..c9d0ba0
--- /dev/null
+++ b/spark/spark-2.2/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -0,0 +1,360 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.sql
+
+import java.util.{ArrayDeque, Locale, TimeZone}
+
+import scala.collection.JavaConverters._
+import scala.util.control.NonFatal
+
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate.ImperativeAggregate
+import org.apache.spark.sql.catalyst.plans._
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.trees.TreeNode
+import org.apache.spark.sql.catalyst.util._
+import org.apache.spark.sql.execution._
+import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression
+import org.apache.spark.sql.execution.columnar.InMemoryRelation
+import org.apache.spark.sql.execution.datasources.LogicalRelation
+import org.apache.spark.sql.execution.streaming.MemoryPlan
+import org.apache.spark.sql.types.{Metadata, ObjectType}
+
+
+abstract class QueryTest extends PlanTest {
+
+ protected def spark: SparkSession
+
+ // Timezone is fixed to America/Los_Angeles for those timezone sensitive tests (timestamp_*)
+ TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles"))
+ // Add Locale setting
+ Locale.setDefault(Locale.US)
+
+ /**
+ * Runs the plan and makes sure the answer contains all of the keywords.
+ */
+ def checkKeywordsExist(df: DataFrame, keywords: String*): Unit = {
+ val outputs = df.collect().map(_.mkString).mkString
+ for (key <- keywords) {
+ assert(outputs.contains(key), s"Failed for $df ($key doesn't exist in result)")
+ }
+ }
+
+ /**
+ * Runs the plan and makes sure the answer does NOT contain any of the keywords.
+ */
+ def checkKeywordsNotExist(df: DataFrame, keywords: String*): Unit = {
+ val outputs = df.collect().map(_.mkString).mkString
+ for (key <- keywords) {
+ assert(!outputs.contains(key), s"Failed for $df ($key existed in the result)")
+ }
+ }
+
+ /**
+ * Evaluates a dataset to make sure that the result of calling collect matches the given
+ * expected answer.
+ */
+ protected def checkDataset[T](
+ ds: => Dataset[T],
+ expectedAnswer: T*): Unit = {
+ val result = getResult(ds)
+
+ if (!compare(result.toSeq, expectedAnswer)) {
+ fail(
+ s"""
+ |Decoded objects do not match expected objects:
+ |expected: $expectedAnswer
+ |actual: ${result.toSeq}
+ |${ds.exprEnc.deserializer.treeString}
+ """.stripMargin)
+ }
+ }
+
+ /**
+ * Evaluates a dataset to make sure that the result of calling collect matches the given
+ * expected answer, after sort.
+ */
+ protected def checkDatasetUnorderly[T : Ordering](
+ ds: => Dataset[T],
+ expectedAnswer: T*): Unit = {
+ val result = getResult(ds)
+
+ if (!compare(result.toSeq.sorted, expectedAnswer.sorted)) {
+ fail(
+ s"""
+ |Decoded objects do not match expected objects:
+ |expected: $expectedAnswer
+ |actual: ${result.toSeq}
+ |${ds.exprEnc.deserializer.treeString}
+ """.stripMargin)
+ }
+ }
+
+ private def getResult[T](ds: => Dataset[T]): Array[T] = {
+ val analyzedDS = try ds catch {
+ case ae: AnalysisException =>
+ if (ae.plan.isDefined) {
+ fail(
+ s"""
+ |Failed to analyze query: $ae
+ |${ae.plan.get}
+ |
+ |${stackTraceToString(ae)}
+ """.stripMargin)
+ } else {
+ throw ae
+ }
+ }
+ assertEmptyMissingInput(analyzedDS)
+
+ try ds.collect() catch {
+ case e: Exception =>
+ fail(
+ s"""
+ |Exception collecting dataset as objects
+ |${ds.exprEnc}
+ |${ds.exprEnc.deserializer.treeString}
+ |${ds.queryExecution}
+ """.stripMargin, e)
+ }
+ }
+
+ private def compare(obj1: Any, obj2: Any): Boolean = (obj1, obj2) match {
+ case (null, null) => true
+ case (null, _) => false
+ case (_, null) => false
+ case (a: Array[_], b: Array[_]) =>
+ a.length == b.length && a.zip(b).forall { case (l, r) => compare(l, r)}
+ case (a: Iterable[_], b: Iterable[_]) =>
+ a.size == b.size && a.zip(b).forall { case (l, r) => compare(l, r)}
+ case (a, b) => a == b
+ }
+
+ /**
+ * Runs the plan and makes sure the answer matches the expected result.
+ *
+ * @param df the [[DataFrame]] to be executed
+ * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s.
+ */
+ protected def checkAnswer(df: => DataFrame, expectedAnswer: Seq[Row]): Unit = {
+ val analyzedDF = try df catch {
+ case ae: AnalysisException =>
+ if (ae.plan.isDefined) {
+ fail(
+ s"""
+ |Failed to analyze query: $ae
+ |${ae.plan.get}
+ |
+ |${stackTraceToString(ae)}
+ |""".stripMargin)
+ } else {
+ throw ae
+ }
+ }
+
+ assertEmptyMissingInput(analyzedDF)
+
+ QueryTest.checkAnswer(analyzedDF, expectedAnswer) match {
+ case Some(errorMessage) => fail(errorMessage)
+ case None =>
+ }
+ }
+
+ protected def checkAnswer(df: => DataFrame, expectedAnswer: Row): Unit = {
+ checkAnswer(df, Seq(expectedAnswer))
+ }
+
+ protected def checkAnswer(df: => DataFrame, expectedAnswer: DataFrame): Unit = {
+ checkAnswer(df, expectedAnswer.collect())
+ }
+
+ /**
+ * Runs the plan and makes sure the answer is within absTol of the expected result.
+ *
+ * @param dataFrame the [[DataFrame]] to be executed
+ * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s.
+ * @param absTol the absolute tolerance between actual and expected answers.
+ */
+ protected def checkAggregatesWithTol(dataFrame: DataFrame,
+ expectedAnswer: Seq[Row],
+ absTol: Double): Unit = {
+ // TODO: catch exceptions in data frame execution
+ val actualAnswer = dataFrame.collect()
+ require(actualAnswer.length == expectedAnswer.length,
+ s"actual num rows ${actualAnswer.length} != expected num of rows ${expectedAnswer.length}")
+
+ actualAnswer.zip(expectedAnswer).foreach {
+ case (actualRow, expectedRow) =>
+ QueryTest.checkAggregatesWithTol(actualRow, expectedRow, absTol)
+ }
+ }
+
+ protected def checkAggregatesWithTol(dataFrame: DataFrame,
+ expectedAnswer: Row,
+ absTol: Double): Unit = {
+ checkAggregatesWithTol(dataFrame, Seq(expectedAnswer), absTol)
+ }
+
+ /**
+ * Asserts that a given [[Dataset]] will be executed using the given number of cached results.
+ */
+ def assertCached(query: Dataset[_], numCachedTables: Int = 1): Unit = {
+ val planWithCaching = query.queryExecution.withCachedData
+ val cachedData = planWithCaching collect {
+ case cached: InMemoryRelation => cached
+ }
+
+ assert(
+ cachedData.size == numCachedTables,
+ s"Expected query to contain $numCachedTables, but it actually had ${cachedData.size}\n" +
+ planWithCaching)
+ }
+
+ /**
+ * Asserts that a given [[Dataset]] does not have missing inputs in all the analyzed plans.
+ */
+ def assertEmptyMissingInput(query: Dataset[_]): Unit = {
+ assert(query.queryExecution.analyzed.missingInput.isEmpty,
+ s"The analyzed logical plan has missing inputs:\n${query.queryExecution.analyzed}")
+ assert(query.queryExecution.optimizedPlan.missingInput.isEmpty,
+ s"The optimized logical plan has missing inputs:\n${query.queryExecution.optimizedPlan}")
+ assert(query.queryExecution.executedPlan.missingInput.isEmpty,
+ s"The physical plan has missing inputs:\n${query.queryExecution.executedPlan}")
+ }
+}
+
+object QueryTest {
+ /**
+ * Runs the plan and makes sure the answer matches the expected result.
+ * If there was exception during the execution or the contents of the DataFrame does not
+ * match the expected result, an error message will be returned. Otherwise, a [[None]] will
+ * be returned.
+ *
+ * @param df the [[DataFrame]] to be executed
+ * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s.
+ * @param checkToRDD whether to verify deserialization to an RDD. This runs the query twice.
+ */
+ def checkAnswer(
+ df: DataFrame,
+ expectedAnswer: Seq[Row],
+ checkToRDD: Boolean = true): Option[String] = {
+ val isSorted = df.logicalPlan.collect { case s: logical.Sort => s }.nonEmpty
+ if (checkToRDD) {
+ df.rdd.count() // Also attempt to deserialize as an RDD [SPARK-15791]
+ }
+
+ val sparkAnswer = try df.collect().toSeq catch {
+ case e: Exception =>
+ val errorMessage =
+ s"""
+ |Exception thrown while executing query:
+ |${df.queryExecution}
+ |== Exception ==
+ |$e
+ |${org.apache.spark.sql.catalyst.util.stackTraceToString(e)}
+ """.stripMargin
+ return Some(errorMessage)
+ }
+
+ sameRows(expectedAnswer, sparkAnswer, isSorted).map { results =>
+ s"""
+ |Results do not match for query:
+ |Timezone: ${TimeZone.getDefault}
+ |Timezone Env: ${sys.env.getOrElse("TZ", "")}
+ |
+ |${df.queryExecution}
+ |== Results ==
+ |$results
+ """.stripMargin
+ }
+ }
+
+
+ def prepareAnswer(answer: Seq[Row], isSorted: Boolean): Seq[Row] = {
+ // Converts data to types that we can do equality comparison using Scala collections.
+ // For BigDecimal type, the Scala type has a better definition of equality test (similar to
+ // Java's java.math.BigDecimal.compareTo).
+ // For binary arrays, we convert it to Seq to avoid of calling java.util.Arrays.equals for
+ // equality test.
+ val converted: Seq[Row] = answer.map(prepareRow)
+ if (!isSorted) converted.sortBy(_.toString()) else converted
+ }
+
+ // We need to call prepareRow recursively to handle schemas with struct types.
+ def prepareRow(row: Row): Row = {
+ Row.fromSeq(row.toSeq.map {
+ case null => null
+ case d: java.math.BigDecimal => BigDecimal(d)
+ // Convert array to Seq for easy equality check.
+ case b: Array[_] => b.toSeq
+ case r: Row => prepareRow(r)
+ case o => o
+ })
+ }
+
+ def sameRows(
+ expectedAnswer: Seq[Row],
+ sparkAnswer: Seq[Row],
+ isSorted: Boolean = false): Option[String] = {
+ if (prepareAnswer(expectedAnswer, isSorted) != prepareAnswer(sparkAnswer, isSorted)) {
+ val errorMessage =
+ s"""
+ |== Results ==
+ |${sideBySide(
+ s"== Correct Answer - ${expectedAnswer.size} ==" +:
+ prepareAnswer(expectedAnswer, isSorted).map(_.toString()),
+ s"== Spark Answer - ${sparkAnswer.size} ==" +:
+ prepareAnswer(sparkAnswer, isSorted).map(_.toString())).mkString("\n")}
+ """.stripMargin
+ return Some(errorMessage)
+ }
+ None
+ }
+
+ /**
+ * Runs the plan and makes sure the answer is within absTol of the expected result.
+ *
+ * @param actualAnswer the actual result in a [[Row]].
+ * @param expectedAnswer the expected result in a[[Row]].
+ * @param absTol the absolute tolerance between actual and expected answers.
+ */
+ protected def checkAggregatesWithTol(actualAnswer: Row, expectedAnswer: Row, absTol: Double) = {
+ require(actualAnswer.length == expectedAnswer.length,
+ s"actual answer length ${actualAnswer.length} != " +
+ s"expected answer length ${expectedAnswer.length}")
+
+ // TODO: support other numeric types besides Double
+ // TODO: support struct types?
+ actualAnswer.toSeq.zip(expectedAnswer.toSeq).foreach {
+ case (actual: Double, expected: Double) =>
+ assert(math.abs(actual - expected) < absTol,
+ s"actual answer $actual not within $absTol of correct answer $expected")
+ case (actual, expected) =>
+ assert(actual == expected, s"$actual did not equal $expected")
+ }
+ }
+
+ def checkAnswer(df: DataFrame, expectedAnswer: java.util.List[Row]): String = {
+ checkAnswer(df, expectedAnswer.asScala) match {
+ case Some(errorMessage) => errorMessage
+ case None => null
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8bf6dd9e/spark/spark-2.2/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
----------------------------------------------------------------------
diff --git a/spark/spark-2.2/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/spark/spark-2.2/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
new file mode 100644
index 0000000..a4aeaa6
--- /dev/null
+++ b/spark/spark-2.2/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
@@ -0,0 +1,137 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.sql.catalyst.plans
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.util._
+import org.apache.spark.sql.internal.SQLConf
+
+/**
+ * Provides helper methods for comparing plans.
+ */
+abstract class PlanTest extends SparkFunSuite with PredicateHelper {
+
+ protected val conf = new SQLConf().copy(SQLConf.CASE_SENSITIVE -> true)
+
+ /**
+ * Since attribute references are given globally unique ids during analysis,
+ * we must normalize them to check if two different queries are identical.
+ */
+ protected def normalizeExprIds(plan: LogicalPlan) = {
+ plan transformAllExpressions {
+ case s: ScalarSubquery =>
+ s.copy(exprId = ExprId(0))
+ case e: Exists =>
+ e.copy(exprId = ExprId(0))
+ case l: ListQuery =>
+ l.copy(exprId = ExprId(0))
+ case a: AttributeReference =>
+ AttributeReference(a.name, a.dataType, a.nullable)(exprId = ExprId(0))
+ case a: Alias =>
+ Alias(a.child, a.name)(exprId = ExprId(0))
+ case ae: AggregateExpression =>
+ ae.copy(resultId = ExprId(0))
+ }
+ }
+
+ /**
+ * Normalizes plans:
+ * - Filter the filter conditions that appear in a plan. For instance,
+ * ((expr 1 && expr 2) && expr 3), (expr 1 && expr 2 && expr 3), (expr 3 && (expr 1 && expr 2)
+ * etc., will all now be equivalent.
+ * - Sample the seed will replaced by 0L.
+ * - Join conditions will be resorted by hashCode.
+ */
+ protected def normalizePlan(plan: LogicalPlan): LogicalPlan = {
+ plan transform {
+ case filter @ Filter(condition: Expression, child: LogicalPlan) =>
+ Filter(splitConjunctivePredicates(condition).map(rewriteEqual(_)).sortBy(_.hashCode())
+ .reduce(And), child)
+ case sample: Sample =>
+ sample.copy(seed = 0L)(true)
+ case join @ Join(left, right, joinType, condition) if condition.isDefined =>
+ val newCondition =
+ splitConjunctivePredicates(condition.get).map(rewriteEqual(_)).sortBy(_.hashCode())
+ .reduce(And)
+ Join(left, right, joinType, Some(newCondition))
+ }
+ }
+
+ /**
+ * Rewrite [[EqualTo]] and [[EqualNullSafe]] operator to keep order. The following cases will be
+ * equivalent:
+ * 1. (a = b), (b = a);
+ * 2. (a <=> b), (b <=> a).
+ */
+ private def rewriteEqual(condition: Expression): Expression = condition match {
+ case eq @ EqualTo(l: Expression, r: Expression) =>
+ Seq(l, r).sortBy(_.hashCode()).reduce(EqualTo)
+ case eq @ EqualNullSafe(l: Expression, r: Expression) =>
+ Seq(l, r).sortBy(_.hashCode()).reduce(EqualNullSafe)
+ case _ => condition // Don't reorder.
+ }
+
+ /** Fails the test if the two plans do not match */
+ protected def comparePlans(plan1: LogicalPlan, plan2: LogicalPlan) {
+ val normalized1 = normalizePlan(normalizeExprIds(plan1))
+ val normalized2 = normalizePlan(normalizeExprIds(plan2))
+ if (normalized1 != normalized2) {
+ fail(
+ s"""
+ |== FAIL: Plans do not match ===
+ |${sideBySide(normalized1.treeString, normalized2.treeString).mkString("\n")}
+ """.stripMargin)
+ }
+ }
+
+ /** Fails the test if the two expressions do not match */
+ protected def compareExpressions(e1: Expression, e2: Expression): Unit = {
+ comparePlans(Filter(e1, OneRowRelation), Filter(e2, OneRowRelation))
+ }
+
+ /** Fails the test if the join order in the two plans do not match */
+ protected def compareJoinOrder(plan1: LogicalPlan, plan2: LogicalPlan) {
+ val normalized1 = normalizePlan(normalizeExprIds(plan1))
+ val normalized2 = normalizePlan(normalizeExprIds(plan2))
+ if (!sameJoinPlan(normalized1, normalized2)) {
+ fail(
+ s"""
+ |== FAIL: Plans do not match ===
+ |${sideBySide(normalized1.treeString, normalized2.treeString).mkString("\n")}
+ """.stripMargin)
+ }
+ }
+
+ /** Consider symmetry for joins when comparing plans. */
+ private def sameJoinPlan(plan1: LogicalPlan, plan2: LogicalPlan): Boolean = {
+ (plan1, plan2) match {
+ case (j1: Join, j2: Join) =>
+ (sameJoinPlan(j1.left, j2.left) && sameJoinPlan(j1.right, j2.right)) ||
+ (sameJoinPlan(j1.left, j2.right) && sameJoinPlan(j1.right, j2.left))
+ case (p1: Project, p2: Project) =>
+ p1.projectList == p2.projectList && sameJoinPlan(p1.child, p2.child)
+ case _ =>
+ plan1 == plan2
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8bf6dd9e/spark/spark-2.2/src/test/scala/org/apache/spark/sql/execution/benchmark/BenchmarkBase.scala
----------------------------------------------------------------------
diff --git a/spark/spark-2.2/src/test/scala/org/apache/spark/sql/execution/benchmark/BenchmarkBase.scala b/spark/spark-2.2/src/test/scala/org/apache/spark/sql/execution/benchmark/BenchmarkBase.scala
new file mode 100644
index 0000000..8283503
--- /dev/null
+++ b/spark/spark-2.2/src/test/scala/org/apache/spark/sql/execution/benchmark/BenchmarkBase.scala
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.sql.execution.benchmark
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.util.Benchmark
+
+/**
+ * Common base trait for micro benchmarks that are supposed to run standalone (i.e. not together
+ * with other test suites).
+ */
+private[sql] trait BenchmarkBase extends SparkFunSuite {
+
+ lazy val sparkSession = SparkSession.builder
+ .master("local[1]")
+ .appName("microbenchmark")
+ .config("spark.sql.shuffle.partitions", 1)
+ .config("spark.sql.autoBroadcastJoinThreshold", 1)
+ .getOrCreate()
+
+ /** Runs function `f` with whole stage codegen on and off. */
+ def runBenchmark(name: String, cardinality: Long)(f: => Unit): Unit = {
+ val benchmark = new Benchmark(name, cardinality)
+
+ benchmark.addCase(s"$name wholestage off", numIters = 2) { iter =>
+ sparkSession.conf.set("spark.sql.codegen.wholeStage", value = false)
+ f
+ }
+
+ benchmark.addCase(s"$name wholestage on", numIters = 5) { iter =>
+ sparkSession.conf.set("spark.sql.codegen.wholeStage", value = true)
+ f
+ }
+
+ benchmark.run()
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8bf6dd9e/spark/spark-2.2/src/test/scala/org/apache/spark/sql/hive/HiveUdfSuite.scala
----------------------------------------------------------------------
diff --git a/spark/spark-2.2/src/test/scala/org/apache/spark/sql/hive/HiveUdfSuite.scala b/spark/spark-2.2/src/test/scala/org/apache/spark/sql/hive/HiveUdfSuite.scala
new file mode 100644
index 0000000..b145b7f
--- /dev/null
+++ b/spark/spark-2.2/src/test/scala/org/apache/spark/sql/hive/HiveUdfSuite.scala
@@ -0,0 +1,161 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.sql.hive
+
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.hive.HivemallUtils._
+import org.apache.spark.sql.hive.test.HivemallFeatureQueryTest
+import org.apache.spark.sql.test.VectorQueryTest
+
+final class HiveUdfWithFeatureSuite extends HivemallFeatureQueryTest {
+ import hiveContext.implicits._
+ import hiveContext._
+
+ test("hivemall_version") {
+ sql(s"""
+ | CREATE TEMPORARY FUNCTION hivemall_version
+ | AS '${classOf[hivemall.HivemallVersionUDF].getName}'
+ """.stripMargin)
+
+ checkAnswer(
+ sql(s"SELECT DISTINCT hivemall_version()"),
+ Row("0.4.2-rc.2")
+ )
+
+ // sql("DROP TEMPORARY FUNCTION IF EXISTS hivemall_version")
+ // reset()
+ }
+
+ test("train_logregr") {
+ TinyTrainData.createOrReplaceTempView("TinyTrainData")
+ sql(s"""
+ | CREATE TEMPORARY FUNCTION train_logregr
+ | AS '${classOf[hivemall.regression.LogressUDTF].getName}'
+ """.stripMargin)
+ sql(s"""
+ | CREATE TEMPORARY FUNCTION add_bias
+ | AS '${classOf[hivemall.ftvec.AddBiasUDFWrapper].getName}'
+ """.stripMargin)
+
+ val model = sql(
+ s"""
+ | SELECT feature, AVG(weight) AS weight
+ | FROM (
+ | SELECT train_logregr(add_bias(features), label) AS (feature, weight)
+ | FROM TinyTrainData
+ | ) t
+ | GROUP BY feature
+ """.stripMargin)
+
+ checkAnswer(
+ model.select($"feature"),
+ Seq(Row("0"), Row("1"), Row("2"))
+ )
+
+ // TODO: Why 'train_logregr' is not registered in HiveMetaStore?
+ // ERROR RetryingHMSHandler: MetaException(message:NoSuchObjectException
+ // (message:Function default.train_logregr does not exist))
+ //
+ // hiveContext.sql("DROP TEMPORARY FUNCTION IF EXISTS train_logregr")
+ // hiveContext.reset()
+ }
+
+ test("each_top_k") {
+ val testDf = Seq(
+ ("a", "1", 0.5, Array(0, 1, 2)),
+ ("b", "5", 0.1, Array(3)),
+ ("a", "3", 0.8, Array(2, 5)),
+ ("c", "6", 0.3, Array(1, 3)),
+ ("b", "4", 0.3, Array(2)),
+ ("a", "2", 0.6, Array(1))
+ ).toDF("key", "value", "score", "data")
+
+ import testDf.sqlContext.implicits._
+ testDf.repartition($"key").sortWithinPartitions($"key").createOrReplaceTempView("TestData")
+ sql(s"""
+ | CREATE TEMPORARY FUNCTION each_top_k
+ | AS '${classOf[hivemall.tools.EachTopKUDTF].getName}'
+ """.stripMargin)
+
+ // Compute top-1 rows for each group
+ checkAnswer(
+ sql("SELECT each_top_k(1, key, score, key, value) FROM TestData"),
+ Row(1, 0.8, "a", "3") ::
+ Row(1, 0.3, "b", "4") ::
+ Row(1, 0.3, "c", "6") ::
+ Nil
+ )
+
+ // Compute reverse top-1 rows for each group
+ checkAnswer(
+ sql("SELECT each_top_k(-1, key, score, key, value) FROM TestData"),
+ Row(1, 0.5, "a", "1") ::
+ Row(1, 0.1, "b", "5") ::
+ Row(1, 0.3, "c", "6") ::
+ Nil
+ )
+ }
+}
+
+final class HiveUdfWithVectorSuite extends VectorQueryTest {
+ import hiveContext._
+
+ test("to_hivemall_features") {
+ mllibTrainDf.createOrReplaceTempView("mllibTrainDf")
+ hiveContext.udf.register("to_hivemall_features", to_hivemall_features_func)
+ checkAnswer(
+ sql(
+ s"""
+ | SELECT to_hivemall_features(features)
+ | FROM mllibTrainDf
+ """.stripMargin),
+ Seq(
+ Row(Seq("0:1.0", "2:2.0", "4:3.0")),
+ Row(Seq("0:1.0", "3:1.5", "4:2.1", "6:1.2")),
+ Row(Seq("0:1.1", "3:1.0", "4:2.3", "6:1.0")),
+ Row(Seq("1:4.0", "3:5.0", "5:6.0"))
+ )
+ )
+ }
+
+ test("append_bias") {
+ mllibTrainDf.createOrReplaceTempView("mllibTrainDf")
+ hiveContext.udf.register("append_bias", append_bias_func)
+ hiveContext.udf.register("to_hivemall_features", to_hivemall_features_func)
+ checkAnswer(
+ sql(
+ s"""
+ | SELECT to_hivemall_features(append_bias(features))
+ | FROM mllibTrainDF
+ """.stripMargin),
+ Seq(
+ Row(Seq("0:1.0", "2:2.0", "4:3.0", "7:1.0")),
+ Row(Seq("0:1.0", "3:1.5", "4:2.1", "6:1.2", "7:1.0")),
+ Row(Seq("0:1.1", "3:1.0", "4:2.3", "6:1.0", "7:1.0")),
+ Row(Seq("1:4.0", "3:5.0", "5:6.0", "7:1.0"))
+ )
+ )
+ }
+
+ ignore("explode_vector") {
+ // TODO: Spark-2.0 does not support use-defined generator function in
+ // `org.apache.spark.sql.UDFRegistration`.
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8bf6dd9e/spark/spark-2.2/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
----------------------------------------------------------------------
diff --git a/spark/spark-2.2/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala b/spark/spark-2.2/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
new file mode 100644
index 0000000..6b5d4cd
--- /dev/null
+++ b/spark/spark-2.2/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
@@ -0,0 +1,961 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.sql.hive
+
+import org.apache.spark.sql.{AnalysisException, Row}
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.hive.HivemallGroupedDataset._
+import org.apache.spark.sql.hive.HivemallOps._
+import org.apache.spark.sql.hive.HivemallUtils._
+import org.apache.spark.sql.hive.test.HivemallFeatureQueryTest
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.test.VectorQueryTest
+import org.apache.spark.sql.types._
+import org.apache.spark.test.TestFPWrapper._
+import org.apache.spark.test.TestUtils
+
+final class HivemallOpsWithFeatureSuite extends HivemallFeatureQueryTest {
+
+ test("anomaly") {
+ import hiveContext.implicits._
+ val df = spark.range(1000).selectExpr("id AS time", "rand() AS x")
+ // TODO: Test results more strictly
+ assert(df.sort($"time".asc).select(changefinder($"x")).count === 1000)
+ assert(df.sort($"time".asc).select(sst($"x", lit("-th 0.005"))).count === 1000)
+ }
+
+ test("knn.similarity") {
+ val df1 = DummyInputData.select(cosine_sim(lit2(Seq(1, 2, 3, 4)), lit2(Seq(3, 4, 5, 6))))
+ assert(df1.collect.apply(0).getFloat(0) ~== 0.500f)
+
+ val df2 = DummyInputData.select(jaccard(lit(5), lit(6)))
+ assert(df2.collect.apply(0).getFloat(0) ~== 0.96875f)
+
+ val df3 = DummyInputData.select(angular_similarity(lit2(Seq(1, 2, 3)), lit2(Seq(4, 5, 6))))
+ assert(df3.collect.apply(0).getFloat(0) ~== 0.500f)
+
+ val df4 = DummyInputData.select(euclid_similarity(lit2(Seq(5, 3, 1)), lit2(Seq(2, 8, 3))))
+ assert(df4.collect.apply(0).getFloat(0) ~== 0.33333334f)
+
+ val df5 = DummyInputData.select(distance2similarity(lit(1.0)))
+ assert(df5.collect.apply(0).getFloat(0) ~== 0.5f)
+ }
+
+ test("knn.distance") {
+ val df1 = DummyInputData.select(hamming_distance(lit(1), lit(3)))
+ checkAnswer(df1, Row(1) :: Nil)
+
+ val df2 = DummyInputData.select(popcnt(lit(1)))
+ checkAnswer(df2, Row(1) :: Nil)
+
+ val df3 = DummyInputData.select(kld(lit(0.1), lit(0.5), lit(0.2), lit(0.5)))
+ assert(df3.collect.apply(0).getDouble(0) ~== 0.01)
+
+ val df4 = DummyInputData.select(
+ euclid_distance(lit2(Seq("0.1", "0.5")), lit2(Seq("0.2", "0.5"))))
+ assert(df4.collect.apply(0).getFloat(0) ~== 1.4142135f)
+
+ val df5 = DummyInputData.select(
+ cosine_distance(lit2(Seq("0.8", "0.3")), lit2(Seq("0.4", "0.6"))))
+ assert(df5.collect.apply(0).getFloat(0) ~== 1.0f)
+
+ val df6 = DummyInputData.select(
+ angular_distance(lit2(Seq("0.1", "0.1")), lit2(Seq("0.3", "0.8"))))
+ assert(df6.collect.apply(0).getFloat(0) ~== 0.50f)
+
+ val df7 = DummyInputData.select(
+ manhattan_distance(lit2(Seq("0.7", "0.8")), lit2(Seq("0.5", "0.6"))))
+ assert(df7.collect.apply(0).getFloat(0) ~== 4.0f)
+
+ val df8 = DummyInputData.select(
+ minkowski_distance(lit2(Seq("0.1", "0.2")), lit2(Seq("0.2", "0.2")), lit2(1.0)))
+ assert(df8.collect.apply(0).getFloat(0) ~== 2.0f)
+ }
+
+ test("knn.lsh") {
+ import hiveContext.implicits._
+ assert(IntList2Data.minhash(lit(1), $"target").count() > 0)
+
+ assert(DummyInputData.select(bbit_minhash(lit2(Seq("1:0.1", "2:0.5")), lit(false))).count
+ == DummyInputData.count)
+ assert(DummyInputData.select(minhashes(lit2(Seq("1:0.1", "2:0.5")), lit(false))).count
+ == DummyInputData.count)
+ }
+
+ test("ftvec - add_bias") {
+ import hiveContext.implicits._
+ checkAnswer(TinyTrainData.select(add_bias($"features")),
+ Row(Seq("1:0.8", "2:0.2", "0:1.0")) ::
+ Row(Seq("2:0.7", "0:1.0")) ::
+ Row(Seq("1:0.9", "0:1.0")) ::
+ Nil
+ )
+ }
+
+ test("ftvec - extract_feature") {
+ val df = DummyInputData.select(extract_feature(lit("1:0.8")))
+ checkAnswer(df, Row("1") :: Nil)
+ }
+
+ test("ftvec - extract_weight") {
+ val df = DummyInputData.select(extract_weight(lit("3:0.1")))
+ assert(df.collect.apply(0).getDouble(0) ~== 0.1)
+ }
+
+ test("ftvec - explode_array") {
+ import hiveContext.implicits._
+ val df = TinyTrainData.explode_array($"features").select($"feature")
+ checkAnswer(df, Row("1:0.8") :: Row("2:0.2") :: Row("2:0.7") :: Row("1:0.9") :: Nil)
+ }
+
+ test("ftvec - add_feature_index") {
+ import hiveContext.implicits._
+ val doubleListData = Seq(Array(0.8, 0.5), Array(0.3, 0.1), Array(0.2)).toDF("data")
+ checkAnswer(
+ doubleListData.select(add_feature_index($"data")),
+ Row(Seq("1:0.8", "2:0.5")) ::
+ Row(Seq("1:0.3", "2:0.1")) ::
+ Row(Seq("1:0.2")) ::
+ Nil
+ )
+ }
+
+ test("ftvec - sort_by_feature") {
+ // import hiveContext.implicits._
+ val intFloatMapData = {
+ // TODO: Use `toDF`
+ val rowRdd = hiveContext.sparkContext.parallelize(
+ Row(Map(1 -> 0.3f, 2 -> 0.1f, 3 -> 0.5f)) ::
+ Row(Map(2 -> 0.4f, 1 -> 0.2f)) ::
+ Row(Map(2 -> 0.4f, 3 -> 0.2f, 1 -> 0.1f, 4 -> 0.6f)) ::
+ Nil
+ )
+ hiveContext.createDataFrame(
+ rowRdd,
+ StructType(
+ StructField("data", MapType(IntegerType, FloatType), true) ::
+ Nil)
+ )
+ }
+ val sortedKeys = intFloatMapData.select(sort_by_feature(intFloatMapData.col("data")))
+ .collect.map {
+ case Row(m: Map[Int, Float]) => m.keysIterator.toSeq
+ }
+ assert(sortedKeys.toSet === Set(Seq(1, 2, 3), Seq(1, 2), Seq(1, 2, 3, 4)))
+ }
+
+ test("ftvec.hash") {
+ assert(DummyInputData.select(mhash(lit("test"))).count == DummyInputData.count)
+ assert(DummyInputData.select(org.apache.spark.sql.hive.HivemallOps.sha1(lit("test"))).count ==
+ DummyInputData.count)
+ // TODO: The tests below failed because:
+ // org.apache.spark.sql.AnalysisException: List type in java is unsupported because JVM type
+ // erasure makes spark fail to catch a component type in List<>;
+ //
+ // assert(DummyInputData.select(array_hash_values(lit2(Seq("aaa", "bbb")))).count
+ // == DummyInputData.count)
+ // assert(DummyInputData.select(
+ // prefixed_hash_values(lit2(Seq("ccc", "ddd")), lit("prefix"))).count
+ // == DummyInputData.count)
+ }
+
+ test("ftvec.scaling") {
+ val df1 = TinyTrainData.select(rescale(lit(2.0f), lit(1.0), lit(5.0)))
+ assert(df1.collect.apply(0).getFloat(0) === 0.25f)
+ val df2 = TinyTrainData.select(zscore(lit(1.0f), lit(0.5), lit(0.5)))
+ assert(df2.collect.apply(0).getFloat(0) === 1.0f)
+ val df3 = TinyTrainData.select(normalize(TinyTrainData.col("features")))
+ checkAnswer(
+ df3,
+ Row(Seq("1:0.9701425", "2:0.24253562")) ::
+ Row(Seq("2:1.0")) ::
+ Row(Seq("1:1.0")) ::
+ Nil)
+ }
+
+ test("ftvec.selection - chi2") {
+ import hiveContext.implicits._
+
+ // See also hivemall.ftvec.selection.ChiSquareUDFTest
+ val df = Seq(
+ Seq(
+ Seq(250.29999999999998, 170.90000000000003, 73.2, 12.199999999999996),
+ Seq(296.8, 138.50000000000003, 212.99999999999997, 66.3),
+ Seq(329.3999999999999, 148.7, 277.59999999999997, 101.29999999999998)
+ ) -> Seq(
+ Seq(292.1666753739119, 152.70000455081467, 187.93333893418327, 59.93333511948589),
+ Seq(292.1666753739119, 152.70000455081467, 187.93333893418327, 59.93333511948589),
+ Seq(292.1666753739119, 152.70000455081467, 187.93333893418327, 59.93333511948589)))
+ .toDF("arg0", "arg1")
+
+ val result = df.select(chi2(df("arg0"), df("arg1"))).collect
+ assert(result.length == 1)
+ val chi2Val = result.head.getAs[Row](0).getAs[Seq[Double]](0)
+ val pVal = result.head.getAs[Row](0).getAs[Seq[Double]](1)
+
+ (chi2Val, Seq(10.81782088, 3.59449902, 116.16984746, 67.24482759))
+ .zipped
+ .foreach((actual, expected) => assert(actual ~== expected))
+
+ (pVal, Seq(4.47651499e-03, 1.65754167e-01, 5.94344354e-26, 2.50017968e-15))
+ .zipped
+ .foreach((actual, expected) => assert(actual ~== expected))
+ }
+
+ test("ftvec.conv - quantify") {
+ import hiveContext.implicits._
+ val testDf = Seq((1, "aaa", true), (2, "bbb", false), (3, "aaa", false)).toDF
+ // This test is done in a single partition because `HivemallOps#quantify` assigns identifiers
+ // for non-numerical values in each partition.
+ checkAnswer(
+ testDf.coalesce(1).quantify(lit(true) +: testDf.cols: _*),
+ Row(1, 0, 0) :: Row(2, 1, 1) :: Row(3, 0, 1) :: Nil)
+ }
+
+ test("ftvec.amplify") {
+ import hiveContext.implicits._
+ assert(TinyTrainData.amplify(lit(3), $"label", $"features").count() == 9)
+ assert(TinyTrainData.part_amplify(lit(3)).count() == 9)
+ // TODO: The test below failed because:
+ // java.lang.RuntimeException: Unsupported literal type class scala.Tuple3
+ // (-buf 128,label,features)
+ //
+ // assert(TinyTrainData.rand_amplify(lit(3), lit("-buf 8", $"label", $"features")).count() == 9)
+ }
+
+ ignore("ftvec.conv") {
+ import hiveContext.implicits._
+
+ val df1 = Seq((0.0, "1:0.1" :: "3:0.3" :: Nil), (1, 0, "2:0.2" :: Nil)).toDF("a", "b")
+ checkAnswer(
+ df1.select(to_dense_features(df1("b"), lit(3))),
+ Row(Array(0.1f, 0.0f, 0.3f)) :: Row(Array(0.0f, 0.2f, 0.0f)) :: Nil
+ )
+ val df2 = Seq((0.1, 0.2, 0.3), (0.2, 0.5, 0.4)).toDF("a", "b", "c")
+ checkAnswer(
+ df2.select(to_sparse_features(df2("a"), df2("b"), df2("c"))),
+ Row(Seq("1:0.1", "2:0.2", "3:0.3")) :: Row(Seq("1:0.2", "2:0.5", "3:0.4")) :: Nil
+ )
+ }
+
+ test("ftvec.trans") {
+ import hiveContext.implicits._
+
+ val df1 = Seq((1, -3, 1), (2, -2, 1)).toDF("a", "b", "c")
+ checkAnswer(
+ df1.binarize_label($"a", $"b", $"c"),
+ Row(1, 1) :: Row(1, 1) :: Row(1, 1) :: Nil
+ )
+
+ val df2 = Seq((0.1f, 0.2f), (0.5f, 0.3f)).toDF("a", "b")
+ checkAnswer(
+ df2.select(vectorize_features(lit2(Seq("a", "b")), df2("a"), df2("b"))),
+ Row(Seq("a:0.1", "b:0.2")) :: Row(Seq("a:0.5", "b:0.3")) :: Nil
+ )
+
+ val df3 = Seq(("c11", "c12"), ("c21", "c22")).toDF("a", "b")
+ checkAnswer(
+ df3.select(categorical_features(lit2(Seq("a", "b")), df3("a"), df3("b"))),
+ Row(Seq("a#c11", "b#c12")) :: Row(Seq("a#c21", "b#c22")) :: Nil
+ )
+
+ val df4 = Seq((0.1, 0.2, 0.3), (0.2, 0.5, 0.4)).toDF("a", "b", "c")
+ checkAnswer(
+ df4.select(indexed_features(df4("a"), df4("b"), df4("c"))),
+ Row(Seq("1:0.1", "2:0.2", "3:0.3")) :: Row(Seq("1:0.2", "2:0.5", "3:0.4")) :: Nil
+ )
+
+ val df5 = Seq(("xxx", "yyy", 0), ("zzz", "yyy", 1)).toDF("a", "b", "c").coalesce(1)
+ checkAnswer(
+ df5.quantified_features(lit(true), df5("a"), df5("b"), df5("c")),
+ Row(Seq(0.0, 0.0, 0.0)) :: Row(Seq(1.0, 0.0, 1.0)) :: Nil
+ )
+
+ val df6 = Seq((0.1, 0.2), (0.5, 0.3)).toDF("a", "b")
+ checkAnswer(
+ df6.select(quantitative_features(lit2(Seq("a", "b")), df6("a"), df6("b"))),
+ Row(Seq("a:0.1", "b:0.2")) :: Row(Seq("a:0.5", "b:0.3")) :: Nil
+ )
+ }
+
+ test("misc - hivemall_version") {
+ checkAnswer(DummyInputData.select(hivemall_version()), Row("0.4.2-rc.2"))
+ }
+
+ test("misc - rowid") {
+ assert(DummyInputData.select(rowid()).distinct.count == DummyInputData.count)
+ }
+
+ test("misc - each_top_k") {
+ import hiveContext.implicits._
+ val inputDf = Seq(
+ ("a", "1", 0.5, 0.1, Array(0, 1, 2)),
+ ("b", "5", 0.1, 0.2, Array(3)),
+ ("a", "3", 0.8, 0.8, Array(2, 5)),
+ ("c", "6", 0.3, 0.3, Array(1, 3)),
+ ("b", "4", 0.3, 0.4, Array(2)),
+ ("a", "2", 0.6, 0.5, Array(1))
+ ).toDF("key", "value", "x", "y", "data")
+
+ // Compute top-1 rows for each group
+ val distance = sqrt(inputDf("x") * inputDf("x") + inputDf("y") * inputDf("y")).as("score")
+ val top1Df = inputDf.each_top_k(lit(1), distance, $"key".as("group"))
+ assert(top1Df.schema.toSet === Set(
+ StructField("rank", IntegerType, nullable = true),
+ StructField("score", DoubleType, nullable = true),
+ StructField("key", StringType, nullable = true),
+ StructField("value", StringType, nullable = true),
+ StructField("x", DoubleType, nullable = true),
+ StructField("y", DoubleType, nullable = true),
+ StructField("data", ArrayType(IntegerType, containsNull = false), nullable = true)
+ ))
+ checkAnswer(
+ top1Df.select($"rank", $"key", $"value", $"data"),
+ Row(1, "a", "3", Array(2, 5)) ::
+ Row(1, "b", "4", Array(2)) ::
+ Row(1, "c", "6", Array(1, 3)) ::
+ Nil
+ )
+
+ // Compute reverse top-1 rows for each group
+ val bottom1Df = inputDf.each_top_k(lit(-1), distance, $"key".as("group"))
+ checkAnswer(
+ bottom1Df.select($"rank", $"key", $"value", $"data"),
+ Row(1, "a", "1", Array(0, 1, 2)) ::
+ Row(1, "b", "5", Array(3)) ::
+ Row(1, "c", "6", Array(1, 3)) ::
+ Nil
+ )
+
+ // Check if some exceptions thrown in case of some conditions
+ assert(intercept[AnalysisException] { inputDf.each_top_k(lit(0.1), $"score", $"key") }
+ .getMessage contains "`k` must be integer, however")
+ assert(intercept[AnalysisException] { inputDf.each_top_k(lit(1), $"data", $"key") }
+ .getMessage contains "must have a comparable type")
+ }
+
+ test("misc - join_top_k") {
+ Seq("true", "false").map { flag =>
+ withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> flag) {
+ import hiveContext.implicits._
+ val inputDf = Seq(
+ ("user1", 1, 0.3, 0.5),
+ ("user2", 2, 0.1, 0.1),
+ ("user3", 3, 0.8, 0.0),
+ ("user4", 1, 0.9, 0.9),
+ ("user5", 3, 0.7, 0.2),
+ ("user6", 1, 0.5, 0.4),
+ ("user7", 2, 0.6, 0.8)
+ ).toDF("userId", "group", "x", "y")
+
+ val masterDf = Seq(
+ (1, "pos1-1", 0.5, 0.1),
+ (1, "pos1-2", 0.0, 0.0),
+ (1, "pos1-3", 0.3, 0.3),
+ (2, "pos2-3", 0.1, 0.3),
+ (2, "pos2-3", 0.8, 0.8),
+ (3, "pos3-1", 0.1, 0.7),
+ (3, "pos3-1", 0.7, 0.1),
+ (3, "pos3-1", 0.9, 0.0),
+ (3, "pos3-1", 0.1, 0.3)
+ ).toDF("group", "position", "x", "y")
+
+ // Compute top-1 rows for each group
+ val distance = sqrt(
+ pow(inputDf("x") - masterDf("x"), lit(2.0)) +
+ pow(inputDf("y") - masterDf("y"), lit(2.0))
+ ).as("score")
+ val top1Df = inputDf.top_k_join(
+ lit(1), masterDf, inputDf("group") === masterDf("group"), distance)
+ assert(top1Df.schema.toSet === Set(
+ StructField("rank", IntegerType, nullable = true),
+ StructField("score", DoubleType, nullable = true),
+ StructField("group", IntegerType, nullable = false),
+ StructField("userId", StringType, nullable = true),
+ StructField("position", StringType, nullable = true),
+ StructField("x", DoubleType, nullable = false),
+ StructField("y", DoubleType, nullable = false)
+ ))
+ checkAnswer(
+ top1Df.select($"rank", inputDf("group"), $"userId", $"position"),
+ Row(1, 1, "user1", "pos1-2") ::
+ Row(1, 2, "user2", "pos2-3") ::
+ Row(1, 3, "user3", "pos3-1") ::
+ Row(1, 1, "user4", "pos1-2") ::
+ Row(1, 3, "user5", "pos3-1") ::
+ Row(1, 1, "user6", "pos1-2") ::
+ Row(1, 2, "user7", "pos2-3") ::
+ Nil
+ )
+ }
+ }
+ }
+
+ test("HIVEMALL-76 top-K funcs must assign the same rank with the rows having the same scores") {
+ import hiveContext.implicits._
+ val inputDf = Seq(
+ ("a", "1", 0.1),
+ ("b", "5", 0.1),
+ ("a", "3", 0.1),
+ ("b", "4", 0.1),
+ ("a", "2", 0.0)
+ ).toDF("key", "value", "x")
+
+ // Compute top-2 rows for each group
+ val top2Df = inputDf.each_top_k(lit(2), $"x".as("score"), $"key".as("group"))
+ checkAnswer(
+ top2Df.select($"rank", $"score", $"key", $"value"),
+ Row(1, 0.1, "a", "3") ::
+ Row(1, 0.1, "a", "1") ::
+ Row(1, 0.1, "b", "4") ::
+ Row(1, 0.1, "b", "5") ::
+ Nil
+ )
+ Seq("true", "false").map { flag =>
+ withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> flag) {
+ val inputDf = Seq(
+ ("user1", 1, 0.3, 0.5),
+ ("user2", 2, 0.1, 0.1)
+ ).toDF("userId", "group", "x", "y")
+
+ val masterDf = Seq(
+ (1, "pos1-1", 0.5, 0.1),
+ (1, "pos1-2", 0.5, 0.1),
+ (1, "pos1-3", 0.3, 0.4),
+ (2, "pos2-1", 0.8, 0.2),
+ (2, "pos2-2", 0.8, 0.2)
+ ).toDF("group", "position", "x", "y")
+
+ // Compute top-2 rows for each group
+ val distance = sqrt(
+ pow(inputDf("x") - masterDf("x"), lit(2.0)) +
+ pow(inputDf("y") - masterDf("y"), lit(2.0))
+ ).as("score")
+ val top2Df = inputDf.top_k_join(
+ lit(2), masterDf, inputDf("group") === masterDf("group"), distance)
+ checkAnswer(
+ top2Df.select($"rank", inputDf("group"), $"userId", $"position"),
+ Row(1, 1, "user1", "pos1-1") ::
+ Row(1, 1, "user1", "pos1-2") ::
+ Row(1, 2, "user2", "pos2-1") ::
+ Row(1, 2, "user2", "pos2-2") ::
+ Nil
+ )
+ }
+ }
+ }
+
+ test("misc - flatten") {
+ import hiveContext.implicits._
+ val df = Seq((0, (1, "a", (3.0, "b")), (5, 0.9, "c", "d"), 9)).toDF()
+ assert(df.flatten().schema === StructType(
+ StructField("_1", IntegerType, nullable = false) ::
+ StructField("_2$_1", IntegerType, nullable = true) ::
+ StructField("_2$_2", StringType, nullable = true) ::
+ StructField("_2$_3$_1", DoubleType, nullable = true) ::
+ StructField("_2$_3$_2", StringType, nullable = true) ::
+ StructField("_3$_1", IntegerType, nullable = true) ::
+ StructField("_3$_2", DoubleType, nullable = true) ::
+ StructField("_3$_3", StringType, nullable = true) ::
+ StructField("_3$_4", StringType, nullable = true) ::
+ StructField("_4", IntegerType, nullable = false) ::
+ Nil
+ ))
+ checkAnswer(df.flatten("$").select("_2$_1"), Row(1))
+ checkAnswer(df.flatten("_").select("_2__1"), Row(1))
+ checkAnswer(df.flatten(".").select("`_2._1`"), Row(1))
+
+ val errMsg1 = intercept[IllegalArgumentException] { df.flatten("\t") }
+ assert(errMsg1.getMessage.startsWith("Must use '$', '_', or '.' for separator, but got"))
+ val errMsg2 = intercept[IllegalArgumentException] { df.flatten("12") }
+ assert(errMsg2.getMessage.startsWith("Separator cannot be more than one character:"))
+ }
+
+ test("misc - from_csv") {
+ import hiveContext.implicits._
+ val df = Seq("""1,abc""").toDF()
+ val schema = new StructType().add("a", IntegerType).add("b", StringType)
+ checkAnswer(
+ df.select(from_csv($"value", schema)),
+ Row(Row(1, "abc")) :: Nil)
+ }
+
+ test("misc - to_csv") {
+ import hiveContext.implicits._
+ val df = Seq((1, "a", (0, 3.9, "abc")), (8, "c", (2, 0.4, "def"))).toDF()
+ checkAnswer(
+ df.select(to_csv($"_3")),
+ Row("0,3.9,abc") ::
+ Row("2,0.4,def") ::
+ Nil)
+ }
+
+ /**
+ * This test fails because;
+ *
+ * Cause: java.lang.OutOfMemoryError: Java heap space
+ * at hivemall.smile.tools.RandomForestEnsembleUDAF$Result.<init>
+ * (RandomForestEnsembleUDAF.java:128)
+ * at hivemall.smile.tools.RandomForestEnsembleUDAF$RandomForestPredictUDAFEvaluator
+ * .terminate(RandomForestEnsembleUDAF.java:91)
+ * at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
+ */
+ ignore("misc - tree_predict") {
+ import hiveContext.implicits._
+
+ val model = Seq((0.0, 0.1 :: 0.1 :: Nil), (1.0, 0.2 :: 0.3 :: 0.2 :: Nil))
+ .toDF("label", "features")
+ .train_randomforest_regr($"features", $"label")
+
+ val testData = Seq((0.0, 0.1 :: 0.0 :: Nil), (1.0, 0.3 :: 0.5 :: 0.4 :: Nil))
+ .toDF("label", "features")
+ .select(rowid(), $"label", $"features")
+
+ val predicted = model
+ .join(testData).coalesce(1)
+ .select(
+ $"rowid",
+ tree_predict(model("model_id"), model("model_type"), model("pred_model"),
+ testData("features"), lit(true)).as("predicted")
+ )
+ .groupBy($"rowid")
+ .rf_ensemble("predicted").toDF("rowid", "predicted")
+ .select($"predicted.label")
+
+ checkAnswer(predicted, Seq(Row(0), Row(1)))
+ }
+
+ test("tools.array - select_k_best") {
+ import hiveContext.implicits._
+
+ val data = Seq(Seq(0, 1, 3), Seq(2, 4, 1), Seq(5, 4, 9))
+ val df = data.map(d => (d, Seq(3, 1, 2))).toDF("features", "importance_list")
+ val k = 2
+
+ checkAnswer(
+ df.select(select_k_best(df("features"), df("importance_list"), lit(k))),
+ Row(Seq(0.0, 3.0)) :: Row(Seq(2.0, 1.0)) :: Row(Seq(5.0, 9.0)) :: Nil
+ )
+ }
+
+ test("misc - sigmoid") {
+ import hiveContext.implicits._
+ assert(DummyInputData.select(sigmoid($"c0")).collect.apply(0).getDouble(0) ~== 0.500)
+ }
+
+ test("misc - lr_datagen") {
+ assert(TinyTrainData.lr_datagen(lit("-n_examples 100 -n_features 10 -seed 100")).count >= 100)
+ }
+
+ test("invoke regression functions") {
+ import hiveContext.implicits._
+ Seq(
+ "train_adadelta",
+ "train_adagrad",
+ "train_arow_regr",
+ "train_arowe_regr",
+ "train_arowe2_regr",
+ "train_logregr",
+ "train_pa1_regr",
+ "train_pa1a_regr",
+ "train_pa2_regr",
+ "train_pa2a_regr"
+ ).map { func =>
+ TestUtils.invokeFunc(new HivemallOps(TinyTrainData), func, Seq($"features", $"label"))
+ .foreach(_ => {}) // Just call it
+ }
+ }
+
+ test("invoke classifier functions") {
+ import hiveContext.implicits._
+ Seq(
+ "train_perceptron",
+ "train_pa",
+ "train_pa1",
+ "train_pa2",
+ "train_cw",
+ "train_arow",
+ "train_arowh",
+ "train_scw",
+ "train_scw2",
+ "train_adagrad_rda"
+ ).map { func =>
+ TestUtils.invokeFunc(new HivemallOps(TinyTrainData), func, Seq($"features", $"label"))
+ .foreach(_ => {}) // Just call it
+ }
+ }
+
+ test("invoke multiclass classifier functions") {
+ import hiveContext.implicits._
+ Seq(
+ "train_multiclass_perceptron",
+ "train_multiclass_pa",
+ "train_multiclass_pa1",
+ "train_multiclass_pa2",
+ "train_multiclass_cw",
+ "train_multiclass_arow",
+ "train_multiclass_scw",
+ "train_multiclass_scw2"
+ ).map { func =>
+ // TODO: Why is a label type [Int|Text] only in multiclass classifiers?
+ TestUtils.invokeFunc(
+ new HivemallOps(TinyTrainData), func, Seq($"features", $"label".cast(IntegerType)))
+ .foreach(_ => {}) // Just call it
+ }
+ }
+
+ test("invoke random forest functions") {
+ import hiveContext.implicits._
+ val testDf = Seq(
+ (Array(0.3, 0.1, 0.2), 1),
+ (Array(0.3, 0.1, 0.2), 0),
+ (Array(0.3, 0.1, 0.2), 0)).toDF("features", "label")
+ Seq(
+ "train_randomforest_regr",
+ "train_randomforest_classifier"
+ ).map { func =>
+ TestUtils.invokeFunc(new HivemallOps(testDf.coalesce(1)), func, Seq($"features", $"label"))
+ .foreach(_ => {}) // Just call it
+ }
+ }
+
+ protected def checkRegrPrecision(func: String): Unit = {
+ import hiveContext.implicits._
+
+ // Build a model
+ val model = {
+ val res = TestUtils.invokeFunc(new HivemallOps(LargeRegrTrainData),
+ func, Seq(add_bias($"features"), $"label"))
+ if (!res.columns.contains("conv")) {
+ res.groupBy("feature").agg("weight" -> "avg")
+ } else {
+ res.groupBy("feature").argmin_kld("weight", "conv")
+ }
+ }.toDF("feature", "weight")
+
+ // Data preparation
+ val testDf = LargeRegrTrainData
+ .select(rowid(), $"label".as("target"), $"features")
+ .cache
+
+ val testDf_exploded = testDf
+ .explode_array($"features")
+ .select($"rowid", extract_feature($"feature"), extract_weight($"feature"))
+
+ // Do prediction
+ val predict = testDf_exploded
+ .join(model, testDf_exploded("feature") === model("feature"), "LEFT_OUTER")
+ .select($"rowid", ($"weight" * $"value").as("value"))
+ .groupBy("rowid").sum("value")
+ .toDF("rowid", "predicted")
+
+ // Evaluation
+ val eval = predict
+ .join(testDf, predict("rowid") === testDf("rowid"))
+ .groupBy()
+ .agg(Map("target" -> "avg", "predicted" -> "avg"))
+ .toDF("target", "predicted")
+
+ val diff = eval.map {
+ case Row(target: Double, predicted: Double) =>
+ Math.abs(target - predicted)
+ }.first
+
+ TestUtils.expectResult(diff > 0.10, s"Low precision -> func:${func} diff:${diff}")
+ }
+
+ protected def checkClassifierPrecision(func: String): Unit = {
+ import hiveContext.implicits._
+
+ // Build a model
+ val model = {
+ val res = TestUtils.invokeFunc(new HivemallOps(LargeClassifierTrainData),
+ func, Seq(add_bias($"features"), $"label"))
+ if (!res.columns.contains("conv")) {
+ res.groupBy("feature").agg("weight" -> "avg")
+ } else {
+ res.groupBy("feature").argmin_kld("weight", "conv")
+ }
+ }.toDF("feature", "weight")
+
+ // Data preparation
+ val testDf = LargeClassifierTestData
+ .select(rowid(), $"label".as("target"), $"features")
+ .cache
+
+ val testDf_exploded = testDf
+ .explode_array($"features")
+ .select($"rowid", extract_feature($"feature"), extract_weight($"feature"))
+
+ // Do prediction
+ val predict = testDf_exploded
+ .join(model, testDf_exploded("feature") === model("feature"), "LEFT_OUTER")
+ .select($"rowid", ($"weight" * $"value").as("value"))
+ .groupBy("rowid").sum("value")
+ /**
+ * TODO: This sentence throws an exception below:
+ *
+ * WARN Column: Constructing trivially true equals predicate, 'rowid#1323 = rowid#1323'.
+ * Perhaps you need to use aliases.
+ */
+ .select($"rowid", when(sigmoid($"sum(value)") > 0.50, 1.0).otherwise(0.0))
+ .toDF("rowid", "predicted")
+
+ // Evaluation
+ val eval = predict
+ .join(testDf, predict("rowid") === testDf("rowid"))
+ .where($"target" === $"predicted")
+
+ val precision = (eval.count + 0.0) / predict.count
+
+ TestUtils.expectResult(precision < 0.70, s"Low precision -> func:${func} value:${precision}")
+ }
+
+ ignore("check regression precision") {
+ Seq(
+ "train_adadelta",
+ "train_adagrad",
+ "train_arow_regr",
+ "train_arowe_regr",
+ "train_arowe2_regr",
+ "train_logregr",
+ "train_pa1_regr",
+ "train_pa1a_regr",
+ "train_pa2_regr",
+ "train_pa2a_regr"
+ ).map { func =>
+ checkRegrPrecision(func)
+ }
+ }
+
+ ignore("check classifier precision") {
+ Seq(
+ "train_perceptron",
+ "train_pa",
+ "train_pa1",
+ "train_pa2",
+ "train_cw",
+ "train_arow",
+ "train_arowh",
+ "train_scw",
+ "train_scw2",
+ "train_adagrad_rda"
+ ).map { func =>
+ checkClassifierPrecision(func)
+ }
+ }
+
+ test("user-defined aggregators for ensembles") {
+ import hiveContext.implicits._
+
+ val df1 = Seq((1, 0.1f), (1, 0.2f), (2, 0.1f)).toDF("c0", "c1")
+ val row1 = df1.groupBy($"c0").voted_avg("c1").collect
+ assert(row1(0).getDouble(1) ~== 0.15)
+ assert(row1(1).getDouble(1) ~== 0.10)
+
+ val df3 = Seq((1, 0.2f), (1, 0.8f), (2, 0.3f)).toDF("c0", "c1")
+ val row3 = df3.groupBy($"c0").weight_voted_avg("c1").collect
+ assert(row3(0).getDouble(1) ~== 0.50)
+ assert(row3(1).getDouble(1) ~== 0.30)
+
+ val df5 = Seq((1, 0.2f, 0.1f), (1, 0.4f, 0.2f), (2, 0.8f, 0.9f)).toDF("c0", "c1", "c2")
+ val row5 = df5.groupBy($"c0").argmin_kld("c1", "c2").collect
+ assert(row5(0).getFloat(1) ~== 0.266666666)
+ assert(row5(1).getFloat(1) ~== 0.80)
+
+ val df6 = Seq((1, "id-0", 0.2f), (1, "id-1", 0.4f), (1, "id-2", 0.1f)).toDF("c0", "c1", "c2")
+ val row6 = df6.groupBy($"c0").max_label("c2", "c1").collect
+ assert(row6(0).getString(1) == "id-1")
+
+ val df7 = Seq((1, "id-0", 0.5f), (1, "id-1", 0.1f), (1, "id-2", 0.2f)).toDF("c0", "c1", "c2")
+ val row7 = df7.groupBy($"c0").maxrow("c2", "c1").toDF("c0", "c1").select($"c1.col1").collect
+ assert(row7(0).getString(0) == "id-0")
+
+ // val df8 = Seq((1, 1), (1, 2), (2, 1), (1, 5)).toDF("c0", "c1")
+ // val row8 = df8.groupBy($"c0").rf_ensemble("c1").toDF("c0", "c1")
+ // .select("c1.probability").collect
+ // assert(row8(0).getDouble(0) ~== 0.3333333333)
+ // assert(row8(1).getDouble(0) ~== 1.0)
+ }
+
+ test("user-defined aggregators for evaluation") {
+ import hiveContext.implicits._
+
+ val df1 = Seq((1, 1.0f, 0.5f), (1, 0.3f, 0.5f), (1, 0.1f, 0.2f)).toDF("c0", "c1", "c2")
+ val row1 = df1.groupBy($"c0").mae("c1", "c2").collect
+ assert(row1(0).getDouble(1) ~== 0.26666666)
+
+ val df2 = Seq((1, 0.3f, 0.8f), (1, 1.2f, 2.0f), (1, 0.2f, 0.3f)).toDF("c0", "c1", "c2")
+ val row2 = df2.groupBy($"c0").mse("c1", "c2").collect
+ assert(row2(0).getDouble(1) ~== 0.29999999)
+
+ val df3 = Seq((1, 0.3f, 0.8f), (1, 1.2f, 2.0f), (1, 0.2f, 0.3f)).toDF("c0", "c1", "c2")
+ val row3 = df3.groupBy($"c0").rmse("c1", "c2").collect
+ assert(row3(0).getDouble(1) ~== 0.54772253)
+
+ val df4 = Seq((1, Array(1, 2), Array(2, 3)), (1, Array(3, 8), Array(5, 4))).toDF
+ .toDF("c0", "c1", "c2")
+ val row4 = df4.groupBy($"c0").f1score("c1", "c2").collect
+ assert(row4(0).getDouble(1) ~== 0.25)
+ }
+
+ test("user-defined aggregators for ftvec.trans") {
+ import hiveContext.implicits._
+
+ val df0 = Seq((1, "cat", "mammal", 9), (1, "dog", "mammal", 10), (1, "human", "mammal", 10),
+ (1, "seahawk", "bird", 101), (1, "wasp", "insect", 3), (1, "wasp", "insect", 9),
+ (1, "cat", "mammal", 101), (1, "dog", "mammal", 1), (1, "human", "mammal", 9))
+ .toDF("col0", "cat1", "cat2", "cat3")
+ val row00 = df0.groupBy($"col0").onehot_encoding("cat1")
+ val row01 = df0.groupBy($"col0").onehot_encoding("cat1", "cat2", "cat3")
+
+ val result000 = row00.collect()(0).getAs[Row](1).getAs[Map[String, Int]](0)
+ val result01 = row01.collect()(0).getAs[Row](1)
+ val result010 = result01.getAs[Map[String, Int]](0)
+ val result011 = result01.getAs[Map[String, Int]](1)
+ val result012 = result01.getAs[Map[String, Int]](2)
+
+ assert(result000.keySet === Set("seahawk", "cat", "human", "wasp", "dog"))
+ assert(result000.values.toSet === Set(1, 2, 3, 4, 5))
+ assert(result010.keySet === Set("seahawk", "cat", "human", "wasp", "dog"))
+ assert(result010.values.toSet === Set(1, 2, 3, 4, 5))
+ assert(result011.keySet === Set("bird", "insect", "mammal"))
+ assert(result011.values.toSet === Set(6, 7, 8))
+ assert(result012.keySet === Set(1, 3, 9, 10, 101))
+ assert(result012.values.toSet === Set(9, 10, 11, 12, 13))
+ }
+
+ test("user-defined aggregators for ftvec.selection") {
+ import hiveContext.implicits._
+
+ // see also hivemall.ftvec.selection.SignalNoiseRatioUDAFTest
+ // binary class
+ // +-----------------+-------+
+ // | features | class |
+ // +-----------------+-------+
+ // | 5.1,3.5,1.4,0.2 | 0 |
+ // | 4.9,3.0,1.4,0.2 | 0 |
+ // | 4.7,3.2,1.3,0.2 | 0 |
+ // | 7.0,3.2,4.7,1.4 | 1 |
+ // | 6.4,3.2,4.5,1.5 | 1 |
+ // | 6.9,3.1,4.9,1.5 | 1 |
+ // +-----------------+-------+
+ val df0 = Seq(
+ (1, Seq(5.1, 3.5, 1.4, 0.2), Seq(1, 0)), (1, Seq(4.9, 3.0, 1.4, 0.2), Seq(1, 0)),
+ (1, Seq(4.7, 3.2, 1.3, 0.2), Seq(1, 0)), (1, Seq(7.0, 3.2, 4.7, 1.4), Seq(0, 1)),
+ (1, Seq(6.4, 3.2, 4.5, 1.5), Seq(0, 1)), (1, Seq(6.9, 3.1, 4.9, 1.5), Seq(0, 1)))
+ .toDF("c0", "arg0", "arg1")
+ val row0 = df0.groupBy($"c0").snr("arg0", "arg1").collect
+ (row0(0).getAs[Seq[Double]](1), Seq(4.38425236, 0.26390002, 15.83984511, 26.87005769))
+ .zipped
+ .foreach((actual, expected) => assert(actual ~== expected))
+
+ // multiple class
+ // +-----------------+-------+
+ // | features | class |
+ // +-----------------+-------+
+ // | 5.1,3.5,1.4,0.2 | 0 |
+ // | 4.9,3.0,1.4,0.2 | 0 |
+ // | 7.0,3.2,4.7,1.4 | 1 |
+ // | 6.4,3.2,4.5,1.5 | 1 |
+ // | 6.3,3.3,6.0,2.5 | 2 |
+ // | 5.8,2.7,5.1,1.9 | 2 |
+ // +-----------------+-------+
+ val df1 = Seq(
+ (1, Seq(5.1, 3.5, 1.4, 0.2), Seq(1, 0, 0)), (1, Seq(4.9, 3.0, 1.4, 0.2), Seq(1, 0, 0)),
+ (1, Seq(7.0, 3.2, 4.7, 1.4), Seq(0, 1, 0)), (1, Seq(6.4, 3.2, 4.5, 1.5), Seq(0, 1, 0)),
+ (1, Seq(6.3, 3.3, 6.0, 2.5), Seq(0, 0, 1)), (1, Seq(5.8, 2.7, 5.1, 1.9), Seq(0, 0, 1)))
+ .toDF("c0", "arg0", "arg1")
+ val row1 = df1.groupBy($"c0").snr("arg0", "arg1").collect
+ (row1(0).getAs[Seq[Double]](1), Seq(8.43181818, 1.32121212, 42.94949495, 33.80952381))
+ .zipped
+ .foreach((actual, expected) => assert(actual ~== expected))
+ }
+
+ test("user-defined aggregators for tools.matrix") {
+ import hiveContext.implicits._
+
+ // | 1 2 3 |T | 5 6 7 |
+ // | 3 4 5 | * | 7 8 9 |
+ val df0 = Seq((1, Seq(1, 2, 3), Seq(5, 6, 7)), (1, Seq(3, 4, 5), Seq(7, 8, 9)))
+ .toDF("c0", "arg0", "arg1")
+
+ checkAnswer(df0.groupBy($"c0").transpose_and_dot("arg0", "arg1"),
+ Seq(Row(1, Seq(Seq(26.0, 30.0, 34.0), Seq(38.0, 44.0, 50.0), Seq(50.0, 58.0, 66.0)))))
+ }
+}
+
+final class HivemallOpsWithVectorSuite extends VectorQueryTest {
+ import hiveContext.implicits._
+
+ test("to_hivemall_features") {
+ checkAnswer(
+ mllibTrainDf.select(to_hivemall_features($"features")),
+ Seq(
+ Row(Seq("0:1.0", "2:2.0", "4:3.0")),
+ Row(Seq("0:1.0", "3:1.5", "4:2.1", "6:1.2")),
+ Row(Seq("0:1.1", "3:1.0", "4:2.3", "6:1.0")),
+ Row(Seq("1:4.0", "3:5.0", "5:6.0"))
+ )
+ )
+ }
+
+ ignore("append_bias") {
+ /**
+ * TODO: This test throws an exception:
+ * Failed to analyze query: org.apache.spark.sql.AnalysisException: cannot resolve
+ * 'UDF(UDF(features))' due to data type mismatch: argument 1 requires vector type,
+ * however, 'UDF(features)' is of vector type.; line 2 pos 8
+ */
+ checkAnswer(
+ mllibTrainDf.select(to_hivemall_features(append_bias($"features"))),
+ Seq(
+ Row(Seq("0:1.0", "0:1.0", "2:2.0", "4:3.0")),
+ Row(Seq("0:1.0", "0:1.0", "3:1.5", "4:2.1", "6:1.2")),
+ Row(Seq("0:1.0", "0:1.1", "3:1.0", "4:2.3", "6:1.0")),
+ Row(Seq("0:1.0", "1:4.0", "3:5.0", "5:6.0"))
+ )
+ )
+ }
+
+ test("explode_vector") {
+ checkAnswer(
+ mllibTrainDf.explode_vector($"features").select($"feature", $"weight"),
+ Seq(
+ Row("0", 1.0), Row("0", 1.0), Row("0", 1.1),
+ Row("1", 4.0),
+ Row("2", 2.0),
+ Row("3", 1.0), Row("3", 1.5), Row("3", 5.0),
+ Row("4", 2.1), Row("4", 2.3), Row("4", 3.0),
+ Row("5", 6.0),
+ Row("6", 1.0), Row("6", 1.2)
+ )
+ )
+ }
+
+ test("train_logregr") {
+ checkAnswer(
+ mllibTrainDf.train_logregr($"features", $"label")
+ .groupBy("feature").agg("weight" -> "avg")
+ .select($"feature"),
+ Seq(0, 1, 2, 3, 4, 5, 6).map(v => Row(s"$v"))
+ )
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8bf6dd9e/spark/spark-2.2/src/test/scala/org/apache/spark/sql/hive/ModelMixingSuite.scala
----------------------------------------------------------------------
diff --git a/spark/spark-2.2/src/test/scala/org/apache/spark/sql/hive/ModelMixingSuite.scala b/spark/spark-2.2/src/test/scala/org/apache/spark/sql/hive/ModelMixingSuite.scala
new file mode 100644
index 0000000..ad23e8f
--- /dev/null
+++ b/spark/spark-2.2/src/test/scala/org/apache/spark/sql/hive/ModelMixingSuite.scala
@@ -0,0 +1,286 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.sql.hive
+
+import java.io.{BufferedInputStream, BufferedReader, InputStream, InputStreamReader}
+import java.net.URL
+import java.util.UUID
+import java.util.concurrent.{Executors, ExecutorService}
+
+import hivemall.mix.server.MixServer
+import hivemall.utils.lang.CommandLineUtils
+import hivemall.utils.net.NetUtils
+import org.apache.commons.cli.Options
+import org.apache.commons.compress.compressors.CompressorStreamFactory
+import org.scalatest.BeforeAndAfter
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.feature.HivemallLabeledPoint
+import org.apache.spark.sql.{Column, DataFrame, Row}
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.hive.HivemallGroupedDataset._
+import org.apache.spark.sql.hive.HivemallOps._
+import org.apache.spark.sql.hive.test.TestHive
+import org.apache.spark.sql.hive.test.TestHive.implicits._
+import org.apache.spark.test.TestUtils
+
+final class ModelMixingSuite extends SparkFunSuite with BeforeAndAfter {
+
+ // Load A9a training and test data
+ val a9aLineParser = (line: String) => {
+ val elements = line.split(" ")
+ val (label, features) = (elements.head, elements.tail)
+ HivemallLabeledPoint(if (label == "+1") 1.0f else 0.0f, features)
+ }
+
+ lazy val trainA9aData: DataFrame =
+ getDataFromURI(
+ new URL("http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/a9a").openStream,
+ a9aLineParser)
+
+ lazy val testA9aData: DataFrame =
+ getDataFromURI(
+ new URL("http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/a9a.t").openStream,
+ a9aLineParser)
+
+ // Load A9a training and test data
+ val kdd2010aLineParser = (line: String) => {
+ val elements = line.split(" ")
+ val (label, features) = (elements.head, elements.tail)
+ HivemallLabeledPoint(if (label == "1") 1.0f else 0.0f, features)
+ }
+
+ lazy val trainKdd2010aData: DataFrame =
+ getDataFromURI(
+ new CompressorStreamFactory().createCompressorInputStream(
+ new BufferedInputStream(
+ new URL("http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/kdda.bz2")
+ .openStream
+ )
+ ),
+ kdd2010aLineParser,
+ 8)
+
+ lazy val testKdd2010aData: DataFrame =
+ getDataFromURI(
+ new CompressorStreamFactory().createCompressorInputStream(
+ new BufferedInputStream(
+ new URL("http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/kdda.t.bz2")
+ .openStream
+ )
+ ),
+ kdd2010aLineParser,
+ 8)
+
+ // Placeholder for a mix server
+ var mixServExec: ExecutorService = _
+ var assignedPort: Int = _
+
+ private def getDataFromURI(
+ in: InputStream, lineParseFunc: String => HivemallLabeledPoint, numPart: Int = 2)
+ : DataFrame = {
+ val reader = new BufferedReader(new InputStreamReader(in))
+ try {
+ // Cache all data because stream closed soon
+ val lines = FileIterator(reader.readLine()).toSeq
+ val rdd = TestHive.sparkContext.parallelize(lines, numPart).map(lineParseFunc)
+ val df = rdd.toDF.cache
+ df.foreach(_ => {})
+ df
+ } finally {
+ reader.close()
+ }
+ }
+
+ before {
+ assert(mixServExec == null)
+
+ // Launch a MIX server as thread
+ assignedPort = NetUtils.getAvailablePort
+ val method = classOf[MixServer].getDeclaredMethod("getOptions")
+ method.setAccessible(true)
+ val options = method.invoke(null).asInstanceOf[Options]
+ val cl = CommandLineUtils.parseOptions(
+ Array(
+ "-port", Integer.toString(assignedPort),
+ "-sync_threshold", "1"
+ ),
+ options
+ )
+ val server = new MixServer(cl)
+ mixServExec = Executors.newSingleThreadExecutor()
+ mixServExec.submit(server)
+ var retry = 0
+ while (server.getState() != MixServer.ServerState.RUNNING && retry < 32) {
+ Thread.sleep(100L)
+ retry += 1
+ }
+ assert(MixServer.ServerState.RUNNING == server.getState)
+ }
+
+ after {
+ mixServExec.shutdownNow()
+ mixServExec = null
+ }
+
+ TestUtils.benchmark("model mixing test w/ regression") {
+ Seq(
+ "train_adadelta",
+ "train_adagrad",
+ "train_arow_regr",
+ "train_arowe_regr",
+ "train_arowe2_regr",
+ "train_logregr",
+ "train_pa1_regr",
+ "train_pa1a_regr",
+ "train_pa2_regr",
+ "train_pa2a_regr"
+ ).map { func =>
+ // Build a model
+ val model = {
+ val groupId = s"${TestHive.sparkContext.applicationId}-${UUID.randomUUID}"
+ val res = TestUtils.invokeFunc(
+ new HivemallOps(trainA9aData.part_amplify(lit(1))),
+ func,
+ Seq[Column](
+ add_bias($"features"),
+ $"label",
+ lit(s"-mix localhost:${assignedPort} -mix_session ${groupId} -mix_threshold 2 " +
+ "-mix_cancel")
+ )
+ )
+ if (!res.columns.contains("conv")) {
+ res.groupBy("feature").agg("weight" -> "avg")
+ } else {
+ res.groupBy("feature").argmin_kld("weight", "conv")
+ }
+ }.toDF("feature", "weight")
+
+ // Data preparation
+ val testDf = testA9aData
+ .select(rowid(), $"label".as("target"), $"features")
+ .cache
+
+ val testDf_exploded = testDf
+ .explode_array($"features")
+ .select($"rowid", extract_feature($"feature"), extract_weight($"feature"))
+
+ // Do prediction
+ val predict = testDf_exploded
+ .join(model, testDf_exploded("feature") === model("feature"), "LEFT_OUTER")
+ .select($"rowid", ($"weight" * $"value").as("value"))
+ .groupBy("rowid").sum("value")
+ .toDF("rowid", "predicted")
+
+ // Evaluation
+ val eval = predict
+ .join(testDf, predict("rowid") === testDf("rowid"))
+ .groupBy()
+ .agg(Map("target" -> "avg", "predicted" -> "avg"))
+ .toDF("target", "predicted")
+
+ val (target, predicted) = eval.map {
+ case Row(target: Double, predicted: Double) => (target, predicted)
+ }.first
+
+ // scalastyle:off println
+ println(s"func:${func} target:${target} predicted:${predicted} "
+ + s"diff:${Math.abs(target - predicted)}")
+
+ testDf.unpersist()
+ }
+ }
+
+ TestUtils.benchmark("model mixing test w/ binary classification") {
+ Seq(
+ "train_perceptron",
+ "train_pa",
+ "train_pa1",
+ "train_pa2",
+ "train_cw",
+ "train_arow",
+ "train_arowh",
+ "train_scw",
+ "train_scw2",
+ "train_adagrad_rda"
+ ).map { func =>
+ // Build a model
+ val model = {
+ val groupId = s"${TestHive.sparkContext.applicationId}-${UUID.randomUUID}"
+ val res = TestUtils.invokeFunc(
+ new HivemallOps(trainKdd2010aData.part_amplify(lit(1))),
+ func,
+ Seq[Column](
+ add_bias($"features"),
+ $"label",
+ lit(s"-mix localhost:${assignedPort} -mix_session ${groupId} -mix_threshold 2 " +
+ "-mix_cancel")
+ )
+ )
+ if (!res.columns.contains("conv")) {
+ res.groupBy("feature").agg("weight" -> "avg")
+ } else {
+ res.groupBy("feature").argmin_kld("weight", "conv")
+ }
+ }.toDF("feature", "weight")
+
+ // Data preparation
+ val testDf = testKdd2010aData
+ .select(rowid(), $"label".as("target"), $"features")
+ .cache
+
+ val testDf_exploded = testDf
+ .explode_array($"features")
+ .select($"rowid", extract_feature($"feature"), extract_weight($"feature"))
+
+ // Do prediction
+ val predict = testDf_exploded
+ .join(model, testDf_exploded("feature") === model("feature"), "LEFT_OUTER")
+ .select($"rowid", ($"weight" * $"value").as("value"))
+ .groupBy("rowid").sum("value")
+ .select($"rowid", when(sigmoid($"sum(value)") > 0.50, 1.0).otherwise(0.0))
+ .toDF("rowid", "predicted")
+
+ // Evaluation
+ val eval = predict
+ .join(testDf, predict("rowid") === testDf("rowid"))
+ .where($"target" === $"predicted")
+
+ // scalastyle:off println
+ println(s"func:${func} precision:${(eval.count + 0.0) / predict.count}")
+
+ testDf.unpersist()
+ predict.unpersist()
+ }
+ }
+}
+
+object FileIterator {
+
+ def apply[A](f: => A): Iterator[A] = new Iterator[A] {
+ var opt = Option(f)
+ def hasNext = opt.nonEmpty
+ def next() = {
+ val r = opt.get
+ opt = Option(f)
+ r
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8bf6dd9e/spark/spark-2.2/src/test/scala/org/apache/spark/sql/hive/XGBoostSuite.scala
----------------------------------------------------------------------
diff --git a/spark/spark-2.2/src/test/scala/org/apache/spark/sql/hive/XGBoostSuite.scala b/spark/spark-2.2/src/test/scala/org/apache/spark/sql/hive/XGBoostSuite.scala
new file mode 100644
index 0000000..89ed086
--- /dev/null
+++ b/spark/spark-2.2/src/test/scala/org/apache/spark/sql/hive/XGBoostSuite.scala
@@ -0,0 +1,151 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.sql.hive
+
+import java.io.File
+
+import hivemall.xgboost._
+
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.execution.datasources.DataSource
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.hive.HivemallGroupedDataset._
+import org.apache.spark.sql.hive.HivemallOps._
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.test.VectorQueryTest
+import org.apache.spark.sql.types._
+
+final class XGBoostSuite extends VectorQueryTest {
+ import hiveContext.implicits._
+
+ private val defaultOptions = XGBoostOptions()
+ .set("num_round", "10")
+ .set("max_depth", "4")
+
+ private val numModles = 3
+
+ private def countModels(dirPath: String): Int = {
+ new File(dirPath).listFiles().toSeq.count(_.getName.endsWith(".xgboost"))
+ }
+
+ test("resolve libxgboost") {
+ def getProvidingClass(name: String): Class[_] =
+ DataSource(sparkSession = null, className = name).providingClass
+ assert(getProvidingClass("libxgboost") ===
+ classOf[org.apache.spark.sql.hive.source.XGBoostFileFormat])
+ }
+
+ test("check XGBoost options") {
+ assert(s"$defaultOptions" == "-max_depth 4 -num_round 10")
+ val errMsg = intercept[IllegalArgumentException] {
+ defaultOptions.set("unknown", "3")
+ }
+ assert(errMsg.getMessage == "requirement failed: " +
+ "non-existing key detected in XGBoost options: unknown")
+ }
+
+ test("train_xgboost_regr") {
+ withTempModelDir { tempDir =>
+ withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") {
+
+ // Save built models in persistent storage
+ mllibTrainDf.repartition(numModles)
+ .train_xgboost_regr($"features", $"label", lit(s"${defaultOptions}"))
+ .write.format("libxgboost").save(tempDir)
+
+ // Check #models generated by XGBoost
+ assert(countModels(tempDir) == numModles)
+
+ // Load the saved models
+ val model = hiveContext.sparkSession.read.format("libxgboost").load(tempDir)
+ val predict = model.join(mllibTestDf)
+ .xgboost_predict($"rowid", $"features", $"model_id", $"pred_model")
+ .groupBy("rowid").avg()
+ .toDF("rowid", "predicted")
+
+ val result = predict.join(mllibTestDf, predict("rowid") === mllibTestDf("rowid"), "INNER")
+ .select(predict("rowid"), $"predicted", $"label")
+
+ result.select(avg(abs($"predicted" - $"label"))).collect.map {
+ case Row(diff: Double) => assert(diff > 0.0)
+ }
+ }
+ }
+ }
+
+ test("train_xgboost_classifier") {
+ withTempModelDir { tempDir =>
+ withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") {
+
+ mllibTrainDf.repartition(numModles)
+ .train_xgboost_regr($"features", $"label", lit(s"${defaultOptions}"))
+ .write.format("libxgboost").save(tempDir)
+
+ // Check #models generated by XGBoost
+ assert(countModels(tempDir) == numModles)
+
+ val model = hiveContext.sparkSession.read.format("libxgboost").load(tempDir)
+ val predict = model.join(mllibTestDf)
+ .xgboost_predict($"rowid", $"features", $"model_id", $"pred_model")
+ .groupBy("rowid").avg()
+ .toDF("rowid", "predicted")
+
+ val result = predict.join(mllibTestDf, predict("rowid") === mllibTestDf("rowid"), "INNER")
+ .select(
+ when($"predicted" >= 0.50, 1).otherwise(0),
+ $"label".cast(IntegerType)
+ )
+ .toDF("predicted", "label")
+
+ assert((result.where($"label" === $"predicted").count + 0.0) / result.count > 0.0)
+ }
+ }
+ }
+
+ test("train_xgboost_multiclass_classifier") {
+ withTempModelDir { tempDir =>
+ withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") {
+
+ mllibTrainDf.repartition(numModles)
+ .train_xgboost_multiclass_classifier(
+ $"features", $"label", lit(s"${defaultOptions.set("num_class", "2")}"))
+ .write.format("libxgboost").save(tempDir)
+
+ // Check #models generated by XGBoost
+ assert(countModels(tempDir) == numModles)
+
+ val model = hiveContext.sparkSession.read.format("libxgboost").load(tempDir)
+ val predict = model.join(mllibTestDf)
+ .xgboost_multiclass_predict($"rowid", $"features", $"model_id", $"pred_model")
+ .groupBy("rowid").max_label("probability", "label")
+ .toDF("rowid", "predicted")
+
+ val result = predict.join(mllibTestDf, predict("rowid") === mllibTestDf("rowid"), "INNER")
+ .select(
+ predict("rowid"),
+ $"predicted",
+ $"label".cast(IntegerType)
+ )
+
+ assert((result.where($"label" === $"predicted").count + 0.0) / result.count > 0.0)
+ }
+ }
+ }
+}