You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hivemall.apache.org by my...@apache.org on 2017/09/13 12:18:28 UTC
[4/4] incubator-hivemall git commit: Close #122:
[HIVEMALL-133][SPARK] Support spark-v2.2 in the hivemalls-spark module
Close #122: [HIVEMALL-133][SPARK] Support spark-v2.2 in the hivemalls-spark module
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/8bf6dd9e
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/8bf6dd9e
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/8bf6dd9e
Branch: refs/heads/master
Commit: 8bf6dd9e760b1d4bfdf9046fdf09e62f46f97d37
Parents: 688daa5
Author: Takeshi Yamamuro <ya...@apache.org>
Authored: Wed Sep 13 21:18:06 2017 +0900
Committer: Makoto Yui <my...@apache.org>
Committed: Wed Sep 13 21:18:06 2017 +0900
----------------------------------------------------------------------
.travis.yml | 4 +-
bin/run_travis_tests.sh | 47 +
pom.xml | 44 +
spark/spark-2.2/bin/mvn-zinc | 99 ++
spark/spark-2.2/extra-src/README.md | 1 +
.../org/apache/spark/sql/hive/HiveShim.scala | 279 ++++
spark/spark-2.2/pom.xml | 269 +++
.../java/hivemall/xgboost/XGBoostOptions.scala | 59 +
....apache.spark.sql.sources.DataSourceRegister | 1 +
.../src/main/resources/log4j.properties | 12 +
.../hivemall/tools/RegressionDatagen.scala | 67 +
.../sql/catalyst/expressions/EachTopK.scala | 133 ++
.../sql/catalyst/plans/logical/JoinTopK.scala | 68 +
.../utils/InternalRowPriorityQueue.scala | 76 +
.../sql/execution/UserProvidedPlanner.scala | 83 +
.../datasources/csv/csvExpressions.scala | 169 ++
.../joins/ShuffledHashJoinTopKExec.scala | 405 +++++
.../spark/sql/hive/HivemallGroupedDataset.scala | 304 ++++
.../org/apache/spark/sql/hive/HivemallOps.scala | 1538 ++++++++++++++++++
.../apache/spark/sql/hive/HivemallUtils.scala | 146 ++
.../sql/hive/internal/HivemallOpsImpl.scala | 79 +
.../sql/hive/source/XGBoostFileFormat.scala | 163 ++
.../src/test/resources/data/files/README.md | 3 +
.../src/test/resources/data/files/complex.seq | 0
.../src/test/resources/data/files/episodes.avro | 0
.../src/test/resources/data/files/json.txt | 0
.../src/test/resources/data/files/kv1.txt | 0
.../src/test/resources/data/files/kv3.txt | 0
.../src/test/resources/log4j.properties | 7 +
.../hivemall/mix/server/MixServerSuite.scala | 124 ++
.../hivemall/tools/RegressionDatagenSuite.scala | 33 +
.../scala/org/apache/spark/SparkFunSuite.scala | 51 +
.../ml/feature/HivemallLabeledPointSuite.scala | 36 +
.../scala/org/apache/spark/sql/QueryTest.scala | 360 ++++
.../spark/sql/catalyst/plans/PlanTest.scala | 137 ++
.../sql/execution/benchmark/BenchmarkBase.scala | 56 +
.../apache/spark/sql/hive/HiveUdfSuite.scala | 161 ++
.../spark/sql/hive/HivemallOpsSuite.scala | 961 +++++++++++
.../spark/sql/hive/ModelMixingSuite.scala | 286 ++++
.../apache/spark/sql/hive/XGBoostSuite.scala | 151 ++
.../sql/hive/benchmark/MiscBenchmark.scala | 268 +++
.../hive/test/HivemallFeatureQueryTest.scala | 113 ++
.../spark/sql/hive/test/TestHiveSingleton.scala | 39 +
.../org/apache/spark/sql/test/SQLTestData.scala | 315 ++++
.../apache/spark/sql/test/SQLTestUtils.scala | 336 ++++
.../apache/spark/sql/test/VectorQueryTest.scala | 89 +
.../streaming/HivemallOpsWithFeatureSuite.scala | 155 ++
.../scala/org/apache/spark/test/TestUtils.scala | 65 +
48 files changed, 7789 insertions(+), 3 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8bf6dd9e/.travis.yml
----------------------------------------------------------------------
diff --git a/.travis.yml b/.travis.yml
index 96f8f4e..c64c5ff 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -34,9 +34,7 @@ notifications:
email: false
script:
- - mvn -q scalastyle:check test -Pspark-2.1
- # test the spark-2.0 modules only in the following runs
- - mvn -q scalastyle:check clean -Pspark-2.0 -pl spark/spark-2.0 -am test -Dtest=none
+ - ./bin/run_travis_tests.sh
after_success:
- mvn clean cobertura:cobertura coveralls:report
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8bf6dd9e/bin/run_travis_tests.sh
----------------------------------------------------------------------
diff --git a/bin/run_travis_tests.sh b/bin/run_travis_tests.sh
new file mode 100755
index 0000000..f1bffec
--- /dev/null
+++ b/bin/run_travis_tests.sh
@@ -0,0 +1,47 @@
+#!/bin/sh
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+if [ "$HIVEMALL_HOME" = "" ]; then
+ if [ -e ../bin/${0##*/} ]; then
+ HIVEMALL_HOME=".."
+ elif [ -e ./bin/${0##*/} ]; then
+ HIVEMALL_HOME="."
+ else
+ echo "env HIVEMALL_HOME not defined"
+ exit 1
+ fi
+fi
+
+set -ev
+
+cd $HIVEMALL_HOME
+
+mvn -q scalastyle:check test -Pspark-2.1
+
+# Tests the spark-2.2/spark-2.0 modules only in the following runs
+if [[ ! -z "$(java -version 2>&1 | grep 1.8)" ]]; then
+ mvn -q scalastyle:check clean -Djava.source.version=1.8 -Djava.target.version=1.8 \
+ -Pspark-2.2 -pl spark/spark-2.2 -am test -Dtest=none
+fi
+
+mvn -q scalastyle:check clean -Pspark-2.0 -pl spark/spark-2.0 -am test -Dtest=none
+
+exit 0
+
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8bf6dd9e/pom.xml
----------------------------------------------------------------------
diff --git a/pom.xml b/pom.xml
index 3d7040c..8a543e6 100644
--- a/pom.xml
+++ b/pom.xml
@@ -267,6 +267,50 @@
<profiles>
<profile>
+ <id>spark-2.2</id>
+ <modules>
+ <module>spark/spark-2.2</module>
+ <module>spark/spark-common</module>
+ </modules>
+ <properties>
+ <spark.version>2.2.0</spark.version>
+ <spark.binary.version>2.2</spark.binary.version>
+ </properties>
+ <build>
+ <plugins>
+ <!-- Spark-2.2 only supports Java 8 -->
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-enforcer-plugin</artifactId>
+ <version>1.4.1</version>
+ <executions>
+ <execution>
+ <id>enforce-versions</id>
+ <phase>validate</phase>
+ <goals>
+ <goal>enforce</goal>
+ </goals>
+ <configuration>
+ <rules>
+ <requireProperty>
+ <property>java.source.version</property>
+ <regex>1.8</regex>
+ <regexMessage>When -Pspark-2.2 set, java.source.version must be 1.8</regexMessage>
+ </requireProperty>
+ <requireProperty>
+ <property>java.target.version</property>
+ <regex>1.8</regex>
+ <regexMessage>When -Pspark-2.2 set, java.target.version must be 1.8</regexMessage>
+ </requireProperty>
+ </rules>
+ </configuration>
+ </execution>
+ </executions>
+ </plugin>
+ </plugins>
+ </build>
+ </profile>
+ <profile>
<id>spark-2.1</id>
<modules>
<module>spark/spark-2.1</module>
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8bf6dd9e/spark/spark-2.2/bin/mvn-zinc
----------------------------------------------------------------------
diff --git a/spark/spark-2.2/bin/mvn-zinc b/spark/spark-2.2/bin/mvn-zinc
new file mode 100755
index 0000000..759b0a5
--- /dev/null
+++ b/spark/spark-2.2/bin/mvn-zinc
@@ -0,0 +1,99 @@
+#!/usr/bin/env bash
+
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# Copyed from commit 48682f6bf663e54cb63b7e95a4520d34b6fa890b in Apache Spark
+
+# Determine the current working directory
+_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
+# Preserve the calling directory
+_CALLING_DIR="$(pwd)"
+# Options used during compilation
+_COMPILE_JVM_OPTS="-Xmx2g -XX:MaxPermSize=512M -XX:ReservedCodeCacheSize=512m"
+
+# Installs any application tarball given a URL, the expected tarball name,
+# and, optionally, a checkable binary path to determine if the binary has
+# already been installed
+## Arg1 - URL
+## Arg2 - Tarball Name
+## Arg3 - Checkable Binary
+install_app() {
+ local remote_tarball="$1/$2"
+ local local_tarball="${_DIR}/$2"
+ local binary="${_DIR}/$3"
+ local curl_opts="--progress-bar -L"
+ local wget_opts="--progress=bar:force ${wget_opts}"
+
+ if [ -z "$3" -o ! -f "$binary" ]; then
+ # check if we already have the tarball
+ # check if we have curl installed
+ # download application
+ [ ! -f "${local_tarball}" ] && [ $(command -v curl) ] && \
+ echo "exec: curl ${curl_opts} ${remote_tarball}" 1>&2 && \
+ curl ${curl_opts} "${remote_tarball}" > "${local_tarball}"
+ # if the file still doesn't exist, lets try `wget` and cross our fingers
+ [ ! -f "${local_tarball}" ] && [ $(command -v wget) ] && \
+ echo "exec: wget ${wget_opts} ${remote_tarball}" 1>&2 && \
+ wget ${wget_opts} -O "${local_tarball}" "${remote_tarball}"
+ # if both were unsuccessful, exit
+ [ ! -f "${local_tarball}" ] && \
+ echo -n "ERROR: Cannot download $2 with cURL or wget; " && \
+ echo "please install manually and try again." && \
+ exit 2
+ cd "${_DIR}" && tar -xzf "$2"
+ rm -rf "$local_tarball"
+ fi
+}
+
+# Install zinc under the bin/ folder
+install_zinc() {
+ local zinc_path="zinc-0.3.9/bin/zinc"
+ [ ! -f "${_DIR}/${zinc_path}" ] && ZINC_INSTALL_FLAG=1
+ install_app \
+ "http://downloads.typesafe.com/zinc/0.3.9" \
+ "zinc-0.3.9.tgz" \
+ "${zinc_path}"
+ ZINC_BIN="${_DIR}/${zinc_path}"
+}
+
+# Setup healthy defaults for the Zinc port if none were provided from
+# the environment
+ZINC_PORT=${ZINC_PORT:-"3030"}
+
+# Install Zinc for the bin/
+install_zinc
+
+# Reset the current working directory
+cd "${_CALLING_DIR}"
+
+# Now that zinc is ensured to be installed, check its status and, if its
+# not running or just installed, start it
+if [ ! -f "${ZINC_BIN}" ]; then
+ exit -1
+fi
+if [ -n "${ZINC_INSTALL_FLAG}" -o -z "`"${ZINC_BIN}" -status -port ${ZINC_PORT}`" ]; then
+ export ZINC_OPTS=${ZINC_OPTS:-"$_COMPILE_JVM_OPTS"}
+ "${ZINC_BIN}" -shutdown -port ${ZINC_PORT}
+ "${ZINC_BIN}" -start -port ${ZINC_PORT} &>/dev/null
+fi
+
+# Set any `mvn` options if not already present
+export MAVEN_OPTS=${MAVEN_OPTS:-"$_COMPILE_JVM_OPTS"}
+
+# Last, call the `mvn` command as usual
+mvn -DzincPort=${ZINC_PORT} "$@"
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8bf6dd9e/spark/spark-2.2/extra-src/README.md
----------------------------------------------------------------------
diff --git a/spark/spark-2.2/extra-src/README.md b/spark/spark-2.2/extra-src/README.md
new file mode 100644
index 0000000..1d89d0a
--- /dev/null
+++ b/spark/spark-2.2/extra-src/README.md
@@ -0,0 +1 @@
+Copyed from the spark v2.2.0 release.
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8bf6dd9e/spark/spark-2.2/extra-src/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala
----------------------------------------------------------------------
diff --git a/spark/spark-2.2/extra-src/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala b/spark/spark-2.2/extra-src/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala
new file mode 100644
index 0000000..9e98948
--- /dev/null
+++ b/spark/spark-2.2/extra-src/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala
@@ -0,0 +1,279 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive
+
+import java.io.{InputStream, OutputStream}
+import java.rmi.server.UID
+
+import scala.collection.JavaConverters._
+import scala.language.implicitConversions
+import scala.reflect.ClassTag
+
+import com.google.common.base.Objects
+import org.apache.avro.Schema
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.Path
+import org.apache.hadoop.hive.ql.exec.{UDF, Utilities}
+import org.apache.hadoop.hive.ql.plan.{FileSinkDesc, TableDesc}
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDFMacro
+import org.apache.hadoop.hive.serde2.ColumnProjectionUtils
+import org.apache.hadoop.hive.serde2.avro.{AvroGenericRecordWritable, AvroSerdeUtils}
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.HiveDecimalObjectInspector
+import org.apache.hadoop.io.Writable
+import org.apache.hive.com.esotericsoftware.kryo.Kryo
+import org.apache.hive.com.esotericsoftware.kryo.io.{Input, Output}
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.types.Decimal
+import org.apache.spark.util.Utils
+
+private[hive] object HiveShim {
+ // Precision and scale to pass for unlimited decimals; these are the same as the precision and
+ // scale Hive 0.13 infers for BigDecimals from sources that don't specify them (e.g. UDFs)
+ val UNLIMITED_DECIMAL_PRECISION = 38
+ val UNLIMITED_DECIMAL_SCALE = 18
+ val HIVE_GENERIC_UDF_MACRO_CLS = "org.apache.hadoop.hive.ql.udf.generic.GenericUDFMacro"
+
+ /*
+ * This function in hive-0.13 become private, but we have to do this to walkaround hive bug
+ */
+ private def appendReadColumnNames(conf: Configuration, cols: Seq[String]) {
+ val old: String = conf.get(ColumnProjectionUtils.READ_COLUMN_NAMES_CONF_STR, "")
+ val result: StringBuilder = new StringBuilder(old)
+ var first: Boolean = old.isEmpty
+
+ for (col <- cols) {
+ if (first) {
+ first = false
+ } else {
+ result.append(',')
+ }
+ result.append(col)
+ }
+ conf.set(ColumnProjectionUtils.READ_COLUMN_NAMES_CONF_STR, result.toString)
+ }
+
+ /*
+ * Cannot use ColumnProjectionUtils.appendReadColumns directly, if ids is null
+ */
+ def appendReadColumns(conf: Configuration, ids: Seq[Integer], names: Seq[String]) {
+ if (ids != null) {
+ ColumnProjectionUtils.appendReadColumns(conf, ids.asJava)
+ }
+ if (names != null) {
+ appendReadColumnNames(conf, names)
+ }
+ }
+
+ /*
+ * Bug introduced in hive-0.13. AvroGenericRecordWritable has a member recordReaderID that
+ * is needed to initialize before serialization.
+ */
+ def prepareWritable(w: Writable, serDeProps: Seq[(String, String)]): Writable = {
+ w match {
+ case w: AvroGenericRecordWritable =>
+ w.setRecordReaderID(new UID())
+ // In Hive 1.1, the record's schema may need to be initialized manually or a NPE will
+ // be thrown.
+ if (w.getFileSchema() == null) {
+ serDeProps
+ .find(_._1 == AvroSerdeUtils.AvroTableProperties.SCHEMA_LITERAL.getPropName())
+ .foreach { kv =>
+ w.setFileSchema(new Schema.Parser().parse(kv._2))
+ }
+ }
+ case _ =>
+ }
+ w
+ }
+
+ def toCatalystDecimal(hdoi: HiveDecimalObjectInspector, data: Any): Decimal = {
+ if (hdoi.preferWritable()) {
+ Decimal(hdoi.getPrimitiveWritableObject(data).getHiveDecimal().bigDecimalValue,
+ hdoi.precision(), hdoi.scale())
+ } else {
+ Decimal(hdoi.getPrimitiveJavaObject(data).bigDecimalValue(), hdoi.precision(), hdoi.scale())
+ }
+ }
+
+ /**
+ * This class provides the UDF creation and also the UDF instance serialization and
+ * de-serialization cross process boundary.
+ *
+ * Detail discussion can be found at https://github.com/apache/spark/pull/3640
+ *
+ * @param functionClassName UDF class name
+ * @param instance optional UDF instance which contains additional information (for macro)
+ */
+ private[hive] case class HiveFunctionWrapper(var functionClassName: String,
+ private var instance: AnyRef = null) extends java.io.Externalizable {
+
+ // for Serialization
+ def this() = this(null)
+
+ override def hashCode(): Int = {
+ if (functionClassName == HIVE_GENERIC_UDF_MACRO_CLS) {
+ Objects.hashCode(functionClassName, instance.asInstanceOf[GenericUDFMacro].getBody())
+ } else {
+ functionClassName.hashCode()
+ }
+ }
+
+ override def equals(other: Any): Boolean = other match {
+ case a: HiveFunctionWrapper if functionClassName == a.functionClassName =>
+ // In case of udf macro, check to make sure they point to the same underlying UDF
+ if (functionClassName == HIVE_GENERIC_UDF_MACRO_CLS) {
+ a.instance.asInstanceOf[GenericUDFMacro].getBody() ==
+ instance.asInstanceOf[GenericUDFMacro].getBody()
+ } else {
+ true
+ }
+ case _ => false
+ }
+
+ @transient
+ def deserializeObjectByKryo[T: ClassTag](
+ kryo: Kryo,
+ in: InputStream,
+ clazz: Class[_]): T = {
+ val inp = new Input(in)
+ val t: T = kryo.readObject(inp, clazz).asInstanceOf[T]
+ inp.close()
+ t
+ }
+
+ @transient
+ def serializeObjectByKryo(
+ kryo: Kryo,
+ plan: Object,
+ out: OutputStream) {
+ val output: Output = new Output(out)
+ kryo.writeObject(output, plan)
+ output.close()
+ }
+
+ def deserializePlan[UDFType](is: java.io.InputStream, clazz: Class[_]): UDFType = {
+ deserializeObjectByKryo(Utilities.runtimeSerializationKryo.get(), is, clazz)
+ .asInstanceOf[UDFType]
+ }
+
+ def serializePlan(function: AnyRef, out: java.io.OutputStream): Unit = {
+ serializeObjectByKryo(Utilities.runtimeSerializationKryo.get(), function, out)
+ }
+
+ def writeExternal(out: java.io.ObjectOutput) {
+ // output the function name
+ out.writeUTF(functionClassName)
+
+ // Write a flag if instance is null or not
+ out.writeBoolean(instance != null)
+ if (instance != null) {
+ // Some of the UDF are serializable, but some others are not
+ // Hive Utilities can handle both cases
+ val baos = new java.io.ByteArrayOutputStream()
+ serializePlan(instance, baos)
+ val functionInBytes = baos.toByteArray
+
+ // output the function bytes
+ out.writeInt(functionInBytes.length)
+ out.write(functionInBytes, 0, functionInBytes.length)
+ }
+ }
+
+ def readExternal(in: java.io.ObjectInput) {
+ // read the function name
+ functionClassName = in.readUTF()
+
+ if (in.readBoolean()) {
+ // if the instance is not null
+ // read the function in bytes
+ val functionInBytesLength = in.readInt()
+ val functionInBytes = new Array[Byte](functionInBytesLength)
+ in.readFully(functionInBytes)
+
+ // deserialize the function object via Hive Utilities
+ instance = deserializePlan[AnyRef](new java.io.ByteArrayInputStream(functionInBytes),
+ Utils.getContextOrSparkClassLoader.loadClass(functionClassName))
+ }
+ }
+
+ def createFunction[UDFType <: AnyRef](): UDFType = {
+ if (instance != null) {
+ instance.asInstanceOf[UDFType]
+ } else {
+ val func = Utils.getContextOrSparkClassLoader
+ .loadClass(functionClassName).newInstance.asInstanceOf[UDFType]
+ if (!func.isInstanceOf[UDF]) {
+ // We cache the function if it's no the Simple UDF,
+ // as we always have to create new instance for Simple UDF
+ instance = func
+ }
+ func
+ }
+ }
+ }
+
+ /*
+ * Bug introduced in hive-0.13. FileSinkDesc is serializable, but its member path is not.
+ * Fix it through wrapper.
+ */
+ implicit def wrapperToFileSinkDesc(w: ShimFileSinkDesc): FileSinkDesc = {
+ val f = new FileSinkDesc(new Path(w.dir), w.tableInfo, w.compressed)
+ f.setCompressCodec(w.compressCodec)
+ f.setCompressType(w.compressType)
+ f.setTableInfo(w.tableInfo)
+ f.setDestTableId(w.destTableId)
+ f
+ }
+
+ /*
+ * Bug introduced in hive-0.13. FileSinkDesc is serializable, but its member path is not.
+ * Fix it through wrapper.
+ */
+ private[hive] class ShimFileSinkDesc(
+ var dir: String,
+ var tableInfo: TableDesc,
+ var compressed: Boolean)
+ extends Serializable with Logging {
+ var compressCodec: String = _
+ var compressType: String = _
+ var destTableId: Int = _
+
+ def setCompressed(compressed: Boolean) {
+ this.compressed = compressed
+ }
+
+ def getDirName(): String = dir
+
+ def setDestTableId(destTableId: Int) {
+ this.destTableId = destTableId
+ }
+
+ def setTableInfo(tableInfo: TableDesc) {
+ this.tableInfo = tableInfo
+ }
+
+ def setCompressCodec(intermediateCompressorCodec: String) {
+ compressCodec = intermediateCompressorCodec
+ }
+
+ def setCompressType(intermediateCompressType: String) {
+ compressType = intermediateCompressType
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8bf6dd9e/spark/spark-2.2/pom.xml
----------------------------------------------------------------------
diff --git a/spark/spark-2.2/pom.xml b/spark/spark-2.2/pom.xml
new file mode 100644
index 0000000..85a296f
--- /dev/null
+++ b/spark/spark-2.2/pom.xml
@@ -0,0 +1,269 @@
+<!--
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+-->
+<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
+ xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
+ <modelVersion>4.0.0</modelVersion>
+
+ <parent>
+ <groupId>io.github.myui</groupId>
+ <artifactId>hivemall</artifactId>
+ <version>0.4.2-rc.2</version>
+ <relativePath>../../pom.xml</relativePath>
+ </parent>
+
+ <artifactId>hivemall-spark</artifactId>
+ <name>Hivemall on Spark 2.2</name>
+ <packaging>jar</packaging>
+
+ <properties>
+ <PermGen>64m</PermGen>
+ <MaxPermGen>512m</MaxPermGen>
+ <CodeCacheSize>512m</CodeCacheSize>
+ <main.basedir>${project.parent.basedir}</main.basedir>
+ </properties>
+
+ <dependencies>
+ <!-- hivemall dependencies -->
+ <dependency>
+ <groupId>io.github.myui</groupId>
+ <artifactId>hivemall-core</artifactId>
+ <version>${project.version}</version>
+ <scope>compile</scope>
+ </dependency>
+ <dependency>
+ <groupId>io.github.myui</groupId>
+ <artifactId>hivemall-xgboost</artifactId>
+ <version>${project.version}</version>
+ <scope>compile</scope>
+ </dependency>
+ <dependency>
+ <groupId>io.github.myui</groupId>
+ <artifactId>hivemall-spark-common</artifactId>
+ <version>${project.version}</version>
+ <scope>compile</scope>
+ </dependency>
+
+ <!-- third-party dependencies -->
+ <dependency>
+ <groupId>org.scala-lang</groupId>
+ <artifactId>scala-library</artifactId>
+ <version>${scala.version}</version>
+ <scope>compile</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.commons</groupId>
+ <artifactId>commons-compress</artifactId>
+ <version>1.8</version>
+ <scope>compile</scope>
+ </dependency>
+
+ <!-- other provided dependencies -->
+ <dependency>
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-core_${scala.binary.version}</artifactId>
+ <version>${spark.version}</version>
+ <scope>provided</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-sql_${scala.binary.version}</artifactId>
+ <version>${spark.version}</version>
+ <scope>provided</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-hive_${scala.binary.version}</artifactId>
+ <version>${spark.version}</version>
+ <scope>provided</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-streaming_${scala.binary.version}</artifactId>
+ <version>${spark.version}</version>
+ <scope>provided</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-mllib_${scala.binary.version}</artifactId>
+ <version>${spark.version}</version>
+ <scope>provided</scope>
+ </dependency>
+
+ <!-- test dependencies -->
+ <dependency>
+ <groupId>io.github.myui</groupId>
+ <artifactId>hivemall-mixserv</artifactId>
+ <version>${project.version}</version>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.xerial</groupId>
+ <artifactId>xerial-core</artifactId>
+ <version>3.2.3</version>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.scalatest</groupId>
+ <artifactId>scalatest_${scala.binary.version}</artifactId>
+ <version>2.2.4</version>
+ <scope>test</scope>
+ </dependency>
+ </dependencies>
+
+ <build>
+ <directory>target</directory>
+ <outputDirectory>target/classes</outputDirectory>
+ <finalName>${project.artifactId}-${spark.binary.version}_${scala.binary.version}-${project.version}</finalName>
+ <testOutputDirectory>target/test-classes</testOutputDirectory>
+ <plugins>
+ <!-- For incremental compilation -->
+ <plugin>
+ <groupId>net.alchim31.maven</groupId>
+ <artifactId>scala-maven-plugin</artifactId>
+ <version>3.2.2</version>
+ <executions>
+ <execution>
+ <id>scala-compile-first</id>
+ <phase>process-resources</phase>
+ <goals>
+ <goal>compile</goal>
+ </goals>
+ </execution>
+ <execution>
+ <id>scala-test-compile-first</id>
+ <phase>process-test-resources</phase>
+ <goals>
+ <goal>testCompile</goal>
+ </goals>
+ </execution>
+ </executions>
+ <configuration>
+ <scalaVersion>${scala.version}</scalaVersion>
+ <recompileMode>incremental</recompileMode>
+ <useZincServer>true</useZincServer>
+ <args>
+ <arg>-unchecked</arg>
+ <arg>-deprecation</arg>
+ <!-- TODO: To enable this option, we need to fix many wornings -->
+ <!-- <arg>-feature</arg> -->
+ </args>
+ <jvmArgs>
+ <jvmArg>-Xms1024m</jvmArg>
+ <jvmArg>-Xmx1024m</jvmArg>
+ <jvmArg>-XX:PermSize=${PermGen}</jvmArg>
+ <jvmArg>-XX:MaxPermSize=${MaxPermGen}</jvmArg>
+ <jvmArg>-XX:ReservedCodeCacheSize=${CodeCacheSize}</jvmArg>
+ </jvmArgs>
+ </configuration>
+ </plugin>
+ <!-- hivemall-spark_xx-xx.jar -->
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-jar-plugin</artifactId>
+ <version>2.5</version>
+ <configuration>
+ <finalName>${project.artifactId}-${spark.binary.version}_${scala.binary.version}-${project.version}</finalName>
+ <outputDirectory>${project.parent.build.directory}</outputDirectory>
+ </configuration>
+ </plugin>
+ <!-- hivemall-spark_xx-xx-with-dependencies.jar including minimum dependencies -->
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-shade-plugin</artifactId>
+ <version>2.3</version>
+ <executions>
+ <execution>
+ <id>jar-with-dependencies</id>
+ <phase>package</phase>
+ <goals>
+ <goal>shade</goal>
+ </goals>
+ <configuration>
+ <finalName>${project.artifactId}-${spark.binary.version}_${scala.binary.version}-${project.version}-with-dependencies</finalName>
+ <outputDirectory>${project.parent.build.directory}</outputDirectory>
+ <minimizeJar>false</minimizeJar>
+ <createDependencyReducedPom>false</createDependencyReducedPom>
+ <artifactSet>
+ <includes>
+ <include>io.github.myui:hivemall-core</include>
+ <include>io.github.myui:hivemall-xgboost</include>
+ <include>io.github.myui:hivemall-spark-common</include>
+ <include>com.github.haifengl:smile-core</include>
+ <include>com.github.haifengl:smile-math</include>
+ <include>com.github.haifengl:smile-data</include>
+ <include>ml.dmlc:xgboost4j</include>
+ <include>com.esotericsoftware.kryo:kryo</include>
+ </includes>
+ </artifactSet>
+ </configuration>
+ </execution>
+ </executions>
+ </plugin>
+ <!-- disable surefire because there is no java test -->
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-surefire-plugin</artifactId>
+ <version>2.7</version>
+ <configuration>
+ <skipTests>true</skipTests>
+ </configuration>
+ </plugin>
+ <!-- then, enable scalatest -->
+ <plugin>
+ <groupId>org.scalatest</groupId>
+ <artifactId>scalatest-maven-plugin</artifactId>
+ <version>1.0</version>
+ <configuration>
+ <reportsDirectory>${project.build.directory}/surefire-reports</reportsDirectory>
+ <junitxml>.</junitxml>
+ <filereports>SparkTestSuite.txt</filereports>
+ <argLine>-ea -Xmx2g -XX:MaxPermSize=${MaxPermGen} -XX:ReservedCodeCacheSize=${CodeCacheSize}</argLine>
+ <stderr />
+ <environmentVariables>
+ <SPARK_PREPEND_CLASSES>1</SPARK_PREPEND_CLASSES>
+ <SPARK_SCALA_VERSION>${scala.binary.version}</SPARK_SCALA_VERSION>
+ <SPARK_TESTING>1</SPARK_TESTING>
+ <JAVA_HOME>${env.JAVA_HOME}</JAVA_HOME>
+ </environmentVariables>
+ <systemProperties>
+ <log4j.configuration>file:src/test/resources/log4j.properties</log4j.configuration>
+ <derby.system.durability>test</derby.system.durability>
+ <java.awt.headless>true</java.awt.headless>
+ <java.io.tmpdir>${project.build.directory}/tmp</java.io.tmpdir>
+ <spark.testing>1</spark.testing>
+ <spark.ui.enabled>false</spark.ui.enabled>
+ <spark.ui.showConsoleProgress>false</spark.ui.showConsoleProgress>
+ <spark.unsafe.exceptionOnMemoryLeak>true</spark.unsafe.exceptionOnMemoryLeak>
+ <!-- Needed by sql/hive tests. -->
+ <test.src.tables>__not_used__</test.src.tables>
+ </systemProperties>
+ <tagsToExclude>${test.exclude.tags}</tagsToExclude>
+ </configuration>
+ <executions>
+ <execution>
+ <id>test</id>
+ <goals>
+ <goal>test</goal>
+ </goals>
+ </execution>
+ </executions>
+ </plugin>
+ </plugins>
+ </build>
+</project>
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8bf6dd9e/spark/spark-2.2/src/main/java/hivemall/xgboost/XGBoostOptions.scala
----------------------------------------------------------------------
diff --git a/spark/spark-2.2/src/main/java/hivemall/xgboost/XGBoostOptions.scala b/spark/spark-2.2/src/main/java/hivemall/xgboost/XGBoostOptions.scala
new file mode 100644
index 0000000..3e0f274
--- /dev/null
+++ b/spark/spark-2.2/src/main/java/hivemall/xgboost/XGBoostOptions.scala
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package hivemall.xgboost
+
+import scala.collection.mutable
+
+import org.apache.commons.cli.Options
+import org.apache.spark.annotation.AlphaComponent
+
+/**
+ * :: AlphaComponent ::
+ * An utility class to generate a sequence of options used in XGBoost.
+ */
+@AlphaComponent
+case class XGBoostOptions() {
+ private val params: mutable.Map[String, String] = mutable.Map.empty
+ private val options: Options = {
+ new XGBoostUDTF() {
+ def options(): Options = super.getOptions()
+ }.options()
+ }
+
+ private def isValidKey(key: String): Boolean = {
+ // TODO: Is there another way to handle all the XGBoost options?
+ options.hasOption(key) || key == "num_class"
+ }
+
+ def set(key: String, value: String): XGBoostOptions = {
+ require(isValidKey(key), s"non-existing key detected in XGBoost options: ${key}")
+ params.put(key, value)
+ this
+ }
+
+ def help(): Unit = {
+ import scala.collection.JavaConversions._
+ options.getOptions.map { case option => println(option) }
+ }
+
+ override def toString(): String = {
+ params.map { case (key, value) => s"-$key $value" }.mkString(" ")
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8bf6dd9e/spark/spark-2.2/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister
----------------------------------------------------------------------
diff --git a/spark/spark-2.2/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/spark/spark-2.2/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister
new file mode 100644
index 0000000..b49e20a
--- /dev/null
+++ b/spark/spark-2.2/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister
@@ -0,0 +1 @@
+org.apache.spark.sql.hive.source.XGBoostFileFormat
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8bf6dd9e/spark/spark-2.2/src/main/resources/log4j.properties
----------------------------------------------------------------------
diff --git a/spark/spark-2.2/src/main/resources/log4j.properties b/spark/spark-2.2/src/main/resources/log4j.properties
new file mode 100644
index 0000000..72bf5b6
--- /dev/null
+++ b/spark/spark-2.2/src/main/resources/log4j.properties
@@ -0,0 +1,12 @@
+# Set everything to be logged to the console
+log4j.rootCategory=INFO, console
+log4j.appender.console=org.apache.log4j.ConsoleAppender
+log4j.appender.console.target=System.err
+log4j.appender.console.layout=org.apache.log4j.PatternLayout
+log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n
+
+# Settings to quiet third party logs that are too verbose
+log4j.logger.org.eclipse.jetty=INFO
+log4j.logger.org.eclipse.jetty.util.component.AbstractLifeCycle=ERROR
+log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO
+log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8bf6dd9e/spark/spark-2.2/src/main/scala/hivemall/tools/RegressionDatagen.scala
----------------------------------------------------------------------
diff --git a/spark/spark-2.2/src/main/scala/hivemall/tools/RegressionDatagen.scala b/spark/spark-2.2/src/main/scala/hivemall/tools/RegressionDatagen.scala
new file mode 100644
index 0000000..a2b7f60
--- /dev/null
+++ b/spark/spark-2.2/src/main/scala/hivemall/tools/RegressionDatagen.scala
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package hivemall.tools
+
+import org.apache.spark.sql.{DataFrame, Row, SQLContext}
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.hive.HivemallOps._
+import org.apache.spark.sql.types._
+
+object RegressionDatagen {
+
+ /**
+ * Generate data for regression/classification.
+ * See [[hivemall.dataset.LogisticRegressionDataGeneratorUDTF]]
+ * for the details of arguments below.
+ */
+ def exec(sc: SQLContext,
+ n_partitions: Int = 2,
+ min_examples: Int = 1000,
+ n_features: Int = 10,
+ n_dims: Int = 200,
+ seed: Int = 43,
+ dense: Boolean = false,
+ prob_one: Float = 0.6f,
+ sort: Boolean = false,
+ cl: Boolean = false): DataFrame = {
+
+ require(n_partitions > 0, "Non-negative #n_partitions required.")
+ require(min_examples > 0, "Non-negative #min_examples required.")
+ require(n_features > 0, "Non-negative #n_features required.")
+ require(n_dims > 0, "Non-negative #n_dims required.")
+
+ // Calculate #examples to generate in each partition
+ val n_examples = (min_examples + n_partitions - 1) / n_partitions
+
+ val df = sc.createDataFrame(
+ sc.sparkContext.parallelize((0 until n_partitions).map(Row(_)), n_partitions),
+ StructType(
+ StructField("data", IntegerType, true) ::
+ Nil)
+ )
+ import sc.implicits._
+ df.lr_datagen(
+ lit(s"-n_examples $n_examples -n_features $n_features -n_dims $n_dims -prob_one $prob_one"
+ + (if (dense) " -dense" else "")
+ + (if (sort) " -sort" else "")
+ + (if (cl) " -cl" else ""))
+ ).select($"label".cast(DoubleType).as("label"), $"features")
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8bf6dd9e/spark/spark-2.2/src/main/scala/org/apache/spark/sql/catalyst/expressions/EachTopK.scala
----------------------------------------------------------------------
diff --git a/spark/spark-2.2/src/main/scala/org/apache/spark/sql/catalyst/expressions/EachTopK.scala b/spark/spark-2.2/src/main/scala/org/apache/spark/sql/catalyst/expressions/EachTopK.scala
new file mode 100644
index 0000000..cac2a5d
--- /dev/null
+++ b/spark/spark-2.2/src/main/scala/org/apache/spark/sql/catalyst/expressions/EachTopK.scala
@@ -0,0 +1,133 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
+import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.catalyst.util.TypeUtils
+import org.apache.spark.sql.catalyst.utils.InternalRowPriorityQueue
+import org.apache.spark.sql.types._
+
+trait TopKHelper {
+
+ def k: Int
+ def scoreType: DataType
+
+ @transient val ScoreTypes = TypeCollection(
+ ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType, DecimalType
+ )
+
+ protected case class ScoreWriter(writer: UnsafeRowWriter, ordinal: Int) {
+
+ def write(v: Any): Unit = scoreType match {
+ case ByteType => writer.write(ordinal, v.asInstanceOf[Byte])
+ case ShortType => writer.write(ordinal, v.asInstanceOf[Short])
+ case IntegerType => writer.write(ordinal, v.asInstanceOf[Int])
+ case LongType => writer.write(ordinal, v.asInstanceOf[Long])
+ case FloatType => writer.write(ordinal, v.asInstanceOf[Float])
+ case DoubleType => writer.write(ordinal, v.asInstanceOf[Double])
+ case d: DecimalType => writer.write(ordinal, v.asInstanceOf[Decimal], d.precision, d.scale)
+ }
+ }
+
+ protected lazy val scoreOrdering = {
+ val ordering = TypeUtils.getInterpretedOrdering(scoreType)
+ if (k > 0) ordering else ordering.reverse
+ }
+
+ protected lazy val reverseScoreOrdering = scoreOrdering.reverse
+
+ protected lazy val queue: InternalRowPriorityQueue = {
+ new InternalRowPriorityQueue(Math.abs(k), (x: Any, y: Any) => scoreOrdering.compare(x, y))
+ }
+}
+
+case class EachTopK(
+ k: Int,
+ scoreExpr: Expression,
+ groupExprs: Seq[Expression],
+ elementSchema: StructType,
+ children: Seq[Attribute])
+ extends Generator with TopKHelper with CodegenFallback {
+
+ override val scoreType: DataType = scoreExpr.dataType
+
+ private lazy val groupingProjection: UnsafeProjection = UnsafeProjection.create(groupExprs)
+ private lazy val scoreProjection: UnsafeProjection = UnsafeProjection.create(scoreExpr :: Nil)
+
+ // The grouping key of the current partition
+ private var currentGroupingKeys: UnsafeRow = _
+
+ override def checkInputDataTypes(): TypeCheckResult = {
+ if (!ScoreTypes.acceptsType(scoreExpr.dataType)) {
+ TypeCheckResult.TypeCheckFailure(s"$scoreExpr must have a comparable type")
+ } else {
+ TypeCheckResult.TypeCheckSuccess
+ }
+ }
+
+ private def topKRowsForGroup(): Seq[InternalRow] = if (queue.size > 0) {
+ val outputRows = queue.iterator.toSeq.reverse
+ val (headScore, _) = outputRows.head
+ val rankNum = outputRows.scanLeft((1, headScore)) { case ((rank, prevScore), (score, _)) =>
+ if (prevScore == score) (rank, score) else (rank + 1, score)
+ }
+ val topKRow = new UnsafeRow(1)
+ val bufferHolder = new BufferHolder(topKRow)
+ val unsafeRowWriter = new UnsafeRowWriter(bufferHolder, 1)
+ outputRows.zip(rankNum.map(_._1)).map { case ((_, row), index) =>
+ // Writes to an UnsafeRow directly
+ bufferHolder.reset()
+ unsafeRowWriter.write(0, index)
+ topKRow.setTotalSize(bufferHolder.totalSize())
+ new JoinedRow(topKRow, row)
+ }
+ } else {
+ Seq.empty
+ }
+
+ override def eval(input: InternalRow): TraversableOnce[InternalRow] = {
+ val groupingKeys = groupingProjection(input)
+ val ret = if (currentGroupingKeys != groupingKeys) {
+ val topKRows = topKRowsForGroup()
+ currentGroupingKeys = groupingKeys.copy()
+ queue.clear()
+ topKRows
+ } else {
+ Iterator.empty
+ }
+ queue += Tuple2(scoreProjection(input).get(0, scoreType), input)
+ ret
+ }
+
+ override def terminate(): TraversableOnce[InternalRow] = {
+ if (queue.size > 0) {
+ val topKRows = topKRowsForGroup()
+ queue.clear()
+ topKRows
+ } else {
+ Iterator.empty
+ }
+ }
+
+ // TODO: Need to support codegen
+ // protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8bf6dd9e/spark/spark-2.2/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/JoinTopK.scala
----------------------------------------------------------------------
diff --git a/spark/spark-2.2/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/JoinTopK.scala b/spark/spark-2.2/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/JoinTopK.scala
new file mode 100644
index 0000000..556cdc3
--- /dev/null
+++ b/spark/spark-2.2/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/JoinTopK.scala
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.sql.catalyst.plans.logical
+
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.{Inner, JoinType}
+import org.apache.spark.sql.types.{BooleanType, IntegerType}
+
+case class JoinTopK(
+ k: Int,
+ left: LogicalPlan,
+ right: LogicalPlan,
+ joinType: JoinType,
+ condition: Option[Expression])(
+ val scoreExpr: NamedExpression,
+ private[sql] val rankAttr: Seq[Attribute] = AttributeReference("rank", IntegerType)() :: Nil)
+ extends BinaryNode with PredicateHelper {
+
+ override def output: Seq[Attribute] = joinType match {
+ case Inner => rankAttr ++ Seq(scoreExpr.toAttribute) ++ left.output ++ right.output
+ }
+
+ override def references: AttributeSet = {
+ AttributeSet((expressions ++ Seq(scoreExpr)).flatMap(_.references))
+ }
+
+ override protected def validConstraints: Set[Expression] = joinType match {
+ case Inner if condition.isDefined =>
+ left.constraints.union(right.constraints)
+ .union(splitConjunctivePredicates(condition.get).toSet)
+ }
+
+ override protected final def otherCopyArgs: Seq[AnyRef] = {
+ scoreExpr :: rankAttr :: Nil
+ }
+
+ def duplicateResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty
+
+ lazy val resolvedExceptNatural: Boolean = {
+ childrenResolved &&
+ expressions.forall(_.resolved) &&
+ duplicateResolved &&
+ condition.forall(_.dataType == BooleanType)
+ }
+
+ override lazy val resolved: Boolean = joinType match {
+ case Inner => resolvedExceptNatural
+ case tpe => throw new AnalysisException(s"Unsupported using join type $tpe")
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8bf6dd9e/spark/spark-2.2/src/main/scala/org/apache/spark/sql/catalyst/utils/InternalRowPriorityQueue.scala
----------------------------------------------------------------------
diff --git a/spark/spark-2.2/src/main/scala/org/apache/spark/sql/catalyst/utils/InternalRowPriorityQueue.scala b/spark/spark-2.2/src/main/scala/org/apache/spark/sql/catalyst/utils/InternalRowPriorityQueue.scala
new file mode 100644
index 0000000..12c20fb
--- /dev/null
+++ b/spark/spark-2.2/src/main/scala/org/apache/spark/sql/catalyst/utils/InternalRowPriorityQueue.scala
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.sql.catalyst.utils
+
+import java.io.Serializable
+import java.util.{PriorityQueue => JPriorityQueue}
+
+import scala.collection.JavaConverters._
+import scala.collection.generic.Growable
+
+import org.apache.spark.sql.catalyst.InternalRow
+
+private[sql] class InternalRowPriorityQueue(
+ maxSize: Int,
+ compareFunc: (Any, Any) => Int
+ ) extends Iterable[(Any, InternalRow)] with Growable[(Any, InternalRow)] with Serializable {
+
+ private[this] val ordering = new Ordering[(Any, InternalRow)] {
+ override def compare(x: (Any, InternalRow), y: (Any, InternalRow)): Int =
+ compareFunc(x._1, y._1)
+ }
+
+ private val underlying = new JPriorityQueue[(Any, InternalRow)](maxSize, ordering)
+
+ override def iterator: Iterator[(Any, InternalRow)] = underlying.iterator.asScala
+
+ override def size: Int = underlying.size
+
+ override def ++=(xs: TraversableOnce[(Any, InternalRow)]): this.type = {
+ xs.foreach { this += _ }
+ this
+ }
+
+ override def +=(elem: (Any, InternalRow)): this.type = {
+ if (size < maxSize) {
+ underlying.offer((elem._1, elem._2.copy()))
+ } else {
+ maybeReplaceLowest(elem)
+ }
+ this
+ }
+
+ override def +=(elem1: (Any, InternalRow), elem2: (Any, InternalRow), elems: (Any, InternalRow)*)
+ : this.type = {
+ this += elem1 += elem2 ++= elems
+ }
+
+ override def clear() { underlying.clear() }
+
+ private def maybeReplaceLowest(a: (Any, InternalRow)): Boolean = {
+ val head = underlying.peek()
+ if (head != null && ordering.gt(a, head)) {
+ underlying.poll()
+ underlying.offer((a._1, a._2.copy()))
+ } else {
+ false
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8bf6dd9e/spark/spark-2.2/src/main/scala/org/apache/spark/sql/execution/UserProvidedPlanner.scala
----------------------------------------------------------------------
diff --git a/spark/spark-2.2/src/main/scala/org/apache/spark/sql/execution/UserProvidedPlanner.scala b/spark/spark-2.2/src/main/scala/org/apache/spark/sql/execution/UserProvidedPlanner.scala
new file mode 100644
index 0000000..09d60a6
--- /dev/null
+++ b/spark/spark-2.2/src/main/scala/org/apache/spark/sql/execution/UserProvidedPlanner.scala
@@ -0,0 +1,83 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.Strategy
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.JoinType
+import org.apache.spark.sql.catalyst.plans.logical.{JoinTopK, LogicalPlan}
+import org.apache.spark.sql.internal.SQLConf
+
+private object ExtractJoinTopKKeys extends Logging with PredicateHelper {
+ /** (k, scoreExpr, joinType, leftKeys, rightKeys, condition, leftChild, rightChild) */
+ type ReturnType =
+ (Int, NamedExpression, Seq[Attribute], JoinType, Seq[Expression], Seq[Expression],
+ Option[Expression], LogicalPlan, LogicalPlan)
+
+ def unapply(plan: LogicalPlan): Option[ReturnType] = plan match {
+ case join @ JoinTopK(k, left, right, joinType, condition) =>
+ logDebug(s"Considering join on: $condition")
+ val predicates = condition.map(splitConjunctivePredicates).getOrElse(Nil)
+ val joinKeys = predicates.flatMap {
+ case EqualTo(l, r) if canEvaluate(l, left) && canEvaluate(r, right) => Some((l, r))
+ case EqualTo(l, r) if canEvaluate(l, right) && canEvaluate(r, left) => Some((r, l))
+ // Replace null with default value for joining key, then those rows with null in it could
+ // be joined together
+ case EqualNullSafe(l, r) if canEvaluate(l, left) && canEvaluate(r, right) =>
+ Some((Coalesce(Seq(l, Literal.default(l.dataType))),
+ Coalesce(Seq(r, Literal.default(r.dataType)))))
+ case EqualNullSafe(l, r) if canEvaluate(l, right) && canEvaluate(r, left) =>
+ Some((Coalesce(Seq(r, Literal.default(r.dataType))),
+ Coalesce(Seq(l, Literal.default(l.dataType)))))
+ case other => None
+ }
+ val otherPredicates = predicates.filterNot {
+ case EqualTo(l, r) =>
+ canEvaluate(l, left) && canEvaluate(r, right) ||
+ canEvaluate(l, right) && canEvaluate(r, left)
+ case other => false
+ }
+
+ if (joinKeys.nonEmpty) {
+ val (leftKeys, rightKeys) = joinKeys.unzip
+ logDebug(s"leftKeys:$leftKeys | rightKeys:$rightKeys")
+ Some((k, join.scoreExpr, join.rankAttr, joinType, leftKeys, rightKeys,
+ otherPredicates.reduceOption(And), left, right))
+ } else {
+ None
+ }
+
+ case p =>
+ None
+ }
+}
+
+private[sql] class UserProvidedPlanner(val conf: SQLConf) extends Strategy {
+
+ override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
+ case ExtractJoinTopKKeys(
+ k, scoreExpr, rankAttr, _, leftKeys, rightKeys, condition, left, right) =>
+ Seq(joins.ShuffledHashJoinTopKExec(
+ k, leftKeys, rightKeys, condition, planLater(left), planLater(right))(scoreExpr, rankAttr))
+ case _ =>
+ Nil
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8bf6dd9e/spark/spark-2.2/src/main/scala/org/apache/spark/sql/execution/datasources/csv/csvExpressions.scala
----------------------------------------------------------------------
diff --git a/spark/spark-2.2/src/main/scala/org/apache/spark/sql/execution/datasources/csv/csvExpressions.scala b/spark/spark-2.2/src/main/scala/org/apache/spark/sql/execution/datasources/csv/csvExpressions.scala
new file mode 100644
index 0000000..1f56c90
--- /dev/null
+++ b/spark/spark-2.2/src/main/scala/org/apache/spark/sql/execution/datasources/csv/csvExpressions.scala
@@ -0,0 +1,169 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.csv
+
+import com.univocity.parsers.csv.CsvWriter
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
+import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, TimeZoneAwareExpression, UnaryExpression}
+import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
+
+/**
+ * Converts a csv input string to a [[StructType]] with the specified schema.
+ *
+ * TODO: Move this class into org.apache.spark.sql.catalyst.expressions in Spark-v2.2+
+ */
+case class CsvToStruct(
+ schema: StructType,
+ options: Map[String, String],
+ child: Expression,
+ timeZoneId: Option[String] = None)
+ extends UnaryExpression with TimeZoneAwareExpression with CodegenFallback with ExpectsInputTypes {
+
+ def this(schema: StructType, options: Map[String, String], child: Expression) =
+ this(schema, options, child, None)
+
+ override def nullable: Boolean = true
+
+ @transient private lazy val csvOptions = new CSVOptions(options, timeZoneId.get)
+ @transient private lazy val csvParser = new UnivocityParser(schema, schema, csvOptions)
+
+ private def parse(input: String): InternalRow = csvParser.parse(input)
+
+ override def dataType: DataType = schema
+
+ override def nullSafeEval(csv: Any): Any = {
+ try parse(csv.toString) catch { case _: RuntimeException => null }
+ }
+
+ override def inputTypes: Seq[AbstractDataType] = StringType :: Nil
+
+ override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
+ copy(timeZoneId = Option(timeZoneId))
+}
+
+private class CsvGenerator(schema: StructType, options: CSVOptions) {
+
+ // A `ValueConverter` is responsible for converting a value of an `InternalRow` to `String`.
+ // When the value is null, this converter should not be called.
+ private type ValueConverter = (InternalRow, Int) => String
+
+ // `ValueConverter`s for all values in the fields of the schema
+ private val valueConverters: Array[ValueConverter] =
+ schema.map(_.dataType).map(makeConverter).toArray
+
+ private def makeConverter(dataType: DataType): ValueConverter = dataType match {
+ case DateType =>
+ (row: InternalRow, ordinal: Int) =>
+ options.dateFormat.format(DateTimeUtils.toJavaDate(row.getInt(ordinal)))
+
+ case TimestampType =>
+ (row: InternalRow, ordinal: Int) =>
+ options.timestampFormat.format(DateTimeUtils.toJavaTimestamp(row.getLong(ordinal)))
+
+ case udt: UserDefinedType[_] => makeConverter(udt.sqlType)
+
+ case dt: DataType =>
+ (row: InternalRow, ordinal: Int) =>
+ row.get(ordinal, dt).toString
+ }
+
+ def convertRow(row: InternalRow): Seq[String] = {
+ var i = 0
+ val values = new Array[String](row.numFields)
+ while (i < row.numFields) {
+ if (!row.isNullAt(i)) {
+ values(i) = valueConverters(i).apply(row, i)
+ } else {
+ values(i) = options.nullValue
+ }
+ i += 1
+ }
+ values
+ }
+}
+
+/**
+ * Converts a [[StructType]] to a csv output string.
+ */
+case class StructToCsv(
+ options: Map[String, String],
+ child: Expression,
+ timeZoneId: Option[String] = None)
+ extends UnaryExpression with TimeZoneAwareExpression with CodegenFallback with ExpectsInputTypes {
+ override def nullable: Boolean = true
+
+ @transient
+ private lazy val params = new CSVOptions(options, timeZoneId.get)
+
+ @transient
+ private lazy val dataSchema = child.dataType.asInstanceOf[StructType]
+
+ @transient
+ private lazy val writer = new CsvGenerator(dataSchema, params)
+
+ override def dataType: DataType = StringType
+
+ private def verifySchema(schema: StructType): Unit = {
+ def verifyType(dataType: DataType): Unit = dataType match {
+ case ByteType | ShortType | IntegerType | LongType | FloatType |
+ DoubleType | BooleanType | _: DecimalType | TimestampType |
+ DateType | StringType =>
+
+ case udt: UserDefinedType[_] => verifyType(udt.sqlType)
+
+ case _ =>
+ throw new UnsupportedOperationException(
+ s"CSV data source does not support ${dataType.simpleString} data type.")
+ }
+
+ schema.foreach(field => verifyType(field.dataType))
+ }
+
+ override def checkInputDataTypes(): TypeCheckResult = {
+ if (StructType.acceptsType(child.dataType)) {
+ try {
+ verifySchema(child.dataType.asInstanceOf[StructType])
+ TypeCheckResult.TypeCheckSuccess
+ } catch {
+ case e: UnsupportedOperationException =>
+ TypeCheckResult.TypeCheckFailure(e.getMessage)
+ }
+ } else {
+ TypeCheckResult.TypeCheckFailure(
+ s"$prettyName requires that the expression is a struct expression.")
+ }
+ }
+
+ override def nullSafeEval(row: Any): Any = {
+ val rowStr = writer.convertRow(row.asInstanceOf[InternalRow])
+ .mkString(params.delimiter.toString)
+ UTF8String.fromString(rowStr)
+ }
+
+ override def inputTypes: Seq[AbstractDataType] = StructType :: Nil
+
+ override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
+ copy(timeZoneId = Option(timeZoneId))
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8bf6dd9e/spark/spark-2.2/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinTopKExec.scala
----------------------------------------------------------------------
diff --git a/spark/spark-2.2/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinTopKExec.scala b/spark/spark-2.2/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinTopKExec.scala
new file mode 100644
index 0000000..0067bbb
--- /dev/null
+++ b/spark/spark-2.2/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinTopKExec.scala
@@ -0,0 +1,405 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.sql.execution.joins
+
+import org.apache.spark.TaskContext
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.catalyst.plans._
+import org.apache.spark.sql.catalyst.plans.physical._
+import org.apache.spark.sql.catalyst.utils.InternalRowPriorityQueue
+import org.apache.spark.sql.execution._
+import org.apache.spark.sql.execution.metric._
+import org.apache.spark.sql.types._
+
+abstract class PriorityQueueShim {
+
+ def insert(score: Any, row: InternalRow): Unit
+ def get(): Iterator[InternalRow]
+ def clear(): Unit
+}
+
+case class ShuffledHashJoinTopKExec(
+ k: Int,
+ leftKeys: Seq[Expression],
+ rightKeys: Seq[Expression],
+ condition: Option[Expression],
+ left: SparkPlan,
+ right: SparkPlan)(
+ scoreExpr: NamedExpression,
+ rankAttr: Seq[Attribute])
+ extends BinaryExecNode with TopKHelper with HashJoin with CodegenSupport {
+
+ override lazy val metrics = Map(
+ "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
+
+ override val scoreType: DataType = scoreExpr.dataType
+ override val joinType: JoinType = Inner
+ override val buildSide: BuildSide = BuildRight // Only support `BuildRight`
+
+ private lazy val scoreProjection: UnsafeProjection =
+ UnsafeProjection.create(scoreExpr :: Nil, left.output ++ right.output)
+
+ private lazy val boundCondition = if (condition.isDefined) {
+ (r: InternalRow) => newPredicate(condition.get, streamedPlan.output ++ buildPlan.output).eval(r)
+ } else {
+ (r: InternalRow) => true
+ }
+
+ private lazy val topKAttr = rankAttr :+ scoreExpr.toAttribute
+
+ private lazy val _priorityQueue = new PriorityQueueShim {
+
+ private val q: InternalRowPriorityQueue = queue
+ private val joinedRow = new JoinedRow
+
+ override def insert(score: Any, row: InternalRow): Unit = {
+ q += Tuple2(score, row)
+ }
+
+ override def get(): Iterator[InternalRow] = {
+ val outputRows = queue.iterator.toSeq.reverse
+ val (headScore, _) = outputRows.head
+ val rankNum = outputRows.scanLeft((1, headScore)) { case ((rank, prevScore), (score, _)) =>
+ if (prevScore == score) (rank, score) else (rank + 1, score)
+ }
+ val topKRow = new UnsafeRow(2)
+ val bufferHolder = new BufferHolder(topKRow)
+ val unsafeRowWriter = new UnsafeRowWriter(bufferHolder, 2)
+ val scoreWriter = ScoreWriter(unsafeRowWriter, 1)
+ outputRows.zip(rankNum.map(_._1)).map { case ((score, row), index) =>
+ // Writes to an UnsafeRow directly
+ bufferHolder.reset()
+ unsafeRowWriter.write(0, index)
+ scoreWriter.write(score)
+ topKRow.setTotalSize(bufferHolder.totalSize())
+ joinedRow.apply(topKRow, row)
+ }.iterator
+ }
+
+ override def clear(): Unit = q.clear()
+ }
+
+ override def output: Seq[Attribute] = joinType match {
+ case Inner => topKAttr ++ left.output ++ right.output
+ }
+
+ override protected final def otherCopyArgs: Seq[AnyRef] = {
+ scoreExpr :: rankAttr :: Nil
+ }
+
+ override def requiredChildDistribution: Seq[Distribution] =
+ ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
+
+ def buildHashedRelation(iter: Iterator[InternalRow]): HashedRelation = {
+ val context = TaskContext.get()
+ val relation = HashedRelation(iter, buildKeys, taskMemoryManager = context.taskMemoryManager())
+ context.addTaskCompletionListener(_ => relation.close())
+ relation
+ }
+
+ override protected def createResultProjection(): (InternalRow) => InternalRow = joinType match {
+ case Inner =>
+ // Always put the stream side on left to simplify implementation
+ // both of left and right side could be null
+ UnsafeProjection.create(
+ output, (topKAttr ++ streamedPlan.output ++ buildPlan.output).map(_.withNullability(true)))
+ }
+
+ protected def InnerJoin(
+ streamedIter: Iterator[InternalRow],
+ hashedRelation: HashedRelation,
+ numOutputRows: SQLMetric): Iterator[InternalRow] = {
+ val joinRow = new JoinedRow
+ val joinKeysProj = streamSideKeyGenerator()
+ val joinedIter = streamedIter.flatMap { srow =>
+ joinRow.withLeft(srow)
+ val joinKeys = joinKeysProj(srow) // `joinKeys` is also a grouping key
+ val matches = hashedRelation.get(joinKeys)
+ if (matches != null) {
+ matches.map(joinRow.withRight).filter(boundCondition).foreach { resultRow =>
+ _priorityQueue.insert(scoreProjection(resultRow).get(0, scoreType), resultRow)
+ }
+ val iter = _priorityQueue.get()
+ _priorityQueue.clear()
+ iter
+ } else {
+ Seq.empty
+ }
+ }
+ val resultProj = createResultProjection()
+ (joinedIter ++ queue.iterator.toSeq.sortBy(_._1)(reverseScoreOrdering)
+ .map(_._2)).map { r =>
+ resultProj(r)
+ }
+ }
+
+ override protected def doExecute(): RDD[InternalRow] = {
+ streamedPlan.execute().zipPartitions(buildPlan.execute()) { (streamIter, buildIter) =>
+ val hashed = buildHashedRelation(buildIter)
+ InnerJoin(streamIter, hashed, null)
+ }
+ }
+
+ override def inputRDDs(): Seq[RDD[InternalRow]] = {
+ left.execute() :: right.execute() :: Nil
+ }
+
+ // Accessor for generated code
+ def priorityQueue(): PriorityQueueShim = _priorityQueue
+
+ /**
+ * Add a state of HashedRelation and return the variable name for it.
+ */
+ private def prepareHashedRelation(ctx: CodegenContext): String = {
+ // create a name for HashedRelation
+ val joinExec = ctx.addReferenceObj("joinExec", this)
+ val relationTerm = ctx.freshName("relation")
+ val clsName = HashedRelation.getClass.getName.replace("$", "")
+ ctx.addMutableState(clsName, relationTerm,
+ s"""
+ | $relationTerm = ($clsName) $joinExec.buildHashedRelation(inputs[1]);
+ | incPeakExecutionMemory($relationTerm.estimatedSize());
+ """.stripMargin)
+ relationTerm
+ }
+
+ /**
+ * Creates variables for left part of result row.
+ *
+ * In order to defer the access after condition and also only access once in the loop,
+ * the variables should be declared separately from accessing the columns, we can't use the
+ * codegen of BoundReference here.
+ */
+ private def createLeftVars(ctx: CodegenContext, leftRow: String): Seq[ExprCode] = {
+ ctx.INPUT_ROW = leftRow
+ left.output.zipWithIndex.map { case (a, i) =>
+ val value = ctx.freshName("value")
+ val valueCode = ctx.getValue(leftRow, a.dataType, i.toString)
+ // declare it as class member, so we can access the column before or in the loop.
+ ctx.addMutableState(ctx.javaType(a.dataType), value, "")
+ if (a.nullable) {
+ val isNull = ctx.freshName("isNull")
+ ctx.addMutableState("boolean", isNull, "")
+ val code =
+ s"""
+ |$isNull = $leftRow.isNullAt($i);
+ |$value = $isNull ? ${ctx.defaultValue(a.dataType)} : ($valueCode);
+ """.stripMargin
+ ExprCode(code, isNull, value)
+ } else {
+ ExprCode(s"$value = $valueCode;", "false", value)
+ }
+ }
+ }
+
+ /**
+ * Creates the variables for right part of result row, using BoundReference, since the right
+ * part are accessed inside the loop.
+ */
+ private def createRightVar(ctx: CodegenContext, rightRow: String): Seq[ExprCode] = {
+ ctx.INPUT_ROW = rightRow
+ right.output.zipWithIndex.map { case (a, i) =>
+ BoundReference(i, a.dataType, a.nullable).genCode(ctx)
+ }
+ }
+
+ /**
+ * Returns the code for generating join key for stream side, and expression of whether the key
+ * has any null in it or not.
+ */
+ private def genStreamSideJoinKey(ctx: CodegenContext, leftRow: String): (ExprCode, String) = {
+ ctx.INPUT_ROW = leftRow
+ if (streamedKeys.length == 1 && streamedKeys.head.dataType == LongType) {
+ // generate the join key as Long
+ val ev = streamedKeys.head.genCode(ctx)
+ (ev, ev.isNull)
+ } else {
+ // generate the join key as UnsafeRow
+ val ev = GenerateUnsafeProjection.createCode(ctx, streamedKeys)
+ (ev, s"${ev.value}.anyNull()")
+ }
+ }
+
+ private def createScoreVar(ctx: CodegenContext, row: String): ExprCode = {
+ ctx.INPUT_ROW = row
+ BindReferences.bindReference(scoreExpr, left.output ++ right.output).genCode(ctx)
+ }
+
+ private def createResultVars(ctx: CodegenContext, resultRow: String): Seq[ExprCode] = {
+ ctx.INPUT_ROW = resultRow
+ output.zipWithIndex.map { case (a, i) =>
+ val value = ctx.freshName("value")
+ val valueCode = ctx.getValue(resultRow, a.dataType, i.toString)
+ // declare it as class member, so we can access the column before or in the loop.
+ ctx.addMutableState(ctx.javaType(a.dataType), value, "")
+ if (a.nullable) {
+ val isNull = ctx.freshName("isNull")
+ ctx.addMutableState("boolean", isNull, "")
+ val code =
+ s"""
+ |$isNull = $resultRow.isNullAt($i);
+ |$value = $isNull ? ${ctx.defaultValue(a.dataType)} : ($valueCode);
+ """.stripMargin
+ ExprCode(code, isNull, value)
+ } else {
+ ExprCode(s"$value = $valueCode;", "false", value)
+ }
+ }
+ }
+
+ /**
+ * Splits variables based on whether it's used by condition or not, returns the code to create
+ * these variables before the condition and after the condition.
+ *
+ * Only a few columns are used by condition, then we can skip the accessing of those columns
+ * that are not used by condition also filtered out by condition.
+ */
+ private def splitVarsByCondition(
+ attributes: Seq[Attribute],
+ variables: Seq[ExprCode]): (String, String) = {
+ if (condition.isDefined) {
+ val condRefs = condition.get.references
+ val (used, notUsed) = attributes.zip(variables).partition{ case (a, ev) =>
+ condRefs.contains(a)
+ }
+ val beforeCond = evaluateVariables(used.map(_._2))
+ val afterCond = evaluateVariables(notUsed.map(_._2))
+ (beforeCond, afterCond)
+ } else {
+ (evaluateVariables(variables), "")
+ }
+ }
+
+ override def doProduce(ctx: CodegenContext): String = {
+ ctx.copyResult = true
+
+ val topKJoin = ctx.addReferenceObj("topKJoin", this)
+
+ // Prepare a priority queue for top-K computing
+ val pQueue = ctx.freshName("queue")
+ ctx.addMutableState(classOf[PriorityQueueShim].getName, pQueue,
+ s"$pQueue = $topKJoin.priorityQueue();")
+
+ // Prepare variables for a left side
+ val leftIter = ctx.freshName("leftIter")
+ ctx.addMutableState("scala.collection.Iterator", leftIter, s"$leftIter = inputs[0];")
+ val leftRow = ctx.freshName("leftRow")
+ ctx.addMutableState("InternalRow", leftRow, "")
+ val leftVars = createLeftVars(ctx, leftRow)
+
+ // Prepare variables for a right side
+ val rightRow = ctx.freshName("rightRow")
+ val rightVars = createRightVar(ctx, rightRow)
+
+ // Build a hashed relation from a right side
+ val buildRelation = prepareHashedRelation(ctx)
+
+ // Project join keys from a left side
+ val (keyEv, anyNull) = genStreamSideJoinKey(ctx, leftRow)
+
+ // Prepare variables for joined rows
+ val joinedRow = ctx.freshName("joinedRow")
+ val joinedRowCls = classOf[JoinedRow].getName
+ ctx.addMutableState(joinedRowCls, joinedRow, s"$joinedRow = new $joinedRowCls();")
+
+ // Project score values from joined rows
+ val scoreVar = createScoreVar(ctx, joinedRow)
+
+ // Prepare variables for output rows
+ val resultRow = ctx.freshName("resultRow")
+ val resultVars = createResultVars(ctx, resultRow)
+
+ val (beforeLoop, condCheck) = if (condition.isDefined) {
+ // Split the code of creating variables based on whether it's used by condition or not.
+ val loaded = ctx.freshName("loaded")
+ val (leftBefore, leftAfter) = splitVarsByCondition(left.output, leftVars)
+ val (rightBefore, rightAfter) = splitVarsByCondition(right.output, rightVars)
+ // Generate code for condition
+ ctx.currentVars = leftVars ++ rightVars
+ val cond = BindReferences.bindReference(condition.get, output).genCode(ctx)
+ // evaluate the columns those used by condition before loop
+ val before = s"""
+ |boolean $loaded = false;
+ |$leftBefore
+ """.stripMargin
+
+ val checking = s"""
+ |$rightBefore
+ |${cond.code}
+ |if (${cond.isNull} || !${cond.value}) continue;
+ |if (!$loaded) {
+ | $loaded = true;
+ | $leftAfter
+ |}
+ |$rightAfter
+ """.stripMargin
+ (before, checking)
+ } else {
+ (evaluateVariables(leftVars), "")
+ }
+
+ val numOutput = metricTerm(ctx, "numOutputRows")
+
+ val matches = ctx.freshName("matches")
+ val topKRows = ctx.freshName("topKRows")
+ val iteratorCls = classOf[Iterator[UnsafeRow]].getName
+
+ s"""
+ |$leftRow = null;
+ |while ($leftIter.hasNext()) {
+ | $leftRow = (InternalRow) $leftIter.next();
+ |
+ | // Generate join key for stream side
+ | ${keyEv.code}
+ |
+ | // Find matches from HashedRelation
+ | $iteratorCls $matches = $anyNull? null : ($iteratorCls)$buildRelation.get(${keyEv.value});
+ | if ($matches == null) continue;
+ |
+ | // Join top-K right rows
+ | while ($matches.hasNext()) {
+ | ${beforeLoop.trim}
+ | InternalRow $rightRow = (InternalRow) $matches.next();
+ | ${condCheck.trim}
+ | InternalRow row = $joinedRow.apply($leftRow, $rightRow);
+ | // Compute a score for the `row`
+ | ${scoreVar.code}
+ | $pQueue.insert(${scoreVar.value}, row);
+ | }
+ |
+ | // Get top-K rows
+ | $iteratorCls $topKRows = $pQueue.get();
+ | $pQueue.clear();
+ |
+ | // Output top-K rows
+ | while ($topKRows.hasNext()) {
+ | InternalRow $resultRow = (InternalRow) $topKRows.next();
+ | $numOutput.add(1);
+ | ${consume(ctx, resultVars)}
+ | }
+ |
+ | if (shouldStop()) return;
+ |}
+ """.stripMargin
+ }
+}