You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2022/03/23 02:10:52 UTC
[spark] branch branch-3.3 updated: [SPARK-32268][SQL] Row-level Runtime Filtering
This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch branch-3.3
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.3 by this push:
new 2b59a96 [SPARK-32268][SQL] Row-level Runtime Filtering
2b59a96 is described below
commit 2b59a9616c5922d515db430752044869abd93746
Author: Abhishek Somani <ab...@databricks.com>
AuthorDate: Wed Mar 23 09:57:28 2022 +0800
[SPARK-32268][SQL] Row-level Runtime Filtering
### What changes were proposed in this pull request?
This PR proposes row-level runtime filters in Spark to reduce intermediate data volume for operators like shuffle, join and aggregate, and hence improve performance. We propose two mechanisms to do this: semi-join filters or bloom filters, and both mechanisms are proposed to co-exist side-by-side behind feature configs.
[Design Doc](https://docs.google.com/document/d/16IEuyLeQlubQkH8YuVuXWKo2-grVIoDJqQpHZrE7q04/edit?usp=sharing) with more details.
### Why are the changes needed?
With Semi-Join, we see 9 queries improve for the TPC DS 3TB benchmark, and no regressions.
With Bloom Filter, we see 10 queries improve for the TPC DS 3TB benchmark, and no regressions.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Added tests
Closes #35789 from somani/rf.
Lead-authored-by: Abhishek Somani <ab...@databricks.com>
Co-authored-by: Abhishek Somani <ab...@gmail.com>
Co-authored-by: Yuming Wang <yu...@ebay.com>
Signed-off-by: Wenchen Fan <we...@databricks.com>
(cherry picked from commit 1f4e4c812a9dc6d7e35631c1663c1ba6f6d9b721)
Signed-off-by: Wenchen Fan <we...@databricks.com>
---
.../org/apache/spark/util/sketch/BloomFilter.java | 7 +
.../apache/spark/util/sketch/BloomFilterImpl.java | 5 +
.../expressions/BloomFilterMightContain.scala | 113 +++++
.../aggregate/BloomFilterAggregate.scala | 179 ++++++++
.../sql/catalyst/expressions/objects/objects.scala | 2 +
.../sql/catalyst/expressions/predicates.scala | 16 +
.../catalyst/expressions/regexpExpressions.scala | 5 +-
.../catalyst/optimizer/InjectRuntimeFilter.scala | 303 +++++++++++++
.../spark/sql/catalyst/trees/TreePatterns.scala | 3 +
.../org/apache/spark/sql/internal/SQLConf.scala | 80 ++++
.../spark/sql/execution/SparkOptimizer.scala | 2 +
.../dynamicpruning/PartitionPruning.scala | 15 -
.../spark/sql/BloomFilterAggregateQuerySuite.scala | 215 +++++++++
.../spark/sql/InjectRuntimeFilterSuite.scala | 503 +++++++++++++++++++++
14 files changed, 1432 insertions(+), 16 deletions(-)
diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java
index c53987e..2a6e270 100644
--- a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java
+++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java
@@ -164,6 +164,13 @@ public abstract class BloomFilter {
public abstract void writeTo(OutputStream out) throws IOException;
/**
+ * @return the number of set bits in this {@link BloomFilter}.
+ */
+ public long cardinality() {
+ throw new UnsupportedOperationException("Not implemented");
+ }
+
+ /**
* Reads in a {@link BloomFilter} from an input stream. It is the caller's responsibility to close
* the stream.
*/
diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java
index e7766ee..ccf1833 100644
--- a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java
+++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java
@@ -207,6 +207,11 @@ class BloomFilterImpl extends BloomFilter implements Serializable {
return this;
}
+ @Override
+ public long cardinality() {
+ return this.bits.cardinality();
+ }
+
private BloomFilterImpl checkCompatibilityForMerge(BloomFilter other)
throws IncompatibleMergeException {
// Duplicates the logic of `isCompatible` here to provide better error message.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BloomFilterMightContain.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BloomFilterMightContain.scala
new file mode 100644
index 0000000..cf052f8
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BloomFilterMightContain.scala
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import java.io.ByteArrayInputStream
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, JavaCode, TrueLiteral}
+import org.apache.spark.sql.catalyst.expressions.codegen.Block.BlockHelper
+import org.apache.spark.sql.catalyst.trees.TreePattern.OUTER_REFERENCE
+import org.apache.spark.sql.types._
+import org.apache.spark.util.sketch.BloomFilter
+
+/**
+ * An internal scalar function that returns the membership check result (either true or false)
+ * for values of `valueExpression` in the Bloom filter represented by `bloomFilterExpression`.
+ * Not that since the function is "might contain", always returning true regardless is not
+ * wrong.
+ * Note that this expression requires that `bloomFilterExpression` is either a constant value or
+ * an uncorrelated scalar subquery. This is sufficient for the Bloom filter join rewrite.
+ *
+ * @param bloomFilterExpression the Binary data of Bloom filter.
+ * @param valueExpression the Long value to be tested for the membership of `bloomFilterExpression`.
+ */
+case class BloomFilterMightContain(
+ bloomFilterExpression: Expression,
+ valueExpression: Expression) extends BinaryExpression {
+
+ override def nullable: Boolean = true
+ override def left: Expression = bloomFilterExpression
+ override def right: Expression = valueExpression
+ override def prettyName: String = "might_contain"
+ override def dataType: DataType = BooleanType
+
+ override def checkInputDataTypes(): TypeCheckResult = {
+ (left.dataType, right.dataType) match {
+ case (BinaryType, NullType) | (NullType, LongType) | (NullType, NullType) |
+ (BinaryType, LongType) =>
+ bloomFilterExpression match {
+ case e : Expression if e.foldable => TypeCheckResult.TypeCheckSuccess
+ case subquery : PlanExpression[_] if !subquery.containsPattern(OUTER_REFERENCE) =>
+ TypeCheckResult.TypeCheckSuccess
+ case _ =>
+ TypeCheckResult.TypeCheckFailure(s"The Bloom filter binary input to $prettyName " +
+ "should be either a constant value or a scalar subquery expression")
+ }
+ case _ => TypeCheckResult.TypeCheckFailure(s"Input to function $prettyName should have " +
+ s"been ${BinaryType.simpleString} followed by a value with ${LongType.simpleString}, " +
+ s"but it's [${left.dataType.catalogString}, ${right.dataType.catalogString}].")
+ }
+ }
+
+ override protected def withNewChildrenInternal(
+ newBloomFilterExpression: Expression,
+ newValueExpression: Expression): BloomFilterMightContain =
+ copy(bloomFilterExpression = newBloomFilterExpression,
+ valueExpression = newValueExpression)
+
+ // The bloom filter created from `bloomFilterExpression`.
+ @transient private lazy val bloomFilter = {
+ val bytes = bloomFilterExpression.eval().asInstanceOf[Array[Byte]]
+ if (bytes == null) null else deserialize(bytes)
+ }
+
+ override def eval(input: InternalRow): Any = {
+ if (bloomFilter == null) {
+ null
+ } else {
+ val value = valueExpression.eval(input)
+ if (value == null) null else bloomFilter.mightContainLong(value.asInstanceOf[Long])
+ }
+ }
+
+ override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ if (bloomFilter == null) {
+ ev.copy(isNull = TrueLiteral, value = JavaCode.defaultLiteral(dataType))
+ } else {
+ val bf = ctx.addReferenceObj("bloomFilter", bloomFilter, classOf[BloomFilter].getName)
+ val valueEval = valueExpression.genCode(ctx)
+ ev.copy(code = code"""
+ ${valueEval.code}
+ boolean ${ev.isNull} = ${valueEval.isNull};
+ ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
+ if (!${ev.isNull}) {
+ ${ev.value} = $bf.mightContainLong((Long)${valueEval.value});
+ }""")
+ }
+ }
+
+ final def deserialize(bytes: Array[Byte]): BloomFilter = {
+ val in = new ByteArrayInputStream(bytes)
+ val bloomFilter = BloomFilter.readFrom(in)
+ in.close()
+ bloomFilter
+ }
+
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/BloomFilterAggregate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/BloomFilterAggregate.scala
new file mode 100644
index 0000000..c734bca
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/BloomFilterAggregate.scala
@@ -0,0 +1,179 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.aggregate
+
+import java.io.ByteArrayInputStream
+import java.io.ByteArrayOutputStream
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult._
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.trees.TernaryLike
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types._
+import org.apache.spark.util.sketch.BloomFilter
+
+/**
+ * An internal aggregate function that creates a Bloom filter from input values.
+ *
+ * @param child Child expression of Long values for creating a Bloom filter.
+ * @param estimatedNumItemsExpression The number of estimated distinct items (optional).
+ * @param numBitsExpression The number of bits to use (optional).
+ */
+case class BloomFilterAggregate(
+ child: Expression,
+ estimatedNumItemsExpression: Expression,
+ numBitsExpression: Expression,
+ override val mutableAggBufferOffset: Int,
+ override val inputAggBufferOffset: Int)
+ extends TypedImperativeAggregate[BloomFilter] with TernaryLike[Expression] {
+
+ def this(child: Expression, estimatedNumItemsExpression: Expression,
+ numBitsExpression: Expression) = {
+ this(child, estimatedNumItemsExpression, numBitsExpression, 0, 0)
+ }
+
+ def this(child: Expression, estimatedNumItemsExpression: Expression) = {
+ this(child, estimatedNumItemsExpression,
+ // 1 byte per item.
+ Multiply(estimatedNumItemsExpression, Literal(8L)))
+ }
+
+ def this(child: Expression) = {
+ this(child, Literal(SQLConf.get.getConf(SQLConf.RUNTIME_BLOOM_FILTER_EXPECTED_NUM_ITEMS)),
+ Literal(SQLConf.get.getConf(SQLConf.RUNTIME_BLOOM_FILTER_NUM_BITS)))
+ }
+
+ override def checkInputDataTypes(): TypeCheckResult = {
+ (first.dataType, second.dataType, third.dataType) match {
+ case (_, NullType, _) | (_, _, NullType) =>
+ TypeCheckResult.TypeCheckFailure("Null typed values cannot be used as size arguments")
+ case (LongType, LongType, LongType) =>
+ if (!estimatedNumItemsExpression.foldable) {
+ TypeCheckFailure("The estimated number of items provided must be a constant literal")
+ } else if (estimatedNumItems <= 0L) {
+ TypeCheckFailure("The estimated number of items must be a positive value " +
+ s" (current value = $estimatedNumItems)")
+ } else if (!numBitsExpression.foldable) {
+ TypeCheckFailure("The number of bits provided must be a constant literal")
+ } else if (numBits <= 0L) {
+ TypeCheckFailure("The number of bits must be a positive value " +
+ s" (current value = $numBits)")
+ } else {
+ require(estimatedNumItems <=
+ SQLConf.get.getConf(SQLConf.RUNTIME_BLOOM_FILTER_MAX_NUM_ITEMS))
+ require(numBits <= SQLConf.get.getConf(SQLConf.RUNTIME_BLOOM_FILTER_MAX_NUM_BITS))
+ TypeCheckSuccess
+ }
+ case _ => TypeCheckResult.TypeCheckFailure(s"Input to function $prettyName should have " +
+ s"been a ${LongType.simpleString} value followed with two ${LongType.simpleString} size " +
+ s"arguments, but it's [${first.dataType.catalogString}, " +
+ s"${second.dataType.catalogString}, ${third.dataType.catalogString}]")
+ }
+ }
+ override def nullable: Boolean = true
+
+ override def dataType: DataType = BinaryType
+
+ override def prettyName: String = "bloom_filter_agg"
+
+ // Mark as lazy so that `estimatedNumItems` is not evaluated during tree transformation.
+ private lazy val estimatedNumItems: Long =
+ Math.min(estimatedNumItemsExpression.eval().asInstanceOf[Number].longValue,
+ SQLConf.get.getConf(SQLConf.RUNTIME_BLOOM_FILTER_MAX_NUM_ITEMS))
+
+ // Mark as lazy so that `numBits` is not evaluated during tree transformation.
+ private lazy val numBits: Long =
+ Math.min(numBitsExpression.eval().asInstanceOf[Number].longValue,
+ SQLConf.get.getConf(SQLConf.RUNTIME_BLOOM_FILTER_MAX_NUM_BITS))
+
+ override def first: Expression = child
+
+ override def second: Expression = estimatedNumItemsExpression
+
+ override def third: Expression = numBitsExpression
+
+ override protected def withNewChildrenInternal(
+ newChild: Expression,
+ newEstimatedNumItemsExpression: Expression,
+ newNumBitsExpression: Expression): BloomFilterAggregate = {
+ copy(child = newChild, estimatedNumItemsExpression = newEstimatedNumItemsExpression,
+ numBitsExpression = newNumBitsExpression)
+ }
+
+ override def createAggregationBuffer(): BloomFilter = {
+ BloomFilter.create(estimatedNumItems, numBits)
+ }
+
+ override def update(buffer: BloomFilter, inputRow: InternalRow): BloomFilter = {
+ val value = child.eval(inputRow)
+ // Ignore null values.
+ if (value == null) {
+ return buffer
+ }
+ buffer.putLong(value.asInstanceOf[Long])
+ buffer
+ }
+
+ override def merge(buffer: BloomFilter, other: BloomFilter): BloomFilter = {
+ buffer.mergeInPlace(other)
+ }
+
+ override def eval(buffer: BloomFilter): Any = {
+ if (buffer.cardinality() == 0) {
+ // There's no set bit in the Bloom filter and hence no not-null value is processed.
+ return null
+ }
+ serialize(buffer)
+ }
+
+ override def withNewMutableAggBufferOffset(newOffset: Int): BloomFilterAggregate =
+ copy(mutableAggBufferOffset = newOffset)
+
+ override def withNewInputAggBufferOffset(newOffset: Int): BloomFilterAggregate =
+ copy(inputAggBufferOffset = newOffset)
+
+ override def serialize(obj: BloomFilter): Array[Byte] = {
+ BloomFilterAggregate.serialize(obj)
+ }
+
+ override def deserialize(bytes: Array[Byte]): BloomFilter = {
+ BloomFilterAggregate.deserialize(bytes)
+ }
+}
+
+object BloomFilterAggregate {
+ final def serialize(obj: BloomFilter): Array[Byte] = {
+ // BloomFilterImpl.writeTo() writes 2 integers (version number and num hash functions), hence
+ // the +8
+ val size = (obj.bitSize() / 8) + 8
+ require(size <= Integer.MAX_VALUE, s"actual number of bits is too large $size")
+ val out = new ByteArrayOutputStream(size.intValue())
+ obj.writeTo(out)
+ out.close()
+ out.toByteArray
+ }
+
+ final def deserialize(bytes: Array[Byte]): BloomFilter = {
+ val in = new ByteArrayInputStream(bytes)
+ val bloomFilter = BloomFilter.readFrom(in)
+ in.close()
+ bloomFilter
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index 6974ada..2c879be 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -360,6 +360,8 @@ case class Invoke(
lazy val argClasses = ScalaReflection.expressionJavaClasses(arguments)
+ final override val nodePatterns: Seq[TreePattern] = Seq(INVOKE)
+
override def nullable: Boolean = targetObject.nullable || needNullCheck || returnNullable
override def children: Seq[Expression] = targetObject +: arguments
override lazy val deterministic: Boolean = isDeterministic && arguments.forall(_.deterministic)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index a2fd668..d16e09c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -287,6 +287,22 @@ trait PredicateHelper extends AliasHelper with Logging {
}
}
}
+
+ /**
+ * Returns whether an expression is likely to be selective
+ */
+ def isLikelySelective(e: Expression): Boolean = e match {
+ case Not(expr) => isLikelySelective(expr)
+ case And(l, r) => isLikelySelective(l) || isLikelySelective(r)
+ case Or(l, r) => isLikelySelective(l) && isLikelySelective(r)
+ case _: StringRegexExpression => true
+ case _: BinaryComparison => true
+ case _: In | _: InSet => true
+ case _: StringPredicate => true
+ case BinaryPredicate(_) => true
+ case _: MultiLikeBase => true
+ case _ => false
+ }
}
@ExpressionDescription(
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala
index 368cbfd..bfaaba5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala
@@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure,
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.trees.BinaryLike
-import org.apache.spark.sql.catalyst.trees.TreePattern.{LIKE_FAMLIY, TreePattern}
+import org.apache.spark.sql.catalyst.trees.TreePattern.{LIKE_FAMLIY, REGEXP_EXTRACT_FAMILY, REGEXP_REPLACE, TreePattern}
import org.apache.spark.sql.catalyst.util.{GenericArrayData, StringUtils}
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.types._
@@ -627,6 +627,7 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio
@transient private var lastReplacementInUTF8: UTF8String = _
// result buffer write by Matcher
@transient private lazy val result: StringBuffer = new StringBuffer
+ final override val nodePatterns: Seq[TreePattern] = Seq(REGEXP_REPLACE)
override def nullSafeEval(s: Any, p: Any, r: Any, i: Any): Any = {
if (!p.equals(lastRegex)) {
@@ -751,6 +752,8 @@ abstract class RegExpExtractBase
// last regex pattern, we cache it for performance concern
@transient private var pattern: Pattern = _
+ final override val nodePatterns: Seq[TreePattern] = Seq(REGEXP_EXTRACT_FAMILY)
+
override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType, IntegerType)
override def first: Expression = subject
override def second: Expression = regexp
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala
new file mode 100644
index 0000000..35d0189
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala
@@ -0,0 +1,303 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, BloomFilterAggregate, Complete}
+import org.apache.spark.sql.catalyst.planning.{ExtractEquiJoinKeys, PhysicalOperation}
+import org.apache.spark.sql.catalyst.plans._
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.catalyst.trees.TreePattern.{INVOKE, JSON_TO_STRUCT, LIKE_FAMLIY, PYTHON_UDF, REGEXP_EXTRACT_FAMILY, REGEXP_REPLACE, SCALA_UDF}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types._
+
+/**
+ * Insert a filter on one side of the join if the other side has a selective predicate.
+ * The filter could be an IN subquery (converted to a semi join), a bloom filter, or something
+ * else in the future.
+ */
+object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with JoinSelectionHelper {
+
+ // Wraps `expr` with a hash function if its byte size is larger than an integer.
+ private def mayWrapWithHash(expr: Expression): Expression = {
+ if (expr.dataType.defaultSize > IntegerType.defaultSize) {
+ new Murmur3Hash(Seq(expr))
+ } else {
+ expr
+ }
+ }
+
+ private def injectFilter(
+ filterApplicationSideExp: Expression,
+ filterApplicationSidePlan: LogicalPlan,
+ filterCreationSideExp: Expression,
+ filterCreationSidePlan: LogicalPlan): LogicalPlan = {
+ require(conf.runtimeFilterBloomFilterEnabled || conf.runtimeFilterSemiJoinReductionEnabled)
+ if (conf.runtimeFilterBloomFilterEnabled) {
+ injectBloomFilter(
+ filterApplicationSideExp,
+ filterApplicationSidePlan,
+ filterCreationSideExp,
+ filterCreationSidePlan
+ )
+ } else {
+ injectInSubqueryFilter(
+ filterApplicationSideExp,
+ filterApplicationSidePlan,
+ filterCreationSideExp,
+ filterCreationSidePlan
+ )
+ }
+ }
+
+ private def injectBloomFilter(
+ filterApplicationSideExp: Expression,
+ filterApplicationSidePlan: LogicalPlan,
+ filterCreationSideExp: Expression,
+ filterCreationSidePlan: LogicalPlan): LogicalPlan = {
+ // Skip if the filter creation side is too big
+ if (filterCreationSidePlan.stats.sizeInBytes > conf.runtimeFilterCreationSideThreshold) {
+ return filterApplicationSidePlan
+ }
+ val rowCount = filterCreationSidePlan.stats.rowCount
+ val bloomFilterAgg =
+ if (rowCount.isDefined && rowCount.get.longValue > 0L) {
+ new BloomFilterAggregate(new XxHash64(Seq(filterCreationSideExp)),
+ Literal(rowCount.get.longValue))
+ } else {
+ new BloomFilterAggregate(new XxHash64(Seq(filterCreationSideExp)))
+ }
+ val aggExp = AggregateExpression(bloomFilterAgg, Complete, isDistinct = false, None)
+ val alias = Alias(aggExp, "bloomFilter")()
+ val aggregate = ConstantFolding(Aggregate(Nil, Seq(alias), filterCreationSidePlan))
+ val bloomFilterSubquery = ScalarSubquery(aggregate, Nil)
+ val filter = BloomFilterMightContain(bloomFilterSubquery,
+ new XxHash64(Seq(filterApplicationSideExp)))
+ Filter(filter, filterApplicationSidePlan)
+ }
+
+ private def injectInSubqueryFilter(
+ filterApplicationSideExp: Expression,
+ filterApplicationSidePlan: LogicalPlan,
+ filterCreationSideExp: Expression,
+ filterCreationSidePlan: LogicalPlan): LogicalPlan = {
+ require(filterApplicationSideExp.dataType == filterCreationSideExp.dataType)
+ val actualFilterKeyExpr = mayWrapWithHash(filterCreationSideExp)
+ val alias = Alias(actualFilterKeyExpr, actualFilterKeyExpr.toString)()
+ val aggregate = Aggregate(Seq(alias), Seq(alias), filterCreationSidePlan)
+ if (!canBroadcastBySize(aggregate, conf)) {
+ // Skip the InSubquery filter if the size of `aggregate` is beyond broadcast join threshold,
+ // i.e., the semi-join will be a shuffled join, which is not worthwhile.
+ return filterApplicationSidePlan
+ }
+ val filter = InSubquery(Seq(mayWrapWithHash(filterApplicationSideExp)),
+ ListQuery(aggregate, childOutputs = aggregate.output))
+ Filter(filter, filterApplicationSidePlan)
+ }
+
+ /**
+ * Returns whether the plan is a simple filter over scan and the filter is likely selective
+ * Also check if the plan only has simple expressions (attribute reference, literals) so that we
+ * do not add a subquery that might have an expensive computation
+ */
+ private def isSelectiveFilterOverScan(plan: LogicalPlan): Boolean = {
+ val ret = plan match {
+ case PhysicalOperation(_, filters, child) if child.isInstanceOf[LeafNode] =>
+ filters.forall(isSimpleExpression) &&
+ filters.exists(isLikelySelective)
+ case _ => false
+ }
+ !plan.isStreaming && ret
+ }
+
+ private def isSimpleExpression(e: Expression): Boolean = {
+ !e.containsAnyPattern(PYTHON_UDF, SCALA_UDF, INVOKE, JSON_TO_STRUCT, LIKE_FAMLIY,
+ REGEXP_EXTRACT_FAMILY, REGEXP_REPLACE)
+ }
+
+ private def canFilterLeft(joinType: JoinType): Boolean = joinType match {
+ case Inner | RightOuter => true
+ case _ => false
+ }
+
+ private def canFilterRight(joinType: JoinType): Boolean = joinType match {
+ case Inner | LeftOuter => true
+ case _ => false
+ }
+
+ private def isProbablyShuffleJoin(left: LogicalPlan,
+ right: LogicalPlan, hint: JoinHint): Boolean = {
+ !hintToBroadcastLeft(hint) && !hintToBroadcastRight(hint) &&
+ !canBroadcastBySize(left, conf) && !canBroadcastBySize(right, conf)
+ }
+
+ private def probablyHasShuffle(plan: LogicalPlan): Boolean = {
+ plan.collectFirst {
+ case j@Join(left, right, _, _, hint)
+ if isProbablyShuffleJoin(left, right, hint) => j
+ case a: Aggregate => a
+ }.nonEmpty
+ }
+
+ // Returns the max scan byte size in the subtree rooted at `filterApplicationSide`.
+ private def maxScanByteSize(filterApplicationSide: LogicalPlan): BigInt = {
+ val defaultSizeInBytes = conf.getConf(SQLConf.DEFAULT_SIZE_IN_BYTES)
+ filterApplicationSide.collect({
+ case leaf: LeafNode => leaf
+ }).map(scan => {
+ // DEFAULT_SIZE_IN_BYTES means there's no byte size information in stats. Since we avoid
+ // creating a Bloom filter when the filter application side is very small, so using 0
+ // as the byte size when the actual size is unknown can avoid regression by applying BF
+ // on a small table.
+ if (scan.stats.sizeInBytes == defaultSizeInBytes) BigInt(0) else scan.stats.sizeInBytes
+ }).max
+ }
+
+ // Returns true if `filterApplicationSide` satisfies the byte size requirement to apply a
+ // Bloom filter; false otherwise.
+ private def satisfyByteSizeRequirement(filterApplicationSide: LogicalPlan): Boolean = {
+ // In case `filterApplicationSide` is a union of many small tables, disseminating the Bloom
+ // filter to each small task might be more costly than scanning them itself. Thus, we use max
+ // rather than sum here.
+ val maxScanSize = maxScanByteSize(filterApplicationSide)
+ maxScanSize >=
+ conf.getConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD)
+ }
+
+ /**
+ * Check that:
+ * - The filterApplicationSideJoinExp can be pushed down through joins and aggregates (ie the
+ * expression references originate from a single leaf node)
+ * - The filter creation side has a selective predicate
+ * - The current join is a shuffle join or a broadcast join that has a shuffle below it
+ * - The max filterApplicationSide scan size is greater than a configurable threshold
+ */
+ private def filteringHasBenefit(
+ filterApplicationSide: LogicalPlan,
+ filterCreationSide: LogicalPlan,
+ filterApplicationSideExp: Expression,
+ hint: JoinHint): Boolean = {
+ findExpressionAndTrackLineageDown(filterApplicationSideExp,
+ filterApplicationSide).isDefined && isSelectiveFilterOverScan(filterCreationSide) &&
+ (isProbablyShuffleJoin(filterApplicationSide, filterCreationSide, hint) ||
+ probablyHasShuffle(filterApplicationSide)) &&
+ satisfyByteSizeRequirement(filterApplicationSide)
+ }
+
+ def hasRuntimeFilter(left: LogicalPlan, right: LogicalPlan, leftKey: Expression,
+ rightKey: Expression): Boolean = {
+ if (conf.runtimeFilterBloomFilterEnabled) {
+ hasBloomFilter(left, right, leftKey, rightKey)
+ } else {
+ hasInSubquery(left, right, leftKey, rightKey)
+ }
+ }
+
+ // This checks if there is already a DPP filter, as this rule is called just after DPP.
+ def hasDynamicPruningSubquery(
+ left: LogicalPlan,
+ right: LogicalPlan,
+ leftKey: Expression,
+ rightKey: Expression): Boolean = {
+ (left, right) match {
+ case (Filter(DynamicPruningSubquery(pruningKey, _, _, _, _, _), plan), _) =>
+ pruningKey.fastEquals(leftKey) || hasDynamicPruningSubquery(plan, right, leftKey, rightKey)
+ case (_, Filter(DynamicPruningSubquery(pruningKey, _, _, _, _, _), plan)) =>
+ pruningKey.fastEquals(rightKey) ||
+ hasDynamicPruningSubquery(left, plan, leftKey, rightKey)
+ case _ => false
+ }
+ }
+
+ def hasBloomFilter(
+ left: LogicalPlan,
+ right: LogicalPlan,
+ leftKey: Expression,
+ rightKey: Expression): Boolean = {
+ findBloomFilterWithExp(left, leftKey) || findBloomFilterWithExp(right, rightKey)
+ }
+
+ private def findBloomFilterWithExp(plan: LogicalPlan, key: Expression): Boolean = {
+ plan.find {
+ case Filter(condition, _) =>
+ splitConjunctivePredicates(condition).exists {
+ case BloomFilterMightContain(_, XxHash64(Seq(valueExpression), _))
+ if valueExpression.fastEquals(key) => true
+ case _ => false
+ }
+ case _ => false
+ }.isDefined
+ }
+
+ def hasInSubquery(left: LogicalPlan, right: LogicalPlan, leftKey: Expression,
+ rightKey: Expression): Boolean = {
+ (left, right) match {
+ case (Filter(InSubquery(Seq(key),
+ ListQuery(Aggregate(Seq(Alias(_, _)), Seq(Alias(_, _)), _), _, _, _, _)), _), _) =>
+ key.fastEquals(leftKey) || key.fastEquals(new Murmur3Hash(Seq(leftKey)))
+ case (_, Filter(InSubquery(Seq(key),
+ ListQuery(Aggregate(Seq(Alias(_, _)), Seq(Alias(_, _)), _), _, _, _, _)), _)) =>
+ key.fastEquals(rightKey) || key.fastEquals(new Murmur3Hash(Seq(rightKey)))
+ case _ => false
+ }
+ }
+
+ private def tryInjectRuntimeFilter(plan: LogicalPlan): LogicalPlan = {
+ var filterCounter = 0
+ val numFilterThreshold = conf.getConf(SQLConf.RUNTIME_FILTER_NUMBER_THRESHOLD)
+ plan transformUp {
+ case join @ ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, _, _, left, right, hint) =>
+ var newLeft = left
+ var newRight = right
+ (leftKeys, rightKeys).zipped.foreach((l, r) => {
+ // Check if:
+ // 1. There is already a DPP filter on the key
+ // 2. There is already a runtime filter (Bloom filter or IN subquery) on the key
+ // 3. The keys are simple cheap expressions
+ if (filterCounter < numFilterThreshold &&
+ !hasDynamicPruningSubquery(left, right, l, r) &&
+ !hasRuntimeFilter(newLeft, newRight, l, r) &&
+ isSimpleExpression(l) && isSimpleExpression(r)) {
+ val oldLeft = newLeft
+ val oldRight = newRight
+ if (canFilterLeft(joinType) && filteringHasBenefit(left, right, l, hint)) {
+ newLeft = injectFilter(l, newLeft, r, right)
+ }
+ // Did we actually inject on the left? If not, try on the right
+ if (newLeft.fastEquals(oldLeft) && canFilterRight(joinType) &&
+ filteringHasBenefit(right, left, r, hint)) {
+ newRight = injectFilter(r, newRight, l, left)
+ }
+ if (!newLeft.fastEquals(oldLeft) || !newRight.fastEquals(oldRight)) {
+ filterCounter = filterCounter + 1
+ }
+ }
+ })
+ join.withNewChildren(Seq(newLeft, newRight))
+ }
+ }
+
+ override def apply(plan: LogicalPlan): LogicalPlan = plan match {
+ case s: Subquery if s.correlated => plan
+ case _ if !conf.runtimeFilterSemiJoinReductionEnabled &&
+ !conf.runtimeFilterBloomFilterEnabled => plan
+ case _ => tryInjectRuntimeFilter(plan)
+ }
+
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
index b595966..3cf45d5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
@@ -54,6 +54,7 @@ object TreePattern extends Enumeration {
val IN_SUBQUERY: Value = Value
val INSET: Value = Value
val INTERSECT: Value = Value
+ val INVOKE: Value = Value
val JSON_TO_STRUCT: Value = Value
val LAMBDA_FUNCTION: Value = Value
val LAMBDA_VARIABLE: Value = Value
@@ -72,6 +73,8 @@ object TreePattern extends Enumeration {
val PIVOT: Value = Value
val PLAN_EXPRESSION: Value = Value
val PYTHON_UDF: Value = Value
+ val REGEXP_EXTRACT_FAMILY: Value = Value
+ val REGEXP_REPLACE: Value = Value
val RUNTIME_REPLACEABLE: Value = Value
val SCALAR_SUBQUERY: Value = Value
val SCALA_UDF: Value = Value
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 3314dd1..1bba8b6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -341,6 +341,77 @@ object SQLConf {
.booleanConf
.createWithDefault(true)
+ val RUNTIME_FILTER_SEMI_JOIN_REDUCTION_ENABLED =
+ buildConf("spark.sql.optimizer.runtimeFilter.semiJoinReduction.enabled")
+ .doc("When true and if one side of a shuffle join has a selective predicate, we attempt " +
+ "to insert a semi join in the other side to reduce the amount of shuffle data.")
+ .version("3.3.0")
+ .booleanConf
+ .createWithDefault(false)
+
+ val RUNTIME_FILTER_NUMBER_THRESHOLD =
+ buildConf("spark.sql.optimizer.runtimeFilter.number.threshold")
+ .doc("The total number of injected runtime filters (non-DPP) for a single " +
+ "query. This is to prevent driver OOMs with too many Bloom filters.")
+ .version("3.3.0")
+ .intConf
+ .checkValue(threshold => threshold >= 0, "The threshold should be >= 0")
+ .createWithDefault(10)
+
+ val RUNTIME_BLOOM_FILTER_ENABLED =
+ buildConf("spark.sql.optimizer.runtime.bloomFilter.enabled")
+ .doc("When true and if one side of a shuffle join has a selective predicate, we attempt " +
+ "to insert a bloom filter in the other side to reduce the amount of shuffle data.")
+ .version("3.3.0")
+ .booleanConf
+ .createWithDefault(false)
+
+ val RUNTIME_BLOOM_FILTER_CREATION_SIDE_THRESHOLD =
+ buildConf("spark.sql.optimizer.runtime.bloomFilter.creationSideThreshold")
+ .doc("Size threshold of the bloom filter creation side plan. Estimated size needs to be " +
+ "under this value to try to inject bloom filter.")
+ .version("3.3.0")
+ .bytesConf(ByteUnit.BYTE)
+ .createWithDefaultString("10MB")
+
+ val RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD =
+ buildConf("spark.sql.optimizer.runtime.bloomFilter.applicationSideScanSizethreshold")
+ .doc("Byte size threshold of the Bloom filter application side plan's aggregated scan " +
+ "size. Aggregated scan byte size of the Bloom filter application side needs to be over " +
+ "this value to inject a bloom filter.")
+ .version("3.3.0")
+ .bytesConf(ByteUnit.BYTE)
+ .createWithDefaultString("10GB")
+
+ val RUNTIME_BLOOM_FILTER_EXPECTED_NUM_ITEMS =
+ buildConf("spark.sql.optimizer.runtime.bloomFilter.expectedNumItems")
+ .doc("The default number of expected items for the runtime bloomfilter")
+ .version("3.3.0")
+ .longConf
+ .createWithDefault(1000000L)
+
+ val RUNTIME_BLOOM_FILTER_MAX_NUM_ITEMS =
+ buildConf("spark.sql.optimizer.runtime.bloomFilter.maxNumItems")
+ .doc("The max allowed number of expected items for the runtime bloom filter")
+ .version("3.3.0")
+ .longConf
+ .createWithDefault(4000000L)
+
+
+ val RUNTIME_BLOOM_FILTER_NUM_BITS =
+ buildConf("spark.sql.optimizer.runtime.bloomFilter.numBits")
+ .doc("The default number of bits to use for the runtime bloom filter")
+ .version("3.3.0")
+ .longConf
+ .createWithDefault(8388608L)
+
+ val RUNTIME_BLOOM_FILTER_MAX_NUM_BITS =
+ buildConf("spark.sql.optimizer.runtime.bloomFilter.maxNumBits")
+ .doc("The max number of bits to use for the runtime bloom filter")
+ .version("3.3.0")
+ .longConf
+ .createWithDefault(67108864L)
+
val COMPRESS_CACHED = buildConf("spark.sql.inMemoryColumnarStorage.compressed")
.doc("When set to true Spark SQL will automatically select a compression codec for each " +
"column based on statistics of the data.")
@@ -3750,6 +3821,15 @@ class SQLConf extends Serializable with Logging {
def dynamicPartitionPruningReuseBroadcastOnly: Boolean =
getConf(DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY)
+ def runtimeFilterSemiJoinReductionEnabled: Boolean =
+ getConf(RUNTIME_FILTER_SEMI_JOIN_REDUCTION_ENABLED)
+
+ def runtimeFilterBloomFilterEnabled: Boolean =
+ getConf(RUNTIME_BLOOM_FILTER_ENABLED)
+
+ def runtimeFilterCreationSideThreshold: Long =
+ getConf(RUNTIME_BLOOM_FILTER_CREATION_SIDE_THRESHOLD)
+
def stateStoreProviderClass: String = getConf(STATE_STORE_PROVIDER_CLASS)
def isStateSchemaCheckEnabled: Boolean = getConf(STATE_SCHEMA_CHECK_ENABLED)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala
index 7e8fb4a..743cb59 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala
@@ -43,6 +43,8 @@ class SparkOptimizer(
Batch("Optimize Metadata Only Query", Once, OptimizeMetadataOnlyQuery(catalog)) :+
Batch("PartitionPruning", Once,
PartitionPruning) :+
+ Batch("InjectRuntimeFilter", FixedPoint(1),
+ InjectRuntimeFilter) :+
Batch("Pushdown Filters from PartitionPruning", fixedPoint,
PushDownPredicates) :+
Batch("Cleanup filters that cannot be pushed down", Once,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PartitionPruning.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PartitionPruning.scala
index 3b5fc4a..89d6603 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PartitionPruning.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PartitionPruning.scala
@@ -194,21 +194,6 @@ object PartitionPruning extends Rule[LogicalPlan] with PredicateHelper {
scanOverhead + cachedOverhead
}
- /**
- * Returns whether an expression is likely to be selective
- */
- private def isLikelySelective(e: Expression): Boolean = e match {
- case Not(expr) => isLikelySelective(expr)
- case And(l, r) => isLikelySelective(l) || isLikelySelective(r)
- case Or(l, r) => isLikelySelective(l) && isLikelySelective(r)
- case _: StringRegexExpression => true
- case _: BinaryComparison => true
- case _: In | _: InSet => true
- case _: StringPredicate => true
- case BinaryPredicate(_) => true
- case _: MultiLikeBase => true
- case _ => false
- }
/**
* Search a filtering predicate in a given logical plan
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/BloomFilterAggregateQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/BloomFilterAggregateQuerySuite.scala
new file mode 100644
index 0000000..025593b
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/BloomFilterAggregateQuerySuite.scala
@@ -0,0 +1,215 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.sql.catalyst.FunctionIdentifier
+import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate.BloomFilterAggregate
+import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec
+import org.apache.spark.sql.execution.aggregate.BaseAggregateExec
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.test.SharedSparkSession
+
+/**
+ * Query tests for the Bloom filter aggregate and filter function.
+ */
+class BloomFilterAggregateQuerySuite extends QueryTest with SharedSparkSession {
+ import testImplicits._
+
+ val funcId_bloom_filter_agg = new FunctionIdentifier("bloom_filter_agg")
+ val funcId_might_contain = new FunctionIdentifier("might_contain")
+
+ // Register 'bloom_filter_agg' to builtin.
+ FunctionRegistry.builtin.registerFunction(funcId_bloom_filter_agg,
+ new ExpressionInfo(classOf[BloomFilterAggregate].getName, "bloom_filter_agg"),
+ (children: Seq[Expression]) => children.size match {
+ case 1 => new BloomFilterAggregate(children.head)
+ case 2 => new BloomFilterAggregate(children.head, children(1))
+ case 3 => new BloomFilterAggregate(children.head, children(1), children(2))
+ })
+
+ // Register 'might_contain' to builtin.
+ FunctionRegistry.builtin.registerFunction(funcId_might_contain,
+ new ExpressionInfo(classOf[BloomFilterMightContain].getName, "might_contain"),
+ (children: Seq[Expression]) => BloomFilterMightContain(children.head, children(1)))
+
+ override def afterAll(): Unit = {
+ FunctionRegistry.builtin.dropFunction(funcId_bloom_filter_agg)
+ FunctionRegistry.builtin.dropFunction(funcId_might_contain)
+ super.afterAll()
+ }
+
+ test("Test bloom_filter_agg and might_contain") {
+ val conf = SQLConf.get
+ val table = "bloom_filter_test"
+ for (numEstimatedItems <- Seq(Long.MinValue, -10L, 0L, 4096L, 4194304L, Long.MaxValue,
+ conf.getConf(SQLConf.RUNTIME_BLOOM_FILTER_MAX_NUM_ITEMS))) {
+ for (numBits <- Seq(Long.MinValue, -10L, 0L, 4096L, 4194304L, Long.MaxValue,
+ conf.getConf(SQLConf.RUNTIME_BLOOM_FILTER_MAX_NUM_BITS))) {
+ val sqlString = s"""
+ |SELECT every(might_contain(
+ | (SELECT bloom_filter_agg(col,
+ | cast($numEstimatedItems as long),
+ | cast($numBits as long))
+ | FROM $table),
+ | col)) positive_membership_test,
+ | every(might_contain(
+ | (SELECT bloom_filter_agg(col,
+ | cast($numEstimatedItems as long),
+ | cast($numBits as long))
+ | FROM values (-1L), (100001L), (20000L) as t(col)),
+ | col)) negative_membership_test
+ |FROM $table
+ """.stripMargin
+ withTempView(table) {
+ (Seq(Long.MinValue, 0, Long.MaxValue) ++ (1L to 10000L))
+ .toDF("col").createOrReplaceTempView(table)
+ // Validate error messages as well as answers when there's no error.
+ if (numEstimatedItems <= 0) {
+ val exception = intercept[AnalysisException] {
+ spark.sql(sqlString)
+ }
+ assert(exception.getMessage.contains(
+ "The estimated number of items must be a positive value"))
+ } else if (numBits <= 0) {
+ val exception = intercept[AnalysisException] {
+ spark.sql(sqlString)
+ }
+ assert(exception.getMessage.contains("The number of bits must be a positive value"))
+ } else {
+ checkAnswer(spark.sql(sqlString), Row(true, false))
+ }
+ }
+ }
+ }
+ }
+
+ test("Test that bloom_filter_agg errors out disallowed input value types") {
+ val exception1 = intercept[AnalysisException] {
+ spark.sql("""
+ |SELECT bloom_filter_agg(a)
+ |FROM values (1.2), (2.5) as t(a)"""
+ .stripMargin)
+ }
+ assert(exception1.getMessage.contains(
+ "Input to function bloom_filter_agg should have been a bigint value"))
+
+ val exception2 = intercept[AnalysisException] {
+ spark.sql("""
+ |SELECT bloom_filter_agg(a, 2)
+ |FROM values (cast(1 as long)), (cast(2 as long)) as t(a)"""
+ .stripMargin)
+ }
+ assert(exception2.getMessage.contains(
+ "function bloom_filter_agg should have been a bigint value followed with two bigint"))
+
+ val exception3 = intercept[AnalysisException] {
+ spark.sql("""
+ |SELECT bloom_filter_agg(a, cast(2 as long), 5)
+ |FROM values (cast(1 as long)), (cast(2 as long)) as t(a)"""
+ .stripMargin)
+ }
+ assert(exception3.getMessage.contains(
+ "function bloom_filter_agg should have been a bigint value followed with two bigint"))
+
+ val exception4 = intercept[AnalysisException] {
+ spark.sql("""
+ |SELECT bloom_filter_agg(a, null, 5)
+ |FROM values (cast(1 as long)), (cast(2 as long)) as t(a)"""
+ .stripMargin)
+ }
+ assert(exception4.getMessage.contains("Null typed values cannot be used as size arguments"))
+
+ val exception5 = intercept[AnalysisException] {
+ spark.sql("""
+ |SELECT bloom_filter_agg(a, 5, null)
+ |FROM values (cast(1 as long)), (cast(2 as long)) as t(a)"""
+ .stripMargin)
+ }
+ assert(exception5.getMessage.contains("Null typed values cannot be used as size arguments"))
+ }
+
+ test("Test that might_contain errors out disallowed input value types") {
+ val exception1 = intercept[AnalysisException] {
+ spark.sql("""|SELECT might_contain(1.0, 1L)"""
+ .stripMargin)
+ }
+ assert(exception1.getMessage.contains(
+ "Input to function might_contain should have been binary followed by a value with bigint"))
+
+ val exception2 = intercept[AnalysisException] {
+ spark.sql("""|SELECT might_contain(NULL, 0.1)"""
+ .stripMargin)
+ }
+ assert(exception2.getMessage.contains(
+ "Input to function might_contain should have been binary followed by a value with bigint"))
+ }
+
+ test("Test that might_contain errors out non-constant Bloom filter") {
+ val exception1 = intercept[AnalysisException] {
+ spark.sql("""
+ |SELECT might_contain(cast(a as binary), cast(5 as long))
+ |FROM values (cast(1 as long)), (cast(2 as long)) as t(a)"""
+ .stripMargin)
+ }
+ assert(exception1.getMessage.contains(
+ "The Bloom filter binary input to might_contain should be either a constant value or " +
+ "a scalar subquery expression"))
+
+ val exception2 = intercept[AnalysisException] {
+ spark.sql("""
+ |SELECT might_contain((select cast(a as binary)), cast(5 as long))
+ |FROM values (cast(1 as long)), (cast(2 as long)) as t(a)"""
+ .stripMargin)
+ }
+ assert(exception2.getMessage.contains(
+ "The Bloom filter binary input to might_contain should be either a constant value or " +
+ "a scalar subquery expression"))
+ }
+
+ test("Test that might_contain can take a constant value input") {
+ checkAnswer(spark.sql(
+ """SELECT might_contain(
+ |X'00000001000000050000000343A2EC6EA8C117E2D3CDB767296B144FC5BFBCED9737F267',
+ |cast(201 as long))""".stripMargin),
+ Row(false))
+ }
+
+ test("Test that bloom_filter_agg produces a NULL with empty input") {
+ checkAnswer(spark.sql("""SELECT bloom_filter_agg(cast(id as long)) from range(1, 1)"""),
+ Row(null))
+ }
+
+ test("Test NULL inputs for might_contain") {
+ checkAnswer(spark.sql(
+ s"""
+ |SELECT might_contain(null, null) both_null,
+ | might_contain(null, 1L) null_bf,
+ | might_contain((SELECT bloom_filter_agg(cast(id as long)) from range(1, 10000)),
+ | null) null_value
+ """.stripMargin),
+ Row(null, null, null))
+ }
+
+ test("Test that a query with bloom_filter_agg has partial aggregates") {
+ assert(spark.sql("""SELECT bloom_filter_agg(cast(id as long)) from range(1, 1000000)""")
+ .queryExecution.executedPlan.asInstanceOf[AdaptiveSparkPlanExec].inputPlan
+ .collect({case agg: BaseAggregateExec => agg}).size == 2)
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/InjectRuntimeFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/InjectRuntimeFilterSuite.scala
new file mode 100644
index 0000000..a5e27fb
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/InjectRuntimeFilterSuite.scala
@@ -0,0 +1,503 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.sql.catalyst.expressions.{Alias, BloomFilterMightContain, Literal}
+import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, BloomFilterAggregate}
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LogicalPlan}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.test.{SharedSparkSession, SQLTestUtils}
+import org.apache.spark.sql.types.{IntegerType, StructType}
+
+class InjectRuntimeFilterSuite extends QueryTest with SQLTestUtils with SharedSparkSession {
+
+ protected override def beforeAll(): Unit = {
+ super.beforeAll()
+ val schema = new StructType().add("a1", IntegerType, nullable = true)
+ .add("b1", IntegerType, nullable = true)
+ .add("c1", IntegerType, nullable = true)
+ .add("d1", IntegerType, nullable = true)
+ .add("e1", IntegerType, nullable = true)
+ .add("f1", IntegerType, nullable = true)
+
+ val data1 = Seq(Seq(null, 47, null, 4, 6, 48),
+ Seq(73, 63, null, 92, null, null),
+ Seq(76, 10, 74, 98, 37, 5),
+ Seq(0, 63, null, null, null, null),
+ Seq(15, 77, null, null, null, null),
+ Seq(null, 57, 33, 55, null, 58),
+ Seq(4, 0, 86, null, 96, 14),
+ Seq(28, 16, 58, null, null, null),
+ Seq(1, 88, null, 8, null, 79),
+ Seq(59, null, null, null, 20, 25),
+ Seq(1, 50, null, 94, 94, null),
+ Seq(null, null, null, 67, 51, 57),
+ Seq(77, 50, 8, 90, 16, 21),
+ Seq(34, 28, null, 5, null, 64),
+ Seq(null, null, 88, 11, 63, 79),
+ Seq(92, 94, 23, 1, null, 64),
+ Seq(57, 56, null, 83, null, null),
+ Seq(null, 35, 8, 35, null, 70),
+ Seq(null, 8, null, 35, null, 87),
+ Seq(9, null, null, 60, null, 5),
+ Seq(null, 15, 66, null, 83, null))
+ val rdd1 = spark.sparkContext.parallelize(data1)
+ val rddRow1 = rdd1.map(s => Row.fromSeq(s))
+ spark.createDataFrame(rddRow1, schema).write.saveAsTable("bf1")
+
+ val schema2 = new StructType().add("a2", IntegerType, nullable = true)
+ .add("b2", IntegerType, nullable = true)
+ .add("c2", IntegerType, nullable = true)
+ .add("d2", IntegerType, nullable = true)
+ .add("e2", IntegerType, nullable = true)
+ .add("f2", IntegerType, nullable = true)
+
+
+ val data2 = Seq(Seq(67, 17, 45, 91, null, null),
+ Seq(98, 63, 0, 89, null, 40),
+ Seq(null, 76, 68, 75, 20, 19),
+ Seq(8, null, null, null, 78, null),
+ Seq(48, 62, null, null, 11, 98),
+ Seq(84, null, 99, 65, 66, 51),
+ Seq(98, null, null, null, 42, 51),
+ Seq(10, 3, 29, null, 68, 8),
+ Seq(85, 36, 41, null, 28, 71),
+ Seq(89, null, 94, 95, 67, 21),
+ Seq(44, null, 24, 33, null, 6),
+ Seq(null, 6, 78, 31, null, 69),
+ Seq(59, 2, 63, 9, 66, 20),
+ Seq(5, 23, 10, 86, 68, null),
+ Seq(null, 63, 99, 55, 9, 65),
+ Seq(57, 62, 68, 5, null, 0),
+ Seq(75, null, 15, null, 81, null),
+ Seq(53, null, 6, 68, 28, 13),
+ Seq(null, null, null, null, 89, 23),
+ Seq(36, 73, 40, null, 8, null),
+ Seq(24, null, null, 40, null, null))
+ val rdd2 = spark.sparkContext.parallelize(data2)
+ val rddRow2 = rdd2.map(s => Row.fromSeq(s))
+ spark.createDataFrame(rddRow2, schema2).write.saveAsTable("bf2")
+
+ val schema3 = new StructType().add("a3", IntegerType, nullable = true)
+ .add("b3", IntegerType, nullable = true)
+ .add("c3", IntegerType, nullable = true)
+ .add("d3", IntegerType, nullable = true)
+ .add("e3", IntegerType, nullable = true)
+ .add("f3", IntegerType, nullable = true)
+
+ val data3 = Seq(Seq(67, 17, 45, 91, null, null),
+ Seq(98, 63, 0, 89, null, 40),
+ Seq(null, 76, 68, 75, 20, 19),
+ Seq(8, null, null, null, 78, null),
+ Seq(48, 62, null, null, 11, 98),
+ Seq(84, null, 99, 65, 66, 51),
+ Seq(98, null, null, null, 42, 51),
+ Seq(10, 3, 29, null, 68, 8),
+ Seq(85, 36, 41, null, 28, 71),
+ Seq(89, null, 94, 95, 67, 21),
+ Seq(44, null, 24, 33, null, 6),
+ Seq(null, 6, 78, 31, null, 69),
+ Seq(59, 2, 63, 9, 66, 20),
+ Seq(5, 23, 10, 86, 68, null),
+ Seq(null, 63, 99, 55, 9, 65),
+ Seq(57, 62, 68, 5, null, 0),
+ Seq(75, null, 15, null, 81, null),
+ Seq(53, null, 6, 68, 28, 13),
+ Seq(null, null, null, null, 89, 23),
+ Seq(36, 73, 40, null, 8, null),
+ Seq(24, null, null, 40, null, null))
+ val rdd3 = spark.sparkContext.parallelize(data3)
+ val rddRow3 = rdd3.map(s => Row.fromSeq(s))
+ spark.createDataFrame(rddRow3, schema3).write.saveAsTable("bf3")
+
+
+ val schema4 = new StructType().add("a4", IntegerType, nullable = true)
+ .add("b4", IntegerType, nullable = true)
+ .add("c4", IntegerType, nullable = true)
+ .add("d4", IntegerType, nullable = true)
+ .add("e4", IntegerType, nullable = true)
+ .add("f4", IntegerType, nullable = true)
+
+ val data4 = Seq(Seq(67, 17, 45, 91, null, null),
+ Seq(98, 63, 0, 89, null, 40),
+ Seq(null, 76, 68, 75, 20, 19),
+ Seq(8, null, null, null, 78, null),
+ Seq(48, 62, null, null, 11, 98),
+ Seq(84, null, 99, 65, 66, 51),
+ Seq(98, null, null, null, 42, 51),
+ Seq(10, 3, 29, null, 68, 8),
+ Seq(85, 36, 41, null, 28, 71),
+ Seq(89, null, 94, 95, 67, 21),
+ Seq(44, null, 24, 33, null, 6),
+ Seq(null, 6, 78, 31, null, 69),
+ Seq(59, 2, 63, 9, 66, 20),
+ Seq(5, 23, 10, 86, 68, null),
+ Seq(null, 63, 99, 55, 9, 65),
+ Seq(57, 62, 68, 5, null, 0),
+ Seq(75, null, 15, null, 81, null),
+ Seq(53, null, 6, 68, 28, 13),
+ Seq(null, null, null, null, 89, 23),
+ Seq(36, 73, 40, null, 8, null),
+ Seq(24, null, null, 40, null, null))
+ val rdd4 = spark.sparkContext.parallelize(data4)
+ val rddRow4 = rdd4.map(s => Row.fromSeq(s))
+ spark.createDataFrame(rddRow4, schema4).write.saveAsTable("bf4")
+
+ val schema5part = new StructType().add("a5", IntegerType, nullable = true)
+ .add("b5", IntegerType, nullable = true)
+ .add("c5", IntegerType, nullable = true)
+ .add("d5", IntegerType, nullable = true)
+ .add("e5", IntegerType, nullable = true)
+ .add("f5", IntegerType, nullable = true)
+
+ val data5part = Seq(Seq(67, 17, 45, 91, null, null),
+ Seq(98, 63, 0, 89, null, 40),
+ Seq(null, 76, 68, 75, 20, 19),
+ Seq(8, null, null, null, 78, null),
+ Seq(48, 62, null, null, 11, 98),
+ Seq(84, null, 99, 65, 66, 51),
+ Seq(98, null, null, null, 42, 51),
+ Seq(10, 3, 29, null, 68, 8),
+ Seq(85, 36, 41, null, 28, 71),
+ Seq(89, null, 94, 95, 67, 21),
+ Seq(44, null, 24, 33, null, 6),
+ Seq(null, 6, 78, 31, null, 69),
+ Seq(59, 2, 63, 9, 66, 20),
+ Seq(5, 23, 10, 86, 68, null),
+ Seq(null, 63, 99, 55, 9, 65),
+ Seq(57, 62, 68, 5, null, 0),
+ Seq(75, null, 15, null, 81, null),
+ Seq(53, null, 6, 68, 28, 13),
+ Seq(null, null, null, null, 89, 23),
+ Seq(36, 73, 40, null, 8, null),
+ Seq(24, null, null, 40, null, null))
+ val rdd5part = spark.sparkContext.parallelize(data5part)
+ val rddRow5part = rdd5part.map(s => Row.fromSeq(s))
+ spark.createDataFrame(rddRow5part, schema5part).write.partitionBy("f5")
+ .saveAsTable("bf5part")
+ spark.createDataFrame(rddRow5part, schema5part).filter("a5 > 30")
+ .write.partitionBy("f5")
+ .saveAsTable("bf5filtered")
+
+ sql("analyze table bf1 compute statistics for columns a1, b1, c1, d1, e1, f1")
+ sql("analyze table bf2 compute statistics for columns a2, b2, c2, d2, e2, f2")
+ sql("analyze table bf3 compute statistics for columns a3, b3, c3, d3, e3, f3")
+ sql("analyze table bf4 compute statistics for columns a4, b4, c4, d4, e4, f4")
+ sql("analyze table bf5part compute statistics for columns a5, b5, c5, d5, e5, f5")
+ sql("analyze table bf5filtered compute statistics for columns a5, b5, c5, d5, e5, f5")
+ }
+
+ protected override def afterAll(): Unit = try {
+ sql("DROP TABLE IF EXISTS bf1")
+ sql("DROP TABLE IF EXISTS bf2")
+ sql("DROP TABLE IF EXISTS bf3")
+ sql("DROP TABLE IF EXISTS bf4")
+ sql("DROP TABLE IF EXISTS bf5part")
+ sql("DROP TABLE IF EXISTS bf5filtered")
+ } finally {
+ super.afterAll()
+ }
+
+ def checkWithAndWithoutFeatureEnabled(query: String, testSemiJoin: Boolean,
+ shouldReplace: Boolean): Unit = {
+ var planDisabled: LogicalPlan = null
+ var planEnabled: LogicalPlan = null
+ var expectedAnswer: Array[Row] = null
+
+ withSQLConf(SQLConf.RUNTIME_FILTER_SEMI_JOIN_REDUCTION_ENABLED.key -> "false",
+ SQLConf.RUNTIME_BLOOM_FILTER_ENABLED.key -> "false") {
+ planDisabled = sql(query).queryExecution.optimizedPlan
+ expectedAnswer = sql(query).collect()
+ }
+
+ if (testSemiJoin) {
+ withSQLConf(SQLConf.RUNTIME_FILTER_SEMI_JOIN_REDUCTION_ENABLED.key -> "true",
+ SQLConf.RUNTIME_BLOOM_FILTER_ENABLED.key -> "false") {
+ planEnabled = sql(query).queryExecution.optimizedPlan
+ checkAnswer(sql(query), expectedAnswer)
+ }
+ if (shouldReplace) {
+ val normalizedEnabled = normalizePlan(normalizeExprIds(planEnabled))
+ val normalizedDisabled = normalizePlan(normalizeExprIds(planDisabled))
+ assert(normalizedEnabled != normalizedDisabled)
+ } else {
+ comparePlans(planDisabled, planEnabled)
+ }
+ } else {
+ withSQLConf(SQLConf.RUNTIME_FILTER_SEMI_JOIN_REDUCTION_ENABLED.key -> "false",
+ SQLConf.RUNTIME_BLOOM_FILTER_ENABLED.key -> "true") {
+ planEnabled = sql(query).queryExecution.optimizedPlan
+ checkAnswer(sql(query), expectedAnswer)
+ if (shouldReplace) {
+ assert(getNumBloomFilters(planEnabled) > getNumBloomFilters(planDisabled))
+ } else {
+ assert(getNumBloomFilters(planEnabled) == getNumBloomFilters(planDisabled))
+ }
+ }
+ }
+ }
+
+ def getNumBloomFilters(plan: LogicalPlan): Integer = {
+ val numBloomFilterAggs = plan.collect {
+ case Filter(condition, _) => condition.collect {
+ case subquery: org.apache.spark.sql.catalyst.expressions.ScalarSubquery
+ => subquery.plan.collect {
+ case Aggregate(_, aggregateExpressions, _) =>
+ aggregateExpressions.map {
+ case Alias(AggregateExpression(bfAgg : BloomFilterAggregate, _, _, _, _),
+ _) =>
+ assert(bfAgg.estimatedNumItemsExpression.isInstanceOf[Literal])
+ assert(bfAgg.numBitsExpression.isInstanceOf[Literal])
+ 1
+ }.sum
+ }.sum
+ }.sum
+ }.sum
+ val numMightContains = plan.collect {
+ case Filter(condition, _) => condition.collect {
+ case BloomFilterMightContain(_, _) => 1
+ }.sum
+ }.sum
+ assert(numBloomFilterAggs == numMightContains)
+ numMightContains
+ }
+
+ def assertRewroteSemiJoin(query: String): Unit = {
+ checkWithAndWithoutFeatureEnabled(query, testSemiJoin = true, shouldReplace = true)
+ }
+
+ def assertDidNotRewriteSemiJoin(query: String): Unit = {
+ checkWithAndWithoutFeatureEnabled(query, testSemiJoin = true, shouldReplace = false)
+ }
+
+ def assertRewroteWithBloomFilter(query: String): Unit = {
+ checkWithAndWithoutFeatureEnabled(query, testSemiJoin = false, shouldReplace = true)
+ }
+
+ def assertDidNotRewriteWithBloomFilter(query: String): Unit = {
+ checkWithAndWithoutFeatureEnabled(query, testSemiJoin = false, shouldReplace = false)
+ }
+
+ test("Runtime semi join reduction: simple") {
+ // Filter creation side is 3409 bytes
+ // Filter application side scan is 3362 bytes
+ withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000",
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2000") {
+ assertRewroteSemiJoin("select * from bf1 join bf2 on bf1.c1 = bf2.c2 where bf2.a2 = 62")
+ assertDidNotRewriteSemiJoin("select * from bf1 join bf2 on bf1.c1 = bf2.c2")
+ }
+ }
+
+ test("Runtime semi join reduction: two joins") {
+ withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000",
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2000") {
+ assertRewroteSemiJoin("select * from bf1 join bf2 join bf3 on bf1.c1 = bf2.c2 " +
+ "and bf3.c3 = bf2.c2 where bf2.a2 = 5")
+ }
+ }
+
+ test("Runtime semi join reduction: three joins") {
+ withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000",
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2000") {
+ assertRewroteSemiJoin("select * from bf1 join bf2 join bf3 join bf4 on " +
+ "bf1.c1 = bf2.c2 and bf2.c2 = bf3.c3 and bf3.c3 = bf4.c4 where bf1.a1 = 5")
+ }
+ }
+
+ test("Runtime semi join reduction: simple expressions only") {
+ withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000",
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2000") {
+ val squared = (s: Long) => {
+ s * s
+ }
+ spark.udf.register("square", squared)
+ assertDidNotRewriteSemiJoin("select * from bf1 join bf2 on " +
+ "bf1.c1 = bf2.c2 where square(bf2.a2) = 62")
+ assertDidNotRewriteSemiJoin("select * from bf1 join bf2 on " +
+ "bf1.c1 = square(bf2.c2) where bf2.a2= 62")
+ }
+ }
+
+ test("Runtime bloom filter join: simple") {
+ withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000",
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2000") {
+ assertRewroteWithBloomFilter("select * from bf1 join bf2 on bf1.c1 = bf2.c2 " +
+ "where bf2.a2 = 62")
+ assertDidNotRewriteWithBloomFilter("select * from bf1 join bf2 on bf1.c1 = bf2.c2")
+ }
+ }
+
+ test("Runtime bloom filter join: two filters single join") {
+ withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000",
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2000") {
+ var planDisabled: LogicalPlan = null
+ var planEnabled: LogicalPlan = null
+ var expectedAnswer: Array[Row] = null
+
+ val query = "select * from bf1 join bf2 on bf1.c1 = bf2.c2 and " +
+ "bf1.b1 = bf2.b2 where bf2.a2 = 62"
+
+ withSQLConf(SQLConf.RUNTIME_FILTER_SEMI_JOIN_REDUCTION_ENABLED.key -> "false",
+ SQLConf.RUNTIME_BLOOM_FILTER_ENABLED.key -> "false") {
+ planDisabled = sql(query).queryExecution.optimizedPlan
+ expectedAnswer = sql(query).collect()
+ }
+
+ withSQLConf(SQLConf.RUNTIME_FILTER_SEMI_JOIN_REDUCTION_ENABLED.key -> "false",
+ SQLConf.RUNTIME_BLOOM_FILTER_ENABLED.key -> "true") {
+ planEnabled = sql(query).queryExecution.optimizedPlan
+ checkAnswer(sql(query), expectedAnswer)
+ }
+ assert(getNumBloomFilters(planEnabled) == getNumBloomFilters(planDisabled) + 2)
+ }
+ }
+
+ test("Runtime bloom filter join: test the number of filter threshold") {
+ withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000",
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2000") {
+ var planDisabled: LogicalPlan = null
+ var planEnabled: LogicalPlan = null
+ var expectedAnswer: Array[Row] = null
+
+ val query = "select * from bf1 join bf2 on bf1.c1 = bf2.c2 and " +
+ "bf1.b1 = bf2.b2 where bf2.a2 = 62"
+
+ withSQLConf(SQLConf.RUNTIME_FILTER_SEMI_JOIN_REDUCTION_ENABLED.key -> "false",
+ SQLConf.RUNTIME_BLOOM_FILTER_ENABLED.key -> "false") {
+ planDisabled = sql(query).queryExecution.optimizedPlan
+ expectedAnswer = sql(query).collect()
+ }
+
+ for (numFilterThreshold <- 0 to 3) {
+ withSQLConf(SQLConf.RUNTIME_FILTER_SEMI_JOIN_REDUCTION_ENABLED.key -> "false",
+ SQLConf.RUNTIME_BLOOM_FILTER_ENABLED.key -> "true",
+ SQLConf.RUNTIME_FILTER_NUMBER_THRESHOLD.key -> numFilterThreshold.toString) {
+ planEnabled = sql(query).queryExecution.optimizedPlan
+ checkAnswer(sql(query), expectedAnswer)
+ }
+ if (numFilterThreshold < 3) {
+ assert(getNumBloomFilters(planEnabled) == getNumBloomFilters(planDisabled)
+ + numFilterThreshold)
+ } else {
+ assert(getNumBloomFilters(planEnabled) == getNumBloomFilters(planDisabled) + 2)
+ }
+ }
+ }
+ }
+
+ test("Runtime bloom filter join: insert one bloom filter per column") {
+ withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000",
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2000") {
+ var planDisabled: LogicalPlan = null
+ var planEnabled: LogicalPlan = null
+ var expectedAnswer: Array[Row] = null
+
+ val query = "select * from bf1 join bf2 on bf1.c1 = bf2.c2 and " +
+ "bf1.c1 = bf2.b2 where bf2.a2 = 62"
+
+ withSQLConf(SQLConf.RUNTIME_FILTER_SEMI_JOIN_REDUCTION_ENABLED.key -> "false",
+ SQLConf.RUNTIME_BLOOM_FILTER_ENABLED.key -> "false") {
+ planDisabled = sql(query).queryExecution.optimizedPlan
+ expectedAnswer = sql(query).collect()
+ }
+
+ withSQLConf(SQLConf.RUNTIME_FILTER_SEMI_JOIN_REDUCTION_ENABLED.key -> "false",
+ SQLConf.RUNTIME_BLOOM_FILTER_ENABLED.key -> "true") {
+ planEnabled = sql(query).queryExecution.optimizedPlan
+ checkAnswer(sql(query), expectedAnswer)
+ }
+ assert(getNumBloomFilters(planEnabled) == getNumBloomFilters(planDisabled) + 1)
+ }
+ }
+
+ test("Runtime bloom filter join: do not add bloom filter if dpp filter exists " +
+ "on the same column") {
+ withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000",
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2000") {
+ assertDidNotRewriteWithBloomFilter("select * from bf5part join bf2 on " +
+ "bf5part.f5 = bf2.c2 where bf2.a2 = 62")
+ }
+ }
+
+ test("Runtime bloom filter join: add bloom filter if dpp filter exists on " +
+ "a different column") {
+ withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000",
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2000") {
+ assertRewroteWithBloomFilter("select * from bf5part join bf2 on " +
+ "bf5part.c5 = bf2.c2 and bf5part.f5 = bf2.f2 where bf2.a2 = 62")
+ }
+ }
+
+ test("Runtime bloom filter join: BF rewrite triggering threshold test") {
+ // Filter creation side data size is 3409 bytes. On the filter application side, an individual
+ // scan's byte size is 3362.
+ withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000",
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "3000",
+ SQLConf.RUNTIME_BLOOM_FILTER_CREATION_SIDE_THRESHOLD.key -> "4000"
+ ) {
+ assertRewroteWithBloomFilter("select * from bf1 join bf2 on bf1.c1 = bf2.c2 " +
+ "where bf2.a2 = 62")
+ }
+ withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000",
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "50",
+ SQLConf.RUNTIME_BLOOM_FILTER_CREATION_SIDE_THRESHOLD.key -> "50"
+ ) {
+ assertDidNotRewriteWithBloomFilter("select * from bf1 join bf2 on bf1.c1 = bf2.c2 " +
+ "where bf2.a2 = 62")
+ }
+ withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "5000",
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "3000",
+ SQLConf.RUNTIME_BLOOM_FILTER_CREATION_SIDE_THRESHOLD.key -> "4000"
+ ) {
+ // Rewrite should not be triggered as the Bloom filter application side scan size is small.
+ assertDidNotRewriteWithBloomFilter("select * from bf1 join bf2 on bf1.c1 = bf2.c2 "
+ + "where bf2.a2 = 62")
+ }
+ withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "32",
+ SQLConf.RUNTIME_BLOOM_FILTER_CREATION_SIDE_THRESHOLD.key -> "4000") {
+ // Test that the max scan size rather than an individual scan size on the filter
+ // application side matters. `bf5filtered` has 14168 bytes and `bf2` has 3409 bytes.
+ withSQLConf(
+ SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "5000") {
+ assertRewroteWithBloomFilter("select * from " +
+ "(select * from bf5filtered union all select * from bf2) t " +
+ "join bf3 on t.c5 = bf3.c3 where bf3.a3 = 5")
+ }
+ withSQLConf(
+ SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "15000") {
+ assertDidNotRewriteWithBloomFilter("select * from " +
+ "(select * from bf5filtered union all select * from bf2) t " +
+ "join bf3 on t.c5 = bf3.c3 where bf3.a3 = 5")
+ }
+ }
+ }
+
+ test("Runtime bloom filter join: simple expressions only") {
+ withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000",
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2000") {
+ val squared = (s: Long) => {
+ s * s
+ }
+ spark.udf.register("square", squared)
+ assertDidNotRewriteWithBloomFilter("select * from bf1 join bf2 on " +
+ "bf1.c1 = bf2.c2 where square(bf2.a2) = 62" )
+ assertDidNotRewriteWithBloomFilter("select * from bf1 join bf2 on " +
+ "bf1.c1 = square(bf2.c2) where bf2.a2 = 62" )
+ }
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org