You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hivemall.apache.org by my...@apache.org on 2016/12/02 08:02:33 UTC
[43/50] [abbrv] incubator-hivemall git commit: Merge branch
'feature/feature_selection' of https://github.com/amaya382/hivemall into
feature_selection
Merge branch 'feature/feature_selection' of
https://github.com/amaya382/hivemall into feature_selection
# Conflicts:
# core/src/main/java/hivemall/utils/hadoop/HiveUtils.java
# core/src/main/java/hivemall/utils/math/StatsUtils.java
# spark/spark-1.6/src/main/scala/org/apache/spark/sql/hive/GroupedDataEx.scala
# spark/spark-1.6/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
# spark/spark-2.0/src/main/scala/org/apache/spark/sql/hive/HivemallGroupedDataset.scala
# spark/spark-2.0/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/67ba9631
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/67ba9631
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/67ba9631
Branch: refs/heads/JIRA-22/pr-385
Commit: 67ba9631af3c231b7abd145134d17237b6aca0a5
Parents: 69496fa ce4a489
Author: myui <yu...@gmail.com>
Authored: Mon Nov 21 18:19:45 2016 +0900
Committer: myui <yu...@gmail.com>
Committed: Mon Nov 21 18:19:45 2016 +0900
----------------------------------------------------------------------
.../hivemall/ftvec/selection/ChiSquareUDF.java | 155 ++++++++
.../ftvec/selection/SignalNoiseRatioUDAF.java | 349 +++++++++++++++++++
.../hivemall/tools/array/SelectKBestUDF.java | 143 ++++++++
.../tools/matrix/TransposeAndDotUDAF.java | 213 +++++++++++
.../java/hivemall/utils/hadoop/HiveUtils.java | 22 +-
.../java/hivemall/utils/math/StatsUtils.java | 91 +++++
.../ftvec/selection/ChiSquareUDFTest.java | 80 +++++
.../selection/SignalNoiseRatioUDAFTest.java | 348 ++++++++++++++++++
.../tools/array/SelectKBeatUDFTest.java | 65 ++++
.../tools/matrix/TransposeAndDotUDAFTest.java | 58 +++
resources/ddl/define-all-as-permanent.hive | 20 ++
resources/ddl/define-all.hive | 20 ++
resources/ddl/define-all.spark | 20 ++
resources/ddl/define-udfs.td.hql | 4 +
.../apache/spark/sql/hive/GroupedDataEx.scala | 21 ++
.../org/apache/spark/sql/hive/HivemallOps.scala | 18 +
.../spark/sql/hive/HivemallOpsSuite.scala | 100 ++++++
.../spark/sql/hive/HivemallGroupedDataset.scala | 25 ++
.../org/apache/spark/sql/hive/HivemallOps.scala | 20 ++
.../spark/sql/hive/HivemallOpsSuite.scala | 103 ++++++
20 files changed, 1873 insertions(+), 2 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/67ba9631/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java
----------------------------------------------------------------------
diff --cc core/src/main/java/hivemall/utils/hadoop/HiveUtils.java
index d8b1aef,c752188..8188b7a
--- a/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java
+++ b/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java
@@@ -242,10 -240,16 +242,20 @@@ public final class HiveUtils
return category == Category.LIST;
}
+ public static boolean isMapOI(@Nonnull final ObjectInspector oi) {
+ return oi.getCategory() == Category.MAP;
+ }
+
+ public static boolean isNumberListOI(@Nonnull final ObjectInspector oi) {
+ return isListOI(oi)
+ && isNumberOI(((ListObjectInspector) oi).getListElementObjectInspector());
+ }
+
+ public static boolean isNumberListListOI(@Nonnull final ObjectInspector oi) {
+ return isListOI(oi)
+ && isNumberListOI(((ListObjectInspector) oi).getListElementObjectInspector());
+ }
+
public static boolean isPrimitiveTypeInfo(@Nonnull TypeInfo typeInfo) {
return typeInfo.getCategory() == ObjectInspector.Category.PRIMITIVE;
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/67ba9631/core/src/main/java/hivemall/utils/math/StatsUtils.java
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/67ba9631/spark/spark-1.6/src/main/scala/org/apache/spark/sql/hive/GroupedDataEx.scala
----------------------------------------------------------------------
diff --cc spark/spark-1.6/src/main/scala/org/apache/spark/sql/hive/GroupedDataEx.scala
index fd4da64,2482c62..8f78a7f
--- a/spark/spark-1.6/src/main/scala/org/apache/spark/sql/hive/GroupedDataEx.scala
+++ b/spark/spark-1.6/src/main/scala/org/apache/spark/sql/hive/GroupedDataEx.scala
@@@ -267,13 -266,25 +267,34 @@@ final class GroupedDataEx protected[sql
}
/**
+ * @see hivemall.ftvec.trans.OnehotEncodingUDAF
+ */
+ def onehot_encoding(features: String*): DataFrame = {
+ val udaf = HiveUDAFFunction(
+ new HiveFunctionWrapper("hivemall.ftvec.trans.OnehotEncodingUDAF"),
+ features.map(df.col(_).expr),
+ isUDAFBridgeRequired = false)
++
++ /**
+ * @see hivemall.ftvec.selection.SignalNoiseRatioUDAF
+ */
+ def snr(X: String, Y: String): DataFrame = {
+ val udaf = HiveUDAFFunction(
+ new HiveFunctionWrapper("hivemall.ftvec.selection.SignalNoiseRatioUDAF"),
+ Seq(X, Y).map(df.col(_).expr),
+ isUDAFBridgeRequired = false)
+ .toAggregateExpression()
+ toDF(Seq(Alias(udaf, udaf.prettyString)()))
+ }
+
+ /**
+ * @see hivemall.tools.matrix.TransposeAndDotUDAF
+ */
+ def transpose_and_dot(X: String, Y: String): DataFrame = {
+ val udaf = HiveUDAFFunction(
+ new HiveFunctionWrapper("hivemall.tools.matrix.TransposeAndDotUDAF"),
+ Seq(X, Y).map(df.col(_).expr),
+ isUDAFBridgeRequired = false)
.toAggregateExpression()
toDF(Seq(Alias(udaf, udaf.prettyString)()))
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/67ba9631/spark/spark-1.6/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/67ba9631/spark/spark-1.6/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
----------------------------------------------------------------------
diff --cc spark/spark-1.6/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
index 901056d,c7016c0..c231105
--- a/spark/spark-1.6/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
+++ b/spark/spark-1.6/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
@@@ -534,30 -570,63 +575,89 @@@ final class HivemallOpsSuite extends Hi
assert(row4(0).getDouble(1) ~== 0.25)
}
+ test("user-defined aggregators for ftvec.trans") {
+ import hiveContext.implicits._
+
+ val df0 = Seq((1, "cat", "mammal", 9), (1, "dog", "mammal", 10), (1, "human", "mammal", 10),
+ (1, "seahawk", "bird", 101), (1, "wasp", "insect", 3), (1, "wasp", "insect", 9),
+ (1, "cat", "mammal", 101), (1, "dog", "mammal", 1), (1, "human", "mammal", 9))
+ .toDF("col0", "cat1", "cat2", "cat3")
+
+ val row00 = df0.groupby($"col0").onehot_encoding("cat1")
+ val row01 = df0.groupby($"col0").onehot_encoding("cat1", "cat2", "cat3")
+
+ val result000 = row00.collect()(0).getAs[Row](1).getAs[Map[String, Int]](0)
+ val result01 = row01.collect()(0).getAs[Row](1)
+ val result010 = result01.getAs[Map[String, Int]](0)
+ val result011 = result01.getAs[Map[String, Int]](1)
+ val result012 = result01.getAs[Map[String, Int]](2)
+
+ assert(result000.keySet === Set("seahawk", "cat", "human", "wasp", "dog"))
+ assert(result000.values.toSet === Set(1, 2, 3, 4, 5))
+ assert(result010.keySet === Set("seahawk", "cat", "human", "wasp", "dog"))
+ assert(result010.values.toSet === Set(1, 2, 3, 4, 5))
+ assert(result011.keySet === Set("bird", "insect", "mammal"))
+ assert(result011.values.toSet === Set(6, 7, 8))
+ assert(result012.keySet === Set(1, 3, 9, 10, 101))
+ assert(result012.values.toSet === Set(9, 10, 11, 12, 13))
++
+ test("user-defined aggregators for ftvec.selection") {
+ import hiveContext.implicits._
+
+ // see also hivemall.ftvec.selection.SignalNoiseRatioUDAFTest
+ // binary class
+ // +-----------------+-------+
+ // | features | class |
+ // +-----------------+-------+
+ // | 5.1,3.5,1.4,0.2 | 0 |
+ // | 4.9,3.0,1.4,0.2 | 0 |
+ // | 4.7,3.2,1.3,0.2 | 0 |
+ // | 7.0,3.2,4.7,1.4 | 1 |
+ // | 6.4,3.2,4.5,1.5 | 1 |
+ // | 6.9,3.1,4.9,1.5 | 1 |
+ // +-----------------+-------+
+ val df0 = Seq(
+ (1, Seq(5.1, 3.5, 1.4, 0.2), Seq(1, 0)), (1, Seq(4.9, 3.0, 1.4, 0.2), Seq(1, 0)),
+ (1, Seq(4.7, 3.2, 1.3, 0.2), Seq(1, 0)), (1, Seq(7.0, 3.2, 4.7, 1.4), Seq(0, 1)),
+ (1, Seq(6.4, 3.2, 4.5, 1.5), Seq(0, 1)), (1, Seq(6.9, 3.1, 4.9, 1.5), Seq(0, 1)))
+ .toDF("c0", "arg0", "arg1")
+ val row0 = df0.groupby($"c0").snr("arg0", "arg1").collect
+ (row0(0).getAs[Seq[Double]](1), Seq(4.38425236, 0.26390002, 15.83984511, 26.87005769))
+ .zipped
+ .foreach((actual, expected) => assert(actual ~== expected))
+
+ // multiple class
+ // +-----------------+-------+
+ // | features | class |
+ // +-----------------+-------+
+ // | 5.1,3.5,1.4,0.2 | 0 |
+ // | 4.9,3.0,1.4,0.2 | 0 |
+ // | 7.0,3.2,4.7,1.4 | 1 |
+ // | 6.4,3.2,4.5,1.5 | 1 |
+ // | 6.3,3.3,6.0,2.5 | 2 |
+ // | 5.8,2.7,5.1,1.9 | 2 |
+ // +-----------------+-------+
+ val df1 = Seq(
+ (1, Seq(5.1, 3.5, 1.4, 0.2), Seq(1, 0, 0)), (1, Seq(4.9, 3.0, 1.4, 0.2), Seq(1, 0, 0)),
+ (1, Seq(7.0, 3.2, 4.7, 1.4), Seq(0, 1, 0)), (1, Seq(6.4, 3.2, 4.5, 1.5), Seq(0, 1, 0)),
+ (1, Seq(6.3, 3.3, 6.0, 2.5), Seq(0, 0, 1)), (1, Seq(5.8, 2.7, 5.1, 1.9), Seq(0, 0, 1)))
+ .toDF("c0", "arg0", "arg1")
+ val row1 = df1.groupby($"c0").snr("arg0", "arg1").collect
+ (row1(0).getAs[Seq[Double]](1), Seq(8.43181818, 1.32121212, 42.94949495, 33.80952381))
+ .zipped
+ .foreach((actual, expected) => assert(actual ~== expected))
+ }
+
+ test("user-defined aggregators for tools.matrix") {
+ import hiveContext.implicits._
+
+ // | 1 2 3 |T | 5 6 7 |
+ // | 3 4 5 | * | 7 8 9 |
+ val df0 = Seq((1, Seq(1, 2, 3), Seq(5, 6, 7)), (1, Seq(3, 4, 5), Seq(7, 8, 9)))
+ .toDF("c0", "arg0", "arg1")
+
+ // if use checkAnswer here, fail for some reason, maybe type? but it's okay on spark-2.0
+ assert(df0.groupby($"c0").transpose_and_dot("arg0", "arg1").collect() ===
+ Seq(Row(1, Seq(Seq(26.0, 30.0, 34.0), Seq(38.0, 44.0, 50.0), Seq(50.0, 58.0, 66.0)))))
}
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/67ba9631/spark/spark-2.0/src/main/scala/org/apache/spark/sql/hive/HivemallGroupedDataset.scala
----------------------------------------------------------------------
diff --cc spark/spark-2.0/src/main/scala/org/apache/spark/sql/hive/HivemallGroupedDataset.scala
index 8ac7185,0000000..73757f6
mode 100644,000000..100644
--- a/spark/spark-2.0/src/main/scala/org/apache/spark/sql/hive/HivemallGroupedDataset.scala
+++ b/spark/spark-2.0/src/main/scala/org/apache/spark/sql/hive/HivemallGroupedDataset.scala
@@@ -1,277 -1,0 +1,302 @@@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.spark.sql.hive
+
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.Dataset
+import org.apache.spark.sql.RelationalGroupedDataset
+import org.apache.spark.sql.catalyst.analysis.UnresolvedAlias
+import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.logical.Aggregate
+import org.apache.spark.sql.catalyst.plans.logical.Pivot
+import org.apache.spark.sql.hive.HiveShim.HiveFunctionWrapper
+import org.apache.spark.sql.types._
+
+/**
+ * Groups the [[DataFrame]] using the specified columns, so we can run aggregation on them.
+ *
+ * @groupname ensemble
+ * @groupname ftvec.trans
+ * @groupname evaluation
+ */
+final class HivemallGroupedDataset(groupBy: RelationalGroupedDataset) {
+
+ /**
+ * @see hivemall.ensemble.bagging.VotedAvgUDAF
+ * @group ensemble
+ */
+ def voted_avg(weight: String): DataFrame = {
+ // checkType(weight, NumericType)
+ val udaf = HiveUDAFFunction(
+ "voted_avg",
+ new HiveFunctionWrapper("hivemall.ensemble.bagging.WeightVotedAvgUDAF"),
+ Seq(weight).map(df.col(_).expr),
+ isUDAFBridgeRequired = true)
+ .toAggregateExpression()
+ toDF((Alias(udaf, udaf.prettyName)() :: Nil).toSeq)
+ }
+
+ /**
+ * @see hivemall.ensemble.bagging.WeightVotedAvgUDAF
+ * @group ensemble
+ */
+ def weight_voted_avg(weight: String): DataFrame = {
+ // checkType(weight, NumericType)
+ val udaf = HiveUDAFFunction(
+ "weight_voted_avg",
+ new HiveFunctionWrapper("hivemall.ensemble.bagging.WeightVotedAvgUDAF"),
+ Seq(weight).map(df.col(_).expr),
+ isUDAFBridgeRequired = true)
+ .toAggregateExpression()
+ toDF((Alias(udaf, udaf.prettyName)() :: Nil).toSeq)
+ }
+
+ /**
+ * @see hivemall.ensemble.ArgminKLDistanceUDAF
+ * @group ensemble
+ */
+ def argmin_kld(weight: String, conv: String): DataFrame = {
+ // checkType(weight, NumericType)
+ // checkType(conv, NumericType)
+ val udaf = HiveUDAFFunction(
+ "argmin_kld",
+ new HiveFunctionWrapper("hivemall.ensemble.ArgminKLDistanceUDAF"),
+ Seq(weight, conv).map(df.col(_).expr),
+ isUDAFBridgeRequired = true)
+ .toAggregateExpression()
+ toDF((Alias(udaf, udaf.prettyName)() :: Nil).toSeq)
+ }
+
+ /**
+ * @see hivemall.ensemble.MaxValueLabelUDAF"
+ * @group ensemble
+ */
+ def max_label(score: String, label: String): DataFrame = {
+ // checkType(score, NumericType)
+ checkType(label, StringType)
+ val udaf = HiveUDAFFunction(
+ "max_label",
+ new HiveFunctionWrapper("hivemall.ensemble.MaxValueLabelUDAF"),
+ Seq(score, label).map(df.col(_).expr),
+ isUDAFBridgeRequired = true)
+ .toAggregateExpression()
+ toDF((Alias(udaf, udaf.prettyName)() :: Nil).toSeq)
+ }
+
+ /**
+ * @see hivemall.ensemble.MaxRowUDAF
+ * @group ensemble
+ */
+ def maxrow(score: String, label: String): DataFrame = {
+ // checkType(score, NumericType)
+ checkType(label, StringType)
+ val udaf = HiveUDAFFunction(
+ "maxrow",
+ new HiveFunctionWrapper("hivemall.ensemble.MaxRowUDAF"),
+ Seq(score, label).map(df.col(_).expr),
+ isUDAFBridgeRequired = false)
+ .toAggregateExpression()
+ toDF((Alias(udaf, udaf.prettyName)() :: Nil).toSeq)
+ }
+
+ /**
+ * @see hivemall.smile.tools.RandomForestEnsembleUDAF
+ * @group ensemble
+ */
+ def rf_ensemble(predict: String): DataFrame = {
+ // checkType(predict, NumericType)
+ val udaf = HiveUDAFFunction(
+ "rf_ensemble",
+ new HiveFunctionWrapper("hivemall.smile.tools.RandomForestEnsembleUDAF"),
+ Seq(predict).map(df.col(_).expr),
+ isUDAFBridgeRequired = true)
+ .toAggregateExpression()
+ toDF((Alias(udaf, udaf.prettyName)() :: Nil).toSeq)
+ }
+
+ /**
+ * @see hivemall.ftvec.trans.OnehotEncodingUDAF
+ * @group ftvec.trans
+ */
+ def onehot_encoding(cols: String*): DataFrame = {
+ val udaf = HiveUDAFFunction(
+ "onehot_encoding",
+ new HiveFunctionWrapper("hivemall.ftvec.trans.OnehotEncodingUDAF"),
+ cols.map(df.col(_).expr),
+ isUDAFBridgeRequired = false)
+ .toAggregateExpression()
+ toDF(Seq(Alias(udaf, udaf.prettyName)()))
+ }
+
+ /**
+ * @see hivemall.evaluation.MeanAbsoluteErrorUDAF
+ * @group evaluation
+ */
+ def mae(predict: String, target: String): DataFrame = {
+ checkType(predict, FloatType)
+ checkType(target, FloatType)
+ val udaf = HiveUDAFFunction(
+ "mae",
+ new HiveFunctionWrapper("hivemall.evaluation.MeanAbsoluteErrorUDAF"),
+ Seq(predict, target).map(df.col(_).expr),
+ isUDAFBridgeRequired = true)
+ .toAggregateExpression()
+ toDF((Alias(udaf, udaf.prettyName)() :: Nil).toSeq)
+ }
+
+ /**
+ * @see hivemall.evaluation.MeanSquareErrorUDAF
+ * @group evaluation
+ */
+ def mse(predict: String, target: String): DataFrame = {
+ checkType(predict, FloatType)
+ checkType(target, FloatType)
+ val udaf = HiveUDAFFunction(
+ "mse",
+ new HiveFunctionWrapper("hivemall.evaluation.MeanSquaredErrorUDAF"),
+ Seq(predict, target).map(df.col(_).expr),
+ isUDAFBridgeRequired = true)
+ .toAggregateExpression()
+ toDF((Alias(udaf, udaf.prettyName)() :: Nil).toSeq)
+ }
+
+ /**
+ * @see hivemall.evaluation.RootMeanSquareErrorUDAF
+ * @group evaluation
+ */
+ def rmse(predict: String, target: String): DataFrame = {
+ checkType(predict, FloatType)
+ checkType(target, FloatType)
+ val udaf = HiveUDAFFunction(
+ "rmse",
+ new HiveFunctionWrapper("hivemall.evaluation.RootMeanSquaredErrorUDAF"),
+ Seq(predict, target).map(df.col(_).expr),
+ isUDAFBridgeRequired = true)
+ .toAggregateExpression()
+ toDF((Alias(udaf, udaf.prettyName)() :: Nil).toSeq)
+ }
+
+ /**
+ * @see hivemall.evaluation.FMeasureUDAF
+ * @group evaluation
+ */
+ def f1score(predict: String, target: String): DataFrame = {
+ // checkType(target, ArrayType(IntegerType))
+ // checkType(predict, ArrayType(IntegerType))
+ val udaf = HiveUDAFFunction(
+ "f1score",
+ new HiveFunctionWrapper("hivemall.evaluation.FMeasureUDAF"),
+ Seq(predict, target).map(df.col(_).expr),
+ isUDAFBridgeRequired = true)
+ .toAggregateExpression()
+ toDF((Alias(udaf, udaf.prettyName)() :: Nil).toSeq)
+ }
+
+ /**
+ * [[RelationalGroupedDataset]] has the three values as private fields, so, to inject Hivemall
+ * aggregate functions, we fetch them via Java Reflections.
+ */
+ private val df = getPrivateField[DataFrame]("org$apache$spark$sql$RelationalGroupedDataset$$df")
+ private val groupingExprs = getPrivateField[Seq[Expression]]("groupingExprs")
+ private val groupType = getPrivateField[RelationalGroupedDataset.GroupType]("groupType")
+
+ private def getPrivateField[T](name: String): T = {
+ val field = groupBy.getClass.getDeclaredField(name)
+ field.setAccessible(true)
+ field.get(groupBy).asInstanceOf[T]
+ }
+
+ private def toDF(aggExprs: Seq[Expression]): DataFrame = {
+ val aggregates = if (df.sqlContext.conf.dataFrameRetainGroupColumns) {
+ groupingExprs ++ aggExprs
+ } else {
+ aggExprs
+ }
+
+ val aliasedAgg = aggregates.map(alias)
+
+ groupType match {
+ case RelationalGroupedDataset.GroupByType =>
+ Dataset.ofRows(
+ df.sparkSession, Aggregate(groupingExprs, aliasedAgg, df.logicalPlan))
+ case RelationalGroupedDataset.RollupType =>
+ Dataset.ofRows(
+ df.sparkSession, Aggregate(Seq(Rollup(groupingExprs)), aliasedAgg, df.logicalPlan))
+ case RelationalGroupedDataset.CubeType =>
+ Dataset.ofRows(
+ df.sparkSession, Aggregate(Seq(Cube(groupingExprs)), aliasedAgg, df.logicalPlan))
+ case RelationalGroupedDataset.PivotType(pivotCol, values) =>
+ val aliasedGrps = groupingExprs.map(alias)
+ Dataset.ofRows(
+ df.sparkSession, Pivot(aliasedGrps, pivotCol, values, aggExprs, df.logicalPlan))
+ }
+ }
+
+ private def alias(expr: Expression): NamedExpression = expr match {
+ case u: UnresolvedAttribute => UnresolvedAlias(u)
+ case expr: NamedExpression => expr
+ case expr: Expression => Alias(expr, expr.prettyName)()
+ }
+
+ private def checkType(colName: String, expected: DataType) = {
+ val dataType = df.resolve(colName).dataType
+ if (dataType != expected) {
+ throw new AnalysisException(
+ s""""$colName" must be $expected, however it is $dataType""")
+ }
+ }
+}
+
+object HivemallGroupedDataset {
+
+ /**
+ * Implicitly inject the [[HivemallGroupedDataset]] into [[RelationalGroupedDataset]].
+ */
+ implicit def relationalGroupedDatasetToHivemallOne(
+ groupBy: RelationalGroupedDataset): HivemallGroupedDataset = {
+ new HivemallGroupedDataset(groupBy)
++
++ /**
++ * @see hivemall.ftvec.selection.SignalNoiseRatioUDAF
++ */
++ def snr(X: String, Y: String): DataFrame = {
++ val udaf = HiveUDAFFunction(
++ "snr",
++ new HiveFunctionWrapper("hivemall.ftvec.selection.SignalNoiseRatioUDAF"),
++ Seq(X, Y).map(df.col(_).expr),
++ isUDAFBridgeRequired = false)
++ .toAggregateExpression()
++ toDF(Seq(Alias(udaf, udaf.prettyName)()))
++ }
++
++ /**
++ * @see hivemall.tools.matrix.TransposeAndDotUDAF
++ */
++ def transpose_and_dot(X: String, Y: String): DataFrame = {
++ val udaf = HiveUDAFFunction(
++ "transpose_and_dot",
++ new HiveFunctionWrapper("hivemall.tools.matrix.TransposeAndDotUDAF"),
++ Seq(X, Y).map(df.col(_).expr),
++ isUDAFBridgeRequired = false)
++ .toAggregateExpression()
++ toDF(Seq(Alias(udaf, udaf.prettyName)()))
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/67ba9631/spark/spark-2.0/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/67ba9631/spark/spark-2.0/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
----------------------------------------------------------------------
diff --cc spark/spark-2.0/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
index a093e07,8446677..8bea975
--- a/spark/spark-2.0/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
+++ b/spark/spark-2.0/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
@@@ -1,31 -1,28 +1,37 @@@
/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
*
- * http://www.apache.org/licenses/LICENSE-2.0
+ * http://www.apache.org/licenses/LICENSE-2.0
*
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
*/
-
package org.apache.spark.sql.hive
+import org.apache.spark.sql.{AnalysisException, Column, Row}
+import org.apache.spark.sql.functions
+import org.apache.spark.sql.hive.HivemallGroupedDataset._
+import org.apache.spark.sql.hive.HivemallOps._
+import org.apache.spark.sql.hive.HivemallUtils._
+import org.apache.spark.sql.types._
+import org.apache.spark.test.{HivemallFeatureQueryTest, TestUtils, VectorQueryTest}
+import org.apache.spark.test.TestDoubleWrapper._
+ import org.apache.spark.sql.hive.HivemallOps._
+ import org.apache.spark.sql.hive.HivemallUtils._
+ import org.apache.spark.sql.types._
+ import org.apache.spark.sql.{AnalysisException, Column, Row, functions}
+ import org.apache.spark.test.TestDoubleWrapper._
+ import org.apache.spark.test.{HivemallFeatureQueryTest, TestUtils, VectorQueryTest}
final class HivemallOpsWithFeatureSuite extends HivemallFeatureQueryTest {
@@@ -636,30 -685,63 +681,88 @@@
assert(row4(0).getDouble(1) ~== 0.25)
}
+ test("user-defined aggregators for ftvec.trans") {
+ import hiveContext.implicits._
+
+ val df0 = Seq((1, "cat", "mammal", 9), (1, "dog", "mammal", 10), (1, "human", "mammal", 10),
+ (1, "seahawk", "bird", 101), (1, "wasp", "insect", 3), (1, "wasp", "insect", 9),
+ (1, "cat", "mammal", 101), (1, "dog", "mammal", 1), (1, "human", "mammal", 9))
+ .toDF("col0", "cat1", "cat2", "cat3")
+ val row00 = df0.groupBy($"col0").onehot_encoding("cat1")
+ val row01 = df0.groupBy($"col0").onehot_encoding("cat1", "cat2", "cat3")
+
+ val result000 = row00.collect()(0).getAs[Row](1).getAs[Map[String, Int]](0)
+ val result01 = row01.collect()(0).getAs[Row](1)
+ val result010 = result01.getAs[Map[String, Int]](0)
+ val result011 = result01.getAs[Map[String, Int]](1)
+ val result012 = result01.getAs[Map[String, Int]](2)
+
+ assert(result000.keySet === Set("seahawk", "cat", "human", "wasp", "dog"))
+ assert(result000.values.toSet === Set(1, 2, 3, 4, 5))
+ assert(result010.keySet === Set("seahawk", "cat", "human", "wasp", "dog"))
+ assert(result010.values.toSet === Set(1, 2, 3, 4, 5))
+ assert(result011.keySet === Set("bird", "insect", "mammal"))
+ assert(result011.values.toSet === Set(6, 7, 8))
+ assert(result012.keySet === Set(1, 3, 9, 10, 101))
+ assert(result012.values.toSet === Set(9, 10, 11, 12, 13))
++
+ test("user-defined aggregators for ftvec.selection") {
+ import hiveContext.implicits._
+
+ // see also hivemall.ftvec.selection.SignalNoiseRatioUDAFTest
+ // binary class
+ // +-----------------+-------+
+ // | features | class |
+ // +-----------------+-------+
+ // | 5.1,3.5,1.4,0.2 | 0 |
+ // | 4.9,3.0,1.4,0.2 | 0 |
+ // | 4.7,3.2,1.3,0.2 | 0 |
+ // | 7.0,3.2,4.7,1.4 | 1 |
+ // | 6.4,3.2,4.5,1.5 | 1 |
+ // | 6.9,3.1,4.9,1.5 | 1 |
+ // +-----------------+-------+
+ val df0 = Seq(
+ (1, Seq(5.1, 3.5, 1.4, 0.2), Seq(1, 0)), (1, Seq(4.9, 3.0, 1.4, 0.2), Seq(1, 0)),
+ (1, Seq(4.7, 3.2, 1.3, 0.2), Seq(1, 0)), (1, Seq(7.0, 3.2, 4.7, 1.4), Seq(0, 1)),
+ (1, Seq(6.4, 3.2, 4.5, 1.5), Seq(0, 1)), (1, Seq(6.9, 3.1, 4.9, 1.5), Seq(0, 1)))
+ .toDF("c0", "arg0", "arg1")
+ val row0 = df0.groupby($"c0").snr("arg0", "arg1").collect
+ (row0(0).getAs[Seq[Double]](1), Seq(4.38425236, 0.26390002, 15.83984511, 26.87005769))
+ .zipped
+ .foreach((actual, expected) => assert(actual ~== expected))
+
+ // multiple class
+ // +-----------------+-------+
+ // | features | class |
+ // +-----------------+-------+
+ // | 5.1,3.5,1.4,0.2 | 0 |
+ // | 4.9,3.0,1.4,0.2 | 0 |
+ // | 7.0,3.2,4.7,1.4 | 1 |
+ // | 6.4,3.2,4.5,1.5 | 1 |
+ // | 6.3,3.3,6.0,2.5 | 2 |
+ // | 5.8,2.7,5.1,1.9 | 2 |
+ // +-----------------+-------+
+ val df1 = Seq(
+ (1, Seq(5.1, 3.5, 1.4, 0.2), Seq(1, 0, 0)), (1, Seq(4.9, 3.0, 1.4, 0.2), Seq(1, 0, 0)),
+ (1, Seq(7.0, 3.2, 4.7, 1.4), Seq(0, 1, 0)), (1, Seq(6.4, 3.2, 4.5, 1.5), Seq(0, 1, 0)),
+ (1, Seq(6.3, 3.3, 6.0, 2.5), Seq(0, 0, 1)), (1, Seq(5.8, 2.7, 5.1, 1.9), Seq(0, 0, 1)))
+ .toDF("c0", "arg0", "arg1")
+ val row1 = df1.groupby($"c0").snr("arg0", "arg1").collect
+ (row1(0).getAs[Seq[Double]](1), Seq(8.43181818, 1.32121212, 42.94949495, 33.80952381))
+ .zipped
+ .foreach((actual, expected) => assert(actual ~== expected))
+ }
+
+ test("user-defined aggregators for tools.matrix") {
+ import hiveContext.implicits._
+
+ // | 1 2 3 |T | 5 6 7 |
+ // | 3 4 5 | * | 7 8 9 |
+ val df0 = Seq((1, Seq(1, 2, 3), Seq(5, 6, 7)), (1, Seq(3, 4, 5), Seq(7, 8, 9)))
+ .toDF("c0", "arg0", "arg1")
+
+ checkAnswer(df0.groupby($"c0").transpose_and_dot("arg0", "arg1"),
+ Seq(Row(1, Seq(Seq(26.0, 30.0, 34.0), Seq(38.0, 44.0, 50.0), Seq(50.0, 58.0, 66.0)))))
}
}