You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hivemall.apache.org by my...@apache.org on 2019/10/04 05:28:56 UTC
[incubator-hivemall] branch master updated: [HIVEMALL-267] Drop
Spark Dataframe support (SparkSQL remain supported)
This is an automated email from the ASF dual-hosted git repository.
myui pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-hivemall.git
The following commit(s) were added to refs/heads/master by this push:
new ff3693d [HIVEMALL-267] Drop Spark Dataframe support (SparkSQL remain supported)
ff3693d is described below
commit ff3693d122b6e681b793985f06076a2c56561619
Author: Makoto Yui <my...@apache.org>
AuthorDate: Fri Oct 4 14:28:49 2019 +0900
[HIVEMALL-267] Drop Spark Dataframe support (SparkSQL remain supported)
## What changes were proposed in this pull request?
Drop Spark Dataframe support (SparkSQL remain supported).
## What type of PR is it?
Hot Fix, Refactoring
## What is the Jira issue?
https://issues.apache.org/jira/browse/HIVEMALL-267
## How was this patch tested?
unit tests, manual tests
## Checklist
(Please remove this section if not needed; check `x` for YES, blank for NO)
- [x] Did you apply source code formatter, i.e., `./bin/format_code.sh`, for your commit?
- [ ] Did you run system tests on Hive (or Spark)?
Author: Makoto Yui <my...@apache.org>
Closes #201 from myui/HIVEMALL-267.
---
.gitignore | 5 -
bin/format_header.sh | 5 -
bin/run_travis_tests.sh | 10 +-
bin/spark-shell | 137 --
conf/spark-defaults.conf | 31 -
docs/gitbook/SUMMARY.md | 8 +-
docs/gitbook/spark/binaryclass/a9a_df.md | 105 -
docs/gitbook/spark/getting_started/README.md | 20 -
docs/gitbook/spark/getting_started/installation.md | 31 +-
docs/gitbook/spark/misc/functions.md | 116 -
docs/gitbook/spark/misc/misc.md | 18 -
docs/gitbook/spark/misc/topk_join.md | 216 --
docs/gitbook/spark/regression/e2006_df.md | 109 -
pom.xml | 1 -
resources/ddl/import-packages.spark | 28 -
spark/common/pom.xml | 64 -
...LogisticRegressionDataGeneratorUDTFWrapper.java | 109 -
.../java/hivemall/ftvec/AddBiasUDFWrapper.java | 84 -
.../hivemall/ftvec/AddFeatureIndexUDFWrapper.java | 85 -
.../hivemall/ftvec/ExtractFeatureUDFWrapper.java | 73 -
.../hivemall/ftvec/ExtractWeightUDFWrapper.java | 74 -
.../hivemall/ftvec/SortByFeatureUDFWrapper.java | 95 -
.../ftvec/scaling/L2NormalizationUDFWrapper.java | 97 -
.../java/hivemall/knn/lsh/MinHashesUDFWrapper.java | 94 -
.../hivemall/tools/mapred/RowIdUDFWrapper.java | 71 -
.../main/scala/hivemall/HivemallException.scala | 25 -
.../spark/ml/feature/HivemallLabeledPoint.scala | 82 -
spark/pom.xml | 311 ---
spark/scalastyle-config.xml | 333 ---
spark/spark-2.2/bin/mvn-zinc | 99 -
spark/spark-2.2/extra-src/README.md | 20 -
.../scala/org/apache/spark/sql/hive/HiveShim.scala | 279 ---
spark/spark-2.2/pom.xml | 142 --
.../java/hivemall/xgboost/XGBoostOptions.scala | 59 -
...org.apache.spark.sql.sources.DataSourceRegister | 1 -
.../spark-2.2/src/main/resources/log4j.properties | 29 -
.../scala/hivemall/tools/RegressionDatagen.scala | 67 -
.../spark/sql/catalyst/expressions/EachTopK.scala | 135 --
.../sql/catalyst/plans/logical/JoinTopK.scala | 68 -
.../catalyst/utils/InternalRowPriorityQueue.scala | 76 -
.../spark/sql/execution/UserProvidedPlanner.scala | 83 -
.../execution/datasources/csv/csvExpressions.scala | 169 --
.../execution/joins/ShuffledHashJoinTopKExec.scala | 405 ----
.../spark/sql/hive/HivemallGroupedDataset.scala | 636 ------
.../org/apache/spark/sql/hive/HivemallOps.scala | 2260 --------------------
.../org/apache/spark/sql/hive/HivemallUtils.scala | 146 --
.../spark/sql/hive/internal/HivemallOpsImpl.scala | 79 -
.../spark/sql/hive/source/XGBoostFileFormat.scala | 163 --
.../spark/streaming/HivemallStreamingOps.scala | 47 -
.../src/test/resources/data/files/README.md | 22 -
.../src/test/resources/data/files/complex.seq | 0
.../src/test/resources/data/files/episodes.avro | 0
.../src/test/resources/data/files/json.txt | 0
.../src/test/resources/data/files/kv1.txt | 0
.../src/test/resources/data/files/kv3.txt | 0
.../spark-2.2/src/test/resources/log4j.properties | 24 -
.../scala/hivemall/mix/server/MixServerSuite.scala | 124 --
.../hivemall/tools/RegressionDatagenSuite.scala | 33 -
.../scala/org/apache/spark/SparkFunSuite.scala | 51 -
.../ml/feature/HivemallLabeledPointSuite.scala | 36 -
.../scala/org/apache/spark/sql/QueryTest.scala | 360 ----
.../apache/spark/sql/catalyst/plans/PlanTest.scala | 137 --
.../sql/execution/benchmark/BenchmarkBase.scala | 56 -
.../org/apache/spark/sql/hive/HiveUdfSuite.scala | 161 --
.../apache/spark/sql/hive/HivemallOpsSuite.scala | 1397 ------------
.../apache/spark/sql/hive/ModelMixingSuite.scala | 286 ---
.../org/apache/spark/sql/hive/XGBoostSuite.scala | 151 --
.../spark/sql/hive/benchmark/MiscBenchmark.scala | 268 ---
.../sql/hive/test/HivemallFeatureQueryTest.scala | 102 -
.../spark/sql/hive/test/TestHiveSingleton.scala | 39 -
.../org/apache/spark/sql/test/SQLTestData.scala | 315 ---
.../org/apache/spark/sql/test/SQLTestUtils.scala | 336 ---
.../apache/spark/sql/test/VectorQueryTest.scala | 89 -
.../streaming/HivemallOpsWithFeatureSuite.scala | 155 --
.../scala/org/apache/spark/test/TestUtils.scala | 65 -
spark/spark-2.3/bin/mvn-zinc | 99 -
spark/spark-2.3/extra-src/README.md | 20 -
.../scala/org/apache/spark/sql/hive/HiveShim.scala | 279 ---
spark/spark-2.3/pom.xml | 190 --
.../java/hivemall/xgboost/XGBoostOptions.scala | 59 -
...org.apache.spark.sql.sources.DataSourceRegister | 1 -
.../spark-2.3/src/main/resources/log4j.properties | 29 -
.../scala/hivemall/tools/RegressionDatagen.scala | 67 -
.../spark/sql/catalyst/expressions/EachTopK.scala | 135 --
.../sql/catalyst/plans/logical/JoinTopK.scala | 68 -
.../catalyst/utils/InternalRowPriorityQueue.scala | 76 -
.../spark/sql/execution/UserProvidedPlanner.scala | 83 -
.../execution/datasources/csv/csvExpressions.scala | 169 --
.../execution/joins/ShuffledHashJoinTopKExec.scala | 402 ----
.../spark/sql/hive/HivemallGroupedDataset.scala | 636 ------
.../org/apache/spark/sql/hive/HivemallOps.scala | 2260 --------------------
.../org/apache/spark/sql/hive/HivemallUtils.scala | 146 --
.../spark/sql/hive/internal/HivemallOpsImpl.scala | 79 -
.../spark/sql/hive/source/XGBoostFileFormat.scala | 165 --
.../spark/streaming/HivemallStreamingOps.scala | 47 -
.../src/test/resources/data/files/README.md | 22 -
.../src/test/resources/data/files/complex.seq | 0
.../src/test/resources/data/files/episodes.avro | 0
.../src/test/resources/data/files/json.txt | 0
.../src/test/resources/data/files/kv1.txt | 0
.../src/test/resources/data/files/kv3.txt | 0
.../spark-2.3/src/test/resources/log4j.properties | 24 -
.../scala/hivemall/mix/server/MixServerSuite.scala | 124 --
.../hivemall/tools/RegressionDatagenSuite.scala | 33 -
.../ml/feature/HivemallLabeledPointSuite.scala | 36 -
.../benchmark/BenchmarkBaseAccessor.scala | 23 -
.../org/apache/spark/sql/hive/HiveUdfSuite.scala | 161 --
.../apache/spark/sql/hive/HivemallOpsSuite.scala | 1398 ------------
.../apache/spark/sql/hive/ModelMixingSuite.scala | 286 ---
.../org/apache/spark/sql/hive/XGBoostSuite.scala | 154 --
.../spark/sql/hive/benchmark/MiscBenchmark.scala | 268 ---
.../sql/hive/test/HivemallFeatureQueryTest.scala | 102 -
.../apache/spark/sql/test/VectorQueryTest.scala | 89 -
.../streaming/HivemallOpsWithFeatureSuite.scala | 155 --
.../scala/org/apache/spark/test/TestUtils.scala | 65 -
src/site/markdown/overview.md | 2 +-
116 files changed, 16 insertions(+), 19543 deletions(-)
diff --git a/.gitignore b/.gitignore
index 3ba5593..0742ad9 100644
--- a/.gitignore
+++ b/.gitignore
@@ -10,13 +10,8 @@ logs
*.iml
.DS_Store
*~
-bin/scala*
-bin/spark-*-bin-*
bin/apache-maven-*
-scalastyle-output.xml
-scalastyle.txt
derby.log
-spark/bin/zinc-*
*.dylib
*.so
.classpath
diff --git a/bin/format_header.sh b/bin/format_header.sh
index f2c063b..f14207a 100755
--- a/bin/format_header.sh
+++ b/bin/format_header.sh
@@ -34,8 +34,3 @@ HIVEMALL_HOME=`pwd`
mvn license:format
-cd $HIVEMALL_HOME/spark/spark-common
-mvn license:format -P spark-2.0
-
-cd $HIVEMALL_HOME/spark/spark-2.0
-mvn license:format -P spark-2.0
diff --git a/bin/run_travis_tests.sh b/bin/run_travis_tests.sh
index f5b5da6..3f0b090 100755
--- a/bin/run_travis_tests.sh
+++ b/bin/run_travis_tests.sh
@@ -31,15 +31,7 @@ fi
set -ev
-cd $HIVEMALL_HOME/spark
-
-export MAVEN_OPTS="-XX:MaxMetaspaceSize=256m"
-
-# spark-2.2 runs on Java 8+
-if [ ! -z "$(java -version 2>&1 | grep 1.8)" ]; then
- mvn -q scalastyle:check clean -Djava.source.version=1.8 -Djava.target.version=1.8 \
- -pl spark-2.2,spark-2.3 -am test
-fi
+mvn clean test
exit 0
diff --git a/bin/spark-shell b/bin/spark-shell
deleted file mode 100755
index d19c67b..0000000
--- a/bin/spark-shell
+++ /dev/null
@@ -1,137 +0,0 @@
-#!/usr/bin/env bash
-#
-# Licensed to the Apache Software Foundation (ASF) under one or more
-# contributor license agreements. See the NOTICE file distributed with
-# this work for additional information regarding copyright ownership.
-# The ASF licenses this file to You under the Apache License, Version 2.0
-# (the "License"); you may not use this file except in compliance with
-# the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-# Determine the current working directory
-_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
-# Preserve the calling directory
-_CALLING_DIR="$(pwd)"
-
-# Download any application given a URL
-## Arg1 - Remote URL
-## Arg2 - Local file name
-download_app() {
- local remote_url="$1"
- local local_name="$2"
-
- # setup `curl` and `wget` options
- local curl_opts="--progress-bar -L"
- local wget_opts="--progress=bar:force"
-
- # check if we already have the given application
- # check if we have curl installed
- # download application
- [ ! -f "${local_name}" ] && [ $(command -v curl) ] && \
- echo "exec: curl ${curl_opts} ${remote_url}" 1>&2 && \
- curl ${curl_opts} "${remote_url}" > "${local_name}"
- # if the file still doesn't exist, lets try `wget` and cross our fingers
- [ ! -f "${local_name}" ] && [ $(command -v wget) ] && \
- echo "exec: wget ${wget_opts} ${remote_url}" 1>&2 && \
- wget ${wget_opts} -O "${local_name}" "${remote_url}"
- # if both were unsuccessful, exit
- [ ! -f "${local_name}" ] && \
- echo -n "ERROR: Cannot download $2 with cURL or wget; " && \
- echo "please install manually and try again." && \
- exit 2
-}
-
-# Installs any application tarball given a URL, the expected tarball name,
-# and, optionally, a checkable binary path to determine if the binary has
-# already been installed
-## Arg1 - URL
-## Arg2 - Tarball Name
-## Arg3 - Checkable Binary
-install_app() {
- local remote_tarball="$1/$2"
- local local_tarball="${_DIR}/$2"
- local binary="${_DIR}/$3"
-
- if [ -z "$3" -o ! -f "$binary" ]; then
- download_app "${remote_tarball}" "${local_tarball}"
- cd "${_DIR}" && tar -xzf "$2"
- rm -rf "$local_tarball"
- fi
-}
-
-# Determine the Spark version from the root pom.xml file and
-# install Spark under the bin/ folder if needed.
-install_spark() {
- local SPARK_VERSION=`grep "<spark.version>" "${_DIR}/../pom.xml" | head -n1 | awk -F '[<>]' '{print $3}'`
- local HADOOP_VERSION=`grep "<hadoop.version>" "${_DIR}/../pom.xml" | head -n1 | awk -F '[<>]' '{print $3}' | cut -d '.' -f1-2`
- local SPARK_DIR="${_DIR}/spark-${SPARK_VERSION}-bin-hadoop${HADOOP_VERSION}"
- local APACHE_MIRROR=${APACHE_MIRROR:-'http://d3kbcqa49mib13.cloudfront.net'}
-
- install_app \
- "${APACHE_MIRROR}" \
- "spark-${SPARK_VERSION}-bin-hadoop${HADOOP_VERSION}.tgz" \
- "spark-${SPARK_VERSION}-bin-hadoop${HADOOP_VERSION}/bin/spark-shell"
-
- SPARK_BIN="${SPARK_DIR}/bin/spark-shell"
-}
-
-# Determine the Maven version from the root pom.xml file and
-# install maven under the build/ folder if needed.
-install_mvn() {
- local MVN_VERSION="3.3.9"
- MVN_BIN="$(command -v mvn)"
- if [ "$MVN_BIN" ]; then
- local MVN_DETECTED_VERSION="$(mvn --version | head -n1 | awk '{print $3}')"
- fi
- # See simple version normalization: http://stackoverflow.com/questions/16989598/bash-comparing-version-numbers
- function version { echo "$@" | awk -F. '{ printf("%03d%03d%03d\n", $1,$2,$3); }'; }
- if [ $(version $MVN_DETECTED_VERSION) -lt $(version $MVN_VERSION) ]; then
- local APACHE_MIRROR=${APACHE_MIRROR:-'https://www.apache.org/dyn/closer.lua?action=download&filename='}
-
- install_app \
- "${APACHE_MIRROR}/maven/maven-3/${MVN_VERSION}/binaries" \
- "apache-maven-${MVN_VERSION}-bin.tar.gz" \
- "apache-maven-${MVN_VERSION}/bin/mvn"
-
- MVN_BIN="${_DIR}/apache-maven-${MVN_VERSION}/bin/mvn"
- fi
-}
-
-# Compile hivemall for the latest Spark release
-compile_hivemall() {
- local HIVEMALL_VERSION=`grep "<version>" "${_DIR}/../pom.xml" | head -n1 | awk -F '[<>]' '{print $3}'`
- local SCALA_VERSION=`grep "<scala.binary.version>" "${_DIR}/../pom.xml" | head -n1 | awk -F '[<>]' '{print $3}'`
- local SPARK_VERSION=`grep "<spark.binary.version>" "${_DIR}/../pom.xml" | head -n1 | awk -F '[<>]' '{print $3}'`
-
- HIVEMALL_BIN="${_DIR}/../target/hivemall-spark-${SPARK_VERSION}_${SCALA_VERSION}-${HIVEMALL_VERSION}-with-dependencies.jar"
- if [ ! -f "${HIVEMALL_BIN}" ]; then
- install_mvn && ${MVN_BIN} validate && ${MVN_BIN} -f "${_DIR}/../pom.xml" clean package -P"spark-${SPARK_VERSION}" -DskipTests
- if [ $? = 127 ]; then
- echo "Failed to compile hivemall for spark-${SPARK_VERSION}"
- exit 1
- fi
- fi
-}
-
-# Install the proper version of Spark for launching spark-shell
-install_spark
-
-# Compile hivemall for the Spark version
-compile_hivemall
-
-# Reset the current working directory
-cd "${_CALLING_DIR}"
-
-echo "Using \`spark-shell\` from path: $SPARK_BIN" 1>&2
-
-# Last, call the `spark-shell` command as usual
-${SPARK_BIN} --properties-file ${_DIR}/../conf/spark-defaults.conf "$@"
-
diff --git a/conf/spark-defaults.conf b/conf/spark-defaults.conf
deleted file mode 100644
index 52a43fb..0000000
--- a/conf/spark-defaults.conf
+++ /dev/null
@@ -1,31 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one or more
-# contributor license agreements. See the NOTICE file distributed with
-# this work for additional information regarding copyright ownership.
-# The ASF licenses this file to You under the Apache License, Version 2.0
-# (the "License"); you may not use this file except in compliance with
-# the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-# Default system properties included when running spark-submit.
-# This is useful for setting default environmental settings.
-
-# Example:
-# spark.master spark://master:7077
-# spark.eventLog.enabled true
-# spark.eventLog.dir hdfs://namenode:8021/directory
-# spark.serializer org.apache.spark.serializer.KryoSerializer
-# spark.driver.memory 5g
-# spark.executor.extraJavaOptions -XX:+PrintGCDetails -Dkey=value -Dnumbers="one two three"
-
-# We assume that the latest Spark loads this configuration via ./bin/spark-shell
-spark.jars ./target/hivemall-spark-2.1_2.11-0.4.2-rc.2-with-dependencies.jar
-
diff --git a/docs/gitbook/SUMMARY.md b/docs/gitbook/SUMMARY.md
index 02fc97e..c50e78e 100644
--- a/docs/gitbook/SUMMARY.md
+++ b/docs/gitbook/SUMMARY.md
@@ -185,23 +185,17 @@
* [Lat/Lon functions](geospatial/latlon.md)
-## Part XIII - Hivemall on Spark
+## Part XIII - Hivemall on SparkSQL
* [Getting Started](spark/getting_started/README.md)
* [Installation](spark/getting_started/installation.md)
* [Binary Classification](spark/binaryclass/index.md)
- * [a9a Tutorial for DataFrame](spark/binaryclass/a9a_df.md)
* [a9a Tutorial for SQL](spark/binaryclass/a9a_sql.md)
* [Regression](spark/binaryclass/index.md)
- * [E2006-tfidf Regression Tutorial for DataFrame](spark/regression/e2006_df.md)
* [E2006-tfidf Regression Tutorial for SQL](spark/regression/e2006_sql.md)
-* [Generic Features](spark/misc/misc.md)
- * [Top-k Join Processing](spark/misc/topk_join.md)
- * [Other Utility Functions](spark/misc/functions.md)
-
## Part XIV - Hivemall on Docker
* [Getting Started](docker/getting_started.md)
diff --git a/docs/gitbook/spark/binaryclass/a9a_df.md b/docs/gitbook/spark/binaryclass/a9a_df.md
deleted file mode 100644
index b9cb68b..0000000
--- a/docs/gitbook/spark/binaryclass/a9a_df.md
+++ /dev/null
@@ -1,105 +0,0 @@
-<!--
- Licensed to the Apache Software Foundation (ASF) under one
- or more contributor license agreements. See the NOTICE file
- distributed with this work for additional information
- regarding copyright ownership. The ASF licenses this file
- to you under the Apache License, Version 2.0 (the
- "License"); you may not use this file except in compliance
- with the License. You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing,
- software distributed under the License is distributed on an
- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- KIND, either express or implied. See the License for the
- specific language governing permissions and limitations
- under the License.
--->
-
-a9a
-===
-https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary.html#a9a
-
-Data preparation
-================
-
-```sh
-$ wget https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/a9a
-$ wget https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/a9a.t
-```
-
-```scala
-scala> :paste
-val rawTrainDf = spark.read.format("libsvm").load("a9a")
-
-val (max, min) = rawTrainDf.select(max($"label"), min($"label")).collect.map {
- case Row(max: Double, min: Double) => (max, min)
-}
-
-val trainDf = rawTrainDf.select(
- // `label` must be [0.0, 1.0]
- rescale($"label", lit(min), lit(max)).as("label"),
- $"features"
- )
-
-scala> trainDf.printSchema
-root
- |-- label: float (nullable = true)
- |-- features: vector (nullable = true)
-
-scala> :paste
-val testDf = spark.read.format("libsvm").load("a9a.t")
- .select(rowid(), rescale($"label", lit(min), lit(max)).as("label"), $"features")
- .explode_vector($"features")
- .select($"rowid", $"label".as("target"), $"feature", $"weight".as("value"))
- .cache
-
-scala> testDf.printSchema
-root
- |-- rowid: string (nullable = true)
- |-- target: float (nullable = true)
- |-- feature: string (nullable = true)
- |-- value: double (nullable = true)
-```
-
-Tutorials
-================
-
-[Logistic Regression]
----
-
-#Training
-
-```scala
-scala> :paste
-val modelDf = trainDf
- .train_logregr(append_bias($"features"), $"label")
- .groupBy("feature").avg("weight")
- .toDF("feature", "weight")
- .cache
-```
-
-#Test
-
-```scala
-scala> :paste
-val predictDf = testDf
- .join(modelDf, testDf("feature") === modelDf("feature"), "LEFT_OUTER")
- .select($"rowid", ($"weight" * $"value").as("value"))
- .groupBy("rowid").sum("value")
- .select(
- $"rowid",
- when(sigmoid($"sum(value)") > 0.5, 1.0).otherwise(0.0).as("predicted")
- )
-```
-
-#Evaluation
-
-```scala
-scala> val df = predictDf.join(testDf, predictDf("rowid").as("id") === testDf("rowid"), "INNER")
-
-scala> (df.where($"target" === $"predicted").count + 0.0) / df.count
-Double = 0.8327921286841418
-```
-
diff --git a/docs/gitbook/spark/getting_started/README.md b/docs/gitbook/spark/getting_started/README.md
deleted file mode 100644
index e4f5b68..0000000
--- a/docs/gitbook/spark/getting_started/README.md
+++ /dev/null
@@ -1,20 +0,0 @@
-<!--
- Licensed to the Apache Software Foundation (ASF) under one
- or more contributor license agreements. See the NOTICE file
- distributed with this work for additional information
- regarding copyright ownership. The ASF licenses this file
- to you under the Apache License, Version 2.0 (the
- "License"); you may not use this file except in compliance
- with the License. You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing,
- software distributed under the License is distributed on an
- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- KIND, either express or implied. See the License for the
- specific language governing permissions and limitations
- under the License.
--->
-
-# Summary
diff --git a/docs/gitbook/spark/getting_started/installation.md b/docs/gitbook/spark/getting_started/installation.md
index 7b9595d..d30b230 100644
--- a/docs/gitbook/spark/getting_started/installation.md
+++ b/docs/gitbook/spark/getting_started/installation.md
@@ -24,10 +24,6 @@ Prerequisites
* Java 7 or later
* `hivemall-spark-xxx-with-dependencies.jar` that can be found in [the ASF distribution mirror](https://www.apache.org/dyn/closer.cgi/incubator/hivemall/).
* [define-all.spark](https://github.com/apache/incubator-hivemall/blob/master/resources/ddl/define-all.spark)
-* [import-packages.spark](https://github.com/apache/incubator-hivemall/blob/master/resources/ddl/import-packages.spark)
-
-> #### Caution
-> You need to use a specific `hivemall-spark-xxx-with-dependencies.jar` for each Spark version.
Installation
============
@@ -35,18 +31,7 @@ Installation
First, you download a compiled Spark package from [the Spark official web page](https://spark.apache.org/downloads.html) and invoke spark-shell with a compiled Hivemall binary.
```
-$ ./bin/spark-shell --jars hivemall-spark-xxx-with-dependencies.jar
-```
-
-> #### Notice
-> If you would like to try Hivemall functions on the latest release of Spark, you just say `bin/spark-shell` in a Hivemall package.
-> This command automatically downloads the latest Spark version, compiles Hivemall for the version, and invokes spark-shell with the compiled Hivemall binary.
-
-Then, you load scripts for Hivemall functions.
-
-```
-scala> :load resources/ddl/define-all.spark
-scala> :load resources/ddl/import-packages.spark
+$ spark-shell --jars target/hivemall-all-<version>-incubating-SNAPSHOT.jar
```
Installation via [Spark Packages](https://spark-packages.org/package/apache-hivemall/apache-hivemall)
@@ -55,8 +40,18 @@ Installation via [Spark Packages](https://spark-packages.org/package/apache-hive
In another way to install Hivemall, you can use a `--packages` option.
```
-$ ./bin/spark-shell --packages apache-hivemall:apache-hivemall:0.5.1-<spark version>
+$ spark-shell --packages org.apache.hivemall:hivemall-all:<version>
```
-You need to set your Spark version at `<spark version>`, e.g., `spark2.2` for Spark v2.2.x.
+You find available Hivemall versions on [Maven repository](https://mvnrepository.com/artifact/org.apache.hivemall/hivemall-all/0.5.2-incubating).
+
+
+> #### Notice
+> If you would like to try Hivemall functions on the latest release of Spark, you just say `bin/spark-shell` in a Hivemall package.
+> This command automatically downloads the latest Spark version, compiles Hivemall for the version, and invokes spark-shell with the compiled Hivemall binary.
+
+Then, you load scripts for Hivemall functions.
+```
+scala> :load resources/ddl/define-all.spark
+```
\ No newline at end of file
diff --git a/docs/gitbook/spark/misc/functions.md b/docs/gitbook/spark/misc/functions.md
deleted file mode 100644
index fdc2292..0000000
--- a/docs/gitbook/spark/misc/functions.md
+++ /dev/null
@@ -1,116 +0,0 @@
-<!--
- Licensed to the Apache Software Foundation (ASF) under one
- or more contributor license agreements. See the NOTICE file
- distributed with this work for additional information
- regarding copyright ownership. The ASF licenses this file
- to you under the Apache License, Version 2.0 (the
- "License"); you may not use this file except in compliance
- with the License. You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing,
- software distributed under the License is distributed on an
- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- KIND, either express or implied. See the License for the
- specific language governing permissions and limitations
- under the License.
--->
-
-flatten
-================
-
-`df.flatten()` flattens a nested schema of `df` into a flat one.
-
-## Usage
-
-```scala
-scala> val df = Seq((0, (1, (3.0, "a")), (5, 0.9))).toDF()
-scala> df.printSchema
-root
- |-- _1: integer (nullable = false)
- |-- _2: struct (nullable = true)
- | |-- _1: integer (nullable = false)
- | |-- _2: struct (nullable = true)
- | | |-- _1: double (nullable = false)
- | | |-- _2: string (nullable = true)
- |-- _3: struct (nullable = true)
- | |-- _1: integer (nullable = false)
- | |-- _2: double (nullable = false)
-
-scala> df.flatten(separator = "$").printSchema
-root
- |-- _1: integer (nullable = false)
- |-- _2$_1: integer (nullable = true)
- |-- _2$_2$_1: double (nullable = true)
- |-- _2$_2$_2: string (nullable = true)
- |-- _3$_1: integer (nullable = true)
- |-- _3$_2: double (nullable = true)
-```
-
-from_csv
-================
-
-This function parses a column containing a CSV string into a `StructType`
-with the specified schema.
-
-## Usage
-
-```scala
-scala> val df = Seq("1, abc, 0.8").toDF()
-
-scala> df.printSchema
-root
- |-- value: string (nullable = true)
-
-scala> val schema = new StructType().add("a", IntegerType).add("b", StringType).add("c", DoubleType)
-
-scala> df.select(from_csv($"value", schema)).printSchema
-root
- |-- csvtostruct(value): struct (nullable = true)
- | |-- a: integer (nullable = true)
- | |-- b: string (nullable = true)
- | |-- c: double (nullable = true)
-
-scala> df.select(from_csv($"value", schema)).show
-+------------------+
-|csvtostruct(value)|
-+------------------+
-| [1, abc,0.8]|
-+------------------+
-```
-
-to_csv
-================
-
-This function converts a column containing a `StructType` into a CSV string
-with the specified schema.
-
-## Usage
-
-```scala
-scala> val df = Seq((1, "a", (0, 3.9, "abc")), (8, "c", (2, 0.4, "def"))).toDF()
-
-scala> df.printSchema
-root
- |-- _1: integer (nullable = false)
- |-- _2: string (nullable = true)
- |-- _3: struct (nullable = true)
- | |-- _1: integer (nullable = false)
- | |-- _2: double (nullable = false)
- | |-- _3: string (nullable = true)
-
-scala> df.select(to_csv($"_3"))
-
-scala> df.select(to_csv($"_3")).printSchema
-root
- |-- structtocsv(_3): string (nullable = true)
-
-scala> df.select(to_csv($"_3")).show
-+---------------+
-|structtocsv(_3)|
-+---------------+
-| 0,3.9,abc|
-| 2,0.4,def|
-+---------------+
-```
diff --git a/docs/gitbook/spark/misc/misc.md b/docs/gitbook/spark/misc/misc.md
deleted file mode 100644
index 0475c9c..0000000
--- a/docs/gitbook/spark/misc/misc.md
+++ /dev/null
@@ -1,18 +0,0 @@
-<!--
- Licensed to the Apache Software Foundation (ASF) under one
- or more contributor license agreements. See the NOTICE file
- distributed with this work for additional information
- regarding copyright ownership. The ASF licenses this file
- to you under the Apache License, Version 2.0 (the
- "License"); you may not use this file except in compliance
- with the License. You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing,
- software distributed under the License is distributed on an
- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- KIND, either express or implied. See the License for the
- specific language governing permissions and limitations
- under the License.
--->
diff --git a/docs/gitbook/spark/misc/topk_join.md b/docs/gitbook/spark/misc/topk_join.md
deleted file mode 100644
index eb10ab2..0000000
--- a/docs/gitbook/spark/misc/topk_join.md
+++ /dev/null
@@ -1,216 +0,0 @@
-<!--
- Licensed to the Apache Software Foundation (ASF) under one
- or more contributor license agreements. See the NOTICE file
- distributed with this work for additional information
- regarding copyright ownership. The ASF licenses this file
- to you under the Apache License, Version 2.0 (the
- "License"); you may not use this file except in compliance
- with the License. You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing,
- software distributed under the License is distributed on an
- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- KIND, either express or implied. See the License for the
- specific language governing permissions and limitations
- under the License.
--->
-
-`leftDf.top_k_join(k: Column, rightDf: DataFrame, joinExprs: Column, score: Column)` only joins the top-k records of `rightDf` for each `leftDf` record with a join condition `joinExprs`. An output schema of this operation is the joined schema of `leftDf` and `rightDf` plus (rank: Int, score: `score` type).
-
-`top_k_join` is much IO-efficient as compared to regular joining + ranking operations because `top_k_join` drops unsatisfied records and writes only top-k records to disks during joins.
-
-> #### Caution
-> * `top_k_join` is supported in the DataFrame of Spark v2.1.0 or later.
-> * A type of `score` must be ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType, or DecimalType.
-> * If `k` is less than 0, the order is reverse and `top_k_join` joins the tail-K records of `rightDf`.
-
-# Usage
-
-For example, we have two tables below;
-
-- An input table (`leftDf`)
-
-```scala
-scala> :paste
-val leftDf = Seq(
- (1, "b", 0.3, 0.3),
- (2, "a", 0.5, 0.4),
- (3, "a", 0.1, 0.8),
- (4, "c", 0.2, 0.2),
- (5, "a", 0.1, 0.4),
- (6, "b", 0.8, 0.8)
-).toDF("userId", "group", "x", "y")
-
-scala> leftDf.show
-+------+-----+---+---+
-|userId|group| x| y|
-+------+-----+---+---+
-| 1| b|0.3|0.3|
-| 2| a|0.5|0.4|
-| 3| a|0.1|0.8|
-| 4| c|0.2|0.2|
-| 5| a|0.1|0.4|
-| 6| b|0.8|0.8|
-+------+-----+---+---+
-```
-
-- A reference table (`rightDf`)
-
-```scala
-scala> :paste
-val rightDf = Seq(
- ("a", "pos1", 0.0, 0.1),
- ("a", "pos2", 0.9, 0.3),
- ("a", "pos3", 0.3, 0.2),
- ("b", "pos4", 0.5, 0.7),
- ("b", "pos5", 0.4, 0.2),
- ("c", "pos6", 0.8, 0.7),
- ("c", "pos7", 0.3, 0.3),
- ("c", "pos8", 0.4, 0.2),
- ("c", "pos9", 0.3, 0.8)
-).toDF("group", "position", "x", "y")
-
-scala> rightDf.show
-+-----+--------+---+---+
-|group|position| x| y|
-+-----+--------+---+---+
-| a| pos1|0.0|0.1|
-| a| pos2|0.9|0.3|
-| a| pos3|0.3|0.2|
-| b| pos4|0.5|0.7|
-| b| pos5|0.4|0.2|
-| c| pos6|0.8|0.7|
-| c| pos7|0.3|0.3|
-| c| pos8|0.4|0.2|
-| c| pos9|0.3|0.8|
-+-----+--------+---+---+
-```
-
-In the two tables, the example computes the nearest `position` for `userId` in each `group`.
-The standard way using DataFrame window functions would be as follows:
-
-```scala
-scala> paste:
-val computeDistanceFunc =
- sqrt(pow(inputDf("x") - masterDf("x"), lit(2.0)) + pow(inputDf("y") - masterDf("y"), lit(2.0)))
-
-val resultDf = leftDf.join(
- right = rightDf,
- joinExpr = leftDf("group") === rightDf("group")
- )
- .select(inputDf("group"), $"userId", $"posId", computeDistanceFunc.as("score"))
- .withColumn("rank", rank().over(Window.partitionBy($"group", $"userId").orderBy($"score".desc)))
- .where($"rank" <= 1)
-```
-
-You can use `top_k_join` as follows:
-
-```scala
-scala> paste:
-import org.apache.spark.sql.hive.HivemallOps._
-
-val resultDf = leftDf.top_k_join(
- k = lit(-1),
- right = rightDf,
- joinExpr = leftDf("group") === rightDf("group"),
- score = computeDistanceFunc.as("score")
- )
-```
-
-The result is as follows:
-
-```scala
-scala> resultDf.show
-+----+-------------------+------+-----+---+---+-----+--------+---+---+
-|rank| score|userId|group| x| y|group|position| x| y|
-+----+-------------------+------+-----+---+---+-----+--------+---+---+
-| 1|0.09999999999999998| 4| c|0.2|0.2| c| pos9|0.3|0.8|
-| 1|0.10000000000000003| 1| b|0.3|0.3| b| pos5|0.4|0.2|
-| 1|0.30000000000000004| 6| b|0.8|0.8| b| pos4|0.5|0.7|
-| 1| 0.2| 2| a|0.5|0.4| a| pos3|0.3|0.2|
-| 1| 0.1| 3| a|0.1|0.8| a| pos1|0.0|0.1|
-| 1| 0.1| 5| a|0.1|0.4| a| pos1|0.0|0.1|
-+----+-------------------+------+-----+---+---+-----+--------+---+---+
-```
-
-`top_k_join` is also useful for Spark Vector users.
-If you'd like to filter the records having the smallest squared distances between vectors, you can use `top_k_join` as follows;
-
-```scala
-scala> import org.apache.spark.ml.linalg._
-scala> import org.apache.spark.sql.hive.HivemallOps._
-scala> paste:
-val leftDf = Seq(
- (1, "a", Vectors.dense(Array(1.0, 0.5, 0.6, 0.2))),
- (2, "b", Vectors.dense(Array(0.2, 0.3, 0.4, 0.1))),
- (3, "a", Vectors.dense(Array(0.8, 0.4, 0.2, 0.6))),
- (4, "a", Vectors.dense(Array(0.2, 0.7, 0.4, 0.8))),
- (5, "c", Vectors.dense(Array(0.4, 0.5, 0.6, 0.2))),
- (6, "c", Vectors.dense(Array(0.3, 0.9, 1.0, 0.1)))
-).toDF("userId", "group", "vector")
-
-scala> leftDf.show
-+------+-----+-----------------+
-|userId|group| vector|
-+------+-----+-----------------+
-| 1| a|[1.0,0.5,0.6,0.2]|
-| 2| b|[0.2,0.3,0.4,0.1]|
-| 3| a|[0.8,0.4,0.2,0.6]|
-| 4| a|[0.2,0.7,0.4,0.8]|
-| 5| c|[0.4,0.5,0.6,0.2]|
-| 6| c|[0.3,0.9,1.0,0.1]|
-+------+-----+-----------------+
-
-scala> paste:
-val rightDf = Seq(
- ("a", "pos-1", Vectors.dense(Array(0.3, 0.4, 0.3, 0.5))),
- ("a", "pos-2", Vectors.dense(Array(0.9, 0.2, 0.8, 0.3))),
- ("a", "pos-3", Vectors.dense(Array(1.0, 0.0, 0.3, 0.1))),
- ("a", "pos-4", Vectors.dense(Array(0.1, 0.8, 0.5, 0.7))),
- ("b", "pos-5", Vectors.dense(Array(0.3, 0.3, 0.3, 0.8))),
- ("b", "pos-6", Vectors.dense(Array(0.0, 0.7, 0.5, 0.6))),
- ("b", "pos-7", Vectors.dense(Array(0.1, 0.8, 0.4, 0.5))),
- ("c", "pos-8", Vectors.dense(Array(0.8, 0.3, 0.2, 0.1))),
- ("c", "pos-9", Vectors.dense(Array(0.7, 0.5, 0.8, 0.3)))
- ).toDF("group", "position", "vector")
-
-scala> rightDf.show
-+-----+--------+-----------------+
-|group|position| vector|
-+-----+--------+-----------------+
-| a| pos-1|[0.3,0.4,0.3,0.5]|
-| a| pos-2|[0.9,0.2,0.8,0.3]|
-| a| pos-3|[1.0,0.0,0.3,0.1]|
-| a| pos-4|[0.1,0.8,0.5,0.7]|
-| b| pos-5|[0.3,0.3,0.3,0.8]|
-| b| pos-6|[0.0,0.7,0.5,0.6]|
-| b| pos-7|[0.1,0.8,0.4,0.5]|
-| c| pos-8|[0.8,0.3,0.2,0.1]|
-| c| pos-9|[0.7,0.5,0.8,0.3]|
-+-----+--------+-----------------+
-
-scala> paste:
-val sqDistFunc = udf { (v1: Vector, v2: Vector) => Vectors.sqdist(v1, v2) }
-
-val resultDf = leftDf.top_k_join(
- k = lit(-1),
- right = rightDf,
- joinExpr = leftDf("group") === rightDf("group"),
- score = sqDistFunc(leftDf("vector"), rightDf("vector")).as("score")
-)
-
-scala> resultDf.show
-+----+-------------------+------+-----+-----------------+-----+--------+-----------------+
-|rank| score|userId|group| vector|group|position| vector|
-+----+-------------------+------+-----+-----------------+-----+--------+-----------------+
-| 1|0.13999999999999996| 5| c|[0.4,0.5,0.6,0.2]| c| pos-9|[0.7,0.5,0.8,0.3]|
-| 1|0.39999999999999997| 6| c|[0.3,0.9,1.0,0.1]| c| pos-9|[0.7,0.5,0.8,0.3]|
-| 1|0.42000000000000004| 2| b|[0.2,0.3,0.4,0.1]| b| pos-7|[0.1,0.8,0.4,0.5]|
-| 1|0.15000000000000002| 1| a|[1.0,0.5,0.6,0.2]| a| pos-2|[0.9,0.2,0.8,0.3]|
-| 1| 0.27| 3| a|[0.8,0.4,0.2,0.6]| a| pos-1|[0.3,0.4,0.3,0.5]|
-| 1|0.04000000000000003| 4| a|[0.2,0.7,0.4,0.8]| a| pos-4|[0.1,0.8,0.5,0.7]|
-+----+-------------------+------+-----+-----------------+-----+--------+-----------------+
-```
-
diff --git a/docs/gitbook/spark/regression/e2006_df.md b/docs/gitbook/spark/regression/e2006_df.md
deleted file mode 100644
index 015ee00..0000000
--- a/docs/gitbook/spark/regression/e2006_df.md
+++ /dev/null
@@ -1,109 +0,0 @@
-<!--
- Licensed to the Apache Software Foundation (ASF) under one
- or more contributor license agreements. See the NOTICE file
- distributed with this work for additional information
- regarding copyright ownership. The ASF licenses this file
- to you under the Apache License, Version 2.0 (the
- "License"); you may not use this file except in compliance
- with the License. You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing,
- software distributed under the License is distributed on an
- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- KIND, either express or implied. See the License for the
- specific language governing permissions and limitations
- under the License.
--->
-
-E2006
-===
-https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/regression.html#E2006-tfidf
-
-Data preparation
-================
-
-```sh
-$ wget https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/regression/E2006.train.bz2
-$ wget https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/regression/E2006.test.bz2
-```
-
-```scala
-scala> :paste
-val rawTrainDf = spark.read.format("libsvm").load("E2006.train.bz2")
-
-val (max, min) = rawTrainDf.select(max($"label"), min($"label")).collect.map {
- case Row(max: Double, min: Double) => (max, min)
-}
-
-val trainDf = rawTrainDf.select(
- // `label` must be [0.0, 1.0]
- rescale($"label", lit(min), lit(max).as("label"),
- $"features"
- )
-
-scala> trainDf.printSchema
-root
- |-- label: float (nullable = true)
- |-- features: vector (nullable = true)
-
-scala> :paste
-val testDf = spark.read.format("libsvm").load("E2006.test.bz2")
- .select(rowid(), rescale($"label", lit(min), lit(max)).as("label"), $"features")
- .explode_vector($"features")
- .select($"rowid", $"label".as("target"), $"feature", $"weight".as("value"))
- .cache
-
-scala> df.printSchema
-root
- |-- rowid: string (nullable = true)
- |-- target: float (nullable = true)
- |-- feature: string (nullable = true)
- |-- value: double (nullable = true)
-```
-
-Tutorials
-================
-
-[AROWe2]
----
-
-#Training
-
-```scala
-scala> :paste
-val modelDf = trainDf
- .train_arowe2_regr(append_bias($"features"), $"label")
- .groupBy("feature").avg("weight")
- .toDF("feature", "weight")
- .cache
-```
-
-#Test
-
-```scala
-scala> :paste
-val predictDf = testDf
- .join(modelDf, testDf("feature") === modelDf("feature"), "LEFT_OUTER")
- .select($"rowid", ($"weight" * $"value").as("value"))
- .groupBy("rowid").sum("value")
- .select($"rowid", sigmoid($"sum(value)").as("predicted"))
-```
-
-#Evaluation
-
-```scala
-scala> :paste
-predictDf
- .join(testDf, predictDf("rowid").as("id") === testDf("rowid"), "INNER")
- .groupBy().avg("target", "predicted")
- .show()
-
-+------------------+------------------+
-| avg(target)| avg(predicted)|
-+------------------+------------------+
-|0.5489154884487879|0.6030108853227014|
-+------------------+------------------+
-```
-
diff --git a/pom.xml b/pom.xml
index b55b5f7..157b7db 100644
--- a/pom.xml
+++ b/pom.xml
@@ -250,7 +250,6 @@
<module>nlp</module>
<module>xgboost</module>
<module>mixserv</module>
- <module>spark</module>
<module>dist</module>
<module>tools</module>
</modules>
diff --git a/resources/ddl/import-packages.spark b/resources/ddl/import-packages.spark
deleted file mode 100644
index c3a4955..0000000
--- a/resources/ddl/import-packages.spark
+++ /dev/null
@@ -1,28 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/**
- * An initialization script for DataFrame use
- */
-
-import org.apache.spark.sql.hive.HivemallOps._
-import org.apache.spark.sql.hive.HivemallGroupedDataset._
-import org.apache.spark.sql.hive.HivemallUtils._
-import hivemall.xgboost.XGBoostOptions
-
diff --git a/spark/common/pom.xml b/spark/common/pom.xml
deleted file mode 100644
index 745475f..0000000
--- a/spark/common/pom.xml
+++ /dev/null
@@ -1,64 +0,0 @@
-<!--
- Licensed to the Apache Software Foundation (ASF) under one
- or more contributor license agreements. See the NOTICE file
- distributed with this work for additional information
- regarding copyright ownership. The ASF licenses this file
- to you under the Apache License, Version 2.0 (the
- "License"); you may not use this file except in compliance
- with the License. You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing,
- software distributed under the License is distributed on an
- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- KIND, either express or implied. See the License for the
- specific language governing permissions and limitations
- under the License.
--->
-<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
- <modelVersion>4.0.0</modelVersion>
-
- <parent>
- <groupId>org.apache.hivemall</groupId>
- <artifactId>hivemall-spark</artifactId>
- <version>0.6.0-incubating-SNAPSHOT</version>
- <relativePath>../pom.xml</relativePath>
- </parent>
-
- <artifactId>hivemall-spark-common</artifactId>
- <name>Hivemall on Spark Common</name>
- <packaging>jar</packaging>
-
- <properties>
- <main.basedir>${project.parent.parent.basedir}</main.basedir>
- </properties>
-
- <dependencies>
- <!-- provided scope -->
- <dependency>
- <groupId>org.apache.hadoop</groupId>
- <artifactId>hadoop-common</artifactId>
- <scope>provided</scope>
- </dependency>
- <dependency>
- <groupId>org.apache.hadoop</groupId>
- <artifactId>hadoop-mapreduce-client-core</artifactId>
- <scope>provided</scope>
- </dependency>
- <dependency>
- <groupId>org.apache.hive</groupId>
- <artifactId>hive-exec</artifactId>
- <scope>provided</scope>
- </dependency>
-
- <!-- compile scope -->
- <dependency>
- <groupId>org.apache.hivemall</groupId>
- <artifactId>hivemall-core</artifactId>
- <scope>compile</scope>
- </dependency>
- </dependencies>
-
-</project>
-
diff --git a/spark/common/src/main/java/hivemall/dataset/LogisticRegressionDataGeneratorUDTFWrapper.java b/spark/common/src/main/java/hivemall/dataset/LogisticRegressionDataGeneratorUDTFWrapper.java
deleted file mode 100644
index cf10ed7..0000000
--- a/spark/common/src/main/java/hivemall/dataset/LogisticRegressionDataGeneratorUDTFWrapper.java
+++ /dev/null
@@ -1,109 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package hivemall.dataset;
-
-import hivemall.UDTFWithOptions;
-
-import java.lang.reflect.Field;
-import java.lang.reflect.Method;
-import java.util.Random;
-
-import org.apache.commons.cli.CommandLine;
-import org.apache.commons.cli.Options;
-import org.apache.hadoop.hive.ql.exec.Description;
-import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
-import org.apache.hadoop.hive.ql.metadata.HiveException;
-import org.apache.hadoop.hive.ql.udf.generic.Collector;
-import org.apache.hadoop.hive.ql.udf.generic.GenericUDTF;
-import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
-import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
-
-/**
- * A wrapper of [[hivemall.dataset.LogisticRegressionDataGeneratorUDTF]]. This wrapper is needed
- * because Spark cannot handle HadoopUtils#getTaskId() correctly.
- */
-@Description(name = "lr_datagen",
- value = "_FUNC_(options string) - Generates a logistic regression dataset")
-public final class LogisticRegressionDataGeneratorUDTFWrapper extends UDTFWithOptions {
- private transient LogisticRegressionDataGeneratorUDTF udtf =
- new LogisticRegressionDataGeneratorUDTF();
-
- @Override
- protected Options getOptions() {
- Options options = null;
- try {
- Method m = udtf.getClass().getDeclaredMethod("getOptions");
- m.setAccessible(true);
- options = (Options) m.invoke(udtf);
- } catch (Exception e) {
- e.printStackTrace();
- }
- return options;
- }
-
- @SuppressWarnings("all")
- @Override
- protected CommandLine processOptions(ObjectInspector[] objectInspectors)
- throws UDFArgumentException {
- CommandLine commands = null;
- try {
- Method m = udtf.getClass().getDeclaredMethod("processOptions");
- m.setAccessible(true);
- commands = (CommandLine) m.invoke(udtf, objectInspectors);
- } catch (Exception e) {
- e.printStackTrace();
- }
- return commands;
- }
-
- @Override
- public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
- try {
- // Extract a collector for LogisticRegressionDataGeneratorUDTF
- Field collector = GenericUDTF.class.getDeclaredField("collector");
- collector.setAccessible(true);
- udtf.setCollector((Collector) collector.get(this));
-
- // To avoid HadoopUtils#getTaskId()
- Class<?> clazz = udtf.getClass();
- Field rnd1 = clazz.getDeclaredField("rnd1");
- Field rnd2 = clazz.getDeclaredField("rnd2");
- Field r_seed = clazz.getDeclaredField("r_seed");
- r_seed.setAccessible(true);
- final long seed = r_seed.getLong(udtf) + (int) Thread.currentThread().getId();
- rnd1.setAccessible(true);
- rnd2.setAccessible(true);
- rnd1.set(udtf, new Random(seed));
- rnd2.set(udtf, new Random(seed + 1));
- } catch (Exception e) {
- e.printStackTrace();
- }
- return udtf.initialize(argOIs);
- }
-
- @Override
- public void process(Object[] objects) throws HiveException {
- udtf.process(objects);
- }
-
- @Override
- public void close() throws HiveException {
- udtf.close();
- }
-}
diff --git a/spark/common/src/main/java/hivemall/ftvec/AddBiasUDFWrapper.java b/spark/common/src/main/java/hivemall/ftvec/AddBiasUDFWrapper.java
deleted file mode 100644
index e7da7cb..0000000
--- a/spark/common/src/main/java/hivemall/ftvec/AddBiasUDFWrapper.java
+++ /dev/null
@@ -1,84 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package hivemall.ftvec;
-
-import java.util.Arrays;
-import java.util.List;
-
-import org.apache.hadoop.hive.ql.exec.Description;
-import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
-import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
-import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
-import org.apache.hadoop.hive.ql.metadata.HiveException;
-import org.apache.hadoop.hive.ql.udf.UDFType;
-import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
-import org.apache.hadoop.hive.serde2.objectinspector.*;
-import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category;
-import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory;
-
-/**
- * A wrapper of [[hivemall.ftvec.AddBiasUDF]].
- *
- * NOTE: This is needed to avoid the issue of Spark reflection. That is, spark cannot handle List<>
- * as a return type in Hive UDF. Therefore, the type must be passed via ObjectInspector.
- */
-@Description(name = "add_bias",
- value = "_FUNC_(features in array<string>) - Returns features with a bias as array<string>")
-@UDFType(deterministic = true, stateful = false)
-public class AddBiasUDFWrapper extends GenericUDF {
- private AddBiasUDF udf = new AddBiasUDF();
- private ListObjectInspector argumentOI = null;
-
- @Override
- public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumentException {
- if (arguments.length != 1) {
- throw new UDFArgumentLengthException(
- "add_bias() has an single arguments: array<string> features");
- }
-
- switch (arguments[0].getCategory()) {
- case LIST:
- argumentOI = (ListObjectInspector) arguments[0];
- ObjectInspector elmOI = argumentOI.getListElementObjectInspector();
- if (elmOI.getCategory().equals(Category.PRIMITIVE)) {
- if (((PrimitiveObjectInspector) elmOI).getPrimitiveCategory() == PrimitiveCategory.STRING) {
- break;
- }
- }
- default:
- throw new UDFArgumentTypeException(0, "Type mismatch: features");
- }
-
- return ObjectInspectorFactory.getStandardListObjectInspector(
- argumentOI.getListElementObjectInspector());
- }
-
- @Override
- public Object evaluate(DeferredObject[] arguments) throws HiveException {
- assert (arguments.length == 1);
- @SuppressWarnings("unchecked")
- final List<String> input = (List<String>) argumentOI.getList(arguments[0].get());
- return udf.evaluate(input);
- }
-
- @Override
- public String getDisplayString(String[] children) {
- return "add_bias(" + Arrays.toString(children) + ")";
- }
-}
diff --git a/spark/common/src/main/java/hivemall/ftvec/AddFeatureIndexUDFWrapper.java b/spark/common/src/main/java/hivemall/ftvec/AddFeatureIndexUDFWrapper.java
deleted file mode 100644
index 6be3a9e..0000000
--- a/spark/common/src/main/java/hivemall/ftvec/AddFeatureIndexUDFWrapper.java
+++ /dev/null
@@ -1,85 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package hivemall.ftvec;
-
-import java.util.Arrays;
-import java.util.List;
-
-import org.apache.hadoop.hive.ql.exec.Description;
-import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
-import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
-import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
-import org.apache.hadoop.hive.ql.metadata.HiveException;
-import org.apache.hadoop.hive.ql.udf.UDFType;
-import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
-import org.apache.hadoop.hive.serde2.objectinspector.*;
-import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category;
-import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory;
-import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
-
-/**
- * A wrapper of [[hivemall.ftvec.AddFeatureIndexUDF]].
- *
- * NOTE: This is needed to avoid the issue of Spark reflection. That is, spark cannot handle List<>
- * as a return type in Hive UDF. Therefore, the type must be passed via ObjectInspector.
- */
-@Description(name = "add_feature_index",
- value = "_FUNC_(dense features in array<double>) - Returns a feature vector with feature indices")
-@UDFType(deterministic = true, stateful = false)
-public class AddFeatureIndexUDFWrapper extends GenericUDF {
- private AddFeatureIndexUDF udf = new AddFeatureIndexUDF();
- private ListObjectInspector argumentOI = null;
-
- @Override
- public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumentException {
- if (arguments.length != 1) {
- throw new UDFArgumentLengthException(
- "add_feature_index() has an single arguments: array<double> features");
- }
-
- switch (arguments[0].getCategory()) {
- case LIST:
- argumentOI = (ListObjectInspector) arguments[0];
- ObjectInspector elmOI = argumentOI.getListElementObjectInspector();
- if (elmOI.getCategory().equals(Category.PRIMITIVE)) {
- if (((PrimitiveObjectInspector) elmOI).getPrimitiveCategory() == PrimitiveCategory.DOUBLE) {
- break;
- }
- }
- default:
- throw new UDFArgumentTypeException(0, "Type mismatch: features");
- }
-
- return ObjectInspectorFactory.getStandardListObjectInspector(
- PrimitiveObjectInspectorFactory.javaStringObjectInspector);
- }
-
- @Override
- public Object evaluate(DeferredObject[] arguments) throws HiveException {
- assert (arguments.length == 1);
- @SuppressWarnings("unchecked")
- final List<Double> input = (List<Double>) argumentOI.getList(arguments[0].get());
- return udf.evaluate(input);
- }
-
- @Override
- public String getDisplayString(String[] children) {
- return "add_feature_index(" + Arrays.toString(children) + ")";
- }
-}
diff --git a/spark/common/src/main/java/hivemall/ftvec/ExtractFeatureUDFWrapper.java b/spark/common/src/main/java/hivemall/ftvec/ExtractFeatureUDFWrapper.java
deleted file mode 100644
index 5924468..0000000
--- a/spark/common/src/main/java/hivemall/ftvec/ExtractFeatureUDFWrapper.java
+++ /dev/null
@@ -1,73 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package hivemall.ftvec;
-
-import java.util.Arrays;
-
-import org.apache.hadoop.hive.ql.exec.Description;
-import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
-import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
-import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
-import org.apache.hadoop.hive.ql.metadata.HiveException;
-import org.apache.hadoop.hive.ql.udf.UDFType;
-import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
-import org.apache.hadoop.hive.serde2.objectinspector.*;
-import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory;
-import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
-
-/**
- * A wrapper of [[hivemall.ftvec.ExtractFeatureUDF]].
- *
- * NOTE: This is needed to avoid the issue of Spark reflection. That is, spark cannot handle List<>
- * as a return type in Hive UDF. Therefore, the type must be passed via ObjectInspector.
- */
-@Description(name = "extract_feature",
- value = "_FUNC_(feature in string) - Returns a parsed feature as string")
-@UDFType(deterministic = true, stateful = false)
-public class ExtractFeatureUDFWrapper extends GenericUDF {
- private ExtractFeatureUDF udf = new ExtractFeatureUDF();
- private PrimitiveObjectInspector argumentOI = null;
-
- @Override
- public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumentException {
- if (arguments.length != 1) {
- throw new UDFArgumentLengthException(
- "extract_feature() has an single arguments: string feature");
- }
-
- argumentOI = (PrimitiveObjectInspector) arguments[0];
- if (argumentOI.getPrimitiveCategory() != PrimitiveCategory.STRING) {
- throw new UDFArgumentTypeException(0, "Type mismatch: feature");
- }
-
- return PrimitiveObjectInspectorFactory.javaStringObjectInspector;
- }
-
- @Override
- public Object evaluate(DeferredObject[] arguments) throws HiveException {
- assert (arguments.length == 1);
- final String input = (String) argumentOI.getPrimitiveJavaObject(arguments[0].get());
- return udf.evaluate(input);
- }
-
- @Override
- public String getDisplayString(String[] children) {
- return "extract_feature(" + Arrays.toString(children) + ")";
- }
-}
diff --git a/spark/common/src/main/java/hivemall/ftvec/ExtractWeightUDFWrapper.java b/spark/common/src/main/java/hivemall/ftvec/ExtractWeightUDFWrapper.java
deleted file mode 100644
index b5ef807..0000000
--- a/spark/common/src/main/java/hivemall/ftvec/ExtractWeightUDFWrapper.java
+++ /dev/null
@@ -1,74 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package hivemall.ftvec;
-
-import java.util.Arrays;
-
-import org.apache.hadoop.hive.ql.exec.Description;
-import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
-import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
-import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
-import org.apache.hadoop.hive.ql.metadata.HiveException;
-import org.apache.hadoop.hive.ql.udf.UDFType;
-import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
-import org.apache.hadoop.hive.serde2.objectinspector.*;
-import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory;
-import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
-
-/**
- * A wrapper of [[hivemall.ftvec.ExtractWeightUDF]].
- *
- * NOTE: This is needed to avoid the issue of Spark reflection. That is, spark cannot handle List<>
- * as a return type in Hive UDF. Therefore, the type must be passed via ObjectInspector.
- */
-@Description(name = "extract_weight",
- value = "_FUNC_(feature in string) - Returns the weight of a feature as string")
-@UDFType(deterministic = true, stateful = false)
-public class ExtractWeightUDFWrapper extends GenericUDF {
- private ExtractWeightUDF udf = new ExtractWeightUDF();
- private PrimitiveObjectInspector argumentOI = null;
-
- @Override
- public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumentException {
- if (arguments.length != 1) {
- throw new UDFArgumentLengthException(
- "extract_weight() has an single arguments: string feature");
- }
-
- argumentOI = (PrimitiveObjectInspector) arguments[0];
- if (argumentOI.getPrimitiveCategory() != PrimitiveCategory.STRING) {
- throw new UDFArgumentTypeException(0, "Type mismatch: feature");
- }
-
- return PrimitiveObjectInspectorFactory.getPrimitiveWritableObjectInspector(
- PrimitiveCategory.DOUBLE);
- }
-
- @Override
- public Object evaluate(DeferredObject[] arguments) throws HiveException {
- assert (arguments.length == 1);
- final String input = (String) argumentOI.getPrimitiveJavaObject(arguments[0].get());
- return udf.evaluate(input);
- }
-
- @Override
- public String getDisplayString(String[] children) {
- return "extract_weight(" + Arrays.toString(children) + ")";
- }
-}
diff --git a/spark/common/src/main/java/hivemall/ftvec/SortByFeatureUDFWrapper.java b/spark/common/src/main/java/hivemall/ftvec/SortByFeatureUDFWrapper.java
deleted file mode 100644
index e13e030..0000000
--- a/spark/common/src/main/java/hivemall/ftvec/SortByFeatureUDFWrapper.java
+++ /dev/null
@@ -1,95 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package hivemall.ftvec;
-
-import java.util.Arrays;
-import java.util.Map;
-
-import org.apache.hadoop.hive.ql.exec.Description;
-import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
-import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
-import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
-import org.apache.hadoop.hive.ql.metadata.HiveException;
-import org.apache.hadoop.hive.ql.udf.UDFType;
-import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
-import org.apache.hadoop.hive.serde2.objectinspector.*;
-import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category;
-import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory;
-import org.apache.hadoop.io.FloatWritable;
-import org.apache.hadoop.io.IntWritable;
-
-/**
- * A wrapper of [[hivemall.ftvec.SortByFeatureUDF]].
- *
- * NOTE: This is needed to avoid the issue of Spark reflection. That is, spark cannot handle Map<>
- * as a return type in Hive UDF. Therefore, the type must be passed via ObjectInspector.
- */
-@Description(name = "sort_by_feature",
- value = "_FUNC_(map in map<int,float>) - Returns a sorted map")
-@UDFType(deterministic = true, stateful = false)
-public class SortByFeatureUDFWrapper extends GenericUDF {
- private SortByFeatureUDF udf = new SortByFeatureUDF();
- private MapObjectInspector argumentOI = null;
-
- @Override
- public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumentException {
- if (arguments.length != 1) {
- throw new UDFArgumentLengthException(
- "sorted_by_feature() has an single arguments: map<int, float> map");
- }
-
- switch (arguments[0].getCategory()) {
- case MAP:
- argumentOI = (MapObjectInspector) arguments[0];
- ObjectInspector keyOI = argumentOI.getMapKeyObjectInspector();
- ObjectInspector valueOI = argumentOI.getMapValueObjectInspector();
- if (keyOI.getCategory().equals(Category.PRIMITIVE)
- && valueOI.getCategory().equals(Category.PRIMITIVE)) {
- final PrimitiveCategory keyCategory =
- ((PrimitiveObjectInspector) keyOI).getPrimitiveCategory();
- final PrimitiveCategory valueCategory =
- ((PrimitiveObjectInspector) valueOI).getPrimitiveCategory();
- if (keyCategory == PrimitiveCategory.INT
- && valueCategory == PrimitiveCategory.FLOAT) {
- break;
- }
- }
- default:
- throw new UDFArgumentTypeException(0, "Type mismatch: map");
- }
-
-
- return ObjectInspectorFactory.getStandardMapObjectInspector(
- argumentOI.getMapKeyObjectInspector(), argumentOI.getMapValueObjectInspector());
- }
-
- @Override
- public Object evaluate(DeferredObject[] arguments) throws HiveException {
- assert (arguments.length == 1);
- @SuppressWarnings("unchecked")
- final Map<IntWritable, FloatWritable> input =
- (Map<IntWritable, FloatWritable>) argumentOI.getMap(arguments[0].get());
- return udf.evaluate(input);
- }
-
- @Override
- public String getDisplayString(String[] children) {
- return "sort_by_feature(" + Arrays.toString(children) + ")";
- }
-}
diff --git a/spark/common/src/main/java/hivemall/ftvec/scaling/L2NormalizationUDFWrapper.java b/spark/common/src/main/java/hivemall/ftvec/scaling/L2NormalizationUDFWrapper.java
deleted file mode 100644
index dcdba24..0000000
--- a/spark/common/src/main/java/hivemall/ftvec/scaling/L2NormalizationUDFWrapper.java
+++ /dev/null
@@ -1,97 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package hivemall.ftvec.scaling;
-
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.List;
-
-import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
-import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
-import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
-import org.apache.hadoop.hive.ql.metadata.HiveException;
-import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
-import org.apache.hadoop.hive.serde2.objectinspector.*;
-import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category;
-import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory;
-import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorConverters.Converter;
-import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory.ObjectInspectorOptions;
-import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
-import org.apache.hadoop.io.Text;
-
-/**
- * A wrapper of [[hivemall.ftvec.scaling.L2NormalizationUDF]].
- *
- * NOTE: This is needed to avoid the issue of Spark reflection. That is, spark-1.3 cannot handle
- * List<> as a return type in Hive UDF. The type must be passed via ObjectInspector. This issues has
- * been reported in SPARK-6747, so a future release of Spark makes the wrapper obsolete.
- */
-public class L2NormalizationUDFWrapper extends GenericUDF {
- private L2NormalizationUDF udf = new L2NormalizationUDF();
-
- private transient List<Text> retValue = new ArrayList<Text>();
- private transient Converter toListText = null;
-
- @Override
- public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumentException {
- if (arguments.length != 1) {
- throw new UDFArgumentLengthException("normalize() has an only single argument.");
- }
-
- switch (arguments[0].getCategory()) {
- case LIST:
- ObjectInspector elmOI =
- ((ListObjectInspector) arguments[0]).getListElementObjectInspector();
- if (elmOI.getCategory().equals(Category.PRIMITIVE)) {
- if (((PrimitiveObjectInspector) elmOI).getPrimitiveCategory() == PrimitiveCategory.STRING) {
- break;
- }
- }
- default:
- throw new UDFArgumentTypeException(0,
- "normalize() must have List[String] as an argument, but "
- + arguments[0].getTypeName() + " was found.");
- }
-
- // Create a ObjectInspector converter for arguments
- ObjectInspector outputElemOI = ObjectInspectorFactory.getReflectionObjectInspector(
- Text.class, ObjectInspectorOptions.JAVA);
- ObjectInspector outputOI =
- ObjectInspectorFactory.getStandardListObjectInspector(outputElemOI);
- toListText = ObjectInspectorConverters.getConverter(arguments[0], outputOI);
-
- ObjectInspector listElemOI = PrimitiveObjectInspectorFactory.javaStringObjectInspector;
- ObjectInspector returnElemOI = ObjectInspectorUtils.getStandardObjectInspector(listElemOI);
- return ObjectInspectorFactory.getStandardListObjectInspector(returnElemOI);
- }
-
- @Override
- public Object evaluate(DeferredObject[] arguments) throws HiveException {
- assert (arguments.length == 1);
- @SuppressWarnings("unchecked")
- final List<Text> input = (List<Text>) toListText.convert(arguments[0].get());
- retValue = udf.evaluate(input);
- return retValue;
- }
-
- @Override
- public String getDisplayString(String[] children) {
- return "normalize(" + Arrays.toString(children) + ")";
- }
-}
diff --git a/spark/common/src/main/java/hivemall/knn/lsh/MinHashesUDFWrapper.java b/spark/common/src/main/java/hivemall/knn/lsh/MinHashesUDFWrapper.java
deleted file mode 100644
index 3c1fe9b..0000000
--- a/spark/common/src/main/java/hivemall/knn/lsh/MinHashesUDFWrapper.java
+++ /dev/null
@@ -1,94 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package hivemall.knn.lsh;
-
-import java.util.Arrays;
-import java.util.List;
-
-import org.apache.hadoop.hive.ql.exec.Description;
-import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
-import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
-import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
-import org.apache.hadoop.hive.ql.metadata.HiveException;
-import org.apache.hadoop.hive.ql.udf.UDFType;
-import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
-import org.apache.hadoop.hive.serde2.objectinspector.*;
-import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category;
-import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory;
-import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
-import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
-
-/** A wrapper of [[hivemall.knn.lsh.MinHashesUDF]]. */
-@Description(name = "minhashes",
- value = "_FUNC_(features in array<string>, noWeight in boolean) - Returns hashed features as array<int>")
-@UDFType(deterministic = true, stateful = false)
-public class MinHashesUDFWrapper extends GenericUDF {
- private MinHashesUDF udf = new MinHashesUDF();
- private ListObjectInspector featuresOI = null;
- private PrimitiveObjectInspector noWeightOI = null;
-
- @Override
- public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumentException {
- if (arguments.length != 2) {
- throw new UDFArgumentLengthException(
- "minhashes() has 2 arguments: array<string> features, boolean noWeight");
- }
-
- // Check argument types
- switch (arguments[0].getCategory()) {
- case LIST:
- featuresOI = (ListObjectInspector) arguments[0];
- ObjectInspector elmOI = featuresOI.getListElementObjectInspector();
- if (elmOI.getCategory().equals(Category.PRIMITIVE)) {
- if (((PrimitiveObjectInspector) elmOI).getPrimitiveCategory() == PrimitiveCategory.STRING) {
- break;
- }
- }
- default:
- throw new UDFArgumentTypeException(0, "Type mismatch: features");
- }
-
- noWeightOI = (PrimitiveObjectInspector) arguments[1];
- if (noWeightOI.getPrimitiveCategory() != PrimitiveCategory.BOOLEAN) {
- throw new UDFArgumentException("Type mismatch: noWeight");
- }
-
- return ObjectInspectorFactory.getStandardListObjectInspector(
- PrimitiveObjectInspectorFactory.getPrimitiveWritableObjectInspector(
- PrimitiveCategory.INT));
- }
-
- @Override
- public Object evaluate(DeferredObject[] arguments) throws HiveException {
- assert (arguments.length == 2);
- @SuppressWarnings("unchecked")
- final List<String> features = (List<String>) featuresOI.getList(arguments[0].get());
- final Boolean noWeight =
- PrimitiveObjectInspectorUtils.getBoolean(arguments[1].get(), noWeightOI);
- return udf.evaluate(features, noWeight);
- }
-
- @Override
- public String getDisplayString(String[] children) {
- /**
- * TODO: Need to return hive-specific type names.
- */
- return "minhashes(" + Arrays.toString(children) + ")";
- }
-}
diff --git a/spark/common/src/main/java/hivemall/tools/mapred/RowIdUDFWrapper.java b/spark/common/src/main/java/hivemall/tools/mapred/RowIdUDFWrapper.java
deleted file mode 100644
index e907c38..0000000
--- a/spark/common/src/main/java/hivemall/tools/mapred/RowIdUDFWrapper.java
+++ /dev/null
@@ -1,71 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package hivemall.tools.mapred;
-
-import java.util.UUID;
-
-import org.apache.hadoop.hive.ql.exec.Description;
-import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
-import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
-import org.apache.hadoop.hive.ql.metadata.HiveException;
-import org.apache.hadoop.hive.ql.udf.UDFType;
-import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
-import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
-import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
-
-/** An alternative implementation of [[hivemall.tools.mapred.RowIdUDF]]. */
-@Description(name = "rowid",
- value = "_FUNC_() - Returns a generated row id of a form {TASK_ID}-{UUID}-{SEQUENCE_NUMBER}")
-@UDFType(deterministic = false, stateful = true)
-public class RowIdUDFWrapper extends GenericUDF {
- // RowIdUDF is directly used because spark cannot
- // handle HadoopUtils#getTaskId().
-
- private long sequence;
- private long taskId;
-
- public RowIdUDFWrapper() {
- this.sequence = 0L;
- this.taskId = Thread.currentThread().getId();
- }
-
- @Override
- public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumentException {
- if (arguments.length != 0) {
- throw new UDFArgumentLengthException("row_number() has no argument.");
- }
-
- return PrimitiveObjectInspectorFactory.javaStringObjectInspector;
- }
-
- @Override
- public Object evaluate(DeferredObject[] arguments) throws HiveException {
- assert (arguments.length == 0);
- sequence++;
- /**
- * TODO: Check if it is unique over all tasks in executors of Spark.
- */
- return taskId + "-" + UUID.randomUUID() + "-" + sequence;
- }
-
- @Override
- public String getDisplayString(String[] children) {
- return "row_number()";
- }
-}
diff --git a/spark/common/src/main/scala/hivemall/HivemallException.scala b/spark/common/src/main/scala/hivemall/HivemallException.scala
deleted file mode 100644
index 53f6756..0000000
--- a/spark/common/src/main/scala/hivemall/HivemallException.scala
+++ /dev/null
@@ -1,25 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package hivemall
-
-class HivemallException(message: String, cause: Throwable)
- extends Exception(message, cause) {
-
- def this(message: String) = this(message, null)
-}
diff --git a/spark/common/src/main/scala/org/apache/spark/ml/feature/HivemallLabeledPoint.scala b/spark/common/src/main/scala/org/apache/spark/ml/feature/HivemallLabeledPoint.scala
deleted file mode 100644
index 3fb2d18..0000000
--- a/spark/common/src/main/scala/org/apache/spark/ml/feature/HivemallLabeledPoint.scala
+++ /dev/null
@@ -1,82 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.spark.ml.feature
-
-import java.util.StringTokenizer
-
-import scala.collection.mutable.ListBuffer
-
-import hivemall.HivemallException
-
-// Used for DataFrame#explode
-case class HivemallFeature(feature: String)
-
-/**
- * Class that represents the features and labels of a data point for Hivemall.
- *
- * @param label Label for this data point.
- * @param features List of features for this data point.
- */
-case class HivemallLabeledPoint(label: Float = 0.0f, features: Seq[String]) {
- override def toString: String = {
- "%s,%s".format(label, features.mkString("[", ",", "]"))
- }
-}
-
-object HivemallLabeledPoint {
-
- // Simple parser for HivemallLabeledPoint
- def parse(s: String): HivemallLabeledPoint = {
- val (label, features) = s.indexOf(',') match {
- case d if d > 0 => (s.substring(0, d), s.substring(d + 1))
- case _ => ("0.0", "[]") // Dummy
- }
- HivemallLabeledPoint(label.toFloat, parseTuple(new StringTokenizer(features, "[],", true)))
- }
-
- // TODO: Support to parse rows without labels
- private[this] def parseTuple(tokenizer: StringTokenizer): Seq[String] = {
- val items = ListBuffer.empty[String]
- var parsing = true
- var allowDelim = false
- while (parsing && tokenizer.hasMoreTokens()) {
- val token = tokenizer.nextToken()
- if (token == "[") {
- items ++= parseTuple(tokenizer)
- parsing = false
- allowDelim = true
- } else if (token == ",") {
- if (allowDelim) {
- allowDelim = false
- } else {
- throw new HivemallException("Found ',' at a wrong position.")
- }
- } else if (token == "]") {
- parsing = false
- } else {
- items.append(token)
- allowDelim = true
- }
- }
- if (parsing) {
- throw new HivemallException(s"A tuple must end with ']'.")
- }
- items
- }
-}
diff --git a/spark/pom.xml b/spark/pom.xml
deleted file mode 100644
index b4288c7..0000000
--- a/spark/pom.xml
+++ /dev/null
@@ -1,311 +0,0 @@
-<!--
- Licensed to the Apache Software Foundation (ASF) under one
- or more contributor license agreements. See the NOTICE file
- distributed with this work for additional information
- regarding copyright ownership. The ASF licenses this file
- to you under the Apache License, Version 2.0 (the
- "License"); you may not use this file except in compliance
- with the License. You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing,
- software distributed under the License is distributed on an
- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- KIND, either express or implied. See the License for the
- specific language governing permissions and limitations
- under the License.
--->
-<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
- <modelVersion>4.0.0</modelVersion>
-
- <parent>
- <groupId>org.apache.hivemall</groupId>
- <artifactId>hivemall</artifactId>
- <version>0.6.0-incubating-SNAPSHOT</version>
- <relativePath>../pom.xml</relativePath>
- </parent>
-
- <artifactId>hivemall-spark</artifactId>
- <packaging>pom</packaging>
- <name>Hivemall on Apache Spark</name>
-
- <modules>
- <module>common</module>
- <module>spark-2.2</module>
- <module>spark-2.3</module>
- </modules>
-
- <properties>
- <main.basedir>${project.parent.basedir}</main.basedir>
- <scala.version>2.11.8</scala.version>
- <scala.binary.version>2.11</scala.binary.version>
- <scalatest.jvm.opts>-ea -Xms768m -Xmx1024m -XX:PermSize=128m -XX:MaxMetaspaceSize=512m -XX:ReservedCodeCacheSize=512m</scalatest.jvm.opts>
- </properties>
-
- <dependencyManagement>
- <dependencies>
- <!-- compile scope -->
- <dependency>
- <groupId>org.apache.hivemall</groupId>
- <artifactId>hivemall-core</artifactId>
- <version>${project.version}</version>
- <scope>compile</scope>
- <exclusions>
- <exclusion>
- <groupId>io.netty</groupId>
- <artifactId>netty-all</artifactId>
- </exclusion>
- </exclusions>
- </dependency>
- <dependency>
- <groupId>org.apache.hivemall</groupId>
- <artifactId>hivemall-xgboost</artifactId>
- <version>${project.version}</version>
- <scope>compile</scope>
- </dependency>
- <dependency>
- <groupId>org.apache.commons</groupId>
- <artifactId>commons-compress</artifactId>
- <version>1.8</version>
- <scope>compile</scope>
- </dependency>
-
- <!-- provided scope -->
- <dependency>
- <groupId>org.scala-lang</groupId>
- <artifactId>scala-library</artifactId>
- <version>${scala.version}</version>
- <scope>provided</scope>
- </dependency>
-
- <!-- test dependencies -->
- <dependency>
- <groupId>org.apache.hivemall</groupId>
- <artifactId>hivemall-mixserv</artifactId>
- <version>${project.version}</version>
- <scope>test</scope>
- </dependency>
- <dependency>
- <groupId>org.scalatest</groupId>
- <artifactId>scalatest_${scala.binary.version}</artifactId>
- <version>2.2.6</version>
- <scope>test</scope>
- </dependency>
- </dependencies>
- </dependencyManagement>
-
- <build>
- <directory>target</directory>
- <outputDirectory>target/classes</outputDirectory>
- <finalName>${project.artifactId}-${project.version}</finalName>
- <testOutputDirectory>target/test-classes</testOutputDirectory>
-
- <pluginManagement>
- <plugins>
- <plugin>
- <groupId>net.alchim31.maven</groupId>
- <artifactId>scala-maven-plugin</artifactId>
- <version>3.2.2</version>
- </plugin>
- <plugin>
- <groupId>org.scalatest</groupId>
- <artifactId>scalatest-maven-plugin</artifactId>
- <version>1.0</version>
- <configuration>
- <reportsDirectory>${project.build.directory}/surefire-reports</reportsDirectory>
- <junitxml>.</junitxml>
- <filereports>SparkTestSuite.txt</filereports>
- <argLine>${scalatest.jvm.opts}</argLine>
- <stderr />
- <environmentVariables>
- <SPARK_PREPEND_CLASSES>1</SPARK_PREPEND_CLASSES>
- <SPARK_SCALA_VERSION>${scala.binary.version}</SPARK_SCALA_VERSION>
- <SPARK_TESTING>1</SPARK_TESTING>
- <JAVA_HOME>${env.JAVA_HOME}</JAVA_HOME>
- <PATH>${env.JAVA_HOME}/bin:${env.PATH}</PATH>
- </environmentVariables>
- <systemProperties>
- <log4j.configuration>file:src/test/resources/log4j.properties</log4j.configuration>
- <derby.system.durability>test</derby.system.durability>
- <java.awt.headless>true</java.awt.headless>
- <java.io.tmpdir>${project.build.directory}/tmp</java.io.tmpdir>
- <spark.testing>1</spark.testing>
- <spark.ui.enabled>false</spark.ui.enabled>
- <spark.ui.showConsoleProgress>false</spark.ui.showConsoleProgress>
- <spark.unsafe.exceptionOnMemoryLeak>true</spark.unsafe.exceptionOnMemoryLeak>
- <!-- Needed by sql/hive tests. -->
- <test.src.tables>__not_used__</test.src.tables>
- </systemProperties>
- <tagsToExclude>${test.exclude.tags}</tagsToExclude>
- </configuration>
- </plugin>
- <!-- hivemall-spark_xx-xx-with-dependencies.jar including minimum dependencies -->
- <plugin>
- <groupId>org.apache.maven.plugins</groupId>
- <artifactId>maven-shade-plugin</artifactId>
- <executions>
- <execution>
- <id>jar-with-dependencies</id>
- <phase>package</phase>
- <goals>
- <goal>shade</goal>
- </goals>
- <configuration>
- <finalName>${project.artifactId}-${project.version}-with-dependencies</finalName>
- <outputDirectory>${main.basedir}/target</outputDirectory>
- <minimizeJar>false</minimizeJar>
- <createDependencyReducedPom>false</createDependencyReducedPom>
- <createSourcesJar>true</createSourcesJar>
- <artifactSet>
- <includes>
- <include>org.apache.hivemall:hivemall-spark-common</include>
- <!-- hivemall-core -->
- <include>org.apache.hivemall:hivemall-core</include>
- <!--
- Since `netty-all` is bundled in Spark, we don't need to include it here
- <include>io.netty:netty-all</include>
- -->
- <include>com.github.haifengl:smile-core</include>
- <include>com.github.haifengl:smile-math</include>
- <include>com.github.haifengl:smile-data</include>
- <include>org.tukaani:xz</include>
- <include>org.apache.commons:commons-math3</include>
- <include>org.roaringbitmap:RoaringBitmap</include>
- <include>it.unimi.dsi:fastutil</include>
- <include>com.clearspring.analytics:stream</include>
- <!-- hivemall-nlp -->
- <include>org.apache.hivemall:hivemall-nlp</include>
- <include>org.apache.lucene:lucene-analyzers-kuromoji</include>
- <include>org.apache.lucene:lucene-analyzers-smartcn</include>
- <include>org.apache.lucene:lucene-analyzers-common</include>
- <include>org.apache.lucene:lucene-core</include>
- <!-- hivemall-xgboost -->
- <include>org.apache.hivemall:hivemall-xgboost</include>
- <include>io.github.myui:xgboost4j</include>
- <include>com.esotericsoftware.kryo:kryo</include>
- </includes>
- </artifactSet>
- <transformers>
- <transformer implementation="org.apache.maven.plugins.shade.resource.ManifestResourceTransformer">
- <manifestEntries>
- <Implementation-Title>${project.name}</Implementation-Title>
- <Implementation-Version>${project.version}</Implementation-Version>
- <Implementation-Vendor>${project.organization.name}</Implementation-Vendor>
- </manifestEntries>
- </transformer>
- </transformers>
- <filters>
- <filter>
- <artifact>org.apache.lucene:*</artifact>
- <includes>
- <include>**</include>
- </includes>
- </filter>
- <filter>
- <artifact>com.esotericsoftware.kryo:kryo</artifact>
- <includes>
- <include>**</include>
- </includes>
- </filter>
- <filter>
- <artifact>*:*</artifact>
- <excludes>
- <exclude>META-INF/LICENSE.txt</exclude>
- <exclude>META-INF/NOTICE.txt</exclude>
- <exclude>META-INF/*.SF</exclude>
- <exclude>META-INF/*.DSA</exclude>
- <exclude>META-INF/*.RSA</exclude>
- <exclude>*.jar</exclude>
- <exclude>tracker.py</exclude>
- </excludes>
- </filter>
- </filters>
- </configuration>
- </execution>
- </executions>
- </plugin>
- <plugin>
- <groupId>org.scalastyle</groupId>
- <artifactId>scalastyle-maven-plugin</artifactId>
- <version>0.8.0</version>
- </plugin>
- </plugins>
- </pluginManagement>
-
- <plugins>
- <!-- disable Java API compatibility checks for Spark modules -->
- <plugin>
- <groupId>org.codehaus.mojo</groupId>
- <artifactId>animal-sniffer-maven-plugin</artifactId>
- <configuration>
- <skip>true</skip>
- </configuration>
- </plugin>
- <plugin>
- <groupId>org.scalastyle</groupId>
- <artifactId>scalastyle-maven-plugin</artifactId>
- <configuration>
- <verbose>false</verbose>
- <failOnViolation>true</failOnViolation>
- <includeTestSourceDirectory>true</includeTestSourceDirectory>
- <failOnWarning>false</failOnWarning>
- <sourceDirectory>${basedir}/src/main/scala</sourceDirectory>
- <testSourceDirectory>${basedir}/src/test/scala</testSourceDirectory>
- <configLocation>${main.basedir}/spark/scalastyle-config.xml</configLocation>
- <outputFile>${basedir}/target/scalastyle-output.xml</outputFile>
- <inputEncoding>${project.build.sourceEncoding}</inputEncoding>
- <outputEncoding>${project.reporting.outputEncoding}</outputEncoding>
- </configuration>
- <executions>
- <execution>
- <goals>
- <goal>check</goal>
- </goals>
- </execution>
- </executions>
- </plugin>
- <plugin>
- <groupId>net.alchim31.maven</groupId>
- <artifactId>scala-maven-plugin</artifactId>
- <executions>
- <execution>
- <id>scala-compile-first</id>
- <phase>process-resources</phase>
- <goals>
- <goal>add-source</goal>
- <goal>compile</goal>
- </goals>
- </execution>
- <execution>
- <id>scala-test-compile</id>
- <phase>process-test-resources</phase>
- <goals>
- <goal>testCompile</goal>
- </goals>
- </execution>
- </executions>
- <!-- For incremental compilation -->
- <configuration>
- <scalaVersion>${scala.version}</scalaVersion>
- <recompileMode>incremental</recompileMode>
- <useZincServer>true</useZincServer>
- <args>
- <arg>-unchecked</arg>
- <arg>-deprecation</arg>
- <!-- TODO: To enable this option, we need to fix many wornings -->
- <!-- <arg>-feature</arg> -->
- </args>
- <jvmArgs>
- <jvmArg>-Xms768m</jvmArg>
- <jvmArg>-Xmx1024m</jvmArg>
- <jvmArg>-XX:PermSize=128m</jvmArg>
- <jvmArg>-XX:MaxMetaspaceSize=512m</jvmArg>
- <jvmArg>-XX:ReservedCodeCacheSize=512m</jvmArg>
- </jvmArgs>
- </configuration>
- </plugin>
- </plugins>
- </build>
-
-</project>
diff --git a/spark/scalastyle-config.xml b/spark/scalastyle-config.xml
deleted file mode 100644
index 13d1c47..0000000
--- a/spark/scalastyle-config.xml
+++ /dev/null
@@ -1,333 +0,0 @@
-<!--
- Licensed to the Apache Software Foundation (ASF) under one
- or more contributor license agreements. See the NOTICE file
- distributed with this work for additional information
- regarding copyright ownership. The ASF licenses this file
- to you under the Apache License, Version 2.0 (the
- "License"); you may not use this file except in compliance
- with the License. You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing,
- software distributed under the License is distributed on an
- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- KIND, either express or implied. See the License for the
- specific language governing permissions and limitations
- under the License.
--->
-
-<!--
-If you wish to turn off checking for a section of code, you can put a comment in the source
-before and after the section, with the following syntax:
-
- // scalastyle:off
- ... // stuff that breaks the styles
- // scalastyle:on
-
-You can also disable only one rule, by specifying its rule id, as specified in:
- http://www.scalastyle.org/rules-0.7.0.html
-
- // scalastyle:off no.finalize
- override def finalize(): Unit = ...
- // scalastyle:on no.finalize
-
-This file is divided into 3 sections:
- (1) rules that we enforce.
- (2) rules that we would like to enforce, but haven't cleaned up the codebase to turn on yet
- (or we need to make the scalastyle rule more configurable).
- (3) rules that we don't want to enforce.
--->
-
-<scalastyle>
- <name>Scalastyle standard configuration</name>
-
- <!-- ================================================================================ -->
- <!-- rules we enforce -->
- <!-- ================================================================================ -->
-
- <check level="error" class="org.scalastyle.file.FileTabChecker" enabled="true"></check>
-
- <check level="error" class="org.scalastyle.file.HeaderMatchesChecker" enabled="true">
- <parameters>
- <parameter name="header"><![CDATA[/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */]]></parameter>
- </parameters>
- </check>
-
- <check level="error" class="org.scalastyle.scalariform.SpacesAfterPlusChecker" enabled="true"></check>
-
- <check level="error" class="org.scalastyle.scalariform.SpacesBeforePlusChecker" enabled="true"></check>
-
- <check level="error" class="org.scalastyle.file.WhitespaceEndOfLineChecker" enabled="true"></check>
-
- <check level="error" class="org.scalastyle.file.FileLineLengthChecker" enabled="true">
- <parameters>
- <parameter name="maxLineLength"><![CDATA[100]]></parameter>
- <parameter name="tabSize"><![CDATA[2]]></parameter>
- <parameter name="ignoreImports">true</parameter>
- </parameters>
- </check>
-
- <check level="error" class="org.scalastyle.scalariform.ClassNamesChecker" enabled="true">
- <parameters><parameter name="regex"><![CDATA[[A-Z][A-Za-z]*]]></parameter></parameters>
- </check>
-
- <check level="error" class="org.scalastyle.scalariform.ObjectNamesChecker" enabled="true">
- <parameters><parameter name="regex"><![CDATA[[A-Z][A-Za-z]*]]></parameter></parameters>
- </check>
-
- <check level="error" class="org.scalastyle.scalariform.PackageObjectNamesChecker" enabled="true">
- <parameters><parameter name="regex"><![CDATA[^[a-z][A-Za-z]*$]]></parameter></parameters>
- </check>
-
- <check level="error" class="org.scalastyle.scalariform.ParameterNumberChecker" enabled="true">
- <parameters><parameter name="maxParameters"><![CDATA[10]]></parameter></parameters>
- </check>
-
- <check level="error" class="org.scalastyle.scalariform.NoFinalizeChecker" enabled="true"></check>
-
- <check level="error" class="org.scalastyle.scalariform.CovariantEqualsChecker" enabled="true"></check>
-
- <check level="error" class="org.scalastyle.scalariform.StructuralTypeChecker" enabled="true"></check>
-
- <check level="error" class="org.scalastyle.scalariform.UppercaseLChecker" enabled="true"></check>
-
- <check level="error" class="org.scalastyle.scalariform.IfBraceChecker" enabled="true">
- <parameters>
- <parameter name="singleLineAllowed"><![CDATA[true]]></parameter>
- <parameter name="doubleLineAllowed"><![CDATA[true]]></parameter>
- </parameters>
- </check>
-
- <check level="error" class="org.scalastyle.scalariform.PublicMethodsHaveTypeChecker" enabled="true"></check>
-
- <check level="error" class="org.scalastyle.file.NewLineAtEofChecker" enabled="true"></check>
-
- <check customId="nonascii" level="error" class="org.scalastyle.scalariform.NonASCIICharacterChecker" enabled="true"></check>
-
- <check level="error" class="org.scalastyle.scalariform.SpaceAfterCommentStartChecker" enabled="true"></check>
-
- <check level="error" class="org.scalastyle.scalariform.EnsureSingleSpaceBeforeTokenChecker" enabled="true">
- <parameters>
- <parameter name="tokens">ARROW, EQUALS, ELSE, TRY, CATCH, FINALLY, LARROW, RARROW</parameter>
- </parameters>
- </check>
-
- <check level="error" class="org.scalastyle.scalariform.EnsureSingleSpaceAfterTokenChecker" enabled="true">
- <parameters>
- <parameter name="tokens">ARROW, EQUALS, COMMA, COLON, IF, ELSE, DO, WHILE, FOR, MATCH, TRY, CATCH, FINALLY, LARROW, RARROW</parameter>
- </parameters>
- </check>
-
- <!-- ??? usually shouldn't be checked into the code base. -->
- <check level="error" class="org.scalastyle.scalariform.NotImplementedErrorUsage" enabled="true"></check>
-
- <!-- As of SPARK-7977 all printlns need to be wrapped in '// scalastyle:off/on println' -->
- <check customId="println" level="error" class="org.scalastyle.scalariform.TokenChecker" enabled="true">
- <parameters><parameter name="regex">^println$</parameter></parameters>
- <customMessage><![CDATA[Are you sure you want to println? If yes, wrap the code block with
- // scalastyle:off println
- println(...)
- // scalastyle:on println]]></customMessage>
- </check>
-
- <check customId="visiblefortesting" level="error" class="org.scalastyle.file.RegexChecker" enabled="true">
- <parameters><parameter name="regex">@VisibleForTesting</parameter></parameters>
- <customMessage><![CDATA[
- @VisibleForTesting causes classpath issues. Please note this in the java doc instead (SPARK-11615).
- ]]></customMessage>
- </check>
-
- <check customId="runtimeaddshutdownhook" level="error" class="org.scalastyle.file.RegexChecker" enabled="true">
- <parameters><parameter name="regex">Runtime\.getRuntime\.addShutdownHook</parameter></parameters>
- <customMessage><![CDATA[
- Are you sure that you want to use Runtime.getRuntime.addShutdownHook? In most cases, you should use
- ShutdownHookManager.addShutdownHook instead.
- If you must use Runtime.getRuntime.addShutdownHook, wrap the code block with
- // scalastyle:off runtimeaddshutdownhook
- Runtime.getRuntime.addShutdownHook(...)
- // scalastyle:on runtimeaddshutdownhook
- ]]></customMessage>
- </check>
-
- <check customId="mutablesynchronizedbuffer" level="error" class="org.scalastyle.file.RegexChecker" enabled="true">
- <parameters><parameter name="regex">mutable\.SynchronizedBuffer</parameter></parameters>
- <customMessage><![CDATA[
- Are you sure that you want to use mutable.SynchronizedBuffer? In most cases, you should use
- java.util.concurrent.ConcurrentLinkedQueue instead.
- If you must use mutable.SynchronizedBuffer, wrap the code block with
- // scalastyle:off mutablesynchronizedbuffer
- mutable.SynchronizedBuffer[...]
- // scalastyle:on mutablesynchronizedbuffer
- ]]></customMessage>
- </check>
-
- <check customId="classforname" level="error" class="org.scalastyle.file.RegexChecker" enabled="true">
- <parameters><parameter name="regex">Class\.forName</parameter></parameters>
- <customMessage><![CDATA[
- Are you sure that you want to use Class.forName? In most cases, you should use Utils.classForName instead.
- If you must use Class.forName, wrap the code block with
- // scalastyle:off classforname
- Class.forName(...)
- // scalastyle:on classforname
- ]]></customMessage>
- </check>
-
- <check customId="awaitresult" level="error" class="org.scalastyle.file.RegexChecker" enabled="true">
- <parameters><parameter name="regex">Await\.result</parameter></parameters>
- <customMessage><![CDATA[
- Are you sure that you want to use Await.result? In most cases, you should use ThreadUtils.awaitResult instead.
- If you must use Await.result, wrap the code block with
- // scalastyle:off awaitresult
- Await.result(...)
- // scalastyle:on awaitresult
- ]]></customMessage>
- </check>
-
- <!-- As of SPARK-9613 JavaConversions should be replaced with JavaConverters -->
- <check customId="javaconversions" level="error" class="org.scalastyle.scalariform.TokenChecker" enabled="true">
- <parameters><parameter name="regex">JavaConversions</parameter></parameters>
- <customMessage>Instead of importing implicits in scala.collection.JavaConversions._, import
- scala.collection.JavaConverters._ and use .asScala / .asJava methods</customMessage>
- </check>
-
- <check customId="commonslang2" level="error" class="org.scalastyle.scalariform.TokenChecker" enabled="true">
- <parameters><parameter name="regex">org\.apache\.commons\.lang\.</parameter></parameters>
- <customMessage>Use Commons Lang 3 classes (package org.apache.commons.lang3.*) instead
- of Commons Lang 2 (package org.apache.commons.lang.*)</customMessage>
- </check>
-
- <check level="error" class="org.scalastyle.scalariform.ImportOrderChecker" enabled="true">
- <parameters>
- <parameter name="groups">java,scala,3rdParty,spark</parameter>
- <parameter name="group.java">javax?\..*</parameter>
- <parameter name="group.scala">scala\..*</parameter>
- <parameter name="group.3rdParty">(?!org\.apache\.spark\.).*</parameter>
- <parameter name="group.spark">org\.apache\.spark\..*</parameter>
- </parameters>
- </check>
-
- <check level="error" class="org.scalastyle.scalariform.DisallowSpaceBeforeTokenChecker" enabled="true">
- <parameters>
- <parameter name="tokens">COMMA</parameter>
- </parameters>
- </check>
-
- <!-- SPARK-3854: Single Space between ')' and '{' -->
- <check customId="SingleSpaceBetweenRParenAndLCurlyBrace" level="error" class="org.scalastyle.file.RegexChecker" enabled="true">
- <parameters><parameter name="regex">\)\{</parameter></parameters>
- <customMessage><![CDATA[
- Single Space between ')' and `{`.
- ]]></customMessage>
- </check>
-
- <check customId="NoScalaDoc" level="error" class="org.scalastyle.file.RegexChecker" enabled="true">
- <parameters><parameter name="regex">(?m)^(\s*)/[*][*].*$(\r|)\n^\1 [*]</parameter></parameters>
- <customMessage>Use Javadoc style indentation for multiline comments</customMessage>
- </check>
-
- <check customId="OmitBracesInCase" level="error" class="org.scalastyle.file.RegexChecker" enabled="true">
- <parameters><parameter name="regex">case[^\n>]*=>\s*\{</parameter></parameters>
- <customMessage>Omit braces in case clauses.</customMessage>
- </check>
-
- <!-- SPARK-16877: Avoid Java annotations -->
- <check customId="OverrideJavaCase" level="error" class="org.scalastyle.scalariform.TokenChecker" enabled="true">
- <parameters><parameter name="regex">^Override$</parameter></parameters>
- <customMessage>override modifier should be used instead of @java.lang.Override.</customMessage>
- </check>
-
- <check level="error" class="org.scalastyle.scalariform.DeprecatedJavaChecker" enabled="true"></check>
-
- <!-- ================================================================================ -->
- <!-- rules we'd like to enforce, but haven't cleaned up the codebase yet -->
- <!-- ================================================================================ -->
-
- <!-- We cannot turn the following two on, because it'd fail a lot of string interpolation use cases. -->
- <!-- Ideally the following two rules should be configurable to rule out string interpolation. -->
- <check level="error" class="org.scalastyle.scalariform.NoWhitespaceBeforeLeftBracketChecker" enabled="false"></check>
- <check level="error" class="org.scalastyle.scalariform.NoWhitespaceAfterLeftBracketChecker" enabled="false"></check>
-
- <!-- This breaks symbolic method names so we don't turn it on. -->
- <!-- Maybe we should update it to allow basic symbolic names, and then we are good to go. -->
- <check level="error" class="org.scalastyle.scalariform.MethodNamesChecker" enabled="false">
- <parameters>
- <parameter name="regex"><![CDATA[^[a-z][A-Za-z0-9]*$]]></parameter>
- </parameters>
- </check>
-
- <!-- Should turn this on, but we have a few places that need to be fixed first -->
- <check level="error" class="org.scalastyle.scalariform.EqualsHashCodeChecker" enabled="true"></check>
-
- <!-- ================================================================================ -->
- <!-- rules we don't want -->
- <!-- ================================================================================ -->
-
- <check level="error" class="org.scalastyle.scalariform.IllegalImportsChecker" enabled="false">
- <parameters><parameter name="illegalImports"><![CDATA[sun._,java.awt._]]></parameter></parameters>
- </check>
-
- <!-- We want the opposite of this: NewLineAtEofChecker -->
- <check level="error" class="org.scalastyle.file.NoNewLineAtEofChecker" enabled="false"></check>
-
- <!-- This one complains about all kinds of random things. Disable. -->
- <check level="error" class="org.scalastyle.scalariform.SimplifyBooleanExpressionChecker" enabled="false"></check>
-
- <!-- We use return quite a bit for control flows and guards -->
- <check level="error" class="org.scalastyle.scalariform.ReturnChecker" enabled="false"></check>
-
- <!-- We use null a lot in low level code and to interface with 3rd party code -->
- <check level="error" class="org.scalastyle.scalariform.NullChecker" enabled="false"></check>
-
- <!-- Doesn't seem super big deal here ... -->
- <check level="error" class="org.scalastyle.scalariform.NoCloneChecker" enabled="false"></check>
-
- <!-- Doesn't seem super big deal here ... -->
- <check level="error" class="org.scalastyle.file.FileLengthChecker" enabled="false">
- <parameters><parameter name="maxFileLength">800></parameter></parameters>
- </check>
-
- <!-- Doesn't seem super big deal here ... -->
- <check level="error" class="org.scalastyle.scalariform.NumberOfTypesChecker" enabled="false">
- <parameters><parameter name="maxTypes">30</parameter></parameters>
- </check>
-
- <!-- Doesn't seem super big deal here ... -->
- <check level="error" class="org.scalastyle.scalariform.CyclomaticComplexityChecker" enabled="false">
- <parameters><parameter name="maximum">10</parameter></parameters>
- </check>
-
- <!-- Doesn't seem super big deal here ... -->
- <check level="error" class="org.scalastyle.scalariform.MethodLengthChecker" enabled="false">
- <parameters><parameter name="maxLength">50</parameter></parameters>
- </check>
-
- <!-- Not exactly feasible to enforce this right now. -->
- <!-- It is also infrequent that somebody introduces a new class with a lot of methods. -->
- <check level="error" class="org.scalastyle.scalariform.NumberOfMethodsInTypeChecker" enabled="false">
- <parameters><parameter name="maxMethods"><![CDATA[30]]></parameter></parameters>
- </check>
-
- <!-- Doesn't seem super big deal here, and we have a lot of magic numbers ... -->
- <check level="error" class="org.scalastyle.scalariform.MagicNumberChecker" enabled="false">
- <parameters><parameter name="ignore">-1,0,1,2,3</parameter></parameters>
- </check>
-
-</scalastyle>
diff --git a/spark/spark-2.2/bin/mvn-zinc b/spark/spark-2.2/bin/mvn-zinc
deleted file mode 100755
index 581c45f..0000000
--- a/spark/spark-2.2/bin/mvn-zinc
+++ /dev/null
@@ -1,99 +0,0 @@
-#!/usr/bin/env bash
-
-#
-# Licensed to the Apache Software Foundation (ASF) under one or more
-# contributor license agreements. See the NOTICE file distributed with
-# this work for additional information regarding copyright ownership.
-# The ASF licenses this file to You under the Apache License, Version 2.0
-# (the "License"); you may not use this file except in compliance with
-# the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-# Copyed from commit 48682f6bf663e54cb63b7e95a4520d34b6fa890b in Apache Spark
-
-# Determine the current working directory
-_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
-# Preserve the calling directory
-_CALLING_DIR="$(pwd)"
-# Options used during compilation
-_COMPILE_JVM_OPTS="-Xmx2g -XX:MaxMetaspaceSize=512M -XX:ReservedCodeCacheSize=512m"
-
-# Installs any application tarball given a URL, the expected tarball name,
-# and, optionally, a checkable binary path to determine if the binary has
-# already been installed
-## Arg1 - URL
-## Arg2 - Tarball Name
-## Arg3 - Checkable Binary
-install_app() {
- local remote_tarball="$1/$2"
- local local_tarball="${_DIR}/$2"
- local binary="${_DIR}/$3"
- local curl_opts="--progress-bar -L"
- local wget_opts="--progress=bar:force ${wget_opts}"
-
- if [ -z "$3" -o ! -f "$binary" ]; then
- # check if we already have the tarball
- # check if we have curl installed
- # download application
- [ ! -f "${local_tarball}" ] && [ $(command -v curl) ] && \
- echo "exec: curl ${curl_opts} ${remote_tarball}" 1>&2 && \
- curl ${curl_opts} "${remote_tarball}" > "${local_tarball}"
- # if the file still doesn't exist, lets try `wget` and cross our fingers
- [ ! -f "${local_tarball}" ] && [ $(command -v wget) ] && \
- echo "exec: wget ${wget_opts} ${remote_tarball}" 1>&2 && \
- wget ${wget_opts} -O "${local_tarball}" "${remote_tarball}"
- # if both were unsuccessful, exit
- [ ! -f "${local_tarball}" ] && \
- echo -n "ERROR: Cannot download $2 with cURL or wget; " && \
- echo "please install manually and try again." && \
- exit 2
- cd "${_DIR}" && tar -xzf "$2"
- rm -rf "$local_tarball"
- fi
-}
-
-# Install zinc under the bin/ folder
-install_zinc() {
- local zinc_path="zinc-0.3.9/bin/zinc"
- [ ! -f "${_DIR}/${zinc_path}" ] && ZINC_INSTALL_FLAG=1
- install_app \
- "http://downloads.typesafe.com/zinc/0.3.9" \
- "zinc-0.3.9.tgz" \
- "${zinc_path}"
- ZINC_BIN="${_DIR}/${zinc_path}"
-}
-
-# Setup healthy defaults for the Zinc port if none were provided from
-# the environment
-ZINC_PORT=${ZINC_PORT:-"3030"}
-
-# Install Zinc for the bin/
-install_zinc
-
-# Reset the current working directory
-cd "${_CALLING_DIR}"
-
-# Now that zinc is ensured to be installed, check its status and, if its
-# not running or just installed, start it
-if [ ! -f "${ZINC_BIN}" ]; then
- exit -1
-fi
-if [ -n "${ZINC_INSTALL_FLAG}" -o -z "`"${ZINC_BIN}" -status -port ${ZINC_PORT}`" ]; then
- export ZINC_OPTS=${ZINC_OPTS:-"$_COMPILE_JVM_OPTS"}
- "${ZINC_BIN}" -shutdown -port ${ZINC_PORT}
- "${ZINC_BIN}" -start -port ${ZINC_PORT} &>/dev/null
-fi
-
-# Set any `mvn` options if not already present
-export MAVEN_OPTS=${MAVEN_OPTS:-"$_COMPILE_JVM_OPTS"}
-
-# Last, call the `mvn` command as usual
-mvn -DzincPort=${ZINC_PORT} "$@"
diff --git a/spark/spark-2.2/extra-src/README.md b/spark/spark-2.2/extra-src/README.md
deleted file mode 100644
index bdffa37..0000000
--- a/spark/spark-2.2/extra-src/README.md
+++ /dev/null
@@ -1,20 +0,0 @@
-<!--
- Licensed to the Apache Software Foundation (ASF) under one
- or more contributor license agreements. See the NOTICE file
- distributed with this work for additional information
- regarding copyright ownership. The ASF licenses this file
- to you under the Apache License, Version 2.0 (the
- "License"); you may not use this file except in compliance
- with the License. You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing,
- software distributed under the License is distributed on an
- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- KIND, either express or implied. See the License for the
- specific language governing permissions and limitations
- under the License.
--->
-
-Copyed from the spark v2.2.0 release.
diff --git a/spark/spark-2.2/extra-src/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala b/spark/spark-2.2/extra-src/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala
deleted file mode 100644
index 9e98948..0000000
--- a/spark/spark-2.2/extra-src/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala
+++ /dev/null
@@ -1,279 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.hive
-
-import java.io.{InputStream, OutputStream}
-import java.rmi.server.UID
-
-import scala.collection.JavaConverters._
-import scala.language.implicitConversions
-import scala.reflect.ClassTag
-
-import com.google.common.base.Objects
-import org.apache.avro.Schema
-import org.apache.hadoop.conf.Configuration
-import org.apache.hadoop.fs.Path
-import org.apache.hadoop.hive.ql.exec.{UDF, Utilities}
-import org.apache.hadoop.hive.ql.plan.{FileSinkDesc, TableDesc}
-import org.apache.hadoop.hive.ql.udf.generic.GenericUDFMacro
-import org.apache.hadoop.hive.serde2.ColumnProjectionUtils
-import org.apache.hadoop.hive.serde2.avro.{AvroGenericRecordWritable, AvroSerdeUtils}
-import org.apache.hadoop.hive.serde2.objectinspector.primitive.HiveDecimalObjectInspector
-import org.apache.hadoop.io.Writable
-import org.apache.hive.com.esotericsoftware.kryo.Kryo
-import org.apache.hive.com.esotericsoftware.kryo.io.{Input, Output}
-
-import org.apache.spark.internal.Logging
-import org.apache.spark.sql.types.Decimal
-import org.apache.spark.util.Utils
-
-private[hive] object HiveShim {
- // Precision and scale to pass for unlimited decimals; these are the same as the precision and
- // scale Hive 0.13 infers for BigDecimals from sources that don't specify them (e.g. UDFs)
- val UNLIMITED_DECIMAL_PRECISION = 38
- val UNLIMITED_DECIMAL_SCALE = 18
- val HIVE_GENERIC_UDF_MACRO_CLS = "org.apache.hadoop.hive.ql.udf.generic.GenericUDFMacro"
-
- /*
- * This function in hive-0.13 become private, but we have to do this to walkaround hive bug
- */
- private def appendReadColumnNames(conf: Configuration, cols: Seq[String]) {
- val old: String = conf.get(ColumnProjectionUtils.READ_COLUMN_NAMES_CONF_STR, "")
- val result: StringBuilder = new StringBuilder(old)
- var first: Boolean = old.isEmpty
-
- for (col <- cols) {
- if (first) {
- first = false
- } else {
- result.append(',')
- }
- result.append(col)
- }
- conf.set(ColumnProjectionUtils.READ_COLUMN_NAMES_CONF_STR, result.toString)
- }
-
- /*
- * Cannot use ColumnProjectionUtils.appendReadColumns directly, if ids is null
- */
- def appendReadColumns(conf: Configuration, ids: Seq[Integer], names: Seq[String]) {
- if (ids != null) {
- ColumnProjectionUtils.appendReadColumns(conf, ids.asJava)
- }
- if (names != null) {
- appendReadColumnNames(conf, names)
- }
- }
-
- /*
- * Bug introduced in hive-0.13. AvroGenericRecordWritable has a member recordReaderID that
- * is needed to initialize before serialization.
- */
- def prepareWritable(w: Writable, serDeProps: Seq[(String, String)]): Writable = {
- w match {
- case w: AvroGenericRecordWritable =>
- w.setRecordReaderID(new UID())
- // In Hive 1.1, the record's schema may need to be initialized manually or a NPE will
- // be thrown.
- if (w.getFileSchema() == null) {
- serDeProps
- .find(_._1 == AvroSerdeUtils.AvroTableProperties.SCHEMA_LITERAL.getPropName())
- .foreach { kv =>
- w.setFileSchema(new Schema.Parser().parse(kv._2))
- }
- }
- case _ =>
- }
- w
- }
-
- def toCatalystDecimal(hdoi: HiveDecimalObjectInspector, data: Any): Decimal = {
- if (hdoi.preferWritable()) {
- Decimal(hdoi.getPrimitiveWritableObject(data).getHiveDecimal().bigDecimalValue,
- hdoi.precision(), hdoi.scale())
- } else {
- Decimal(hdoi.getPrimitiveJavaObject(data).bigDecimalValue(), hdoi.precision(), hdoi.scale())
- }
- }
-
- /**
- * This class provides the UDF creation and also the UDF instance serialization and
- * de-serialization cross process boundary.
- *
- * Detail discussion can be found at https://github.com/apache/spark/pull/3640
- *
- * @param functionClassName UDF class name
- * @param instance optional UDF instance which contains additional information (for macro)
- */
- private[hive] case class HiveFunctionWrapper(var functionClassName: String,
- private var instance: AnyRef = null) extends java.io.Externalizable {
-
- // for Serialization
- def this() = this(null)
-
- override def hashCode(): Int = {
- if (functionClassName == HIVE_GENERIC_UDF_MACRO_CLS) {
- Objects.hashCode(functionClassName, instance.asInstanceOf[GenericUDFMacro].getBody())
- } else {
- functionClassName.hashCode()
- }
- }
-
- override def equals(other: Any): Boolean = other match {
- case a: HiveFunctionWrapper if functionClassName == a.functionClassName =>
- // In case of udf macro, check to make sure they point to the same underlying UDF
- if (functionClassName == HIVE_GENERIC_UDF_MACRO_CLS) {
- a.instance.asInstanceOf[GenericUDFMacro].getBody() ==
- instance.asInstanceOf[GenericUDFMacro].getBody()
- } else {
- true
- }
- case _ => false
- }
-
- @transient
- def deserializeObjectByKryo[T: ClassTag](
- kryo: Kryo,
- in: InputStream,
- clazz: Class[_]): T = {
- val inp = new Input(in)
- val t: T = kryo.readObject(inp, clazz).asInstanceOf[T]
- inp.close()
- t
- }
-
- @transient
- def serializeObjectByKryo(
- kryo: Kryo,
- plan: Object,
- out: OutputStream) {
- val output: Output = new Output(out)
- kryo.writeObject(output, plan)
- output.close()
- }
-
- def deserializePlan[UDFType](is: java.io.InputStream, clazz: Class[_]): UDFType = {
- deserializeObjectByKryo(Utilities.runtimeSerializationKryo.get(), is, clazz)
- .asInstanceOf[UDFType]
- }
-
- def serializePlan(function: AnyRef, out: java.io.OutputStream): Unit = {
- serializeObjectByKryo(Utilities.runtimeSerializationKryo.get(), function, out)
- }
-
- def writeExternal(out: java.io.ObjectOutput) {
- // output the function name
- out.writeUTF(functionClassName)
-
- // Write a flag if instance is null or not
- out.writeBoolean(instance != null)
- if (instance != null) {
- // Some of the UDF are serializable, but some others are not
- // Hive Utilities can handle both cases
- val baos = new java.io.ByteArrayOutputStream()
- serializePlan(instance, baos)
- val functionInBytes = baos.toByteArray
-
- // output the function bytes
- out.writeInt(functionInBytes.length)
- out.write(functionInBytes, 0, functionInBytes.length)
- }
- }
-
- def readExternal(in: java.io.ObjectInput) {
- // read the function name
- functionClassName = in.readUTF()
-
- if (in.readBoolean()) {
- // if the instance is not null
- // read the function in bytes
- val functionInBytesLength = in.readInt()
- val functionInBytes = new Array[Byte](functionInBytesLength)
- in.readFully(functionInBytes)
-
- // deserialize the function object via Hive Utilities
- instance = deserializePlan[AnyRef](new java.io.ByteArrayInputStream(functionInBytes),
- Utils.getContextOrSparkClassLoader.loadClass(functionClassName))
- }
- }
-
- def createFunction[UDFType <: AnyRef](): UDFType = {
- if (instance != null) {
- instance.asInstanceOf[UDFType]
- } else {
- val func = Utils.getContextOrSparkClassLoader
- .loadClass(functionClassName).newInstance.asInstanceOf[UDFType]
- if (!func.isInstanceOf[UDF]) {
- // We cache the function if it's no the Simple UDF,
- // as we always have to create new instance for Simple UDF
- instance = func
- }
- func
- }
- }
- }
-
- /*
- * Bug introduced in hive-0.13. FileSinkDesc is serializable, but its member path is not.
- * Fix it through wrapper.
- */
- implicit def wrapperToFileSinkDesc(w: ShimFileSinkDesc): FileSinkDesc = {
- val f = new FileSinkDesc(new Path(w.dir), w.tableInfo, w.compressed)
- f.setCompressCodec(w.compressCodec)
- f.setCompressType(w.compressType)
- f.setTableInfo(w.tableInfo)
- f.setDestTableId(w.destTableId)
- f
- }
-
- /*
- * Bug introduced in hive-0.13. FileSinkDesc is serializable, but its member path is not.
- * Fix it through wrapper.
- */
- private[hive] class ShimFileSinkDesc(
- var dir: String,
- var tableInfo: TableDesc,
- var compressed: Boolean)
- extends Serializable with Logging {
- var compressCodec: String = _
- var compressType: String = _
- var destTableId: Int = _
-
- def setCompressed(compressed: Boolean) {
- this.compressed = compressed
- }
-
- def getDirName(): String = dir
-
- def setDestTableId(destTableId: Int) {
- this.destTableId = destTableId
- }
-
- def setTableInfo(tableInfo: TableDesc) {
- this.tableInfo = tableInfo
- }
-
- def setCompressCodec(intermediateCompressorCodec: String) {
- compressCodec = intermediateCompressorCodec
- }
-
- def setCompressType(intermediateCompressType: String) {
- compressType = intermediateCompressType
- }
- }
-}
diff --git a/spark/spark-2.2/pom.xml b/spark/spark-2.2/pom.xml
deleted file mode 100644
index c7a9997..0000000
--- a/spark/spark-2.2/pom.xml
+++ /dev/null
@@ -1,142 +0,0 @@
-<!--
- Licensed to the Apache Software Foundation (ASF) under one
- or more contributor license agreements. See the NOTICE file
- distributed with this work for additional information
- regarding copyright ownership. The ASF licenses this file
- to you under the Apache License, Version 2.0 (the
- "License"); you may not use this file except in compliance
- with the License. You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing,
- software distributed under the License is distributed on an
- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- KIND, either express or implied. See the License for the
- specific language governing permissions and limitations
- under the License.
--->
-<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
- <modelVersion>4.0.0</modelVersion>
-
- <parent>
- <groupId>org.apache.hivemall</groupId>
- <artifactId>hivemall-spark</artifactId>
- <version>0.6.0-incubating-SNAPSHOT</version>
- <relativePath>../pom.xml</relativePath>
- </parent>
-
- <artifactId>hivemall-spark2.2</artifactId>
- <name>Hivemall on Spark 2.2</name>
- <packaging>jar</packaging>
-
- <properties>
- <main.basedir>${project.parent.parent.basedir}</main.basedir>
- <spark.version>2.2.0</spark.version>
- <spark.binary.version>2.2</spark.binary.version>
- <hadoop.version>2.6.5</hadoop.version>
- <scalatest.jvm.opts>-ea -Xms768m -Xmx2g -XX:MetaspaceSize=128m -XX:MaxMetaspaceSize=512m -XX:ReservedCodeCacheSize=512m</scalatest.jvm.opts>
- <maven.compiler.source>1.8</maven.compiler.source>
- <maven.compiler.target>1.8</maven.compiler.target>
- </properties>
-
- <dependencies>
- <!-- compile scope -->
- <dependency>
- <groupId>org.apache.hivemall</groupId>
- <artifactId>hivemall-core</artifactId>
- <scope>compile</scope>
- </dependency>
- <dependency>
- <groupId>org.apache.hivemall</groupId>
- <artifactId>hivemall-xgboost</artifactId>
- <scope>compile</scope>
- </dependency>
- <dependency>
- <groupId>org.apache.hivemall</groupId>
- <artifactId>hivemall-spark-common</artifactId>
- <version>${project.version}</version>
- <scope>compile</scope>
- </dependency>
-
- <!-- provided scope -->
- <dependency>
- <groupId>org.scala-lang</groupId>
- <artifactId>scala-library</artifactId>
- <scope>provided</scope>
- </dependency>
- <dependency>
- <groupId>org.apache.spark</groupId>
- <artifactId>spark-core_${scala.binary.version}</artifactId>
- <version>${spark.version}</version>
- <scope>provided</scope>
- </dependency>
- <dependency>
- <groupId>org.apache.spark</groupId>
- <artifactId>spark-sql_${scala.binary.version}</artifactId>
- <version>${spark.version}</version>
- <scope>provided</scope>
- </dependency>
- <dependency>
- <groupId>org.apache.spark</groupId>
- <artifactId>spark-hive_${scala.binary.version}</artifactId>
- <version>${spark.version}</version>
- <scope>provided</scope>
- </dependency>
- <dependency>
- <groupId>org.apache.spark</groupId>
- <artifactId>spark-streaming_${scala.binary.version}</artifactId>
- <version>${spark.version}</version>
- <scope>provided</scope>
- </dependency>
- <dependency>
- <groupId>org.apache.spark</groupId>
- <artifactId>spark-mllib_${scala.binary.version}</artifactId>
- <version>${spark.version}</version>
- <scope>provided</scope>
- </dependency>
-
- <!-- test dependencies -->
- <dependency>
- <groupId>org.apache.hivemall</groupId>
- <artifactId>hivemall-mixserv</artifactId>
- <scope>test</scope>
- </dependency>
- <dependency>
- <groupId>org.scalatest</groupId>
- <artifactId>scalatest_${scala.binary.version}</artifactId>
- <scope>test</scope>
- </dependency>
- </dependencies>
-
- <build>
- <plugins>
- <!-- hivemall-spark_xx-xx-with-dependencies.jar including minimum dependencies -->
- <plugin>
- <groupId>org.apache.maven.plugins</groupId>
- <artifactId>maven-shade-plugin</artifactId>
- </plugin>
- <!-- disable surefire because there is no java test -->
- <plugin>
- <groupId>org.apache.maven.plugins</groupId>
- <artifactId>maven-surefire-plugin</artifactId>
- <configuration>
- <skipTests>true</skipTests>
- </configuration>
- </plugin>
- <!-- then, enable scalatest -->
- <plugin>
- <groupId>org.scalatest</groupId>
- <artifactId>scalatest-maven-plugin</artifactId>
- <executions>
- <execution>
- <id>test</id>
- <goals>
- <goal>test</goal>
- </goals>
- </execution>
- </executions>
- </plugin>
- </plugins>
- </build>
-</project>
diff --git a/spark/spark-2.2/src/main/java/hivemall/xgboost/XGBoostOptions.scala b/spark/spark-2.2/src/main/java/hivemall/xgboost/XGBoostOptions.scala
deleted file mode 100644
index 3e0f274..0000000
--- a/spark/spark-2.2/src/main/java/hivemall/xgboost/XGBoostOptions.scala
+++ /dev/null
@@ -1,59 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package hivemall.xgboost
-
-import scala.collection.mutable
-
-import org.apache.commons.cli.Options
-import org.apache.spark.annotation.AlphaComponent
-
-/**
- * :: AlphaComponent ::
- * An utility class to generate a sequence of options used in XGBoost.
- */
-@AlphaComponent
-case class XGBoostOptions() {
- private val params: mutable.Map[String, String] = mutable.Map.empty
- private val options: Options = {
- new XGBoostUDTF() {
- def options(): Options = super.getOptions()
- }.options()
- }
-
- private def isValidKey(key: String): Boolean = {
- // TODO: Is there another way to handle all the XGBoost options?
- options.hasOption(key) || key == "num_class"
- }
-
- def set(key: String, value: String): XGBoostOptions = {
- require(isValidKey(key), s"non-existing key detected in XGBoost options: ${key}")
- params.put(key, value)
- this
- }
-
- def help(): Unit = {
- import scala.collection.JavaConversions._
- options.getOptions.map { case option => println(option) }
- }
-
- override def toString(): String = {
- params.map { case (key, value) => s"-$key $value" }.mkString(" ")
- }
-}
diff --git a/spark/spark-2.2/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/spark/spark-2.2/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister
deleted file mode 100644
index b49e20a..0000000
--- a/spark/spark-2.2/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister
+++ /dev/null
@@ -1 +0,0 @@
-org.apache.spark.sql.hive.source.XGBoostFileFormat
diff --git a/spark/spark-2.2/src/main/resources/log4j.properties b/spark/spark-2.2/src/main/resources/log4j.properties
deleted file mode 100644
index ef4f606..0000000
--- a/spark/spark-2.2/src/main/resources/log4j.properties
+++ /dev/null
@@ -1,29 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one or more
-# contributor license agreements. See the NOTICE file distributed with
-# this work for additional information regarding copyright ownership.
-# The ASF licenses this file to You under the Apache License, Version 2.0
-# (the "License"); you may not use this file except in compliance with
-# the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-# Set everything to be logged to the console
-log4j.rootCategory=INFO, console
-log4j.appender.console=org.apache.log4j.ConsoleAppender
-log4j.appender.console.target=System.err
-log4j.appender.console.layout=org.apache.log4j.PatternLayout
-log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n
-
-# Settings to quiet third party logs that are too verbose
-log4j.logger.org.eclipse.jetty=INFO
-log4j.logger.org.eclipse.jetty.util.component.AbstractLifeCycle=ERROR
-log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO
-log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO
diff --git a/spark/spark-2.2/src/main/scala/hivemall/tools/RegressionDatagen.scala b/spark/spark-2.2/src/main/scala/hivemall/tools/RegressionDatagen.scala
deleted file mode 100644
index a2b7f60..0000000
--- a/spark/spark-2.2/src/main/scala/hivemall/tools/RegressionDatagen.scala
+++ /dev/null
@@ -1,67 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package hivemall.tools
-
-import org.apache.spark.sql.{DataFrame, Row, SQLContext}
-import org.apache.spark.sql.functions._
-import org.apache.spark.sql.hive.HivemallOps._
-import org.apache.spark.sql.types._
-
-object RegressionDatagen {
-
- /**
- * Generate data for regression/classification.
- * See [[hivemall.dataset.LogisticRegressionDataGeneratorUDTF]]
- * for the details of arguments below.
- */
- def exec(sc: SQLContext,
- n_partitions: Int = 2,
- min_examples: Int = 1000,
- n_features: Int = 10,
- n_dims: Int = 200,
- seed: Int = 43,
- dense: Boolean = false,
- prob_one: Float = 0.6f,
- sort: Boolean = false,
- cl: Boolean = false): DataFrame = {
-
- require(n_partitions > 0, "Non-negative #n_partitions required.")
- require(min_examples > 0, "Non-negative #min_examples required.")
- require(n_features > 0, "Non-negative #n_features required.")
- require(n_dims > 0, "Non-negative #n_dims required.")
-
- // Calculate #examples to generate in each partition
- val n_examples = (min_examples + n_partitions - 1) / n_partitions
-
- val df = sc.createDataFrame(
- sc.sparkContext.parallelize((0 until n_partitions).map(Row(_)), n_partitions),
- StructType(
- StructField("data", IntegerType, true) ::
- Nil)
- )
- import sc.implicits._
- df.lr_datagen(
- lit(s"-n_examples $n_examples -n_features $n_features -n_dims $n_dims -prob_one $prob_one"
- + (if (dense) " -dense" else "")
- + (if (sort) " -sort" else "")
- + (if (cl) " -cl" else ""))
- ).select($"label".cast(DoubleType).as("label"), $"features")
- }
-}
diff --git a/spark/spark-2.2/src/main/scala/org/apache/spark/sql/catalyst/expressions/EachTopK.scala b/spark/spark-2.2/src/main/scala/org/apache/spark/sql/catalyst/expressions/EachTopK.scala
deleted file mode 100644
index 15bc068..0000000
--- a/spark/spark-2.2/src/main/scala/org/apache/spark/sql/catalyst/expressions/EachTopK.scala
+++ /dev/null
@@ -1,135 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.spark.sql.catalyst.expressions
-
-import scala.collection.mutable
-
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
-import org.apache.spark.sql.catalyst.expressions.codegen._
-import org.apache.spark.sql.catalyst.util.TypeUtils
-import org.apache.spark.sql.catalyst.utils.InternalRowPriorityQueue
-import org.apache.spark.sql.types._
-
-trait TopKHelper {
-
- def k: Int
- def scoreType: DataType
-
- @transient val ScoreTypes = TypeCollection(
- ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType, DecimalType
- )
-
- protected case class ScoreWriter(writer: UnsafeRowWriter, ordinal: Int) {
-
- def write(v: Any): Unit = scoreType match {
- case ByteType => writer.write(ordinal, v.asInstanceOf[Byte])
- case ShortType => writer.write(ordinal, v.asInstanceOf[Short])
- case IntegerType => writer.write(ordinal, v.asInstanceOf[Int])
- case LongType => writer.write(ordinal, v.asInstanceOf[Long])
- case FloatType => writer.write(ordinal, v.asInstanceOf[Float])
- case DoubleType => writer.write(ordinal, v.asInstanceOf[Double])
- case d: DecimalType => writer.write(ordinal, v.asInstanceOf[Decimal], d.precision, d.scale)
- }
- }
-
- protected lazy val scoreOrdering = {
- val ordering = TypeUtils.getInterpretedOrdering(scoreType)
- if (k > 0) ordering else ordering.reverse
- }
-
- protected lazy val reverseScoreOrdering = scoreOrdering.reverse
-
- protected lazy val queue: InternalRowPriorityQueue = {
- new InternalRowPriorityQueue(Math.abs(k), (x: Any, y: Any) => scoreOrdering.compare(x, y))
- }
-}
-
-case class EachTopK(
- k: Int,
- scoreExpr: Expression,
- groupExprs: Seq[Expression],
- elementSchema: StructType,
- children: Seq[Attribute])
- extends Generator with TopKHelper with CodegenFallback {
-
- override val scoreType: DataType = scoreExpr.dataType
-
- private lazy val groupingProjection: UnsafeProjection = UnsafeProjection.create(groupExprs)
- private lazy val scoreProjection: UnsafeProjection = UnsafeProjection.create(scoreExpr :: Nil)
-
- // The grouping key of the current partition
- private var currentGroupingKeys: UnsafeRow = _
-
- override def checkInputDataTypes(): TypeCheckResult = {
- if (!ScoreTypes.acceptsType(scoreExpr.dataType)) {
- TypeCheckResult.TypeCheckFailure(s"$scoreExpr must have a comparable type")
- } else {
- TypeCheckResult.TypeCheckSuccess
- }
- }
-
- private def topKRowsForGroup(): Seq[InternalRow] = if (queue.size > 0) {
- val outputRows = queue.iterator.toSeq.sortBy(_._1)(scoreOrdering).reverse
- val (headScore, _) = outputRows.head
- val rankNum = outputRows.scanLeft((1, headScore)) {
- case ((rank, prevScore), (score, _)) =>
- if (prevScore == score) (rank, score) else (rank + 1, score)
- }.tail
- val buf = mutable.ArrayBuffer[InternalRow]()
- var i = 0
- while (rankNum.length > i) {
- val rank = rankNum(i)._1
- val row = new JoinedRow(InternalRow.fromSeq(rank :: Nil), outputRows(i)._2)
- buf.append(row)
- i += 1
- }
- buf
- } else {
- Seq.empty
- }
-
- override def eval(input: InternalRow): TraversableOnce[InternalRow] = {
- val groupingKeys = groupingProjection(input)
- val ret = if (currentGroupingKeys != groupingKeys) {
- val topKRows = topKRowsForGroup()
- currentGroupingKeys = groupingKeys.copy()
- queue.clear()
- topKRows
- } else {
- Iterator.empty
- }
- queue += Tuple2(scoreProjection(input).get(0, scoreType), input)
- ret
- }
-
- override def terminate(): TraversableOnce[InternalRow] = {
- if (queue.size > 0) {
- val topKRows = topKRowsForGroup()
- queue.clear()
- topKRows
- } else {
- Iterator.empty
- }
- }
-
- // TODO: Need to support codegen
- // protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode
-}
diff --git a/spark/spark-2.2/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/JoinTopK.scala b/spark/spark-2.2/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/JoinTopK.scala
deleted file mode 100644
index 556cdc3..0000000
--- a/spark/spark-2.2/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/JoinTopK.scala
+++ /dev/null
@@ -1,68 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.spark.sql.catalyst.plans.logical
-
-import org.apache.spark.sql.AnalysisException
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.{Inner, JoinType}
-import org.apache.spark.sql.types.{BooleanType, IntegerType}
-
-case class JoinTopK(
- k: Int,
- left: LogicalPlan,
- right: LogicalPlan,
- joinType: JoinType,
- condition: Option[Expression])(
- val scoreExpr: NamedExpression,
- private[sql] val rankAttr: Seq[Attribute] = AttributeReference("rank", IntegerType)() :: Nil)
- extends BinaryNode with PredicateHelper {
-
- override def output: Seq[Attribute] = joinType match {
- case Inner => rankAttr ++ Seq(scoreExpr.toAttribute) ++ left.output ++ right.output
- }
-
- override def references: AttributeSet = {
- AttributeSet((expressions ++ Seq(scoreExpr)).flatMap(_.references))
- }
-
- override protected def validConstraints: Set[Expression] = joinType match {
- case Inner if condition.isDefined =>
- left.constraints.union(right.constraints)
- .union(splitConjunctivePredicates(condition.get).toSet)
- }
-
- override protected final def otherCopyArgs: Seq[AnyRef] = {
- scoreExpr :: rankAttr :: Nil
- }
-
- def duplicateResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty
-
- lazy val resolvedExceptNatural: Boolean = {
- childrenResolved &&
- expressions.forall(_.resolved) &&
- duplicateResolved &&
- condition.forall(_.dataType == BooleanType)
- }
-
- override lazy val resolved: Boolean = joinType match {
- case Inner => resolvedExceptNatural
- case tpe => throw new AnalysisException(s"Unsupported using join type $tpe")
- }
-}
diff --git a/spark/spark-2.2/src/main/scala/org/apache/spark/sql/catalyst/utils/InternalRowPriorityQueue.scala b/spark/spark-2.2/src/main/scala/org/apache/spark/sql/catalyst/utils/InternalRowPriorityQueue.scala
deleted file mode 100644
index 12c20fb..0000000
--- a/spark/spark-2.2/src/main/scala/org/apache/spark/sql/catalyst/utils/InternalRowPriorityQueue.scala
+++ /dev/null
@@ -1,76 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.spark.sql.catalyst.utils
-
-import java.io.Serializable
-import java.util.{PriorityQueue => JPriorityQueue}
-
-import scala.collection.JavaConverters._
-import scala.collection.generic.Growable
-
-import org.apache.spark.sql.catalyst.InternalRow
-
-private[sql] class InternalRowPriorityQueue(
- maxSize: Int,
- compareFunc: (Any, Any) => Int
- ) extends Iterable[(Any, InternalRow)] with Growable[(Any, InternalRow)] with Serializable {
-
- private[this] val ordering = new Ordering[(Any, InternalRow)] {
- override def compare(x: (Any, InternalRow), y: (Any, InternalRow)): Int =
- compareFunc(x._1, y._1)
- }
-
- private val underlying = new JPriorityQueue[(Any, InternalRow)](maxSize, ordering)
-
- override def iterator: Iterator[(Any, InternalRow)] = underlying.iterator.asScala
-
- override def size: Int = underlying.size
-
- override def ++=(xs: TraversableOnce[(Any, InternalRow)]): this.type = {
- xs.foreach { this += _ }
- this
- }
-
- override def +=(elem: (Any, InternalRow)): this.type = {
- if (size < maxSize) {
- underlying.offer((elem._1, elem._2.copy()))
- } else {
- maybeReplaceLowest(elem)
- }
- this
- }
-
- override def +=(elem1: (Any, InternalRow), elem2: (Any, InternalRow), elems: (Any, InternalRow)*)
- : this.type = {
- this += elem1 += elem2 ++= elems
- }
-
- override def clear() { underlying.clear() }
-
- private def maybeReplaceLowest(a: (Any, InternalRow)): Boolean = {
- val head = underlying.peek()
- if (head != null && ordering.gt(a, head)) {
- underlying.poll()
- underlying.offer((a._1, a._2.copy()))
- } else {
- false
- }
- }
-}
diff --git a/spark/spark-2.2/src/main/scala/org/apache/spark/sql/execution/UserProvidedPlanner.scala b/spark/spark-2.2/src/main/scala/org/apache/spark/sql/execution/UserProvidedPlanner.scala
deleted file mode 100644
index 09d60a6..0000000
--- a/spark/spark-2.2/src/main/scala/org/apache/spark/sql/execution/UserProvidedPlanner.scala
+++ /dev/null
@@ -1,83 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.spark.sql.execution
-
-import org.apache.spark.internal.Logging
-import org.apache.spark.sql.Strategy
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.JoinType
-import org.apache.spark.sql.catalyst.plans.logical.{JoinTopK, LogicalPlan}
-import org.apache.spark.sql.internal.SQLConf
-
-private object ExtractJoinTopKKeys extends Logging with PredicateHelper {
- /** (k, scoreExpr, joinType, leftKeys, rightKeys, condition, leftChild, rightChild) */
- type ReturnType =
- (Int, NamedExpression, Seq[Attribute], JoinType, Seq[Expression], Seq[Expression],
- Option[Expression], LogicalPlan, LogicalPlan)
-
- def unapply(plan: LogicalPlan): Option[ReturnType] = plan match {
- case join @ JoinTopK(k, left, right, joinType, condition) =>
- logDebug(s"Considering join on: $condition")
- val predicates = condition.map(splitConjunctivePredicates).getOrElse(Nil)
- val joinKeys = predicates.flatMap {
- case EqualTo(l, r) if canEvaluate(l, left) && canEvaluate(r, right) => Some((l, r))
- case EqualTo(l, r) if canEvaluate(l, right) && canEvaluate(r, left) => Some((r, l))
- // Replace null with default value for joining key, then those rows with null in it could
- // be joined together
- case EqualNullSafe(l, r) if canEvaluate(l, left) && canEvaluate(r, right) =>
- Some((Coalesce(Seq(l, Literal.default(l.dataType))),
- Coalesce(Seq(r, Literal.default(r.dataType)))))
- case EqualNullSafe(l, r) if canEvaluate(l, right) && canEvaluate(r, left) =>
- Some((Coalesce(Seq(r, Literal.default(r.dataType))),
- Coalesce(Seq(l, Literal.default(l.dataType)))))
- case other => None
- }
- val otherPredicates = predicates.filterNot {
- case EqualTo(l, r) =>
- canEvaluate(l, left) && canEvaluate(r, right) ||
- canEvaluate(l, right) && canEvaluate(r, left)
- case other => false
- }
-
- if (joinKeys.nonEmpty) {
- val (leftKeys, rightKeys) = joinKeys.unzip
- logDebug(s"leftKeys:$leftKeys | rightKeys:$rightKeys")
- Some((k, join.scoreExpr, join.rankAttr, joinType, leftKeys, rightKeys,
- otherPredicates.reduceOption(And), left, right))
- } else {
- None
- }
-
- case p =>
- None
- }
-}
-
-private[sql] class UserProvidedPlanner(val conf: SQLConf) extends Strategy {
-
- override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
- case ExtractJoinTopKKeys(
- k, scoreExpr, rankAttr, _, leftKeys, rightKeys, condition, left, right) =>
- Seq(joins.ShuffledHashJoinTopKExec(
- k, leftKeys, rightKeys, condition, planLater(left), planLater(right))(scoreExpr, rankAttr))
- case _ =>
- Nil
- }
-}
diff --git a/spark/spark-2.2/src/main/scala/org/apache/spark/sql/execution/datasources/csv/csvExpressions.scala b/spark/spark-2.2/src/main/scala/org/apache/spark/sql/execution/datasources/csv/csvExpressions.scala
deleted file mode 100644
index 1f56c90..0000000
--- a/spark/spark-2.2/src/main/scala/org/apache/spark/sql/execution/datasources/csv/csvExpressions.scala
+++ /dev/null
@@ -1,169 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.spark.sql.execution.datasources.csv
-
-import com.univocity.parsers.csv.CsvWriter
-
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
-import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, TimeZoneAwareExpression, UnaryExpression}
-import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
-import org.apache.spark.sql.catalyst.util.DateTimeUtils
-import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.types.UTF8String
-
-/**
- * Converts a csv input string to a [[StructType]] with the specified schema.
- *
- * TODO: Move this class into org.apache.spark.sql.catalyst.expressions in Spark-v2.2+
- */
-case class CsvToStruct(
- schema: StructType,
- options: Map[String, String],
- child: Expression,
- timeZoneId: Option[String] = None)
- extends UnaryExpression with TimeZoneAwareExpression with CodegenFallback with ExpectsInputTypes {
-
- def this(schema: StructType, options: Map[String, String], child: Expression) =
- this(schema, options, child, None)
-
- override def nullable: Boolean = true
-
- @transient private lazy val csvOptions = new CSVOptions(options, timeZoneId.get)
- @transient private lazy val csvParser = new UnivocityParser(schema, schema, csvOptions)
-
- private def parse(input: String): InternalRow = csvParser.parse(input)
-
- override def dataType: DataType = schema
-
- override def nullSafeEval(csv: Any): Any = {
- try parse(csv.toString) catch { case _: RuntimeException => null }
- }
-
- override def inputTypes: Seq[AbstractDataType] = StringType :: Nil
-
- override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
- copy(timeZoneId = Option(timeZoneId))
-}
-
-private class CsvGenerator(schema: StructType, options: CSVOptions) {
-
- // A `ValueConverter` is responsible for converting a value of an `InternalRow` to `String`.
- // When the value is null, this converter should not be called.
- private type ValueConverter = (InternalRow, Int) => String
-
- // `ValueConverter`s for all values in the fields of the schema
- private val valueConverters: Array[ValueConverter] =
- schema.map(_.dataType).map(makeConverter).toArray
-
- private def makeConverter(dataType: DataType): ValueConverter = dataType match {
- case DateType =>
- (row: InternalRow, ordinal: Int) =>
- options.dateFormat.format(DateTimeUtils.toJavaDate(row.getInt(ordinal)))
-
- case TimestampType =>
- (row: InternalRow, ordinal: Int) =>
- options.timestampFormat.format(DateTimeUtils.toJavaTimestamp(row.getLong(ordinal)))
-
- case udt: UserDefinedType[_] => makeConverter(udt.sqlType)
-
- case dt: DataType =>
- (row: InternalRow, ordinal: Int) =>
- row.get(ordinal, dt).toString
- }
-
- def convertRow(row: InternalRow): Seq[String] = {
- var i = 0
- val values = new Array[String](row.numFields)
- while (i < row.numFields) {
- if (!row.isNullAt(i)) {
- values(i) = valueConverters(i).apply(row, i)
- } else {
- values(i) = options.nullValue
- }
- i += 1
- }
- values
- }
-}
-
-/**
- * Converts a [[StructType]] to a csv output string.
- */
-case class StructToCsv(
- options: Map[String, String],
- child: Expression,
- timeZoneId: Option[String] = None)
- extends UnaryExpression with TimeZoneAwareExpression with CodegenFallback with ExpectsInputTypes {
- override def nullable: Boolean = true
-
- @transient
- private lazy val params = new CSVOptions(options, timeZoneId.get)
-
- @transient
- private lazy val dataSchema = child.dataType.asInstanceOf[StructType]
-
- @transient
- private lazy val writer = new CsvGenerator(dataSchema, params)
-
- override def dataType: DataType = StringType
-
- private def verifySchema(schema: StructType): Unit = {
- def verifyType(dataType: DataType): Unit = dataType match {
- case ByteType | ShortType | IntegerType | LongType | FloatType |
- DoubleType | BooleanType | _: DecimalType | TimestampType |
- DateType | StringType =>
-
- case udt: UserDefinedType[_] => verifyType(udt.sqlType)
-
- case _ =>
- throw new UnsupportedOperationException(
- s"CSV data source does not support ${dataType.simpleString} data type.")
- }
-
- schema.foreach(field => verifyType(field.dataType))
- }
-
- override def checkInputDataTypes(): TypeCheckResult = {
- if (StructType.acceptsType(child.dataType)) {
- try {
- verifySchema(child.dataType.asInstanceOf[StructType])
- TypeCheckResult.TypeCheckSuccess
- } catch {
- case e: UnsupportedOperationException =>
- TypeCheckResult.TypeCheckFailure(e.getMessage)
- }
- } else {
- TypeCheckResult.TypeCheckFailure(
- s"$prettyName requires that the expression is a struct expression.")
- }
- }
-
- override def nullSafeEval(row: Any): Any = {
- val rowStr = writer.convertRow(row.asInstanceOf[InternalRow])
- .mkString(params.delimiter.toString)
- UTF8String.fromString(rowStr)
- }
-
- override def inputTypes: Seq[AbstractDataType] = StructType :: Nil
-
- override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
- copy(timeZoneId = Option(timeZoneId))
-}
diff --git a/spark/spark-2.2/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinTopKExec.scala b/spark/spark-2.2/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinTopKExec.scala
deleted file mode 100644
index 0067bbb..0000000
--- a/spark/spark-2.2/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinTopKExec.scala
+++ /dev/null
@@ -1,405 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.spark.sql.execution.joins
-
-import org.apache.spark.TaskContext
-import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.codegen._
-import org.apache.spark.sql.catalyst.plans._
-import org.apache.spark.sql.catalyst.plans.physical._
-import org.apache.spark.sql.catalyst.utils.InternalRowPriorityQueue
-import org.apache.spark.sql.execution._
-import org.apache.spark.sql.execution.metric._
-import org.apache.spark.sql.types._
-
-abstract class PriorityQueueShim {
-
- def insert(score: Any, row: InternalRow): Unit
- def get(): Iterator[InternalRow]
- def clear(): Unit
-}
-
-case class ShuffledHashJoinTopKExec(
- k: Int,
- leftKeys: Seq[Expression],
- rightKeys: Seq[Expression],
- condition: Option[Expression],
- left: SparkPlan,
- right: SparkPlan)(
- scoreExpr: NamedExpression,
- rankAttr: Seq[Attribute])
- extends BinaryExecNode with TopKHelper with HashJoin with CodegenSupport {
-
- override lazy val metrics = Map(
- "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
-
- override val scoreType: DataType = scoreExpr.dataType
- override val joinType: JoinType = Inner
- override val buildSide: BuildSide = BuildRight // Only support `BuildRight`
-
- private lazy val scoreProjection: UnsafeProjection =
- UnsafeProjection.create(scoreExpr :: Nil, left.output ++ right.output)
-
- private lazy val boundCondition = if (condition.isDefined) {
- (r: InternalRow) => newPredicate(condition.get, streamedPlan.output ++ buildPlan.output).eval(r)
- } else {
- (r: InternalRow) => true
- }
-
- private lazy val topKAttr = rankAttr :+ scoreExpr.toAttribute
-
- private lazy val _priorityQueue = new PriorityQueueShim {
-
- private val q: InternalRowPriorityQueue = queue
- private val joinedRow = new JoinedRow
-
- override def insert(score: Any, row: InternalRow): Unit = {
- q += Tuple2(score, row)
- }
-
- override def get(): Iterator[InternalRow] = {
- val outputRows = queue.iterator.toSeq.reverse
- val (headScore, _) = outputRows.head
- val rankNum = outputRows.scanLeft((1, headScore)) { case ((rank, prevScore), (score, _)) =>
- if (prevScore == score) (rank, score) else (rank + 1, score)
- }
- val topKRow = new UnsafeRow(2)
- val bufferHolder = new BufferHolder(topKRow)
- val unsafeRowWriter = new UnsafeRowWriter(bufferHolder, 2)
- val scoreWriter = ScoreWriter(unsafeRowWriter, 1)
- outputRows.zip(rankNum.map(_._1)).map { case ((score, row), index) =>
- // Writes to an UnsafeRow directly
- bufferHolder.reset()
- unsafeRowWriter.write(0, index)
- scoreWriter.write(score)
- topKRow.setTotalSize(bufferHolder.totalSize())
- joinedRow.apply(topKRow, row)
- }.iterator
- }
-
- override def clear(): Unit = q.clear()
- }
-
- override def output: Seq[Attribute] = joinType match {
- case Inner => topKAttr ++ left.output ++ right.output
- }
-
- override protected final def otherCopyArgs: Seq[AnyRef] = {
- scoreExpr :: rankAttr :: Nil
- }
-
- override def requiredChildDistribution: Seq[Distribution] =
- ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
-
- def buildHashedRelation(iter: Iterator[InternalRow]): HashedRelation = {
- val context = TaskContext.get()
- val relation = HashedRelation(iter, buildKeys, taskMemoryManager = context.taskMemoryManager())
- context.addTaskCompletionListener(_ => relation.close())
- relation
- }
-
- override protected def createResultProjection(): (InternalRow) => InternalRow = joinType match {
- case Inner =>
- // Always put the stream side on left to simplify implementation
- // both of left and right side could be null
- UnsafeProjection.create(
- output, (topKAttr ++ streamedPlan.output ++ buildPlan.output).map(_.withNullability(true)))
- }
-
- protected def InnerJoin(
- streamedIter: Iterator[InternalRow],
- hashedRelation: HashedRelation,
- numOutputRows: SQLMetric): Iterator[InternalRow] = {
- val joinRow = new JoinedRow
- val joinKeysProj = streamSideKeyGenerator()
- val joinedIter = streamedIter.flatMap { srow =>
- joinRow.withLeft(srow)
- val joinKeys = joinKeysProj(srow) // `joinKeys` is also a grouping key
- val matches = hashedRelation.get(joinKeys)
- if (matches != null) {
- matches.map(joinRow.withRight).filter(boundCondition).foreach { resultRow =>
- _priorityQueue.insert(scoreProjection(resultRow).get(0, scoreType), resultRow)
- }
- val iter = _priorityQueue.get()
- _priorityQueue.clear()
- iter
- } else {
- Seq.empty
- }
- }
- val resultProj = createResultProjection()
- (joinedIter ++ queue.iterator.toSeq.sortBy(_._1)(reverseScoreOrdering)
- .map(_._2)).map { r =>
- resultProj(r)
- }
- }
-
- override protected def doExecute(): RDD[InternalRow] = {
- streamedPlan.execute().zipPartitions(buildPlan.execute()) { (streamIter, buildIter) =>
- val hashed = buildHashedRelation(buildIter)
- InnerJoin(streamIter, hashed, null)
- }
- }
-
- override def inputRDDs(): Seq[RDD[InternalRow]] = {
- left.execute() :: right.execute() :: Nil
- }
-
- // Accessor for generated code
- def priorityQueue(): PriorityQueueShim = _priorityQueue
-
- /**
- * Add a state of HashedRelation and return the variable name for it.
- */
- private def prepareHashedRelation(ctx: CodegenContext): String = {
- // create a name for HashedRelation
- val joinExec = ctx.addReferenceObj("joinExec", this)
- val relationTerm = ctx.freshName("relation")
- val clsName = HashedRelation.getClass.getName.replace("$", "")
- ctx.addMutableState(clsName, relationTerm,
- s"""
- | $relationTerm = ($clsName) $joinExec.buildHashedRelation(inputs[1]);
- | incPeakExecutionMemory($relationTerm.estimatedSize());
- """.stripMargin)
- relationTerm
- }
-
- /**
- * Creates variables for left part of result row.
- *
- * In order to defer the access after condition and also only access once in the loop,
- * the variables should be declared separately from accessing the columns, we can't use the
- * codegen of BoundReference here.
- */
- private def createLeftVars(ctx: CodegenContext, leftRow: String): Seq[ExprCode] = {
- ctx.INPUT_ROW = leftRow
- left.output.zipWithIndex.map { case (a, i) =>
- val value = ctx.freshName("value")
- val valueCode = ctx.getValue(leftRow, a.dataType, i.toString)
- // declare it as class member, so we can access the column before or in the loop.
- ctx.addMutableState(ctx.javaType(a.dataType), value, "")
- if (a.nullable) {
- val isNull = ctx.freshName("isNull")
- ctx.addMutableState("boolean", isNull, "")
- val code =
- s"""
- |$isNull = $leftRow.isNullAt($i);
- |$value = $isNull ? ${ctx.defaultValue(a.dataType)} : ($valueCode);
- """.stripMargin
- ExprCode(code, isNull, value)
- } else {
- ExprCode(s"$value = $valueCode;", "false", value)
- }
- }
- }
-
- /**
- * Creates the variables for right part of result row, using BoundReference, since the right
- * part are accessed inside the loop.
- */
- private def createRightVar(ctx: CodegenContext, rightRow: String): Seq[ExprCode] = {
- ctx.INPUT_ROW = rightRow
- right.output.zipWithIndex.map { case (a, i) =>
- BoundReference(i, a.dataType, a.nullable).genCode(ctx)
- }
- }
-
- /**
- * Returns the code for generating join key for stream side, and expression of whether the key
- * has any null in it or not.
- */
- private def genStreamSideJoinKey(ctx: CodegenContext, leftRow: String): (ExprCode, String) = {
- ctx.INPUT_ROW = leftRow
- if (streamedKeys.length == 1 && streamedKeys.head.dataType == LongType) {
- // generate the join key as Long
- val ev = streamedKeys.head.genCode(ctx)
- (ev, ev.isNull)
- } else {
- // generate the join key as UnsafeRow
- val ev = GenerateUnsafeProjection.createCode(ctx, streamedKeys)
- (ev, s"${ev.value}.anyNull()")
- }
- }
-
- private def createScoreVar(ctx: CodegenContext, row: String): ExprCode = {
- ctx.INPUT_ROW = row
- BindReferences.bindReference(scoreExpr, left.output ++ right.output).genCode(ctx)
- }
-
- private def createResultVars(ctx: CodegenContext, resultRow: String): Seq[ExprCode] = {
- ctx.INPUT_ROW = resultRow
- output.zipWithIndex.map { case (a, i) =>
- val value = ctx.freshName("value")
- val valueCode = ctx.getValue(resultRow, a.dataType, i.toString)
- // declare it as class member, so we can access the column before or in the loop.
- ctx.addMutableState(ctx.javaType(a.dataType), value, "")
- if (a.nullable) {
- val isNull = ctx.freshName("isNull")
- ctx.addMutableState("boolean", isNull, "")
- val code =
- s"""
- |$isNull = $resultRow.isNullAt($i);
- |$value = $isNull ? ${ctx.defaultValue(a.dataType)} : ($valueCode);
- """.stripMargin
- ExprCode(code, isNull, value)
- } else {
- ExprCode(s"$value = $valueCode;", "false", value)
- }
- }
- }
-
- /**
- * Splits variables based on whether it's used by condition or not, returns the code to create
- * these variables before the condition and after the condition.
- *
- * Only a few columns are used by condition, then we can skip the accessing of those columns
- * that are not used by condition also filtered out by condition.
- */
- private def splitVarsByCondition(
- attributes: Seq[Attribute],
- variables: Seq[ExprCode]): (String, String) = {
- if (condition.isDefined) {
- val condRefs = condition.get.references
- val (used, notUsed) = attributes.zip(variables).partition{ case (a, ev) =>
- condRefs.contains(a)
- }
- val beforeCond = evaluateVariables(used.map(_._2))
- val afterCond = evaluateVariables(notUsed.map(_._2))
- (beforeCond, afterCond)
- } else {
- (evaluateVariables(variables), "")
- }
- }
-
- override def doProduce(ctx: CodegenContext): String = {
- ctx.copyResult = true
-
- val topKJoin = ctx.addReferenceObj("topKJoin", this)
-
- // Prepare a priority queue for top-K computing
- val pQueue = ctx.freshName("queue")
- ctx.addMutableState(classOf[PriorityQueueShim].getName, pQueue,
- s"$pQueue = $topKJoin.priorityQueue();")
-
- // Prepare variables for a left side
- val leftIter = ctx.freshName("leftIter")
- ctx.addMutableState("scala.collection.Iterator", leftIter, s"$leftIter = inputs[0];")
- val leftRow = ctx.freshName("leftRow")
- ctx.addMutableState("InternalRow", leftRow, "")
- val leftVars = createLeftVars(ctx, leftRow)
-
- // Prepare variables for a right side
- val rightRow = ctx.freshName("rightRow")
- val rightVars = createRightVar(ctx, rightRow)
-
- // Build a hashed relation from a right side
- val buildRelation = prepareHashedRelation(ctx)
-
- // Project join keys from a left side
- val (keyEv, anyNull) = genStreamSideJoinKey(ctx, leftRow)
-
- // Prepare variables for joined rows
- val joinedRow = ctx.freshName("joinedRow")
- val joinedRowCls = classOf[JoinedRow].getName
- ctx.addMutableState(joinedRowCls, joinedRow, s"$joinedRow = new $joinedRowCls();")
-
- // Project score values from joined rows
- val scoreVar = createScoreVar(ctx, joinedRow)
-
- // Prepare variables for output rows
- val resultRow = ctx.freshName("resultRow")
- val resultVars = createResultVars(ctx, resultRow)
-
- val (beforeLoop, condCheck) = if (condition.isDefined) {
- // Split the code of creating variables based on whether it's used by condition or not.
- val loaded = ctx.freshName("loaded")
- val (leftBefore, leftAfter) = splitVarsByCondition(left.output, leftVars)
- val (rightBefore, rightAfter) = splitVarsByCondition(right.output, rightVars)
- // Generate code for condition
- ctx.currentVars = leftVars ++ rightVars
- val cond = BindReferences.bindReference(condition.get, output).genCode(ctx)
- // evaluate the columns those used by condition before loop
- val before = s"""
- |boolean $loaded = false;
- |$leftBefore
- """.stripMargin
-
- val checking = s"""
- |$rightBefore
- |${cond.code}
- |if (${cond.isNull} || !${cond.value}) continue;
- |if (!$loaded) {
- | $loaded = true;
- | $leftAfter
- |}
- |$rightAfter
- """.stripMargin
- (before, checking)
- } else {
- (evaluateVariables(leftVars), "")
- }
-
- val numOutput = metricTerm(ctx, "numOutputRows")
-
- val matches = ctx.freshName("matches")
- val topKRows = ctx.freshName("topKRows")
- val iteratorCls = classOf[Iterator[UnsafeRow]].getName
-
- s"""
- |$leftRow = null;
- |while ($leftIter.hasNext()) {
- | $leftRow = (InternalRow) $leftIter.next();
- |
- | // Generate join key for stream side
- | ${keyEv.code}
- |
- | // Find matches from HashedRelation
- | $iteratorCls $matches = $anyNull? null : ($iteratorCls)$buildRelation.get(${keyEv.value});
- | if ($matches == null) continue;
- |
- | // Join top-K right rows
- | while ($matches.hasNext()) {
- | ${beforeLoop.trim}
- | InternalRow $rightRow = (InternalRow) $matches.next();
- | ${condCheck.trim}
- | InternalRow row = $joinedRow.apply($leftRow, $rightRow);
- | // Compute a score for the `row`
- | ${scoreVar.code}
- | $pQueue.insert(${scoreVar.value}, row);
- | }
- |
- | // Get top-K rows
- | $iteratorCls $topKRows = $pQueue.get();
- | $pQueue.clear();
- |
- | // Output top-K rows
- | while ($topKRows.hasNext()) {
- | InternalRow $resultRow = (InternalRow) $topKRows.next();
- | $numOutput.add(1);
- | ${consume(ctx, resultVars)}
- | }
- |
- | if (shouldStop()) return;
- |}
- """.stripMargin
- }
-}
diff --git a/spark/spark-2.2/src/main/scala/org/apache/spark/sql/hive/HivemallGroupedDataset.scala b/spark/spark-2.2/src/main/scala/org/apache/spark/sql/hive/HivemallGroupedDataset.scala
deleted file mode 100644
index 2982d9c..0000000
--- a/spark/spark-2.2/src/main/scala/org/apache/spark/sql/hive/HivemallGroupedDataset.scala
+++ /dev/null
@@ -1,636 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.spark.sql.hive
-
-import org.apache.spark.sql.AnalysisException
-import org.apache.spark.sql.DataFrame
-import org.apache.spark.sql.Dataset
-import org.apache.spark.sql.RelationalGroupedDataset
-import org.apache.spark.sql.catalyst.analysis.UnresolvedAlias
-import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.logical.Aggregate
-import org.apache.spark.sql.catalyst.plans.logical.Pivot
-import org.apache.spark.sql.hive.HiveShim.HiveFunctionWrapper
-import org.apache.spark.sql.types._
-
-/**
- * Groups the [[DataFrame]] using the specified columns, so we can run aggregation on them.
- *
- * @groupname classifier
- * @groupname ensemble
- * @groupname evaluation
- * @groupname topicmodel
- * @groupname ftvec.selection
- * @groupname ftvec.text
- * @groupname ftvec.trans
- * @groupname tools.array
- * @groupname tools.bits
- * @groupname tools.list
- * @groupname tools.map
- * @groupname tools.matrix
- * @groupname tools.math
- *
- * A list of unsupported functions is as follows:
- * * ftvec.conv
- * - conv2dense
- * - build_bins
- */
-final class HivemallGroupedDataset(groupBy: RelationalGroupedDataset) {
-
- /**
- * @see hivemall.classifier.KPAPredictUDAF
- * @group classifier
- */
- def kpa_predict(xh: String, xk: String, w0: String, w1: String, w2: String, w3: String)
- : DataFrame = {
- checkType(xh, DoubleType)
- checkType(xk, DoubleType)
- checkType(w0, FloatType)
- checkType(w1, FloatType)
- checkType(w2, FloatType)
- checkType(w3, FloatType)
- val udaf = HiveUDAFFunction(
- "kpa_predict",
- new HiveFunctionWrapper("hivemall.classifier.KPAPredictUDAF"),
- Seq(xh, xk, w0, w1, w2, w3).map(df(_).expr),
- isUDAFBridgeRequired = false)
- .toAggregateExpression()
- toDF(Alias(udaf, udaf.prettyName)() :: Nil)
- }
-
- /**
- * @see hivemall.ensemble.bagging.VotedAvgUDAF
- * @group ensemble
- */
- def voted_avg(weight: String): DataFrame = {
- checkType(weight, DoubleType)
- val udaf = HiveUDAFFunction(
- "voted_avg",
- new HiveFunctionWrapper("hivemall.ensemble.bagging.WeightVotedAvgUDAF"),
- Seq(weight).map(df(_).expr),
- isUDAFBridgeRequired = true)
- .toAggregateExpression()
- toDF(Alias(udaf, udaf.prettyName)() :: Nil)
- }
-
- /**
- * @see hivemall.ensemble.bagging.WeightVotedAvgUDAF
- * @group ensemble
- */
- def weight_voted_avg(weight: String): DataFrame = {
- checkType(weight, DoubleType)
- val udaf = HiveUDAFFunction(
- "weight_voted_avg",
- new HiveFunctionWrapper("hivemall.ensemble.bagging.WeightVotedAvgUDAF"),
- Seq(weight).map(df(_).expr),
- isUDAFBridgeRequired = true)
- .toAggregateExpression()
- toDF(Alias(udaf, udaf.prettyName)() :: Nil)
- }
-
- /**
- * @see hivemall.ensemble.ArgminKLDistanceUDAF
- * @group ensemble
- */
- def argmin_kld(weight: String, conv: String): DataFrame = {
- checkType(weight, FloatType)
- checkType(conv, FloatType)
- val udaf = HiveUDAFFunction(
- "argmin_kld",
- new HiveFunctionWrapper("hivemall.ensemble.ArgminKLDistanceUDAF"),
- Seq(weight, conv).map(df(_).expr),
- isUDAFBridgeRequired = true)
- .toAggregateExpression()
- toDF(Alias(udaf, udaf.prettyName)() :: Nil)
- }
-
- /**
- * @see hivemall.ensemble.MaxValueLabelUDAF"
- * @group ensemble
- */
- def max_label(score: String, label: String): DataFrame = {
- // checkType(score, DoubleType)
- checkType(label, StringType)
- val udaf = HiveUDAFFunction(
- "max_label",
- new HiveFunctionWrapper("hivemall.ensemble.MaxValueLabelUDAF"),
- Seq(score, label).map(df(_).expr),
- isUDAFBridgeRequired = true)
- .toAggregateExpression()
- toDF(Alias(udaf, udaf.prettyName)() :: Nil)
- }
-
- /**
- * @see hivemall.ensemble.MaxRowUDAF
- * @group ensemble
- */
- def maxrow(score: String, label: String): DataFrame = {
- checkType(score, DoubleType)
- checkType(label, StringType)
- val udaf = HiveUDAFFunction(
- "maxrow",
- new HiveFunctionWrapper("hivemall.ensemble.MaxRowUDAF"),
- Seq(score, label).map(df(_).expr),
- isUDAFBridgeRequired = false)
- .toAggregateExpression()
- toDF(Alias(udaf, udaf.prettyName)() :: Nil)
- }
-
- /**
- * @see hivemall.smile.tools.RandomForestEnsembleUDAF
- * @group ensemble
- */
- @scala.annotation.varargs
- def rf_ensemble(yhat: String, others: String*): DataFrame = {
- checkType(yhat, IntegerType)
- val udaf = HiveUDAFFunction(
- "rf_ensemble",
- new HiveFunctionWrapper("hivemall.smile.tools.RandomForestEnsembleUDAF"),
- (yhat +: others).map(df(_).expr),
- isUDAFBridgeRequired = false)
- .toAggregateExpression()
- toDF(Alias(udaf, udaf.prettyName)() :: Nil)
- }
-
- /**
- * @see hivemall.evaluation.MeanAbsoluteErrorUDAF
- * @group evaluation
- */
- def mae(predict: String, target: String): DataFrame = {
- checkType(predict, DoubleType)
- checkType(target, DoubleType)
- val udaf = HiveUDAFFunction(
- "mae",
- new HiveFunctionWrapper("hivemall.evaluation.MeanAbsoluteErrorUDAF"),
- Seq(predict, target).map(df(_).expr),
- isUDAFBridgeRequired = true)
- .toAggregateExpression()
- toDF(Alias(udaf, udaf.prettyName)() :: Nil)
- }
-
- /**
- * @see hivemall.evaluation.MeanSquareErrorUDAF
- * @group evaluation
- */
- def mse(predict: String, target: String): DataFrame = {
- checkType(predict, DoubleType)
- checkType(target, DoubleType)
- val udaf = HiveUDAFFunction(
- "mse",
- new HiveFunctionWrapper("hivemall.evaluation.MeanSquaredErrorUDAF"),
- Seq(predict, target).map(df(_).expr),
- isUDAFBridgeRequired = true)
- .toAggregateExpression()
- toDF(Alias(udaf, udaf.prettyName)() :: Nil)
- }
-
- /**
- * @see hivemall.evaluation.RootMeanSquareErrorUDAF
- * @group evaluation
- */
- def rmse(predict: String, target: String): DataFrame = {
- checkType(predict, DoubleType)
- checkType(target, DoubleType)
- val udaf = HiveUDAFFunction(
- "rmse",
- new HiveFunctionWrapper("hivemall.evaluation.RootMeanSquaredErrorUDAF"),
- Seq(predict, target).map(df(_).expr),
- isUDAFBridgeRequired = true)
- .toAggregateExpression()
- toDF(Alias(udaf, udaf.prettyName)() :: Nil)
- }
-
- /**
- * @see hivemall.evaluation.R2UDAF
- * @group evaluation
- */
- def r2(predict: String, target: String): DataFrame = {
- checkType(predict, DoubleType)
- checkType(target, DoubleType)
- val udaf = HiveUDAFFunction(
- "r2",
- new HiveFunctionWrapper("hivemall.evaluation.R2UDAF"),
- Seq(predict, target).map(df(_).expr),
- isUDAFBridgeRequired = true)
- .toAggregateExpression()
- toDF(Alias(udaf, udaf.prettyName)() :: Nil)
- }
-
- /**
- * @see hivemall.evaluation.LogarithmicLossUDAF
- * @group evaluation
- */
- def logloss(predict: String, target: String): DataFrame = {
- checkType(predict, DoubleType)
- checkType(target, DoubleType)
- val udaf = HiveUDAFFunction(
- "logloss",
- new HiveFunctionWrapper("hivemall.evaluation.LogarithmicLossUDAF"),
- Seq(predict, target).map(df(_).expr),
- isUDAFBridgeRequired = true)
- .toAggregateExpression()
- toDF(Alias(udaf, udaf.prettyName)() :: Nil)
- }
-
- /**
- * @see hivemall.evaluation.F1ScoreUDAF
- * @group evaluation
- */
- def f1score(predict: String, target: String): DataFrame = {
- // checkType(target, ArrayType(IntegerType, false))
- // checkType(predict, ArrayType(IntegerType, false))
- val udaf = HiveUDAFFunction(
- "f1score",
- new HiveFunctionWrapper("hivemall.evaluation.F1ScoreUDAF"),
- Seq(predict, target).map(df(_).expr),
- isUDAFBridgeRequired = true)
- .toAggregateExpression()
- toDF(Alias(udaf, udaf.prettyName)() :: Nil)
- }
-
- /**
- * @see hivemall.evaluation.NDCGUDAF
- * @group evaluation
- */
- @scala.annotation.varargs
- def ndcg(rankItems: String, correctItems: String, others: String*): DataFrame = {
- val udaf = HiveUDAFFunction(
- "ndcg",
- new HiveFunctionWrapper("hivemall.evaluation.NDCGUDAF"),
- (rankItems +: correctItems +: others).map(df(_).expr),
- isUDAFBridgeRequired = false)
- .toAggregateExpression()
- toDF(Alias(udaf, udaf.prettyName)() :: Nil)
- }
-
- /**
- * @see hivemall.evaluation.PrecisionUDAF
- * @group evaluation
- */
- @scala.annotation.varargs
- def precision_at(rankItems: String, correctItems: String, others: String*): DataFrame = {
- val udaf = HiveUDAFFunction(
- "precision_at",
- new HiveFunctionWrapper("hivemall.evaluation.PrecisionUDAF"),
- (rankItems +: correctItems +: others).map(df(_).expr),
- isUDAFBridgeRequired = false)
- .toAggregateExpression()
- toDF(Alias(udaf, udaf.prettyName)() :: Nil)
- }
-
- /**
- * @see hivemall.evaluation.RecallUDAF
- * @group evaluation
- */
- @scala.annotation.varargs
- def recall_at(rankItems: String, correctItems: String, others: String*): DataFrame = {
- val udaf = HiveUDAFFunction(
- "recall_at",
- new HiveFunctionWrapper("hivemall.evaluation.RecallUDAF"),
- (rankItems +: correctItems +: others).map(df(_).expr),
- isUDAFBridgeRequired = false)
- .toAggregateExpression()
- toDF(Alias(udaf, udaf.prettyName)() :: Nil)
- }
-
- /**
- * @see hivemall.evaluation.HitRateUDAF
- * @group evaluation
- */
- @scala.annotation.varargs
- def hitrate(rankItems: String, correctItems: String, others: String*): DataFrame = {
- val udaf = HiveUDAFFunction(
- "hitrate",
- new HiveFunctionWrapper("hivemall.evaluation.HitRateUDAF"),
- (rankItems +: correctItems +: others).map(df(_).expr),
- isUDAFBridgeRequired = false)
- .toAggregateExpression()
- toDF(Alias(udaf, udaf.prettyName)() :: Nil)
- }
-
- /**
- * @see hivemall.evaluation.MRRUDAF
- * @group evaluation
- */
- @scala.annotation.varargs
- def mrr(rankItems: String, correctItems: String, others: String*): DataFrame = {
- val udaf = HiveUDAFFunction(
- "mrr",
- new HiveFunctionWrapper("hivemall.evaluation.MRRUDAF"),
- (rankItems +: correctItems +: others).map(df(_).expr),
- isUDAFBridgeRequired = false)
- .toAggregateExpression()
- toDF(Alias(udaf, udaf.prettyName)() :: Nil)
- }
-
- /**
- * @see hivemall.evaluation.MAPUDAF
- * @group evaluation
- */
- @scala.annotation.varargs
- def average_precision(rankItems: String, correctItems: String, others: String*): DataFrame = {
- val udaf = HiveUDAFFunction(
- "average_precision",
- new HiveFunctionWrapper("hivemall.evaluation.MAPUDAF"),
- (rankItems +: correctItems +: others).map(df(_).expr),
- isUDAFBridgeRequired = false)
- .toAggregateExpression()
- toDF(Alias(udaf, udaf.prettyName)() :: Nil)
- }
-
- /**
- * @see hivemall.evaluation.AUCUDAF
- * @group evaluation
- */
- @scala.annotation.varargs
- def auc(args: String*): DataFrame = {
- val udaf = HiveUDAFFunction(
- "auc",
- new HiveFunctionWrapper("hivemall.evaluation.AUCUDAF"),
- args.map(df(_).expr),
- isUDAFBridgeRequired = false)
- .toAggregateExpression()
- toDF(Alias(udaf, udaf.prettyName)() :: Nil)
- }
-
- /**
- * @see hivemall.topicmodel.LDAPredictUDAF
- * @group topicmodel
- */
- @scala.annotation.varargs
- def lda_predict(word: String, value: String, label: String, lambda: String, others: String*)
- : DataFrame = {
- checkType(word, StringType)
- checkType(value, DoubleType)
- checkType(label, IntegerType)
- checkType(lambda, DoubleType)
- val udaf = HiveUDAFFunction(
- "lda_predict",
- new HiveFunctionWrapper("hivemall.topicmodel.LDAPredictUDAF"),
- (word +: value +: label +: lambda +: others).map(df(_).expr),
- isUDAFBridgeRequired = false)
- .toAggregateExpression()
- toDF(Alias(udaf, udaf.prettyName)() :: Nil)
- }
-
- /**
- * @see hivemall.topicmodel.PLSAPredictUDAF
- * @group topicmodel
- */
- @scala.annotation.varargs
- def plsa_predict(word: String, value: String, label: String, prob: String, others: String*)
- : DataFrame = {
- checkType(word, StringType)
- checkType(value, DoubleType)
- checkType(label, IntegerType)
- checkType(prob, DoubleType)
- val udaf = HiveUDAFFunction(
- "plsa_predict",
- new HiveFunctionWrapper("hivemall.topicmodel.PLSAPredictUDAF"),
- (word +: value +: label +: prob +: others).map(df(_).expr),
- isUDAFBridgeRequired = false)
- .toAggregateExpression()
- toDF(Alias(udaf, udaf.prettyName)() :: Nil)
- }
-
- /**
- * @see hivemall.ftvec.text.TermFrequencyUDAF
- * @group ftvec.text
- */
- def tf(text: String): DataFrame = {
- checkType(text, StringType)
- val udaf = HiveUDAFFunction(
- "tf",
- new HiveFunctionWrapper("hivemall.ftvec.text.TermFrequencyUDAF"),
- Seq(text).map(df(_).expr),
- isUDAFBridgeRequired = true)
- .toAggregateExpression()
- toDF(Alias(udaf, udaf.prettyName)() :: Nil)
- }
-
- /**
- * @see hivemall.ftvec.trans.OnehotEncodingUDAF
- * @group ftvec.trans
- */
- @scala.annotation.varargs
- def onehot_encoding(feature: String, others: String*): DataFrame = {
- val udaf = HiveUDAFFunction(
- "onehot_encoding",
- new HiveFunctionWrapper("hivemall.ftvec.trans.OnehotEncodingUDAF"),
- (feature +: others).map(df(_).expr),
- isUDAFBridgeRequired = false)
- .toAggregateExpression()
- toDF(Alias(udaf, udaf.prettyName)() :: Nil)
- }
-
- /**
- * @see hivemall.ftvec.selection.SignalNoiseRatioUDAF
- * @group ftvec.selection
- */
- def snr(feature: String, label: String): DataFrame = {
- val udaf = HiveUDAFFunction(
- "snr",
- new HiveFunctionWrapper("hivemall.ftvec.selection.SignalNoiseRatioUDAF"),
- Seq(feature, label).map(df(_).expr),
- isUDAFBridgeRequired = false)
- .toAggregateExpression()
- toDF(Alias(udaf, udaf.prettyName)() :: Nil)
- }
-
- /**
- * @see hivemall.tools.array.ArrayAvgGenericUDAF
- * @group tools.array
- */
- def array_avg(ar: String): DataFrame = {
- val udaf = HiveUDAFFunction(
- "array_avg",
- new HiveFunctionWrapper("hivemall.tools.array.ArrayAvgGenericUDAF"),
- Seq(ar).map(df(_).expr),
- isUDAFBridgeRequired = false)
- .toAggregateExpression()
- toDF(Alias(udaf, udaf.prettyName)() :: Nil)
- }
-
- /**
- * @see hivemall.tools.array.ArraySumUDAF
- * @group tools.array
- */
- def array_sum(ar: String): DataFrame = {
- val udaf = HiveUDAFFunction(
- "array_sum",
- new HiveFunctionWrapper("hivemall.tools.array.ArraySumUDAF"),
- Seq(ar).map(df(_).expr),
- isUDAFBridgeRequired = true)
- .toAggregateExpression()
- toDF(Alias(udaf, udaf.prettyName)() :: Nil)
- }
-
- /**
- * @see hivemall.tools.bits.BitsCollectUDAF
- * @group tools.bits
- */
- def bits_collect(x: String): DataFrame = {
- val udaf = HiveUDAFFunction(
- "bits_collect",
- new HiveFunctionWrapper("hivemall.tools.bits.BitsCollectUDAF"),
- Seq(x).map(df(_).expr),
- isUDAFBridgeRequired = false)
- .toAggregateExpression()
- toDF(Alias(udaf, udaf.prettyName)() :: Nil)
- }
-
- /**
- * @see hivemall.tools.list.UDAFToOrderedList
- * @group tools.list
- */
- @scala.annotation.varargs
- def to_ordered_list(value: String, others: String*): DataFrame = {
- val udaf = HiveUDAFFunction(
- "to_ordered_list",
- new HiveFunctionWrapper("hivemall.tools.list.UDAFToOrderedList"),
- (value +: others).map(df(_).expr),
- isUDAFBridgeRequired = false)
- .toAggregateExpression()
- toDF(Alias(udaf, udaf.prettyName)() :: Nil)
- }
-
- /**
- * @see hivemall.tools.map.UDAFToMap
- * @group tools.map
- */
- def to_map(key: String, value: String): DataFrame = {
- val udaf = HiveUDAFFunction(
- "to_map",
- new HiveFunctionWrapper("hivemall.tools.map.UDAFToMap"),
- Seq(key, value).map(df(_).expr),
- isUDAFBridgeRequired = false)
- .toAggregateExpression()
- toDF(Alias(udaf, udaf.prettyName)() :: Nil)
- }
-
- /**
- * @see hivemall.tools.map.UDAFToOrderedMap
- * @group tools.map
- */
- @scala.annotation.varargs
- def to_ordered_map(key: String, value: String, others: String*): DataFrame = {
- val udaf = HiveUDAFFunction(
- "to_ordered_map",
- new HiveFunctionWrapper("hivemall.tools.map.UDAFToOrderedMap"),
- (key +: value +: others).map(df(_).expr),
- isUDAFBridgeRequired = false)
- .toAggregateExpression()
- toDF(Alias(udaf, udaf.prettyName)() :: Nil)
- }
-
- /**
- * @see hivemall.tools.matrix.TransposeAndDotUDAF
- * @group tools.matrix
- */
- def transpose_and_dot(matrix0_row: String, matrix1_row: String): DataFrame = {
- val udaf = HiveUDAFFunction(
- "transpose_and_dot",
- new HiveFunctionWrapper("hivemall.tools.matrix.TransposeAndDotUDAF"),
- Seq(matrix0_row, matrix1_row).map(df(_).expr),
- isUDAFBridgeRequired = false)
- .toAggregateExpression()
- toDF(Alias(udaf, udaf.prettyName)() :: Nil)
- }
-
- /**
- * @see hivemall.tools.math.L2NormUDAF
- * @group tools.math
- */
- def l2_norm(xi: String): DataFrame = {
- val udaf = HiveUDAFFunction(
- "l2_norm",
- new HiveFunctionWrapper("hivemall.tools.math.L2NormUDAF"),
- Seq(xi).map(df(_).expr),
- isUDAFBridgeRequired = true)
- .toAggregateExpression()
- toDF(Alias(udaf, udaf.prettyName)() :: Nil)
- }
-
- /**
- * [[RelationalGroupedDataset]] has the three values as private fields, so, to inject Hivemall
- * aggregate functions, we fetch them via Java Reflections.
- */
- private val df = getPrivateField[DataFrame]("org$apache$spark$sql$RelationalGroupedDataset$$df")
- private val groupingExprs = getPrivateField[Seq[Expression]]("groupingExprs")
- private val groupType = getPrivateField[RelationalGroupedDataset.GroupType]("groupType")
-
- private def getPrivateField[T](name: String): T = {
- val field = groupBy.getClass.getDeclaredField(name)
- field.setAccessible(true)
- field.get(groupBy).asInstanceOf[T]
- }
-
- private def toDF(aggExprs: Seq[Expression]): DataFrame = {
- val aggregates = if (df.sqlContext.conf.dataFrameRetainGroupColumns) {
- groupingExprs ++ aggExprs
- } else {
- aggExprs
- }
-
- val aliasedAgg = aggregates.map(alias)
-
- groupType match {
- case RelationalGroupedDataset.GroupByType =>
- Dataset.ofRows(
- df.sparkSession, Aggregate(groupingExprs, aliasedAgg, df.logicalPlan))
- case RelationalGroupedDataset.RollupType =>
- Dataset.ofRows(
- df.sparkSession, Aggregate(Seq(Rollup(groupingExprs)), aliasedAgg, df.logicalPlan))
- case RelationalGroupedDataset.CubeType =>
- Dataset.ofRows(
- df.sparkSession, Aggregate(Seq(Cube(groupingExprs)), aliasedAgg, df.logicalPlan))
- case RelationalGroupedDataset.PivotType(pivotCol, values) =>
- val aliasedGrps = groupingExprs.map(alias)
- Dataset.ofRows(
- df.sparkSession, Pivot(aliasedGrps, pivotCol, values, aggExprs, df.logicalPlan))
- }
- }
-
- private def alias(expr: Expression): NamedExpression = expr match {
- case u: UnresolvedAttribute => UnresolvedAlias(u)
- case expr: NamedExpression => expr
- case expr: Expression => Alias(expr, expr.prettyName)()
- }
-
- private def checkType(colName: String, expected: DataType) = {
- val dataType = df.resolve(colName).dataType
- if (dataType != expected) {
- throw new AnalysisException(
- s""""$colName" must be $expected, however it is $dataType""")
- }
- }
-}
-
-object HivemallGroupedDataset {
-
- /**
- * Implicitly inject the [[HivemallGroupedDataset]] into [[RelationalGroupedDataset]].
- */
- implicit def relationalGroupedDatasetToHivemallOne(
- groupBy: RelationalGroupedDataset): HivemallGroupedDataset = {
- new HivemallGroupedDataset(groupBy)
- }
-}
diff --git a/spark/spark-2.2/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala b/spark/spark-2.2/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala
deleted file mode 100644
index 4a97d38..0000000
--- a/spark/spark-2.2/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala
+++ /dev/null
@@ -1,2260 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.spark.sql.hive
-
-import java.util.UUID
-
-import org.apache.spark.annotation.Experimental
-import org.apache.spark.internal.Logging
-import org.apache.spark.ml.feature.HivemallFeature
-import org.apache.spark.ml.linalg.{DenseVector, SparseVector, VectorUDT}
-import org.apache.spark.sql._
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
-import org.apache.spark.sql.catalyst.encoders.RowEncoder
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.Inner
-import org.apache.spark.sql.catalyst.plans.logical.{Generate, JoinTopK, LogicalPlan}
-import org.apache.spark.sql.execution.UserProvidedPlanner
-import org.apache.spark.sql.execution.datasources.csv.{CsvToStruct, StructToCsv}
-import org.apache.spark.sql.functions._
-import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.types.UTF8String
-
-
-/**
- * Hivemall wrapper and some utility functions for DataFrame. These functions below derives
- * from `resources/ddl/define-all-as-permanent.hive`.
- *
- * @groupname regression
- * @groupname classifier
- * @groupname classifier.multiclass
- * @groupname recommend
- * @groupname topicmodel
- * @groupname geospatial
- * @groupname smile
- * @groupname xgboost
- * @groupname anomaly
- * @groupname knn.similarity
- * @groupname knn.distance
- * @groupname knn.lsh
- * @groupname ftvec
- * @groupname ftvec.amplify
- * @groupname ftvec.hashing
- * @groupname ftvec.paring
- * @groupname ftvec.scaling
- * @groupname ftvec.selection
- * @groupname ftvec.conv
- * @groupname ftvec.trans
- * @groupname ftvec.ranking
- * @groupname tools
- * @groupname tools.array
- * @groupname tools.bits
- * @groupname tools.compress
- * @groupname tools.map
- * @groupname tools.text
- * @groupname misc
- *
- * A list of unsupported functions is as follows:
- * * smile
- * - guess_attribute_types
- * * mapred functions
- * - taskid
- * - jobid
- * - rownum
- * - distcache_gets
- * - jobconf_gets
- * * matrix factorization
- * - mf_predict
- * - train_mf_sgd
- * - train_mf_adagrad
- * - train_bprmf
- * - bprmf_predict
- * * Factorization Machine
- * - fm_predict
- * - train_fm
- * - train_ffm
- * - ffm_predict
- */
-final class HivemallOps(df: DataFrame) extends Logging {
- import internal.HivemallOpsImpl._
-
- private lazy val _sparkSession = df.sparkSession
- private lazy val _strategy = new UserProvidedPlanner(_sparkSession.sqlContext.conf)
-
- /**
- * @see [[hivemall.regression.GeneralRegressorUDTF]]
- * @group regression
- */
- @scala.annotation.varargs
- def train_regressor(exprs: Column*): DataFrame = withTypedPlan {
- planHiveGenericUDTF(
- df,
- "hivemall.regression.GeneralRegressorUDTF",
- "train_regressor",
- setMixServs(toHivemallFeatures(exprs)),
- Seq("feature", "weight")
- )
- }
-
- /**
- * @see [[hivemall.regression.AdaDeltaUDTF]]
- * @group regression
- */
- @scala.annotation.varargs
- def train_adadelta_regr(exprs: Column*): DataFrame = withTypedPlan {
- planHiveGenericUDTF(
- df,
- "hivemall.regression.AdaDeltaUDTF",
- "train_adadelta_regr",
- setMixServs(toHivemallFeatures(exprs)),
- Seq("feature", "weight")
- )
- }
-
- /**
- * @see [[hivemall.regression.AdaGradUDTF]]
- * @group regression
- */
- @scala.annotation.varargs
- def train_adagrad_regr(exprs: Column*): DataFrame = withTypedPlan {
- planHiveGenericUDTF(
- df,
- "hivemall.regression.AdaGradUDTF",
- "train_adagrad_regr",
- setMixServs(toHivemallFeatures(exprs)),
- Seq("feature", "weight")
- )
- }
-
- /**
- * @see [[hivemall.regression.AROWRegressionUDTF]]
- * @group regression
- */
- @scala.annotation.varargs
- def train_arow_regr(exprs: Column*): DataFrame = withTypedPlan {
- planHiveGenericUDTF(
- df,
- "hivemall.regression.AROWRegressionUDTF",
- "train_arow_regr",
- setMixServs(toHivemallFeatures(exprs)),
- Seq("feature", "weight", "conv")
- )
- }
-
- /**
- * @see [[hivemall.regression.AROWRegressionUDTF.AROWe]]
- * @group regression
- */
- @scala.annotation.varargs
- def train_arowe_regr(exprs: Column*): DataFrame = withTypedPlan {
- planHiveGenericUDTF(
- df,
- "hivemall.regression.AROWRegressionUDTF$AROWe",
- "train_arowe_regr",
- setMixServs(toHivemallFeatures(exprs)),
- Seq("feature", "weight", "conv")
- )
- }
-
- /**
- * @see [[hivemall.regression.AROWRegressionUDTF.AROWe2]]
- * @group regression
- */
- @scala.annotation.varargs
- def train_arowe2_regr(exprs: Column*): DataFrame = withTypedPlan {
- planHiveGenericUDTF(
- df,
- "hivemall.regression.AROWRegressionUDTF$AROWe2",
- "train_arowe2_regr",
- setMixServs(toHivemallFeatures(exprs)),
- Seq("feature", "weight", "conv")
- )
- }
-
- /**
- * @see [[hivemall.regression.LogressUDTF]]
- * @group regression
- */
- @scala.annotation.varargs
- def train_logistic_regr(exprs: Column*): DataFrame = withTypedPlan {
- planHiveGenericUDTF(
- df,
- "hivemall.regression.LogressUDTF",
- "train_logistic_regr",
- setMixServs(toHivemallFeatures(exprs)),
- Seq("feature", "weight")
- )
- }
-
- /**
- * @see [[hivemall.regression.PassiveAggressiveRegressionUDTF]]
- * @group regression
- */
- @scala.annotation.varargs
- def train_pa1_regr(exprs: Column*): DataFrame = withTypedPlan {
- planHiveGenericUDTF(
- df,
- "hivemall.regression.PassiveAggressiveRegressionUDTF",
- "train_pa1_regr",
- setMixServs(toHivemallFeatures(exprs)),
- Seq("feature", "weight")
- )
- }
-
- /**
- * @see [[hivemall.regression.PassiveAggressiveRegressionUDTF.PA1a]]
- * @group regression
- */
- @scala.annotation.varargs
- def train_pa1a_regr(exprs: Column*): DataFrame = withTypedPlan {
- planHiveGenericUDTF(
- df,
- "hivemall.regression.PassiveAggressiveRegressionUDTF$PA1a",
- "train_pa1a_regr",
- setMixServs(toHivemallFeatures(exprs)),
- Seq("feature", "weight")
- )
- }
-
- /**
- * @see [[hivemall.regression.PassiveAggressiveRegressionUDTF.PA2]]
- * @group regression
- */
- @scala.annotation.varargs
- def train_pa2_regr(exprs: Column*): DataFrame = withTypedPlan {
- planHiveGenericUDTF(
- df,
- "hivemall.regression.PassiveAggressiveRegressionUDTF$PA2",
- "train_pa2_regr",
- setMixServs(toHivemallFeatures(exprs)),
- Seq("feature", "weight")
- )
- }
-
- /**
- * @see [[hivemall.regression.PassiveAggressiveRegressionUDTF.PA2a]]
- * @group regression
- */
- @scala.annotation.varargs
- def train_pa2a_regr(exprs: Column*): DataFrame = withTypedPlan {
- planHiveGenericUDTF(
- df,
- "hivemall.regression.PassiveAggressiveRegressionUDTF$PA2a",
- "train_pa2a_regr",
- setMixServs(toHivemallFeatures(exprs)),
- Seq("feature", "weight")
- )
- }
-
- /**
- * @see [[hivemall.classifier.GeneralClassifierUDTF]]
- * @group classifier
- */
- @scala.annotation.varargs
- def train_classifier(exprs: Column*): DataFrame = withTypedPlan {
- planHiveGenericUDTF(
- df,
- "hivemall.classifier.GeneralClassifierUDTF",
- "train_classifier",
- setMixServs(toHivemallFeatures(exprs)),
- Seq("feature", "weight")
- )
- }
-
- /**
- * @see [[hivemall.classifier.PerceptronUDTF]]
- * @group classifier
- */
- @scala.annotation.varargs
- def train_perceptron(exprs: Column*): DataFrame = withTypedPlan {
- planHiveGenericUDTF(
- df,
- "hivemall.classifier.PerceptronUDTF",
- "train_perceptron",
- setMixServs(toHivemallFeatures(exprs)),
- Seq("feature", "weight")
- )
- }
-
- /**
- * @see [[hivemall.classifier.PassiveAggressiveUDTF]]
- * @group classifier
- */
- @scala.annotation.varargs
- def train_pa(exprs: Column*): DataFrame = withTypedPlan {
- planHiveGenericUDTF(
- df,
- "hivemall.classifier.PassiveAggressiveUDTF",
- "train_pa",
- setMixServs(toHivemallFeatures(exprs)),
- Seq("feature", "weight")
- )
- }
-
- /**
- * @see [[hivemall.classifier.PassiveAggressiveUDTF.PA1]]
- * @group classifier
- */
- @scala.annotation.varargs
- def train_pa1(exprs: Column*): DataFrame = withTypedPlan {
- planHiveGenericUDTF(
- df,
- "hivemall.classifier.PassiveAggressiveUDTF$PA1",
- "train_pa1",
- setMixServs(toHivemallFeatures(exprs)),
- Seq("feature", "weight")
- )
- }
-
- /**
- * @see [[hivemall.classifier.PassiveAggressiveUDTF.PA2]]
- * @group classifier
- */
- @scala.annotation.varargs
- def train_pa2(exprs: Column*): DataFrame = withTypedPlan {
- planHiveGenericUDTF(
- df,
- "hivemall.classifier.PassiveAggressiveUDTF$PA2",
- "train_pa2",
- setMixServs(toHivemallFeatures(exprs)),
- Seq("feature", "weight")
- )
- }
-
- /**
- * @see [[hivemall.classifier.ConfidenceWeightedUDTF]]
- * @group classifier
- */
- @scala.annotation.varargs
- def train_cw(exprs: Column*): DataFrame = withTypedPlan {
- planHiveGenericUDTF(
- df,
- "hivemall.classifier.ConfidenceWeightedUDTF",
- "train_cw",
- setMixServs(toHivemallFeatures(exprs)),
- Seq("feature", "weight", "conv")
- )
- }
-
- /**
- * @see [[hivemall.classifier.AROWClassifierUDTF]]
- * @group classifier
- */
- @scala.annotation.varargs
- def train_arow(exprs: Column*): DataFrame = withTypedPlan {
- planHiveGenericUDTF(
- df,
- "hivemall.classifier.AROWClassifierUDTF",
- "train_arow",
- setMixServs(toHivemallFeatures(exprs)),
- Seq("feature", "weight", "conv")
- )
- }
-
- /**
- * @see [[hivemall.classifier.AROWClassifierUDTF.AROWh]]
- * @group classifier
- */
- @scala.annotation.varargs
- def train_arowh(exprs: Column*): DataFrame = withTypedPlan {
- planHiveGenericUDTF(
- df,
- "hivemall.classifier.AROWClassifierUDTF$AROWh",
- "train_arowh",
- setMixServs(toHivemallFeatures(exprs)),
- Seq("feature", "weight", "conv")
- )
- }
-
- /**
- * @see [[hivemall.classifier.SoftConfideceWeightedUDTF.SCW1]]
- * @group classifier
- */
- @scala.annotation.varargs
- def train_scw(exprs: Column*): DataFrame = withTypedPlan {
- planHiveGenericUDTF(
- df,
- "hivemall.classifier.SoftConfideceWeightedUDTF$SCW1",
- "train_scw",
- setMixServs(toHivemallFeatures(exprs)),
- Seq("feature", "weight", "conv")
- )
- }
-
- /**
- * @see [[hivemall.classifier.SoftConfideceWeightedUDTF.SCW1]]
- * @group classifier
- */
- @scala.annotation.varargs
- def train_scw2(exprs: Column*): DataFrame = withTypedPlan {
- planHiveGenericUDTF(
- df,
- "hivemall.classifier.SoftConfideceWeightedUDTF$SCW2",
- "train_scw2",
- setMixServs(toHivemallFeatures(exprs)),
- Seq("feature", "weight", "conv")
- )
- }
-
- /**
- * @see [[hivemall.classifier.AdaGradRDAUDTF]]
- * @group classifier
- */
- @scala.annotation.varargs
- def train_adagrad_rda(exprs: Column*): DataFrame = withTypedPlan {
- planHiveGenericUDTF(
- df,
- "hivemall.classifier.AdaGradRDAUDTF",
- "train_adagrad_rda",
- setMixServs(toHivemallFeatures(exprs)),
- Seq("feature", "weight")
- )
- }
-
- /**
- * @see [[hivemall.classifier.KernelExpansionPassiveAggressiveUDTF]]
- * @group classifier
- */
- @scala.annotation.varargs
- def train_kpa(exprs: Column*): DataFrame = withTypedPlan {
- planHiveGenericUDTF(
- df,
- "hivemall.classifier.KernelExpansionPassiveAggressiveUDTF",
- "train_kpa",
- setMixServs(toHivemallFeatures(exprs)),
- Seq("h", "hk", "w0", "w1", "w2", "w3")
- )
- }
-
- /**
- * @see [[hivemall.classifier.multiclass.MulticlassPerceptronUDTF]]
- * @group classifier.multiclass
- */
- @scala.annotation.varargs
- def train_multiclass_perceptron(exprs: Column*): DataFrame = withTypedPlan {
- planHiveGenericUDTF(
- df,
- "hivemall.classifier.multiclass.MulticlassPerceptronUDTF",
- "train_multiclass_perceptron",
- setMixServs(toHivemallFeatures(exprs)),
- Seq("label", "feature", "weight")
- )
- }
-
- /**
- * @see [[hivemall.classifier.multiclass.MulticlassPassiveAggressiveUDTF]]
- * @group classifier.multiclass
- */
- @scala.annotation.varargs
- def train_multiclass_pa(exprs: Column*): DataFrame = withTypedPlan {
- planHiveGenericUDTF(
- df,
- "hivemall.classifier.multiclass.MulticlassPassiveAggressiveUDTF",
- "train_multiclass_pa",
- setMixServs(toHivemallFeatures(exprs)),
- Seq("label", "feature", "weight")
- )
- }
-
- /**
- * @see [[hivemall.classifier.multiclass.MulticlassPassiveAggressiveUDTF.PA1]]
- * @group classifier.multiclass
- */
- @scala.annotation.varargs
- def train_multiclass_pa1(exprs: Column*): DataFrame = withTypedPlan {
- planHiveGenericUDTF(
- df,
- "hivemall.classifier.multiclass.MulticlassPassiveAggressiveUDTF$PA1",
- "train_multiclass_pa1",
- setMixServs(toHivemallFeatures(exprs)),
- Seq("label", "feature", "weight")
- )
- }
-
- /**
- * @see [[hivemall.classifier.multiclass.MulticlassPassiveAggressiveUDTF.PA2]]
- * @group classifier.multiclass
- */
- @scala.annotation.varargs
- def train_multiclass_pa2(exprs: Column*): DataFrame = withTypedPlan {
- planHiveGenericUDTF(
- df,
- "hivemall.classifier.multiclass.MulticlassPassiveAggressiveUDTF$PA2",
- "train_multiclass_pa2",
- setMixServs(toHivemallFeatures(exprs)),
- Seq("label", "feature", "weight")
- )
- }
-
- /**
- * @see [[hivemall.classifier.multiclass.MulticlassConfidenceWeightedUDTF]]
- * @group classifier.multiclass
- */
- @scala.annotation.varargs
- def train_multiclass_cw(exprs: Column*): DataFrame = withTypedPlan {
- planHiveGenericUDTF(
- df,
- "hivemall.classifier.multiclass.MulticlassConfidenceWeightedUDTF",
- "train_multiclass_cw",
- setMixServs(toHivemallFeatures(exprs)),
- Seq("label", "feature", "weight", "conv")
- )
- }
-
- /**
- * @see [[hivemall.classifier.multiclass.MulticlassAROWClassifierUDTF]]
- * @group classifier.multiclass
- */
- @scala.annotation.varargs
- def train_multiclass_arow(exprs: Column*): DataFrame = withTypedPlan {
- planHiveGenericUDTF(
- df,
- "hivemall.classifier.multiclass.MulticlassAROWClassifierUDTF",
- "train_multiclass_arow",
- setMixServs(toHivemallFeatures(exprs)),
- Seq("label", "feature", "weight", "conv")
- )
- }
-
- /**
- * @see [[hivemall.classifier.multiclass.MulticlassAROWClassifierUDTF.AROWh]]
- * @group classifier.multiclass
- */
- @scala.annotation.varargs
- def train_multiclass_arowh(exprs: Column*): DataFrame = withTypedPlan {
- planHiveGenericUDTF(
- df,
- "hivemall.classifier.multiclass.MulticlassAROWClassifierUDTF$AROWh",
- "train_multiclass_arowh",
- setMixServs(toHivemallFeatures(exprs)),
- Seq("label", "feature", "weight", "conv")
- )
- }
-
- /**
- * @see [[hivemall.classifier.multiclass.MulticlassSoftConfidenceWeightedUDTF.SCW1]]
- * @group classifier.multiclass
- */
- @scala.annotation.varargs
- def train_multiclass_scw(exprs: Column*): DataFrame = withTypedPlan {
- planHiveGenericUDTF(
- df,
- "hivemall.classifier.multiclass.MulticlassSoftConfidenceWeightedUDTF$SCW1",
- "train_multiclass_scw",
- setMixServs(toHivemallFeatures(exprs)),
- Seq("label", "feature", "weight", "conv")
- )
- }
-
- /**
- * @see [[hivemall.classifier.multiclass.MulticlassSoftConfidenceWeightedUDTF.SCW2]]
- * @group classifier.multiclass
- */
- @scala.annotation.varargs
- def train_multiclass_scw2(exprs: Column*): DataFrame = withTypedPlan {
- planHiveGenericUDTF(
- df,
- "hivemall.classifier.multiclass.MulticlassSoftConfidenceWeightedUDTF$SCW2",
- "train_multiclass_scw2",
- setMixServs(toHivemallFeatures(exprs)),
- Seq("label", "feature", "weight", "conv")
- )
- }
-
- /**
- * @see [[hivemall.recommend.SlimUDTF]]
- * @group recommend
- */
- @scala.annotation.varargs
- def train_slim(exprs: Column*): DataFrame = withTypedPlan {
- planHiveGenericUDTF(
- df,
- "hivemall.recommend.SlimUDTF",
- "train_slim",
- setMixServs(toHivemallFeatures(exprs)),
- Seq("j", "nn", "w")
- )
- }
-
- /**
- * @see [[hivemall.topicmodel.LDAUDTF]]
- * @group topicmodel
- */
- @scala.annotation.varargs
- def train_lda(exprs: Column*): DataFrame = withTypedPlan {
- planHiveGenericUDTF(
- df,
- "hivemall.topicmodel.LDAUDTF",
- "train_lda",
- setMixServs(toHivemallFeatures(exprs)),
- Seq("topic", "word", "score")
- )
- }
-
- /**
- * @see [[hivemall.topicmodel.PLSAUDTF]]
- * @group topicmodel
- */
- @scala.annotation.varargs
- def train_plsa(exprs: Column*): DataFrame = withTypedPlan {
- planHiveGenericUDTF(
- df,
- "hivemall.topicmodel.PLSAUDTF",
- "train_plsa",
- setMixServs(toHivemallFeatures(exprs)),
- Seq("topic", "word", "score")
- )
- }
-
- /**
- * @see [[hivemall.smile.regression.RandomForestRegressionUDTF]]
- * @group smile
- */
- @scala.annotation.varargs
- def train_randomforest_regressor(exprs: Column*): DataFrame = withTypedPlan {
- planHiveGenericUDTF(
- df,
- "hivemall.smile.regression.RandomForestRegressionUDTF",
- "train_randomforest_regressor",
- setMixServs(toHivemallFeatures(exprs)),
- Seq("model_id", "model_type", "pred_model", "var_importance", "oob_errors", "oob_tests")
- )
- }
-
- /**
- * @see [[hivemall.smile.classification.RandomForestClassifierUDTF]]
- * @group smile
- */
- @scala.annotation.varargs
- def train_randomforest_classifier(exprs: Column*): DataFrame = withTypedPlan {
- planHiveGenericUDTF(
- df,
- "hivemall.smile.classification.RandomForestClassifierUDTF",
- "train_randomforest_classifier",
- setMixServs(toHivemallFeatures(exprs)),
- Seq("model_id", "model_type", "pred_model", "var_importance", "oob_errors", "oob_tests")
- )
- }
-
- /**
- * :: Experimental ::
- * @see [[hivemall.xgboost.regression.XGBoostRegressionUDTF]]
- * @group xgboost
- */
- @Experimental
- @scala.annotation.varargs
- def train_xgboost_regr(exprs: Column*): DataFrame = withTypedPlan {
- planHiveGenericUDTF(
- df,
- "hivemall.xgboost.regression.XGBoostRegressionUDTF",
- "train_xgboost_regr",
- setMixServs(toHivemallFeatures(exprs)),
- Seq("model_id", "pred_model")
- )
- }
-
- /**
- * :: Experimental ::
- * @see [[hivemall.xgboost.classification.XGBoostBinaryClassifierUDTF]]
- * @group xgboost
- */
- @Experimental
- @scala.annotation.varargs
- def train_xgboost_classifier(exprs: Column*): DataFrame = withTypedPlan {
- planHiveGenericUDTF(
- df,
- "hivemall.xgboost.classification.XGBoostBinaryClassifierUDTF",
- "train_xgboost_classifier",
- setMixServs(toHivemallFeatures(exprs)),
- Seq("model_id", "pred_model")
- )
- }
-
- /**
- * :: Experimental ::
- * @see [[hivemall.xgboost.classification.XGBoostMulticlassClassifierUDTF]]
- * @group xgboost
- */
- @Experimental
- @scala.annotation.varargs
- def train_xgboost_multiclass_classifier(exprs: Column*): DataFrame = withTypedPlan {
- planHiveGenericUDTF(
- df,
- "hivemall.xgboost.classification.XGBoostMulticlassClassifierUDTF",
- "train_xgboost_multiclass_classifier",
- setMixServs(toHivemallFeatures(exprs)),
- Seq("model_id", "pred_model")
- )
- }
-
- /**
- * :: Experimental ::
- * @see [[hivemall.xgboost.tools.XGBoostPredictUDTF]]
- * @group xgboost
- */
- @Experimental
- @scala.annotation.varargs
- def xgboost_predict(exprs: Column*): DataFrame = withTypedPlan {
- planHiveGenericUDTF(
- df,
- "hivemall.xgboost.tools.XGBoostPredictUDTF",
- "xgboost_predict",
- setMixServs(toHivemallFeatures(exprs)),
- Seq("rowid", "predicted")
- )
- }
-
- /**
- * :: Experimental ::
- * @see [[hivemall.xgboost.tools.XGBoostMulticlassPredictUDTF]]
- * @group xgboost
- */
- @Experimental
- @scala.annotation.varargs
- def xgboost_multiclass_predict(exprs: Column*): DataFrame = withTypedPlan {
- planHiveGenericUDTF(
- df,
- "hivemall.xgboost.tools.XGBoostMulticlassPredictUDTF",
- "xgboost_multiclass_predict",
- setMixServs(toHivemallFeatures(exprs)),
- Seq("rowid", "label", "probability")
- )
- }
-
- /**
- * @see [[hivemall.knn.similarity.DIMSUMMapperUDTF]]
- * @group knn.similarity
- */
- @scala.annotation.varargs
- def dimsum_mapper(exprs: Column*): DataFrame = withTypedPlan {
- planHiveGenericUDTF(
- df,
- "hivemall.knn.similarity.DIMSUMMapperUDTF",
- "dimsum_mapper",
- exprs,
- Seq("j", "k", "b_jk")
- )
- }
-
- /**
- * @see [[hivemall.knn.lsh.MinHashUDTF]]
- * @group knn.lsh
- */
- @scala.annotation.varargs
- def minhash(exprs: Column*): DataFrame = withTypedPlan {
- planHiveGenericUDTF(
- df,
- "hivemall.knn.lsh.MinHashUDTF",
- "minhash",
- exprs,
- Seq("clusterid", "item")
- )
- }
-
- /**
- * @see [[hivemall.ftvec.amplify.AmplifierUDTF]]
- * @group ftvec.amplify
- */
- @scala.annotation.varargs
- def amplify(exprs: Column*): DataFrame = withTypedPlan {
- planHiveGenericUDTF(
- df,
- "hivemall.ftvec.amplify.AmplifierUDTF",
- "amplify",
- exprs,
- Seq("clusterid", "item")
- )
- }
-
- /**
- * @see [[hivemall.ftvec.amplify.RandomAmplifierUDTF]]
- * @group ftvec.amplify
- */
- @scala.annotation.varargs
- def rand_amplify(exprs: Column*): DataFrame = withTypedPlan {
- throw new UnsupportedOperationException("`rand_amplify` not supported yet")
- }
-
- /**
- * Amplifies and shuffle data inside partitions.
- * @group ftvec.amplify
- */
- def part_amplify(xtimes: Column): DataFrame = {
- val xtimesInt = xtimes.expr match {
- case Literal(v: Any, IntegerType) => v.asInstanceOf[Int]
- case e => throw new AnalysisException("`xtimes` must be integer, however " + e)
- }
- val rdd = df.rdd.mapPartitions({ iter =>
- val elems = iter.flatMap{ row =>
- Seq.fill[Row](xtimesInt)(row)
- }
- // Need to check how this shuffling affects results
- scala.util.Random.shuffle(elems)
- }, true)
- df.sqlContext.createDataFrame(rdd, df.schema)
- }
-
- /**
- * Quantifies input columns.
- * @see [[hivemall.ftvec.conv.QuantifyColumnsUDTF]]
- * @group ftvec.conv
- */
- @scala.annotation.varargs
- def quantify(exprs: Column*): DataFrame = withTypedPlan {
- planHiveGenericUDTF(
- df,
- "hivemall.ftvec.conv.QuantifyColumnsUDTF",
- "quantify",
- exprs,
- (0 until exprs.size - 1).map(i => s"c$i")
- )
- }
-
- /**
- * @see [[hivemall.ftvec.trans.BinarizeLabelUDTF]]
- * @group ftvec.trans
- */
- @scala.annotation.varargs
- def binarize_label(exprs: Column*): DataFrame = withTypedPlan {
- planHiveGenericUDTF(
- df,
- "hivemall.ftvec.trans.BinarizeLabelUDTF",
- "binarize_label",
- exprs,
- (0 until exprs.size - 1).map(i => s"c$i")
- )
- }
-
- /**
- * @see [[hivemall.ftvec.trans.QuantifiedFeaturesUDTF]]
- * @group ftvec.trans
- */
- @scala.annotation.varargs
- def quantified_features(exprs: Column*): DataFrame = withTypedPlan {
- planHiveGenericUDTF(
- df,
- "hivemall.ftvec.trans.QuantifiedFeaturesUDTF",
- "quantified_features",
- exprs,
- Seq("features")
- )
- }
-
- /**
- * @see [[hivemall.ftvec.ranking.BprSamplingUDTF]]
- * @group ftvec.ranking
- */
- @scala.annotation.varargs
- def bpr_sampling(exprs: Column*): DataFrame = withTypedPlan {
- planHiveGenericUDTF(
- df,
- "hivemall.ftvec.ranking.BprSamplingUDTF",
- "bpr_sampling",
- exprs,
- Seq("user", "pos_item", "neg_item")
- )
- }
-
- /**
- * @see [[hivemall.ftvec.ranking.ItemPairsSamplingUDTF]]
- * @group ftvec.ranking
- */
- @scala.annotation.varargs
- def item_pairs_sampling(exprs: Column*): DataFrame = withTypedPlan {
- planHiveGenericUDTF(
- df,
- "hivemall.ftvec.ranking.ItemPairsSamplingUDTF",
- "item_pairs_sampling",
- exprs,
- Seq("pos_item_id", "neg_item_id")
- )
- }
-
- /**
- * @see [[hivemall.ftvec.ranking.PopulateNotInUDTF]]
- * @group ftvec.ranking
- */
- @scala.annotation.varargs
- def populate_not_in(exprs: Column*): DataFrame = withTypedPlan {
- planHiveGenericUDTF(
- df,
- "hivemall.ftvec.ranking.PopulateNotInUDTF",
- "populate_not_in",
- exprs,
- Seq("item")
- )
- }
-
- /**
- * Splits Seq[String] into pieces.
- * @group ftvec
- */
- def explode_array(features: Column): DataFrame = {
- df.explode(features) { case Row(v: Seq[_]) =>
- // Type erasure removes the component type in Seq
- v.map(s => HivemallFeature(s.asInstanceOf[String]))
- }
- }
-
- /**
- * Splits [[Vector]] into pieces.
- * @group ftvec
- */
- def explode_vector(features: Column): DataFrame = {
- val elementSchema = StructType(
- StructField("feature", StringType) :: StructField("weight", DoubleType) :: Nil)
- val explodeFunc: Row => TraversableOnce[InternalRow] = (row: Row) => {
- row.get(0) match {
- case dv: DenseVector =>
- dv.values.zipWithIndex.map {
- case (value, index) =>
- InternalRow(UTF8String.fromString(s"$index"), value)
- }
- case sv: SparseVector =>
- sv.values.zip(sv.indices).map {
- case (value, index) =>
- InternalRow(UTF8String.fromString(s"$index"), value)
- }
- }
- }
- withTypedPlan {
- Generate(
- UserDefinedGenerator(elementSchema, explodeFunc, features.expr :: Nil),
- join = true, outer = false, None,
- generatorOutput = Nil,
- df.logicalPlan)
- }
- }
-
- /**
- * @see [[hivemall.tools.GenerateSeriesUDTF]]
- * @group tools
- */
- def generate_series(start: Column, end: Column): DataFrame = withTypedPlan {
- planHiveGenericUDTF(
- df,
- "hivemall.tools.GenerateSeriesUDTF",
- "generate_series",
- start :: end :: Nil,
- Seq("generate_series")
- )
- }
-
- /**
- * Returns `top-k` records for each `group`.
- * @group misc
- */
- def each_top_k(k: Column, score: Column, group: Column*): DataFrame = withTypedPlan {
- val kInt = k.expr match {
- case Literal(v: Any, IntegerType) => v.asInstanceOf[Int]
- case e => throw new AnalysisException("`k` must be integer, however " + e)
- }
- if (kInt == 0) {
- throw new AnalysisException("`k` must not have 0")
- }
- val clusterDf = df.repartition(group: _*).sortWithinPartitions(group: _*)
- .select(score, Column("*"))
- val analyzedPlan = clusterDf.queryExecution.analyzed
- val inputAttrs = analyzedPlan.output
- val scoreExpr = BindReferences.bindReference(analyzedPlan.expressions.head, inputAttrs)
- val groupNames = group.map { _.expr match {
- case ne: NamedExpression => ne.name
- case ua: UnresolvedAttribute => ua.name
- }}
- val groupExprs = analyzedPlan.expressions.filter {
- case ne: NamedExpression => groupNames.contains(ne.name)
- }.map { e =>
- BindReferences.bindReference(e, inputAttrs)
- }
- val rankField = StructField("rank", IntegerType)
- Generate(
- generator = EachTopK(
- k = kInt,
- scoreExpr = scoreExpr,
- groupExprs = groupExprs,
- elementSchema = StructType(
- rankField +: inputAttrs.map(d => StructField(d.name, d.dataType))
- ),
- children = inputAttrs
- ),
- join = false,
- outer = false,
- qualifier = None,
- generatorOutput = Seq(rankField.name).map(UnresolvedAttribute(_)) ++ inputAttrs,
- child = analyzedPlan
- )
- }
-
- /**
- * :: Experimental ::
- * Joins input two tables with the given keys and the top-k highest `score` values.
- * @group misc
- */
- @Experimental
- def top_k_join(k: Column, right: DataFrame, joinExprs: Column, score: Column)
- : DataFrame = withTypedPlanInCustomStrategy {
- val kInt = k.expr match {
- case Literal(v: Any, IntegerType) => v.asInstanceOf[Int]
- case e => throw new AnalysisException("`k` must be integer, however " + e)
- }
- if (kInt == 0) {
- throw new AnalysisException("`k` must not have 0")
- }
- JoinTopK(kInt, df.logicalPlan, right.logicalPlan, Inner, Option(joinExprs.expr))(score.named)
- }
-
- private def doFlatten(schema: StructType, separator: Char, prefixParts: Seq[String] = Seq.empty)
- : Seq[Column] = {
- schema.fields.flatMap { f =>
- val colNameParts = prefixParts :+ f.name
- f.dataType match {
- case st: StructType =>
- doFlatten(st, separator, colNameParts)
- case _ =>
- col(colNameParts.mkString(".")).as(colNameParts.mkString(separator.toString)) :: Nil
- }
- }
- }
-
- // Converts string representation of a character to actual character
- @throws[IllegalArgumentException]
- private def toChar(str: String): Char = {
- if (str.length == 1) {
- str.charAt(0) match {
- case '$' | '_' | '.' => str.charAt(0)
- case _ => throw new IllegalArgumentException(
- "Must use '$', '_', or '.' for separator, but got " + str)
- }
- } else {
- throw new IllegalArgumentException(
- s"Separator cannot be more than one character: $str")
- }
- }
-
- /**
- * Flattens a nested schema into a flat one.
- * @group misc
- *
- * For example:
- * {{{
- * scala> val df = Seq((0, (1, (3.0, "a")), (5, 0.9))).toDF()
- * scala> df.printSchema
- * root
- * |-- _1: integer (nullable = false)
- * |-- _2: struct (nullable = true)
- * | |-- _1: integer (nullable = false)
- * | |-- _2: struct (nullable = true)
- * | | |-- _1: double (nullable = false)
- * | | |-- _2: string (nullable = true)
- * |-- _3: struct (nullable = true)
- * | |-- _1: integer (nullable = false)
- * | |-- _2: double (nullable = false)
- *
- * scala> df.flatten(separator = "$").printSchema
- * root
- * |-- _1: integer (nullable = false)
- * |-- _2$_1: integer (nullable = true)
- * |-- _2$_2$_1: double (nullable = true)
- * |-- _2$_2$_2: string (nullable = true)
- * |-- _3$_1: integer (nullable = true)
- * |-- _3$_2: double (nullable = true)
- * }}}
- */
- def flatten(separator: String = "$"): DataFrame =
- df.select(doFlatten(df.schema, toChar(separator)): _*)
-
- /**
- * @see [[hivemall.dataset.LogisticRegressionDataGeneratorUDTF]]
- * @group misc
- */
- @scala.annotation.varargs
- def lr_datagen(exprs: Column*): Dataset[Row] = withTypedPlan {
- planHiveGenericUDTF(
- df,
- "hivemall.dataset.LogisticRegressionDataGeneratorUDTFWrapper",
- "lr_datagen",
- exprs,
- Seq("label", "features")
- )
- }
-
- /**
- * Returns all the columns as Seq[Column] in this [[DataFrame]].
- */
- private[sql] def cols: Seq[Column] = {
- df.schema.fields.map(col => df.col(col.name)).toSeq
- }
-
- /**
- * :: Experimental ::
- * If a parameter '-mix' does not exist in a 3rd argument,
- * set it from an environmental variable
- * 'HIVEMALL_MIX_SERVERS'.
- *
- * TODO: This could work if '--deploy-mode' has 'client';
- * otherwise, we need to set HIVEMALL_MIX_SERVERS
- * in all possible spark workers.
- */
- @Experimental
- private def setMixServs(exprs: Seq[Column]): Seq[Column] = {
- val mixes = System.getenv("HIVEMALL_MIX_SERVERS")
- if (mixes != null && !mixes.isEmpty()) {
- val groupId = df.sqlContext.sparkContext.applicationId + "-" + UUID.randomUUID
- logInfo(s"set '${mixes}' as default mix servers (session: ${groupId})")
- exprs.size match {
- case 2 => exprs :+ Column(
- Literal.create(s"-mix ${mixes} -mix_session ${groupId}", StringType))
- /** TODO: Add codes in the case where exprs.size == 3. */
- case _ => exprs
- }
- } else {
- exprs
- }
- }
-
- /**
- * If the input is a [[Vector]], transform it into Hivemall features.
- */
- @inline private def toHivemallFeatures(exprs: Seq[Column]): Seq[Column] = {
- df.select(exprs: _*).queryExecution.analyzed.schema.zip(exprs).map {
- case (StructField(_, _: VectorUDT, _, _), c) => HivemallUtils.to_hivemall_features(c)
- case (_, c) => c
- }
- }
-
- /**
- * A convenient function to wrap a logical plan and produce a DataFrame.
- */
- @inline private def withTypedPlan(logicalPlan: => LogicalPlan): DataFrame = {
- val queryExecution = _sparkSession.sessionState.executePlan(logicalPlan)
- val outputSchema = queryExecution.sparkPlan.schema
- new Dataset[Row](df.sparkSession, queryExecution, RowEncoder(outputSchema))
- }
-
- @inline private def withTypedPlanInCustomStrategy(logicalPlan: => LogicalPlan)
- : DataFrame = {
- // Inject custom strategies
- if (!_sparkSession.experimental.extraStrategies.contains(_strategy)) {
- _sparkSession.experimental.extraStrategies = Seq(_strategy)
- }
- withTypedPlan(logicalPlan)
- }
-}
-
-object HivemallOps {
- import internal.HivemallOpsImpl._
-
- /**
- * Implicitly inject the [[HivemallOps]] into [[DataFrame]].
- */
- implicit def dataFrameToHivemallOps(df: DataFrame): HivemallOps =
- new HivemallOps(df)
-
- /**
- * @see [[hivemall.HivemallVersionUDF]]
- * @group misc
- */
- def hivemall_version(): Column = withExpr {
- planHiveUDF(
- "hivemall.HivemallVersionUDF",
- "hivemall_version",
- Nil
- )
- }
-
- /**
- * @see [[hivemall.geospatial.TileUDF]]
- * @group geospatial
- */
- def tile(lat: Column, lon: Column, zoom: Column): Column = withExpr {
- planHiveGenericUDF(
- "hivemall.geospatial.TileUDF",
- "tile",
- lat :: lon :: zoom :: Nil
- )
- }
-
- /**
- * @see [[hivemall.geospatial.MapURLUDF]]
- * @group geospatial
- */
- @scala.annotation.varargs
- def map_url(exprs: Column*): Column = withExpr {
- planHiveGenericUDF(
- "hivemall.geospatial.MapURLUDF",
- "map_url",
- exprs
- )
- }
-
- /**
- * @see [[hivemall.geospatial.Lat2TileYUDF]]
- * @group geospatial
- */
- def lat2tiley(lat: Column, zoom: Column): Column = withExpr {
- planHiveGenericUDF(
- "hivemall.geospatial.Lat2TileYUDF",
- "lat2tiley",
- lat :: zoom :: Nil
- )
- }
-
- /**
- * @see [[hivemall.geospatial.Lon2TileXUDF]]
- * @group geospatial
- */
- def lon2tilex(lon: Column, zoom: Column): Column = withExpr {
- planHiveGenericUDF(
- "hivemall.geospatial.Lon2TileXUDF",
- "lon2tilex",
- lon :: zoom :: Nil
- )
- }
-
- /**
- * @see [[hivemall.geospatial.TileX2LonUDF]]
- * @group geospatial
- */
- def tilex2lon(x: Column, zoom: Column): Column = withExpr {
- planHiveGenericUDF(
- "hivemall.geospatial.TileX2LonUDF",
- "tilex2lon",
- x :: zoom :: Nil
- )
- }
-
- /**
- * @see [[hivemall.geospatial.TileY2LatUDF]]
- * @group geospatial
- */
- def tiley2lat(y: Column, zoom: Column): Column = withExpr {
- planHiveGenericUDF(
- "hivemall.geospatial.TileY2LatUDF",
- "tiley2lat",
- y :: zoom :: Nil
- )
- }
-
- /**
- * @see [[hivemall.geospatial.HaversineDistanceUDF]]
- * @group geospatial
- */
- @scala.annotation.varargs
- def haversine_distance(exprs: Column*): Column = withExpr {
- planHiveGenericUDF(
- "hivemall.geospatial.HaversineDistanceUDF",
- "haversine_distance",
- exprs
- )
- }
-
- /**
- * @see [[hivemall.smile.tools.TreePredictUDF]]
- * @group smile
- */
- @scala.annotation.varargs
- def tree_predict(exprs: Column*): Column = withExpr {
- planHiveGenericUDF(
- "hivemall.smile.tools.TreePredictUDF",
- "tree_predict",
- exprs
- )
- }
-
- /**
- * @see [[hivemall.smile.tools.TreeExportUDF]]
- * @group smile
- */
- @scala.annotation.varargs
- def tree_export(exprs: Column*): Column = withExpr {
- planHiveGenericUDF(
- "hivemall.smile.tools.TreeExportUDF",
- "tree_export",
- exprs
- )
- }
-
- /**
- * @see [[hivemall.anomaly.ChangeFinderUDF]]
- * @group anomaly
- */
- @scala.annotation.varargs
- def changefinder(exprs: Column*): Column = withExpr {
- planHiveGenericUDF(
- "hivemall.anomaly.ChangeFinderUDF",
- "changefinder",
- exprs
- )
- }
-
- /**
- * @see [[hivemall.anomaly.SingularSpectrumTransformUDF]]
- * @group anomaly
- */
- @scala.annotation.varargs
- def sst(exprs: Column*): Column = withExpr {
- planHiveGenericUDF(
- "hivemall.anomaly.SingularSpectrumTransformUDF",
- "sst",
- exprs
- )
- }
-
- /**
- * @see [[hivemall.knn.similarity.CosineSimilarityUDF]]
- * @group knn.similarity
- */
- @scala.annotation.varargs
- def cosine_similarity(exprs: Column*): Column = withExpr {
- planHiveGenericUDF(
- "hivemall.knn.similarity.CosineSimilarityUDF",
- "cosine_similarity",
- exprs
- )
- }
-
- /**
- * @see [[hivemall.knn.similarity.JaccardIndexUDF]]
- * @group knn.similarity
- */
- @scala.annotation.varargs
- def jaccard_similarity(exprs: Column*): Column = withExpr {
- planHiveUDF(
- "hivemall.knn.similarity.JaccardIndexUDF",
- "jaccard_similarity",
- exprs
- )
- }
-
- /**
- * @see [[hivemall.knn.similarity.AngularSimilarityUDF]]
- * @group knn.similarity
- */
- @scala.annotation.varargs
- def angular_similarity(exprs: Column*): Column = withExpr {
- planHiveGenericUDF(
- "hivemall.knn.similarity.AngularSimilarityUDF",
- "angular_similarity",
- exprs
- )
- }
-
- /**
- * @see [[hivemall.knn.similarity.EuclidSimilarity]]
- * @group knn.similarity
- */
- @scala.annotation.varargs
- def euclid_similarity(exprs: Column*): Column = withExpr {
- planHiveGenericUDF(
- "hivemall.knn.similarity.EuclidSimilarity",
- "euclid_similarity",
- exprs
- )
- }
-
- /**
- * @see [[hivemall.knn.similarity.Distance2SimilarityUDF]]
- * @group knn.similarity
- */
- @scala.annotation.varargs
- def distance2similarity(exprs: Column*): Column = withExpr {
- // TODO: Need a wrapper class because of using unsupported types
- planHiveGenericUDF(
- "hivemall.knn.similarity.Distance2SimilarityUDF",
- "distance2similarity",
- exprs
- )
- }
-
- /**
- * @see [[hivemall.knn.distance.HammingDistanceUDF]]
- * @group knn.distance
- */
- @scala.annotation.varargs
- def hamming_distance(exprs: Column*): Column = withExpr {
- planHiveUDF(
- "hivemall.knn.distance.HammingDistanceUDF",
- "hamming_distance",
- exprs
- )
- }
-
- /**
- * @see [[hivemall.knn.distance.PopcountUDF]]
- * @group knn.distance
- */
- @scala.annotation.varargs
- def popcnt(exprs: Column*): Column = withExpr {
- planHiveUDF(
- "hivemall.knn.distance.PopcountUDF",
- "popcnt",
- exprs
- )
- }
-
- /**
- * @see [[hivemall.knn.distance.KLDivergenceUDF]]
- * @group knn.distance
- */
- @scala.annotation.varargs
- def kld(exprs: Column*): Column = withExpr {
- planHiveUDF(
- "hivemall.knn.distance.KLDivergenceUDF",
- "kld",
- exprs
- )
- }
-
- /**
- * @see [[hivemall.knn.distance.EuclidDistanceUDF]]
- * @group knn.distance
- */
- @scala.annotation.varargs
- def euclid_distance(exprs: Column*): Column = withExpr {
- planHiveGenericUDF(
- "hivemall.knn.distance.EuclidDistanceUDF",
- "euclid_distance",
- exprs
- )
- }
-
- /**
- * @see [[hivemall.knn.distance.CosineDistanceUDF]]
- * @group knn.distance
- */
- @scala.annotation.varargs
- def cosine_distance(exprs: Column*): Column = withExpr {
- planHiveGenericUDF(
- "hivemall.knn.distance.CosineDistanceUDF",
- "cosine_distance",
- exprs
- )
- }
-
- /**
- * @see [[hivemall.knn.distance.AngularDistanceUDF]]
- * @group knn.distance
- */
- @scala.annotation.varargs
- def angular_distance(exprs: Column*): Column = withExpr {
- planHiveGenericUDF(
- "hivemall.knn.distance.AngularDistanceUDF",
- "angular_distance",
- exprs
- )
- }
-
- /**
- * @see [[hivemall.knn.distance.JaccardDistanceUDF]]
- * @group knn.distance
- */
- @scala.annotation.varargs
- def jaccard_distance(exprs: Column*): Column = withExpr {
- planHiveUDF(
- "hivemall.knn.distance.JaccardDistanceUDF",
- "jaccard_distance",
- exprs
- )
- }
-
- /**
- * @see [[hivemall.knn.distance.ManhattanDistanceUDF]]
- * @group knn.distance
- */
- @scala.annotation.varargs
- def manhattan_distance(exprs: Column*): Column = withExpr {
- planHiveGenericUDF(
- "hivemall.knn.distance.ManhattanDistanceUDF",
- "manhattan_distance",
- exprs
- )
- }
-
- /**
- * @see [[hivemall.knn.distance.MinkowskiDistanceUDF]]
- * @group knn.distance
- */
- @scala.annotation.varargs
- def minkowski_distance(exprs: Column*): Column = withExpr {
- planHiveGenericUDF(
- "hivemall.knn.distance.MinkowskiDistanceUDF",
- "minkowski_distance",
- exprs
- )
- }
-
- /**
- * @see [[hivemall.knn.lsh.bBitMinHashUDF]]
- * @group knn.lsh
- */
- @scala.annotation.varargs
- def bbit_minhash(exprs: Column*): Column = withExpr {
- planHiveUDF(
- "hivemall.knn.lsh.bBitMinHashUDF",
- "bbit_minhash",
- exprs
- )
- }
-
- /**
- * @see [[hivemall.knn.lsh.MinHashesUDFWrapper]]
- * @group knn.lsh
- */
- @scala.annotation.varargs
- def minhashes(exprs: Column*): Column = withExpr {
- planHiveGenericUDF(
- "hivemall.knn.lsh.MinHashesUDFWrapper",
- "minhashes",
- exprs
- )
- }
-
- /**
- * Returns new features with `1.0` (bias) appended to the input features.
- * @see [[hivemall.ftvec.AddBiasUDFWrapper]]
- * @group ftvec
- */
- def add_bias(expr: Column): Column = withExpr {
- planHiveGenericUDF(
- "hivemall.ftvec.AddBiasUDFWrapper",
- "add_bias",
- expr :: Nil
- )
- }
-
- /**
- * @see [[hivemall.ftvec.ExtractFeatureUDFWrapper]]
- * @group ftvec
- *
- * TODO: This throws java.lang.ClassCastException because
- * HiveInspectors.toInspector has a bug in spark.
- * Need to fix it later.
- */
- def extract_feature(expr: Column): Column = withExpr {
- planHiveGenericUDF(
- "hivemall.ftvec.ExtractFeatureUDFWrapper",
- "extract_feature",
- expr :: Nil
- )
- }.as("feature")
-
- /**
- * @see [[hivemall.ftvec.ExtractWeightUDFWrapper]]
- * @group ftvec
- *
- * TODO: This throws java.lang.ClassCastException because
- * HiveInspectors.toInspector has a bug in spark.
- * Need to fix it later.
- */
- def extract_weight(expr: Column): Column = withExpr {
- planHiveGenericUDF(
- "hivemall.ftvec.ExtractWeightUDFWrapper",
- "extract_weight",
- expr :: Nil
- )
- }.as("value")
-
- /**
- * @see [[hivemall.ftvec.AddFeatureIndexUDFWrapper]]
- * @group ftvec
- */
- def add_feature_index(features: Column): Column = withExpr {
- planHiveGenericUDF(
- "hivemall.ftvec.AddFeatureIndexUDFWrapper",
- "add_feature_index",
- features :: Nil
- )
- }
-
- /**
- * @see [[hivemall.ftvec.SortByFeatureUDFWrapper]]
- * @group ftvec
- */
- def sort_by_feature(expr: Column): Column = withExpr {
- planHiveGenericUDF(
- "hivemall.ftvec.SortByFeatureUDFWrapper",
- "sort_by_feature",
- expr :: Nil
- )
- }
-
- /**
- * @see [[hivemall.ftvec.hashing.MurmurHash3UDF]]
- * @group ftvec.hashing
- */
- def mhash(expr: Column): Column = withExpr {
- planHiveUDF(
- "hivemall.ftvec.hashing.MurmurHash3UDF",
- "mhash",
- expr :: Nil
- )
- }
-
- /**
- * @see [[hivemall.ftvec.hashing.Sha1UDF]]
- * @group ftvec.hashing
- */
- @scala.annotation.varargs
- def sha1(exprs: Column*): Column = withExpr {
- planHiveUDF(
- "hivemall.ftvec.hashing.Sha1UDF",
- "sha1",
- exprs
- )
- }
-
- /**
- * @see [[hivemall.ftvec.hashing.ArrayHashValuesUDF]]
- * @group ftvec.hashing
- */
- @scala.annotation.varargs
- def array_hash_values(exprs: Column*): Column = withExpr {
- planHiveUDF(
- "hivemall.ftvec.hashing.ArrayHashValuesUDF",
- "array_hash_values",
- exprs
- )
- }
-
- /**
- * @see [[hivemall.ftvec.hashing.ArrayPrefixedHashValuesUDF]]
- * @group ftvec.hashing
- */
- @scala.annotation.varargs
- def prefixed_hash_values(exprs: Column*): Column = withExpr {
- planHiveUDF(
- "hivemall.ftvec.hashing.ArrayPrefixedHashValuesUDF",
- "prefixed_hash_values",
- exprs
- )
- }
-
- /**
- * @see [[hivemall.ftvec.hashing.FeatureHashingUDF]]
- * @group ftvec.hashing
- */
- @scala.annotation.varargs
- def feature_hashing(exprs: Column*): Column = withExpr {
- planHiveGenericUDF(
- "hivemall.ftvec.hashing.FeatureHashingUDF",
- "feature_hashing",
- exprs
- )
- }
-
- /**
- * @see [[hivemall.ftvec.pairing.PolynomialFeaturesUDF]]
- * @group ftvec.paring
- */
- @scala.annotation.varargs
- def polynomial_features(exprs: Column*): Column = withExpr {
- planHiveUDF(
- "hivemall.ftvec.pairing.PolynomialFeaturesUDF",
- "polynomial_features",
- exprs
- )
- }
-
- /**
- * @see [[hivemall.ftvec.pairing.PoweredFeaturesUDF]]
- * @group ftvec.paring
- */
- @scala.annotation.varargs
- def powered_features(exprs: Column*): Column = withExpr {
- planHiveUDF(
- "hivemall.ftvec.pairing.PoweredFeaturesUDF",
- "powered_features",
- exprs
- )
- }
-
- /**
- * @see [[hivemall.ftvec.scaling.RescaleUDF]]
- * @group ftvec.scaling
- */
- def rescale(value: Column, max: Column, min: Column): Column = withExpr {
- planHiveUDF(
- "hivemall.ftvec.scaling.RescaleUDF",
- "rescale",
- value.cast(FloatType) :: max :: min :: Nil
- )
- }
-
- /**
- * @see [[hivemall.ftvec.scaling.ZScoreUDF]]
- * @group ftvec.scaling
- */
- @scala.annotation.varargs
- def zscore(exprs: Column*): Column = withExpr {
- planHiveUDF(
- "hivemall.ftvec.scaling.ZScoreUDF",
- "zscore",
- exprs
- )
- }
-
- /**
- * @see [[hivemall.ftvec.scaling.L2NormalizationUDFWrapper]]
- * @group ftvec.scaling
- */
- def l2_normalize(expr: Column): Column = withExpr {
- planHiveGenericUDF(
- "hivemall.ftvec.scaling.L2NormalizationUDFWrapper",
- "normalize",
- expr :: Nil
- )
- }
-
- /**
- * @see [[hivemall.ftvec.selection.ChiSquareUDF]]
- * @group ftvec.selection
- */
- def chi2(observed: Column, expected: Column): Column = withExpr {
- planHiveGenericUDF(
- "hivemall.ftvec.selection.ChiSquareUDF",
- "chi2",
- Seq(observed, expected)
- )
- }
-
- /**
- * @see [[hivemall.ftvec.conv.ToDenseFeaturesUDF]]
- * @group ftvec.conv
- */
- @scala.annotation.varargs
- def to_dense_features(exprs: Column*): Column = withExpr {
- planHiveUDF(
- "hivemall.ftvec.conv.ToDenseFeaturesUDF",
- "to_dense_features",
- exprs
- )
- }
-
- /**
- * @see [[hivemall.ftvec.conv.ToSparseFeaturesUDF]]
- * @group ftvec.conv
- */
- @scala.annotation.varargs
- def to_sparse_features(exprs: Column*): Column = withExpr {
- planHiveUDF(
- "hivemall.ftvec.conv.ToSparseFeaturesUDF",
- "to_sparse_features",
- exprs
- )
- }
-
- /**
- * @see [[hivemall.ftvec.binning.FeatureBinningUDF]]
- * @group ftvec.conv
- */
- @scala.annotation.varargs
- def feature_binning(exprs: Column*): Column = withExpr {
- planHiveGenericUDF(
- "hivemall.ftvec.binning.FeatureBinningUDF",
- "feature_binning",
- exprs
- )
- }
-
- /**
- * @see [[hivemall.ftvec.trans.VectorizeFeaturesUDF]]
- * @group ftvec.trans
- */
- @scala.annotation.varargs
- def vectorize_features(exprs: Column*): Column = withExpr {
- planHiveGenericUDF(
- "hivemall.ftvec.trans.VectorizeFeaturesUDF",
- "vectorize_features",
- exprs
- )
- }
-
- /**
- * @see [[hivemall.ftvec.trans.CategoricalFeaturesUDF]]
- * @group ftvec.trans
- */
- @scala.annotation.varargs
- def categorical_features(exprs: Column*): Column = withExpr {
- planHiveGenericUDF(
- "hivemall.ftvec.trans.CategoricalFeaturesUDF",
- "categorical_features",
- exprs
- )
- }
-
- /**
- * @see [[hivemall.ftvec.trans.FFMFeaturesUDF]]
- * @group ftvec.trans
- */
- @scala.annotation.varargs
- def ffm_features(exprs: Column*): Column = withExpr {
- planHiveGenericUDF(
- "hivemall.ftvec.trans.FFMFeaturesUDF",
- "ffm_features",
- exprs
- )
- }
-
- /**
- * @see [[hivemall.ftvec.trans.IndexedFeatures]]
- * @group ftvec.trans
- */
- @scala.annotation.varargs
- def indexed_features(exprs: Column*): Column = withExpr {
- planHiveGenericUDF(
- "hivemall.ftvec.trans.IndexedFeatures",
- "indexed_features",
- exprs
- )
- }
-
- /**
- * @see [[hivemall.ftvec.trans.QuantitativeFeaturesUDF]]
- * @group ftvec.trans
- */
- @scala.annotation.varargs
- def quantitative_features(exprs: Column*): Column = withExpr {
- planHiveGenericUDF(
- "hivemall.ftvec.trans.QuantitativeFeaturesUDF",
- "quantitative_features",
- exprs
- )
- }
-
- /**
- * @see [[hivemall.ftvec.trans.AddFieldIndicesUDF]]
- * @group ftvec.trans
- */
- def add_field_indices(features: Column): Column = withExpr {
- planHiveGenericUDF(
- "hivemall.ftvec.trans.AddFieldIndicesUDF",
- "add_field_indices",
- features :: Nil
- )
- }
-
- /**
- * @see [[hivemall.tools.ConvertLabelUDF]]
- * @group tools
- */
- def convert_label(label: Column): Column = withExpr {
- planHiveUDF(
- "hivemall.tools.ConvertLabelUDF",
- "convert_label",
- label :: Nil
- )
- }
-
- /**
- * @see [[hivemall.tools.RankSequenceUDF]]
- * @group tools
- */
- def x_rank(key: Column): Column = withExpr {
- planHiveUDF(
- "hivemall.tools.RankSequenceUDF",
- "x_rank",
- key :: Nil
- )
- }
-
- /**
- * @see [[hivemall.tools.array.AllocFloatArrayUDF]]
- * @group tools.array
- */
- def float_array(nDims: Column): Column = withExpr {
- planHiveUDF(
- "hivemall.tools.array.AllocFloatArrayUDF",
- "float_array",
- nDims :: Nil
- )
- }
-
- /**
- * @see [[hivemall.tools.array.ArrayRemoveUDF]]
- * @group tools.array
- */
- def array_remove(original: Column, target: Column): Column = withExpr {
- planHiveUDF(
- "hivemall.tools.array.ArrayRemoveUDF",
- "array_remove",
- original :: target :: Nil
- )
- }
-
- /**
- * @see [[hivemall.tools.array.SortAndUniqArrayUDF]]
- * @group tools.array
- */
- def sort_and_uniq_array(ar: Column): Column = withExpr {
- planHiveUDF(
- "hivemall.tools.array.SortAndUniqArrayUDF",
- "sort_and_uniq_array",
- ar :: Nil
- )
- }
-
- /**
- * @see [[hivemall.tools.array.SubarrayEndWithUDF]]
- * @group tools.array
- */
- def subarray_endwith(original: Column, key: Column): Column = withExpr {
- planHiveUDF(
- "hivemall.tools.array.SubarrayEndWithUDF",
- "subarray_endwith",
- original :: key :: Nil
- )
- }
-
- /**
- * @see [[hivemall.tools.array.ArrayConcatUDF]]
- * @group tools.array
- */
- @scala.annotation.varargs
- def array_concat(arrays: Column*): Column = withExpr {
- planHiveGenericUDF(
- "hivemall.tools.array.ArrayConcatUDF",
- "array_concat",
- arrays
- )
- }
-
- /**
- * @see [[hivemall.tools.array.SubarrayUDF]]
- * @group tools.array
- */
- def subarray(original: Column, fromIndex: Column, toIndex: Column): Column = withExpr {
- planHiveGenericUDF(
- "hivemall.tools.array.SubarrayUDF",
- "subarray",
- original :: fromIndex :: toIndex :: Nil
- )
- }
-
- /**
- * @see [[hivemall.tools.array.ArraySliceUDF]]
- * @group tools.array
- */
- def array_slice(original: Column, fromIndex: Column, length: Column): Column = withExpr {
- planHiveGenericUDF(
- "hivemall.tools.array.ArraySliceUDF",
- "array_slice",
- original :: fromIndex :: length :: Nil
- )
- }
-
- /**
- * @see [[hivemall.tools.array.ToStringArrayUDF]]
- * @group tools.array
- */
- def to_string_array(ar: Column): Column = withExpr {
- planHiveGenericUDF(
- "hivemall.tools.array.ToStringArrayUDF",
- "to_string_array",
- ar :: Nil
- )
- }
-
- /**
- * @see [[hivemall.tools.array.ArrayIntersectUDF]]
- * @group tools.array
- */
- @scala.annotation.varargs
- def array_intersect(arrays: Column*): Column = withExpr {
- planHiveGenericUDF(
- "hivemall.tools.array.ArrayIntersectUDF",
- "array_intersect",
- arrays
- )
- }
-
- /**
- * @see [[hivemall.tools.array.SelectKBestUDF]]
- * @group tools.array
- */
- def select_k_best(X: Column, importanceList: Column, k: Column): Column = withExpr {
- planHiveGenericUDF(
- "hivemall.tools.array.SelectKBestUDF",
- "select_k_best",
- Seq(X, importanceList, k)
- )
- }
-
- /**
- * @see [[hivemall.tools.bits.ToBitsUDF]]
- * @group tools.bits
- */
- def to_bits(indexes: Column): Column = withExpr {
- planHiveGenericUDF(
- "hivemall.tools.bits.ToBitsUDF",
- "to_bits",
- indexes :: Nil
- )
- }
-
- /**
- * @see [[hivemall.tools.bits.UnBitsUDF]]
- * @group tools.bits
- */
- def unbits(bitset: Column): Column = withExpr {
- planHiveGenericUDF(
- "hivemall.tools.bits.UnBitsUDF",
- "unbits",
- bitset :: Nil
- )
- }
-
- /**
- * @see [[hivemall.tools.bits.BitsORUDF]]
- * @group tools.bits
- */
- @scala.annotation.varargs
- def bits_or(bits: Column*): Column = withExpr {
- planHiveGenericUDF(
- "hivemall.tools.bits.BitsORUDF",
- "bits_or",
- bits
- )
- }
-
- /**
- * @see [[hivemall.tools.compress.InflateUDF]]
- * @group tools.compress
- */
- @scala.annotation.varargs
- def inflate(exprs: Column*): Column = withExpr {
- planHiveGenericUDF(
- "hivemall.tools.compress.InflateUDF",
- "inflate",
- exprs
- )
- }
-
- /**
- * @see [[hivemall.tools.compress.DeflateUDF]]
- * @group tools.compress
- */
- @scala.annotation.varargs
- def deflate(exprs: Column*): Column = withExpr {
- planHiveGenericUDF(
- "hivemall.tools.compress.DeflateUDF",
- "deflate",
- exprs
- )
- }
-
- /**
- * @see [[hivemall.tools.map.MapGetSumUDF]]
- * @group tools.map
- */
- @scala.annotation.varargs
- def map_get_sum(exprs: Column*): Column = withExpr {
- planHiveUDF(
- "hivemall.tools.map.MapGetSumUDF",
- "map_get_sum",
- exprs
- )
- }
-
- /**
- * @see [[hivemall.tools.map.MapTailNUDF]]
- * @group tools.map
- */
- @scala.annotation.varargs
- def map_tail_n(exprs: Column*): Column = withExpr {
- planHiveGenericUDF(
- "hivemall.tools.map.MapTailNUDF",
- "map_tail_n",
- exprs
- )
- }
-
- /**
- * @see [[hivemall.tools.text.TokenizeUDF]]
- * @group tools.text
- */
- @scala.annotation.varargs
- def tokenize(exprs: Column*): Column = withExpr {
- planHiveUDF(
- "hivemall.tools.text.TokenizeUDF",
- "tokenize",
- exprs
- )
- }
-
- /**
- * @see [[hivemall.tools.text.StopwordUDF]]
- * @group tools.text
- */
- def is_stopword(word: Column): Column = withExpr {
- planHiveUDF(
- "hivemall.tools.text.StopwordUDF",
- "is_stopword",
- word :: Nil
- )
- }
-
- /**
- * @see [[hivemall.tools.text.SingularizeUDF]]
- * @group tools.text
- */
- def singularize(word: Column): Column = withExpr {
- planHiveUDF(
- "hivemall.tools.text.SingularizeUDF",
- "singularize",
- word :: Nil
- )
- }
-
- /**
- * @see [[hivemall.tools.text.SplitWordsUDF]]
- * @group tools.text
- */
- @scala.annotation.varargs
- def split_words(exprs: Column*): Column = withExpr {
- planHiveUDF(
- "hivemall.tools.text.SplitWordsUDF",
- "split_words",
- exprs
- )
- }
-
- /**
- * @see [[hivemall.tools.text.NormalizeUnicodeUDF]]
- * @group tools.text
- */
- @scala.annotation.varargs
- def normalize_unicode(exprs: Column*): Column = withExpr {
- planHiveUDF(
- "hivemall.tools.text.NormalizeUnicodeUDF",
- "normalize_unicode",
- exprs
- )
- }
-
- /**
- * @see [[hivemall.tools.text.Base91UDF]]
- * @group tools.text
- */
- def base91(bin: Column): Column = withExpr {
- planHiveGenericUDF(
- "hivemall.tools.text.Base91UDF",
- "base91",
- bin :: Nil
- )
- }
-
- /**
- * @see [[hivemall.tools.text.Unbase91UDF]]
- * @group tools.text
- */
- def unbase91(base91String: Column): Column = withExpr {
- planHiveGenericUDF(
- "hivemall.tools.text.Unbase91UDF",
- "unbase91",
- base91String :: Nil
- )
- }
-
- /**
- * @see [[hivemall.tools.text.WordNgramsUDF]]
- * @group tools.text
- */
- def word_ngrams(words: Column, minSize: Column, maxSize: Column): Column = withExpr {
- planHiveUDF(
- "hivemall.tools.text.WordNgramsUDF",
- "word_ngrams",
- words :: minSize :: maxSize :: Nil
- )
- }
-
- /**
- * @see [[hivemall.tools.math.SigmoidGenericUDF]]
- * @group misc
- */
- def sigmoid(expr: Column): Column = {
- val one: () => Literal = () => Literal.create(1.0, DoubleType)
- Column(one()) / (Column(one()) + exp(-expr))
- }
-
- /**
- * @see [[hivemall.tools.mapred.RowIdUDFWrapper]]
- * @group misc
- */
- def rowid(): Column = withExpr {
- planHiveGenericUDF(
- "hivemall.tools.mapred.RowIdUDFWrapper",
- "rowid",
- Nil
- )
- }.as("rowid")
-
- /**
- * Parses a column containing a CSV string into a [[StructType]] with the specified schema.
- * Returns `null`, in the case of an unparseable string.
- * @group misc
- *
- * @param e a string column containing CSV data.
- * @param schema the schema to use when parsing the csv string
- * @param options options to control how the csv is parsed. accepts the same options and the
- * csv data source.
- */
- def from_csv(e: Column, schema: StructType, options: Map[String, String]): Column = withExpr {
- CsvToStruct(schema, options, e.expr)
- }
-
- /**
- * Parses a column containing a CSV string into a [[StructType]] with the specified schema.
- * Returns `null`, in the case of an unparseable string.
- * @group misc
- *
- * @param e a string column containing CSV data.
- * @param schema the schema to use when parsing the json string
- */
- def from_csv(e: Column, schema: StructType): Column =
- from_csv(e, schema, Map.empty[String, String])
-
- /**
- * Converts a column containing a [[StructType]] into a CSV string with the specified schema.
- * Throws an exception, in the case of an unsupported type.
- * @group misc
- *
- * @param e a struct column.
- * @param options options to control how the struct column is converted into a json string.
- * accepts the same options and the json data source.
- */
- def to_csv(e: Column, options: Map[String, String]): Column = withExpr {
- StructToCsv(options, e.expr)
- }
-
- /**
- * Converts a column containing a [[StructType]] into a CSV string with the specified schema.
- * Throws an exception, in the case of an unsupported type.
- * @group misc
- *
- * @param e a struct column.
- */
- def to_csv(e: Column): Column = to_csv(e, Map.empty[String, String])
-
- /**
- * A convenient function to wrap an expression and produce a Column.
- */
- @inline private def withExpr(expr: Expression): Column = Column(expr)
-}
diff --git a/spark/spark-2.2/src/main/scala/org/apache/spark/sql/hive/HivemallUtils.scala b/spark/spark-2.2/src/main/scala/org/apache/spark/sql/hive/HivemallUtils.scala
deleted file mode 100644
index 70cf00b..0000000
--- a/spark/spark-2.2/src/main/scala/org/apache/spark/sql/hive/HivemallUtils.scala
+++ /dev/null
@@ -1,146 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.spark.sql.hive
-
-import org.apache.spark.ml.linalg.{BLAS, DenseVector, SparseVector, Vector, Vectors}
-import org.apache.spark.sql.{DataFrame, Row}
-import org.apache.spark.sql.expressions.UserDefinedFunction
-import org.apache.spark.sql.functions._
-import org.apache.spark.sql.types._
-
-object HivemallUtils {
-
- // # of maximum dimensions for feature vectors
- private[this] val maxDims = 100000000
-
- /**
- * Check whether the given schema contains a column of the required data type.
- * @param colName column name
- * @param dataType required column data type
- */
- private[this] def checkColumnType(schema: StructType, colName: String, dataType: DataType)
- : Unit = {
- val actualDataType = schema(colName).dataType
- require(actualDataType.equals(dataType),
- s"Column $colName must be of type $dataType but was actually $actualDataType.")
- }
-
- def to_vector_func(dense: Boolean, dims: Int): Seq[String] => Vector = {
- if (dense) {
- // Dense features
- i: Seq[String] => {
- val features = new Array[Double](dims)
- i.map { ft =>
- val s = ft.split(":").ensuring(_.size == 2)
- features(s(0).toInt) = s(1).toDouble
- }
- Vectors.dense(features)
- }
- } else {
- // Sparse features
- i: Seq[String] => {
- val features = i.map { ft =>
- // val s = ft.split(":").ensuring(_.size == 2)
- val s = ft.split(":")
- (s(0).toInt, s(1).toDouble)
- }
- Vectors.sparse(dims, features)
- }
- }
- }
-
- def to_hivemall_features_func(): Vector => Array[String] = {
- case dv: DenseVector =>
- dv.values.zipWithIndex.map {
- case (value, index) => s"$index:$value"
- }
- case sv: SparseVector =>
- sv.values.zip(sv.indices).map {
- case (value, index) => s"$index:$value"
- }
- case v =>
- throw new IllegalArgumentException(s"Do not support vector type ${v.getClass}")
- }
-
- def append_bias_func(): Vector => Vector = {
- case dv: DenseVector =>
- val inputValues = dv.values
- val inputLength = inputValues.length
- val outputValues = Array.ofDim[Double](inputLength + 1)
- System.arraycopy(inputValues, 0, outputValues, 0, inputLength)
- outputValues(inputLength) = 1.0
- Vectors.dense(outputValues)
- case sv: SparseVector =>
- val inputValues = sv.values
- val inputIndices = sv.indices
- val inputValuesLength = inputValues.length
- val dim = sv.size
- val outputValues = Array.ofDim[Double](inputValuesLength + 1)
- val outputIndices = Array.ofDim[Int](inputValuesLength + 1)
- System.arraycopy(inputValues, 0, outputValues, 0, inputValuesLength)
- System.arraycopy(inputIndices, 0, outputIndices, 0, inputValuesLength)
- outputValues(inputValuesLength) = 1.0
- outputIndices(inputValuesLength) = dim
- Vectors.sparse(dim + 1, outputIndices, outputValues)
- case v =>
- throw new IllegalArgumentException(s"Do not support vector type ${v.getClass}")
- }
-
- /**
- * Transforms Hivemall features into a [[Vector]].
- */
- def to_vector(dense: Boolean = false, dims: Int = maxDims): UserDefinedFunction = {
- udf(to_vector_func(dense, dims))
- }
-
- /**
- * Transforms a [[Vector]] into Hivemall features.
- */
- def to_hivemall_features: UserDefinedFunction = udf(to_hivemall_features_func)
-
- /**
- * Returns a new [[Vector]] with `1.0` (bias) appended to the input [[Vector]].
- * @group ftvec
- */
- def append_bias: UserDefinedFunction = udf(append_bias_func)
-
- /**
- * Builds a [[Vector]]-based model from a table of Hivemall models
- */
- def vectorized_model(df: DataFrame, dense: Boolean = false, dims: Int = maxDims)
- : UserDefinedFunction = {
- checkColumnType(df.schema, "feature", StringType)
- checkColumnType(df.schema, "weight", DoubleType)
-
- import df.sqlContext.implicits._
- val intercept = df
- .where($"feature" === "0")
- .select($"weight")
- .map { case Row(weight: Double) => weight}
- .reduce(_ + _)
- val weights = to_vector_func(dense, dims)(
- df.select($"feature", $"weight")
- .where($"feature" !== "0")
- .map { case Row(label: String, feature: Double) => s"${label}:$feature"}
- .collect.toSeq)
-
- udf((input: Vector) => BLAS.dot(input, weights) + intercept)
- }
-}
diff --git a/spark/spark-2.2/src/main/scala/org/apache/spark/sql/hive/internal/HivemallOpsImpl.scala b/spark/spark-2.2/src/main/scala/org/apache/spark/sql/hive/internal/HivemallOpsImpl.scala
deleted file mode 100644
index 179b146..0000000
--- a/spark/spark-2.2/src/main/scala/org/apache/spark/sql/hive/internal/HivemallOpsImpl.scala
+++ /dev/null
@@ -1,79 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.spark.sql.hive.internal
-
-import org.apache.spark.internal.Logging
-import org.apache.spark.sql._
-import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
-import org.apache.spark.sql.catalyst.expressions.Expression
-import org.apache.spark.sql.catalyst.plans.logical.{Generate, LogicalPlan}
-import org.apache.spark.sql.hive._
-import org.apache.spark.sql.hive.HiveShim.HiveFunctionWrapper
-
-/**
- * This is an implementation class for [[org.apache.spark.sql.hive.HivemallOps]].
- * This class mainly uses the internal Spark classes (e.g., `Generate` and `HiveGenericUDTF`) that
- * have unstable interfaces (so, these interfaces may evolve in upcoming releases).
- * Therefore, the objective of this class is to extract these unstable parts
- * from [[org.apache.spark.sql.hive.HivemallOps]].
- */
-private[hive] object HivemallOpsImpl extends Logging {
-
- def planHiveUDF(
- className: String,
- funcName: String,
- argumentExprs: Seq[Column]): Expression = {
- HiveSimpleUDF(
- name = funcName,
- funcWrapper = new HiveFunctionWrapper(className),
- children = argumentExprs.map(_.expr)
- )
- }
-
- def planHiveGenericUDF(
- className: String,
- funcName: String,
- argumentExprs: Seq[Column]): Expression = {
- HiveGenericUDF(
- name = funcName,
- funcWrapper = new HiveFunctionWrapper(className),
- children = argumentExprs.map(_.expr)
- )
- }
-
- def planHiveGenericUDTF(
- df: DataFrame,
- className: String,
- funcName: String,
- argumentExprs: Seq[Column],
- outputAttrNames: Seq[String]): LogicalPlan = {
- Generate(
- generator = HiveGenericUDTF(
- name = funcName,
- funcWrapper = new HiveFunctionWrapper(className),
- children = argumentExprs.map(_.expr)
- ),
- join = false,
- outer = false,
- qualifier = None,
- generatorOutput = outputAttrNames.map(UnresolvedAttribute(_)),
- child = df.logicalPlan)
- }
-}
diff --git a/spark/spark-2.2/src/main/scala/org/apache/spark/sql/hive/source/XGBoostFileFormat.scala b/spark/spark-2.2/src/main/scala/org/apache/spark/sql/hive/source/XGBoostFileFormat.scala
deleted file mode 100644
index 65cdf24..0000000
--- a/spark/spark-2.2/src/main/scala/org/apache/spark/sql/hive/source/XGBoostFileFormat.scala
+++ /dev/null
@@ -1,163 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.spark.sql.hive.source
-
-import java.io.File
-import java.io.IOException
-import java.net.URI
-
-import org.apache.hadoop.conf.Configuration
-import org.apache.hadoop.fs.{FileStatus, FSDataInputStream, Path}
-import org.apache.hadoop.io.IOUtils
-import org.apache.hadoop.io.compress.GzipCodec
-import org.apache.hadoop.mapreduce._
-import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat
-import org.apache.hadoop.util.ReflectionUtils
-
-import org.apache.spark.sql.{Row, SparkSession}
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.encoders.RowEncoder
-import org.apache.spark.sql.catalyst.expressions.AttributeReference
-import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
-import org.apache.spark.sql.execution.datasources._
-import org.apache.spark.sql.sources._
-import org.apache.spark.sql.types._
-import org.apache.spark.util.SerializableConfiguration
-
-private[source] final class XGBoostOutputWriter(
- path: String,
- dataSchema: StructType,
- context: TaskAttemptContext)
- extends OutputWriter {
-
- private val hadoopConf = new SerializableConfiguration(new Configuration())
-
- override def write(row: InternalRow): Unit = {
- val fields = row.toSeq(dataSchema)
- val model = fields(1).asInstanceOf[Array[Byte]]
- val filePath = new Path(new URI(s"$path"))
- val fs = filePath.getFileSystem(hadoopConf.value)
- val outputFile = fs.create(filePath)
- outputFile.write(model)
- outputFile.close()
- }
-
- override def close(): Unit = {}
-}
-
-object XGBoostOutputWriter {
-
- /** Returns the compression codec extension to be used in a file name, e.g. ".gzip"). */
- def getCompressionExtension(context: TaskAttemptContext): String = {
- if (FileOutputFormat.getCompressOutput(context)) {
- val codecClass = FileOutputFormat.getOutputCompressorClass(context, classOf[GzipCodec])
- ReflectionUtils.newInstance(codecClass, context.getConfiguration).getDefaultExtension
- } else {
- ""
- }
- }
-}
-
-final class XGBoostFileFormat extends FileFormat with DataSourceRegister {
-
- override def shortName(): String = "libxgboost"
-
- override def toString: String = "XGBoost"
-
- private def verifySchema(dataSchema: StructType): Unit = {
- if (
- dataSchema.size != 2 ||
- !dataSchema(0).dataType.sameType(StringType) ||
- !dataSchema(1).dataType.sameType(BinaryType)
- ) {
- throw new IOException(s"Illegal schema for XGBoost data, schema=$dataSchema")
- }
- }
-
- override def inferSchema(
- sparkSession: SparkSession,
- options: Map[String, String],
- files: Seq[FileStatus]): Option[StructType] = {
- Some(
- StructType(
- StructField("model_id", StringType, nullable = false) ::
- StructField("pred_model", BinaryType, nullable = false) :: Nil)
- )
- }
-
- override def prepareWrite(
- sparkSession: SparkSession,
- job: Job,
- options: Map[String, String],
- dataSchema: StructType): OutputWriterFactory = {
- new OutputWriterFactory {
- override def newInstance(
- path: String,
- dataSchema: StructType,
- context: TaskAttemptContext): OutputWriter = {
- new XGBoostOutputWriter(path, dataSchema, context)
- }
-
- override def getFileExtension(context: TaskAttemptContext): String = {
- XGBoostOutputWriter.getCompressionExtension(context) + ".xgboost"
- }
- }
- }
-
- override def buildReader(
- sparkSession: SparkSession,
- dataSchema: StructType,
- partitionSchema: StructType,
- requiredSchema: StructType,
- filters: Seq[Filter],
- options: Map[String, String],
- hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = {
- verifySchema(dataSchema)
- val broadcastedHadoopConf =
- sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf))
-
- (file: PartitionedFile) => {
- val model = new Array[Byte](file.length.asInstanceOf[Int])
- val filePath = new Path(new URI(file.filePath))
- val fs = filePath.getFileSystem(broadcastedHadoopConf.value.value)
-
- var in: FSDataInputStream = null
- try {
- in = fs.open(filePath)
- IOUtils.readFully(in, model, 0, model.length)
- } finally {
- IOUtils.closeStream(in)
- }
-
- val converter = RowEncoder(dataSchema)
- val fullOutput = dataSchema.map { f =>
- AttributeReference(f.name, f.dataType, f.nullable, f.metadata)()
- }
- val requiredOutput = fullOutput.filter { a =>
- requiredSchema.fieldNames.contains(a.name)
- }
- val requiredColumns = GenerateUnsafeProjection.generate(requiredOutput, fullOutput)
- (requiredColumns(
- converter.toRow(Row(new File(file.filePath).getName, model)))
- :: Nil
- ).toIterator
- }
- }
-}
diff --git a/spark/spark-2.2/src/main/scala/org/apache/spark/streaming/HivemallStreamingOps.scala b/spark/spark-2.2/src/main/scala/org/apache/spark/streaming/HivemallStreamingOps.scala
deleted file mode 100644
index a6bbb4b..0000000
--- a/spark/spark-2.2/src/main/scala/org/apache/spark/streaming/HivemallStreamingOps.scala
+++ /dev/null
@@ -1,47 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.spark.streaming
-
-import scala.reflect.ClassTag
-
-import org.apache.spark.ml.feature.HivemallLabeledPoint
-import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, Row, SQLContext}
-import org.apache.spark.streaming.dstream.DStream
-
-final class HivemallStreamingOps(ds: DStream[HivemallLabeledPoint]) {
-
- def predict[U: ClassTag](f: DataFrame => DataFrame)(implicit sqlContext: SQLContext)
- : DStream[Row] = {
- ds.transform[Row] { rdd: RDD[HivemallLabeledPoint] =>
- f(sqlContext.createDataFrame(rdd)).rdd
- }
- }
-}
-
-object HivemallStreamingOps {
-
- /**
- * Implicitly inject the [[HivemallStreamingOps]] into [[DStream]].
- */
- implicit def dataFrameToHivemallStreamingOps(ds: DStream[HivemallLabeledPoint])
- : HivemallStreamingOps = {
- new HivemallStreamingOps(ds)
- }
-}
diff --git a/spark/spark-2.2/src/test/resources/data/files/README.md b/spark/spark-2.2/src/test/resources/data/files/README.md
deleted file mode 100644
index 238d472..0000000
--- a/spark/spark-2.2/src/test/resources/data/files/README.md
+++ /dev/null
@@ -1,22 +0,0 @@
-<!--
- Licensed to the Apache Software Foundation (ASF) under one
- or more contributor license agreements. See the NOTICE file
- distributed with this work for additional information
- regarding copyright ownership. The ASF licenses this file
- to you under the Apache License, Version 2.0 (the
- "License"); you may not use this file except in compliance
- with the License. You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing,
- software distributed under the License is distributed on an
- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- KIND, either express or implied. See the License for the
- specific language governing permissions and limitations
- under the License.
--->
-
-The files in this dir exist for preventing exceptions in o.a.s.sql.hive.test.TESTHive.
-We need to fix this issue in future.
-
diff --git a/spark/spark-2.2/src/test/resources/data/files/complex.seq b/spark/spark-2.2/src/test/resources/data/files/complex.seq
deleted file mode 100644
index e69de29..0000000
diff --git a/spark/spark-2.2/src/test/resources/data/files/episodes.avro b/spark/spark-2.2/src/test/resources/data/files/episodes.avro
deleted file mode 100644
index e69de29..0000000
diff --git a/spark/spark-2.2/src/test/resources/data/files/json.txt b/spark/spark-2.2/src/test/resources/data/files/json.txt
deleted file mode 100644
index e69de29..0000000
diff --git a/spark/spark-2.2/src/test/resources/data/files/kv1.txt b/spark/spark-2.2/src/test/resources/data/files/kv1.txt
deleted file mode 100644
index e69de29..0000000
diff --git a/spark/spark-2.2/src/test/resources/data/files/kv3.txt b/spark/spark-2.2/src/test/resources/data/files/kv3.txt
deleted file mode 100644
index e69de29..0000000
diff --git a/spark/spark-2.2/src/test/resources/log4j.properties b/spark/spark-2.2/src/test/resources/log4j.properties
deleted file mode 100644
index c6e4297..0000000
--- a/spark/spark-2.2/src/test/resources/log4j.properties
+++ /dev/null
@@ -1,24 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one or more
-# contributor license agreements. See the NOTICE file distributed with
-# this work for additional information regarding copyright ownership.
-# The ASF licenses this file to You under the Apache License, Version 2.0
-# (the "License"); you may not use this file except in compliance with
-# the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-# Set everything to be logged to the console
-log4j.rootCategory=FATAL, console
-log4j.appender.console=org.apache.log4j.ConsoleAppender
-log4j.appender.console.target=System.err
-log4j.appender.console.layout=org.apache.log4j.PatternLayout
-log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n
-
diff --git a/spark/spark-2.2/src/test/scala/hivemall/mix/server/MixServerSuite.scala b/spark/spark-2.2/src/test/scala/hivemall/mix/server/MixServerSuite.scala
deleted file mode 100644
index 9bbd3f0..0000000
--- a/spark/spark-2.2/src/test/scala/hivemall/mix/server/MixServerSuite.scala
+++ /dev/null
@@ -1,124 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package hivemall.mix.server
-
-import java.util.Random
-import java.util.concurrent.{Executors, ExecutorService, TimeUnit}
-import java.util.logging.Logger
-
-import hivemall.mix.MixMessage.MixEventName
-import hivemall.mix.client.MixClient
-import hivemall.mix.server.MixServer.ServerState
-import hivemall.model.{DenseModel, PredictionModel, WeightValue}
-import hivemall.utils.io.IOUtils
-import hivemall.utils.lang.CommandLineUtils
-import hivemall.utils.net.NetUtils
-import org.scalatest.{BeforeAndAfter, FunSuite}
-
-class MixServerSuite extends FunSuite with BeforeAndAfter {
-
- private[this] var server: MixServer = _
- private[this] var executor : ExecutorService = _
- private[this] var port: Int = _
-
- private[this] val rand = new Random(43)
- private[this] val counter = Stream.from(0).iterator
-
- private[this] val eachTestTime = 100
- private[this] val logger =
- Logger.getLogger(classOf[MixServerSuite].getName)
-
- before {
- this.port = NetUtils.getAvailablePort
- this.server = new MixServer(
- CommandLineUtils.parseOptions(
- Array("-port", s"${port}", "-sync_threshold", "3"),
- MixServer.getOptions()
- )
- )
- this.executor = Executors.newSingleThreadExecutor
- this.executor.submit(server)
- var retry = 0
- while (server.getState() != ServerState.RUNNING && retry < 50) {
- Thread.sleep(1000L)
- retry += 1
- }
- assert(server.getState == ServerState.RUNNING)
- }
-
- after { this.executor.shutdown() }
-
- private[this] def clientDriver(
- groupId: String, model: PredictionModel, numMsg: Int = 1000000): Unit = {
- var client: MixClient = null
- try {
- client = new MixClient(MixEventName.average, groupId, s"localhost:${port}", false, 2, model)
- model.configureMix(client, false)
- model.configureClock()
-
- for (_ <- 0 until numMsg) {
- val feature = Integer.valueOf(rand.nextInt(model.size))
- model.set(feature, new WeightValue(1.0f))
- }
-
- while (true) { Thread.sleep(eachTestTime * 1000 + 100L) }
- assert(model.getNumMixed > 0)
- } finally {
- IOUtils.closeQuietly(client)
- }
- }
-
- private[this] def fixedGroup: (String, () => String) =
- ("fixed", () => "fixed")
- private[this] def uniqueGroup: (String, () => String) =
- ("unique", () => s"${counter.next}")
-
- Seq(65536).map { ndims =>
- Seq(4).map { nclient =>
- Seq(fixedGroup, uniqueGroup).map { id =>
- val testName = s"dense-dim:${ndims}-clinet:${nclient}-${id._1}"
- ignore(testName) {
- val clients = Executors.newCachedThreadPool()
- val numClients = nclient
- val models = (0 until numClients).map(i => new DenseModel(ndims, false))
- (0 until numClients).map { i =>
- clients.submit(new Runnable() {
- override def run(): Unit = {
- try {
- clientDriver(
- s"${testName}-${id._2}",
- models(i)
- )
- } catch {
- case e: InterruptedException =>
- assert(false, e.getMessage)
- }
- }
- })
- }
- clients.awaitTermination(eachTestTime, TimeUnit.SECONDS)
- clients.shutdown()
- val nMixes = models.map(d => d.getNumMixed).reduce(_ + _)
- logger.info(s"${testName} --> ${(nMixes + 0.0) / eachTestTime} mixes/s")
- }
- }
- }
- }
-}
diff --git a/spark/spark-2.2/src/test/scala/hivemall/tools/RegressionDatagenSuite.scala b/spark/spark-2.2/src/test/scala/hivemall/tools/RegressionDatagenSuite.scala
deleted file mode 100644
index c127276..0000000
--- a/spark/spark-2.2/src/test/scala/hivemall/tools/RegressionDatagenSuite.scala
+++ /dev/null
@@ -1,33 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package hivemall.tools
-
-import org.scalatest.FunSuite
-
-import org.apache.spark.sql.hive.test.TestHive
-
-class RegressionDatagenSuite extends FunSuite {
-
- test("datagen") {
- val df = RegressionDatagen.exec(
- TestHive, min_examples = 10000, n_features = 100, n_dims = 65536, dense = false, cl = true)
- assert(df.count() >= 10000)
- }
-}
diff --git a/spark/spark-2.2/src/test/scala/org/apache/spark/SparkFunSuite.scala b/spark/spark-2.2/src/test/scala/org/apache/spark/SparkFunSuite.scala
deleted file mode 100644
index ed1bb6a..0000000
--- a/spark/spark-2.2/src/test/scala/org/apache/spark/SparkFunSuite.scala
+++ /dev/null
@@ -1,51 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.spark
-
-// scalastyle:off
-import org.scalatest.{FunSuite, Outcome}
-
-import org.apache.spark.internal.Logging
-
-/**
- * Base abstract class for all unit tests in Spark for handling common functionality.
- */
-private[spark] abstract class SparkFunSuite extends FunSuite with Logging {
-// scalastyle:on
-
- /**
- * Log the suite name and the test name before and after each test.
- *
- * Subclasses should never override this method. If they wish to run
- * custom code before and after each test, they should mix in the
- * {{org.scalatest.BeforeAndAfter}} trait instead.
- */
- final protected override def withFixture(test: NoArgTest): Outcome = {
- val testName = test.text
- val suiteName = this.getClass.getName
- val shortSuiteName = suiteName.replaceAll("org.apache.spark", "o.a.s")
- try {
- logInfo(s"\n\n===== TEST OUTPUT FOR $shortSuiteName: '$testName' =====\n")
- test()
- } finally {
- logInfo(s"\n\n===== FINISHED $shortSuiteName: '$testName' =====\n")
- }
- }
-}
diff --git a/spark/spark-2.2/src/test/scala/org/apache/spark/ml/feature/HivemallLabeledPointSuite.scala b/spark/spark-2.2/src/test/scala/org/apache/spark/ml/feature/HivemallLabeledPointSuite.scala
deleted file mode 100644
index 903dc0a..0000000
--- a/spark/spark-2.2/src/test/scala/org/apache/spark/ml/feature/HivemallLabeledPointSuite.scala
+++ /dev/null
@@ -1,36 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.spark.ml.feature
-
-import org.apache.spark.SparkFunSuite
-
-class HivemallLabeledPointSuite extends SparkFunSuite {
-
- test("toString") {
- val lp = HivemallLabeledPoint(1.0f, Seq("1:0.5", "3:0.3", "8:0.1"))
- assert(lp.toString === "1.0,[1:0.5,3:0.3,8:0.1]")
- }
-
- test("parse") {
- val lp = HivemallLabeledPoint.parse("1.0,[1:0.5,3:0.3,8:0.1]")
- assert(lp.label === 1.0)
- assert(lp.features === Seq("1:0.5", "3:0.3", "8:0.1"))
- }
-}
diff --git a/spark/spark-2.2/src/test/scala/org/apache/spark/sql/QueryTest.scala b/spark/spark-2.2/src/test/scala/org/apache/spark/sql/QueryTest.scala
deleted file mode 100644
index c9d0ba0..0000000
--- a/spark/spark-2.2/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ /dev/null
@@ -1,360 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.spark.sql
-
-import java.util.{ArrayDeque, Locale, TimeZone}
-
-import scala.collection.JavaConverters._
-import scala.util.control.NonFatal
-
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.aggregate.ImperativeAggregate
-import org.apache.spark.sql.catalyst.plans._
-import org.apache.spark.sql.catalyst.plans.logical._
-import org.apache.spark.sql.catalyst.trees.TreeNode
-import org.apache.spark.sql.catalyst.util._
-import org.apache.spark.sql.execution._
-import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression
-import org.apache.spark.sql.execution.columnar.InMemoryRelation
-import org.apache.spark.sql.execution.datasources.LogicalRelation
-import org.apache.spark.sql.execution.streaming.MemoryPlan
-import org.apache.spark.sql.types.{Metadata, ObjectType}
-
-
-abstract class QueryTest extends PlanTest {
-
- protected def spark: SparkSession
-
- // Timezone is fixed to America/Los_Angeles for those timezone sensitive tests (timestamp_*)
- TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles"))
- // Add Locale setting
- Locale.setDefault(Locale.US)
-
- /**
- * Runs the plan and makes sure the answer contains all of the keywords.
- */
- def checkKeywordsExist(df: DataFrame, keywords: String*): Unit = {
- val outputs = df.collect().map(_.mkString).mkString
- for (key <- keywords) {
- assert(outputs.contains(key), s"Failed for $df ($key doesn't exist in result)")
- }
- }
-
- /**
- * Runs the plan and makes sure the answer does NOT contain any of the keywords.
- */
- def checkKeywordsNotExist(df: DataFrame, keywords: String*): Unit = {
- val outputs = df.collect().map(_.mkString).mkString
- for (key <- keywords) {
- assert(!outputs.contains(key), s"Failed for $df ($key existed in the result)")
- }
- }
-
- /**
- * Evaluates a dataset to make sure that the result of calling collect matches the given
- * expected answer.
- */
- protected def checkDataset[T](
- ds: => Dataset[T],
- expectedAnswer: T*): Unit = {
- val result = getResult(ds)
-
- if (!compare(result.toSeq, expectedAnswer)) {
- fail(
- s"""
- |Decoded objects do not match expected objects:
- |expected: $expectedAnswer
- |actual: ${result.toSeq}
- |${ds.exprEnc.deserializer.treeString}
- """.stripMargin)
- }
- }
-
- /**
- * Evaluates a dataset to make sure that the result of calling collect matches the given
- * expected answer, after sort.
- */
- protected def checkDatasetUnorderly[T : Ordering](
- ds: => Dataset[T],
- expectedAnswer: T*): Unit = {
- val result = getResult(ds)
-
- if (!compare(result.toSeq.sorted, expectedAnswer.sorted)) {
- fail(
- s"""
- |Decoded objects do not match expected objects:
- |expected: $expectedAnswer
- |actual: ${result.toSeq}
- |${ds.exprEnc.deserializer.treeString}
- """.stripMargin)
- }
- }
-
- private def getResult[T](ds: => Dataset[T]): Array[T] = {
- val analyzedDS = try ds catch {
- case ae: AnalysisException =>
- if (ae.plan.isDefined) {
- fail(
- s"""
- |Failed to analyze query: $ae
- |${ae.plan.get}
- |
- |${stackTraceToString(ae)}
- """.stripMargin)
- } else {
- throw ae
- }
- }
- assertEmptyMissingInput(analyzedDS)
-
- try ds.collect() catch {
- case e: Exception =>
- fail(
- s"""
- |Exception collecting dataset as objects
- |${ds.exprEnc}
- |${ds.exprEnc.deserializer.treeString}
- |${ds.queryExecution}
- """.stripMargin, e)
- }
- }
-
- private def compare(obj1: Any, obj2: Any): Boolean = (obj1, obj2) match {
- case (null, null) => true
- case (null, _) => false
- case (_, null) => false
- case (a: Array[_], b: Array[_]) =>
- a.length == b.length && a.zip(b).forall { case (l, r) => compare(l, r)}
- case (a: Iterable[_], b: Iterable[_]) =>
- a.size == b.size && a.zip(b).forall { case (l, r) => compare(l, r)}
- case (a, b) => a == b
- }
-
- /**
- * Runs the plan and makes sure the answer matches the expected result.
- *
- * @param df the [[DataFrame]] to be executed
- * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s.
- */
- protected def checkAnswer(df: => DataFrame, expectedAnswer: Seq[Row]): Unit = {
- val analyzedDF = try df catch {
- case ae: AnalysisException =>
- if (ae.plan.isDefined) {
- fail(
- s"""
- |Failed to analyze query: $ae
- |${ae.plan.get}
- |
- |${stackTraceToString(ae)}
- |""".stripMargin)
- } else {
- throw ae
- }
- }
-
- assertEmptyMissingInput(analyzedDF)
-
- QueryTest.checkAnswer(analyzedDF, expectedAnswer) match {
- case Some(errorMessage) => fail(errorMessage)
- case None =>
- }
- }
-
- protected def checkAnswer(df: => DataFrame, expectedAnswer: Row): Unit = {
- checkAnswer(df, Seq(expectedAnswer))
- }
-
- protected def checkAnswer(df: => DataFrame, expectedAnswer: DataFrame): Unit = {
- checkAnswer(df, expectedAnswer.collect())
- }
-
- /**
- * Runs the plan and makes sure the answer is within absTol of the expected result.
- *
- * @param dataFrame the [[DataFrame]] to be executed
- * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s.
- * @param absTol the absolute tolerance between actual and expected answers.
- */
- protected def checkAggregatesWithTol(dataFrame: DataFrame,
- expectedAnswer: Seq[Row],
- absTol: Double): Unit = {
- // TODO: catch exceptions in data frame execution
- val actualAnswer = dataFrame.collect()
- require(actualAnswer.length == expectedAnswer.length,
- s"actual num rows ${actualAnswer.length} != expected num of rows ${expectedAnswer.length}")
-
- actualAnswer.zip(expectedAnswer).foreach {
- case (actualRow, expectedRow) =>
- QueryTest.checkAggregatesWithTol(actualRow, expectedRow, absTol)
- }
- }
-
- protected def checkAggregatesWithTol(dataFrame: DataFrame,
- expectedAnswer: Row,
- absTol: Double): Unit = {
- checkAggregatesWithTol(dataFrame, Seq(expectedAnswer), absTol)
- }
-
- /**
- * Asserts that a given [[Dataset]] will be executed using the given number of cached results.
- */
- def assertCached(query: Dataset[_], numCachedTables: Int = 1): Unit = {
- val planWithCaching = query.queryExecution.withCachedData
- val cachedData = planWithCaching collect {
- case cached: InMemoryRelation => cached
- }
-
- assert(
- cachedData.size == numCachedTables,
- s"Expected query to contain $numCachedTables, but it actually had ${cachedData.size}\n" +
- planWithCaching)
- }
-
- /**
- * Asserts that a given [[Dataset]] does not have missing inputs in all the analyzed plans.
- */
- def assertEmptyMissingInput(query: Dataset[_]): Unit = {
- assert(query.queryExecution.analyzed.missingInput.isEmpty,
- s"The analyzed logical plan has missing inputs:\n${query.queryExecution.analyzed}")
- assert(query.queryExecution.optimizedPlan.missingInput.isEmpty,
- s"The optimized logical plan has missing inputs:\n${query.queryExecution.optimizedPlan}")
- assert(query.queryExecution.executedPlan.missingInput.isEmpty,
- s"The physical plan has missing inputs:\n${query.queryExecution.executedPlan}")
- }
-}
-
-object QueryTest {
- /**
- * Runs the plan and makes sure the answer matches the expected result.
- * If there was exception during the execution or the contents of the DataFrame does not
- * match the expected result, an error message will be returned. Otherwise, a [[None]] will
- * be returned.
- *
- * @param df the [[DataFrame]] to be executed
- * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s.
- * @param checkToRDD whether to verify deserialization to an RDD. This runs the query twice.
- */
- def checkAnswer(
- df: DataFrame,
- expectedAnswer: Seq[Row],
- checkToRDD: Boolean = true): Option[String] = {
- val isSorted = df.logicalPlan.collect { case s: logical.Sort => s }.nonEmpty
- if (checkToRDD) {
- df.rdd.count() // Also attempt to deserialize as an RDD [SPARK-15791]
- }
-
- val sparkAnswer = try df.collect().toSeq catch {
- case e: Exception =>
- val errorMessage =
- s"""
- |Exception thrown while executing query:
- |${df.queryExecution}
- |== Exception ==
- |$e
- |${org.apache.spark.sql.catalyst.util.stackTraceToString(e)}
- """.stripMargin
- return Some(errorMessage)
- }
-
- sameRows(expectedAnswer, sparkAnswer, isSorted).map { results =>
- s"""
- |Results do not match for query:
- |Timezone: ${TimeZone.getDefault}
- |Timezone Env: ${sys.env.getOrElse("TZ", "")}
- |
- |${df.queryExecution}
- |== Results ==
- |$results
- """.stripMargin
- }
- }
-
-
- def prepareAnswer(answer: Seq[Row], isSorted: Boolean): Seq[Row] = {
- // Converts data to types that we can do equality comparison using Scala collections.
- // For BigDecimal type, the Scala type has a better definition of equality test (similar to
- // Java's java.math.BigDecimal.compareTo).
- // For binary arrays, we convert it to Seq to avoid of calling java.util.Arrays.equals for
- // equality test.
- val converted: Seq[Row] = answer.map(prepareRow)
- if (!isSorted) converted.sortBy(_.toString()) else converted
- }
-
- // We need to call prepareRow recursively to handle schemas with struct types.
- def prepareRow(row: Row): Row = {
- Row.fromSeq(row.toSeq.map {
- case null => null
- case d: java.math.BigDecimal => BigDecimal(d)
- // Convert array to Seq for easy equality check.
- case b: Array[_] => b.toSeq
- case r: Row => prepareRow(r)
- case o => o
- })
- }
-
- def sameRows(
- expectedAnswer: Seq[Row],
- sparkAnswer: Seq[Row],
- isSorted: Boolean = false): Option[String] = {
- if (prepareAnswer(expectedAnswer, isSorted) != prepareAnswer(sparkAnswer, isSorted)) {
- val errorMessage =
- s"""
- |== Results ==
- |${sideBySide(
- s"== Correct Answer - ${expectedAnswer.size} ==" +:
- prepareAnswer(expectedAnswer, isSorted).map(_.toString()),
- s"== Spark Answer - ${sparkAnswer.size} ==" +:
- prepareAnswer(sparkAnswer, isSorted).map(_.toString())).mkString("\n")}
- """.stripMargin
- return Some(errorMessage)
- }
- None
- }
-
- /**
- * Runs the plan and makes sure the answer is within absTol of the expected result.
- *
- * @param actualAnswer the actual result in a [[Row]].
- * @param expectedAnswer the expected result in a[[Row]].
- * @param absTol the absolute tolerance between actual and expected answers.
- */
- protected def checkAggregatesWithTol(actualAnswer: Row, expectedAnswer: Row, absTol: Double) = {
- require(actualAnswer.length == expectedAnswer.length,
- s"actual answer length ${actualAnswer.length} != " +
- s"expected answer length ${expectedAnswer.length}")
-
- // TODO: support other numeric types besides Double
- // TODO: support struct types?
- actualAnswer.toSeq.zip(expectedAnswer.toSeq).foreach {
- case (actual: Double, expected: Double) =>
- assert(math.abs(actual - expected) < absTol,
- s"actual answer $actual not within $absTol of correct answer $expected")
- case (actual, expected) =>
- assert(actual == expected, s"$actual did not equal $expected")
- }
- }
-
- def checkAnswer(df: DataFrame, expectedAnswer: java.util.List[Row]): String = {
- checkAnswer(df, expectedAnswer.asScala) match {
- case Some(errorMessage) => errorMessage
- case None => null
- }
- }
-}
diff --git a/spark/spark-2.2/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/spark/spark-2.2/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
deleted file mode 100644
index a4aeaa6..0000000
--- a/spark/spark-2.2/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
+++ /dev/null
@@ -1,137 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.spark.sql.catalyst.plans
-
-import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
-import org.apache.spark.sql.catalyst.plans.logical._
-import org.apache.spark.sql.catalyst.util._
-import org.apache.spark.sql.internal.SQLConf
-
-/**
- * Provides helper methods for comparing plans.
- */
-abstract class PlanTest extends SparkFunSuite with PredicateHelper {
-
- protected val conf = new SQLConf().copy(SQLConf.CASE_SENSITIVE -> true)
-
- /**
- * Since attribute references are given globally unique ids during analysis,
- * we must normalize them to check if two different queries are identical.
- */
- protected def normalizeExprIds(plan: LogicalPlan) = {
- plan transformAllExpressions {
- case s: ScalarSubquery =>
- s.copy(exprId = ExprId(0))
- case e: Exists =>
- e.copy(exprId = ExprId(0))
- case l: ListQuery =>
- l.copy(exprId = ExprId(0))
- case a: AttributeReference =>
- AttributeReference(a.name, a.dataType, a.nullable)(exprId = ExprId(0))
- case a: Alias =>
- Alias(a.child, a.name)(exprId = ExprId(0))
- case ae: AggregateExpression =>
- ae.copy(resultId = ExprId(0))
- }
- }
-
- /**
- * Normalizes plans:
- * - Filter the filter conditions that appear in a plan. For instance,
- * ((expr 1 && expr 2) && expr 3), (expr 1 && expr 2 && expr 3), (expr 3 && (expr 1 && expr 2)
- * etc., will all now be equivalent.
- * - Sample the seed will replaced by 0L.
- * - Join conditions will be resorted by hashCode.
- */
- protected def normalizePlan(plan: LogicalPlan): LogicalPlan = {
- plan transform {
- case filter @ Filter(condition: Expression, child: LogicalPlan) =>
- Filter(splitConjunctivePredicates(condition).map(rewriteEqual(_)).sortBy(_.hashCode())
- .reduce(And), child)
- case sample: Sample =>
- sample.copy(seed = 0L)(true)
- case join @ Join(left, right, joinType, condition) if condition.isDefined =>
- val newCondition =
- splitConjunctivePredicates(condition.get).map(rewriteEqual(_)).sortBy(_.hashCode())
- .reduce(And)
- Join(left, right, joinType, Some(newCondition))
- }
- }
-
- /**
- * Rewrite [[EqualTo]] and [[EqualNullSafe]] operator to keep order. The following cases will be
- * equivalent:
- * 1. (a = b), (b = a);
- * 2. (a <=> b), (b <=> a).
- */
- private def rewriteEqual(condition: Expression): Expression = condition match {
- case eq @ EqualTo(l: Expression, r: Expression) =>
- Seq(l, r).sortBy(_.hashCode()).reduce(EqualTo)
- case eq @ EqualNullSafe(l: Expression, r: Expression) =>
- Seq(l, r).sortBy(_.hashCode()).reduce(EqualNullSafe)
- case _ => condition // Don't reorder.
- }
-
- /** Fails the test if the two plans do not match */
- protected def comparePlans(plan1: LogicalPlan, plan2: LogicalPlan) {
- val normalized1 = normalizePlan(normalizeExprIds(plan1))
- val normalized2 = normalizePlan(normalizeExprIds(plan2))
- if (normalized1 != normalized2) {
- fail(
- s"""
- |== FAIL: Plans do not match ===
- |${sideBySide(normalized1.treeString, normalized2.treeString).mkString("\n")}
- """.stripMargin)
- }
- }
-
- /** Fails the test if the two expressions do not match */
- protected def compareExpressions(e1: Expression, e2: Expression): Unit = {
- comparePlans(Filter(e1, OneRowRelation), Filter(e2, OneRowRelation))
- }
-
- /** Fails the test if the join order in the two plans do not match */
- protected def compareJoinOrder(plan1: LogicalPlan, plan2: LogicalPlan) {
- val normalized1 = normalizePlan(normalizeExprIds(plan1))
- val normalized2 = normalizePlan(normalizeExprIds(plan2))
- if (!sameJoinPlan(normalized1, normalized2)) {
- fail(
- s"""
- |== FAIL: Plans do not match ===
- |${sideBySide(normalized1.treeString, normalized2.treeString).mkString("\n")}
- """.stripMargin)
- }
- }
-
- /** Consider symmetry for joins when comparing plans. */
- private def sameJoinPlan(plan1: LogicalPlan, plan2: LogicalPlan): Boolean = {
- (plan1, plan2) match {
- case (j1: Join, j2: Join) =>
- (sameJoinPlan(j1.left, j2.left) && sameJoinPlan(j1.right, j2.right)) ||
- (sameJoinPlan(j1.left, j2.right) && sameJoinPlan(j1.right, j2.left))
- case (p1: Project, p2: Project) =>
- p1.projectList == p2.projectList && sameJoinPlan(p1.child, p2.child)
- case _ =>
- plan1 == plan2
- }
- }
-}
diff --git a/spark/spark-2.2/src/test/scala/org/apache/spark/sql/execution/benchmark/BenchmarkBase.scala b/spark/spark-2.2/src/test/scala/org/apache/spark/sql/execution/benchmark/BenchmarkBase.scala
deleted file mode 100644
index 8283503..0000000
--- a/spark/spark-2.2/src/test/scala/org/apache/spark/sql/execution/benchmark/BenchmarkBase.scala
+++ /dev/null
@@ -1,56 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.spark.sql.execution.benchmark
-
-import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.SparkSession
-import org.apache.spark.util.Benchmark
-
-/**
- * Common base trait for micro benchmarks that are supposed to run standalone (i.e. not together
- * with other test suites).
- */
-private[sql] trait BenchmarkBase extends SparkFunSuite {
-
- lazy val sparkSession = SparkSession.builder
- .master("local[1]")
- .appName("microbenchmark")
- .config("spark.sql.shuffle.partitions", 1)
- .config("spark.sql.autoBroadcastJoinThreshold", 1)
- .getOrCreate()
-
- /** Runs function `f` with whole stage codegen on and off. */
- def runBenchmark(name: String, cardinality: Long)(f: => Unit): Unit = {
- val benchmark = new Benchmark(name, cardinality)
-
- benchmark.addCase(s"$name wholestage off", numIters = 2) { iter =>
- sparkSession.conf.set("spark.sql.codegen.wholeStage", value = false)
- f
- }
-
- benchmark.addCase(s"$name wholestage on", numIters = 5) { iter =>
- sparkSession.conf.set("spark.sql.codegen.wholeStage", value = true)
- f
- }
-
- benchmark.run()
- }
-
-}
diff --git a/spark/spark-2.2/src/test/scala/org/apache/spark/sql/hive/HiveUdfSuite.scala b/spark/spark-2.2/src/test/scala/org/apache/spark/sql/hive/HiveUdfSuite.scala
deleted file mode 100644
index 234f562..0000000
--- a/spark/spark-2.2/src/test/scala/org/apache/spark/sql/hive/HiveUdfSuite.scala
+++ /dev/null
@@ -1,161 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.spark.sql.hive
-
-import org.apache.spark.sql.Row
-import org.apache.spark.sql.hive.HivemallUtils._
-import org.apache.spark.sql.hive.test.HivemallFeatureQueryTest
-import org.apache.spark.sql.test.VectorQueryTest
-
-final class HiveUdfWithFeatureSuite extends HivemallFeatureQueryTest {
- import hiveContext.implicits._
- import hiveContext._
-
- test("hivemall_version") {
- sql(s"""
- | CREATE TEMPORARY FUNCTION hivemall_version
- | AS '${classOf[hivemall.HivemallVersionUDF].getName}'
- """.stripMargin)
-
- checkAnswer(
- sql(s"SELECT DISTINCT hivemall_version()"),
- Row("0.6.0-incubating-SNAPSHOT")
- )
-
- // sql("DROP TEMPORARY FUNCTION IF EXISTS hivemall_version")
- // reset()
- }
-
- test("train_logregr") {
- TinyTrainData.createOrReplaceTempView("TinyTrainData")
- sql(s"""
- | CREATE TEMPORARY FUNCTION train_logregr
- | AS '${classOf[hivemall.regression.LogressUDTF].getName}'
- """.stripMargin)
- sql(s"""
- | CREATE TEMPORARY FUNCTION add_bias
- | AS '${classOf[hivemall.ftvec.AddBiasUDFWrapper].getName}'
- """.stripMargin)
-
- val model = sql(
- s"""
- | SELECT feature, AVG(weight) AS weight
- | FROM (
- | SELECT train_logregr(add_bias(features), label) AS (feature, weight)
- | FROM TinyTrainData
- | ) t
- | GROUP BY feature
- """.stripMargin)
-
- checkAnswer(
- model.select($"feature"),
- Seq(Row("0"), Row("1"), Row("2"))
- )
-
- // TODO: Why 'train_logregr' is not registered in HiveMetaStore?
- // ERROR RetryingHMSHandler: MetaException(message:NoSuchObjectException
- // (message:Function default.train_logregr does not exist))
- //
- // hiveContext.sql("DROP TEMPORARY FUNCTION IF EXISTS train_logregr")
- // hiveContext.reset()
- }
-
- test("each_top_k") {
- val testDf = Seq(
- ("a", "1", 0.5, Array(0, 1, 2)),
- ("b", "5", 0.1, Array(3)),
- ("a", "3", 0.8, Array(2, 5)),
- ("c", "6", 0.3, Array(1, 3)),
- ("b", "4", 0.3, Array(2)),
- ("a", "2", 0.6, Array(1))
- ).toDF("key", "value", "score", "data")
-
- import testDf.sqlContext.implicits._
- testDf.repartition($"key").sortWithinPartitions($"key").createOrReplaceTempView("TestData")
- sql(s"""
- | CREATE TEMPORARY FUNCTION each_top_k
- | AS '${classOf[hivemall.tools.EachTopKUDTF].getName}'
- """.stripMargin)
-
- // Compute top-1 rows for each group
- checkAnswer(
- sql("SELECT each_top_k(1, key, score, key, value) FROM TestData"),
- Row(1, 0.8, "a", "3") ::
- Row(1, 0.3, "b", "4") ::
- Row(1, 0.3, "c", "6") ::
- Nil
- )
-
- // Compute reverse top-1 rows for each group
- checkAnswer(
- sql("SELECT each_top_k(-1, key, score, key, value) FROM TestData"),
- Row(1, 0.5, "a", "1") ::
- Row(1, 0.1, "b", "5") ::
- Row(1, 0.3, "c", "6") ::
- Nil
- )
- }
-}
-
-final class HiveUdfWithVectorSuite extends VectorQueryTest {
- import hiveContext._
-
- test("to_hivemall_features") {
- mllibTrainDf.createOrReplaceTempView("mllibTrainDf")
- hiveContext.udf.register("to_hivemall_features", to_hivemall_features_func)
- checkAnswer(
- sql(
- s"""
- | SELECT to_hivemall_features(features)
- | FROM mllibTrainDf
- """.stripMargin),
- Seq(
- Row(Seq("0:1.0", "2:2.0", "4:3.0")),
- Row(Seq("0:1.0", "3:1.5", "4:2.1", "6:1.2")),
- Row(Seq("0:1.1", "3:1.0", "4:2.3", "6:1.0")),
- Row(Seq("1:4.0", "3:5.0", "5:6.0"))
- )
- )
- }
-
- test("append_bias") {
- mllibTrainDf.createOrReplaceTempView("mllibTrainDf")
- hiveContext.udf.register("append_bias", append_bias_func)
- hiveContext.udf.register("to_hivemall_features", to_hivemall_features_func)
- checkAnswer(
- sql(
- s"""
- | SELECT to_hivemall_features(append_bias(features))
- | FROM mllibTrainDF
- """.stripMargin),
- Seq(
- Row(Seq("0:1.0", "2:2.0", "4:3.0", "7:1.0")),
- Row(Seq("0:1.0", "3:1.5", "4:2.1", "6:1.2", "7:1.0")),
- Row(Seq("0:1.1", "3:1.0", "4:2.3", "6:1.0", "7:1.0")),
- Row(Seq("1:4.0", "3:5.0", "5:6.0", "7:1.0"))
- )
- )
- }
-
- ignore("explode_vector") {
- // TODO: Spark-2.0 does not support use-defined generator function in
- // `org.apache.spark.sql.UDFRegistration`.
- }
-}
diff --git a/spark/spark-2.2/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala b/spark/spark-2.2/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
deleted file mode 100644
index f8d377a..0000000
--- a/spark/spark-2.2/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
+++ /dev/null
@@ -1,1397 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.spark.sql.hive
-
-import org.apache.spark.sql.{AnalysisException, Row}
-import org.apache.spark.sql.functions._
-import org.apache.spark.sql.hive.HivemallGroupedDataset._
-import org.apache.spark.sql.hive.HivemallOps._
-import org.apache.spark.sql.hive.HivemallUtils._
-import org.apache.spark.sql.hive.test.HivemallFeatureQueryTest
-import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.test.VectorQueryTest
-import org.apache.spark.sql.types._
-import org.apache.spark.test.TestFPWrapper._
-import org.apache.spark.test.TestUtils
-
-
-class HivemallOpsWithFeatureSuite extends HivemallFeatureQueryTest {
-
- test("anomaly") {
- import hiveContext.implicits._
- val df = spark.range(1000).selectExpr("id AS time", "rand() AS x")
- // TODO: Test results more strictly
- assert(df.sort($"time".asc).select(changefinder($"x")).count === 1000)
- assert(df.sort($"time".asc).select(sst($"x", lit("-th 0.005"))).count === 1000)
- }
-
- test("knn.similarity") {
- import hiveContext.implicits._
-
- val df1 = DummyInputData.select(
- cosine_similarity(typedLit(Seq(1, 2, 3, 4)), typedLit(Seq(3, 4, 5, 6))))
- val rows1 = df1.collect
- assert(rows1.length == 1)
- assert(rows1(0).getFloat(0) ~== 0.500f)
-
- val df2 = DummyInputData.select(jaccard_similarity(lit(5), lit(6)))
- val rows2 = df2.collect
- assert(rows2.length == 1)
- assert(rows2(0).getFloat(0) ~== 0.96875f)
-
- val df3 = DummyInputData.select(
- angular_similarity(typedLit(Seq(1, 2, 3)), typedLit(Seq(4, 5, 6))))
- val rows3 = df3.collect
- assert(rows3.length == 1)
- assert(rows3(0).getFloat(0) ~== 0.500f)
-
- val df4 = DummyInputData.select(
- euclid_similarity(typedLit(Seq(5, 3, 1)), typedLit(Seq(2, 8, 3))))
- val rows4 = df4.collect
- assert(rows4.length == 1)
- assert(rows4(0).getFloat(0) ~== 0.33333334f)
-
- val df5 = DummyInputData.select(distance2similarity(lit(1.0)))
- val rows5 = df5.collect
- assert(rows5.length == 1)
- assert(rows5(0).getFloat(0) ~== 0.5f)
-
- val df6 = Seq((Seq("1:0.3", "4:0.1"), Map(0 -> 0.5))).toDF("a", "b")
- // TODO: Currently, just check if no exception thrown
- assert(df6.dimsum_mapper(df6("a"), df6("b")).collect.isEmpty)
- }
-
- test("knn.distance") {
- val df1 = DummyInputData.select(hamming_distance(lit(1), lit(3)))
- checkAnswer(df1, Row(1))
-
- val df2 = DummyInputData.select(popcnt(lit(1)))
- checkAnswer(df2, Row(1))
-
- val rows3 = DummyInputData.select(kld(lit(0.1), lit(0.5), lit(0.2), lit(0.5))).collect
- assert(rows3.length === 1)
- assert(rows3(0).getDouble(0) ~== 0.01)
-
- val rows4 = DummyInputData.select(
- euclid_distance(typedLit(Seq("0.1", "0.5")), typedLit(Seq("0.2", "0.5")))).collect
- assert(rows4.length === 1)
- assert(rows4(0).getFloat(0) ~== 1.4142135f)
-
- val rows5 = DummyInputData.select(
- cosine_distance(typedLit(Seq("0.8", "0.3")), typedLit(Seq("0.4", "0.6")))).collect
- assert(rows5.length === 1)
- assert(rows5(0).getFloat(0) ~== 1.0f)
-
- val rows6 = DummyInputData.select(
- angular_distance(typedLit(Seq("0.1", "0.1")), typedLit(Seq("0.3", "0.8")))).collect
- assert(rows6.length === 1)
- assert(rows6(0).getFloat(0) ~== 0.50f)
-
- val rows7 = DummyInputData.select(
- manhattan_distance(typedLit(Seq("0.7", "0.8")), typedLit(Seq("0.5", "0.6")))).collect
- assert(rows7.length === 1)
- assert(rows7(0).getFloat(0) ~== 4.0f)
-
- val rows8 = DummyInputData.select(
- minkowski_distance(typedLit(Seq("0.1", "0.2")), typedLit(Seq("0.2", "0.2")), typedLit(1.0))
- ).collect
- assert(rows8.length === 1)
- assert(rows8(0).getFloat(0) ~== 2.0f)
-
- val rows9 = DummyInputData.select(
- jaccard_distance(typedLit(Seq("0.3", "0.8")), typedLit(Seq("0.1", "0.2")))).collect
- assert(rows9.length === 1)
- assert(rows9(0).getFloat(0) ~== 1.0f)
- }
-
- test("knn.lsh") {
- import hiveContext.implicits._
- checkAnswer(
- IntList2Data.minhash(lit(1), $"target"),
- Row(1016022700, 1) ::
- Row(1264890450, 1) ::
- Row(1304330069, 1) ::
- Row(1321870696, 1) ::
- Row(1492709716, 1) ::
- Row(1511363108, 1) ::
- Row(1601347428, 1) ::
- Row(1974434012, 1) ::
- Row(2022223284, 1) ::
- Row(326269457, 1) ::
- Row(50559334, 1) ::
- Row(716040854, 1) ::
- Row(759249519, 1) ::
- Row(809187771, 1) ::
- Row(900899651, 1) ::
- Nil
- )
- checkAnswer(
- DummyInputData.select(bbit_minhash(typedLit(Seq("1:0.1", "2:0.5")), lit(false))),
- Row("31175986876675838064867796245644543067")
- )
- checkAnswer(
- DummyInputData.select(minhashes(typedLit(Seq("1:0.1", "2:0.5")), lit(false))),
- Row(Seq(1571683640, 987207869, 370931990, 988455638, 846963275))
- )
- }
-
- test("ftvec - add_bias") {
- import hiveContext.implicits._
- checkAnswer(TinyTrainData.select(add_bias($"features")),
- Row(Seq("1:0.8", "2:0.2", "0:1.0")) ::
- Row(Seq("2:0.7", "0:1.0")) ::
- Row(Seq("1:0.9", "0:1.0")) ::
- Nil
- )
- }
-
- test("ftvec - extract_feature") {
- val df = DummyInputData.select(extract_feature(lit("1:0.8")))
- checkAnswer(df, Row("1"))
- }
-
- test("ftvec - extract_weight") {
- val rows = DummyInputData.select(extract_weight(lit("3:0.1"))).collect
- assert(rows.length === 1)
- assert(rows(0).getDouble(0) ~== 0.1)
- }
-
- test("ftvec - explode_array") {
- import hiveContext.implicits._
- val df = TinyTrainData.explode_array($"features").select($"feature")
- checkAnswer(df, Row("1:0.8") :: Row("2:0.2") :: Row("2:0.7") :: Row("1:0.9") :: Nil)
- }
-
- test("ftvec - add_feature_index") {
- import hiveContext.implicits._
- val doubleListData = Seq(Array(0.8, 0.5), Array(0.3, 0.1), Array(0.2)).toDF("data")
- checkAnswer(
- doubleListData.select(add_feature_index($"data")),
- Row(Seq("1:0.8", "2:0.5")) ::
- Row(Seq("1:0.3", "2:0.1")) ::
- Row(Seq("1:0.2")) ::
- Nil
- )
- }
-
- test("ftvec - sort_by_feature") {
- // import hiveContext.implicits._
- val intFloatMapData = {
- // TODO: Use `toDF`
- val rowRdd = hiveContext.sparkContext.parallelize(
- Row(Map(1 -> 0.3f, 2 -> 0.1f, 3 -> 0.5f)) ::
- Row(Map(2 -> 0.4f, 1 -> 0.2f)) ::
- Row(Map(2 -> 0.4f, 3 -> 0.2f, 1 -> 0.1f, 4 -> 0.6f)) ::
- Nil
- )
- hiveContext.createDataFrame(
- rowRdd,
- StructType(
- StructField("data", MapType(IntegerType, FloatType), true) ::
- Nil)
- )
- }
- val sortedKeys = intFloatMapData.select(sort_by_feature(intFloatMapData.col("data")))
- .collect.map {
- case Row(m: Map[Int, Float]) => m.keysIterator.toSeq
- }
- assert(sortedKeys.toSet === Set(Seq(1, 2, 3), Seq(1, 2), Seq(1, 2, 3, 4)))
- }
-
- test("ftvec.hash") {
- checkAnswer(DummyInputData.select(mhash(lit("test"))), Row(4948445))
- checkAnswer(DummyInputData.select(HivemallOps.sha1(lit("test"))), Row(12184508))
- checkAnswer(DummyInputData.select(feature_hashing(typedLit(Seq("1:0.1", "3:0.5")))),
- Row(Seq("11293631:0.1", "4331412:0.5")))
- checkAnswer(DummyInputData.select(array_hash_values(typedLit(Seq("aaa", "bbb")))),
- Row(Seq(4063537, 8459207)))
- checkAnswer(DummyInputData.select(
- prefixed_hash_values(typedLit(Seq("ccc", "ddd")), lit("prefix"))),
- Row(Seq("prefix7873825", "prefix8965544")))
- }
-
- test("ftvec.parting") {
- checkAnswer(DummyInputData.select(polynomial_features(typedLit(Seq("2:0.4", "6:0.1")), lit(2))),
- Row(Seq("2:0.4", "2^2:0.16000001", "2^6:0.040000003", "6:0.1", "6^6:0.010000001")))
- checkAnswer(DummyInputData.select(powered_features(typedLit(Seq("4:0.8", "5:0.2")), lit(2))),
- Row(Seq("4:0.8", "4^2:0.64000005", "5:0.2", "5^2:0.040000003")))
- }
-
- test("ftvec.scaling") {
- val rows1 = TinyTrainData.select(rescale(lit(2.0f), lit(1.0), lit(5.0))).collect
- assert(rows1.length === 3)
- assert(rows1(0).getFloat(0) ~== 0.25f)
- assert(rows1(1).getFloat(0) ~== 0.25f)
- assert(rows1(2).getFloat(0) ~== 0.25f)
- val rows2 = TinyTrainData.select(zscore(lit(1.0f), lit(0.5), lit(0.5))).collect
- assert(rows2.length === 3)
- assert(rows2(0).getFloat(0) ~== 1.0f)
- assert(rows2(1).getFloat(0) ~== 1.0f)
- assert(rows2(2).getFloat(0) ~== 1.0f)
- val df3 = TinyTrainData.select(l2_normalize(TinyTrainData.col("features")))
- checkAnswer(
- df3,
- Row(Seq("1:0.9701425", "2:0.24253562")) ::
- Row(Seq("2:1.0")) ::
- Row(Seq("1:1.0")) ::
- Nil)
- }
-
- test("ftvec.selection - chi2") {
- import hiveContext.implicits._
-
- // See also hivemall.ftvec.selection.ChiSquareUDFTest
- val df = Seq(
- Seq(
- Seq(250.29999999999998, 170.90000000000003, 73.2, 12.199999999999996),
- Seq(296.8, 138.50000000000003, 212.99999999999997, 66.3),
- Seq(329.3999999999999, 148.7, 277.59999999999997, 101.29999999999998)
- ) -> Seq(
- Seq(292.1666753739119, 152.70000455081467, 187.93333893418327, 59.93333511948589),
- Seq(292.1666753739119, 152.70000455081467, 187.93333893418327, 59.93333511948589),
- Seq(292.1666753739119, 152.70000455081467, 187.93333893418327, 59.93333511948589)))
- .toDF("arg0", "arg1")
-
- val rows = df.select(chi2(df("arg0"), df("arg1"))).collect
- assert(rows.length == 1)
- val chi2Val = rows.head.getAs[Row](0).getAs[Seq[Double]](0)
- val pVal = rows.head.getAs[Row](0).getAs[Seq[Double]](1)
-
- (chi2Val, Seq(10.81782088, 3.59449902, 116.16984746, 67.24482759))
- .zipped
- .foreach((actual, expected) => assert(actual ~== expected))
-
- (pVal, Seq(4.47651499e-03, 1.65754167e-01, 5.94344354e-26, 2.50017968e-15))
- .zipped
- .foreach((actual, expected) => assert(actual ~== expected))
- }
-
- test("ftvec.conv - quantify") {
- import hiveContext.implicits._
- val testDf = Seq((1, "aaa", true), (2, "bbb", false), (3, "aaa", false)).toDF
- // This test is done in a single partition because `HivemallOps#quantify` assigns identifiers
- // for non-numerical values in each partition.
- checkAnswer(
- testDf.coalesce(1).quantify(lit(true) +: testDf.cols: _*),
- Row(1, 0, 0) :: Row(2, 1, 1) :: Row(3, 0, 1) :: Nil)
- }
-
- test("ftvec.amplify") {
- import hiveContext.implicits._
- assert(TinyTrainData.amplify(lit(3), $"label", $"features").count() == 9)
- assert(TinyTrainData.part_amplify(lit(3)).count() == 9)
- // TODO: The test below failed because:
- // java.lang.RuntimeException: Unsupported literal type class scala.Tuple3
- // (-buf 128,label,features)
- //
- // assert(TinyTrainData.rand_amplify(lit(3), lit("-buf 8", $"label", $"features")).count() == 9)
- }
-
- test("ftvec.conv") {
- import hiveContext.implicits._
-
- checkAnswer(
- DummyInputData.select(to_dense_features(typedLit(Seq("0:0.1", "1:0.3")), lit(1))),
- Row(Array(0.1f, 0.3f))
- )
- checkAnswer(
- DummyInputData.select(to_sparse_features(typedLit(Seq(0.1f, 0.2f, 0.3f)))),
- Row(Seq("0:0.1", "1:0.2", "2:0.3"))
- )
- checkAnswer(
- DummyInputData.select(feature_binning(typedLit(Seq("1")), typedLit(Map("1" -> Seq(0, 3))))),
- Row(Seq("1"))
- )
- }
-
- test("ftvec.trans") {
- import hiveContext.implicits._
-
- checkAnswer(
- DummyInputData.select(vectorize_features(typedLit(Seq("a", "b")), lit(0.1f), lit(0.2f))),
- Row(Seq("a:0.1", "b:0.2"))
- )
- checkAnswer(
- DummyInputData.select(categorical_features(typedLit(Seq("a", "b")), lit("c11"), lit("c12"))),
- Row(Seq("a#c11", "b#c12"))
- )
- checkAnswer(
- DummyInputData.select(indexed_features(lit(0.1), lit(0.2), lit(0.3))),
- Row(Seq("1:0.1", "2:0.2", "3:0.3"))
- )
- checkAnswer(
- DummyInputData.select(quantitative_features(typedLit(Seq("a", "b")), lit(0.1), lit(0.2))),
- Row(Seq("a:0.1", "b:0.2"))
- )
- checkAnswer(
- DummyInputData.select(ffm_features(typedLit(Seq("1", "2")), lit(0.5), lit(0.2))),
- Row(Seq("190:140405:1", "111:1058718:1"))
- )
- checkAnswer(
- DummyInputData.select(add_field_indices(typedLit(Seq("0.5", "0.1")))),
- Row(Seq("1:0.5", "2:0.1"))
- )
-
- val df1 = Seq((1, -3, 1), (2, -2, 1)).toDF("a", "b", "c")
- checkAnswer(
- df1.binarize_label($"a", $"b", $"c"),
- Row(1, 1) :: Row(1, 1) :: Row(1, 1) :: Nil
- )
- val df2 = Seq(("xxx", "yyy", 0), ("zzz", "yyy", 1)).toDF("a", "b", "c").coalesce(1)
- checkAnswer(
- df2.quantified_features(lit(true), df2("a"), df2("b"), df2("c")),
- Row(Seq(0.0, 0.0, 0.0)) :: Row(Seq(1.0, 0.0, 1.0)) :: Nil
- )
- }
-
- test("ftvec.ranking") {
- import hiveContext.implicits._
-
- val df1 = Seq((1, 0 :: 3 :: 4 :: Nil), (2, 8 :: 9 :: Nil)).toDF("a", "b").coalesce(1)
- checkAnswer(
- df1.bpr_sampling($"a", $"b"),
- Row(1, 0, 7) ::
- Row(1, 3, 6) ::
- Row(2, 8, 0) ::
- Row(2, 8, 4) ::
- Row(2, 9, 7) ::
- Nil
- )
- val df2 = Seq(1 :: 8 :: 9 :: Nil, 0 :: 3 :: Nil).toDF("a").coalesce(1)
- checkAnswer(
- df2.item_pairs_sampling($"a", lit(3)),
- Row(0, 1) ::
- Row(1, 0) ::
- Row(3, 2) ::
- Nil
- )
- val df3 = Seq(3 :: 5 :: Nil, 0 :: Nil).toDF("a").coalesce(1)
- checkAnswer(
- df3.populate_not_in($"a", lit(1)),
- Row(0) ::
- Row(1) ::
- Row(1) ::
- Nil
- )
- }
-
- test("tools") {
- // checkAnswer(
- // DummyInputData.select(convert_label(lit(5))),
- // Nil
- // )
- checkAnswer(
- DummyInputData.select(x_rank(lit("abc"))),
- Row(1)
- )
- }
-
- test("tools.array") {
- checkAnswer(
- DummyInputData.select(float_array(lit(3))),
- Row(Seq())
- )
- checkAnswer(
- DummyInputData.select(array_remove(typedLit(Seq(1, 2, 3)), lit(2))),
- Row(Seq(1, 3))
- )
- checkAnswer(
- DummyInputData.select(sort_and_uniq_array(typedLit(Seq(2, 1, 3, 1)))),
- Row(Seq(1, 2, 3))
- )
- checkAnswer(
- DummyInputData.select(subarray_endwith(typedLit(Seq(1, 2, 3, 4, 5)), lit(4))),
- Row(Seq(1, 2, 3, 4))
- )
- checkAnswer(
- DummyInputData.select(
- array_concat(typedLit(Seq(1, 2)), typedLit(Seq(3)), typedLit(Seq(4, 5)))),
- Row(Seq(1, 2, 3, 4, 5))
- )
- checkAnswer(
- DummyInputData.select(subarray(typedLit(Seq(1, 2, 3, 4, 5)), lit(2), lit(4))),
- Row(Seq(3, 4))
- )
- checkAnswer(
- DummyInputData.select(array_slice(typedLit(Seq(1, 2, 3, 4, 5)), lit(2), lit(4))),
- Row(Seq(3, 4, 5))
- )
- checkAnswer(
- DummyInputData.select(to_string_array(typedLit(Seq(1, 2, 3, 4, 5)))),
- Row(Seq("1", "2", "3", "4", "5"))
- )
- checkAnswer(
- DummyInputData.select(array_intersect(typedLit(Seq(1, 2, 3)), typedLit(Seq(2, 3, 4)))),
- Row(Seq(2, 3))
- )
- }
-
- test("tools.array - select_k_best") {
- import hiveContext.implicits._
-
- val data = Seq(Seq(0, 1, 3), Seq(2, 4, 1), Seq(5, 4, 9))
- val df = data.map(d => (d, Seq(3, 1, 2))).toDF("features", "importance_list")
- val k = 2
-
- checkAnswer(
- df.select(select_k_best(df("features"), df("importance_list"), lit(k))),
- Row(Seq(0.0, 3.0)) :: Row(Seq(2.0, 1.0)) :: Row(Seq(5.0, 9.0)) :: Nil
- )
- }
-
- test("tools.bits") {
- checkAnswer(
- DummyInputData.select(to_bits(typedLit(Seq(1, 3, 9)))),
- Row(Seq(522L))
- )
- checkAnswer(
- DummyInputData.select(unbits(typedLit(Seq(1L, 3L)))),
- Row(Seq(0L, 64L, 65L))
- )
- checkAnswer(
- DummyInputData.select(bits_or(typedLit(Seq(1L, 3L)), typedLit(Seq(8L, 23L)))),
- Row(Seq(9L, 23L))
- )
- }
-
- test("tools.compress") {
- checkAnswer(
- DummyInputData.select(inflate(deflate(lit("input text")))),
- Row("input text")
- )
- }
-
- test("tools.map") {
- val rows = DummyInputData.select(
- map_get_sum(typedLit(Map(1 -> 0.2f, 2 -> 0.5f, 4 -> 0.8f)), typedLit(Seq(1, 4)))
- ).collect
- assert(rows.length === 1)
- assert(rows(0).getDouble(0) ~== 1.0f)
-
- checkAnswer(
- DummyInputData.select(map_tail_n(typedLit(Map(1 -> 2, 2 -> 5)), lit(1))),
- Row(Map(2 -> 5))
- )
- }
-
- test("tools.text") {
- checkAnswer(
- DummyInputData.select(tokenize(lit("This is a pen"))),
- Row("This" :: "is" :: "a" :: "pen" :: Nil)
- )
- checkAnswer(
- DummyInputData.select(is_stopword(lit("because"))),
- Row(true)
- )
- checkAnswer(
- DummyInputData.select(singularize(lit("between"))),
- Row("between")
- )
- checkAnswer(
- DummyInputData.select(split_words(lit("Hello, world"))),
- Row("Hello," :: "world" :: Nil)
- )
- checkAnswer(
- DummyInputData.select(normalize_unicode(lit("abcdefg"))),
- Row("abcdefg")
- )
- checkAnswer(
- DummyInputData.select(base91(typedLit("input text".getBytes))),
- Row("xojg[@TX;R..B")
- )
- checkAnswer(
- DummyInputData.select(unbase91(lit("XXXX"))),
- Row(68 :: -120 :: 8 :: Nil)
- )
- checkAnswer(
- DummyInputData.select(word_ngrams(typedLit("abcd" :: "efg" :: "hij" :: Nil), lit(2), lit(2))),
- Row("abcd efg" :: "efg hij" :: Nil)
- )
- }
-
- test("tools - generated_series") {
- checkAnswer(
- DummyInputData.generate_series(lit(0), lit(3)),
- Row(0) :: Row(1) :: Row(2) :: Row(3) :: Nil
- )
- }
-
- test("geospatial") {
- val rows1 = DummyInputData.select(tilex2lon(lit(1), lit(6))).collect
- assert(rows1.length === 1)
- assert(rows1(0).getDouble(0) ~== -174.375)
-
- val rows2 = DummyInputData.select(tiley2lat(lit(1), lit(3))).collect
- assert(rows2.length === 1)
- assert(rows2(0).getDouble(0) ~== 79.17133464081945)
-
- val rows3 = DummyInputData.select(
- haversine_distance(lit(0.3), lit(0.1), lit(0.4), lit(0.1))).collect
- assert(rows3.length === 1)
- assert(rows3(0).getDouble(0) ~== 11.119492664455878)
-
- checkAnswer(
- DummyInputData.select(tile(lit(0.1), lit(0.8), lit(3))),
- Row(28)
- )
- checkAnswer(
- DummyInputData.select(map_url(lit(0.1), lit(0.8), lit(3))),
- Row("http://tile.openstreetmap.org/3/4/3.png")
- )
- checkAnswer(
- DummyInputData.select(lat2tiley(lit(0.3), lit(3))),
- Row(3)
- )
- checkAnswer(
- DummyInputData.select(lon2tilex(lit(0.4), lit(2))),
- Row(2)
- )
- }
-
- test("misc - hivemall_version") {
- checkAnswer(DummyInputData.select(hivemall_version()), Row("0.6.0-incubating-SNAPSHOT"))
- }
-
- test("misc - rowid") {
- assert(DummyInputData.select(rowid()).distinct.count == DummyInputData.count)
- }
-
- test("misc - each_top_k") {
- import hiveContext.implicits._
- val inputDf = Seq(
- ("a", "1", 0.5, 0.1, Array(0, 1, 2)),
- ("b", "5", 0.1, 0.2, Array(3)),
- ("a", "3", 0.8, 0.8, Array(2, 5)),
- ("c", "6", 0.3, 0.3, Array(1, 3)),
- ("b", "4", 0.3, 0.4, Array(2)),
- ("a", "2", 0.6, 0.5, Array(1))
- ).toDF("key", "value", "x", "y", "data")
-
- // Compute top-1 rows for each group
- val distance = sqrt(inputDf("x") * inputDf("x") + inputDf("y") * inputDf("y")).as("score")
- val top1Df = inputDf.each_top_k(lit(1), distance, $"key".as("group"))
- assert(top1Df.schema.toSet === Set(
- StructField("rank", IntegerType, nullable = true),
- StructField("score", DoubleType, nullable = true),
- StructField("key", StringType, nullable = true),
- StructField("value", StringType, nullable = true),
- StructField("x", DoubleType, nullable = true),
- StructField("y", DoubleType, nullable = true),
- StructField("data", ArrayType(IntegerType, containsNull = false), nullable = true)
- ))
- checkAnswer(
- top1Df.select($"rank", $"key", $"value", $"data"),
- Row(1, "a", "3", Array(2, 5)) ::
- Row(1, "b", "4", Array(2)) ::
- Row(1, "c", "6", Array(1, 3)) ::
- Nil
- )
-
- // Compute reverse top-1 rows for each group
- val bottom1Df = inputDf.each_top_k(lit(-1), distance, $"key".as("group"))
- checkAnswer(
- bottom1Df.select($"rank", $"key", $"value", $"data"),
- Row(1, "a", "1", Array(0, 1, 2)) ::
- Row(1, "b", "5", Array(3)) ::
- Row(1, "c", "6", Array(1, 3)) ::
- Nil
- )
-
- // Check if some exceptions thrown in case of some conditions
- assert(intercept[AnalysisException] { inputDf.each_top_k(lit(0.1), $"score", $"key") }
- .getMessage contains "`k` must be integer, however")
- assert(intercept[AnalysisException] { inputDf.each_top_k(lit(1), $"data", $"key") }
- .getMessage contains "must have a comparable type")
- }
-
- test("misc - join_top_k") {
- Seq("true", "false").map { flag =>
- withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> flag) {
- import hiveContext.implicits._
- val inputDf = Seq(
- ("user1", 1, 0.3, 0.5),
- ("user2", 2, 0.1, 0.1),
- ("user3", 3, 0.8, 0.0),
- ("user4", 1, 0.9, 0.9),
- ("user5", 3, 0.7, 0.2),
- ("user6", 1, 0.5, 0.4),
- ("user7", 2, 0.6, 0.8)
- ).toDF("userId", "group", "x", "y")
-
- val masterDf = Seq(
- (1, "pos1-1", 0.5, 0.1),
- (1, "pos1-2", 0.0, 0.0),
- (1, "pos1-3", 0.3, 0.3),
- (2, "pos2-3", 0.1, 0.3),
- (2, "pos2-3", 0.8, 0.8),
- (3, "pos3-1", 0.1, 0.7),
- (3, "pos3-1", 0.7, 0.1),
- (3, "pos3-1", 0.9, 0.0),
- (3, "pos3-1", 0.1, 0.3)
- ).toDF("group", "position", "x", "y")
-
- // Compute top-1 rows for each group
- val distance = sqrt(
- pow(inputDf("x") - masterDf("x"), lit(2.0)) +
- pow(inputDf("y") - masterDf("y"), lit(2.0))
- ).as("score")
- val top1Df = inputDf.top_k_join(
- lit(1), masterDf, inputDf("group") === masterDf("group"), distance)
- assert(top1Df.schema.toSet === Set(
- StructField("rank", IntegerType, nullable = true),
- StructField("score", DoubleType, nullable = true),
- StructField("group", IntegerType, nullable = false),
- StructField("userId", StringType, nullable = true),
- StructField("position", StringType, nullable = true),
- StructField("x", DoubleType, nullable = false),
- StructField("y", DoubleType, nullable = false)
- ))
- checkAnswer(
- top1Df.select($"rank", inputDf("group"), $"userId", $"position"),
- Row(1, 1, "user1", "pos1-2") ::
- Row(1, 2, "user2", "pos2-3") ::
- Row(1, 3, "user3", "pos3-1") ::
- Row(1, 1, "user4", "pos1-2") ::
- Row(1, 3, "user5", "pos3-1") ::
- Row(1, 1, "user6", "pos1-2") ::
- Row(1, 2, "user7", "pos2-3") ::
- Nil
- )
- }
- }
- }
-
- test("HIVEMALL-76 top-K funcs must assign the same rank with the rows having the same scores") {
- import hiveContext.implicits._
- val inputDf = Seq(
- ("a", "1", 0.1),
- ("b", "5", 0.1),
- ("a", "3", 0.1),
- ("b", "4", 0.1),
- ("a", "2", 0.0)
- ).toDF("key", "value", "x")
-
- // Compute top-2 rows for each group
- val top2Df = inputDf.each_top_k(lit(2), $"x".as("score"), $"key".as("group"))
- checkAnswer(
- top2Df.select($"rank", $"score", $"key", $"value"),
- Row(1, 0.1, "a", "3") ::
- Row(1, 0.1, "a", "1") ::
- Row(1, 0.1, "b", "4") ::
- Row(1, 0.1, "b", "5") ::
- Nil
- )
- Seq("true", "false").map { flag =>
- withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> flag) {
- val inputDf = Seq(
- ("user1", 1, 0.3, 0.5),
- ("user2", 2, 0.1, 0.1)
- ).toDF("userId", "group", "x", "y")
-
- val masterDf = Seq(
- (1, "pos1-1", 0.5, 0.1),
- (1, "pos1-2", 0.5, 0.1),
- (1, "pos1-3", 0.3, 0.4),
- (2, "pos2-1", 0.8, 0.2),
- (2, "pos2-2", 0.8, 0.2)
- ).toDF("group", "position", "x", "y")
-
- // Compute top-2 rows for each group
- val distance = sqrt(
- pow(inputDf("x") - masterDf("x"), lit(2.0)) +
- pow(inputDf("y") - masterDf("y"), lit(2.0))
- ).as("score")
- val top2Df = inputDf.top_k_join(
- lit(2), masterDf, inputDf("group") === masterDf("group"), distance)
- checkAnswer(
- top2Df.select($"rank", inputDf("group"), $"userId", $"position"),
- Row(1, 1, "user1", "pos1-1") ::
- Row(1, 1, "user1", "pos1-2") ::
- Row(1, 2, "user2", "pos2-1") ::
- Row(1, 2, "user2", "pos2-2") ::
- Nil
- )
- }
- }
- }
-
- test("misc - flatten") {
- import hiveContext.implicits._
- val df = Seq((0, (1, "a", (3.0, "b")), (5, 0.9, "c", "d"), 9)).toDF()
- assert(df.flatten().schema === StructType(
- StructField("_1", IntegerType, nullable = false) ::
- StructField("_2$_1", IntegerType, nullable = true) ::
- StructField("_2$_2", StringType, nullable = true) ::
- StructField("_2$_3$_1", DoubleType, nullable = true) ::
- StructField("_2$_3$_2", StringType, nullable = true) ::
- StructField("_3$_1", IntegerType, nullable = true) ::
- StructField("_3$_2", DoubleType, nullable = true) ::
- StructField("_3$_3", StringType, nullable = true) ::
- StructField("_3$_4", StringType, nullable = true) ::
- StructField("_4", IntegerType, nullable = false) ::
- Nil
- ))
- checkAnswer(df.flatten("$").select("_2$_1"), Row(1))
- checkAnswer(df.flatten("_").select("_2__1"), Row(1))
- checkAnswer(df.flatten(".").select("`_2._1`"), Row(1))
-
- val errMsg1 = intercept[IllegalArgumentException] { df.flatten("\t") }
- assert(errMsg1.getMessage.startsWith("Must use '$', '_', or '.' for separator, but got"))
- val errMsg2 = intercept[IllegalArgumentException] { df.flatten("12") }
- assert(errMsg2.getMessage.startsWith("Separator cannot be more than one character:"))
- }
-
- test("misc - from_csv") {
- import hiveContext.implicits._
- val df = Seq("""1,abc""").toDF()
- val schema = new StructType().add("a", IntegerType).add("b", StringType)
- checkAnswer(
- df.select(from_csv($"value", schema)),
- Row(Row(1, "abc")))
- }
-
- test("misc - to_csv") {
- import hiveContext.implicits._
- val df = Seq((1, "a", (0, 3.9, "abc")), (8, "c", (2, 0.4, "def"))).toDF()
- checkAnswer(
- df.select(to_csv($"_3")),
- Row("0,3.9,abc") ::
- Row("2,0.4,def") ::
- Nil)
- }
-
- /**
- * This test fails because;
- *
- * Cause: java.lang.OutOfMemoryError: Java heap space
- * at hivemall.smile.tools.RandomForestEnsembleUDAF$Result.<init>
- * (RandomForestEnsembleUDAF.java:128)
- * at hivemall.smile.tools.RandomForestEnsembleUDAF$RandomForestPredictUDAFEvaluator
- * .terminate(RandomForestEnsembleUDAF.java:91)
- * at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
- */
- ignore("misc - tree_predict") {
- import hiveContext.implicits._
-
- val model = Seq((0.0, 0.1 :: 0.1 :: Nil), (1.0, 0.2 :: 0.3 :: 0.2 :: Nil))
- .toDF("label", "features")
- .train_randomforest_regressor($"features", $"label")
-
- val testData = Seq((0.0, 0.1 :: 0.0 :: Nil), (1.0, 0.3 :: 0.5 :: 0.4 :: Nil))
- .toDF("label", "features")
- .select(rowid(), $"label", $"features")
-
- val predicted = model
- .join(testData).coalesce(1)
- .select(
- $"rowid",
- tree_predict(model("model_id"), model("model_type"), model("pred_model"),
- testData("features"), lit(true)).as("predicted")
- )
- .groupBy($"rowid")
- .rf_ensemble("predicted").toDF("rowid", "predicted")
- .select($"predicted.label")
-
- checkAnswer(predicted, Seq(Row(0), Row(1)))
- }
-
- test("misc - sigmoid") {
- import hiveContext.implicits._
- val rows = DummyInputData.select(sigmoid($"c0")).collect
- assert(rows.length === 1)
- assert(rows(0).getDouble(0) ~== 0.500)
- }
-
- test("misc - lr_datagen") {
- assert(TinyTrainData.lr_datagen(lit("-n_examples 100 -n_features 10 -seed 100")).count >= 100)
- }
-
- test("invoke regression functions") {
- import hiveContext.implicits._
- Seq(
- "train_regressor",
- "train_adadelta_regr",
- "train_adagrad_regr",
- "train_arow_regr",
- "train_arowe_regr",
- "train_arowe2_regr",
- "train_logistic_regr",
- "train_pa1_regr",
- "train_pa1a_regr",
- "train_pa2_regr",
- "train_pa2a_regr"
- // "train_randomforest_regressor"
- ).map { func =>
- TestUtils.invokeFunc(new HivemallOps(TinyTrainData), func, Seq($"features", $"label"))
- .foreach(_ => {}) // Just call it
- }
- }
-
- test("invoke classifier functions") {
- import hiveContext.implicits._
- Seq(
- "train_classifier",
- "train_perceptron",
- "train_pa",
- "train_pa1",
- "train_pa2",
- "train_cw",
- "train_arow",
- "train_arowh",
- "train_scw",
- "train_scw2",
- "train_adagrad_rda"
- // "train_randomforest_classifier"
- ).map { func =>
- TestUtils.invokeFunc(new HivemallOps(TinyTrainData), func, Seq($"features", $"label"))
- .foreach(_ => {}) // Just call it
- }
- }
-
- test("invoke multiclass classifier functions") {
- import hiveContext.implicits._
- Seq(
- "train_multiclass_perceptron",
- "train_multiclass_pa",
- "train_multiclass_pa1",
- "train_multiclass_pa2",
- "train_multiclass_cw",
- "train_multiclass_arow",
- "train_multiclass_arowh",
- "train_multiclass_scw",
- "train_multiclass_scw2"
- ).map { func =>
- // TODO: Why is a label type [Int|Text] only in multiclass classifiers?
- TestUtils.invokeFunc(
- new HivemallOps(TinyTrainData), func, Seq($"features", $"label".cast(IntegerType)))
- .foreach(_ => {}) // Just call it
- }
- }
-
- test("invoke random forest functions") {
- import hiveContext.implicits._
- val testDf = Seq(
- (Array(0.3, 0.1, 0.2), 1),
- (Array(0.3, 0.1, 0.2), 0),
- (Array(0.3, 0.1, 0.2), 0)).toDF("features", "label")
- Seq(
- "train_randomforest_regressor",
- "train_randomforest_classifier"
- ).map { func =>
- TestUtils.invokeFunc(new HivemallOps(testDf.coalesce(1)), func, Seq($"features", $"label"))
- .foreach(_ => {}) // Just call it
- }
- }
-
- test("invoke recommend functions") {
- import hiveContext.implicits._
- val df = Seq((1, Map(1 -> 0.3), Map(2 -> Map(4 -> 0.1)), 0, Map(3 -> 0.5)))
- .toDF("i", "r_i", "topKRatesOfI", "j", "r_j")
- // Just call it
- df.train_slim($"i", $"r_i", $"topKRatesOfI", $"j", $"r_j").collect
-
- }
-
- ignore("invoke topicmodel functions") {
- import hiveContext.implicits._
- val testDf = Seq(Seq("abcd", "'efghij", "klmn")).toDF("words")
- Seq(
- "train_lda",
- "train_plsa"
- ).map { func =>
- TestUtils.invokeFunc(new HivemallOps(testDf.coalesce(1)), func, Seq($"words"))
- .foreach(_ => {}) // Just call it
- }
- }
-
- protected def checkRegrPrecision(func: String): Unit = {
- import hiveContext.implicits._
-
- // Build a model
- val model = {
- val res = TestUtils.invokeFunc(new HivemallOps(LargeRegrTrainData),
- func, Seq(add_bias($"features"), $"label"))
- if (!res.columns.contains("conv")) {
- res.groupBy("feature").agg("weight" -> "avg")
- } else {
- res.groupBy("feature").argmin_kld("weight", "conv")
- }
- }.toDF("feature", "weight")
-
- // Data preparation
- val testDf = LargeRegrTrainData
- .select(rowid(), $"label".as("target"), $"features")
- .cache
-
- val testDf_exploded = testDf
- .explode_array($"features")
- .select($"rowid", extract_feature($"feature"), extract_weight($"feature"))
-
- // Do prediction
- val predict = testDf_exploded
- .join(model, testDf_exploded("feature") === model("feature"), "LEFT_OUTER")
- .select($"rowid", ($"weight" * $"value").as("value"))
- .groupBy("rowid").sum("value")
- .toDF("rowid", "predicted")
-
- // Evaluation
- val eval = predict
- .join(testDf, predict("rowid") === testDf("rowid"))
- .groupBy()
- .agg(Map("target" -> "avg", "predicted" -> "avg"))
- .toDF("target", "predicted")
-
- val diff = eval.map {
- case Row(target: Double, predicted: Double) =>
- Math.abs(target - predicted)
- }.first
-
- TestUtils.expectResult(diff > 0.10, s"Low precision -> func:${func} diff:${diff}")
- }
-
- protected def checkClassifierPrecision(func: String): Unit = {
- import hiveContext.implicits._
-
- // Build a model
- val model = {
- val res = TestUtils.invokeFunc(new HivemallOps(LargeClassifierTrainData),
- func, Seq(add_bias($"features"), $"label"))
- if (!res.columns.contains("conv")) {
- res.groupBy("feature").agg("weight" -> "avg")
- } else {
- res.groupBy("feature").argmin_kld("weight", "conv")
- }
- }.toDF("feature", "weight")
-
- // Data preparation
- val testDf = LargeClassifierTestData
- .select(rowid(), $"label".as("target"), $"features")
- .cache
-
- val testDf_exploded = testDf
- .explode_array($"features")
- .select($"rowid", extract_feature($"feature"), extract_weight($"feature"))
-
- // Do prediction
- val predict = testDf_exploded
- .join(model, testDf_exploded("feature") === model("feature"), "LEFT_OUTER")
- .select($"rowid", ($"weight" * $"value").as("value"))
- .groupBy("rowid").sum("value")
- /**
- * TODO: This sentence throws an exception below:
- *
- * WARN Column: Constructing trivially true equals predicate, 'rowid#1323 = rowid#1323'.
- * Perhaps you need to use aliases.
- */
- .select($"rowid", when(sigmoid($"sum(value)") > 0.50, 1.0).otherwise(0.0))
- .toDF("rowid", "predicted")
-
- // Evaluation
- val eval = predict
- .join(testDf, predict("rowid") === testDf("rowid"))
- .where($"target" === $"predicted")
-
- val precision = (eval.count + 0.0) / predict.count
-
... 10439 lines suppressed ...